diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index d6edaedb6..704f6f32f 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -911,12 +911,12 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): output["batch_index"] = sample_indices return IO.NodeOutput(output) -def simplify_fn(vertices, faces, colors=None, target=100000): +def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None): if vertices.ndim == 3: v_list, f_list, c_list = [], [], [] for i in range(vertices.shape[0]): c_in = colors[i] if colors is not None else None - v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target) + v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target, max_edge_length) v_list.append(v_i) f_list.append(f_i) if c_i is not None: @@ -929,60 +929,292 @@ def simplify_fn(vertices, faces, colors=None, target=100000): return vertices, faces, colors device = vertices.device - target_v = max(target / 4.0, 1.0) + dtype = vertices.dtype - min_v = vertices.min(dim=0)[0] - max_v = vertices.max(dim=0)[0] - extent = max_v - min_v + verts_np = vertices.detach().cpu().numpy().astype(np.float64) + faces_np = faces.detach().cpu().numpy().astype(np.int64) + colors_np = ( + colors.detach().cpu().numpy().astype(np.float64) + if colors is not None + else None + ) - volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) - cell_size = (volume / target_v) ** (1/3.0) + out_v, out_f, out_c = _qem_simplify_robust( + verts_np, faces_np, colors_np, target, device, max_edge_length + ) - # Use CPU-side ordered reductions here so repeated runs produce identical - # simplified meshes instead of relying on GPU scatter-add accumulation order. - vertices_np = vertices.detach().cpu().numpy() - faces_np = faces.detach().cpu().numpy() - colors_np = colors.detach().cpu().numpy() if colors is not None else None - min_v_np = min_v.detach().cpu().numpy() - cell_size_value = float(cell_size.detach().cpu()) + final_v = out_v.to(device=device, dtype=dtype) + final_f = out_f.to(device=device, dtype=faces.dtype) + final_c = ( + out_c.to(device=device, dtype=colors.dtype) + if out_c is not None + else None + ) + return final_v, final_f, final_c - quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64) - unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True) - num_cells = unique_coords.shape[0] +def _qem_simplify_robust(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None): + verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64) + faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64) + colors = ( + torch.from_numpy(colors_np).to(device=device, dtype=torch.float64) + if colors_np is not None + else None + ) - new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype) - np.add.at(new_vertices_np, inverse_indices, vertices_np) + num_verts = verts.shape[0] + num_faces = faces.shape[0] - counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1) - new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None) + v_alive = torch.ones(num_verts, dtype=torch.bool, device=device) + f_alive = torch.ones(num_faces, dtype=torch.bool, device=device) - new_colors = None - if colors_np is not None: - new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype) - np.add.at(new_colors_np, inverse_indices, colors_np) - new_colors = new_colors_np / np.clip(counts_np, 1, None) + Q = _build_quadrics_fast(verts, faces) - new_faces = inverse_indices[faces_np] - valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ - (new_faces[:, 1] != new_faces[:, 2]) & \ - (new_faces[:, 2] != new_faces[:, 0]) - new_faces = new_faces[valid_mask] + # Mesh scale for relative thresholds + bbox = verts.max(dim=0)[0] - verts.min(dim=0)[0] + mesh_scale = torch.norm(bbox).item() - if new_faces.size == 0: - final_vertices_np = new_vertices_np[:0] - final_faces_np = np.empty((0, 3), dtype=np.int64) - final_colors_np = new_colors[:0] if new_colors is not None else None - else: - unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True) - final_vertices_np = new_vertices_np[unique_face_indices] - final_faces_np = inv_face.reshape(-1, 3).astype(np.int64) - final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None + # Default max_edge_length: 2x bounding box diagonal (MeshLib-style) + if max_edge_length is None or max_edge_length <= 0: + max_edge_length = mesh_scale * 2.0 - final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype) - final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype) - final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None + # Stabilizer: regularization to prevent extreme vertex movement + stabilizer = mesh_scale * mesh_scale * 0.001 # MeshLib default ~0.001 * scale^2 - return final_vertices, final_faces, final_colors + iteration = 0 + while True: + n_faces = int(f_alive.sum().item()) + if n_faces <= target_faces: + break + + alive_v = torch.nonzero(v_alive, as_tuple=True)[0] + alive_f = torch.nonzero(f_alive, as_tuple=True)[0] + + if alive_v.numel() <= 4 or alive_f.numel() == 0: + break + + # ---- compact active mesh ------------------------------------------- + vmap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) + vmap[alive_v] = torch.arange(alive_v.numel(), device=device) + + active_faces = faces[alive_f] + remapped = vmap[active_faces] + + # ---- extract edges -------------------------------------------------- + e0 = remapped[:, [0, 1]] + e1 = remapped[:, [1, 2]] + e2 = remapped[:, [2, 0]] + edges = torch.cat([e0, e1, e2], dim=0) + edges = torch.sort(edges, dim=1)[0] + edges = edges[(edges >= 0).all(dim=1)] + edges = edges[edges[:, 0] != edges[:, 1]] + + if edges.shape[0] == 0: + break + + edges_orig = alive_v[edges] + + # ---- MeshLib-style: only process edges longer than maxEdgeLen ------ + pa = verts[edges_orig[:, 0]] + pb = verts[edges_orig[:, 1]] + el = torch.norm(pb - pa, dim=-1) + + long_enough = el > max_edge_length * 0.1 # Allow some tolerance + if not long_enough.any(): + # If no long edges, lower threshold + long_enough = el > max_edge_length * 0.01 + + edges_orig = edges_orig[long_enough] + if edges_orig.shape[0] == 0: + break + + # subsample so we never chew on >300 k edges + if edges_orig.shape[0] > 300_000: + step = edges_orig.shape[0] // 300_000 + 1 + edges_orig = edges_orig[::step] + + n_edges = edges_orig.shape[0] + if n_edges == 0: + break + + # chunking the qem + Q0 = Q[edges_orig[:, 0]] + Q1 = Q[edges_orig[:, 1]] + Qe = Q0 + Q1 + + A = Qe[:, :3, :3] + b = -Qe[:, :3, 3] + + optimal = torch.zeros((n_edges, 3), dtype=torch.float64, device=device) + SOLVE_CHUNK = 50_000 + + for i in range(0, n_edges, SOLVE_CHUNK): + sl = slice(i, min(i + SOLVE_CHUNK, n_edges)) + A_c = A[sl] + b_c = b[sl].unsqueeze(-1) + + # Add stabilizer to prevent extreme solutions + A_reg = A_c + torch.eye(3, device=device, dtype=torch.float64).unsqueeze(0) * stabilizer + + dets = torch.det(A_reg) + good = dets.abs() > 1e-12 + + if good.any(): + try: + sol = torch.linalg.solve(A_reg[good], b_c[good]) + good_idx = torch.nonzero(good, as_tuple=True)[0] + i + optimal[good_idx] = sol.squeeze(-1) + except RuntimeError: + good = torch.zeros_like(good) + + if (~good).any(): + bad_idx = torch.nonzero(~good, as_tuple=True)[0] + i + va = edges_orig[bad_idx, 0] + vb = edges_orig[bad_idx, 1] + optimal[bad_idx] = (verts[va] + verts[vb]) * 0.5 + + # ---- error = v^T Q v (homogeneous) -------------------------------- + v4 = torch.cat([ + optimal, + torch.ones((n_edges, 1), device=device, dtype=torch.float64) + ], dim=1) + err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4)) + + # geometeric guards + pa = verts[edges_orig[:, 0]] + pb = verts[edges_orig[:, 1]] + el = torch.norm(pb - pa, dim=-1) + + # reject near zero edges + length_ok = el > mesh_scale * 1e-5 + + # moderate wander: stabilizer keeps optimal close, so we can be looser + dist_a = torch.norm(optimal - pa, dim=-1) + dist_b = torch.norm(optimal - pb, dim=-1) + wander_ok = (dist_a <= 4.0 * el) & (dist_b <= 4.0 * el) + + nan_ok = ~torch.isnan(optimal).any(dim=-1) + + # MAX ERROR CAP: hard limit on quadric error (MeshLib-style) + # Prevents collapses that would remove too much detail + max_error = max_edge_length * max_edge_length + error_ok = err < max_error + + valid = length_ok & wander_ok & nan_ok & error_ok + if not valid.any(): + break + + valid_idx = torch.nonzero(valid, as_tuple=True)[0] + edges_orig = edges_orig[valid_idx] + optimal = optimal[valid_idx] + err = err[valid_idx] + + # ---- vectorized greedy independent set ------------------------------ + sorted_idx = torch.argsort(err) + used = torch.zeros(num_verts, dtype=torch.bool, device=device) + used[~v_alive] = True + + max_collapses = max(2_000, (n_faces - target_faces) // 5) + selected_edges = [] + n_selected = 0 + GREEDY_CHUNK = 100_000 + + for start in range(0, sorted_idx.numel(), GREEDY_CHUNK): + chunk = sorted_idx[start:start + GREEDY_CHUNK] + va = edges_orig[chunk, 0] + vb = edges_orig[chunk, 1] + + valid_mask = ~used[va] & ~used[vb] + if not valid_mask.any(): + continue + + sel = chunk[valid_mask] + selected_edges.append(sel) + + used[edges_orig[sel, 0]] = True + used[edges_orig[sel, 1]] = True + n_selected += sel.numel() + + if n_selected >= max_collapses: + break + + if n_selected == 0: + break + + sel = torch.cat(selected_edges) + + # ---- apply collapses ------------------------------------------------ + v_a = edges_orig[sel, 0] + v_b = edges_orig[sel, 1] + + verts[v_a] = optimal[sel] + v_alive[v_b] = False + Q[v_a] += Q[v_b] + + if colors is not None: + colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5 + + merge_map = torch.arange(num_verts, device=device) + merge_map[v_b] = v_a + faces = merge_map[faces] + + bad = ( + (faces[:, 0] == faces[:, 1]) + | (faces[:, 1] == faces[:, 2]) + | (faces[:, 2] == faces[:, 0]) + ) + f_alive &= ~bad + + iteration += 1 + if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5: + faces = faces[f_alive] + f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device) + num_faces = faces.shape[0] + + final_v = verts[v_alive] + final_c = colors[v_alive] if colors is not None else None + + remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) + remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device) + final_f = remap[faces[f_alive]] + + if final_f.numel() > 0: + final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0) + + return final_v, final_f, final_c + + +def _build_quadrics_fast(verts, faces): + """GPU quadric build. Fast; non-deterministic on CUDA.""" + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + + e1 = v1 - v0 + e2 = v2 - v0 + n = torch.cross(e1, e2, dim=-1) + area = torch.norm(n, dim=-1) + + mask = area > 1e-12 + n_norm = torch.zeros_like(n) + n_norm[mask] = n[mask] / area[mask].unsqueeze(-1) + + d = -(n_norm * v0).sum(dim=-1, keepdim=True) + p = torch.cat([n_norm, d], dim=-1) + + K = torch.einsum("fi,fj->fij", p, p) + K = K * area[:, None, None] + + V = verts.shape[0] + Q = torch.zeros((V, 4, 4), dtype=torch.float64, device=verts.device) + + K_flat = K.reshape(-1, 16) + Q_flat = Q.reshape(V, 16) + + for corner in range(3): + idx = faces[:, corner].unsqueeze(1).expand(-1, 16) + Q_flat.scatter_add_(0, idx, K_flat) + + return Q_flat.reshape(V, 4, 4) def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3