update the simplify function

This commit is contained in:
Yousef Rafat 2026-05-08 15:13:07 +03:00
parent 2727c4a48c
commit 94adce93ab

View File

@ -911,12 +911,12 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
output["batch_index"] = sample_indices output["batch_index"] = sample_indices
return IO.NodeOutput(output) return IO.NodeOutput(output)
def simplify_fn(vertices, faces, colors=None, target=100000): def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None):
if vertices.ndim == 3: if vertices.ndim == 3:
v_list, f_list, c_list = [], [], [] v_list, f_list, c_list = [], [], []
for i in range(vertices.shape[0]): for i in range(vertices.shape[0]):
c_in = colors[i] if colors is not None else None c_in = colors[i] if colors is not None else None
v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target) v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target, max_edge_length)
v_list.append(v_i) v_list.append(v_i)
f_list.append(f_i) f_list.append(f_i)
if c_i is not None: if c_i is not None:
@ -929,60 +929,292 @@ def simplify_fn(vertices, faces, colors=None, target=100000):
return vertices, faces, colors return vertices, faces, colors
device = vertices.device device = vertices.device
target_v = max(target / 4.0, 1.0) dtype = vertices.dtype
min_v = vertices.min(dim=0)[0] verts_np = vertices.detach().cpu().numpy().astype(np.float64)
max_v = vertices.max(dim=0)[0] faces_np = faces.detach().cpu().numpy().astype(np.int64)
extent = max_v - min_v colors_np = (
colors.detach().cpu().numpy().astype(np.float64)
if colors is not None
else None
)
volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) out_v, out_f, out_c = _qem_simplify_robust(
cell_size = (volume / target_v) ** (1/3.0) verts_np, faces_np, colors_np, target, device, max_edge_length
)
# Use CPU-side ordered reductions here so repeated runs produce identical final_v = out_v.to(device=device, dtype=dtype)
# simplified meshes instead of relying on GPU scatter-add accumulation order. final_f = out_f.to(device=device, dtype=faces.dtype)
vertices_np = vertices.detach().cpu().numpy() final_c = (
faces_np = faces.detach().cpu().numpy() out_c.to(device=device, dtype=colors.dtype)
colors_np = colors.detach().cpu().numpy() if colors is not None else None if out_c is not None
min_v_np = min_v.detach().cpu().numpy() else None
cell_size_value = float(cell_size.detach().cpu()) )
return final_v, final_f, final_c
quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64) def _qem_simplify_robust(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None):
unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True) verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64)
num_cells = unique_coords.shape[0] faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64)
colors = (
torch.from_numpy(colors_np).to(device=device, dtype=torch.float64)
if colors_np is not None
else None
)
new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype) num_verts = verts.shape[0]
np.add.at(new_vertices_np, inverse_indices, vertices_np) num_faces = faces.shape[0]
counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1) v_alive = torch.ones(num_verts, dtype=torch.bool, device=device)
new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None) f_alive = torch.ones(num_faces, dtype=torch.bool, device=device)
new_colors = None Q = _build_quadrics_fast(verts, faces)
if colors_np is not None:
new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype)
np.add.at(new_colors_np, inverse_indices, colors_np)
new_colors = new_colors_np / np.clip(counts_np, 1, None)
new_faces = inverse_indices[faces_np] # Mesh scale for relative thresholds
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ bbox = verts.max(dim=0)[0] - verts.min(dim=0)[0]
(new_faces[:, 1] != new_faces[:, 2]) & \ mesh_scale = torch.norm(bbox).item()
(new_faces[:, 2] != new_faces[:, 0])
new_faces = new_faces[valid_mask]
if new_faces.size == 0: # Default max_edge_length: 2x bounding box diagonal (MeshLib-style)
final_vertices_np = new_vertices_np[:0] if max_edge_length is None or max_edge_length <= 0:
final_faces_np = np.empty((0, 3), dtype=np.int64) max_edge_length = mesh_scale * 2.0
final_colors_np = new_colors[:0] if new_colors is not None else None
else:
unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True)
final_vertices_np = new_vertices_np[unique_face_indices]
final_faces_np = inv_face.reshape(-1, 3).astype(np.int64)
final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None
final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype) # Stabilizer: regularization to prevent extreme vertex movement
final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype) stabilizer = mesh_scale * mesh_scale * 0.001 # MeshLib default ~0.001 * scale^2
final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None
return final_vertices, final_faces, final_colors iteration = 0
while True:
n_faces = int(f_alive.sum().item())
if n_faces <= target_faces:
break
alive_v = torch.nonzero(v_alive, as_tuple=True)[0]
alive_f = torch.nonzero(f_alive, as_tuple=True)[0]
if alive_v.numel() <= 4 or alive_f.numel() == 0:
break
# ---- compact active mesh -------------------------------------------
vmap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
vmap[alive_v] = torch.arange(alive_v.numel(), device=device)
active_faces = faces[alive_f]
remapped = vmap[active_faces]
# ---- extract edges --------------------------------------------------
e0 = remapped[:, [0, 1]]
e1 = remapped[:, [1, 2]]
e2 = remapped[:, [2, 0]]
edges = torch.cat([e0, e1, e2], dim=0)
edges = torch.sort(edges, dim=1)[0]
edges = edges[(edges >= 0).all(dim=1)]
edges = edges[edges[:, 0] != edges[:, 1]]
if edges.shape[0] == 0:
break
edges_orig = alive_v[edges]
# ---- MeshLib-style: only process edges longer than maxEdgeLen ------
pa = verts[edges_orig[:, 0]]
pb = verts[edges_orig[:, 1]]
el = torch.norm(pb - pa, dim=-1)
long_enough = el > max_edge_length * 0.1 # Allow some tolerance
if not long_enough.any():
# If no long edges, lower threshold
long_enough = el > max_edge_length * 0.01
edges_orig = edges_orig[long_enough]
if edges_orig.shape[0] == 0:
break
# subsample so we never chew on >300 k edges
if edges_orig.shape[0] > 300_000:
step = edges_orig.shape[0] // 300_000 + 1
edges_orig = edges_orig[::step]
n_edges = edges_orig.shape[0]
if n_edges == 0:
break
# chunking the qem
Q0 = Q[edges_orig[:, 0]]
Q1 = Q[edges_orig[:, 1]]
Qe = Q0 + Q1
A = Qe[:, :3, :3]
b = -Qe[:, :3, 3]
optimal = torch.zeros((n_edges, 3), dtype=torch.float64, device=device)
SOLVE_CHUNK = 50_000
for i in range(0, n_edges, SOLVE_CHUNK):
sl = slice(i, min(i + SOLVE_CHUNK, n_edges))
A_c = A[sl]
b_c = b[sl].unsqueeze(-1)
# Add stabilizer to prevent extreme solutions
A_reg = A_c + torch.eye(3, device=device, dtype=torch.float64).unsqueeze(0) * stabilizer
dets = torch.det(A_reg)
good = dets.abs() > 1e-12
if good.any():
try:
sol = torch.linalg.solve(A_reg[good], b_c[good])
good_idx = torch.nonzero(good, as_tuple=True)[0] + i
optimal[good_idx] = sol.squeeze(-1)
except RuntimeError:
good = torch.zeros_like(good)
if (~good).any():
bad_idx = torch.nonzero(~good, as_tuple=True)[0] + i
va = edges_orig[bad_idx, 0]
vb = edges_orig[bad_idx, 1]
optimal[bad_idx] = (verts[va] + verts[vb]) * 0.5
# ---- error = v^T Q v (homogeneous) --------------------------------
v4 = torch.cat([
optimal,
torch.ones((n_edges, 1), device=device, dtype=torch.float64)
], dim=1)
err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4))
# geometeric guards
pa = verts[edges_orig[:, 0]]
pb = verts[edges_orig[:, 1]]
el = torch.norm(pb - pa, dim=-1)
# reject near zero edges
length_ok = el > mesh_scale * 1e-5
# moderate wander: stabilizer keeps optimal close, so we can be looser
dist_a = torch.norm(optimal - pa, dim=-1)
dist_b = torch.norm(optimal - pb, dim=-1)
wander_ok = (dist_a <= 4.0 * el) & (dist_b <= 4.0 * el)
nan_ok = ~torch.isnan(optimal).any(dim=-1)
# MAX ERROR CAP: hard limit on quadric error (MeshLib-style)
# Prevents collapses that would remove too much detail
max_error = max_edge_length * max_edge_length
error_ok = err < max_error
valid = length_ok & wander_ok & nan_ok & error_ok
if not valid.any():
break
valid_idx = torch.nonzero(valid, as_tuple=True)[0]
edges_orig = edges_orig[valid_idx]
optimal = optimal[valid_idx]
err = err[valid_idx]
# ---- vectorized greedy independent set ------------------------------
sorted_idx = torch.argsort(err)
used = torch.zeros(num_verts, dtype=torch.bool, device=device)
used[~v_alive] = True
max_collapses = max(2_000, (n_faces - target_faces) // 5)
selected_edges = []
n_selected = 0
GREEDY_CHUNK = 100_000
for start in range(0, sorted_idx.numel(), GREEDY_CHUNK):
chunk = sorted_idx[start:start + GREEDY_CHUNK]
va = edges_orig[chunk, 0]
vb = edges_orig[chunk, 1]
valid_mask = ~used[va] & ~used[vb]
if not valid_mask.any():
continue
sel = chunk[valid_mask]
selected_edges.append(sel)
used[edges_orig[sel, 0]] = True
used[edges_orig[sel, 1]] = True
n_selected += sel.numel()
if n_selected >= max_collapses:
break
if n_selected == 0:
break
sel = torch.cat(selected_edges)
# ---- apply collapses ------------------------------------------------
v_a = edges_orig[sel, 0]
v_b = edges_orig[sel, 1]
verts[v_a] = optimal[sel]
v_alive[v_b] = False
Q[v_a] += Q[v_b]
if colors is not None:
colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5
merge_map = torch.arange(num_verts, device=device)
merge_map[v_b] = v_a
faces = merge_map[faces]
bad = (
(faces[:, 0] == faces[:, 1])
| (faces[:, 1] == faces[:, 2])
| (faces[:, 2] == faces[:, 0])
)
f_alive &= ~bad
iteration += 1
if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5:
faces = faces[f_alive]
f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device)
num_faces = faces.shape[0]
final_v = verts[v_alive]
final_c = colors[v_alive] if colors is not None else None
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device)
final_f = remap[faces[f_alive]]
if final_f.numel() > 0:
final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0)
return final_v, final_f, final_c
def _build_quadrics_fast(verts, faces):
"""GPU quadric build. Fast; non-deterministic on CUDA."""
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
e1 = v1 - v0
e2 = v2 - v0
n = torch.cross(e1, e2, dim=-1)
area = torch.norm(n, dim=-1)
mask = area > 1e-12
n_norm = torch.zeros_like(n)
n_norm[mask] = n[mask] / area[mask].unsqueeze(-1)
d = -(n_norm * v0).sum(dim=-1, keepdim=True)
p = torch.cat([n_norm, d], dim=-1)
K = torch.einsum("fi,fj->fij", p, p)
K = K * area[:, None, None]
V = verts.shape[0]
Q = torch.zeros((V, 4, 4), dtype=torch.float64, device=verts.device)
K_flat = K.reshape(-1, 16)
Q_flat = Q.reshape(V, 16)
for corner in range(3):
idx = faces[:, corner].unsqueeze(1).expand(-1, 16)
Q_flat.scatter_add_(0, idx, K_flat)
return Q_flat.reshape(V, 4, 4)
def fill_holes_fn(vertices, faces, max_perimeter=0.03): def fill_holes_fn(vertices, faces, max_perimeter=0.03):
is_batched = vertices.ndim == 3 is_batched = vertices.ndim == 3