Compare commits

...

3 Commits

Author SHA1 Message Date
Yousef Rafat
0b9d27ed46 update the double fn 2026-05-18 12:00:29 +03:00
Yousef Rafat
c67ce7df3b fixed normals 2026-05-18 11:38:23 +03:00
Yousef Rafat
7fd083494e fix 2026-05-18 09:56:59 +03:00

View File

@ -152,6 +152,9 @@ class PaintMesh(IO.ComfyNode):
out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors)
return IO.NodeOutput(out_mesh)
if coords.shape[-1] == 4:
coords = coords[:, 1:]
out_mesh = paint_mesh_with_voxels(mesh, coords, colors, resolution=resolution)
return IO.NodeOutput(out_mesh)
@ -244,17 +247,53 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
return v, f
def make_double_sided(vertices, faces):
def make_double_sided(vertices, faces, colors=None):
"""
Creates a double-sided mesh by duplicating vertices, faces, and colors.
Duplicating vertices prevents opposite faces from sharing the same index,
which stops the rendering engine from cancelling out the normals to [0,0,0].
"""
is_batched = vertices.ndim == 3
if is_batched:
f_list = []
for i in range(faces.shape[0]):
f_inv = faces[i][:, [0, 2, 1]]
f_list.append(torch.cat([faces[i], f_inv], dim=0))
return vertices, torch.stack(f_list)
v_list, f_list, c_list = [], [], []
for i in range(vertices.shape[0]):
num_v = vertices[i].shape[0]
faces_inv = faces[:, [0, 2, 1]]
return vertices, torch.cat([faces, faces_inv], dim=0)
# Duplicate vertices
v_dup = torch.cat([vertices[i], vertices[i]], dim=0)
# Invert faces AND shift indices to point to the new duplicated vertices
f_inv = faces[i][:, [0, 2, 1]] + num_v
f_dup = torch.cat([faces[i], f_inv], dim=0)
v_list.append(v_dup)
f_list.append(f_dup)
if colors is not None:
c_dup = torch.cat([colors[i], colors[i]], dim=0)
c_list.append(c_dup)
out_v = torch.stack(v_list)
out_f = torch.stack(f_list)
out_c = torch.stack(c_list) if colors is not None else None
return out_v, out_f, out_c
# --- Unbatched (Single Mesh) ---
num_v = vertices.shape[0]
# duplicate vertices
v_dup = torch.cat([vertices, vertices], dim=0)
# invert faces AND shift indices to point to the new duplicated vertices
faces_inv = faces[:, [0, 2, 1]] + num_v
f_dup = torch.cat([faces, faces_inv], dim=0)
# duplicate colors if they exist
if colors is not None:
c_dup = torch.cat([colors, colors], dim=0)
return v_dup, f_dup, c_dup
return v_dup, f_dup
def _cleanup_mesh(verts, faces, min_angle_deg=0.5, max_aspect=100.0):
if faces.numel() == 0:
@ -422,10 +461,10 @@ def _gpu_greedy_matching_fast(edges, err, v_alive, max_select):
return sel
def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_edge_length=None):
def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces, device, max_edge_length=None):
# Use float32 instead of float64. RTX-class consumer GPUs run FP32 ~32-64x
# faster than FP64, and QEM only needs the stabilizer for conditioning.
# Always copy=True so we can safely mutate verts/colors in-place.
# Always copy=True so we can safely mutate verts/colors/normals in-place.
verts = vertices.detach().to(device=device, dtype=torch.float32, copy=True)
faces = faces_in.detach().to(device=device, dtype=torch.int64)
colors = (
@ -433,6 +472,12 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
if colors_in is not None
else None
)
# ADDED: Initialize normals
normals = (
normals_in.detach().to(device=device, dtype=torch.float32, copy=True)
if normals_in is not None
else None
)
num_verts = verts.shape[0]
num_faces = faces.shape[0]
@ -492,8 +537,7 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
if edges.shape[0] == 0:
break
# Deduplicate edges (each interior edge appears in two adjacent faces).
# Pack (low, high) into a single int64 key and call torch.unique once.
# Deduplicate edges
num_compact = alive_v.numel()
packed = edges[:, 0].long() * num_compact + edges[:, 1].long()
packed = torch.unique(packed)
@ -520,10 +564,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
# Sample edges for processing
n_edges_total = edges_orig.shape[0]
max_edges_to_process = 10_000_000 # 10M edges per iteration
max_edges_to_process = 10_000_000
if n_edges_total > max_edges_to_process:
# Random sample without building a full permutation.
perm = torch.randint(0, n_edges_total, (max_edges_to_process,), device=device)
edges_orig = edges_orig[perm]
n_edges = max_edges_to_process
@ -542,8 +585,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
optimal = optimal[valid_idx]
err = err[valid_idx]
# Vectorized greedy matching can safely collapse many independent
# edges per iteration, so allow a much larger batch.
faces_to_remove = n_faces - target_faces
max_collapses = min(1_000_000, max(10_000, faces_to_remove // 4))
@ -555,12 +596,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
v_a = edges_orig[sel, 0]
v_b = edges_orig[sel, 1]
# NOTE: original PostProcessMesh built a CSR (face_indices/vert_ptrs) and
# iterated v_a/v_b in Python here to populate pair_edge_idx/pair_face_idx
# arrays that were then discarded (keep_mask was unconditionally all-True).
# That dead block was the dominant cost on Trellis-sized meshes and has
# been removed in this fast variant.
# Apply collapses
verts[v_a] = optimal[sel]
v_alive[v_b] = False
@ -569,6 +604,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
if colors is not None:
colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5
if normals is not None:
normals[v_a] = (normals[v_a] + normals[v_b]) * 0.5
merge_map = torch.arange(num_verts, device=device)
merge_map[v_b] = v_a
faces = merge_map[faces]
@ -583,12 +621,10 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
total_collapses += v_a.numel()
iteration += 1
# Log only every 50 iterations to reduce sync overhead
if iteration % 50 == 0 or n_faces < last_faces * 0.9:
logging.debug(f"[QEM-fast] Iter {iteration}: {total_collapses} collapses, {int(f_alive.sum().item())} faces, applied {v_a.numel()}")
last_faces = n_faces
# Periodic compaction
if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5:
faces = faces[f_alive]
f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device)
@ -600,6 +636,7 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
# Finalize
final_v = verts[v_alive]
final_c = colors[v_alive] if colors is not None else None
final_n = normals[v_alive] if normals is not None else None
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device)
@ -614,35 +651,59 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
if final_f.numel() > 0:
final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0)
if final_n is not None and final_f.numel() > 0:
v0, v1, v2 = final_v[final_f[:, 0]], final_v[final_f[:, 1]], final_v[final_f[:, 2]]
# calculate the actual normal of the simplified faces
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
# Get the average reference normal for each face
n0, n1, n2 = final_n[final_f[:, 0]], final_n[final_f[:, 1]], final_n[final_f[:, 2]]
ref_face_normals = (n0 + n1 + n2) / 3.0
# Dot product to check if they point in the same direction
dot_products = (face_normals * ref_face_normals).sum(dim=-1)
# Flip the indices of ONLY the incorrect faces (swap vertex 1 and 2)
wrong_way_mask = dot_products < 0
final_f[wrong_way_mask] = final_f[wrong_way_mask][:, [0, 2, 1]]
final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0)
return final_v, final_f, final_c
return final_v, final_f, final_c, final_n
def simplify_fn_fast(vertices, faces, colors=None, target=100000, max_edge_length=None):
def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None):
if vertices.ndim == 3:
v_list, f_list, c_list = [], [], []
v_list, f_list, c_list, n_list = [], [], [], []
for i in range(vertices.shape[0]):
c_in = colors[i] if colors is not None else None
v_i, f_i, c_i = simplify_fn_fast(vertices[i], faces[i], c_in, target, max_edge_length)
n_in = normals[i] if normals is not None else None
v_i, f_i, c_i, n_i = simplify_fn_fast(vertices[i], faces[i], c_in, n_in, target, max_edge_length)
v_list.append(v_i)
f_list.append(f_i)
if c_i is not None:
c_list.append(c_i)
if n_i is not None:
n_list.append(n_i)
c_out = torch.stack(c_list) if len(c_list) > 0 else None
return torch.stack(v_list), torch.stack(f_list), c_out
n_out = torch.stack(n_list) if len(n_list) > 0 else None
return torch.stack(v_list), torch.stack(f_list), c_out, n_out
if faces.shape[0] <= target:
return vertices, faces, colors
return vertices, faces, colors, normals
device = vertices.device
dtype = vertices.dtype
face_dtype = faces.dtype
color_dtype = colors.dtype if colors is not None else None
# ADDED: Normal dtype
normal_dtype = normals.dtype if normals is not None else None
# Pass tensors directly; _qem_simplify_fast handles dtype/device + copy.
out_v, out_f, out_c = _qem_simplify_fast(
vertices, faces, colors, target, device, max_edge_length
out_v, out_f, out_c, out_n = _qem_simplify_fast(
vertices, faces, colors, normals, target, device, max_edge_length
)
final_v = out_v.to(device=device, dtype=dtype)
@ -652,8 +713,31 @@ def simplify_fn_fast(vertices, faces, colors=None, target=100000, max_edge_lengt
if out_c is not None
else None
)
return final_v, final_f, final_c
final_n = (
out_n.to(device=device, dtype=normal_dtype)
if out_n is not None
else None
)
return final_v, final_f, final_c, final_n
def compute_vertex_normals(verts, faces):
"""Computes area-weighted vertex normals."""
# QUICK FIX: Ensure indices are int64 for scatter_add_
faces_long = faces.to(torch.int64)
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
v0, v1, v2 = verts[i0], verts[i1], verts[i2]
# calculate unnormalized face normals (magnitude is proportional to area)
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
# accumulate face normals to vertices
vertex_normals = torch.zeros_like(verts)
vertex_normals.scatter_add_(0, i0.unsqueeze(-1).expand_as(face_normals), face_normals)
vertex_normals.scatter_add_(0, i1.unsqueeze(-1).expand_as(face_normals), face_normals)
vertex_normals.scatter_add_(0, i2.unsqueeze(-1).expand_as(face_normals), face_normals)
return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6)
class PostProcessMesh(IO.ComfyNode):
@classmethod
@ -690,11 +774,12 @@ class PostProcessMesh(IO.ComfyNode):
v, f = fill_holes_fn(v, f, max_perimeter=fill_holes_perimeter)
bar.update(1)
n = compute_vertex_normals(v, f)
if target_face_count > 0 and f.shape[0] > target_face_count:
v, f, c = simplify_fn_fast(v, f, colors=c, target=target_face_count)
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
bar.update(1)
v, f = make_double_sided(v, f)
v, f = make_double_sided(v, f, c)
bar.update(1)
return v, f, c