diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 79da2e808..3c61a5d77 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -809,6 +809,10 @@ class Trellis2(nn.Module): mode = "structure_generation" not_struct_mode = False + if x.size(-1) == 16 and x.size(-2) == 16: + mode = "structure_generation" + not_struct_mode = False + if not not_struct_mode: bsz = x.size(0) x = x[:, :8] diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 3243ff869..38447d88b 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -168,11 +168,16 @@ def paint_mesh_default_colors(mesh): def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 if is_batched: - v_list, f_list = [],[] + v_list, f_list = [], [] for i in range(vertices.shape[0]): v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter) v_list.append(v_i) f_list.append(f_i) + max_v = max(v.shape[0] for v in v_list) + for i in range(len(v_list)): + if v_list[i].shape[0] < max_v: + pad = torch.zeros(max_v - v_list[i].shape[0], 3, device=v_list[i].device, dtype=v_list[i].dtype) + v_list[i] = torch.cat([v_list[i], pad], dim=0) return torch.stack(v_list), torch.stack(f_list) device = vertices.device @@ -194,13 +199,19 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): if boundary_packed.numel() == 0: return v, f - packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long() - is_boundary = torch.isin(packed_directed_sorted, boundary_packed) - b_edges = edges[is_boundary] + # Build undirected boundary edge adjacency + boundary_mask = torch.isin(packed_undirected, boundary_packed) + b_edges = edges_sorted[boundary_mask] - adj = {u.item(): v_idx.item() for u, v_idx in b_edges} + adj = {} + for i in range(b_edges.shape[0]): + a = b_edges[i, 0].item() + b = b_edges[i, 1].item() + adj.setdefault(a, []).append(b) + adj.setdefault(b, []).append(a) - loops =[] + # Trace all boundary loops + loops = [] visited = set() for start_node in adj.keys(): @@ -208,40 +219,84 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): continue curr = start_node + prev = -1 loop = [] while curr not in visited: visited.add(curr) loop.append(curr) - curr = adj.get(curr, -1) - - if curr == -1: + neighbors = adj[curr] + candidates = [n for n in neighbors if n != prev] + if not candidates: loop = [] break + next_node = candidates[0] + prev, curr = curr, next_node if curr == start_node: loops.append(loop) break - new_verts =[] - new_faces = [] - v_idx = v.shape[0] + if not loops: + return v, f + # Compute mesh normal for orientation + face_normals = torch.linalg.cross( + v[f[:, 1]] - v[f[:, 0]], + v[f[:, 2]] - v[f[:, 0]], + dim=-1 + ) + mesh_normal = face_normals.mean(dim=0) + mesh_normal = mesh_normal / (torch.norm(mesh_normal) + 1e-8) + + # Classify loops: keep only holes (normal aligns with mesh_normal), discard outer boundary + hole_loops = [] for loop in loops: loop_t = torch.tensor(loop, device=device, dtype=torch.long) loop_v = v[loop_t] + next_v = torch.roll(loop_v, -1, dims=0) + cross = torch.linalg.cross(loop_v, next_v, dim=-1) + loop_normal = cross.sum(dim=0) + loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8) + # Hole: loop normal points same way as mesh normal (both "up" for a hole) + # Outer boundary: loop normal points opposite + if torch.dot(loop_normal, mesh_normal) > 0: + hole_loops.append(loop) - diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0) + new_verts = [] + new_faces = [] + v_idx = v.shape[0] + + for loop in hole_loops: + loop_t = torch.tensor(loop, device=device, dtype=torch.long) + loop_v = v[loop_t] + + # Perimeter check + next_v = torch.roll(loop_v, -1, dims=0) + diffs = loop_v - next_v perimeter = torch.norm(diffs, dim=1).sum().item() - if perimeter <= max_perimeter: - new_verts.append(loop_v.mean(dim=0)) + if perimeter > max_perimeter: + continue + # Ensure CCW winding consistent with mesh + cross = torch.linalg.cross(loop_v, next_v, dim=-1) + loop_normal = cross.sum(dim=0) + loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8) + if torch.dot(loop_normal, mesh_normal) < 0: + loop = loop[::-1] + + if len(loop) == 3: + new_faces.append([loop[0], loop[1], loop[2]]) + else: + centroid = loop_v.mean(dim=0) + new_verts.append(centroid) for i in range(len(loop)): - new_faces.append([loop[(i + 1) % len(loop)], loop[i], v_idx]) + new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx]) v_idx += 1 if new_verts: v = torch.cat([v, torch.stack(new_verts)], dim=0) + if new_faces: f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) return v, f