diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index ac91fe0a7..8f58e85d9 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -443,7 +443,9 @@ class VoxelToMeshBasic(IO.ComfyNode): vertices.append(v) faces.append(f) - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + return IO.NodeOutput(Types.MESH(vertices, faces)) decode = execute # TODO: remove @@ -479,7 +481,9 @@ class VoxelToMesh(IO.ComfyNode): vertices.append(v) faces.append(f) - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + return IO.NodeOutput(Types.MESH(vertices, faces)) decode = execute # TODO: remove @@ -682,7 +686,8 @@ class SaveGLB(IO.ComfyNode): }) else: # Handle Mesh input - save vertices and faces as GLB - for i in range(mesh.vertices.shape[0]): + bsz = len(mesh.vertices) if isinstance(mesh.vertices, list) else mesh.vertices.shape[0] + for i in range(bsz): f = f"{filename}_{counter:05}_.glb" v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 61d3532a1..8ef3e8f5a 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -117,9 +117,12 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): samples = shape_norm(samples, coords) mesh, subs = vae.decode_shape_slat(samples, resolution) - faces = torch.stack([m.faces for m in mesh]) - verts = torch.stack([m.vertices for m in mesh]) - mesh = Types.MESH(vertices=verts, faces=faces) + face_list = [m.faces for m in mesh] + vert_list = [m.vertices for m in mesh] + if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): + mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) + else: + mesh = Types.MESH(vertices=vert_list, faces=face_list) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -160,8 +163,23 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel = vae.decode_tex_slat(samples, shape_subs) color_feats = voxel.feats[:, :3] voxel_coords = voxel.coords[:, 1:] + voxel_batch_idx = voxel.coords[:, 0] - out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + if isinstance(shape_mesh.vertices, list): + out_verts, out_faces, out_colors = [], [], [] + for i in range(len(shape_mesh.vertices)): + sel = voxel_batch_idx == i + item_coords = voxel_coords[sel] + item_colors = color_feats[sel] + item_mesh = Types.MESH(vertices=shape_mesh.vertices[i].unsqueeze(0), faces=shape_mesh.faces[i].unsqueeze(0)) + painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) + out_verts.append(painted.vertices.squeeze(0)) + out_faces.append(painted.faces.squeeze(0)) + out_colors.append(painted.colors.squeeze(0)) + out_mesh = Types.MESH(vertices=out_verts, faces=out_faces) + out_mesh.colors = out_colors + else: + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @@ -310,69 +328,83 @@ class Trellis2Conditioning(IO.ComfyNode): @classmethod def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: + # Normalize to batched form so per-image conditioning loop below is uniform. + if image.ndim == 3: + image = image.unsqueeze(0) + if mask.ndim == 2: + mask = mask.unsqueeze(0) + batch_size = image.shape[0] - if image.ndim == 4: - image = image[0] - if mask.ndim == 3: - mask = mask[0] + cond_512_list = [] + cond_1024_list = [] - img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + for b in range(batch_size): + item_image = image[b] + item_mask = mask[b] - pil_img = Image.fromarray(img_np) - pil_mask = Image.fromarray(mask_np) + img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - max_size = max(pil_img.size) - scale = min(1.0, 1024 / max_size) - if scale < 1.0: - new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) - pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) - pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) + pil_img = Image.fromarray(img_np) + pil_mask = Image.fromarray(mask_np) - rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) - rgba_np[:, :, :3] = np.array(pil_img) - rgba_np[:, :, 3] = np.array(pil_mask) + max_size = max(pil_img.size) + scale = min(1.0, 1024 / max_size) + if scale < 1.0: + new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) + pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) - alpha = rgba_np[:, :, 3] - bbox_coords = np.argwhere(alpha > 0.8 * 255) + rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) + rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, 3] = np.array(pil_mask) - if len(bbox_coords) > 0: - y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) - y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + alpha = rgba_np[:, :, 3] + bbox_coords = np.argwhere(alpha > 0.8 * 255) - center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 - size = max(y_max - y_min, x_max - x_min) + if len(bbox_coords) > 0: + y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) + y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) - crop_x1 = int(center_x - size // 2) - crop_y1 = int(center_y - size // 2) - crop_x2 = int(center_x + size // 2) - crop_y2 = int(center_y + size // 2) + center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 + size = max(y_max - y_min, x_max - x_min) - rgba_pil = Image.fromarray(rgba_np) - cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) - cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 - else: - import logging - logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") - cropped_np = rgba_np.astype(np.float32) / 255.0 + crop_x1 = int(center_x - size // 2) + crop_y1 = int(center_y - size // 2) + crop_x2 = int(center_x + size // 2) + crop_y2 = int(center_y + size // 2) - bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} - bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) + rgba_pil = Image.fromarray(rgba_np) + cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 + else: + import logging + logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") + cropped_np = rgba_np.astype(np.float32) / 255.0 - fg = cropped_np[:, :, :3] - alpha_float = cropped_np[:, :, 3:4] - composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) + bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) - # to match trellis2 code (quantize -> dequantize) - composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + fg = cropped_np[:, :, :3] + alpha_float = cropped_np[:, :, 3:4] + composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) - cropped_pil = Image.fromarray(composite_uint8) + # to match trellis2 code (quantize -> dequantize) + composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) + cropped_pil = Image.fromarray(composite_uint8) - embeds = conditioning["cond_1024"] - positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] + item_conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) + cond_512_list.append(item_conditioning["cond_512"]) + cond_1024_list.append(item_conditioning["cond_1024"]) + + cond_512_batched = torch.cat(cond_512_list, dim=0) + cond_1024_batched = torch.cat(cond_1024_list, dim=0) + neg_cond_batched = torch.zeros_like(cond_512_batched) + neg_embeds_batched = torch.zeros_like(cond_1024_batched) + + positive = [[cond_512_batched, {"embeds": cond_1024_batched}]] + negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -659,7 +691,27 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - # TODO: batched mode may break + if isinstance(mesh.vertices, list): + out_verts, out_faces, out_colors = [], [], [] + colors_in = mesh.colors if hasattr(mesh, "colors") and mesh.colors is not None else None + for i in range(len(mesh.vertices)): + v_i = mesh.vertices[i] + f_i = mesh.faces[i] + c_i = colors_in[i] if colors_in is not None else None + actual_face_count = f_i.shape[0] + if fill_holes_perimeter > 0: + v_i, f_i = fill_holes_fn(v_i, f_i, max_perimeter=fill_holes_perimeter) + if simplify > 0 and actual_face_count > simplify: + v_i, f_i, c_i = simplify_fn(v_i, f_i, target=simplify, colors=c_i) + v_i, f_i = make_double_sided(v_i, f_i) + out_verts.append(v_i) + out_faces.append(f_i) + if c_i is not None: + out_colors.append(c_i) + out_mesh = type(mesh)(vertices=out_verts, faces=out_faces) + if len(out_colors) == len(out_verts): + out_mesh.colors = out_colors + return IO.NodeOutput(out_mesh) verts, faces = mesh.vertices, mesh.faces colors = None if hasattr(mesh, "colors"):