mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-14 11:59:21 +08:00
Compare commits
3 Commits
739f4ebb20
...
0b9d27ed46
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b9d27ed46 | ||
|
|
c67ce7df3b | ||
|
|
7fd083494e |
@ -152,6 +152,9 @@ class PaintMesh(IO.ComfyNode):
|
||||
out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
if coords.shape[-1] == 4:
|
||||
coords = coords[:, 1:]
|
||||
|
||||
out_mesh = paint_mesh_with_voxels(mesh, coords, colors, resolution=resolution)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
@ -244,17 +247,53 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
return v, f
|
||||
|
||||
|
||||
def make_double_sided(vertices, faces):
|
||||
def make_double_sided(vertices, faces, colors=None):
|
||||
"""
|
||||
Creates a double-sided mesh by duplicating vertices, faces, and colors.
|
||||
Duplicating vertices prevents opposite faces from sharing the same index,
|
||||
which stops the rendering engine from cancelling out the normals to [0,0,0].
|
||||
"""
|
||||
is_batched = vertices.ndim == 3
|
||||
if is_batched:
|
||||
f_list = []
|
||||
for i in range(faces.shape[0]):
|
||||
f_inv = faces[i][:, [0, 2, 1]]
|
||||
f_list.append(torch.cat([faces[i], f_inv], dim=0))
|
||||
return vertices, torch.stack(f_list)
|
||||
v_list, f_list, c_list = [], [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
num_v = vertices[i].shape[0]
|
||||
|
||||
faces_inv = faces[:, [0, 2, 1]]
|
||||
return vertices, torch.cat([faces, faces_inv], dim=0)
|
||||
# Duplicate vertices
|
||||
v_dup = torch.cat([vertices[i], vertices[i]], dim=0)
|
||||
|
||||
# Invert faces AND shift indices to point to the new duplicated vertices
|
||||
f_inv = faces[i][:, [0, 2, 1]] + num_v
|
||||
f_dup = torch.cat([faces[i], f_inv], dim=0)
|
||||
|
||||
v_list.append(v_dup)
|
||||
f_list.append(f_dup)
|
||||
|
||||
if colors is not None:
|
||||
c_dup = torch.cat([colors[i], colors[i]], dim=0)
|
||||
c_list.append(c_dup)
|
||||
|
||||
out_v = torch.stack(v_list)
|
||||
out_f = torch.stack(f_list)
|
||||
out_c = torch.stack(c_list) if colors is not None else None
|
||||
return out_v, out_f, out_c
|
||||
|
||||
# --- Unbatched (Single Mesh) ---
|
||||
num_v = vertices.shape[0]
|
||||
|
||||
# duplicate vertices
|
||||
v_dup = torch.cat([vertices, vertices], dim=0)
|
||||
|
||||
# invert faces AND shift indices to point to the new duplicated vertices
|
||||
faces_inv = faces[:, [0, 2, 1]] + num_v
|
||||
f_dup = torch.cat([faces, faces_inv], dim=0)
|
||||
|
||||
# duplicate colors if they exist
|
||||
if colors is not None:
|
||||
c_dup = torch.cat([colors, colors], dim=0)
|
||||
return v_dup, f_dup, c_dup
|
||||
|
||||
return v_dup, f_dup
|
||||
|
||||
def _cleanup_mesh(verts, faces, min_angle_deg=0.5, max_aspect=100.0):
|
||||
if faces.numel() == 0:
|
||||
@ -422,10 +461,10 @@ def _gpu_greedy_matching_fast(edges, err, v_alive, max_select):
|
||||
return sel
|
||||
|
||||
|
||||
def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_edge_length=None):
|
||||
def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces, device, max_edge_length=None):
|
||||
# Use float32 instead of float64. RTX-class consumer GPUs run FP32 ~32-64x
|
||||
# faster than FP64, and QEM only needs the stabilizer for conditioning.
|
||||
# Always copy=True so we can safely mutate verts/colors in-place.
|
||||
# Always copy=True so we can safely mutate verts/colors/normals in-place.
|
||||
verts = vertices.detach().to(device=device, dtype=torch.float32, copy=True)
|
||||
faces = faces_in.detach().to(device=device, dtype=torch.int64)
|
||||
colors = (
|
||||
@ -433,6 +472,12 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
if colors_in is not None
|
||||
else None
|
||||
)
|
||||
# ADDED: Initialize normals
|
||||
normals = (
|
||||
normals_in.detach().to(device=device, dtype=torch.float32, copy=True)
|
||||
if normals_in is not None
|
||||
else None
|
||||
)
|
||||
|
||||
num_verts = verts.shape[0]
|
||||
num_faces = faces.shape[0]
|
||||
@ -492,8 +537,7 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
if edges.shape[0] == 0:
|
||||
break
|
||||
|
||||
# Deduplicate edges (each interior edge appears in two adjacent faces).
|
||||
# Pack (low, high) into a single int64 key and call torch.unique once.
|
||||
# Deduplicate edges
|
||||
num_compact = alive_v.numel()
|
||||
packed = edges[:, 0].long() * num_compact + edges[:, 1].long()
|
||||
packed = torch.unique(packed)
|
||||
@ -520,10 +564,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
|
||||
# Sample edges for processing
|
||||
n_edges_total = edges_orig.shape[0]
|
||||
max_edges_to_process = 10_000_000 # 10M edges per iteration
|
||||
max_edges_to_process = 10_000_000
|
||||
|
||||
if n_edges_total > max_edges_to_process:
|
||||
# Random sample without building a full permutation.
|
||||
perm = torch.randint(0, n_edges_total, (max_edges_to_process,), device=device)
|
||||
edges_orig = edges_orig[perm]
|
||||
n_edges = max_edges_to_process
|
||||
@ -542,8 +585,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
optimal = optimal[valid_idx]
|
||||
err = err[valid_idx]
|
||||
|
||||
# Vectorized greedy matching can safely collapse many independent
|
||||
# edges per iteration, so allow a much larger batch.
|
||||
faces_to_remove = n_faces - target_faces
|
||||
max_collapses = min(1_000_000, max(10_000, faces_to_remove // 4))
|
||||
|
||||
@ -555,12 +596,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
v_a = edges_orig[sel, 0]
|
||||
v_b = edges_orig[sel, 1]
|
||||
|
||||
# NOTE: original PostProcessMesh built a CSR (face_indices/vert_ptrs) and
|
||||
# iterated v_a/v_b in Python here to populate pair_edge_idx/pair_face_idx
|
||||
# arrays that were then discarded (keep_mask was unconditionally all-True).
|
||||
# That dead block was the dominant cost on Trellis-sized meshes and has
|
||||
# been removed in this fast variant.
|
||||
|
||||
# Apply collapses
|
||||
verts[v_a] = optimal[sel]
|
||||
v_alive[v_b] = False
|
||||
@ -569,6 +604,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
if colors is not None:
|
||||
colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5
|
||||
|
||||
if normals is not None:
|
||||
normals[v_a] = (normals[v_a] + normals[v_b]) * 0.5
|
||||
|
||||
merge_map = torch.arange(num_verts, device=device)
|
||||
merge_map[v_b] = v_a
|
||||
faces = merge_map[faces]
|
||||
@ -583,12 +621,10 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
total_collapses += v_a.numel()
|
||||
iteration += 1
|
||||
|
||||
# Log only every 50 iterations to reduce sync overhead
|
||||
if iteration % 50 == 0 or n_faces < last_faces * 0.9:
|
||||
logging.debug(f"[QEM-fast] Iter {iteration}: {total_collapses} collapses, {int(f_alive.sum().item())} faces, applied {v_a.numel()}")
|
||||
last_faces = n_faces
|
||||
|
||||
# Periodic compaction
|
||||
if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5:
|
||||
faces = faces[f_alive]
|
||||
f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device)
|
||||
@ -600,6 +636,7 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
# Finalize
|
||||
final_v = verts[v_alive]
|
||||
final_c = colors[v_alive] if colors is not None else None
|
||||
final_n = normals[v_alive] if normals is not None else None
|
||||
|
||||
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
|
||||
remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device)
|
||||
@ -614,35 +651,59 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, target_faces, device, max_
|
||||
if final_f.numel() > 0:
|
||||
final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0)
|
||||
|
||||
if final_n is not None and final_f.numel() > 0:
|
||||
v0, v1, v2 = final_v[final_f[:, 0]], final_v[final_f[:, 1]], final_v[final_f[:, 2]]
|
||||
|
||||
# calculate the actual normal of the simplified faces
|
||||
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
|
||||
# Get the average reference normal for each face
|
||||
n0, n1, n2 = final_n[final_f[:, 0]], final_n[final_f[:, 1]], final_n[final_f[:, 2]]
|
||||
ref_face_normals = (n0 + n1 + n2) / 3.0
|
||||
|
||||
# Dot product to check if they point in the same direction
|
||||
dot_products = (face_normals * ref_face_normals).sum(dim=-1)
|
||||
|
||||
# Flip the indices of ONLY the incorrect faces (swap vertex 1 and 2)
|
||||
wrong_way_mask = dot_products < 0
|
||||
final_f[wrong_way_mask] = final_f[wrong_way_mask][:, [0, 2, 1]]
|
||||
|
||||
final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0)
|
||||
|
||||
return final_v, final_f, final_c
|
||||
return final_v, final_f, final_c, final_n
|
||||
|
||||
|
||||
def simplify_fn_fast(vertices, faces, colors=None, target=100000, max_edge_length=None):
|
||||
def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None):
|
||||
if vertices.ndim == 3:
|
||||
v_list, f_list, c_list = [], [], []
|
||||
v_list, f_list, c_list, n_list = [], [], [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
c_in = colors[i] if colors is not None else None
|
||||
v_i, f_i, c_i = simplify_fn_fast(vertices[i], faces[i], c_in, target, max_edge_length)
|
||||
n_in = normals[i] if normals is not None else None
|
||||
v_i, f_i, c_i, n_i = simplify_fn_fast(vertices[i], faces[i], c_in, n_in, target, max_edge_length)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
if c_i is not None:
|
||||
c_list.append(c_i)
|
||||
if n_i is not None:
|
||||
n_list.append(n_i)
|
||||
|
||||
c_out = torch.stack(c_list) if len(c_list) > 0 else None
|
||||
return torch.stack(v_list), torch.stack(f_list), c_out
|
||||
n_out = torch.stack(n_list) if len(n_list) > 0 else None
|
||||
return torch.stack(v_list), torch.stack(f_list), c_out, n_out
|
||||
|
||||
if faces.shape[0] <= target:
|
||||
return vertices, faces, colors
|
||||
return vertices, faces, colors, normals
|
||||
|
||||
device = vertices.device
|
||||
dtype = vertices.dtype
|
||||
face_dtype = faces.dtype
|
||||
color_dtype = colors.dtype if colors is not None else None
|
||||
# ADDED: Normal dtype
|
||||
normal_dtype = normals.dtype if normals is not None else None
|
||||
|
||||
# Pass tensors directly; _qem_simplify_fast handles dtype/device + copy.
|
||||
out_v, out_f, out_c = _qem_simplify_fast(
|
||||
vertices, faces, colors, target, device, max_edge_length
|
||||
out_v, out_f, out_c, out_n = _qem_simplify_fast(
|
||||
vertices, faces, colors, normals, target, device, max_edge_length
|
||||
)
|
||||
|
||||
final_v = out_v.to(device=device, dtype=dtype)
|
||||
@ -652,8 +713,31 @@ def simplify_fn_fast(vertices, faces, colors=None, target=100000, max_edge_lengt
|
||||
if out_c is not None
|
||||
else None
|
||||
)
|
||||
return final_v, final_f, final_c
|
||||
final_n = (
|
||||
out_n.to(device=device, dtype=normal_dtype)
|
||||
if out_n is not None
|
||||
else None
|
||||
)
|
||||
return final_v, final_f, final_c, final_n
|
||||
|
||||
def compute_vertex_normals(verts, faces):
|
||||
"""Computes area-weighted vertex normals."""
|
||||
# QUICK FIX: Ensure indices are int64 for scatter_add_
|
||||
faces_long = faces.to(torch.int64)
|
||||
|
||||
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
|
||||
v0, v1, v2 = verts[i0], verts[i1], verts[i2]
|
||||
|
||||
# calculate unnormalized face normals (magnitude is proportional to area)
|
||||
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
|
||||
# accumulate face normals to vertices
|
||||
vertex_normals = torch.zeros_like(verts)
|
||||
vertex_normals.scatter_add_(0, i0.unsqueeze(-1).expand_as(face_normals), face_normals)
|
||||
vertex_normals.scatter_add_(0, i1.unsqueeze(-1).expand_as(face_normals), face_normals)
|
||||
vertex_normals.scatter_add_(0, i2.unsqueeze(-1).expand_as(face_normals), face_normals)
|
||||
|
||||
return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6)
|
||||
|
||||
class PostProcessMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -690,11 +774,12 @@ class PostProcessMesh(IO.ComfyNode):
|
||||
v, f = fill_holes_fn(v, f, max_perimeter=fill_holes_perimeter)
|
||||
bar.update(1)
|
||||
|
||||
n = compute_vertex_normals(v, f)
|
||||
if target_face_count > 0 and f.shape[0] > target_face_count:
|
||||
v, f, c = simplify_fn_fast(v, f, colors=c, target=target_face_count)
|
||||
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
|
||||
bar.update(1)
|
||||
|
||||
v, f = make_double_sided(v, f)
|
||||
v, f = make_double_sided(v, f, c)
|
||||
bar.update(1)
|
||||
return v, f, c
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user