mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
update the simplify function
This commit is contained in:
parent
2727c4a48c
commit
94adce93ab
@ -911,12 +911,12 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
output["batch_index"] = sample_indices
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
def simplify_fn(vertices, faces, colors=None, target=100000):
|
||||
def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None):
|
||||
if vertices.ndim == 3:
|
||||
v_list, f_list, c_list = [], [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
c_in = colors[i] if colors is not None else None
|
||||
v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target)
|
||||
v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target, max_edge_length)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
if c_i is not None:
|
||||
@ -929,60 +929,292 @@ def simplify_fn(vertices, faces, colors=None, target=100000):
|
||||
return vertices, faces, colors
|
||||
|
||||
device = vertices.device
|
||||
target_v = max(target / 4.0, 1.0)
|
||||
dtype = vertices.dtype
|
||||
|
||||
min_v = vertices.min(dim=0)[0]
|
||||
max_v = vertices.max(dim=0)[0]
|
||||
extent = max_v - min_v
|
||||
verts_np = vertices.detach().cpu().numpy().astype(np.float64)
|
||||
faces_np = faces.detach().cpu().numpy().astype(np.int64)
|
||||
colors_np = (
|
||||
colors.detach().cpu().numpy().astype(np.float64)
|
||||
if colors is not None
|
||||
else None
|
||||
)
|
||||
|
||||
volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8)
|
||||
cell_size = (volume / target_v) ** (1/3.0)
|
||||
out_v, out_f, out_c = _qem_simplify_robust(
|
||||
verts_np, faces_np, colors_np, target, device, max_edge_length
|
||||
)
|
||||
|
||||
# Use CPU-side ordered reductions here so repeated runs produce identical
|
||||
# simplified meshes instead of relying on GPU scatter-add accumulation order.
|
||||
vertices_np = vertices.detach().cpu().numpy()
|
||||
faces_np = faces.detach().cpu().numpy()
|
||||
colors_np = colors.detach().cpu().numpy() if colors is not None else None
|
||||
min_v_np = min_v.detach().cpu().numpy()
|
||||
cell_size_value = float(cell_size.detach().cpu())
|
||||
final_v = out_v.to(device=device, dtype=dtype)
|
||||
final_f = out_f.to(device=device, dtype=faces.dtype)
|
||||
final_c = (
|
||||
out_c.to(device=device, dtype=colors.dtype)
|
||||
if out_c is not None
|
||||
else None
|
||||
)
|
||||
return final_v, final_f, final_c
|
||||
|
||||
quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64)
|
||||
unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True)
|
||||
num_cells = unique_coords.shape[0]
|
||||
def _qem_simplify_robust(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None):
|
||||
verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64)
|
||||
faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64)
|
||||
colors = (
|
||||
torch.from_numpy(colors_np).to(device=device, dtype=torch.float64)
|
||||
if colors_np is not None
|
||||
else None
|
||||
)
|
||||
|
||||
new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype)
|
||||
np.add.at(new_vertices_np, inverse_indices, vertices_np)
|
||||
num_verts = verts.shape[0]
|
||||
num_faces = faces.shape[0]
|
||||
|
||||
counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1)
|
||||
new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None)
|
||||
v_alive = torch.ones(num_verts, dtype=torch.bool, device=device)
|
||||
f_alive = torch.ones(num_faces, dtype=torch.bool, device=device)
|
||||
|
||||
new_colors = None
|
||||
if colors_np is not None:
|
||||
new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype)
|
||||
np.add.at(new_colors_np, inverse_indices, colors_np)
|
||||
new_colors = new_colors_np / np.clip(counts_np, 1, None)
|
||||
Q = _build_quadrics_fast(verts, faces)
|
||||
|
||||
new_faces = inverse_indices[faces_np]
|
||||
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
|
||||
(new_faces[:, 1] != new_faces[:, 2]) & \
|
||||
(new_faces[:, 2] != new_faces[:, 0])
|
||||
new_faces = new_faces[valid_mask]
|
||||
# Mesh scale for relative thresholds
|
||||
bbox = verts.max(dim=0)[0] - verts.min(dim=0)[0]
|
||||
mesh_scale = torch.norm(bbox).item()
|
||||
|
||||
if new_faces.size == 0:
|
||||
final_vertices_np = new_vertices_np[:0]
|
||||
final_faces_np = np.empty((0, 3), dtype=np.int64)
|
||||
final_colors_np = new_colors[:0] if new_colors is not None else None
|
||||
else:
|
||||
unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True)
|
||||
final_vertices_np = new_vertices_np[unique_face_indices]
|
||||
final_faces_np = inv_face.reshape(-1, 3).astype(np.int64)
|
||||
final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None
|
||||
# Default max_edge_length: 2x bounding box diagonal (MeshLib-style)
|
||||
if max_edge_length is None or max_edge_length <= 0:
|
||||
max_edge_length = mesh_scale * 2.0
|
||||
|
||||
final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype)
|
||||
final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype)
|
||||
final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None
|
||||
# Stabilizer: regularization to prevent extreme vertex movement
|
||||
stabilizer = mesh_scale * mesh_scale * 0.001 # MeshLib default ~0.001 * scale^2
|
||||
|
||||
return final_vertices, final_faces, final_colors
|
||||
iteration = 0
|
||||
while True:
|
||||
n_faces = int(f_alive.sum().item())
|
||||
if n_faces <= target_faces:
|
||||
break
|
||||
|
||||
alive_v = torch.nonzero(v_alive, as_tuple=True)[0]
|
||||
alive_f = torch.nonzero(f_alive, as_tuple=True)[0]
|
||||
|
||||
if alive_v.numel() <= 4 or alive_f.numel() == 0:
|
||||
break
|
||||
|
||||
# ---- compact active mesh -------------------------------------------
|
||||
vmap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
|
||||
vmap[alive_v] = torch.arange(alive_v.numel(), device=device)
|
||||
|
||||
active_faces = faces[alive_f]
|
||||
remapped = vmap[active_faces]
|
||||
|
||||
# ---- extract edges --------------------------------------------------
|
||||
e0 = remapped[:, [0, 1]]
|
||||
e1 = remapped[:, [1, 2]]
|
||||
e2 = remapped[:, [2, 0]]
|
||||
edges = torch.cat([e0, e1, e2], dim=0)
|
||||
edges = torch.sort(edges, dim=1)[0]
|
||||
edges = edges[(edges >= 0).all(dim=1)]
|
||||
edges = edges[edges[:, 0] != edges[:, 1]]
|
||||
|
||||
if edges.shape[0] == 0:
|
||||
break
|
||||
|
||||
edges_orig = alive_v[edges]
|
||||
|
||||
# ---- MeshLib-style: only process edges longer than maxEdgeLen ------
|
||||
pa = verts[edges_orig[:, 0]]
|
||||
pb = verts[edges_orig[:, 1]]
|
||||
el = torch.norm(pb - pa, dim=-1)
|
||||
|
||||
long_enough = el > max_edge_length * 0.1 # Allow some tolerance
|
||||
if not long_enough.any():
|
||||
# If no long edges, lower threshold
|
||||
long_enough = el > max_edge_length * 0.01
|
||||
|
||||
edges_orig = edges_orig[long_enough]
|
||||
if edges_orig.shape[0] == 0:
|
||||
break
|
||||
|
||||
# subsample so we never chew on >300 k edges
|
||||
if edges_orig.shape[0] > 300_000:
|
||||
step = edges_orig.shape[0] // 300_000 + 1
|
||||
edges_orig = edges_orig[::step]
|
||||
|
||||
n_edges = edges_orig.shape[0]
|
||||
if n_edges == 0:
|
||||
break
|
||||
|
||||
# chunking the qem
|
||||
Q0 = Q[edges_orig[:, 0]]
|
||||
Q1 = Q[edges_orig[:, 1]]
|
||||
Qe = Q0 + Q1
|
||||
|
||||
A = Qe[:, :3, :3]
|
||||
b = -Qe[:, :3, 3]
|
||||
|
||||
optimal = torch.zeros((n_edges, 3), dtype=torch.float64, device=device)
|
||||
SOLVE_CHUNK = 50_000
|
||||
|
||||
for i in range(0, n_edges, SOLVE_CHUNK):
|
||||
sl = slice(i, min(i + SOLVE_CHUNK, n_edges))
|
||||
A_c = A[sl]
|
||||
b_c = b[sl].unsqueeze(-1)
|
||||
|
||||
# Add stabilizer to prevent extreme solutions
|
||||
A_reg = A_c + torch.eye(3, device=device, dtype=torch.float64).unsqueeze(0) * stabilizer
|
||||
|
||||
dets = torch.det(A_reg)
|
||||
good = dets.abs() > 1e-12
|
||||
|
||||
if good.any():
|
||||
try:
|
||||
sol = torch.linalg.solve(A_reg[good], b_c[good])
|
||||
good_idx = torch.nonzero(good, as_tuple=True)[0] + i
|
||||
optimal[good_idx] = sol.squeeze(-1)
|
||||
except RuntimeError:
|
||||
good = torch.zeros_like(good)
|
||||
|
||||
if (~good).any():
|
||||
bad_idx = torch.nonzero(~good, as_tuple=True)[0] + i
|
||||
va = edges_orig[bad_idx, 0]
|
||||
vb = edges_orig[bad_idx, 1]
|
||||
optimal[bad_idx] = (verts[va] + verts[vb]) * 0.5
|
||||
|
||||
# ---- error = v^T Q v (homogeneous) --------------------------------
|
||||
v4 = torch.cat([
|
||||
optimal,
|
||||
torch.ones((n_edges, 1), device=device, dtype=torch.float64)
|
||||
], dim=1)
|
||||
err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4))
|
||||
|
||||
# geometeric guards
|
||||
pa = verts[edges_orig[:, 0]]
|
||||
pb = verts[edges_orig[:, 1]]
|
||||
el = torch.norm(pb - pa, dim=-1)
|
||||
|
||||
# reject near zero edges
|
||||
length_ok = el > mesh_scale * 1e-5
|
||||
|
||||
# moderate wander: stabilizer keeps optimal close, so we can be looser
|
||||
dist_a = torch.norm(optimal - pa, dim=-1)
|
||||
dist_b = torch.norm(optimal - pb, dim=-1)
|
||||
wander_ok = (dist_a <= 4.0 * el) & (dist_b <= 4.0 * el)
|
||||
|
||||
nan_ok = ~torch.isnan(optimal).any(dim=-1)
|
||||
|
||||
# MAX ERROR CAP: hard limit on quadric error (MeshLib-style)
|
||||
# Prevents collapses that would remove too much detail
|
||||
max_error = max_edge_length * max_edge_length
|
||||
error_ok = err < max_error
|
||||
|
||||
valid = length_ok & wander_ok & nan_ok & error_ok
|
||||
if not valid.any():
|
||||
break
|
||||
|
||||
valid_idx = torch.nonzero(valid, as_tuple=True)[0]
|
||||
edges_orig = edges_orig[valid_idx]
|
||||
optimal = optimal[valid_idx]
|
||||
err = err[valid_idx]
|
||||
|
||||
# ---- vectorized greedy independent set ------------------------------
|
||||
sorted_idx = torch.argsort(err)
|
||||
used = torch.zeros(num_verts, dtype=torch.bool, device=device)
|
||||
used[~v_alive] = True
|
||||
|
||||
max_collapses = max(2_000, (n_faces - target_faces) // 5)
|
||||
selected_edges = []
|
||||
n_selected = 0
|
||||
GREEDY_CHUNK = 100_000
|
||||
|
||||
for start in range(0, sorted_idx.numel(), GREEDY_CHUNK):
|
||||
chunk = sorted_idx[start:start + GREEDY_CHUNK]
|
||||
va = edges_orig[chunk, 0]
|
||||
vb = edges_orig[chunk, 1]
|
||||
|
||||
valid_mask = ~used[va] & ~used[vb]
|
||||
if not valid_mask.any():
|
||||
continue
|
||||
|
||||
sel = chunk[valid_mask]
|
||||
selected_edges.append(sel)
|
||||
|
||||
used[edges_orig[sel, 0]] = True
|
||||
used[edges_orig[sel, 1]] = True
|
||||
n_selected += sel.numel()
|
||||
|
||||
if n_selected >= max_collapses:
|
||||
break
|
||||
|
||||
if n_selected == 0:
|
||||
break
|
||||
|
||||
sel = torch.cat(selected_edges)
|
||||
|
||||
# ---- apply collapses ------------------------------------------------
|
||||
v_a = edges_orig[sel, 0]
|
||||
v_b = edges_orig[sel, 1]
|
||||
|
||||
verts[v_a] = optimal[sel]
|
||||
v_alive[v_b] = False
|
||||
Q[v_a] += Q[v_b]
|
||||
|
||||
if colors is not None:
|
||||
colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5
|
||||
|
||||
merge_map = torch.arange(num_verts, device=device)
|
||||
merge_map[v_b] = v_a
|
||||
faces = merge_map[faces]
|
||||
|
||||
bad = (
|
||||
(faces[:, 0] == faces[:, 1])
|
||||
| (faces[:, 1] == faces[:, 2])
|
||||
| (faces[:, 2] == faces[:, 0])
|
||||
)
|
||||
f_alive &= ~bad
|
||||
|
||||
iteration += 1
|
||||
if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5:
|
||||
faces = faces[f_alive]
|
||||
f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device)
|
||||
num_faces = faces.shape[0]
|
||||
|
||||
final_v = verts[v_alive]
|
||||
final_c = colors[v_alive] if colors is not None else None
|
||||
|
||||
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
|
||||
remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device)
|
||||
final_f = remap[faces[f_alive]]
|
||||
|
||||
if final_f.numel() > 0:
|
||||
final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0)
|
||||
|
||||
return final_v, final_f, final_c
|
||||
|
||||
|
||||
def _build_quadrics_fast(verts, faces):
|
||||
"""GPU quadric build. Fast; non-deterministic on CUDA."""
|
||||
v0 = verts[faces[:, 0]]
|
||||
v1 = verts[faces[:, 1]]
|
||||
v2 = verts[faces[:, 2]]
|
||||
|
||||
e1 = v1 - v0
|
||||
e2 = v2 - v0
|
||||
n = torch.cross(e1, e2, dim=-1)
|
||||
area = torch.norm(n, dim=-1)
|
||||
|
||||
mask = area > 1e-12
|
||||
n_norm = torch.zeros_like(n)
|
||||
n_norm[mask] = n[mask] / area[mask].unsqueeze(-1)
|
||||
|
||||
d = -(n_norm * v0).sum(dim=-1, keepdim=True)
|
||||
p = torch.cat([n_norm, d], dim=-1)
|
||||
|
||||
K = torch.einsum("fi,fj->fij", p, p)
|
||||
K = K * area[:, None, None]
|
||||
|
||||
V = verts.shape[0]
|
||||
Q = torch.zeros((V, 4, 4), dtype=torch.float64, device=verts.device)
|
||||
|
||||
K_flat = K.reshape(-1, 16)
|
||||
Q_flat = Q.reshape(V, 16)
|
||||
|
||||
for corner in range(3):
|
||||
idx = faces[:, corner].unsqueeze(1).expand(-1, 16)
|
||||
Q_flat.scatter_add_(0, idx, K_flat)
|
||||
|
||||
return Q_flat.reshape(V, 4, 4)
|
||||
|
||||
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
is_batched = vertices.ndim == 3
|
||||
|
||||
Loading…
Reference in New Issue
Block a user