From 955c00ee38356fd35d4c347f617ef15cb10659ec Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 6 Feb 2026 23:54:27 +0200 Subject: [PATCH] post-process node --- comfy/ldm/trellis2/cumesh.py | 127 ---------------------- comfy/ldm/trellis2/vae.py | 22 ---- comfy_extras/nodes_trellis2.py | 187 ++++++++++++++++++++++++++++++++- 3 files changed, 186 insertions(+), 150 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index fe7e80e15..972fb13c3 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -469,133 +469,6 @@ class Mesh: def cpu(self): return self.to('cpu') - # could make this into a new node - def fill_holes(self, max_hole_perimeter=3e-2): - - device = self.vertices.device - vertices = self.vertices - faces = self.faces - - edges = torch.cat([ - faces[:, [0, 1]], - faces[:, [1, 2]], - faces[:, [2, 0]] - ], dim=0) - - edges_sorted, _ = torch.sort(edges, dim=1) - - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) - - boundary_mask = counts == 1 - boundary_edges_sorted = unique_edges[boundary_mask] - - if boundary_edges_sorted.shape[0] == 0: - return - max_idx = vertices.shape[0] - - _, inverse_indices, counts_packed = torch.unique( - torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], - return_inverse=True, return_counts=True - ) - - boundary_packed_mask = counts_packed == 1 - is_boundary_edge = boundary_packed_mask[inverse_indices] - - active_boundary_edges = edges[is_boundary_edge] - - adj = {} - edges_np = active_boundary_edges.cpu().numpy() - for u, v in edges_np: - adj[u] = v - - loops = [] - visited_edges = set() - - possible_starts = list(adj.keys()) - - processed_nodes = set() - - for start_node in possible_starts: - if start_node in processed_nodes: - continue - - current_loop = [] - curr = start_node - - while curr in adj: - next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - - visited_edges.add((curr, next_node)) - processed_nodes.add(curr) - current_loop.append(curr) - - curr = next_node - - if curr == start_node: - loops.append(current_loop) - break - - if len(current_loop) > len(edges_np): - break - - if not loops: - return - - new_faces = [] - - v_offset = vertices.shape[0] - - valid_new_verts = [] - - for loop_indices in loops: - if len(loop_indices) < 3: - continue - - loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) - loop_verts = vertices[loop_tensor] - - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) - perimeter = torch.norm(diffs, dim=1).sum() - - if perimeter > max_hole_perimeter: - continue - - center = loop_verts.mean(dim=0) - valid_new_verts.append(center) - - c_idx = v_offset - v_offset += 1 - - num_v = len(loop_indices) - for i in range(num_v): - v_curr = loop_indices[i] - v_next = loop_indices[(i + 1) % num_v] - new_faces.append([v_curr, v_next, c_idx]) - - if len(valid_new_verts) > 0: - added_vertices = torch.stack(valid_new_verts, dim=0) - added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) - - self.vertices = torch.cat([self.vertices, added_vertices], dim=0) - self.faces = torch.cat([self.faces, added_faces], dim=0) - - - # TODO could be an option - def simplify(self, target=1000000, verbose: bool=False, options: dict={}): - import cumesh - vertices = self.vertices.cuda() - faces = self.faces.cuda() - - mesh = cumesh.CuMesh() - mesh.init(vertices, faces) - mesh.simplify(target, verbose=verbose, options=options) - new_vertices, new_faces = mesh.read() - - self.vertices = new_vertices.to(self.device) - self.faces = new_faces.to(self.device) - class MeshWithVoxel(Mesh, Voxel): def __init__(self, vertices: torch.Tensor, diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index d997bbc41..1d26986cc 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -231,27 +231,6 @@ class config: CONV = "flexgemm" FLEX_GEMM_HASHMAP_RATIO = 2.0 -# TODO post processing -def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): - - num_face = self.cu_mesh.num_faces() - if num_face <= target_num_faces: - return - - thresh = options.get('thresh', 1e-8) - lambda_edge_length = options.get('lambda_edge_length', 1e-2) - lambda_skinny = options.get('lambda_skinny', 1e-3) - while True: - new_num_vert, new_num_face = self.cu_mesh.simplify_step(lambda_edge_length, lambda_skinny, thresh, False) - - if new_num_face <= target_num_faces: - break - - del_num_face = num_face - new_num_face - if del_num_face / num_face < 1e-2: - thresh *= 10 - num_face = new_num_face - class VarLenTensor: def __init__(self, feats: torch.Tensor, layout: List[slice]=None): @@ -1530,7 +1509,6 @@ class Vae(nn.Module): tex_voxels = self.decode_tex_slat(tex_slat, subs) out_mesh = [] for m, v in zip(meshes, tex_voxels): - m.fill_holes() # TODO out_mesh.append( MeshWithVoxel( m.vertices, m.faces, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4a36e2fee..8497b83e2 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -281,6 +281,190 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) +def simplify_fn(vertices, faces, target=100000): + + if vertices.shape[0] <= target: + return + + min_feat = vertices.min(dim=0)[0] + max_feat = vertices.max(dim=0)[0] + extent = (max_feat - min_feat).max() + + grid_resolution = int(torch.sqrt(torch.tensor(target)).item() * 1.5) + voxel_size = extent / grid_resolution + + quantized_coords = ((vertices - min_feat) / voxel_size).long() + + unique_coords, inverse_indices = torch.unique(quantized_coords, dim=0, return_inverse=True) + + num_new_verts = unique_coords.shape[0] + new_vertices = torch.zeros((num_new_verts, 3), dtype=vertices.dtype, device=vertices.device) + + counts = torch.zeros((num_new_verts, 1), dtype=vertices.dtype, device=vertices.device) + + new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) + counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) + + new_vertices = new_vertices / counts.clamp(min=1) + + new_faces = inverse_indices[faces] + + v0 = new_faces[:, 0] + v1 = new_faces[:, 1] + v2 = new_faces[:, 2] + + valid_mask = (v0 != v1) & (v1 != v2) & (v2 != v0) + new_faces = new_faces[valid_mask] + + unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) + final_vertices = new_vertices[unique_face_indices] + final_faces = inv_face.reshape(-1, 3) + + return final_vertices, final_faces + +def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): + + device = vertices.device + orig_vertices = vertices + orig_faces = faces + + edges = torch.cat([ + faces[:, [0, 1]], + faces[:, [1, 2]], + faces[:, [2, 0]] + ], dim=0) + + edges_sorted, _ = torch.sort(edges, dim=1) + + unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + boundary_mask = counts == 1 + boundary_edges_sorted = unique_edges[boundary_mask] + + if boundary_edges_sorted.shape[0] == 0: + return + max_idx = vertices.shape[0] + + _, inverse_indices, counts_packed = torch.unique( + torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], + return_inverse=True, return_counts=True + ) + + boundary_packed_mask = counts_packed == 1 + is_boundary_edge = boundary_packed_mask[inverse_indices] + + active_boundary_edges = edges[is_boundary_edge] + + adj = {} + edges_np = active_boundary_edges.cpu().numpy() + for u, v in edges_np: + adj[u] = v + + loops = [] + visited_edges = set() + + possible_starts = list(adj.keys()) + + processed_nodes = set() + + for start_node in possible_starts: + if start_node in processed_nodes: + continue + + current_loop = [] + curr = start_node + + while curr in adj: + next_node = adj[curr] + if (curr, next_node) in visited_edges: + break + + visited_edges.add((curr, next_node)) + processed_nodes.add(curr) + current_loop.append(curr) + + curr = next_node + + if curr == start_node: + loops.append(current_loop) + break + + if len(current_loop) > len(edges_np): + break + + if not loops: + return + + new_faces = [] + + v_offset = vertices.shape[0] + + valid_new_verts = [] + + for loop_indices in loops: + if len(loop_indices) < 3: + continue + + loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) + loop_verts = vertices[loop_tensor] + + diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum() + + if perimeter > max_hole_perimeter: + continue + + center = loop_verts.mean(dim=0) + valid_new_verts.append(center) + + c_idx = v_offset + v_offset += 1 + + num_v = len(loop_indices) + for i in range(num_v): + v_curr = loop_indices[i] + v_next = loop_indices[(i + 1) % num_v] + new_faces.append([v_curr, v_next, c_idx]) + + if len(valid_new_verts) > 0: + added_vertices = torch.stack(valid_new_verts, dim=0) + added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + + vertices_f = torch.cat([orig_vertices, added_vertices], dim=0) + faces_f = torch.cat([orig_faces, added_faces], dim=0) + + return vertices_f, faces_f + +class PostProcessMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PostProcessMesh", + category="latent/3d", + inputs=[ + IO.Mesh.Input("mesh"), + IO.Int.Input("simplify", default=100_000, min=0), # max? + IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0) + ], + outputs=[ + IO.Mesh.Output("output_mesh"), + ] + ) + @classmethod + def execute(cls, mesh, simplify, fill_holes_perimeter): + verts, faces = mesh.vertices, mesh.faces + + if fill_holes_perimeter != 0.0: + verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) + + if simplify != 0: + verts, faces = simplify_fn(verts, faces, simplify) + + + mesh.vertices = verts + mesh.faces = faces + + return mesh class Trellis2Extension(ComfyExtension): @override @@ -292,7 +476,8 @@ class Trellis2Extension(ComfyExtension): EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, VaeDecodeShapeTrellis, - VaeDecodeStructureTrellis2 + VaeDecodeStructureTrellis2, + PostProcessMesh ]