From 6d99b636c12315f340c668b05d97431b1b547b5c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 22:55:38 -0500 Subject: [PATCH] Trellis2/Hunyuan3d: preserve mesh tensor contract in batch mode --- comfy_extras/nodes_hunyuan3d.py | 61 ++++++++++++++++++++++--- comfy_extras/nodes_trellis2.py | 80 ++++++++++++++++++++++++++------- 2 files changed, 121 insertions(+), 20 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 8f58e85d9..0b7e17bb5 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -445,7 +445,7 @@ class VoxelToMeshBasic(IO.ComfyNode): if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) - return IO.NodeOutput(Types.MESH(vertices, faces)) + return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) decode = execute # TODO: remove @@ -483,7 +483,7 @@ class VoxelToMesh(IO.ComfyNode): if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) - return IO.NodeOutput(Types.MESH(vertices, faces)) + return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) decode = execute # TODO: remove @@ -632,6 +632,57 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + + class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): @@ -686,11 +737,11 @@ class SaveGLB(IO.ComfyNode): }) else: # Handle Mesh input - save vertices and faces as GLB - bsz = len(mesh.vertices) if isinstance(mesh.vertices, list) else mesh.vertices.shape[0] + bsz = mesh.vertices.shape[0] for i in range(bsz): f = f"{filename}_{counter:05}_.glb" - v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) + vertices, faces, v_colors = get_mesh_batch_item(mesh, i) + save_glb(vertices, faces, os.path.join(full_output_folder, f), metadata, v_colors) results.append({ "filename": f, "subfolder": subfolder, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8ef3e8f5a..57732151b 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -8,6 +8,57 @@ import torch import scipy import copy + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, @@ -122,7 +173,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) else: - mesh = Types.MESH(vertices=vert_list, faces=face_list) + mesh = pack_variable_mesh_batch(vert_list, face_list) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -165,19 +216,19 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] - if isinstance(shape_mesh.vertices, list): + if hasattr(shape_mesh, "vertex_counts"): out_verts, out_faces, out_colors = [], [], [] - for i in range(len(shape_mesh.vertices)): + for i in range(shape_mesh.vertices.shape[0]): sel = voxel_batch_idx == i item_coords = voxel_coords[sel] item_colors = color_feats[sel] - item_mesh = Types.MESH(vertices=shape_mesh.vertices[i].unsqueeze(0), faces=shape_mesh.faces[i].unsqueeze(0)) + item_vertices, item_faces, _ = get_mesh_batch_item(shape_mesh, i) + item_mesh = Types.MESH(vertices=item_vertices.unsqueeze(0), faces=item_faces.unsqueeze(0)) painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) out_verts.append(painted.vertices.squeeze(0)) out_faces.append(painted.faces.squeeze(0)) out_colors.append(painted.colors.squeeze(0)) - out_mesh = Types.MESH(vertices=out_verts, faces=out_faces) - out_mesh.colors = out_colors + out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors) else: out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) @@ -334,6 +385,10 @@ class Trellis2Conditioning(IO.ComfyNode): if mask.ndim == 2: mask = mask.unsqueeze(0) batch_size = image.shape[0] + if mask.shape[0] == 1 and batch_size > 1: + mask = mask.repeat(batch_size, 1, 1) + elif mask.shape[0] != batch_size: + raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}") cond_512_list = [] cond_1024_list = [] @@ -691,13 +746,10 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - if isinstance(mesh.vertices, list): + if hasattr(mesh, "vertex_counts"): out_verts, out_faces, out_colors = [], [], [] - colors_in = mesh.colors if hasattr(mesh, "colors") and mesh.colors is not None else None - for i in range(len(mesh.vertices)): - v_i = mesh.vertices[i] - f_i = mesh.faces[i] - c_i = colors_in[i] if colors_in is not None else None + for i in range(mesh.vertices.shape[0]): + v_i, f_i, c_i = get_mesh_batch_item(mesh, i) actual_face_count = f_i.shape[0] if fill_holes_perimeter > 0: v_i, f_i = fill_holes_fn(v_i, f_i, max_perimeter=fill_holes_perimeter) @@ -708,9 +760,7 @@ class PostProcessMesh(IO.ComfyNode): out_faces.append(f_i) if c_i is not None: out_colors.append(c_i) - out_mesh = type(mesh)(vertices=out_verts, faces=out_faces) - if len(out_colors) == len(out_verts): - out_mesh.colors = out_colors + out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors if len(out_colors) == len(out_verts) else None) return IO.NodeOutput(out_mesh) verts, faces = mesh.vertices, mesh.faces colors = None