diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 6ee51e0ab..cad870177 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -302,53 +302,78 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): return v, f -def make_double_sided(vertices, faces, colors=None): +def make_double_sided(vertices, faces, colors=None, normals=None, z_offset=1e-4): """ - Creates a double-sided mesh by duplicating vertices, faces, and colors. - Duplicating vertices prevents opposite faces from sharing the same index, - which stops the rendering engine from cancelling out the normals to [0,0,0]. + Creates double-sided mesh using PER-FACE normals for offset. + This avoids pole singularities completely. """ is_batched = vertices.ndim == 3 + if is_batched: v_list, f_list, c_list = [], [], [] for i in range(vertices.shape[0]): - num_v = vertices[i].shape[0] + # Compute face normals for this mesh + v0 = vertices[i][faces[i][:, 0]] + v1 = vertices[i][faces[i][:, 1]] + v2 = vertices[i][faces[i][:, 2]] + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8) - # Duplicate vertices - v_dup = torch.cat([vertices[i], vertices[i]], dim=0) + # Offset each face's vertices along its face normal + front = torch.stack([v0, v1, v2], dim=1) + fn.unsqueeze(1) * z_offset + back = torch.stack([v0, v1, v2], dim=1) - fn.unsqueeze(1) * z_offset - # Invert faces AND shift indices to point to the new duplicated vertices - f_inv = faces[i][:, [0, 2, 1]] + num_v - f_dup = torch.cat([faces[i], f_inv], dim=0) + front = front.reshape(-1, 3) + back = back.reshape(-1, 3) - v_list.append(v_dup) - f_list.append(f_dup) + f_front = torch.arange(faces[i].shape[0] * 3, device=vertices.device).reshape(-1, 3) + f_back = f_front + faces[i].shape[0] * 3 + f_back = f_back[:, [0, 2, 1]] # flip winding for back faces + + v_list.append(torch.cat([front, back], dim=0)) + f_list.append(torch.cat([f_front, f_back], dim=0)) if colors is not None: - c_dup = torch.cat([colors[i], colors[i]], dim=0) - c_list.append(c_dup) + c_faces = colors[i][faces[i]] + c_front = c_faces.reshape(-1, colors[i].shape[-1]) + c_back = c_front.clone() + c_list.append(torch.cat([c_front, c_back], dim=0)) out_v = torch.stack(v_list) out_f = torch.stack(f_list) - out_c = torch.stack(c_list) if colors is not None else None - return out_v, out_f, out_c + if colors is not None: + return out_v, out_f, torch.stack(c_list) + return out_v, out_f - # --- Unbatched (Single Mesh) --- - num_v = vertices.shape[0] + # --- Unbatched --- + v0 = vertices[faces[:, 0]] + v1 = vertices[faces[:, 1]] + v2 = vertices[faces[:, 2]] + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8) - # duplicate vertices - v_dup = torch.cat([vertices, vertices], dim=0) + # Offset each face's vertices along its face normal + front = torch.stack([v0, v1, v2], dim=1) + fn.unsqueeze(1) * z_offset + back = torch.stack([v0, v1, v2], dim=1) - fn.unsqueeze(1) * z_offset - # invert faces AND shift indices to point to the new duplicated vertices - faces_inv = faces[:, [0, 2, 1]] + num_v - f_dup = torch.cat([faces, faces_inv], dim=0) + front = front.reshape(-1, 3) + back = back.reshape(-1, 3) + + f_front = torch.arange(faces.shape[0] * 3, device=vertices.device).reshape(-1, 3) + f_back = f_front + faces.shape[0] * 3 + f_back = f_back[:, [0, 2, 1]] # flip winding for back faces + + v_dup = torch.cat([front, back], dim=0) + f_dup = torch.cat([f_front, f_back], dim=0) - # duplicate colors if they exist if colors is not None: - c_dup = torch.cat([colors, colors], dim=0) + c_faces = colors[faces] + c_front = c_faces.reshape(-1, colors.shape[-1]) + c_back = c_front.clone() + c_dup = torch.cat([c_front, c_back], dim=0) return v_dup, f_dup, c_dup - return v_dup, f_dup + return v_dup, f_dup, None def _cleanup_mesh(verts, faces, min_angle_deg=0.5, max_aspect=100.0): if faces.numel() == 0: @@ -834,7 +859,7 @@ class PostProcessMesh(IO.ComfyNode): v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count) bar.update(1) - v, f = make_double_sided(v, f, c) + v, f, c = make_double_sided(v, f, c) bar.update(1) return v, f, c