Trellis2/Hunyuan3d: n>1 batched cascade support

Mesh-producing nodes (VoxelToMeshBasic, VoxelToMesh, VaeDecodeShapeTrellis)
previously stacked per-batch vertex/face tensors with torch.stack, which
failed under batch>1 because per-item meshes have variable shapes. Store
per-item tensors as lists when shapes differ; keep stacked-tensor fast
path when shapes match. Update SaveGLB, PostProcessMesh, and
VaeDecodeTextureTrellis consumers to iterate per-item when input is a
list.

Trellis2Conditioning.execute previously collapsed batched image/mask
input to index 0, yielding identical conditioning for every batch item.
Loop over the batch and produce per-image cond_512/cond_1024/neg_cond
tensors stacked along the batch dim, matching the latent batch size.

batch_size=1 behavior is unchanged. batch_size>1 runs now emit one GLB
per batch item per SaveGLB node and carry per-image conditioning
through the structure/shape/texture cascade.
This commit is contained in:
John Pollock 2026-04-17 22:42:42 -05:00
parent 45fcf0f9cc
commit 44043ee6e5
2 changed files with 112 additions and 55 deletions

View File

@ -443,7 +443,9 @@ class VoxelToMeshBasic(IO.ComfyNode):
vertices.append(v)
faces.append(f)
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces):
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
return IO.NodeOutput(Types.MESH(vertices, faces))
decode = execute # TODO: remove
@ -479,7 +481,9 @@ class VoxelToMesh(IO.ComfyNode):
vertices.append(v)
faces.append(f)
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces):
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
return IO.NodeOutput(Types.MESH(vertices, faces))
decode = execute # TODO: remove
@ -682,7 +686,8 @@ class SaveGLB(IO.ComfyNode):
})
else:
# Handle Mesh input - save vertices and faces as GLB
for i in range(mesh.vertices.shape[0]):
bsz = len(mesh.vertices) if isinstance(mesh.vertices, list) else mesh.vertices.shape[0]
for i in range(bsz):
f = f"{filename}_{counter:05}_.glb"
v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors)

View File

@ -117,9 +117,12 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
samples = shape_norm(samples, coords)
mesh, subs = vae.decode_shape_slat(samples, resolution)
faces = torch.stack([m.faces for m in mesh])
verts = torch.stack([m.vertices for m in mesh])
mesh = Types.MESH(vertices=verts, faces=faces)
face_list = [m.faces for m in mesh]
vert_list = [m.vertices for m in mesh]
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
else:
mesh = Types.MESH(vertices=vert_list, faces=face_list)
return IO.NodeOutput(mesh, subs)
class VaeDecodeTextureTrellis(IO.ComfyNode):
@ -160,8 +163,23 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
voxel = vae.decode_tex_slat(samples, shape_subs)
color_feats = voxel.feats[:, :3]
voxel_coords = voxel.coords[:, 1:]
voxel_batch_idx = voxel.coords[:, 0]
out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution)
if isinstance(shape_mesh.vertices, list):
out_verts, out_faces, out_colors = [], [], []
for i in range(len(shape_mesh.vertices)):
sel = voxel_batch_idx == i
item_coords = voxel_coords[sel]
item_colors = color_feats[sel]
item_mesh = Types.MESH(vertices=shape_mesh.vertices[i].unsqueeze(0), faces=shape_mesh.faces[i].unsqueeze(0))
painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution)
out_verts.append(painted.vertices.squeeze(0))
out_faces.append(painted.faces.squeeze(0))
out_colors.append(painted.colors.squeeze(0))
out_mesh = Types.MESH(vertices=out_verts, faces=out_faces)
out_mesh.colors = out_colors
else:
out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution)
return IO.NodeOutput(out_mesh)
class VaeDecodeStructureTrellis2(IO.ComfyNode):
@ -310,69 +328,83 @@ class Trellis2Conditioning(IO.ComfyNode):
@classmethod
def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput:
# Normalize to batched form so per-image conditioning loop below is uniform.
if image.ndim == 3:
image = image.unsqueeze(0)
if mask.ndim == 2:
mask = mask.unsqueeze(0)
batch_size = image.shape[0]
if image.ndim == 4:
image = image[0]
if mask.ndim == 3:
mask = mask[0]
cond_512_list = []
cond_1024_list = []
img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
for b in range(batch_size):
item_image = image[b]
item_mask = mask[b]
pil_img = Image.fromarray(img_np)
pil_mask = Image.fromarray(mask_np)
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
max_size = max(pil_img.size)
scale = min(1.0, 1024 / max_size)
if scale < 1.0:
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
pil_img = Image.fromarray(img_np)
pil_mask = Image.fromarray(mask_np)
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
rgba_np[:, :, :3] = np.array(pil_img)
rgba_np[:, :, 3] = np.array(pil_mask)
max_size = max(pil_img.size)
scale = min(1.0, 1024 / max_size)
if scale < 1.0:
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
alpha = rgba_np[:, :, 3]
bbox_coords = np.argwhere(alpha > 0.8 * 255)
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
rgba_np[:, :, :3] = np.array(pil_img)
rgba_np[:, :, 3] = np.array(pil_mask)
if len(bbox_coords) > 0:
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
alpha = rgba_np[:, :, 3]
bbox_coords = np.argwhere(alpha > 0.8 * 255)
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
size = max(y_max - y_min, x_max - x_min)
if len(bbox_coords) > 0:
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
crop_x1 = int(center_x - size // 2)
crop_y1 = int(center_y - size // 2)
crop_x2 = int(center_x + size // 2)
crop_y2 = int(center_y + size // 2)
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
size = max(y_max - y_min, x_max - x_min)
rgba_pil = Image.fromarray(rgba_np)
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
else:
import logging
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
cropped_np = rgba_np.astype(np.float32) / 255.0
crop_x1 = int(center_x - size // 2)
crop_y1 = int(center_y - size // 2)
crop_x2 = int(center_x + size // 2)
crop_y2 = int(center_y + size // 2)
bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
rgba_pil = Image.fromarray(rgba_np)
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
else:
import logging
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
cropped_np = rgba_np.astype(np.float32) / 255.0
fg = cropped_np[:, :, :3]
alpha_float = cropped_np[:, :, 3:4]
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
# to match trellis2 code (quantize -> dequantize)
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
fg = cropped_np[:, :, :3]
alpha_float = cropped_np[:, :, 3:4]
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
cropped_pil = Image.fromarray(composite_uint8)
# to match trellis2 code (quantize -> dequantize)
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True)
cropped_pil = Image.fromarray(composite_uint8)
embeds = conditioning["cond_1024"]
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]]
item_conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True)
cond_512_list.append(item_conditioning["cond_512"])
cond_1024_list.append(item_conditioning["cond_1024"])
cond_512_batched = torch.cat(cond_512_list, dim=0)
cond_1024_batched = torch.cat(cond_1024_list, dim=0)
neg_cond_batched = torch.zeros_like(cond_512_batched)
neg_embeds_batched = torch.zeros_like(cond_1024_batched)
positive = [[cond_512_batched, {"embeds": cond_1024_batched}]]
negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]]
return IO.NodeOutput(positive, negative)
class EmptyShapeLatentTrellis2(IO.ComfyNode):
@ -659,7 +691,27 @@ class PostProcessMesh(IO.ComfyNode):
@classmethod
def execute(cls, mesh, simplify, fill_holes_perimeter):
# TODO: batched mode may break
if isinstance(mesh.vertices, list):
out_verts, out_faces, out_colors = [], [], []
colors_in = mesh.colors if hasattr(mesh, "colors") and mesh.colors is not None else None
for i in range(len(mesh.vertices)):
v_i = mesh.vertices[i]
f_i = mesh.faces[i]
c_i = colors_in[i] if colors_in is not None else None
actual_face_count = f_i.shape[0]
if fill_holes_perimeter > 0:
v_i, f_i = fill_holes_fn(v_i, f_i, max_perimeter=fill_holes_perimeter)
if simplify > 0 and actual_face_count > simplify:
v_i, f_i, c_i = simplify_fn(v_i, f_i, target=simplify, colors=c_i)
v_i, f_i = make_double_sided(v_i, f_i)
out_verts.append(v_i)
out_faces.append(f_i)
if c_i is not None:
out_colors.append(c_i)
out_mesh = type(mesh)(vertices=out_verts, faces=out_faces)
if len(out_colors) == len(out_verts):
out_mesh.colors = out_colors
return IO.NodeOutput(out_mesh)
verts, faces = mesh.vertices, mesh.faces
colors = None
if hasattr(mesh, "colors"):