fixed normals

This commit is contained in:
Yousef Rafat 2026-05-18 11:38:23 +03:00
parent 7fd083494e
commit c67ce7df3b

View File

@ -425,10 +425,10 @@ def _gpu_greedy_matching_fast(edges, err, v_alive, max_select):
return sel
def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_edge_length=None):
def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces, device, max_edge_length=None):
# Use float32 instead of float64. RTX-class consumer GPUs run FP32 ~32-64x
# faster than FP64, and QEM only needs the stabilizer for conditioning.
# Always copy=True so we can safely mutate verts/colors in-place.
# Always copy=True so we can safely mutate verts/colors/normals in-place.
verts = vertices.detach().to(device=device, dtype=torch.float32, copy=True)
faces = faces_in.detach().to(device=device, dtype=torch.int64)
colors = (
@ -436,6 +436,12 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
if colors_in is not None
else None
)
# ADDED: Initialize normals
normals = (
normals_in.detach().to(device=device, dtype=torch.float32, copy=True)
if normals_in is not None
else None
)
num_verts = verts.shape[0]
num_faces = faces.shape[0]
@ -495,8 +501,7 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
if edges.shape[0] == 0:
break
# Deduplicate edges (each interior edge appears in two adjacent faces).
# Pack (low, high) into a single int64 key and call torch.unique once.
# Deduplicate edges
num_compact = alive_v.numel()
packed = edges[:, 0].long() * num_compact + edges[:, 1].long()
packed = torch.unique(packed)
@ -523,10 +528,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
# Sample edges for processing
n_edges_total = edges_orig.shape[0]
max_edges_to_process = 10_000_000 # 10M edges per iteration
max_edges_to_process = 10_000_000
if n_edges_total > max_edges_to_process:
# Random sample without building a full permutation.
perm = torch.randint(0, n_edges_total, (max_edges_to_process,), device=device)
edges_orig = edges_orig[perm]
n_edges = max_edges_to_process
@ -545,8 +549,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
optimal = optimal[valid_idx]
err = err[valid_idx]
# Vectorized greedy matching can safely collapse many independent
# edges per iteration, so allow a much larger batch.
faces_to_remove = n_faces - target_faces
max_collapses = min(1_000_000, max(10_000, faces_to_remove // 4))
@ -558,12 +560,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
v_a = edges_orig[sel, 0]
v_b = edges_orig[sel, 1]
# NOTE: original PostProcessMesh built a CSR (face_indices/vert_ptrs) and
# iterated v_a/v_b in Python here to populate pair_edge_idx/pair_face_idx
# arrays that were then discarded (keep_mask was unconditionally all-True).
# That dead block was the dominant cost on Trellis-sized meshes and has
# been removed in this fast variant.
# Apply collapses
verts[v_a] = optimal[sel]
v_alive[v_b] = False
@ -572,6 +568,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
if colors is not None:
colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5
if normals is not None:
normals[v_a] = (normals[v_a] + normals[v_b]) * 0.5
merge_map = torch.arange(num_verts, device=device)
merge_map[v_b] = v_a
faces = merge_map[faces]
@ -586,12 +585,10 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
total_collapses += v_a.numel()
iteration += 1
# Log only every 50 iterations to reduce sync overhead
if iteration % 50 == 0 or n_faces < last_faces * 0.9:
logging.debug(f"[QEM-fast] Iter {iteration}: {total_collapses} collapses, {int(f_alive.sum().item())} faces, applied {v_a.numel()}")
last_faces = n_faces
# Periodic compaction
if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5:
faces = faces[f_alive]
f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device)
@ -603,6 +600,7 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
# Finalize
final_v = verts[v_alive]
final_c = colors[v_alive] if colors is not None else None
final_n = normals[v_alive] if normals is not None else None
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device)
@ -617,35 +615,59 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
if final_f.numel() > 0:
final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0)
if final_n is not None and final_f.numel() > 0:
v0, v1, v2 = final_v[final_f[:, 0]], final_v[final_f[:, 1]], final_v[final_f[:, 2]]
# calculate the actual normal of the simplified faces
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
# Get the average reference normal for each face
n0, n1, n2 = final_n[final_f[:, 0]], final_n[final_f[:, 1]], final_n[final_f[:, 2]]
ref_face_normals = (n0 + n1 + n2) / 3.0
# Dot product to check if they point in the same direction
dot_products = (face_normals * ref_face_normals).sum(dim=-1)
# Flip the indices of ONLY the incorrect faces (swap vertex 1 and 2)
wrong_way_mask = dot_products < 0
final_f[wrong_way_mask] = final_f[wrong_way_mask][:, [0, 2, 1]]
final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0)
return final_v, final_f, final_c
return final_v, final_f, final_c, final_n
def simplify_fn_fast(vertices, faces, colors=None, target=100000, max_edge_length=None):
def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None):
if vertices.ndim == 3:
v_list, f_list, c_list = [], [], []
v_list, f_list, c_list, n_list = [], [], [], []
for i in range(vertices.shape[0]):
c_in = colors[i] if colors is not None else None
v_i, f_i, c_i = simplify_fn_fast(vertices[i], faces[i], c_in, target, max_edge_length)
n_in = normals[i] if normals is not None else None
v_i, f_i, c_i, n_i = simplify_fn_fast(vertices[i], faces[i], c_in, n_in, target, max_edge_length)
v_list.append(v_i)
f_list.append(f_i)
if c_i is not None:
c_list.append(c_i)
if n_i is not None:
n_list.append(n_i)
c_out = torch.stack(c_list) if len(c_list) > 0 else None
return torch.stack(v_list), torch.stack(f_list), c_out
n_out = torch.stack(n_list) if len(n_list) > 0 else None
return torch.stack(v_list), torch.stack(f_list), c_out, n_out
if faces.shape[0] <= target:
return vertices, faces, colors
return vertices, faces, colors, normals
device = vertices.device
dtype = vertices.dtype
face_dtype = faces.dtype
color_dtype = colors.dtype if colors is not None else None
# ADDED: Normal dtype
normal_dtype = normals.dtype if normals is not None else None
# Pass tensors directly; _qem_simplify_fast handles dtype/device + copy.
out_v, out_f, out_c = _qem_simplify_fast(
vertices, faces, colors, target, device, max_edge_length
out_v, out_f, out_c, out_n = _qem_simplify_fast(
vertices, faces, colors, normals, target, device, max_edge_length
)
final_v = out_v.to(device=device, dtype=dtype)
@ -655,19 +677,31 @@ def simplify_fn_fast(vertices, faces, colors=None, target=100000, max_edge_lengt
if out_c is not None
else None
)
return final_v, final_f, final_c
final_n = (
out_n.to(device=device, dtype=normal_dtype)
if out_n is not None
else None
)
return final_v, final_f, final_c, final_n
def compute_vertex_normals(vertices, faces):
v0 = vertices[faces[:, 0]]
v1 = vertices[faces[:, 1]]
v2 = vertices[faces[:, 2]]
def compute_vertex_normals(verts, faces):
"""Computes area-weighted vertex normals."""
# QUICK FIX: Ensure indices are int64 for scatter_add_
faces_long = faces.to(torch.int64)
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
v0, v1, v2 = verts[i0], verts[i1], verts[i2]
# calculate unnormalized face normals (magnitude is proportional to area)
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
n = torch.zeros_like(vertices)
for i in range(3):
n.index_add_(0, faces[:, i], face_normals)
# accumulate face normals to vertices
vertex_normals = torch.zeros_like(verts)
vertex_normals.scatter_add_(0, i0.unsqueeze(-1).expand_as(face_normals), face_normals)
vertex_normals.scatter_add_(0, i1.unsqueeze(-1).expand_as(face_normals), face_normals)
vertex_normals.scatter_add_(0, i2.unsqueeze(-1).expand_as(face_normals), face_normals)
return torch.nn.functional.normalize(n, dim=-1)
return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6)
class PostProcessMesh(IO.ComfyNode):
@classmethod
@ -704,11 +738,12 @@ class PostProcessMesh(IO.ComfyNode):
v, f = fill_holes_fn(v, f, max_perimeter=fill_holes_perimeter)
bar.update(1)
n = compute_vertex_normals(v, f)
if target_face_count > 0 and f.shape[0] > target_face_count:
v, f, c = simplify_fn_fast(v, f, colors=c, target=target_face_count)
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
bar.update(1)
v, f = make_double_sided(v, f)
#v, f = make_double_sided(v, f)
bar.update(1)
return v, f, c