From 40219ab0fce492f8ff91f99f909e3b5060483e32 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 23:33:09 -0500 Subject: [PATCH] Trellis2: share batched mesh helpers --- comfy_extras/mesh_batch_utils.py | 53 +++++++++++++++++++++++++++++ comfy_extras/nodes_hunyuan3d.py | 53 +---------------------------- comfy_extras/nodes_trellis2.py | 58 +++----------------------------- 3 files changed, 58 insertions(+), 106 deletions(-) create mode 100644 comfy_extras/mesh_batch_utils.py diff --git a/comfy_extras/mesh_batch_utils.py b/comfy_extras/mesh_batch_utils.py new file mode 100644 index 000000000..841328776 --- /dev/null +++ b/comfy_extras/mesh_batch_utils.py @@ -0,0 +1,53 @@ +import torch +from comfy_api.latest import Types + + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 0b7e17bb5..78ab3b841 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -10,6 +10,7 @@ from comfy.cli_args import args from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa +from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item class EmptyLatentHunyuan3Dv2(IO.ComfyNode): @@ -631,58 +632,6 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors - - class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 7a72b2824..cdac6f103 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,6 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor +from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item import comfy.model_management from PIL import Image import numpy as np @@ -8,57 +9,6 @@ import torch import scipy import copy - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors - shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, @@ -130,14 +80,14 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): final_colors = linear_colors.unsqueeze(0) - out_mesh = copy.deepcopy(mesh) + out_mesh = copy.copy(mesh) out_mesh.colors = final_colors return out_mesh def paint_mesh_default_colors(mesh): - out_mesh = copy.deepcopy(mesh) + out_mesh = copy.copy(mesh) vertex_count = mesh.vertices.shape[1] out_mesh.colors = mesh.vertices.new_zeros((1, vertex_count, 3)) return out_mesh @@ -400,7 +350,7 @@ class Trellis2Conditioning(IO.ComfyNode): mask = mask.unsqueeze(0) batch_size = image.shape[0] if mask.shape[0] == 1 and batch_size > 1: - mask = mask.repeat(batch_size, 1, 1) + mask = mask.expand(batch_size, -1, -1) elif mask.shape[0] != batch_size: raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")