From 44043ee6e5fadaa569c8a905915901b677dec65f Mon Sep 17 00:00:00 2001 From: John Pollock Date: Fri, 17 Apr 2026 22:42:42 -0500 Subject: [PATCH 1/5] Trellis2/Hunyuan3d: n>1 batched cascade support Mesh-producing nodes (VoxelToMeshBasic, VoxelToMesh, VaeDecodeShapeTrellis) previously stacked per-batch vertex/face tensors with torch.stack, which failed under batch>1 because per-item meshes have variable shapes. Store per-item tensors as lists when shapes differ; keep stacked-tensor fast path when shapes match. Update SaveGLB, PostProcessMesh, and VaeDecodeTextureTrellis consumers to iterate per-item when input is a list. Trellis2Conditioning.execute previously collapsed batched image/mask input to index 0, yielding identical conditioning for every batch item. Loop over the batch and produce per-image cond_512/cond_1024/neg_cond tensors stacked along the batch dim, matching the latent batch size. batch_size=1 behavior is unchanged. batch_size>1 runs now emit one GLB per batch item per SaveGLB node and carry per-image conditioning through the structure/shape/texture cascade. --- comfy_extras/nodes_hunyuan3d.py | 11 ++- comfy_extras/nodes_trellis2.py | 156 +++++++++++++++++++++----------- 2 files changed, 112 insertions(+), 55 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index ac91fe0a7..8f58e85d9 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -443,7 +443,9 @@ class VoxelToMeshBasic(IO.ComfyNode): vertices.append(v) faces.append(f) - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + return IO.NodeOutput(Types.MESH(vertices, faces)) decode = execute # TODO: remove @@ -479,7 +481,9 @@ class VoxelToMesh(IO.ComfyNode): vertices.append(v) faces.append(f) - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + return IO.NodeOutput(Types.MESH(vertices, faces)) decode = execute # TODO: remove @@ -682,7 +686,8 @@ class SaveGLB(IO.ComfyNode): }) else: # Handle Mesh input - save vertices and faces as GLB - for i in range(mesh.vertices.shape[0]): + bsz = len(mesh.vertices) if isinstance(mesh.vertices, list) else mesh.vertices.shape[0] + for i in range(bsz): f = f"{filename}_{counter:05}_.glb" v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 61d3532a1..8ef3e8f5a 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -117,9 +117,12 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): samples = shape_norm(samples, coords) mesh, subs = vae.decode_shape_slat(samples, resolution) - faces = torch.stack([m.faces for m in mesh]) - verts = torch.stack([m.vertices for m in mesh]) - mesh = Types.MESH(vertices=verts, faces=faces) + face_list = [m.faces for m in mesh] + vert_list = [m.vertices for m in mesh] + if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): + mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) + else: + mesh = Types.MESH(vertices=vert_list, faces=face_list) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -160,8 +163,23 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel = vae.decode_tex_slat(samples, shape_subs) color_feats = voxel.feats[:, :3] voxel_coords = voxel.coords[:, 1:] + voxel_batch_idx = voxel.coords[:, 0] - out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + if isinstance(shape_mesh.vertices, list): + out_verts, out_faces, out_colors = [], [], [] + for i in range(len(shape_mesh.vertices)): + sel = voxel_batch_idx == i + item_coords = voxel_coords[sel] + item_colors = color_feats[sel] + item_mesh = Types.MESH(vertices=shape_mesh.vertices[i].unsqueeze(0), faces=shape_mesh.faces[i].unsqueeze(0)) + painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) + out_verts.append(painted.vertices.squeeze(0)) + out_faces.append(painted.faces.squeeze(0)) + out_colors.append(painted.colors.squeeze(0)) + out_mesh = Types.MESH(vertices=out_verts, faces=out_faces) + out_mesh.colors = out_colors + else: + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @@ -310,69 +328,83 @@ class Trellis2Conditioning(IO.ComfyNode): @classmethod def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: + # Normalize to batched form so per-image conditioning loop below is uniform. + if image.ndim == 3: + image = image.unsqueeze(0) + if mask.ndim == 2: + mask = mask.unsqueeze(0) + batch_size = image.shape[0] - if image.ndim == 4: - image = image[0] - if mask.ndim == 3: - mask = mask[0] + cond_512_list = [] + cond_1024_list = [] - img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + for b in range(batch_size): + item_image = image[b] + item_mask = mask[b] - pil_img = Image.fromarray(img_np) - pil_mask = Image.fromarray(mask_np) + img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - max_size = max(pil_img.size) - scale = min(1.0, 1024 / max_size) - if scale < 1.0: - new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) - pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) - pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) + pil_img = Image.fromarray(img_np) + pil_mask = Image.fromarray(mask_np) - rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) - rgba_np[:, :, :3] = np.array(pil_img) - rgba_np[:, :, 3] = np.array(pil_mask) + max_size = max(pil_img.size) + scale = min(1.0, 1024 / max_size) + if scale < 1.0: + new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) + pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) - alpha = rgba_np[:, :, 3] - bbox_coords = np.argwhere(alpha > 0.8 * 255) + rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) + rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, 3] = np.array(pil_mask) - if len(bbox_coords) > 0: - y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) - y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + alpha = rgba_np[:, :, 3] + bbox_coords = np.argwhere(alpha > 0.8 * 255) - center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 - size = max(y_max - y_min, x_max - x_min) + if len(bbox_coords) > 0: + y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) + y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) - crop_x1 = int(center_x - size // 2) - crop_y1 = int(center_y - size // 2) - crop_x2 = int(center_x + size // 2) - crop_y2 = int(center_y + size // 2) + center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 + size = max(y_max - y_min, x_max - x_min) - rgba_pil = Image.fromarray(rgba_np) - cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) - cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 - else: - import logging - logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") - cropped_np = rgba_np.astype(np.float32) / 255.0 + crop_x1 = int(center_x - size // 2) + crop_y1 = int(center_y - size // 2) + crop_x2 = int(center_x + size // 2) + crop_y2 = int(center_y + size // 2) - bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} - bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) + rgba_pil = Image.fromarray(rgba_np) + cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 + else: + import logging + logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") + cropped_np = rgba_np.astype(np.float32) / 255.0 - fg = cropped_np[:, :, :3] - alpha_float = cropped_np[:, :, 3:4] - composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) + bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) - # to match trellis2 code (quantize -> dequantize) - composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + fg = cropped_np[:, :, :3] + alpha_float = cropped_np[:, :, 3:4] + composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) - cropped_pil = Image.fromarray(composite_uint8) + # to match trellis2 code (quantize -> dequantize) + composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) + cropped_pil = Image.fromarray(composite_uint8) - embeds = conditioning["cond_1024"] - positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] + item_conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) + cond_512_list.append(item_conditioning["cond_512"]) + cond_1024_list.append(item_conditioning["cond_1024"]) + + cond_512_batched = torch.cat(cond_512_list, dim=0) + cond_1024_batched = torch.cat(cond_1024_list, dim=0) + neg_cond_batched = torch.zeros_like(cond_512_batched) + neg_embeds_batched = torch.zeros_like(cond_1024_batched) + + positive = [[cond_512_batched, {"embeds": cond_1024_batched}]] + negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -659,7 +691,27 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - # TODO: batched mode may break + if isinstance(mesh.vertices, list): + out_verts, out_faces, out_colors = [], [], [] + colors_in = mesh.colors if hasattr(mesh, "colors") and mesh.colors is not None else None + for i in range(len(mesh.vertices)): + v_i = mesh.vertices[i] + f_i = mesh.faces[i] + c_i = colors_in[i] if colors_in is not None else None + actual_face_count = f_i.shape[0] + if fill_holes_perimeter > 0: + v_i, f_i = fill_holes_fn(v_i, f_i, max_perimeter=fill_holes_perimeter) + if simplify > 0 and actual_face_count > simplify: + v_i, f_i, c_i = simplify_fn(v_i, f_i, target=simplify, colors=c_i) + v_i, f_i = make_double_sided(v_i, f_i) + out_verts.append(v_i) + out_faces.append(f_i) + if c_i is not None: + out_colors.append(c_i) + out_mesh = type(mesh)(vertices=out_verts, faces=out_faces) + if len(out_colors) == len(out_verts): + out_mesh.colors = out_colors + return IO.NodeOutput(out_mesh) verts, faces = mesh.vertices, mesh.faces colors = None if hasattr(mesh, "colors"): From 6d99b636c12315f340c668b05d97431b1b547b5c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 22:55:38 -0500 Subject: [PATCH 2/5] Trellis2/Hunyuan3d: preserve mesh tensor contract in batch mode --- comfy_extras/nodes_hunyuan3d.py | 61 ++++++++++++++++++++++--- comfy_extras/nodes_trellis2.py | 80 ++++++++++++++++++++++++++------- 2 files changed, 121 insertions(+), 20 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 8f58e85d9..0b7e17bb5 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -445,7 +445,7 @@ class VoxelToMeshBasic(IO.ComfyNode): if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) - return IO.NodeOutput(Types.MESH(vertices, faces)) + return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) decode = execute # TODO: remove @@ -483,7 +483,7 @@ class VoxelToMesh(IO.ComfyNode): if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) - return IO.NodeOutput(Types.MESH(vertices, faces)) + return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) decode = execute # TODO: remove @@ -632,6 +632,57 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + + class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): @@ -686,11 +737,11 @@ class SaveGLB(IO.ComfyNode): }) else: # Handle Mesh input - save vertices and faces as GLB - bsz = len(mesh.vertices) if isinstance(mesh.vertices, list) else mesh.vertices.shape[0] + bsz = mesh.vertices.shape[0] for i in range(bsz): f = f"{filename}_{counter:05}_.glb" - v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) + vertices, faces, v_colors = get_mesh_batch_item(mesh, i) + save_glb(vertices, faces, os.path.join(full_output_folder, f), metadata, v_colors) results.append({ "filename": f, "subfolder": subfolder, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8ef3e8f5a..57732151b 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -8,6 +8,57 @@ import torch import scipy import copy + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, @@ -122,7 +173,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) else: - mesh = Types.MESH(vertices=vert_list, faces=face_list) + mesh = pack_variable_mesh_batch(vert_list, face_list) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -165,19 +216,19 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] - if isinstance(shape_mesh.vertices, list): + if hasattr(shape_mesh, "vertex_counts"): out_verts, out_faces, out_colors = [], [], [] - for i in range(len(shape_mesh.vertices)): + for i in range(shape_mesh.vertices.shape[0]): sel = voxel_batch_idx == i item_coords = voxel_coords[sel] item_colors = color_feats[sel] - item_mesh = Types.MESH(vertices=shape_mesh.vertices[i].unsqueeze(0), faces=shape_mesh.faces[i].unsqueeze(0)) + item_vertices, item_faces, _ = get_mesh_batch_item(shape_mesh, i) + item_mesh = Types.MESH(vertices=item_vertices.unsqueeze(0), faces=item_faces.unsqueeze(0)) painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) out_verts.append(painted.vertices.squeeze(0)) out_faces.append(painted.faces.squeeze(0)) out_colors.append(painted.colors.squeeze(0)) - out_mesh = Types.MESH(vertices=out_verts, faces=out_faces) - out_mesh.colors = out_colors + out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors) else: out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) @@ -334,6 +385,10 @@ class Trellis2Conditioning(IO.ComfyNode): if mask.ndim == 2: mask = mask.unsqueeze(0) batch_size = image.shape[0] + if mask.shape[0] == 1 and batch_size > 1: + mask = mask.repeat(batch_size, 1, 1) + elif mask.shape[0] != batch_size: + raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}") cond_512_list = [] cond_1024_list = [] @@ -691,13 +746,10 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - if isinstance(mesh.vertices, list): + if hasattr(mesh, "vertex_counts"): out_verts, out_faces, out_colors = [], [], [] - colors_in = mesh.colors if hasattr(mesh, "colors") and mesh.colors is not None else None - for i in range(len(mesh.vertices)): - v_i = mesh.vertices[i] - f_i = mesh.faces[i] - c_i = colors_in[i] if colors_in is not None else None + for i in range(mesh.vertices.shape[0]): + v_i, f_i, c_i = get_mesh_batch_item(mesh, i) actual_face_count = f_i.shape[0] if fill_holes_perimeter > 0: v_i, f_i = fill_holes_fn(v_i, f_i, max_perimeter=fill_holes_perimeter) @@ -708,9 +760,7 @@ class PostProcessMesh(IO.ComfyNode): out_faces.append(f_i) if c_i is not None: out_colors.append(c_i) - out_mesh = type(mesh)(vertices=out_verts, faces=out_faces) - if len(out_colors) == len(out_verts): - out_mesh.colors = out_colors + out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors if len(out_colors) == len(out_verts) else None) return IO.NodeOutput(out_mesh) verts, faces = mesh.vertices, mesh.faces colors = None From c297a9f839a26c22c44a879aeeb3aed302055448 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 23:23:24 -0500 Subject: [PATCH 3/5] Trellis2: handle empty and batched texture paint paths --- comfy_extras/nodes_trellis2.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 57732151b..7a72b2824 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -135,6 +135,13 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): return out_mesh + +def paint_mesh_default_colors(mesh): + out_mesh = copy.deepcopy(mesh) + vertex_count = mesh.vertices.shape[1] + out_mesh.colors = mesh.vertices.new_zeros((1, vertex_count, 3)) + return out_mesh + class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -216,21 +223,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] - if hasattr(shape_mesh, "vertex_counts"): + mesh_batch_size = shape_mesh.vertices.shape[0] + if mesh_batch_size > 1: out_verts, out_faces, out_colors = [], [], [] - for i in range(shape_mesh.vertices.shape[0]): + for i in range(mesh_batch_size): sel = voxel_batch_idx == i item_coords = voxel_coords[sel] item_colors = color_feats[sel] item_vertices, item_faces, _ = get_mesh_batch_item(shape_mesh, i) item_mesh = Types.MESH(vertices=item_vertices.unsqueeze(0), faces=item_faces.unsqueeze(0)) - painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) + if item_coords.shape[0] == 0: + painted = paint_mesh_default_colors(item_mesh) + else: + painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) out_verts.append(painted.vertices.squeeze(0)) out_faces.append(painted.faces.squeeze(0)) out_colors.append(painted.colors.squeeze(0)) out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors) else: - out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + if voxel_coords.shape[0] == 0: + out_mesh = paint_mesh_default_colors(shape_mesh) + else: + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): From 40219ab0fce492f8ff91f99f909e3b5060483e32 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 23:33:09 -0500 Subject: [PATCH 4/5] Trellis2: share batched mesh helpers --- comfy_extras/mesh_batch_utils.py | 53 +++++++++++++++++++++++++++++ comfy_extras/nodes_hunyuan3d.py | 53 +---------------------------- comfy_extras/nodes_trellis2.py | 58 +++----------------------------- 3 files changed, 58 insertions(+), 106 deletions(-) create mode 100644 comfy_extras/mesh_batch_utils.py diff --git a/comfy_extras/mesh_batch_utils.py b/comfy_extras/mesh_batch_utils.py new file mode 100644 index 000000000..841328776 --- /dev/null +++ b/comfy_extras/mesh_batch_utils.py @@ -0,0 +1,53 @@ +import torch +from comfy_api.latest import Types + + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 0b7e17bb5..78ab3b841 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -10,6 +10,7 @@ from comfy.cli_args import args from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa +from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item class EmptyLatentHunyuan3Dv2(IO.ComfyNode): @@ -631,58 +632,6 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors - - class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 7a72b2824..cdac6f103 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,6 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor +from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item import comfy.model_management from PIL import Image import numpy as np @@ -8,57 +9,6 @@ import torch import scipy import copy - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors - shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, @@ -130,14 +80,14 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): final_colors = linear_colors.unsqueeze(0) - out_mesh = copy.deepcopy(mesh) + out_mesh = copy.copy(mesh) out_mesh.colors = final_colors return out_mesh def paint_mesh_default_colors(mesh): - out_mesh = copy.deepcopy(mesh) + out_mesh = copy.copy(mesh) vertex_count = mesh.vertices.shape[1] out_mesh.colors = mesh.vertices.new_zeros((1, vertex_count, 3)) return out_mesh @@ -400,7 +350,7 @@ class Trellis2Conditioning(IO.ComfyNode): mask = mask.unsqueeze(0) batch_size = image.shape[0] if mask.shape[0] == 1 and batch_size > 1: - mask = mask.repeat(batch_size, 1, 1) + mask = mask.expand(batch_size, -1, -1) elif mask.shape[0] != batch_size: raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}") From 9cfa8f2c0171ca386c2702d482252a3d6cf64ce8 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 23:47:57 -0500 Subject: [PATCH 5/5] Trellis2: inline batched mesh helpers --- comfy_extras/mesh_batch_utils.py | 53 -------------------------------- comfy_extras/nodes_hunyuan3d.py | 52 ++++++++++++++++++++++++++++++- comfy_extras/nodes_trellis2.py | 52 ++++++++++++++++++++++++++++++- 3 files changed, 102 insertions(+), 55 deletions(-) delete mode 100644 comfy_extras/mesh_batch_utils.py diff --git a/comfy_extras/mesh_batch_utils.py b/comfy_extras/mesh_batch_utils.py deleted file mode 100644 index 841328776..000000000 --- a/comfy_extras/mesh_batch_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -from comfy_api.latest import Types - - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 78ab3b841..7ae69db98 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -10,7 +10,6 @@ from comfy.cli_args import args from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa -from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item class EmptyLatentHunyuan3Dv2(IO.ComfyNode): @@ -632,6 +631,57 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index cdac6f103..8121e261b 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,7 +1,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor -from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item import comfy.model_management from PIL import Image import numpy as np @@ -9,6 +8,57 @@ import torch import scipy import copy + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,