fix: stabilize Trellis2 mesh simplification

This commit is contained in:
John Pollock 2026-04-20 17:22:31 -05:00
parent 880d7823e8
commit 597adfce3f

View File

@ -109,15 +109,15 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
# map voxels
voxel_pos = voxel_coords.to(device).float() * voxel_size + origin
verts = mesh.vertices.to(device).squeeze(0)
voxel_colors = voxel_colors.to(device)
voxel_colors = voxel_colors.cpu()
voxel_pos_np = voxel_pos.numpy()
verts_np = verts.numpy()
voxel_pos_np = voxel_pos.cpu().numpy()
verts_np = verts.cpu().numpy()
tree = scipy.spatial.cKDTree(voxel_pos_np)
# nearest neighbour k=1
_, nearest_idx_np = tree.query(verts_np, k=1, workers=-1)
_, nearest_idx_np = tree.query(verts_np, k=1, workers=1)
nearest_idx = torch.from_numpy(nearest_idx_np).long()
v_colors = voxel_colors[nearest_idx]
@ -194,6 +194,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.AnyType.Input("shape_subs"),
IO.Combo.Input("resolution", options=["512", "1024"], default="1024")
],
outputs=[
IO.Mesh.Output("mesh"),
@ -201,9 +202,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
)
@classmethod
def execute(cls, shape_mesh, samples, vae, shape_subs):
def execute(cls, shape_mesh, samples, vae, shape_subs, resolution):
resolution = 1024
resolution = int(resolution)
patcher = vae.patcher
device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher)
@ -617,34 +618,49 @@ def simplify_fn(vertices, faces, colors=None, target=100000):
volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8)
cell_size = (volume / target_v) ** (1/3.0)
quantized = ((vertices - min_v) / cell_size).round().long()
unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True)
# Use CPU-side ordered reductions here so repeated runs produce identical
# simplified meshes instead of relying on GPU scatter-add accumulation order.
vertices_np = vertices.detach().cpu().numpy()
faces_np = faces.detach().cpu().numpy()
colors_np = colors.detach().cpu().numpy() if colors is not None else None
min_v_np = min_v.detach().cpu().numpy()
cell_size_value = float(cell_size.detach().cpu())
quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64)
unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True)
num_cells = unique_coords.shape[0]
new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device)
counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device)
new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices)
counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1]))
new_vertices = new_vertices / counts.clamp(min=1)
new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype)
np.add.at(new_vertices_np, inverse_indices, vertices_np)
counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1)
new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None)
new_colors = None
if colors is not None:
new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device)
new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors)
new_colors = new_colors / counts.clamp(min=1)
if colors_np is not None:
new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype)
np.add.at(new_colors_np, inverse_indices, colors_np)
new_colors = new_colors_np / np.clip(counts_np, 1, None)
new_faces = inverse_indices[faces]
new_faces = inverse_indices[faces_np]
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
(new_faces[:, 1] != new_faces[:, 2]) & \
(new_faces[:, 2] != new_faces[:, 0])
new_faces = new_faces[valid_mask]
unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True)
final_vertices = new_vertices[unique_face_indices]
final_faces = inv_face.reshape(-1, 3)
if new_faces.size == 0:
final_vertices_np = new_vertices_np[:0]
final_faces_np = np.empty((0, 3), dtype=np.int64)
final_colors_np = new_colors[:0] if new_colors is not None else None
else:
unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True)
final_vertices_np = new_vertices_np[unique_face_indices]
final_faces_np = inv_face.reshape(-1, 3).astype(np.int64)
final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None
# assign colors
final_colors = new_colors[unique_face_indices] if new_colors is not None else None
final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype)
final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype)
final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None
return final_vertices, final_faces, final_colors