Compare commits

...

2 Commits

Author SHA1 Message Date
Yousef Rafat
9bf7bbb496 .. 2026-05-19 01:42:35 +03:00
Yousef Rafat
f29fb04fa2 normal fix 2026-05-19 01:26:54 +03:00

View File

@ -189,17 +189,14 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0)
edges_sorted, _ = torch.sort(edges, dim=1)
max_v = v.shape[0]
packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long()
unique_packed, counts = torch.unique(packed_undirected, return_counts=True)
boundary_packed = unique_packed[counts == 1]
if boundary_packed.numel() == 0:
return v, f
# Build undirected boundary edge adjacency
boundary_mask = torch.isin(packed_undirected, boundary_packed)
b_edges = edges_sorted[boundary_mask]
@ -213,15 +210,12 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
# Trace all boundary loops
loops = []
visited = set()
for start_node in adj.keys():
if start_node in visited:
continue
curr = start_node
prev = -1
loop = []
while curr not in visited:
visited.add(curr)
loop.append(curr)
@ -239,7 +233,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
if not loops:
return v, f
# Compute mesh normal for orientation
# Mesh normal for winding orientation only
face_normals = torch.linalg.cross(
v[f[:, 1]] - v[f[:, 0]],
v[f[:, 2]] - v[f[:, 0]],
@ -248,25 +242,12 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
mesh_normal = face_normals.mean(dim=0)
mesh_normal = mesh_normal / (torch.norm(mesh_normal) + 1e-8)
# Classify loops: keep only holes (normal aligns with mesh_normal), discard outer boundary
hole_loops = []
for loop in loops:
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
loop_v = v[loop_t]
next_v = torch.roll(loop_v, -1, dims=0)
cross = torch.linalg.cross(loop_v, next_v, dim=-1)
loop_normal = cross.sum(dim=0)
loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8)
# Hole: loop normal points same way as mesh normal (both "up" for a hole)
# Outer boundary: loop normal points opposite
if torch.dot(loop_normal, mesh_normal) > 0:
hole_loops.append(loop)
# === FIX: Fill ALL boundary loops below perimeter threshold ===
new_verts = []
new_faces = []
v_idx = v.shape[0]
for loop in hole_loops:
for loop in loops:
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
loop_v = v[loop_t]
@ -284,6 +265,8 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8)
if torch.dot(loop_normal, mesh_normal) < 0:
loop = loop[::-1]
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
loop_v = v[loop_t]
if len(loop) == 3:
new_faces.append([loop[0], loop[1], loop[2]])
@ -302,53 +285,78 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
return v, f
def make_double_sided(vertices, faces, colors=None):
def make_double_sided(vertices, faces, colors=None, normals=None, z_offset=1e-4):
"""
Creates a double-sided mesh by duplicating vertices, faces, and colors.
Duplicating vertices prevents opposite faces from sharing the same index,
which stops the rendering engine from cancelling out the normals to [0,0,0].
Creates double-sided mesh using PER-FACE normals for offset.
This avoids pole singularities completely.
"""
is_batched = vertices.ndim == 3
if is_batched:
v_list, f_list, c_list = [], [], []
for i in range(vertices.shape[0]):
num_v = vertices[i].shape[0]
# Compute face normals for this mesh
v0 = vertices[i][faces[i][:, 0]]
v1 = vertices[i][faces[i][:, 1]]
v2 = vertices[i][faces[i][:, 2]]
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8)
# Duplicate vertices
v_dup = torch.cat([vertices[i], vertices[i]], dim=0)
# Offset each face's vertices along its face normal
front = torch.stack([v0, v1, v2], dim=1) + fn.unsqueeze(1) * z_offset
back = torch.stack([v0, v1, v2], dim=1) - fn.unsqueeze(1) * z_offset
# Invert faces AND shift indices to point to the new duplicated vertices
f_inv = faces[i][:, [0, 2, 1]] + num_v
f_dup = torch.cat([faces[i], f_inv], dim=0)
front = front.reshape(-1, 3)
back = back.reshape(-1, 3)
v_list.append(v_dup)
f_list.append(f_dup)
f_front = torch.arange(faces[i].shape[0] * 3, device=vertices.device).reshape(-1, 3)
f_back = f_front + faces[i].shape[0] * 3
f_back = f_back[:, [0, 2, 1]] # flip winding for back faces
v_list.append(torch.cat([front, back], dim=0))
f_list.append(torch.cat([f_front, f_back], dim=0))
if colors is not None:
c_dup = torch.cat([colors[i], colors[i]], dim=0)
c_list.append(c_dup)
c_faces = colors[i][faces[i]]
c_front = c_faces.reshape(-1, colors[i].shape[-1])
c_back = c_front.clone()
c_list.append(torch.cat([c_front, c_back], dim=0))
out_v = torch.stack(v_list)
out_f = torch.stack(f_list)
out_c = torch.stack(c_list) if colors is not None else None
return out_v, out_f, out_c
if colors is not None:
return out_v, out_f, torch.stack(c_list)
return out_v, out_f
# --- Unbatched (Single Mesh) ---
num_v = vertices.shape[0]
# --- Unbatched ---
v0 = vertices[faces[:, 0]]
v1 = vertices[faces[:, 1]]
v2 = vertices[faces[:, 2]]
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8)
# duplicate vertices
v_dup = torch.cat([vertices, vertices], dim=0)
# Offset each face's vertices along its face normal
front = torch.stack([v0, v1, v2], dim=1) + fn.unsqueeze(1) * z_offset
back = torch.stack([v0, v1, v2], dim=1) - fn.unsqueeze(1) * z_offset
# invert faces AND shift indices to point to the new duplicated vertices
faces_inv = faces[:, [0, 2, 1]] + num_v
f_dup = torch.cat([faces, faces_inv], dim=0)
front = front.reshape(-1, 3)
back = back.reshape(-1, 3)
f_front = torch.arange(faces.shape[0] * 3, device=vertices.device).reshape(-1, 3)
f_back = f_front + faces.shape[0] * 3
f_back = f_back[:, [0, 2, 1]] # flip winding for back faces
v_dup = torch.cat([front, back], dim=0)
f_dup = torch.cat([f_front, f_back], dim=0)
# duplicate colors if they exist
if colors is not None:
c_dup = torch.cat([colors, colors], dim=0)
c_faces = colors[faces]
c_front = c_faces.reshape(-1, colors.shape[-1])
c_back = c_front.clone()
c_dup = torch.cat([c_front, c_back], dim=0)
return v_dup, f_dup, c_dup
return v_dup, f_dup
return v_dup, f_dup, None
def _cleanup_mesh(verts, faces, min_angle_deg=0.5, max_aspect=100.0):
if faces.numel() == 0:
@ -834,7 +842,7 @@ class PostProcessMesh(IO.ComfyNode):
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
bar.update(1)
v, f = make_double_sided(v, f, c)
v, f, c = make_double_sided(v, f, c)
bar.update(1)
return v, f, c