From 597adfce3ffa96bf5b11c187404e5b492ea28cfc Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 17:22:31 -0500 Subject: [PATCH] fix: stabilize Trellis2 mesh simplification --- comfy_extras/nodes_trellis2.py | 62 +++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8121e261b..8501ef128 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -109,15 +109,15 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): # map voxels voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) - voxel_colors = voxel_colors.to(device) + voxel_colors = voxel_colors.cpu() - voxel_pos_np = voxel_pos.numpy() - verts_np = verts.numpy() + voxel_pos_np = voxel_pos.cpu().numpy() + verts_np = verts.cpu().numpy() tree = scipy.spatial.cKDTree(voxel_pos_np) # nearest neighbour k=1 - _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) + _, nearest_idx_np = tree.query(verts_np, k=1, workers=1) nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] @@ -194,6 +194,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): IO.Latent.Input("samples"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), + IO.Combo.Input("resolution", options=["512", "1024"], default="1024") ], outputs=[ IO.Mesh.Output("mesh"), @@ -201,9 +202,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, shape_mesh, samples, vae, shape_subs): + def execute(cls, shape_mesh, samples, vae, shape_subs, resolution): - resolution = 1024 + resolution = int(resolution) patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) @@ -617,34 +618,49 @@ def simplify_fn(vertices, faces, colors=None, target=100000): volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) cell_size = (volume / target_v) ** (1/3.0) - quantized = ((vertices - min_v) / cell_size).round().long() - unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True) + # Use CPU-side ordered reductions here so repeated runs produce identical + # simplified meshes instead of relying on GPU scatter-add accumulation order. + vertices_np = vertices.detach().cpu().numpy() + faces_np = faces.detach().cpu().numpy() + colors_np = colors.detach().cpu().numpy() if colors is not None else None + min_v_np = min_v.detach().cpu().numpy() + cell_size_value = float(cell_size.detach().cpu()) + + quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64) + unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True) num_cells = unique_coords.shape[0] - new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) - counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) - new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) - counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) - new_vertices = new_vertices / counts.clamp(min=1) + new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype) + np.add.at(new_vertices_np, inverse_indices, vertices_np) + + counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1) + new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None) new_colors = None - if colors is not None: - new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device) - new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors) - new_colors = new_colors / counts.clamp(min=1) + if colors_np is not None: + new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype) + np.add.at(new_colors_np, inverse_indices, colors_np) + new_colors = new_colors_np / np.clip(counts_np, 1, None) - new_faces = inverse_indices[faces] + new_faces = inverse_indices[faces_np] valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ (new_faces[:, 1] != new_faces[:, 2]) & \ (new_faces[:, 2] != new_faces[:, 0]) new_faces = new_faces[valid_mask] - unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) - final_vertices = new_vertices[unique_face_indices] - final_faces = inv_face.reshape(-1, 3) + if new_faces.size == 0: + final_vertices_np = new_vertices_np[:0] + final_faces_np = np.empty((0, 3), dtype=np.int64) + final_colors_np = new_colors[:0] if new_colors is not None else None + else: + unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True) + final_vertices_np = new_vertices_np[unique_face_indices] + final_faces_np = inv_face.reshape(-1, 3).astype(np.int64) + final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None - # assign colors - final_colors = new_colors[unique_face_indices] if new_colors is not None else None + final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype) + final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype) + final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None return final_vertices, final_faces, final_colors