diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index bf457135c..817769d08 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -394,11 +394,13 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): visited_edges = set() processed_nodes = set() for start_node in list(adj.keys()): - if start_node in processed_nodes: continue + if start_node in processed_nodes: + continue current_loop, curr = [], start_node while curr in adj: next_node = adj[curr] - if (curr, next_node) in visited_edges: break + if (curr, next_node) in visited_edges: + break visited_edges.add((curr, next_node)) processed_nodes.add(curr) current_loop.append(curr) @@ -406,10 +408,12 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): if curr == start_node: loops.append(current_loop) break - if len(current_loop) > len(edges_np): break + if len(current_loop) > len(edges_np): + break if not loops: - if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + if is_batched: + return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) return orig_vertices, orig_faces new_faces = [] @@ -417,13 +421,15 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): valid_new_verts = [] for loop_indices in loops: - if len(loop_indices) < 3: continue + if len(loop_indices) < 3: + continue loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) loop_verts = vertices[loop_tensor] diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) perimeter = torch.norm(diffs, dim=1).sum() - if perimeter > max_hole_perimeter: continue + if perimeter > max_hole_perimeter: + continue center = loop_verts.mean(dim=0) valid_new_verts.append(center)