diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 57732151b..7a72b2824 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -135,6 +135,13 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): return out_mesh + +def paint_mesh_default_colors(mesh): + out_mesh = copy.deepcopy(mesh) + vertex_count = mesh.vertices.shape[1] + out_mesh.colors = mesh.vertices.new_zeros((1, vertex_count, 3)) + return out_mesh + class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -216,21 +223,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] - if hasattr(shape_mesh, "vertex_counts"): + mesh_batch_size = shape_mesh.vertices.shape[0] + if mesh_batch_size > 1: out_verts, out_faces, out_colors = [], [], [] - for i in range(shape_mesh.vertices.shape[0]): + for i in range(mesh_batch_size): sel = voxel_batch_idx == i item_coords = voxel_coords[sel] item_colors = color_feats[sel] item_vertices, item_faces, _ = get_mesh_batch_item(shape_mesh, i) item_mesh = Types.MESH(vertices=item_vertices.unsqueeze(0), faces=item_faces.unsqueeze(0)) - painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) + if item_coords.shape[0] == 0: + painted = paint_mesh_default_colors(item_mesh) + else: + painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) out_verts.append(painted.vertices.squeeze(0)) out_faces.append(painted.faces.squeeze(0)) out_colors.append(painted.colors.squeeze(0)) out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors) else: - out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + if voxel_coords.shape[0] == 0: + out_mesh = paint_mesh_default_colors(shape_mesh) + else: + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode):