mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-17 21:39:45 +08:00
bug fix in fill_holes
This commit is contained in:
parent
0b9d27ed46
commit
faeb47b3c8
@ -809,6 +809,10 @@ class Trellis2(nn.Module):
|
||||
mode = "structure_generation"
|
||||
not_struct_mode = False
|
||||
|
||||
if x.size(-1) == 16 and x.size(-2) == 16:
|
||||
mode = "structure_generation"
|
||||
not_struct_mode = False
|
||||
|
||||
if not not_struct_mode:
|
||||
bsz = x.size(0)
|
||||
x = x[:, :8]
|
||||
|
||||
@ -168,11 +168,16 @@ def paint_mesh_default_colors(mesh):
|
||||
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
is_batched = vertices.ndim == 3
|
||||
if is_batched:
|
||||
v_list, f_list = [],[]
|
||||
v_list, f_list = [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
max_v = max(v.shape[0] for v in v_list)
|
||||
for i in range(len(v_list)):
|
||||
if v_list[i].shape[0] < max_v:
|
||||
pad = torch.zeros(max_v - v_list[i].shape[0], 3, device=v_list[i].device, dtype=v_list[i].dtype)
|
||||
v_list[i] = torch.cat([v_list[i], pad], dim=0)
|
||||
return torch.stack(v_list), torch.stack(f_list)
|
||||
|
||||
device = vertices.device
|
||||
@ -194,13 +199,19 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
if boundary_packed.numel() == 0:
|
||||
return v, f
|
||||
|
||||
packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long()
|
||||
is_boundary = torch.isin(packed_directed_sorted, boundary_packed)
|
||||
b_edges = edges[is_boundary]
|
||||
# Build undirected boundary edge adjacency
|
||||
boundary_mask = torch.isin(packed_undirected, boundary_packed)
|
||||
b_edges = edges_sorted[boundary_mask]
|
||||
|
||||
adj = {u.item(): v_idx.item() for u, v_idx in b_edges}
|
||||
adj = {}
|
||||
for i in range(b_edges.shape[0]):
|
||||
a = b_edges[i, 0].item()
|
||||
b = b_edges[i, 1].item()
|
||||
adj.setdefault(a, []).append(b)
|
||||
adj.setdefault(b, []).append(a)
|
||||
|
||||
loops =[]
|
||||
# Trace all boundary loops
|
||||
loops = []
|
||||
visited = set()
|
||||
|
||||
for start_node in adj.keys():
|
||||
@ -208,40 +219,84 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
continue
|
||||
|
||||
curr = start_node
|
||||
prev = -1
|
||||
loop = []
|
||||
|
||||
while curr not in visited:
|
||||
visited.add(curr)
|
||||
loop.append(curr)
|
||||
curr = adj.get(curr, -1)
|
||||
|
||||
if curr == -1:
|
||||
neighbors = adj[curr]
|
||||
candidates = [n for n in neighbors if n != prev]
|
||||
if not candidates:
|
||||
loop = []
|
||||
break
|
||||
next_node = candidates[0]
|
||||
prev, curr = curr, next_node
|
||||
if curr == start_node:
|
||||
loops.append(loop)
|
||||
break
|
||||
|
||||
new_verts =[]
|
||||
new_faces = []
|
||||
v_idx = v.shape[0]
|
||||
if not loops:
|
||||
return v, f
|
||||
|
||||
# Compute mesh normal for orientation
|
||||
face_normals = torch.linalg.cross(
|
||||
v[f[:, 1]] - v[f[:, 0]],
|
||||
v[f[:, 2]] - v[f[:, 0]],
|
||||
dim=-1
|
||||
)
|
||||
mesh_normal = face_normals.mean(dim=0)
|
||||
mesh_normal = mesh_normal / (torch.norm(mesh_normal) + 1e-8)
|
||||
|
||||
# Classify loops: keep only holes (normal aligns with mesh_normal), discard outer boundary
|
||||
hole_loops = []
|
||||
for loop in loops:
|
||||
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
|
||||
loop_v = v[loop_t]
|
||||
next_v = torch.roll(loop_v, -1, dims=0)
|
||||
cross = torch.linalg.cross(loop_v, next_v, dim=-1)
|
||||
loop_normal = cross.sum(dim=0)
|
||||
loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8)
|
||||
# Hole: loop normal points same way as mesh normal (both "up" for a hole)
|
||||
# Outer boundary: loop normal points opposite
|
||||
if torch.dot(loop_normal, mesh_normal) > 0:
|
||||
hole_loops.append(loop)
|
||||
|
||||
diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0)
|
||||
new_verts = []
|
||||
new_faces = []
|
||||
v_idx = v.shape[0]
|
||||
|
||||
for loop in hole_loops:
|
||||
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
|
||||
loop_v = v[loop_t]
|
||||
|
||||
# Perimeter check
|
||||
next_v = torch.roll(loop_v, -1, dims=0)
|
||||
diffs = loop_v - next_v
|
||||
perimeter = torch.norm(diffs, dim=1).sum().item()
|
||||
|
||||
if perimeter <= max_perimeter:
|
||||
new_verts.append(loop_v.mean(dim=0))
|
||||
if perimeter > max_perimeter:
|
||||
continue
|
||||
|
||||
# Ensure CCW winding consistent with mesh
|
||||
cross = torch.linalg.cross(loop_v, next_v, dim=-1)
|
||||
loop_normal = cross.sum(dim=0)
|
||||
loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8)
|
||||
if torch.dot(loop_normal, mesh_normal) < 0:
|
||||
loop = loop[::-1]
|
||||
|
||||
if len(loop) == 3:
|
||||
new_faces.append([loop[0], loop[1], loop[2]])
|
||||
else:
|
||||
centroid = loop_v.mean(dim=0)
|
||||
new_verts.append(centroid)
|
||||
for i in range(len(loop)):
|
||||
new_faces.append([loop[(i + 1) % len(loop)], loop[i], v_idx])
|
||||
new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx])
|
||||
v_idx += 1
|
||||
|
||||
if new_verts:
|
||||
v = torch.cat([v, torch.stack(new_verts)], dim=0)
|
||||
if new_faces:
|
||||
f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0)
|
||||
|
||||
return v, f
|
||||
|
||||
Loading…
Reference in New Issue
Block a user