diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 08e746c97..e7be8cbfe 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -425,10 +425,10 @@ def _gpu_greedy_matching_fast(edges, err, v_alive, max_select): return sel -def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_edge_length=None): +def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces, device, max_edge_length=None): # Use float32 instead of float64. RTX-class consumer GPUs run FP32 ~32-64x # faster than FP64, and QEM only needs the stabilizer for conditioning. - # Always copy=True so we can safely mutate verts/colors in-place. + # Always copy=True so we can safely mutate verts/colors/normals in-place. verts = vertices.detach().to(device=device, dtype=torch.float32, copy=True) faces = faces_in.detach().to(device=device, dtype=torch.int64) colors = ( @@ -436,6 +436,12 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ if colors_in is not None else None ) + # ADDED: Initialize normals + normals = ( + normals_in.detach().to(device=device, dtype=torch.float32, copy=True) + if normals_in is not None + else None + ) num_verts = verts.shape[0] num_faces = faces.shape[0] @@ -495,8 +501,7 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ if edges.shape[0] == 0: break - # Deduplicate edges (each interior edge appears in two adjacent faces). - # Pack (low, high) into a single int64 key and call torch.unique once. + # Deduplicate edges num_compact = alive_v.numel() packed = edges[:, 0].long() * num_compact + edges[:, 1].long() packed = torch.unique(packed) @@ -523,10 +528,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ # Sample edges for processing n_edges_total = edges_orig.shape[0] - max_edges_to_process = 10_000_000 # 10M edges per iteration + max_edges_to_process = 10_000_000 if n_edges_total > max_edges_to_process: - # Random sample without building a full permutation. perm = torch.randint(0, n_edges_total, (max_edges_to_process,), device=device) edges_orig = edges_orig[perm] n_edges = max_edges_to_process @@ -545,8 +549,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ optimal = optimal[valid_idx] err = err[valid_idx] - # Vectorized greedy matching can safely collapse many independent - # edges per iteration, so allow a much larger batch. faces_to_remove = n_faces - target_faces max_collapses = min(1_000_000, max(10_000, faces_to_remove // 4)) @@ -558,12 +560,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ v_a = edges_orig[sel, 0] v_b = edges_orig[sel, 1] - # NOTE: original PostProcessMesh built a CSR (face_indices/vert_ptrs) and - # iterated v_a/v_b in Python here to populate pair_edge_idx/pair_face_idx - # arrays that were then discarded (keep_mask was unconditionally all-True). - # That dead block was the dominant cost on Trellis-sized meshes and has - # been removed in this fast variant. - # Apply collapses verts[v_a] = optimal[sel] v_alive[v_b] = False @@ -572,6 +568,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ if colors is not None: colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5 + if normals is not None: + normals[v_a] = (normals[v_a] + normals[v_b]) * 0.5 + merge_map = torch.arange(num_verts, device=device) merge_map[v_b] = v_a faces = merge_map[faces] @@ -586,12 +585,10 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ total_collapses += v_a.numel() iteration += 1 - # Log only every 50 iterations to reduce sync overhead if iteration % 50 == 0 or n_faces < last_faces * 0.9: logging.debug(f"[QEM-fast] Iter {iteration}: {total_collapses} collapses, {int(f_alive.sum().item())} faces, applied {v_a.numel()}") last_faces = n_faces - # Periodic compaction if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5: faces = faces[f_alive] f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device) @@ -603,6 +600,7 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ # Finalize final_v = verts[v_alive] final_c = colors[v_alive] if colors is not None else None + final_n = normals[v_alive] if normals is not None else None remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device) @@ -617,35 +615,59 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_ if final_f.numel() > 0: final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0) + if final_n is not None and final_f.numel() > 0: + v0, v1, v2 = final_v[final_f[:, 0]], final_v[final_f[:, 1]], final_v[final_f[:, 2]] + + # calculate the actual normal of the simplified faces + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + + # Get the average reference normal for each face + n0, n1, n2 = final_n[final_f[:, 0]], final_n[final_f[:, 1]], final_n[final_f[:, 2]] + ref_face_normals = (n0 + n1 + n2) / 3.0 + + # Dot product to check if they point in the same direction + dot_products = (face_normals * ref_face_normals).sum(dim=-1) + + # Flip the indices of ONLY the incorrect faces (swap vertex 1 and 2) + wrong_way_mask = dot_products < 0 + final_f[wrong_way_mask] = final_f[wrong_way_mask][:, [0, 2, 1]] + final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0) - return final_v, final_f, final_c + return final_v, final_f, final_c, final_n -def simplify_fn_fast(vertices, faces, colors=None, target=100000, max_edge_length=None): +def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None): if vertices.ndim == 3: - v_list, f_list, c_list = [], [], [] + v_list, f_list, c_list, n_list = [], [], [], [] for i in range(vertices.shape[0]): c_in = colors[i] if colors is not None else None - v_i, f_i, c_i = simplify_fn_fast(vertices[i], faces[i], c_in, target, max_edge_length) + n_in = normals[i] if normals is not None else None + v_i, f_i, c_i, n_i = simplify_fn_fast(vertices[i], faces[i], c_in, n_in, target, max_edge_length) v_list.append(v_i) f_list.append(f_i) if c_i is not None: c_list.append(c_i) + if n_i is not None: + n_list.append(n_i) + c_out = torch.stack(c_list) if len(c_list) > 0 else None - return torch.stack(v_list), torch.stack(f_list), c_out + n_out = torch.stack(n_list) if len(n_list) > 0 else None + return torch.stack(v_list), torch.stack(f_list), c_out, n_out if faces.shape[0] <= target: - return vertices, faces, colors + return vertices, faces, colors, normals device = vertices.device dtype = vertices.dtype face_dtype = faces.dtype color_dtype = colors.dtype if colors is not None else None + # ADDED: Normal dtype + normal_dtype = normals.dtype if normals is not None else None # Pass tensors directly; _qem_simplify_fast handles dtype/device + copy. - out_v, out_f, out_c = _qem_simplify_fast( - vertices, faces, colors, target, device, max_edge_length + out_v, out_f, out_c, out_n = _qem_simplify_fast( + vertices, faces, colors, normals, target, device, max_edge_length ) final_v = out_v.to(device=device, dtype=dtype) @@ -655,19 +677,31 @@ def simplify_fn_fast(vertices, faces, colors=None, target=100000, max_edge_lengt if out_c is not None else None ) - return final_v, final_f, final_c + final_n = ( + out_n.to(device=device, dtype=normal_dtype) + if out_n is not None + else None + ) + return final_v, final_f, final_c, final_n -def compute_vertex_normals(vertices, faces): - v0 = vertices[faces[:, 0]] - v1 = vertices[faces[:, 1]] - v2 = vertices[faces[:, 2]] +def compute_vertex_normals(verts, faces): + """Computes area-weighted vertex normals.""" + # QUICK FIX: Ensure indices are int64 for scatter_add_ + faces_long = faces.to(torch.int64) + + i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2] + v0, v1, v2 = verts[i0], verts[i1], verts[i2] + + # calculate unnormalized face normals (magnitude is proportional to area) face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) - n = torch.zeros_like(vertices) - for i in range(3): - n.index_add_(0, faces[:, i], face_normals) + # accumulate face normals to vertices + vertex_normals = torch.zeros_like(verts) + vertex_normals.scatter_add_(0, i0.unsqueeze(-1).expand_as(face_normals), face_normals) + vertex_normals.scatter_add_(0, i1.unsqueeze(-1).expand_as(face_normals), face_normals) + vertex_normals.scatter_add_(0, i2.unsqueeze(-1).expand_as(face_normals), face_normals) - return torch.nn.functional.normalize(n, dim=-1) + return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6) class PostProcessMesh(IO.ComfyNode): @classmethod @@ -704,11 +738,12 @@ class PostProcessMesh(IO.ComfyNode): v, f = fill_holes_fn(v, f, max_perimeter=fill_holes_perimeter) bar.update(1) + n = compute_vertex_normals(v, f) if target_face_count > 0 and f.shape[0] > target_face_count: - v, f, c = simplify_fn_fast(v, f, colors=c, target=target_face_count) + v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count) bar.update(1) - v, f = make_double_sided(v, f) + #v, f = make_double_sided(v, f) bar.update(1) return v, f, c