diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index 825425128..22e3df3ac 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -38,7 +38,8 @@ class MESH: metallic_roughness: torch.Tensor | None = None, vertex_counts: torch.Tensor | None = None, face_counts: torch.Tensor | None = None, - unlit: bool = False): + unlit: bool = False, + normals: torch.Tensor | None = None): assert (vertex_counts is None) == (face_counts is None), \ "vertex_counts and face_counts must be provided together (both or neither)" @@ -46,6 +47,9 @@ class MESH: self.faces = faces # faces: (B, M, 3) self.uvs = uvs # uvs: (B, N, 2) self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4) + # Optional per-vertex normals: (B, N, 3). When None, SaveGLB computes smooth + # area-weighted normals so viewers don't fall back to flat (per-face) shading. + self.normals = normals self.texture = texture # texture (baseColor): (B, H, W, 3) # glTF metallicRoughness texture: (B, H, W, 3), R unused, G=roughness, B=metallic self.metallic_roughness = metallic_roughness diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 5e4955695..1ad98c589 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -1,9 +1,12 @@ import torch import numpy as np +import math from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import copy import comfy.utils +import comfy.model_management +from comfy_extras.qem_decimate.qem_core import simplify as qem_decimate_simplify, QEMConfig import logging import scipy @@ -337,10 +340,134 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): gl.glDeleteProgram(prog) -def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution): - """For every masked texel, query the nearest voxel and return ALL its - attribute channels. Returns (H, W, C) float32 in [0, 1] where C is the - voxel feature width (3 for plain color, 6 for full PBR).""" +def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution): + """Normalized trilinear interpolation of a SPARSE voxel attribute field. + + The official o_voxel.to_glb trilinear-samples a *dense* attribute volume; here + the field is sparse (only surface voxels carry values), so a plain trilinear + would bleed zeros from empty cells. Instead we accumulate, per query, only the + occupied corners among the 8 surrounding voxels and renormalize by their + weights — i.e. trilinear over the occupied subset. Voxel centres sit at integer + coords c with world position c/resolution - 0.5. + + Returns (vals [K, C] float64, ok [K] bool). `ok` is False where none of the 8 + corners is occupied (caller falls back to nearest there).""" + R = int(resolution) + origin = -0.5 + voxel_size = 1.0 / R + # Cell-CENTER convention: voxel coord c sits at world origin + (c+0.5)*voxel_size, + # matching the official flex_gemm grid_sample_3d (its trilinear weight centers + # integer coord c at query c+0.5). The `- 0.5` puts integer gc on voxel centres + # so the 8 trilinear corners bracket the query correctly. Omitting it samples + # half a voxel toward the corner — colour bleed at boundaries / thin features. + gc = (positions.astype(np.float64) - origin) / voxel_size - 0.5 # continuous voxel-index coords + base = np.floor(gc).astype(np.int64) # [K,3] lower corner + frac = gc - base # [K,3] in [0,1) + + vc = voxel_coords_np.astype(np.int64) + occ_keys = (vc[:, 0] * R + vc[:, 1]) * R + vc[:, 2] # linear key per occupied voxel + order = np.argsort(occ_keys) + occ_sorted = occ_keys[order] + + K = positions.shape[0] + C = color_np.shape[1] + acc = np.zeros((K, C), dtype=np.float64) + wsum = np.zeros((K, 1), dtype=np.float64) + for dx in (0, 1): + wx = frac[:, 0] if dx else 1.0 - frac[:, 0] + for dy in (0, 1): + wy = frac[:, 1] if dy else 1.0 - frac[:, 1] + for dz in (0, 1): + wz = frac[:, 2] if dz else 1.0 - frac[:, 2] + cx = base[:, 0] + dx + cy = base[:, 1] + dy + cz = base[:, 2] + dz + inb = (cx >= 0) & (cx < R) & (cy >= 0) & (cy < R) & (cz >= 0) & (cz < R) + key = (cx * R + cy) * R + cz + ins = np.clip(np.searchsorted(occ_sorted, key), 0, len(occ_sorted) - 1) + matched = inb & (occ_sorted[ins] == key) + idx = order[ins] # original voxel index (garbage where !matched) + w = np.where(matched, wx * wy * wz, 0.0)[:, None] + acc += w * color_np[idx] # w=0 cancels the garbage rows + wsum += w + ok = wsum[:, 0] > 1e-8 + vals = np.zeros((K, C), dtype=np.float64) + vals[ok] = acc[ok] / wsum[ok] + return vals, ok + + +def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): + """GPU nearest-occupied-voxel lookup for surface points. Voxels sit on a + regular integer grid (coord c ↔ world c/R-0.5), so the nearest voxel to a + query is round((p+0.5)*R) plus a 3³ neighbour check — an O(1)-per-query grid + lookup (sorted-key binary search), ~10-30× faster than a cKDTree over millions + of voxels and ~identical. Returns (vals [K,C] float32, found [K] bool); `found` + is False for the rare query whose nearest occupied voxel is >1 cell away (the + caller falls back to a cKDTree on just those).""" + dev = "cuda" if torch.cuda.is_available() else "cpu" + R = int(resolution) + P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float() + VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long() + col = torch.from_numpy(np.ascontiguousarray(color_np)).to(dev).float() + M, K, C = VC.shape[0], P.shape[0], col.shape[1] + key = (VC[:, 0] * R + VC[:, 1]) * R + VC[:, 2] + skey, order = key.sort() + + def _search(idx, radius): + """Nearest occupied voxel within ±radius cells, for query subset P[idx].""" + Ps = P[idx] + # Cell-CENTER convention: voxel c is centred at (c+0.5)/R - 0.5 in world, + # so the coord nearest a point is round((p+0.5)*R - 0.5) (matches the + # official grid_sample_3d). The distance test below uses the same centre. + rc = ((Ps + 0.5) * R - 0.5).round().long() + n = idx.shape[0] + bd = torch.full((n,), 1e30, device=dev) + bi = torch.zeros(n, dtype=torch.long, device=dev) + fnd = torch.zeros(n, dtype=torch.bool, device=dev) + rng = range(-radius, radius + 1) + for dx in rng: + for dy in rng: + for dz in rng: + cc = rc + torch.tensor([dx, dy, dz], device=dev) + inb = ((cc >= 0) & (cc < R)).all(1) + qk = (cc[:, 0] * R + cc[:, 1]) * R + cc[:, 2] + ins = torch.searchsorted(skey, qk).clamp(max=M - 1) + match = inb & (skey[ins] == qk) + dd = (((cc.float() + 0.5) / R - 0.5 - Ps) ** 2).sum(1) + upd = match & (dd < bd) + bd = torch.where(upd, dd, bd) + bi = torch.where(upd, order[ins], bi) + fnd |= match + return bi, fnd + + all_idx = torch.arange(K, device=dev) + best_i = torch.zeros(K, dtype=torch.long, device=dev) + found = torch.zeros(K, dtype=torch.bool, device=dev) + # Pass 1: radius 1 (3³) over everything — catches ~all surface texels cheaply. + bi1, fnd1 = _search(all_idx, 1) + best_i[all_idx] = bi1 + found[all_idx] = fnd1 + # Pass 2: wider radius on ONLY the few misses (avoids ever building a cKDTree + # over millions of voxels just for a handful of >1-cell-away points). + miss = torch.nonzero(~found, as_tuple=True)[0] + if miss.numel() > 0: + bi2, fnd2 = _search(miss, 4) + best_i[miss] = bi2 + found[miss] = fnd2 + vals = col[best_i] + return vals.cpu().numpy(), found.cpu().numpy() + + +def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution, + mode="trilinear"): + """For every masked texel, sample the voxel field and return ALL its attribute + channels. Returns (H, W, C) float32 in [0, 1] where C is the voxel feature + width (3 for plain color, 6 for full PBR). + + mode="trilinear" — normalized trilinear over occupied voxels (the default; matches + the official o_voxel.to_glb path), with nearest fallback for texels whose 8 + surrounding voxels are all empty. This is the only mode the nodes expose now. + mode="nearest" — nearest-voxel; kept as an internal/dev lever (blocky).""" H, W, _ = position_map.shape color_np = voxel_colors.detach().cpu().numpy().astype(np.float32) C = color_np.shape[-1] @@ -350,19 +477,324 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors origin = np.array([-0.5, -0.5, -0.5], dtype=np.float32) voxel_size = 1.0 / float(resolution) - voxel_pos = voxel_coords.detach().cpu().numpy().astype(np.float32) * voxel_size + origin - - tree = scipy.spatial.cKDTree(voxel_pos) + coords_np = voxel_coords.detach().cpu().numpy() + # Cell-CENTER convention (+0.5 voxel), matching the official grid_sample_3d and + # the _trilinear/_nearest paths above; this cKDTree only serves the rare + # >cell-radius nearest fallback but must use the same world mapping. + voxel_pos = (coords_np.astype(np.float32) + 0.5) * voxel_size + origin valid_positions = position_map[mask] - _, nearest_idx = tree.query(valid_positions, k=1, workers=-1) - out[mask] = np.clip(color_np[nearest_idx], 0.0, 1.0) + + def _nearest(query): + # GPU grid lookup; cKDTree only for the rare >1-cell miss. + vals, found = _nearest_voxel_sample_gpu(query, coords_np, color_np, resolution) + if not found.all(): + tree = scipy.spatial.cKDTree(voxel_pos) + _, nearest_idx = tree.query(query[~found], k=1, workers=-1) + vals[~found] = color_np[nearest_idx] + return vals + + if mode == "trilinear": + vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution) + if not ok.all(): + # Texels with no occupied neighbour fall back to nearest. + vals[~ok] = _nearest(valid_positions[~ok]) + out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32) + else: + out[mask] = np.clip(_nearest(valid_positions), 0.0, 1.0) return out +def _closest_point_on_triangles(p, a, b, c): + """Vectorized exact closest point on triangles (Ericson, Real-Time Collision + Detection §5.1.5). p/a/b/c are [..., 3]; returns [..., 3]. Handles all + vertex/edge/face Voronoi regions, applied highest-priority-last via where.""" + ab = b - a + ac = c - a + ap = p - a + d1 = (ab * ap).sum(-1) + d2 = (ac * ap).sum(-1) + bp = p - b + d3 = (ab * bp).sum(-1) + d4 = (ac * bp).sum(-1) + cp = p - c + d5 = (ab * cp).sum(-1) + d6 = (ac * cp).sum(-1) + va = d3 * d6 - d5 * d4 + vb = d5 * d2 - d1 * d6 + vc = d1 * d4 - d3 * d2 + + def u(x): # broadcast a scalar-per-element weight to [...,1] + return x.unsqueeze(-1) + + # face region (default) + denom = 1.0 / (va + vb + vc).clamp_min(1e-20) + v = vb * denom + w = vc * denom + res = a + ab * u(v) + ac * u(w) + # edge BC + den_bc = (d4 - d3) + (d5 - d6) + w_bc = (d4 - d3) / den_bc.clamp_min(1e-20) + res = torch.where(u((va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0)), + b + (c - b) * u(w_bc), res) + # edge AC + w_ac = d2 / (d2 - d6).clamp_min(1e-20) + res = torch.where(u((vb <= 0) & (d2 >= 0) & (d6 <= 0)), a + ac * u(w_ac), res) + # vertex C + res = torch.where(u((d6 >= 0) & (d5 <= d6)), c, res) + # edge AB + v_ab = d1 / (d1 - d3).clamp_min(1e-20) + res = torch.where(u((vc <= 0) & (d1 >= 0) & (d3 <= 0)), a + ab * u(v_ab), res) + # vertex B + res = torch.where(u((d3 >= 0) & (d4 <= d3)), b, res) + # vertex A + res = torch.where(u((d1 <= 0) & (d2 <= 0)), a, res) + return res + + +def _msb_int64(x): + """floor(log2(x)) elementwise for int64 x >= 1 (bit-search, no float).""" + r = torch.zeros_like(x); xx = x.clone() + for s in (32, 16, 8, 4, 2, 1): + sh = xx >> s; m = sh > 0 + r = torch.where(m, r + s, r); xx = torch.where(m, sh, xx) + return r + + +def _morton_expand21(v): + """Spread the low 21 bits of v across every 3rd bit (for a 63-bit Morton code).""" + v = v & 0x1fffff + v = (v | (v << 32)) & 0x1f00000000ffff + v = (v | (v << 16)) & 0x1f0000ff0000ff + v = (v | (v << 8)) & 0x100f00f00f00f00f + v = (v | (v << 4)) & 0x10c30c30c30c30c3 + v = (v | (v << 2)) & 0x1249249249249249 + return v + + +def _build_triangle_bvh(tri): + """Linear BVH (Karras 2012) over triangle AABBs, pure torch, NO external deps. + + 21-bit-per-axis Morton sort of triangle centroids -> parallel radix-tree + construction -> bottom-up node AABBs. Internal nodes are indexed 0..T-2, leaves + are encoded as LEAF+i (i in 0..T-1) where leaf i holds triangle `order[i]`. + Returns a dict with node AABBs (nmin,nmax over 2T entries), child links + (left,right), the leaf->triangle map `order`, LEAF offset and T. + + A real tree (not a uniform grid) is what makes the closest-point query prune + empty space and dense clusters, so it stays fast on huge, non-uniform references + where the grid's ring search blows up — i.e. the cuMesh BVH approach, in torch.""" + dev = tri.device; T = tri.shape[0] + amin = tri.amin(1); amax = tri.amax(1); cent = (amin + amax) * 0.5 + lo = cent.amin(0); hi = cent.amax(0); span = (hi - lo).clamp_min(1e-12) + q = (((cent - lo) / span).clamp(0, 1) * float((1 << 21) - 1)).long() + morton = (_morton_expand21(q[:, 0]) << 2 | _morton_expand21(q[:, 1]) << 1 | _morton_expand21(q[:, 2])).long() + order = torch.argsort(morton); msort = morton[order] + + # delta(i,j): length of the common prefix of the (morton, index) keys of leaves + # i and j (index breaks ties so duplicate Morton codes still split); -1 if OOB. + def delta(i, j): + ok = (j >= 0) & (j < T); jj = j.clamp(0, T - 1) + x = msort[i] ^ msort[jj]; same = x == 0 + cp = torch.where(same, torch.full_like(x, 63), 62 - _msb_int64(x.clamp_min(1))) + xi = i ^ jj + cpi = torch.where(xi == 0, torch.full_like(x, 32), 31 - _msb_int64(xi.clamp_min(1))) + return torch.where(ok, cp + torch.where(same, cpi, torch.zeros_like(cp)), torch.full_like(x, -1)) + + I = torch.arange(T - 1, device=dev) + dplus = delta(I, I + 1); dminus = delta(I, I - 1) + direction = torch.where(dplus >= dminus, torch.ones_like(I), -torch.ones_like(I)) + dmin = torch.minimum(dplus, dminus) + # range length: exponential probe then binary search + lmax = torch.full_like(I, 2) + while True: + cond = delta(I, I + lmax * direction) > dmin + if not bool(cond.any()): + break + lmax = torch.where(cond, lmax * 2, lmax) + if int(lmax.max()) > 2 * T: + break + l = torch.zeros_like(I); t = lmax.clone() + while True: + t = t // 2 + if int(t.max()) == 0: + break + cond = delta(I, I + (l + t) * direction) > dmin + l = torch.where(cond, l + t, l) + j = I + l * direction + first = torch.minimum(I, j); last = torch.maximum(I, j) + # split position: binary search on delta within [first, last] + dnode = delta(first, last) + s = torch.zeros_like(I); div = torch.full_like(I, 2); rng = last - first + while True: + step = (rng + div - 1) // div + cond = delta(first, (first + s + step).clamp(max=T - 1)) > dnode + s = torch.where(cond, s + step, s) + if int(step.max()) <= 1: + cond1 = delta(first, (first + s + 1).clamp(max=T - 1)) > dnode + s = torch.where(cond1, s + 1, s) + break + div = div * 2 + gamma = first + s; LEAF = T + left = torch.where(gamma == first, LEAF + gamma, gamma) + right = torch.where(gamma + 1 == last, LEAF + gamma + 1, gamma + 1) + + # node AABBs: leaves seeded, internal unioned bottom-up over a few passes (a + # balanced tree settles in ~log2(T) passes; the cap is a safety bound). + nmin = torch.empty((2 * T, 3), device=dev); nmax = torch.empty((2 * T, 3), device=dev) + nmin[LEAF:] = amin[order]; nmax[LEAF:] = amax[order] + setm = torch.zeros(2 * T, dtype=torch.bool, device=dev); setm[LEAF:] = True + for _ in range(128): + need = ~setm[:T - 1] + if not bool(need.any()): + break + idx = torch.nonzero(need, as_tuple=True)[0] + ii = idx[setm[left[idx]] & setm[right[idx]]] + if ii.numel() == 0: + break + nmin[ii] = torch.minimum(nmin[left[ii]], nmin[right[ii]]) + nmax[ii] = torch.maximum(nmax[left[ii]], nmax[right[ii]]) + setm[ii] = True + return dict(LEAF=LEAF, left=left, right=right, nmin=nmin, nmax=nmax, order=order, T=T) + + +def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64): + """Exact closest surface point per query, via per-query stack traversal of the + triangle BVH (nearest-child-first for tight pruning), pure torch. Returns [N,3]. + + Each while-iteration advances all still-active queries by one node; the active + set shrinks fast, so even a few thousand iterations are cheap big GPU kernels. + `max_stack` bounds the per-query stack (= tree height); overflow is counted and + warned (a handful of texels could be slightly off) rather than silently wrong.""" + dev = Q.device; N = Q.shape[0] + LEAF = bvh['LEAF']; nmin = bvh['nmin']; nmax = bvh['nmax'] + left = bvh['left']; right = bvh['right']; order = bvh['order'] + stack = torch.full((N, max_stack), -1, dtype=torch.long, device=dev) + sp = torch.ones(N, dtype=torch.long, device=dev); stack[:, 0] = 0 + best = torch.full((N,), 1e30, device=dev); bestp = Q.clone() + active = torch.arange(N, device=dev); overflow = 0 + + def aabb_d2(node, q): + d = (nmin[node] - q).clamp_min(0) + (q - nmax[node]).clamp_min(0) + return (d * d).sum(-1) + + while active.numel() > 0: + a = active; qa = Q[a] + node = stack[a, sp[a] - 1]; sp[a] = sp[a] - 1 + within = aabb_d2(node, qa) < best[a] + isleaf = node >= LEAF + lv = within & isleaf + if bool(lv.any()): + ga = a[lv]; tt = tri[order[node[lv] - LEAF]] + cp = _closest_point_on_triangles(qa[lv], tt[:, 0], tt[:, 1], tt[:, 2]) + d2 = ((cp - qa[lv]) ** 2).sum(-1) + upd = d2 < best[ga]; gu = ga[upd]; best[gu] = d2[upd]; bestp[gu] = cp[upd] + iv = within & ~isleaf + if bool(iv.any()): + gi = a[iv]; qi = qa[iv]; lc = left[node[iv]]; rc = right[node[iv]] + dl = aabb_d2(lc, qi); dr = aabb_d2(rc, qi) + near = torch.where(dl <= dr, lc, rc); far = torch.where(dl <= dr, rc, lc) + s0 = sp[gi] + stack[gi, s0.clamp(max=max_stack - 1)] = far; sp[gi] = (s0 + 1).clamp(max=max_stack) + s1 = sp[gi]; overflow += int((s1 >= max_stack).sum()) + stack[gi, s1.clamp(max=max_stack - 1)] = near; sp[gi] = (s1 + 1).clamp(max=max_stack) + active = a[sp[a] > 0] + if overflow: + logging.warning(f"[back-project] BVH stack overflow on {overflow} pushes " + f"(max_stack={max_stack}); a few texels may be slightly off — " + f"raise max_stack if this is large.") + return bestp + + +def _back_project_positions(position_map, mask, ref_v, ref_f): + """Snap each covered texel's interpolated position onto the reference mesh's true + surface, so the voxel field is sampled at full surface detail instead of along + flat triangle chords (the cause of faceted/pixelized bakes on coarse meshes). + Mirrors o_voxel.to_glb step 7c but with NO cumesh/scipy/trimesh dependency: a + pure-torch linear BVH (`_build_triangle_bvh`) + exact closest-point traversal, + the same approach as cuMesh's cuBVH. Returns a new position_map with the covered + texels replaced.""" + valid = np.ascontiguousarray(position_map[mask].astype(np.float32)) + if valid.shape[0] == 0: + return position_map + + import time as _time + dev = "cuda" if torch.cuda.is_available() else "cpu" + rv = ref_v.detach().to(dev).float() + rf = ref_f.detach().to(dev).long() + tri = rv[rf] + Q = torch.from_numpy(valid).to(dev) + + _t = _time.perf_counter() + bvh = _build_triangle_bvh(tri) + _tb = _time.perf_counter() + bp = _closest_points_on_mesh_bvh(Q, tri, bvh) + logging.info(f"[back-project] BVH build {_tb - _t:.1f}s + traverse " + f"{_time.perf_counter() - _tb:.1f}s ({rf.shape[0]} ref tris, " + f"{valid.shape[0]} texels)") + + out = position_map.copy() + out[mask] = bp.detach().cpu().numpy().astype(position_map.dtype) + return out + + +def _jfa_fill_gpu(img01, mask): + """Fill every uncovered texel with its nearest covered texel's value via GPU + Jump Flooding (O(log n) passes) — a fast nearest-fill replacement for + cv2.inpaint on UV seam/gutter filling. img01 [H,W,C] float, mask [H,W] bool + (True = covered). Returns [H,W,C] float. ~6× faster than cv2 Telea per map.""" + if not mask.any(): + return img01 + dev = "cuda" + it = torch.from_numpy(np.ascontiguousarray(img01)).to(dev).float() + mm = torch.from_numpy(np.ascontiguousarray(mask)).to(dev) + H, W = mm.shape + yy, xx = torch.meshgrid(torch.arange(H, device=dev), torch.arange(W, device=dev), indexing="ij") + by = torch.where(mm, yy, torch.full_like(yy, -1)) + bx = torch.where(mm, xx, torch.full_like(xx, -1)) + INF = torch.full_like(yy, 1 << 30) + step = 1 << ((max(H, W) - 1).bit_length() - 1) + while step >= 1: + for dy in (-step, 0, step): + for dx in (-step, 0, step): + if dy == 0 and dx == 0: + continue + ny = (yy + dy).clamp(0, H - 1) + nx = (xx + dx).clamp(0, W - 1) + cby = by[ny, nx] + cbx = bx[ny, nx] + valid = cby >= 0 + dc = torch.where(valid, (yy - cby) ** 2 + (xx - cbx) ** 2, INF) + db = torch.where(by >= 0, (yy - by) ** 2 + (xx - bx) ** 2, INF) + take = valid & (dc < db) + by = torch.where(take, cby, by) + bx = torch.where(take, cbx, bx) + step //= 2 + filled = it[by.clamp(0).long(), bx.clamp(0).long()] + return filled.cpu().numpy() + + +def _seam_fill(img01, mask, inpaint_radius): + """Fill the UV-gutter texels around covered charts so seam sampling doesn't + pull in black. GPU Jump Flooding (nearest fill) when CUDA is available, else + cv2 Telea inpaint. `inpaint_radius<=0` disables; the radius only affects the + cv2 fallback (JFA fills every uncovered texel by nearest).""" + if inpaint_radius <= 0: + return img01 + if torch.cuda.is_available(): + return _jfa_fill_gpu(img01, mask) + import cv2 + u8 = (img01 * 255.0).clip(0, 255).astype(np.uint8) + u8 = cv2.inpaint(u8, ((~mask).astype(np.uint8)) * 255, int(inpaint_radius), cv2.INPAINT_TELEA) + if u8.ndim == 2: + u8 = u8[..., None] + return u8.astype(np.float32) / 255.0 + + def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, resolution, texture_size, inpaint_radius=3, fast_unwrap=True, existing_uvs=None, - normalize_uvs=True): + normalize_uvs=True, sample_mode="trilinear", + reference=None, pbar=None): """Bake a baseColor (+ optional metallicRoughness) texture for `vertices/faces`, rasterizing in UV space and nearest-voxel-sampling each texel from the provided sparse colored voxel volume. @@ -375,9 +807,30 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, `fast_unwrap=True` configures xatlas with permissive chart options so it finishes in a reasonable time on large meshes — at the cost of less even - UV distribution. Set False to use xatlas defaults (slow on >100k faces).""" + UV distribution. Set False to use xatlas defaults (slow on >100k faces). + + Progress: drives a local tqdm over its 5 stages (unwrap → rasterize → + back-project → sample → finalize) and, if a comfy `pbar` (ProgressBar) is + passed, ticks it once per stage too — so callers should size it as 5 per + bake.""" import time + # 5-stage progress: tqdm (console) + optional comfy ProgressBar (UI). _tick is + # called exactly once at each stage boundary, including no-op stages (e.g. no + # back-projection), so the comfy pbar stays aligned at 5 ticks per bake. + try: + from tqdm import tqdm as _tqdm + _tq = _tqdm(total=5, desc="Bake texture", leave=False) + except Exception: + _tq = None + + def _tick(name): + if _tq is not None: + _tq.set_postfix_str(name) + _tq.update(1) + if pbar is not None: + pbar.update(1) + v_np = vertices.detach().cpu().numpy().astype(np.float32) f_np = faces.detach().cpu().numpy().astype(np.uint32) fcount = int(f_np.shape[0]) @@ -485,17 +938,30 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, new_faces = indices.astype(np.uint32) new_uvs = uvs.astype(np.float32) + _tick("unwrap") + t1 = time.perf_counter() position_map, mask = _bake_position_map(new_verts, new_faces, new_uvs, texture_size) logging.info(f"[BakeTextureFromVoxel] GL rasterize {texture_size}² in {time.perf_counter() - t1:.1f}s " f"({int(mask.sum())}/{mask.size} texels covered)") + _tick("rasterize") + + if reference is not None: + # Back-project texel positions onto the original dense surface before + # sampling — the o_voxel.to_glb step that makes the bake smooth on coarse + # meshes (instead of sampling along flat triangle chords). + tb = time.perf_counter() + position_map = _back_project_positions(position_map, mask, reference[0], reference[1]) + logging.info(f"[BakeTextureFromVoxel] BVH back-project in {time.perf_counter() - tb:.1f}s") + _tick("back-project") t2 = time.perf_counter() attrs = _sample_voxel_attrs_per_texel( - position_map, mask, voxel_coords, voxel_colors, resolution, + position_map, mask, voxel_coords, voxel_colors, resolution, mode=sample_mode, ) logging.info(f"[BakeTextureFromVoxel] voxel sample in {time.perf_counter() - t2:.1f}s " f"({attrs.shape[-1]} channels)") + _tick("sample") # Split into PBR maps. Layout matches upstream pbr_attr_layout: # 0:3 base_color, 3 metallic, 4 roughness, 5 alpha. @@ -507,24 +973,13 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, # alpha channel exists at index 5 but we keep meshes opaque (upstream uses # alpha_mode=OPAQUE in the remesh path); plumb later if needed. - def _inpaint(img01, n_ch): - if inpaint_radius <= 0: - return img01 - import cv2 - u8 = (img01 * 255.0).clip(0, 255).astype(np.uint8) - mask_inv = ((~mask).astype(np.uint8)) * 255 - u8 = cv2.inpaint(u8, mask_inv, int(inpaint_radius), cv2.INPAINT_TELEA) - if u8.ndim == 2: - u8 = u8[..., None] - return u8.astype(np.float32) / 255.0 - t3 = time.perf_counter() - base_color = _inpaint(np.ascontiguousarray(base_color), 3) + base_color = _seam_fill(np.ascontiguousarray(base_color), mask, inpaint_radius) mr_image = None if has_pbr: # glTF metallicRoughness: R unused, G=roughness, B=metallic. mr = np.concatenate([np.zeros_like(roughness), roughness, metallic], axis=-1) - mr_image = _inpaint(np.ascontiguousarray(mr), 3) + mr_image = _seam_fill(np.ascontiguousarray(mr), mask, inpaint_radius) if inpaint_radius > 0: logging.info(f"[BakeTextureFromVoxel] inpaint in {time.perf_counter() - t3:.1f}s") @@ -535,9 +990,83 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, out_tex = torch.from_numpy(np.ascontiguousarray(base_color)).to(device=device, dtype=torch.float32) out_mr = (torch.from_numpy(np.ascontiguousarray(mr_image)).to(device=device, dtype=torch.float32) if mr_image is not None else None) + _tick("finalize") + if _tq is not None: + _tq.close() return out_v, out_f, out_uvs, out_tex, out_mr +def _per_vertex_normals(verts_np, faces_np): + """Area-weighted per-vertex normals (unit length) for a triangle mesh.""" + v = verts_np.astype(np.float64) + f = faces_np.astype(np.int64) + # Un-normalized face normals are area-weighted (cross product magnitude = 2*area), + # so accumulating them onto vertices gives an area-weighted vertex normal. + fn = np.cross(v[f[:, 1]] - v[f[:, 0]], v[f[:, 2]] - v[f[:, 0]]) + vn = np.zeros_like(v) + for k in range(3): + np.add.at(vn, f[:, k], fn) + vn = vn / np.clip(np.linalg.norm(vn, axis=1, keepdims=True), 1e-12, None) + return vn.astype(np.float32) + + +def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resolution, + texture_size, views, blend_temperature=0.25, + inpaint_radius=3, fast_unwrap=True, existing_uvs=None, + normalize_uvs=True, sample_mode="trilinear"): + """Bake a baseColor texture by projecting view photos onto the mesh. + + Reuses bake_texture_from_voxel_fn for the xatlas unwrap + the nearest-voxel + fallback colour, then overlays photo colour on every covered+visible texel: + each texel's world position/normal is projected into each view, occlusion is + resolved with a texel z-buffer, and the views are blended weighted by how + directly each camera faces the surface. Texels seen by no view keep the voxel + colour. The seam inpaint runs last, over the composited result. + + `views`: list of dicts {image[H,W,3] in [0,1], azimuth_deg, transform_matrix[4,4], + camera_angle_x (scalar tensor), image_resolution}. All Pixal3D views share the + one front camera and differ only by azimuth. + + Returns (verts, faces, uvs, tex, mr) — same shape contract as + bake_texture_from_voxel_fn, so the node attaches them identically.""" + from comfy.ldm.trellis2 import multiview_bake as mvbake + + # Voxel bake → unwrapped geometry + fallback colour (inpaint deferred to the end). + out_v, out_f, out_uvs, voxel_tex, voxel_mr = bake_texture_from_voxel_fn( + vertices, faces, voxel_coords, voxel_colors, resolution=resolution, + texture_size=texture_size, inpaint_radius=0, fast_unwrap=fast_unwrap, + existing_uvs=existing_uvs, normalize_uvs=normalize_uvs, sample_mode=sample_mode) + + v_np = out_v.detach().cpu().numpy().astype(np.float32) + f_np = out_f.detach().cpu().numpy().astype(np.uint32) + uv_np = out_uvs.detach().cpu().numpy().astype(np.float32) + + # Per-texel world position + normal (the GL baker outputs any per-vertex vec3). + position_map, mask = _bake_position_map(v_np, f_np, uv_np, texture_size) + normal_map, _ = _bake_position_map(_per_vertex_normals(v_np, f_np), f_np, uv_np, texture_size) + + device = out_v.device + base = voxel_tex.detach().cpu().numpy().copy() + if mask.any() and views: + pos = torch.from_numpy(np.ascontiguousarray(position_map[mask])).to(device) + nrm = torch.from_numpy(np.ascontiguousarray(normal_map[mask])).to(device) + fallback = torch.from_numpy(np.ascontiguousarray(base[mask])).to(device) + view_objs = [{ + "image": vw["image"].to(device), + "azimuth_deg": vw["azimuth_deg"], + "transform_matrix": vw["transform_matrix"].to(device), + "camera_angle_x": vw["camera_angle_x"].to(device), + "image_resolution": vw["image_resolution"], + } for vw in views] + rgb, _seen = mvbake.composite_views(pos, nrm, view_objs, fallback, blend_temperature) + base[mask] = rgb.detach().cpu().numpy() + + base = _seam_fill(np.ascontiguousarray(base), mask, inpaint_radius) + + out_tex = torch.from_numpy(np.ascontiguousarray(base)).to(device=device, dtype=torch.float32) + return out_v, out_f, out_uvs, out_tex, voxel_mr + + class BakeTextureFromVoxel(IO.ComfyNode): @classmethod def define_schema(cls): @@ -546,58 +1075,46 @@ class BakeTextureFromVoxel(IO.ComfyNode): display_name="Bake Texture From Voxel", category="latent/3d", description=( - "Unwraps the mesh with xatlas, rasterizes it in UV space via OpenGL " - "(using ComfyUI's existing PyOpenGL backend), and bakes PBR textures " - "by nearest-voxel sampling of the input sparse voxel volume. Produces " - "a baseColor texture, plus a metallicRoughness texture when the voxel " - "field carries the full PBR set (6 channels). Returns a Mesh with `uvs`, " - "`texture`, and `metallic_roughness` attached — SaveGLB serializes them " - "as real baseColorTexture / metallicRoughnessTexture maps." + "Bakes PBR textures onto the mesh's existing UV layout by rasterizing it " + "in UV space via OpenGL (ComfyUI's PyOpenGL backend) and trilinear-sampling " + "the input sparse voxel volume. Does NOT unwrap — connect a UV unwrap node " + "(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Produces a " + "baseColor texture, plus a metallicRoughness texture when the voxel field " + "carries the full PBR set (6 channels). Returns a Mesh with `uvs`, `texture`, " + "and `metallic_roughness` attached — SaveGLB serializes them as real " + "baseColorTexture / metallicRoughnessTexture maps. UVs that spill outside " + "[0,1] are uniformly fit back into the unit square." ), inputs=[ IO.Mesh.Input("mesh"), IO.Voxel.Input("voxel_colors"), IO.Int.Input("texture_size", default=1024, min=64, max=8192, tooltip="Square texture resolution. Larger = sharper but slower / bigger file."), - IO.Int.Input("inpaint_radius", default=3, min=0, max=32, - tooltip="OpenCV inpaint radius for filling UV seam gutters. 0 disables."), - IO.Boolean.Input("fast_unwrap", default=True, - tooltip=( - "Use looser xatlas chart options to finish unwrap " - "much faster on large meshes (cost: less even UV " - "distribution). Off uses xatlas defaults, which can " - "take many minutes on >100k-face meshes." - )), - IO.Boolean.Input("use_existing_uvs", default=False, - tooltip=( - "Bake onto the mesh's existing UV layout instead of " - "re-unwrapping with xatlas. Requires the input mesh to " - "already carry UVs (e.g. from TorchXatlasUVWrap or a " - "retopologized mesh). Much faster and preserves your " - "UV layout. Ignored if the mesh has no UVs." - )), - IO.Boolean.Input("normalize_uvs", default=True, - tooltip=( - "When using existing UVs that spill outside [0,1] " - "(common with packers that overflow the unit square), " - "uniformly rescale them to fit. Without this, out-of-range " - "regions are clipped and don't bake. Disable only if your " - "UVs are already exactly in [0,1]." - )), + IO.Mesh.Input("reference_mesh", optional=True, + tooltip=( + "Optional original (dense, pre-decimation) mesh. If connected, each " + "texel is back-projected onto its true surface before sampling — the " + "o_voxel.to_glb step that removes faceted/pixelized baking on coarse " + "meshes. Pure scipy+torch, no extra deps.")), ], outputs=[IO.Mesh.Output("mesh")], ) @classmethod - def execute(cls, mesh, voxel_colors, texture_size, inpaint_radius, fast_unwrap, use_existing_uvs, normalize_uvs): + def execute(cls, mesh, voxel_colors, texture_size, reference_mesh=None): + # Seam-gutter inpaint radius is hardcoded to 3 (matches the official to_glb); + # it's an on/off-grade knob — Telea fills the whole gutter regardless of value. + inpaint_radius = 3 voxels = voxel_colors coords = voxels.data colors = voxels.voxel_colors resolution = voxels.resolution mesh_uvs = getattr(mesh, "uvs", None) - if use_existing_uvs and mesh_uvs is None: - logging.warning("BakeTextureFromVoxel: use_existing_uvs=True but mesh has no UVs; " - "falling back to xatlas unwrap.") + if mesh_uvs is None: + raise ValueError( + "BakeTextureFromVoxel: input mesh has no UVs. This node bakes onto the " + "mesh's existing UV layout and never unwraps — connect a UV unwrap node " + "(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) before it.") if coords.shape[-1] == 4: # Sparse coords have a batch column; bake per-item. @@ -605,7 +1122,9 @@ class BakeTextureFromVoxel(IO.ComfyNode): voxel_xyz = coords[:, 1:] mesh_batch_size = int(mesh.vertices.shape[0]) out_verts, out_faces, out_uvs, out_tex, out_mr = [], [], [], [], [] - pbar = comfy.utils.ProgressBar(mesh_batch_size) + # 5 stage ticks per item (see bake_texture_from_voxel_fn); skipped items + # tick all 5 so the bar stays aligned. + pbar = comfy.utils.ProgressBar(mesh_batch_size * 5) for i in range(mesh_batch_size): sel = batch_idx == i item_coords = voxel_xyz[sel] @@ -613,18 +1132,21 @@ class BakeTextureFromVoxel(IO.ComfyNode): v_i, f_i, _ = get_mesh_batch_item(mesh, i) if item_coords.shape[0] == 0 or f_i.numel() == 0: logging.warning(f"BakeTextureFromVoxel: skipping batch {i} (empty voxel/mesh)") - pbar.update(1) + pbar.update(5) continue - ev_i = mesh_uvs[i, :v_i.shape[0]] if (use_existing_uvs and mesh_uvs is not None) else None + ev_i = mesh_uvs[i, :v_i.shape[0]] + ref_i = None + if reference_mesh is not None: + rv_i, rf_i, _ = get_mesh_batch_item(reference_mesh, i) + ref_i = (rv_i, rf_i) bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn( v_i, f_i, item_coords, item_colors, resolution=resolution, texture_size=texture_size, - inpaint_radius=inpaint_radius, fast_unwrap=fast_unwrap, - existing_uvs=ev_i, normalize_uvs=normalize_uvs, + inpaint_radius=inpaint_radius, + existing_uvs=ev_i, reference=ref_i, pbar=pbar, ) out_verts.append(bv); out_faces.append(bf); out_uvs.append(bu) out_tex.append(bt); out_mr.append(bmr) - pbar.update(1) if not out_verts: return IO.NodeOutput(mesh) # Local pack_variable_mesh_batch doesn't take uvs/texture; build the @@ -643,12 +1165,16 @@ class BakeTextureFromVoxel(IO.ComfyNode): # Single-item path. v0 = mesh.vertices.squeeze(0) f0 = mesh.faces.squeeze(0) - ev0 = mesh_uvs.squeeze(0) if (use_existing_uvs and mesh_uvs is not None) else None + ev0 = mesh_uvs.squeeze(0) + ref0 = None + if reference_mesh is not None: + ref0 = (reference_mesh.vertices.squeeze(0), reference_mesh.faces.squeeze(0)) + pbar = comfy.utils.ProgressBar(5) # 5 stage ticks (see bake_texture_from_voxel_fn) bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn( v0, f0, coords, colors, resolution=resolution, texture_size=texture_size, - inpaint_radius=inpaint_radius, fast_unwrap=fast_unwrap, - existing_uvs=ev0, normalize_uvs=normalize_uvs, + inpaint_radius=inpaint_radius, + existing_uvs=ev0, reference=ref0, pbar=pbar, ) out_mesh = Types.MESH( vertices=bv.unsqueeze(0), faces=bf.unsqueeze(0), @@ -1287,486 +1813,6 @@ def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsi return out_v, out_f, colors -def _cleanup_mesh(verts, faces, min_angle_deg=0.5, max_aspect=100.0): - if faces.numel() == 0: - return verts, faces - - v0 = verts[faces[:, 0]] - v1 = verts[faces[:, 1]] - v2 = verts[faces[:, 2]] - e0 = v1 - v0 - e1 = v2 - v1 - e2 = v0 - v2 - l0 = torch.norm(e0, dim=-1) - l1 = torch.norm(e1, dim=-1) - l2 = torch.norm(e2, dim=-1) - n = torch.cross(e0, e2, dim=-1) - area = torch.norm(n, dim=-1) - - max_edge = torch.max(torch.max(l0, l1), l2) - aspect = max_edge * max_edge / (2.0 * area + 1e-12) - - cos_a = (l1 * l1 + l2 * l2 - l0 * l0) / (2 * l1 * l2 + 1e-12) - cos_b = (l0 * l0 + l2 * l2 - l1 * l1) / (2 * l0 * l2 + 1e-12) - cos_c = (l0 * l0 + l1 * l1 - l2 * l2) / (2 * l0 * l1 + 1e-12) - cos_all = torch.stack([cos_a, cos_b, cos_c], dim=-1) - angles = torch.acos(torch.clamp(cos_all, -1, 1)) * 180 / np.pi - - good = (aspect < max_aspect) & (angles.min(dim=1)[0] > min_angle_deg) & (area > 1e-12) - faces = faces[good] - - if faces.numel() == 0: - return verts, faces - - used = torch.zeros(verts.shape[0], dtype=torch.bool, device=verts.device) - used[faces[:, 0]] = True - used[faces[:, 1]] = True - used[faces[:, 2]] = True - - remap = torch.full((verts.shape[0],), -1, dtype=torch.int64, device=verts.device) - remap[used] = torch.arange(used.sum().item(), device=verts.device) - verts = verts[used] - faces = remap[faces] - return verts, faces - -def _pytorch_edge_errors_fast(verts, Q, edges, stabilizer, max_edge_length_sq, mesh_scale_sq): - n_edges = edges.shape[0] - dtype = verts.dtype - if n_edges == 0: - return (torch.empty((0, 3), dtype=dtype, device=verts.device), - torch.empty((0,), dtype=dtype, device=verts.device), - torch.zeros((0,), dtype=torch.bool, device=verts.device)) - - device = verts.device - mesh_scale = (mesh_scale_sq) ** 0.5 - - va = edges[:, 0] - vb = edges[:, 1] - Q0 = Q[va] - Q1 = Q[vb] - Qe = Q0 + Q1 - - A = Qe[:, :3, :3] + torch.eye(3, device=device, dtype=dtype).unsqueeze(0) * stabilizer - b = -Qe[:, :3, 3].unsqueeze(-1) - - dets = torch.det(A) - good = dets.abs() > 1e-12 - opt = torch.zeros((n_edges, 3), dtype=dtype, device=device) - - if good.any(): - try: - sol = torch.linalg.solve(A[good], b[good]) - opt[good] = sol.squeeze(-1) - except Exception: - good = torch.zeros_like(good) - - if (~good).any(): - bad_idx = torch.nonzero(~good, as_tuple=True)[0] - opt[bad_idx] = (verts[va[bad_idx]] + verts[vb[bad_idx]]) * 0.5 - - pa = verts[va] - pb = verts[vb] - el = torch.norm(pb - pa, dim=-1) - dist_a = torch.norm(opt - pa, dim=-1) - dist_b = torch.norm(opt - pb, dim=-1) - wander_bad = (dist_a > 4.0 * el) | (dist_b > 4.0 * el) - - if wander_bad.any(): - bad_idx = torch.nonzero(wander_bad, as_tuple=True)[0] - opt[bad_idx] = (verts[va[bad_idx]] + verts[vb[bad_idx]]) * 0.5 - - v4 = torch.cat([opt, torch.ones((n_edges, 1), device=device, dtype=dtype)], dim=1) - err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4)) - - length_ok = el > mesh_scale * 1e-5 - error_ok = err < max_edge_length_sq - nan_ok = ~torch.isnan(opt).any(dim=-1) & ~torch.isnan(err) - valid = length_ok & error_ok & nan_ok - - return opt, err, valid - - -def _build_quadrics_fast(verts, faces): - v0 = verts[faces[:, 0]] - v1 = verts[faces[:, 1]] - v2 = verts[faces[:, 2]] - e1 = v1 - v0 - e2 = v2 - v0 - n = torch.cross(e1, e2, dim=-1) - area = torch.norm(n, dim=-1) - mask = area > 1e-12 - n_norm = torch.zeros_like(n) - n_norm[mask] = n[mask] / area[mask].unsqueeze(-1) - d = -(n_norm * v0).sum(dim=-1, keepdim=True) - p = torch.cat([n_norm, d], dim=-1) - K = torch.einsum("fi,fj->fij", p, p) - K = K * area[:, None, None] - V = verts.shape[0] - Q = torch.zeros((V, 4, 4), dtype=verts.dtype, device=verts.device) - K_flat = K.reshape(-1, 16) - Q_flat = Q.reshape(V, 16) - for corner in range(3): - idx = faces[:, corner].unsqueeze(1).expand(-1, 16) - Q_flat.scatter_add_(0, idx, K_flat) - return Q_flat.reshape(V, 4, 4) - - -def _gpu_greedy_matching_fast(edges, err, v_alive, max_select): - """Vectorized greedy matching. - - Selects an independent set of edges (no two share a vertex) preferring - lowest error. Replaces _gpu_greedy_sampled's Python per-edge loop with - two scatter_reduce calls. - """ - device = edges.device - n_edges = edges.shape[0] - if n_edges == 0: - return torch.empty(0, dtype=torch.int64, device=device) - - va = edges[:, 0] - vb = edges[:, 1] - num_verts = v_alive.shape[0] - - # Pack (error_bits, edge_idx) into one int64 so amin gives a unique winner. - # err is non-negative finite float32 -> IEEE bits are monotonic. - err32 = err.to(torch.float32).clamp(min=0).contiguous() - err_bits = err32.view(torch.int32).to(torch.int64) & 0xFFFFFFFF - edge_idx = torch.arange(n_edges, device=device, dtype=torch.int64) - key = (err_bits << 32) | edge_idx - - INT64_MAX = torch.iinfo(torch.int64).max - best_key = torch.full((num_verts,), INT64_MAX, dtype=torch.int64, device=device) - best_key.scatter_reduce_(0, va, key, reduce='amin', include_self=True) - best_key.scatter_reduce_(0, vb, key, reduce='amin', include_self=True) - - # An edge wins iff it is the min-key edge incident to BOTH its endpoints - # AND both endpoints are still alive. - is_winner = (key == best_key[va]) & (key == best_key[vb]) & v_alive[va] & v_alive[vb] - - sel = torch.nonzero(is_winner, as_tuple=True)[0] - - if sel.numel() > max_select: - sel_err = err[sel] - top = torch.topk(sel_err, max_select, largest=False).indices - sel = sel[top] - - return sel - - -def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces, device, max_edge_length=None): - # Use float32 instead of float64. RTX-class consumer GPUs run FP32 ~32-64x - # faster than FP64, and QEM only needs the stabilizer for conditioning. - # Always copy=True so we can safely mutate verts/colors/normals in-place. - verts = vertices.detach().to(device=device, dtype=torch.float32, copy=True) - faces = faces_in.detach().to(device=device, dtype=torch.int64) - colors = ( - colors_in.detach().to(device=device, dtype=torch.float32, copy=True) - if colors_in is not None - else None - ) - # ADDED: Initialize normals - normals = ( - normals_in.detach().to(device=device, dtype=torch.float32, copy=True) - if normals_in is not None - else None - ) - - num_verts = verts.shape[0] - num_faces = faces.shape[0] - - logging.debug(f"[QEM-fast] Input: {num_verts} verts, {num_faces} faces, target={target_faces}") - - v_alive = torch.ones(num_verts, dtype=torch.bool, device=device) - f_alive = torch.ones(num_faces, dtype=torch.bool, device=device) - - Q = _build_quadrics_fast(verts, faces) - - bbox = verts.max(dim=0)[0] - verts.min(dim=0)[0] - mesh_scale = torch.norm(bbox).item() - - if max_edge_length is None or max_edge_length <= 0: - max_edge_length = mesh_scale * 2.0 - - if max_edge_length < 1e-6: - max_edge_length = 1.0 - - stabilizer = mesh_scale * mesh_scale * 0.001 - max_edge_length_sq = max_edge_length * max_edge_length - mesh_scale_sq = mesh_scale * mesh_scale - - iteration = 0 - total_collapses = 0 - last_faces = num_faces - - while True: - n_faces = int(f_alive.sum().item()) - - if n_faces <= target_faces: - break - - alive_v = torch.nonzero(v_alive, as_tuple=True)[0] - alive_f = torch.nonzero(f_alive, as_tuple=True)[0] - - if alive_v.numel() <= 4 or alive_f.numel() == 0: - break - - # Compact active mesh - vmap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) - vmap[alive_v] = torch.arange(alive_v.numel(), device=device) - - active_faces = faces[alive_f] - remapped = vmap[active_faces] - - # Extract edges - e0 = remapped[:, [0, 1]] - e1 = remapped[:, [1, 2]] - e2 = remapped[:, [2, 0]] - edges = torch.cat([e0, e1, e2], dim=0) - edges = torch.sort(edges, dim=1)[0] - edges = edges[(edges >= 0).all(dim=1)] - edges = edges[edges[:, 0] != edges[:, 1]] - - if edges.shape[0] == 0: - break - - # Deduplicate edges - num_compact = alive_v.numel() - packed = edges[:, 0].long() * num_compact + edges[:, 1].long() - packed = torch.unique(packed) - edges = torch.stack([packed // num_compact, packed % num_compact], dim=1) - - edges_orig = alive_v[edges] - - # Filter by edge length - pa = verts[edges_orig[:, 0]] - pb = verts[edges_orig[:, 1]] - el = torch.norm(pb - pa, dim=-1) - short_enough = el < max_edge_length - - if not short_enough.any(): - max_edge_length = el.max().item() * 2.0 - max_edge_length_sq = max_edge_length * max_edge_length - short_enough = el < max_edge_length - if not short_enough.any(): - break - - edges_orig = edges_orig[short_enough] - if edges_orig.shape[0] == 0: - break - - # Sample edges for processing - n_edges_total = edges_orig.shape[0] - max_edges_to_process = 10_000_000 - - if n_edges_total > max_edges_to_process: - perm = torch.randint(0, n_edges_total, (max_edges_to_process,), device=device) - edges_orig = edges_orig[perm] - n_edges = max_edges_to_process - else: - n_edges = n_edges_total - - optimal, err, valid = _pytorch_edge_errors_fast( - verts, Q, edges_orig, stabilizer, max_edge_length_sq, mesh_scale_sq - ) - - if not valid.any(): - valid = torch.ones(n_edges, dtype=torch.bool, device=device) - - valid_idx = torch.nonzero(valid, as_tuple=True)[0] - edges_orig = edges_orig[valid_idx] - optimal = optimal[valid_idx] - err = err[valid_idx] - - faces_to_remove = n_faces - target_faces - max_collapses = min(1_000_000, max(10_000, faces_to_remove // 4)) - - sel = _gpu_greedy_matching_fast(edges_orig, err, v_alive, max_collapses) - - if sel.numel() == 0: - break - - v_a = edges_orig[sel, 0] - v_b = edges_orig[sel, 1] - - # Apply collapses - verts[v_a] = optimal[sel] - v_alive[v_b] = False - Q[v_a] += Q[v_b] - - if colors is not None: - colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5 - - if normals is not None: - normals[v_a] = (normals[v_a] + normals[v_b]) * 0.5 - - merge_map = torch.arange(num_verts, device=device) - merge_map[v_b] = v_a - faces = merge_map[faces] - - bad = ( - (faces[:, 0] == faces[:, 1]) - | (faces[:, 1] == faces[:, 2]) - | (faces[:, 2] == faces[:, 0]) - ) - f_alive &= ~bad - - total_collapses += v_a.numel() - iteration += 1 - - if iteration % 50 == 0 or n_faces < last_faces * 0.9: - logging.debug(f"[QEM-fast] Iter {iteration}: {total_collapses} collapses, {int(f_alive.sum().item())} faces, applied {v_a.numel()}") - last_faces = n_faces - - if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5: - faces = faces[f_alive] - f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device) - num_faces = faces.shape[0] - - if iteration > 5000: - break - - # Finalize - final_v = verts[v_alive] - final_c = colors[v_alive] if colors is not None else None - - remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) - remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device) - - final_f_raw = faces[f_alive] - alive_mask = v_alive[final_f_raw].all(dim=1) - final_f_raw = final_f_raw[alive_mask] - final_f = remap[final_f_raw] - valid_faces = (final_f >= 0).all(dim=1) - final_f = final_f[valid_faces] - - if final_f.numel() > 0: - final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0) - - final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0) - - return final_v, final_f, final_c, None - - -def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None): - if vertices.ndim == 3: - v_list, f_list, c_list, n_list = [], [], [], [] - for i in range(vertices.shape[0]): - c_in = colors[i] if colors is not None else None - n_in = normals[i] if normals is not None else None - v_i, f_i, c_i, n_i = simplify_fn_fast(vertices[i], faces[i], c_in, n_in, target, max_edge_length) - v_list.append(v_i) - f_list.append(f_i) - if c_i is not None: - c_list.append(c_i) - if n_i is not None: - n_list.append(n_i) - - c_out = torch.stack(c_list) if len(c_list) > 0 else None - n_out = torch.stack(n_list) if len(n_list) > 0 else None - return torch.stack(v_list), torch.stack(f_list), c_out, n_out - - if faces.shape[0] <= target: - return vertices, faces, colors, normals - - device = vertices.device - dtype = vertices.dtype - face_dtype = faces.dtype - color_dtype = colors.dtype if colors is not None else None - # ADDED: Normal dtype - normal_dtype = normals.dtype if normals is not None else None - - # Pass tensors directly; _qem_simplify_fast handles dtype/device + copy. - out_v, out_f, out_c, out_n = _qem_simplify_fast( - vertices, faces, colors, normals, target, device, max_edge_length - ) - - final_v = out_v.to(device=device, dtype=dtype) - final_f = out_f.to(device=device, dtype=face_dtype) - final_c = ( - out_c.to(device=device, dtype=color_dtype) - if out_c is not None - else None - ) - final_n = ( - out_n.to(device=device, dtype=normal_dtype) - if out_n is not None - else None - ) - return final_v, final_f, final_c, final_n - -def simplify_fn_vertex(vertices, faces, colors=None, target=100000): - if vertices.ndim == 3: - v_list, f_list, c_list = [], [], [] - for i in range(vertices.shape[0]): - c_in = colors[i] if colors is not None else None - v_i, f_i, c_i = simplify_fn_vertex(vertices[i], faces[i], c_in, target) - v_list.append(v_i) - f_list.append(f_i) - if c_i is not None: - c_list.append(c_i) - - c_out = torch.stack(c_list) if len(c_list) > 0 else None - return torch.stack(v_list), torch.stack(f_list), c_out - - if faces.shape[0] <= target: - return vertices, faces, colors - - device = vertices.device - target_v = max(target / 4.0, 1.0) - - min_v = vertices.min(dim=0)[0] - max_v = vertices.max(dim=0)[0] - extent = max_v - min_v - - volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) - cell_size = (volume / target_v) ** (1/3.0) - - # Use CPU-side ordered reductions here so repeated runs produce identical - # simplified meshes instead of relying on GPU scatter-add accumulation order. - vertices_np = vertices.detach().cpu().numpy() - faces_np = faces.detach().cpu().numpy() - colors_np = colors.detach().cpu().numpy() if colors is not None else None - min_v_np = min_v.detach().cpu().numpy() - cell_size_value = float(cell_size.detach().cpu()) - - quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64) - unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True) - num_cells = unique_coords.shape[0] - - new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype) - np.add.at(new_vertices_np, inverse_indices, vertices_np) - - counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1) - new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None) - - new_colors = None - if colors_np is not None: - new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype) - np.add.at(new_colors_np, inverse_indices, colors_np) - new_colors = new_colors_np / np.clip(counts_np, 1, None) - - new_faces = inverse_indices[faces_np] - valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ - (new_faces[:, 1] != new_faces[:, 2]) & \ - (new_faces[:, 2] != new_faces[:, 0]) - new_faces = new_faces[valid_mask] - - if new_faces.size == 0: - final_vertices_np = new_vertices_np[:0] - final_faces_np = np.empty((0, 3), dtype=np.int64) - final_colors_np = new_colors[:0] if new_colors is not None else None - else: - unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True) - final_vertices_np = new_vertices_np[unique_face_indices] - final_faces_np = inv_face.reshape(-1, 3).astype(np.int64) - final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None - - final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype) - final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype) - final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None - - return final_vertices, final_faces, final_colors - def compute_vertex_normals(verts, faces): """Computes area-weighted vertex normals.""" # QUICK FIX: Ensure indices are int64 for scatter_add_ @@ -1846,111 +1892,65 @@ def fix_face_orientation(vertices, faces, reference_normals=None): device = faces.device corrected = faces.clone() - - idx = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.int64, device=device) - edges = corrected[:, idx] # (num_faces, 3, 2) - - edges_canon = torch.sort(edges, dim=2)[0] - edges_flat = edges_canon.view(-1, 2) - max_vert = vertices.shape[0] - edge_hash = edges_flat[:, 0] * max_vert + edges_flat[:, 1] + # Manifold edge adjacency: pair faces that share an edge (run length 2 after + # canonicalizing + sorting edge hashes). + idx = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.int64, device=device) + edges = corrected[:, idx] # (num_faces, 3, 2) directed + edges_canon = torch.sort(edges, dim=2)[0].view(-1, 2) + edge_hash = edges_canon[:, 0] * max_vert + edges_canon[:, 1] hash_sorted, sort_idx = torch.sort(edge_hash) - - hash_diff = hash_sorted[1:] != hash_sorted[:-1] - hash_diff = torch.cat([torch.tensor([True], device=device), hash_diff]) - unique_starts = torch.nonzero(hash_diff, as_tuple=True)[0] - unique_ends = torch.cat([unique_starts[1:], torch.tensor([len(hash_sorted)], device=device)]) - run_lengths = unique_ends - unique_starts - - manifold_mask = run_lengths == 2 - manifold_starts = unique_starts[manifold_mask] - - component_id_np = np.full(num_faces, -1, dtype=np.int64) + start = torch.cat([torch.ones(1, dtype=torch.bool, device=device), + hash_sorted[1:] != hash_sorted[:-1]]) + unique_starts = torch.nonzero(start, as_tuple=True)[0] + unique_ends = torch.cat([unique_starts[1:], + torch.tensor([hash_sorted.numel()], device=device)]) + manifold_starts = unique_starts[(unique_ends - unique_starts) == 2] if manifold_starts.numel() > 0: - # Replaces slow, nested element-wise matching with direct index mapping f_a = sort_idx[manifold_starts] // 3 f_b = sort_idx[manifold_starts + 1] // 3 - local_edge_a = sort_idx[manifold_starts] % 3 - local_edge_b = sort_idx[manifold_starts + 1] % 3 + le_a = sort_idx[manifold_starts] % 3 + le_b = sort_idx[manifold_starts + 1] % 3 + opposite = (edges[f_a, le_a] == edges[f_b, le_b].flip(dims=[1])).all(dim=1) - dir_edge_a = edges[f_a, local_edge_a] - dir_edge_b = edges[f_b, local_edge_b] + # Connected components via scipy (fast C), replacing a per-face Python BFS. + import scipy.sparse + import scipy.sparse.csgraph + fa_np = f_a.cpu().numpy(); fb_np = f_b.cpu().numpy() + graph = scipy.sparse.coo_matrix( + (np.ones(fa_np.shape[0] * 2, dtype=np.int8), + (np.concatenate([fa_np, fb_np]), np.concatenate([fb_np, fa_np]))), + shape=(num_faces, num_faces)) + num_components, comp = scipy.sparse.csgraph.connected_components(graph, directed=False) + component_id = torch.from_numpy(comp.astype(np.int64)).to(device) - opposite = (dir_edge_a == dir_edge_b.flip(dims=[1])).all(dim=1) - needs_flip_rel = ~opposite - - adj_faces = torch.cat([f_a, f_b]) - adj_neighbors = torch.cat([f_b, f_a]) - adj_flip = torch.cat([needs_flip_rel, needs_flip_rel]) - - adj_order = torch.argsort(adj_faces) - adj_faces_np = adj_faces[adj_order].cpu().numpy() - adj_neighbors_np = adj_neighbors[adj_order].cpu().numpy() - adj_flip_np = adj_flip[adj_order].cpu().numpy() - - # Build CSR-style adjacency on CPU using NumPy - adj_ptr_np = np.zeros(num_faces + 1, dtype=np.int64) - counts_np = np.bincount(adj_faces_np, minlength=num_faces) - adj_ptr_np[1:] = np.cumsum(counts_np) - - visited_np = np.zeros(num_faces, dtype=bool) - flip_state_np = np.zeros(num_faces, dtype=bool) - comp_counter = 0 - - queue_np = np.empty(num_faces, dtype=np.int64) - - for seed in range(num_faces): - if visited_np[seed]: - continue - - visited_np[seed] = True - component_id_np[seed] = comp_counter - q_head = 0 - q_tail = 1 - queue_np[0] = seed - - while q_head < q_tail: - current = queue_np[q_head] - q_head += 1 - - start = adj_ptr_np[current] - end = adj_ptr_np[current + 1] - if start == end: - continue - - nbrs = adj_neighbors_np[start:end] - flips = adj_flip_np[start:end] - src_flip = flip_state_np[current] - - unvisited_mask = ~visited_np[nbrs] - if not np.any(unvisited_mask): - continue - - nbrs_new = nbrs[unvisited_mask] - flips_new = flips[unvisited_mask] - - visited_np[nbrs_new] = True - component_id_np[nbrs_new] = comp_counter - - # NumPy bitwise XOR is fast and direct - flip_state_np[nbrs_new] = flips_new ^ src_flip - - n_new = len(nbrs_new) - queue_np[q_tail:q_tail + n_new] = nbrs_new - q_tail += n_new - - comp_counter += 1 - - flip_state = torch.from_numpy(flip_state_np).to(device=device) - component_id = torch.from_numpy(component_id_np).to(device=device) - - if flip_state.any(): - corrected[flip_state] = corrected[flip_state][:, [0, 2, 1]] + # Within-component consistent winding. A QEM output from a consistently wound + # source is already consistent (every shared edge is traversed oppositely) -> + # no flips needed, the common fast path. Otherwise propagate a parity flip + # across the dual graph by vectorized label relaxation (min-root carrying + # parity), instead of the old per-face CPU BFS. + if not bool(opposite.all()): + nf = ~opposite + src = torch.cat([f_a, f_b]); dst = torch.cat([f_b, f_a]); nfd = torch.cat([nf, nf]) + root = torch.arange(num_faces, device=device) + par = torch.zeros(num_faces, dtype=torch.bool, device=device) + for _ in range(num_faces + 8): # breaks at graph diameter; cap is a backstop + cand_root = root[src]; cand_par = par[src] ^ nfd + new_root = root.clone() + new_root.scatter_reduce_(0, dst, cand_root, reduce='amin', include_self=True) + changed = new_root < root + if not bool(changed.any()): + break + apply = changed[dst] & (cand_root == new_root[dst]) + par[dst[apply]] = cand_par[apply] + root = new_root + if bool(par.any()): + corrected[par] = corrected[par][:, [0, 2, 1]] else: component_id = torch.arange(num_faces, device=device) + num_components = num_faces v0 = vertices[corrected[:, 0]] v1 = vertices[corrected[:, 1]] @@ -1959,8 +1959,6 @@ def fix_face_orientation(vertices, faces, reference_normals=None): face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) face_normals = face_normals / (torch.norm(face_normals, dim=-1, keepdim=True) + 1e-8) - num_components = int(component_id.max().item()) + 1 if component_id.numel() > 0 else 0 - if reference_normals is not None: n0 = reference_normals[corrected[:, 0]] n1 = reference_normals[corrected[:, 1]] @@ -1990,15 +1988,15 @@ def fix_face_orientation(vertices, faces, reference_normals=None): coords = face_centroids[:, axis] # Double stable sort acts as a vectorized lexsort on (coords, component_id) - sort_idx = torch.argsort(coords, stable=True) - sort_idx = sort_idx[torch.argsort(component_id[sort_idx], stable=True)] + sort_idx2 = torch.argsort(coords, stable=True) + sort_idx2 = sort_idx2[torch.argsort(component_id[sort_idx2], stable=True)] # Find group boundaries to get the extreme outer face along this axis per component - comp_id_sorted = component_id[sort_idx] + comp_id_sorted = component_id[sort_idx2] group_ends = torch.nonzero(comp_id_sorted[1:] != comp_id_sorted[:-1], as_tuple=True)[0] group_ends = torch.cat([group_ends, torch.tensor([len(comp_id_sorted) - 1], device=device)]) - extreme_face_indices = sort_idx[group_ends] + extreme_face_indices = sort_idx2[group_ends] extreme_normals = face_normals[extreme_face_indices] # Normal's component along the respective axis should be positive @@ -2071,42 +2069,97 @@ def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4): class DecimateMesh(IO.ComfyNode): @classmethod def define_schema(cls): + # placement_mode picks how the merged vertex is positioned, and which extra + # quality knobs are surfaced (DynamicCombo: the qem sub-widgets only appear + # when 'qem' is selected). + placement_options = [ + IO.DynamicCombo.Option(key="midpoint", inputs=[]), + IO.DynamicCombo.Option(key="qem", inputs=[ + IO.Float.Input("line_quadric_weight", default=0.0, min=0.0, max=100.0, step=0.1, + tooltip="Weight of the per-edge line quadric (squared distance to the edge " + "line). Biases collapses to preserve sharp ridges/valleys. 0 = off."), + IO.Float.Input("feature_edge_quadric_weight", default=0.0, min=0.0, max=1000.0, step=1.0, + tooltip="Extra quadric weight on dihedral feature edges (creases). Higher = " + "more aggressively preserves hard edges. 0 = off."), + IO.Float.Input("feature_edge_min_dihedral_deg", default=30.0, min=0.0, max=180.0, step=1.0, + tooltip="Minimum dihedral angle (degrees) for an edge to count as a feature " + "edge for feature_edge_quadric_weight."), + IO.Boolean.Input("clamp_v_to_edge", default=True, + tooltip="Project the QEM-optimal position onto the collapsed edge segment. " + "Prevents inward-cascade drift on curved surfaces."), + ]), + ] return IO.Schema( node_id="DecimateMesh", display_name="Decimate Mesh", category="latent/3d", - description="Simplifies a mesh to a target face count using QEM.", + description=( + "Simplifies a mesh to a target face count using QEM, on the active compute " + "device. 'midpoint' placement uses the cumesh-faithful preset (best quality, " + "preserves thin features / hair). 'qem' places each merged vertex at the QEM " + "optimum and exposes line/feature-edge quadric controls. Output stays welded " + "so it smooth-shades." + ), inputs=[ IO.Mesh.Input("mesh"), IO.Int.Input("target_face_count", default=200_000, min=0, max=50_000_000, tooltip="Target maximum number of faces. Set to 0 to disable."), + IO.DynamicCombo.Input("placement_mode", options=placement_options, + display_name="placement_mode", + tooltip="midpoint: cumesh-faithful preset (recommended). " + "qem: QEM-optimal placement with line/feature-edge controls."), ], outputs=[IO.Mesh.Output("mesh")], + hidden=[IO.Hidden.unique_id], ) @classmethod - def execute(cls, mesh, target_face_count): + def execute(cls, mesh, target_face_count, placement_mode): + mode = placement_mode.get("placement_mode", "midpoint") + if mode == "qem": + # QEM-optimum placement + ratio driver; everything else inherits the defaults. + cfg = QEMConfig( + placement_mode="qem", + line_quadric_weight=float(placement_mode.get("line_quadric_weight", 0.0)), + feature_edge_quadric_weight=float(placement_mode.get("feature_edge_quadric_weight", 0.0)), + feature_edge_min_dihedral_deg=float(placement_mode.get("feature_edge_min_dihedral_deg", 30.0)), + clamp_v_to_edge=bool(placement_mode.get("clamp_v_to_edge", True)), + ) + else: + cfg = QEMConfig() # midpoint placement + threshold driver (the defaults) + + # ComfyUI passes meshes on CPU; the QEM is ~30x slower there. Run on the + # selected compute device and return on the mesh's original device. + compute_device = comfy.model_management.get_torch_device() + + counts = {"in": 0, "out": 0} + def _fn(v, f, c): + counts["in"] += int(f.shape[0]) if target_face_count > 0 and f.shape[0] > target_face_count: try: - v0, v1, v2 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]] - fn = torch.cross(v1 - v0, v2 - v0, dim=-1) - fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8) - - n = torch.zeros_like(v) - n.index_add_(0, f[:, 0], fn) - n.index_add_(0, f[:, 1], fn) - n.index_add_(0, f[:, 2], fn) - n = n / (torch.norm(n, dim=-1, keepdim=True) + 1e-8) - - v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count) - f = fix_face_orientation(v, f) - v, f, c = unweld_and_offset_mesh(v, f, colors=c, z_offset=1e-4) + src_device = v.device + rv, rf, rc, _rn, _rs = qem_decimate_simplify( + v.to(compute_device), f.to(compute_device), int(target_face_count), + colors=(c.to(compute_device) if c is not None else None), + config=cfg) + v = rv.to(src_device) + f = rf.to(src_device) + if rc is not None: + c = rc.to(src_device) except Exception as e: - logging.warning("Ran into an error while QEM Simplifying, falling back to vertex clustering:\n" + str(e)) - v, f, c = simplify_fn_vertex(v, f, c, target_face_count) + logging.warning(f"DecimateMesh: QEM simplify failed, passing mesh through unchanged: {e!r}") + counts["out"] += int(f.shape[0]) return v, f, c - return _process_mesh_batch(mesh, _fn) + + result = _process_mesh_batch(mesh, _fn) + + # Send progress text to display the face reduction on the node + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text( + f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id) + + return result class FillHoles(IO.ComfyNode): diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index 403f268d4..ef471eeee 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -19,7 +19,8 @@ from comfy.cli_args import args from comfy_api.latest import ComfyExtension, IO, Types -def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False): +def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False, + normals=None, metallic_roughness=None): # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, # stashing per-item lengths as runtime attrs so consumers can recover the real slice. # colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. @@ -55,9 +56,20 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non ) packed_uvs[i, :u.shape[0]] = u + packed_normals = None + if normals is not None: + packed_normals = normals[0].new_zeros((batch_size, max_vertices, normals[0].shape[1])) + for i, nrm in enumerate(normals): + assert nrm.shape[0] == vertices[i].shape[0], ( + f"normals[{i}] has {nrm.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)" + ) + packed_normals[i, :nrm.shape[0]] = nrm + return Types.MESH(packed_vertices, packed_faces, uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, - vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit) + metallic_roughness=metallic_roughness, + vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit, + normals=packed_normals) def get_mesh_batch_item(mesh, index): @@ -65,6 +77,7 @@ def get_mesh_batch_item(mesh, index): # if the mesh carries per-item counts (variable-size batch). v_colors = getattr(mesh, "vertex_colors", None) v_uvs = getattr(mesh, "uvs", None) + v_normals = getattr(mesh, "normals", None) if getattr(mesh, "vertex_counts", None) is not None: vertex_count = int(mesh.vertex_counts[index].item()) face_count = int(mesh.face_counts[index].item()) @@ -72,16 +85,102 @@ def get_mesh_batch_item(mesh, index): faces = mesh.faces[index, :face_count] colors = v_colors[index, :vertex_count] if v_colors is not None else None uvs = v_uvs[index, :vertex_count] if v_uvs is not None else None - return vertices, faces, colors, uvs + normals = v_normals[index, :vertex_count] if v_normals is not None else None + return vertices, faces, colors, uvs, normals colors = v_colors[index] if v_colors is not None else None uvs = v_uvs[index] if v_uvs is not None else None - return mesh.vertices[index], mesh.faces[index], colors, uvs + normals = v_normals[index] if v_normals is not None else None + return mesh.vertices[index], mesh.faces[index], colors, uvs, normals + + +def _smooth_vertex_normals(vertices_np, faces_np): + """Area-weighted per-vertex normals (unit length), fully smooth — no vertex splitting. + + Un-normalized face normals (the raw cross product) have magnitude 2*area, so + accumulating them onto their vertices yields an area-weighted average.""" + tris = vertices_np[faces_np] # (M, 3, 3) + face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0]) + normals = np.zeros((vertices_np.shape[0], 3), dtype=np.float64) + for k in range(3): + np.add.at(normals, faces_np[:, k], face_n) + lens = np.linalg.norm(normals, axis=1, keepdims=True) + normals /= np.where(lens > 1e-12, lens, 1.0) + return normals.astype(np.float32) + + +def _compute_vertex_normals(vertices_np, faces_np, crease_angle=None): + """Compute per-vertex normals, returning (vertices, faces_uint32, normals, remap). + + crease_angle is None (or >= 180) -> fully smooth normals; vertices/faces are + returned unchanged and remap is None. + + Otherwise vertices are split along edges whose dihedral angle exceeds + crease_angle (degrees) so hard creases stay sharp while smooth regions still + interpolate. remap maps each output vertex back to its source index, so the + caller can duplicate any per-vertex attributes (uvs / colors) to match.""" + faces_i = faces_np.astype(np.int64) + if crease_angle is None or crease_angle >= 180.0: + return (vertices_np, faces_i.astype(np.uint32), + _smooth_vertex_normals(vertices_np, faces_i), None) + + M = faces_i.shape[0] + tris = vertices_np[faces_i] + face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0]) + areas = np.linalg.norm(face_n, axis=1, keepdims=True) + face_unit = face_n / np.where(areas > 1e-12, areas, 1.0) + cos_thresh = math.cos(math.radians(crease_angle)) + + # Union faces that share an edge whose dihedral angle is below the crease + # threshold; each connected component becomes one smoothing group. + parent = list(range(M)) + + def find(x): + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + edge_faces = {} + for fi in range(M): + a, b, c = int(faces_i[fi, 0]), int(faces_i[fi, 1]), int(faces_i[fi, 2]) + for u, v in ((a, b), (b, c), (c, a)): + edge_faces.setdefault((u, v) if u < v else (v, u), []).append(fi) + for fl in edge_faces.values(): + if len(fl) == 2 and float(np.dot(face_unit[fl[0]], face_unit[fl[1]])) >= cos_thresh: + ra, rb = find(fl[0]), find(fl[1]) + if ra != rb: + parent[ra] = rb + + # Emit one output vertex per (original vertex, smoothing group) pair. + new_index = {} + remap = [] + out_faces = np.empty((M, 3), dtype=np.int64) + for fi in range(M): + g = find(fi) + for k in range(3): + ov = int(faces_i[fi, k]) + key = (ov, g) + ni = new_index.get(key) + if ni is None: + ni = len(remap) + new_index[key] = ni + remap.append(ov) + out_faces[fi, k] = ni + + remap = np.asarray(remap, dtype=np.int64) + normals = np.zeros((remap.shape[0], 3), dtype=np.float64) + for k in range(3): + np.add.at(normals, out_faces[:, k], face_n) + lens = np.linalg.norm(normals, axis=1, keepdims=True) + normals /= np.where(lens > 1e-12, lens, 1.0) + return (vertices_np[remap], out_faces.astype(np.uint32), normals.astype(np.float32), remap) def save_glb(vertices, faces, filepath, metadata=None, uvs=None, vertex_colors=None, texture_image=None, - metallic_roughness_image=None, unlit=False): + metallic_roughness_image=None, unlit=False, + normals=None): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -95,6 +194,9 @@ def save_glb(vertices, faces, filepath, metadata=None, texture_image: PIL.Image - Optional baseColor texture, embedded as PNG metallic_roughness_image: PIL.Image - Optional glTF metallicRoughness texture (R unused, G=roughness, B=metallic), embedded as PNG + normals: torch.Tensor of shape (N, 3) - Optional per-vertex normals, written as the + glTF NORMAL attribute. When omitted, NO normals are written and viewers fall back + to flat (per-face) shading — use the MeshSmoothNormals node to generate them. """ # Convert tensors to numpy arrays @@ -123,6 +225,12 @@ def save_glb(vertices, faces, filepath, metadata=None, raise ValueError( f"save_glb: vertex_colors has {colors_np.shape[0]} entries but vertex count is {n_verts}" ) + + normals_np = normals.cpu().numpy().astype(np.float32) if normals is not None else None + if normals_np is not None and normals_np.shape[0] != n_verts: + raise ValueError( + f"save_glb: normals has {normals_np.shape[0]} entries but vertex count is {n_verts}" + ) faces_np = faces_signed.astype(np.uint32) texture_png_bytes = None if texture_image is not None: @@ -139,6 +247,7 @@ def save_glb(vertices, faces, filepath, metadata=None, indices_buffer = faces_np.tobytes() uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b"" colors_buffer = colors_np.tobytes() if colors_np is not None else b"" + normals_buffer = normals_np.tobytes() if normals_np is not None else b"" texture_buffer = texture_png_bytes if texture_png_bytes is not None else b"" mr_buffer = mr_png_bytes if mr_png_bytes is not None else b"" @@ -150,6 +259,7 @@ def save_glb(vertices, faces, filepath, metadata=None, indices_buffer_padded = pad_to_4_bytes(indices_buffer) uvs_buffer_padded = pad_to_4_bytes(uvs_buffer) colors_buffer_padded = pad_to_4_bytes(colors_buffer) + normals_buffer_padded = pad_to_4_bytes(normals_buffer) texture_buffer_padded = pad_to_4_bytes(texture_buffer) mr_buffer_padded = pad_to_4_bytes(mr_buffer) @@ -158,6 +268,7 @@ def save_glb(vertices, faces, filepath, metadata=None, indices_buffer_padded, uvs_buffer_padded, colors_buffer_padded, + normals_buffer_padded, texture_buffer_padded, mr_buffer_padded, ]) @@ -168,7 +279,8 @@ def save_glb(vertices, faces, filepath, metadata=None, indices_byte_offset = len(vertices_buffer_padded) uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded) colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded) - texture_byte_offset = colors_byte_offset + len(colors_buffer_padded) + normals_byte_offset = colors_byte_offset + len(colors_buffer_padded) + texture_byte_offset = normals_byte_offset + len(normals_buffer_padded) mr_byte_offset = texture_byte_offset + len(texture_buffer_padded) buffer_views = [ @@ -239,6 +351,23 @@ def save_glb(vertices, faces, filepath, metadata=None, }) primitive_attributes["COLOR_0"] = accessor_idx + if normals_np is not None and len(normals_np) > 0: + buffer_views.append({ + "buffer": 0, + "byteOffset": normals_byte_offset, + "byteLength": len(normals_buffer), + "target": 34962 + }) + accessor_idx = len(accessors) + accessors.append({ + "bufferView": len(buffer_views) - 1, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(normals_np), + "type": "VEC3", + }) + primitive_attributes["NORMAL"] = accessor_idx + primitive = { "attributes": primitive_attributes, "indices": 1, @@ -428,7 +557,7 @@ class SaveGLB(IO.ComfyNode): f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}" ) for i in range(mesh.vertices.shape[0]): - vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i) + vertices_i, faces_i, v_colors, uvs_i, normals_i = get_mesh_batch_item(mesh, i) if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0: logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}") continue @@ -444,6 +573,7 @@ class SaveGLB(IO.ComfyNode): texture_image=tex_img, metallic_roughness_image=mr_img, unlit=getattr(mesh, "unlit", False), + normals=normals_i, ) results.append({ "filename": f, @@ -542,13 +672,89 @@ class RotateMesh(IO.ComfyNode): out.vertices = [rotate(v) for v in mesh.vertices] else: out.vertices = rotate(mesh.vertices) + # Normals are directions; rotate them too (R is orthogonal) so they stay valid. + nrm = getattr(mesh, "normals", None) + if nrm is not None: + out.normals = [rotate(n) for n in nrm] if isinstance(nrm, list) else rotate(nrm) + return IO.NodeOutput(out) + + +class MeshSmoothNormals(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshSmoothNormals", + display_name="Smooth Mesh Normals", + category="3d", + description=( + "Compute smooth per-vertex normals and attach them to the mesh. Meshes " + "without normals are shaded flat (per-face) by glTF viewers; this makes " + "them shade smoothly. With crease_angle below 180, edges sharper than the " + "threshold are kept hard by splitting vertices along them." + ), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Float.Input("crease_angle", default=180.0, min=0.0, max=180.0, step=1.0, + tooltip="Edges whose dihedral angle exceeds this (degrees) stay " + "hard (vertices are split). 180 = fully smooth; lower " + "preserves sharp edges (e.g. ~30-60 for hard-surface)."), + ], + outputs=[IO.Mesh.Output("mesh")], + ) + + @classmethod + def execute(cls, mesh: Types.MESH, crease_angle: float) -> IO.NodeOutput: + crease = None if crease_angle >= 180.0 else float(crease_angle) + batch_size = mesh.vertices.shape[0] + + if crease is None: + # Fully smooth: topology is unchanged, so just attach a normals tensor that + # matches the existing (possibly zero-padded) vertex layout and keep all fields. + normals_padded = torch.zeros_like(mesh.vertices) + for i in range(batch_size): + v_i, f_i, _, _, _ = get_mesh_batch_item(mesh, i) + if v_i.shape[0] == 0 or f_i.shape[0] == 0: + continue + n_i = _smooth_vertex_normals(v_i.cpu().numpy().astype(np.float32), + f_i.cpu().numpy().astype(np.int64)) + normals_padded[i, :n_i.shape[0]] = torch.from_numpy(n_i).to(mesh.vertices) + out = copy.copy(mesh) + out.normals = normals_padded + return IO.NodeOutput(out) + + # Crease split changes per-item vertex counts -> rebuild as a variable-size batch. + v_list, f_list, n_list = [], [], [] + c_list = [] if mesh.vertex_colors is not None else None + u_list = [] if mesh.uvs is not None else None + for i in range(batch_size): + v_i, f_i, c_i, u_i, _ = get_mesh_batch_item(mesh, i) + if v_i.shape[0] == 0 or f_i.shape[0] == 0: + continue + dev = v_i.device + vo, fo, no, remap = _compute_vertex_normals( + v_i.cpu().numpy().astype(np.float32), + f_i.cpu().numpy().astype(np.int64), crease) + remap_t = torch.from_numpy(remap) + v_list.append(torch.from_numpy(vo).to(dev, mesh.vertices.dtype)) + f_list.append(torch.from_numpy(fo.astype(np.int64)).to(dev, mesh.faces.dtype)) + n_list.append(torch.from_numpy(no).to(dev, mesh.vertices.dtype)) + if c_list is not None: + c_list.append(c_i[remap_t.to(c_i.device)]) + if u_list is not None: + u_list.append(u_i[remap_t.to(u_i.device)]) + if not v_list: + return IO.NodeOutput(mesh) + out = pack_variable_mesh_batch( + v_list, f_list, colors=c_list, uvs=u_list, + texture=mesh.texture, unlit=getattr(mesh, "unlit", False), + normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None)) return IO.NodeOutput(out) class Save3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [SaveGLB, RotateMesh] + return [SaveGLB, RotateMesh, MeshSmoothNormals] async def comfy_entrypoint() -> Save3DExtension: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 04bc19f25..9ff444014 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -333,6 +333,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") model_frame = samples.get("model_frame", "y_up") + coord_resolution = samples.get("coord_resolution") samples = samples["samples"] samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) @@ -358,7 +359,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3: _calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords) - if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3: + if coord_resolution is not None: + tex_resolution = int(coord_resolution) * 16 + elif voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3: spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords max_idx = int(spatial.max().item()) + 1 tex_resolution = next((r for r in (256, 512, 1024, 1536, 2048) if r >= max_idx), max_idx) diff --git a/comfy_extras/qem_decimate/qem_core.py b/comfy_extras/qem_decimate/qem_core.py new file mode 100644 index 000000000..7727a9273 --- /dev/null +++ b/comfy_extras/qem_decimate/qem_core.py @@ -0,0 +1,1620 @@ +""" +Pure-PyTorch GPU-parallel QEM mesh simplification. + + - Parallel greedy edge-matching collapse loop + - Plane / line / feature-edge / boundary quadrics, memoryless accumulation + - Normal-flip prevention, link-condition, skinny penalties + - Non-manifold / sliver handling without dropping faces + - Pre/post-clean pipeline (weld, degenerates, small components) +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import math +import time as _time + +import numpy as _np +import torch +from scipy.sparse import coo_matrix +from scipy.sparse.csgraph import connected_components +from tqdm import tqdm as _tqdm +import comfy.utils as _comfy_utils + + +@dataclass +class QEMConfig: + # Precision + dtype: torch.dtype = torch.float32 # float64 much slower on consumer GPUs + + # Numerical conditioning + stabilizer_scale: float = 1e-3 # Tikhonov reg: stabilizer = mesh_scale^2 * this + wander_threshold: float = 2.0 # fall back to midpoint if v* lands > N×edge_length from an endpoint + clamp_v_to_edge: bool = True # project v* onto the edge segment (qem mode only) + + # Vertex placement mode (also selects the collapse driver) + # "midpoint" (default): threshold-schedule driver, most stable. The defaults below match it. + # "qem": sharpest, QEM-optimum placement + ratio driver. + placement_mode: str = "midpoint" + + flip_reject_hard: bool = True # hard-reject (err=+inf) top-K collapses that flip any 1-ring normal + + # Per-iteration batch sizing + sampling_cap: int = 10_000_000 # max edges processed per outer iter + max_collapses_fraction: float = 0.25 # of remaining faces-to-remove + max_collapses_floor: int = 10_000 + max_collapses_ceiling: int = 1_000_000 + max_collapses_relative_cap: float = 0.10 # cap per-iter collapses as fraction of current faces; 0 disables + + # Loop control + max_iterations: int = 5_000 + compaction_period: int = 5 + compaction_threshold: float = 0.85 # compact when alive_frac < this + + # Quality knobs + boundary_quadrics: bool = True + boundary_weight: float = 1000.0 + recompute_normals_post: bool = True + line_quadric_weight: float = 0.0 # penalise deviation ⟂ to edge dir → more uniform verts; 0 disables + line_quadric_skip_opposite_normals_cos: float = 0.0 # skip line quadrics on edges with endpoint cos < this + + # Feature-edge quadrics on sharp interior edges (dihedral > min); 0 disables. + feature_edge_quadric_weight: float = 0.0 + feature_edge_min_dihedral_deg: float = 30.0 + + # Flip check (FA-QEM §3.3) + quality_topk_multiplier: int = 4 # quality-check band size = this * max_collapses_per_iter + flip_cos_threshold: float = 0.0 # 0 = count any sign reversal (dihedral > 90°) + flip_check_max_degree: int = 16 # cap on vertex degree for the flip-check table + + # Triangle shape penalty + skinny_weight: float = 1e-3 # penalise top-K collapses producing needle/sliver tris; 0 disables + + # Topology preservation + enforce_link_condition: bool = True # reject collapses that violate the link condition + + # Quadric area weighting + area_weighted_quadrics: bool = False # True: Garland-Heckbert area-weighted; False: un-weighted + + # edge-length cost regularizer + lambda_edge_length: float = 1e-2 # add λ*len² to bias toward short edges; 0 disables + lambda_edge_length_absolute: bool = True # apply λ absolutely vs relative-to-QEM-median + + # Threshold-schedule driver (placement_mode == "midpoint") + # Cost-threshold schedule: each round collapses a disjoint set with cost <= thresh, ×10 when < 1% removed. + threshold_start: float = 1e-8 + memoryless_qem: bool = True # rebuild quadrics each round vs accumulate + repair_nonmanifold: bool = True # final repair_non_manifold_edges pass + + # Pre-clean (input mesh) + preclean: bool = True # weld coincident verts, drop degenerate/duplicate/unused + + # Post-clean (output mesh) + postclean: bool = True # remove slivers, tiny components, unused verts left by collapse + postclean_min_angle_deg: float = 0.5 + postclean_max_aspect_ratio: float = 100.0 + postclean_min_component_faces: int = 8 # drop components with fewer faces than this + + # Preclean tuning + preclean_weld_epsilon_rel: float = 1e-5 # weld tolerance as fraction of bbox diagonal + preclean_min_component_faces: int = 0 # 0 = keep all components + + + @property + def threshold_driver(self) -> bool: + """The cost-threshold collapse driver is used by the midpoint placement mode.""" + return self.placement_mode == "midpoint" + + +def _sorted_edge_halfedges( + faces: torch.Tensor, num_verts: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """3F half-edges sorted by key min(a,b)*(V+1)+max(a,b); returns (sorted_keys, face_ids, slot_ids).""" + device = faces.device + F = faces.shape[0] + e_all = torch.cat([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) + e_sorted, _ = torch.sort(e_all, dim=1) + P = num_verts + 1 + key = e_sorted[:, 0].long() * P + e_sorted[:, 1].long() + face_per_he = torch.arange(F, device=device, dtype=torch.long).repeat(3) + slot_per_he = torch.arange(3, device=device, dtype=torch.long).repeat_interleave(F) + sort_idx = torch.argsort(key) + return key[sort_idx], face_per_he[sort_idx], slot_per_he[sort_idx] + + +def _vert_is_boundary_mask(faces: torch.Tensor, num_verts: int) -> torch.Tensor: + """(V,) bool mask: True for verts incident to any boundary edge.""" + device = faces.device + out = torch.zeros(num_verts, dtype=torch.bool, device=device) + bedges = _detect_boundary_edges(faces, num_verts) + if bedges.numel() == 0: + return out + out[bedges[:, 0]] = True + out[bedges[:, 1]] = True + return out + + +def _detect_boundary_edges(faces: torch.Tensor, num_verts: int) -> torch.Tensor: + """Boundary edges as [N, 2] of vertex indices (each appearing in exactly one face).""" + if faces.numel() == 0: + return torch.empty((0, 2), dtype=torch.int64, device=faces.device) + sorted_keys, _, _ = _sorted_edge_halfedges(faces, num_verts) + unique_key, counts = torch.unique(sorted_keys, return_counts=True) + boundary_key = unique_key[counts == 1] + if boundary_key.numel() == 0: + return torch.empty((0, 2), dtype=torch.int64, device=faces.device) + P = num_verts + 1 + bv0 = boundary_key // P + bv1 = boundary_key % P + return torch.stack([bv0, bv1], dim=1) + + +def _manifold_edge_pairs( + sorted_keys: torch.Tensor, sorted_faces: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Edges shared by exactly 2 faces (filters >2-incident groups); returns (pair_keys, fa, fb).""" + if sorted_keys.shape[0] < 2: + empty = sorted_keys.new_empty(0) + return empty, empty, empty + pair_mask = sorted_keys[:-1] == sorted_keys[1:] + if not pair_mask.any(): + empty = sorted_keys.new_empty(0) + return empty, empty, empty + pair_starts = torch.nonzero(pair_mask, as_tuple=True)[0] + # manifold iff neither neighbour half-edge has the same key + cur = sorted_keys[pair_starts] + prev_ok = (pair_starts == 0) | (sorted_keys[(pair_starts - 1).clamp_min(0)] != cur) + nxt_idx = (pair_starts + 2).clamp(max=sorted_keys.shape[0] - 1) + nxt_ok = (pair_starts + 2 >= sorted_keys.shape[0]) | (sorted_keys[nxt_idx] != cur) + pair_starts = pair_starts[prev_ok & nxt_ok] + return (sorted_keys[pair_starts], + sorted_faces[pair_starts], + sorted_faces[pair_starts + 1]) + + +def _line_quadric_planes( + pa: torch.Tensor, pb: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Two plane equations (E,4) per edge whose squared-dist sum = squared ⟂ distance to the edge line.""" + e = pb - pa # (E, 3) + elen = torch.norm(e, dim=-1, keepdim=True).clamp_min(1e-12) + e_unit = e / elen # (E, 3) + m = 0.5 * (pa + pb) # (E, 3) + # helper axis not parallel to e_unit + helper = torch.zeros_like(e_unit) + helper.scatter_(-1, e_unit.abs().argmin(dim=-1, keepdim=True), 1.0) + # Gram-Schmidt against e_unit + u = helper - (helper * e_unit).sum(-1, keepdim=True) * e_unit + u = u / torch.norm(u, dim=-1, keepdim=True).clamp_min(1e-12) + w = torch.cross(e_unit, u, dim=-1) + d_u = -(u * m).sum(-1, keepdim=True) + d_w = -(w * m).sum(-1, keepdim=True) + p_u = torch.cat([u, d_u], dim=-1) # (E, 4) + p_w = torch.cat([w, d_w], dim=-1) + return p_u, p_w, elen.squeeze(-1) + + +def _add_line_quadrics( + verts: torch.Tensor, + faces: torch.Tensor, + face_areas: torch.Tensor, + Q_flat: torch.Tensor, + weight: float, + skip_he_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Add line quadrics for all 3F half-edges, weighted by face_area*weight; skip_he_mask zeroes True positions.""" + a_all = torch.cat([faces[:, 0], faces[:, 1], faces[:, 2]], dim=0).long() + b_all = torch.cat([faces[:, 1], faces[:, 2], faces[:, 0]], dim=0).long() + pa = verts[a_all] + pb = verts[b_all] + p_u, p_w, _ = _line_quadric_planes(pa, pb) + area_per_edge = face_areas.repeat(3) + w_per_edge = area_per_edge * weight + if skip_he_mask is not None: + w_per_edge = torch.where(skip_he_mask, torch.zeros_like(w_per_edge), w_per_edge) + w_per_edge = w_per_edge.unsqueeze(-1).unsqueeze(-1) + K_line = ( + p_u.unsqueeze(-1) * p_u.unsqueeze(-2) + + p_w.unsqueeze(-1) * p_w.unsqueeze(-2) + ) * w_per_edge + K_flat = K_line.reshape(-1, 16) + Q_flat.scatter_add_(0, a_all.unsqueeze(1).expand(-1, 16), K_flat) # scatter to both endpoints + Q_flat.scatter_add_(0, b_all.unsqueeze(1).expand(-1, 16), K_flat) + return Q_flat + + +def _build_quadrics( + verts: torch.Tensor, + faces: torch.Tensor, + cfg: QEMConfig, +) -> torch.Tensor: + """Per-vertex area-weighted quadric (V, 4, 4).""" + V = verts.shape[0] + dtype = verts.dtype + device = verts.device + + Q_flat = torch.zeros((V, 16), dtype=dtype, device=device) + + if faces.numel() > 0: + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + e1 = v1 - v0 + e2 = v2 - v0 + n = torch.cross(e1, e2, dim=-1) + area = torch.norm(n, dim=-1) + mask = area > 1e-12 + # where() instead of boolean-index gather+scatter (fewer index kernels) + n_norm = torch.where(mask.unsqueeze(-1), + n / area.unsqueeze(-1).clamp_min(1e-12), + n.new_zeros(())) + d = -(n_norm * v0).sum(dim=-1, keepdim=True) + p = torch.cat([n_norm, d], dim=-1) # (F, 4) + K = torch.einsum("fi,fj->fij", p, p) # (F, 4, 4) + + if cfg.area_weighted_quadrics: + K.mul_(area[:, None, None]) + K_flat = K.reshape(-1, 16) + for corner in range(3): + idx = faces[:, corner].unsqueeze(1).expand(-1, 16) + Q_flat.scatter_add_(0, idx, K_flat) + + # Line quadrics: squared ⟂ distance from v to the edge-midpoint line, all 3F half-edges in one pass. + if cfg.line_quadric_weight > 0 and faces.numel() > 0: + # skip thin-shell rim edges (endpoint normals oppose) + skip_he_sharp = None + if cfg.line_quadric_skip_opposite_normals_cos < 1.0: + v_norm = torch.zeros((V, 3), dtype=dtype, device=device) + n_weighted = n_norm * area.unsqueeze(-1) # face normal * 2× area + for corner in range(3): + v_norm.scatter_add_(0, faces[:, corner].unsqueeze(-1).expand(-1, 3), + n_weighted) + v_norm = torch.nn.functional.normalize(v_norm, p=2, dim=-1, eps=1e-12) + a_he = torch.cat([faces[:, 0], faces[:, 1], faces[:, 2]], dim=0).long() + b_he = torch.cat([faces[:, 1], faces[:, 2], faces[:, 0]], dim=0).long() + cos_endpoints = (v_norm[a_he] * v_norm[b_he]).sum(dim=-1) + skip_he_sharp = cos_endpoints < cfg.line_quadric_skip_opposite_normals_cos + if not skip_he_sharp.any(): + skip_he_sharp = None + Q_flat = _add_line_quadrics(verts, faces, area, Q_flat, + cfg.line_quadric_weight, + skip_he_mask=skip_he_sharp) + + # Boundary line quadrics: pin boundary-edge endpoints to the boundary line. + if cfg.boundary_quadrics and faces.numel() > 0: + b_edges = _detect_boundary_edges(faces, V) + if b_edges.shape[0] > 0: + ba = b_edges[:, 0] + bb = b_edges[:, 1] + pa = verts[ba] + pb = verts[bb] + p_u, p_w, _ = _line_quadric_planes(pa, pb) + K_b = (torch.einsum("ei,ej->eij", p_u, p_u) + + torch.einsum("ei,ej->eij", p_w, p_w)) * cfg.boundary_weight + K_b_flat = K_b.reshape(-1, 16) + Q_flat.scatter_add_(0, ba.unsqueeze(1).expand(-1, 16), K_b_flat) + Q_flat.scatter_add_(0, bb.unsqueeze(1).expand(-1, 16), K_b_flat) + + # Feature-edge quadrics: line quadric on sharp interior edges weighted by (1 - cos(dihedral)). + if cfg.feature_edge_quadric_weight > 0 and faces.numel() > 0: + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + fn = torch.nn.functional.normalize(fn, p=2, dim=-1, eps=1e-12) + sorted_keys_fe, sorted_faces_fe, _ = _sorted_edge_halfedges(faces, V) + pair_keys, f1_idx, f2_idx = _manifold_edge_pairs(sorted_keys_fe, sorted_faces_fe) + if pair_keys.numel() > 0: + P = V + 1 + edge_a = pair_keys // P + edge_b = pair_keys % P + cos_dihedral = (fn[f1_idx] * fn[f2_idx]).sum(dim=-1) + cos_thresh = math.cos(math.radians(cfg.feature_edge_min_dihedral_deg)) + sharp = cos_dihedral < cos_thresh + if sharp.any(): + fa = edge_a[sharp] + fb = edge_b[sharp] + p_u, p_w, _ = _line_quadric_planes(verts[fa], verts[fb]) + sharpness = (1.0 - cos_dihedral[sharp]).clamp_min(0.0) + avg_area = 0.5 * (area[f1_idx[sharp]] + area[f2_idx[sharp]]) + w = (avg_area * sharpness * cfg.feature_edge_quadric_weight) \ + .unsqueeze(-1).unsqueeze(-1) + K_feat = ( + p_u.unsqueeze(-1) * p_u.unsqueeze(-2) + + p_w.unsqueeze(-1) * p_w.unsqueeze(-2) + ) * w + K_flat = K_feat.reshape(-1, 16) + Q_flat.scatter_add_(0, fa.unsqueeze(1).expand(-1, 16), K_flat) + Q_flat.scatter_add_(0, fb.unsqueeze(1).expand(-1, 16), K_flat) + + return Q_flat.reshape(V, 4, 4) + + +def _edge_errors( + verts: torch.Tensor, + Q: torch.Tensor, + edges: torch.Tensor, + stabilizer: float, + max_edge_length_sq: float, + mesh_scale_sq: float, + cfg: QEMConfig, + vert_is_boundary: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns (optimal_pos, error, valid_mask); vert_is_boundary enables boundary-aware midpoint.""" + n_edges = edges.shape[0] + dtype = verts.dtype + device = verts.device + + if n_edges == 0: + return ( + torch.empty((0, 3), dtype=dtype, device=device), + torch.empty((0,), dtype=dtype, device=device), + torch.zeros((0,), dtype=torch.bool, device=device), + ) + + verts_pair = verts[edges] # (E, 2, 3) + pa = verts_pair[:, 0] + pb = verts_pair[:, 1] + edge_vec = pb - pa + el = torch.norm(edge_vec, dim=-1) + + # boundary-aware midpoint: snap to the boundary endpoint when exactly one is boundary + if vert_is_boundary is not None: + ba = vert_is_boundary[edges[:, 0]] + bb = vert_is_boundary[edges[:, 1]] + w_a = torch.where(ba & ~bb, torch.ones_like(el), + torch.where(~ba & bb, torch.zeros_like(el), + torch.full_like(el, 0.5))) + midpoint = pa * w_a.unsqueeze(-1) + pb * (1.0 - w_a).unsqueeze(-1) + else: + midpoint = torch.lerp(pa, pb, 0.5) + + Qe = Q[edges].sum(dim=1) # (E, 4, 4) — sum of Q[va] and Q[vb] + + if cfg.placement_mode == "midpoint": + opt = midpoint + else: + A = Qe[:, :3, :3] + torch.eye(3, device=device, dtype=dtype) * stabilizer + b = -Qe[:, :3, 3].unsqueeze(-1) + + # stabilizer keeps A invertible; solve full batch and pick midpoint via where (no host sync) + sol = torch.linalg.solve(A, b) + dets = torch.det(A) + good = (dets.abs() > 1e-12).unsqueeze(-1) + opt = torch.where(good, sol.squeeze(-1), midpoint) + + if cfg.clamp_v_to_edge: + # qem mode + clamp: project v* onto the edge segment (subsumes the wander check) + edge_len_sq = (edge_vec * edge_vec).sum(dim=-1) + 1e-20 + t = ((opt - pa) * edge_vec).sum(dim=-1) / edge_len_sq + t = t.clamp(0.0, 1.0).unsqueeze(-1) + opt = torch.lerp(pa, pb, t) + else: + # qem mode + no clamp: fall back to midpoint when v* wanders from both endpoints + dist_a = torch.norm(opt - pa, dim=-1) + dist_b = torch.norm(opt - pb, dim=-1) + wander_bad = ((dist_a > cfg.wander_threshold * el) | + (dist_b > cfg.wander_threshold * el)).unsqueeze(-1) + opt = torch.where(wander_bad, midpoint, opt) + + v4 = torch.cat([opt, torch.ones((n_edges, 1), device=device, dtype=dtype)], dim=1) + err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4)) + + # mesh_scale_sq may be Python float or 0-d tensor + if torch.is_tensor(mesh_scale_sq): + length_ok = el * el > mesh_scale_sq * 1e-10 + else: + length_ok = el > math.sqrt(mesh_scale_sq) * 1e-5 + error_ok = err < max_edge_length_sq + nan_ok = ~torch.isnan(opt).any(dim=-1) & ~torch.isnan(err) + valid = length_ok & error_ok & nan_ok + + # edge-length regularizer: bias collapse order toward short edges (uniform sizing) + if cfg.lambda_edge_length > 0.0 and valid.any(): + el2 = el * el + if cfg.lambda_edge_length_absolute: + err = err + cfg.lambda_edge_length * el2 + else: + qem_med = err[valid].median() + len_med = el2[valid].median().clamp_min(1e-30) + err = err + cfg.lambda_edge_length * el2 * (qem_med / len_med) + return opt, err, valid + + +def _greedy_matching( + edges: torch.Tensor, + err: torch.Tensor, + v_alive: torch.Tensor, + max_select: int, +) -> torch.Tensor: + """Vectorised independent edge-set selection: an edge wins iff it is the min-key edge at both endpoints.""" + device = edges.device + n_edges = edges.shape[0] + if n_edges == 0: + return torch.empty(0, dtype=torch.int64, device=device) + + va = edges[:, 0] + vb = edges[:, 1] + num_verts = v_alive.shape[0] + + err32 = err.to(torch.float32).clamp(min=0).contiguous() + err_bits = err32.view(torch.int32).to(torch.int64) & 0xFFFFFFFF + edge_idx = torch.arange(n_edges, device=device, dtype=torch.int64) + key = (err_bits << 32) | edge_idx + + INT64_MAX = torch.iinfo(torch.int64).max + best_key = torch.full((num_verts,), INT64_MAX, dtype=torch.int64, device=device) + best_key.scatter_reduce_(0, va, key, reduce="amin", include_self=True) + best_key.scatter_reduce_(0, vb, key, reduce="amin", include_self=True) + + is_winner = (key == best_key[va]) & (key == best_key[vb]) & v_alive[va] & v_alive[vb] + sel = torch.nonzero(is_winner, as_tuple=True)[0] + + if sel.numel() > max_select: + sel_err = err[sel] + top = torch.topk(sel_err, max_select, largest=False).indices + sel = sel[top] + return sel + + +def _build_vert_to_faces_pad( + faces: torch.Tensor, + num_verts: int, + max_deg: int, +) -> torch.Tensor: + """Pad-CSR vertex-to-incident-faces table (V, max_deg) of face indices, -1 padded, degree truncated.""" + device = faces.device + F = faces.shape[0] + if F == 0: + return torch.full((num_verts, max_deg), -1, dtype=torch.int64, device=device) + v_rep = faces.flatten().long() + f_rep = torch.arange(F, device=device, dtype=torch.int64).repeat_interleave(3) + sort_idx = v_rep.argsort() + sorted_v = v_rep[sort_idx] + sorted_f = f_rep[sort_idx] + offsets = torch.searchsorted( + sorted_v, torch.arange(num_verts + 1, device=device, dtype=sorted_v.dtype) + ) + slot = torch.arange(sorted_v.shape[0], device=device, dtype=torch.int64) - offsets[sorted_v] + keep = slot < max_deg + table = torch.full((num_verts, max_deg), -1, dtype=torch.int64, device=device) + table[sorted_v[keep], slot[keep]] = sorted_f[keep] + return table + + +def _normal_flip_mask( + verts: torch.Tensor, # (V, 3) + faces: torch.Tensor, # (F, 3) — must be alive faces only + edges: torch.Tensor, # (E, 2) candidate collapse edges + opt: torch.Tensor, # (E, 3) proposed collapse positions + vert_to_faces: torch.Tensor, # (V, max_deg) face indices or -1 + cos_threshold: float = 0.0, + chunk_size: int = 100_000, + return_count: bool = False, +) -> torch.Tensor: + """(E,) bool mask (no adjacent-face flip), or int count of would-flip faces per edge if return_count.""" + E = edges.shape[0] + device = verts.device + if return_count: + out = torch.zeros(E, dtype=torch.int32, device=device) + else: + out = torch.ones(E, dtype=torch.bool, device=device) + if E == 0: + return out + + max_deg = vert_to_faces.shape[1] + a_all = edges[:, 0] + b_all = edges[:, 1] + + for start in range(0, E, chunk_size): + stop = min(start + chunk_size, E) + Ec = stop - start + a = a_all[start:stop] + b = b_all[start:stop] + oc = opt[start:stop] + + fa = vert_to_faces[a] # (Ec, max_deg) + fb = vert_to_faces[b] + all_f = torch.cat([fa, fb], dim=1) # (Ec, 2*max_deg) + valid_f = all_f >= 0 + all_f_safe = all_f.clamp(min=0) + fv = faces[all_f_safe] # (Ec, 2*max_deg, 3) + + a_b = a.view(Ec, 1) + b_b = b.view(Ec, 1) + s0_a = fv[..., 0] == a_b; s0_b = fv[..., 0] == b_b + s1_a = fv[..., 1] == a_b; s1_b = fv[..., 1] == b_b + s2_a = fv[..., 2] == a_b; s2_b = fv[..., 2] == b_b + contains_a = s0_a | s1_a | s2_a + contains_b = s0_b | s1_b | s2_b + # affected: face contains exactly one of {a, b} and slot is non-pad + affected = (contains_a ^ contains_b) & valid_f + if not affected.any(): + continue + + p0 = verts[fv[..., 0]] # (Ec, 2*max_deg, 3) + p1 = verts[fv[..., 1]] + p2 = verts[fv[..., 2]] + n_old = torch.cross(p1 - p0, p2 - p0, dim=-1) + + opt_b = oc.view(Ec, 1, 3).expand(-1, 2 * max_deg, -1) + rep0 = (s0_a | s0_b).unsqueeze(-1) + rep1 = (s1_a | s1_b).unsqueeze(-1) + rep2 = (s2_a | s2_b).unsqueeze(-1) + p0n = torch.where(rep0, opt_b, p0) + p1n = torch.where(rep1, opt_b, p1) + p2n = torch.where(rep2, opt_b, p2) + n_new = torch.cross(p1n - p0n, p2n - p0n, dim=-1) + + nlen_old = torch.norm(n_old, dim=-1) + nlen_new = torch.norm(n_new, dim=-1) + # degenerate-before faces can't meaningfully flip; treat as OK + denom = nlen_old * nlen_new + safe = denom > 1e-20 + cos = torch.where(safe, (n_old * n_new).sum(dim=-1) / denom.clamp_min(1e-20), + torch.ones_like(denom)) + flip = (cos < cos_threshold) & affected & safe + if return_count: + out[start:stop] = flip.sum(dim=-1).to(torch.int32) + else: + out[start:stop] = ~flip.any(dim=-1) + + return out + + +def _link_condition_mask( + faces: torch.Tensor, # (F, 3) alive faces only + edges: torch.Tensor, # (E, 2) candidate collapse edges + vert_to_faces: torch.Tensor, # (V, max_deg) face idx or -1 + chunk_size: int = 100_000, +) -> torch.Tensor: + """(E,) bool mask — True where the collapse is topology-safe (link condition: common neighbours <= edge faces).""" + E = edges.shape[0] + device = faces.device + out = torch.ones(E, dtype=torch.bool, device=device) + if E == 0: + return out + D = vert_to_faces.shape[1] + a_all = edges[:, 0] + b_all = edges[:, 1] + + for s in range(0, E, chunk_size): + e = min(s + chunk_size, E) + a = a_all[s:e]; b = b_all[s:e] + Ec = a.shape[0] + + fa = vert_to_faces[a] # (Ec, D) + fb = vert_to_faces[b] + fa_ok = fa >= 0; fb_ok = fb >= 0 + fav = faces[fa.clamp(min=0)] # (Ec, D, 3) + fbv = faces[fb.clamp(min=0)] + + # neighbour verts of a/b: take the 2 non-anchor verts per incident face → (Ec, 2D) + a_b = a[:, None]; b_b = b[:, None] + an1 = torch.where(fav[..., 0] == a_b, fav[..., 1], fav[..., 0]) + an2 = torch.where(fav[..., 2] == a_b, fav[..., 1], fav[..., 2]) + bn1 = torch.where(fbv[..., 0] == b_b, fbv[..., 1], fbv[..., 0]) + bn2 = torch.where(fbv[..., 2] == b_b, fbv[..., 1], fbv[..., 2]) + na = torch.stack([an1, an2], dim=-1).reshape(Ec, 2 * D) + nb = torch.stack([bn1, bn2], dim=-1).reshape(Ec, 2 * D) + fa_okx = fa_ok.repeat_interleave(2, dim=1) + fb_okx = fb_ok.repeat_interleave(2, dim=1) + na[(na == a_b) | (na == b_b) | ~fa_okx] = -1 + nb[(nb == a_b) | (nb == b_b) | ~fb_okx] = -1 + + # common neighbours: na entries also appearing in nb + in_b = (na[:, :, None] == nb[:, None, :]) & (na[:, :, None] >= 0) + na_common = torch.where(in_b.any(dim=2), na, torch.full_like(na, -1)) + # distinct count of common neighbours per edge (sort + count transitions) + cs, _ = na_common.sort(dim=1) + count_common = ((cs[:, 1:] != cs[:, :-1]) & (cs[:, 1:] >= 0)).sum(dim=1) \ + + (cs[:, :1] >= 0).sum(dim=1) + + # faces on the edge = a's faces also containing b + count_faces = ((fav == b[:, None, None]).any(dim=2) & fa_ok).sum(dim=1) + + out[s:e] = count_common <= count_faces + + return out + + +def _skinny_penalty( + verts: torch.Tensor, # (V, 3) + faces: torch.Tensor, # (F, 3) — alive faces only + edges: torch.Tensor, # (E, 2) candidate collapse edges + opt: torch.Tensor, # (E, 3) proposed collapse positions + vert_to_faces: torch.Tensor, # (V, max_deg) + chunk_size: int = 100_000, +) -> torch.Tensor: + """Per-edge post-collapse triangle-shape penalty (lambda_skinny); mean of 1 - clamp(shape,0,1) over the 1-ring.""" + E = edges.shape[0] + device = verts.device + out = torch.zeros(E, dtype=verts.dtype, device=device) + if E == 0: + return out + + max_deg = vert_to_faces.shape[1] + a_all = edges[:, 0] + b_all = edges[:, 1] + sqrt3_4 = 4.0 * math.sqrt(3.0) + + for start in range(0, E, chunk_size): + stop = min(start + chunk_size, E) + Ec = stop - start + a = a_all[start:stop] + b = b_all[start:stop] + oc = opt[start:stop] + + fa = vert_to_faces[a] + fb = vert_to_faces[b] + all_f = torch.cat([fa, fb], dim=1) + valid_f = all_f >= 0 + all_f_safe = all_f.clamp(min=0) + fv = faces[all_f_safe] + + a_b = a.view(Ec, 1) + b_b = b.view(Ec, 1) + s0_a = fv[..., 0] == a_b + s0_b = fv[..., 0] == b_b + s1_a = fv[..., 1] == a_b + s1_b = fv[..., 1] == b_b + s2_a = fv[..., 2] == a_b + s2_b = fv[..., 2] == b_b + contains_a = s0_a | s1_a | s2_a + contains_b = s0_b | s1_b | s2_b + # affected: face contains exactly one of {a, b} and slot is non-pad + affected = (contains_a ^ contains_b) & valid_f + if not affected.any(): + continue + + p0 = verts[fv[..., 0]] + p1 = verts[fv[..., 1]] + p2 = verts[fv[..., 2]] + opt_b = oc.view(Ec, 1, 3).expand(-1, 2 * max_deg, -1) + rep0 = (s0_a | s0_b).unsqueeze(-1) + rep1 = (s1_a | s1_b).unsqueeze(-1) + rep2 = (s2_a | s2_b).unsqueeze(-1) + p0n = torch.where(rep0, opt_b, p0) + p1n = torch.where(rep1, opt_b, p1) + p2n = torch.where(rep2, opt_b, p2) + + e01 = p1n - p0n + e02 = p2n - p0n + e12 = p2n - p1n + two_area = torch.cross(e01, e02, dim=-1).norm(dim=-1) + edge_sum_sq = ((e01 * e01).sum(-1) + + (e02 * e02).sum(-1) + + (e12 * e12).sum(-1)) + shape = (sqrt3_4 * 0.5 * two_area) / edge_sum_sq.clamp_min(1e-20) + term = 1.0 - shape.clamp(0.0, 1.0) + term = torch.where(affected, term, torch.zeros_like(term)) + n_affected = affected.sum(dim=-1).clamp_min(1).to(term.dtype) + out[start:stop] = term.sum(dim=-1) / n_affected + + return out + + +def _quality_checks_fused( + verts: torch.Tensor, + faces: torch.Tensor, + edges: torch.Tensor, + opt: torch.Tensor, + vert_to_faces: torch.Tensor, + cos_threshold: float = 0.0, + want_flip: bool = True, + want_skinny: bool = True, + want_link: bool = False, + chunk_size: int = 100_000, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Fused 1-ring checks (flip count / skinny / link) sharing one faces[v_to_f] gather; returns (flip_count|None, skinny|None, link_safe|None).""" + E = edges.shape[0] + device = verts.device + flip_out = torch.zeros(E, dtype=torch.int32, device=device) if want_flip else None + skinny_out = torch.zeros(E, dtype=verts.dtype, device=device) if want_skinny else None + link_out = torch.ones(E, dtype=torch.bool, device=device) if want_link else None + if E == 0: + return flip_out, skinny_out, link_out + + D = vert_to_faces.shape[1] + a_all = edges[:, 0] + b_all = edges[:, 1] + sqrt3_4 = 4.0 * math.sqrt(3.0) + need_geom = want_flip or want_skinny + + for start in range(0, E, chunk_size): + stop = min(start + chunk_size, E) + Ec = stop - start + a = a_all[start:stop] + b = b_all[start:stop] + + # shared gather: a's and b's incident faces (the expensive part) + fa = vert_to_faces[a] + fb = vert_to_faces[b] + all_f = torch.cat([fa, fb], dim=1) # (Ec, 2D) + valid_f = all_f >= 0 + fv = faces[all_f.clamp(min=0)] # (Ec, 2D, 3) + a_b = a.view(Ec, 1) + b_b = b.view(Ec, 1) + + if need_geom: + oc = opt[start:stop] + s0_a = fv[..., 0] == a_b; s0_b = fv[..., 0] == b_b + s1_a = fv[..., 1] == a_b; s1_b = fv[..., 1] == b_b + s2_a = fv[..., 2] == a_b; s2_b = fv[..., 2] == b_b + contains_a = s0_a | s1_a | s2_a + contains_b = s0_b | s1_b | s2_b + affected = (contains_a ^ contains_b) & valid_f + if affected.any(): + p0 = verts[fv[..., 0]] + p1 = verts[fv[..., 1]] + p2 = verts[fv[..., 2]] + opt_b = oc.view(Ec, 1, 3).expand(-1, 2 * D, -1) + rep0 = (s0_a | s0_b).unsqueeze(-1) + rep1 = (s1_a | s1_b).unsqueeze(-1) + rep2 = (s2_a | s2_b).unsqueeze(-1) + p0n = torch.where(rep0, opt_b, p0) + p1n = torch.where(rep1, opt_b, p1) + p2n = torch.where(rep2, opt_b, p2) + + # post-collapse normal (skinny's two_area == flip's ‖n_new‖) + e01 = p1n - p0n + e02 = p2n - p0n + n_new = torch.cross(e01, e02, dim=-1) + nlen_new = torch.norm(n_new, dim=-1) + + if want_flip: + n_old = torch.cross(p1 - p0, p2 - p0, dim=-1) + nlen_old = torch.norm(n_old, dim=-1) + denom = nlen_old * nlen_new + safe = denom > 1e-20 + cos = torch.where(safe, (n_old * n_new).sum(dim=-1) / denom.clamp_min(1e-20), + torch.ones_like(denom)) + flip = (cos < cos_threshold) & affected & safe + flip_out[start:stop] = flip.sum(dim=-1).to(torch.int32) + + if want_skinny: + e12 = p2n - p1n + edge_sum_sq = ((e01 * e01).sum(-1) + (e02 * e02).sum(-1) + (e12 * e12).sum(-1)) + shape = (sqrt3_4 * 0.5 * nlen_new) / edge_sum_sq.clamp_min(1e-20) + term = 1.0 - shape.clamp(0.0, 1.0) + term = torch.where(affected, term, torch.zeros_like(term)) + n_affected = affected.sum(dim=-1).clamp_min(1).to(term.dtype) + skinny_out[start:stop] = term.sum(dim=-1) / n_affected + + if want_link: + # reuses fv / valid_f; matches _link_condition_mask + fa_ok = valid_f[:, :D] + fb_ok = valid_f[:, D:] + fav = fv[:, :D] + fbv = fv[:, D:] + an1 = torch.where(fav[..., 0] == a_b, fav[..., 1], fav[..., 0]) + an2 = torch.where(fav[..., 2] == a_b, fav[..., 1], fav[..., 2]) + bn1 = torch.where(fbv[..., 0] == b_b, fbv[..., 1], fbv[..., 0]) + bn2 = torch.where(fbv[..., 2] == b_b, fbv[..., 1], fbv[..., 2]) + na = torch.stack([an1, an2], dim=-1).reshape(Ec, 2 * D) + nb = torch.stack([bn1, bn2], dim=-1).reshape(Ec, 2 * D) + fa_okx = fa_ok.repeat_interleave(2, dim=1) + fb_okx = fb_ok.repeat_interleave(2, dim=1) + na[(na == a_b) | (na == b_b) | ~fa_okx] = -1 + nb[(nb == a_b) | (nb == b_b) | ~fb_okx] = -1 + in_b = (na[:, :, None] == nb[:, None, :]) & (na[:, :, None] >= 0) + na_common = torch.where(in_b.any(dim=2), na, torch.full_like(na, -1)) + cs, _ = na_common.sort(dim=1) + count_common = ((cs[:, 1:] != cs[:, :-1]) & (cs[:, 1:] >= 0)).sum(dim=1) \ + + (cs[:, :1] >= 0).sum(dim=1) + count_faces = ((fav == b[:, None, None]).any(dim=2) & fa_ok).sum(dim=1) + link_out[start:stop] = count_common <= count_faces + + return flip_out, skinny_out, link_out + + +def _compute_vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: + if faces.numel() == 0: + return torch.zeros_like(verts) + faces_long = faces.to(torch.int64) + i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2] + v0, v1, v2 = verts[i0], verts[i1], verts[i2] + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + vn = torch.zeros_like(verts) + vn.scatter_add_(0, i0.unsqueeze(-1).expand_as(fn), fn) + vn.scatter_add_(0, i1.unsqueeze(-1).expand_as(fn), fn) + vn.scatter_add_(0, i2.unsqueeze(-1).expand_as(fn), fn) + return torch.nn.functional.normalize(vn, p=2, dim=-1, eps=1e-6) + + +# Public API + +@dataclass +class CleanStats: + in_verts: int = 0 + in_faces: int = 0 + out_verts: int = 0 + out_faces: int = 0 + welded_verts: int = 0 # how many vertex IDs collapsed during welding + degenerate_faces: int = 0 # zero-area or repeated-index faces removed + duplicate_faces: int = 0 # same vertex-set removed + unused_verts: int = 0 # verts not in any face removed + components_dropped: int = 0 # disconnected components below threshold + seconds: float = 0.0 + + def __str__(self): + return (f"clean: in={self.in_verts}v/{self.in_faces}f -> " + f"out={self.out_verts}v/{self.out_faces}f " + f"(welded {self.welded_verts}v, degen {self.degenerate_faces}f, " + f"dup {self.duplicate_faces}f, unused {self.unused_verts}v, " + f"comps {self.components_dropped}) {self.seconds*1000:.1f}ms") + + +def _weld_vertices( + verts: torch.Tensor, faces: torch.Tensor, epsilon, + colors: Optional[torch.Tensor] = None, + normals: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], int]: + """Merge vertices closer than epsilon (L_inf grid), cluster-averaging attributes; returns (v, f, colors, normals, n_welded).""" + if verts.shape[0] == 0: + return verts, faces, colors, normals, 0 + device = verts.device + scale = 1.0 / epsilon + bbox_min = verts.min(dim=0)[0] + q = ((verts - bbox_min) * scale).round().to(torch.int64) + bbox = (verts.max(dim=0)[0] - bbox_min) + extent = (bbox * scale).round().to(torch.int64) + 2 + key = (q[:, 0] * extent[1] + q[:, 1]) * extent[2] + q[:, 2] # pack 3D quantized pos to 1D key + unique_key, inv = torch.unique(key, return_inverse=True) + n_unique = unique_key.shape[0] + if n_unique == verts.shape[0]: + return verts, faces, colors, normals, 0 + counts = torch.zeros(n_unique, dtype=verts.dtype, device=device) + counts.scatter_add_(0, inv, torch.ones(verts.shape[0], dtype=verts.dtype, device=device)) + counts_div = counts.unsqueeze(-1).clamp_min(1.0) + + new_verts = torch.zeros((n_unique, 3), dtype=verts.dtype, device=device) + new_verts.scatter_add_(0, inv.unsqueeze(-1).expand_as(verts), verts) + new_verts = new_verts / counts_div + + new_colors = None + if colors is not None: + new_colors = torch.zeros((n_unique, colors.shape[1]), dtype=colors.dtype, device=device) + new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors) + new_colors = new_colors / counts_div.to(colors.dtype) + + new_normals = None + if normals is not None: + new_normals = torch.zeros((n_unique, normals.shape[1]), dtype=normals.dtype, device=device) + new_normals.scatter_add_(0, inv.unsqueeze(-1).expand_as(normals), normals) + new_normals = torch.nn.functional.normalize(new_normals, p=2, dim=-1, eps=1e-6) + + new_faces = inv[faces.long()] if faces.numel() > 0 else faces + return new_verts, new_faces, new_colors, new_normals, int(verts.shape[0] - n_unique) + + +def _drop_degenerate_faces( + verts: torch.Tensor, faces: torch.Tensor, + min_area: float = 1e-14, +) -> Tuple[torch.Tensor, int]: + """Drop degenerate-by-construction faces (repeated indices or zero-area); slivers go to _collapse_slivers.""" + if faces.numel() == 0: + return faces, 0 + idx_bad = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 0] == faces[:, 2]) + f_good = faces[~idx_bad] + v0 = verts[f_good[:, 0]]; v1 = verts[f_good[:, 1]]; v2 = verts[f_good[:, 2]] + e0 = v1 - v0 + e2 = v0 - v2 + area = 0.5 * torch.norm(torch.cross(e0, -e2, dim=-1), dim=-1) + bad = area < min_area + kept = f_good[~bad] + n_dropped = idx_bad.sum() + bad.sum() # tensor-scalar; caller .item()s once + return kept, n_dropped + + +def _collapse_slivers( + verts: torch.Tensor, faces: torch.Tensor, + min_angle_deg: float = 0.0, + max_aspect_ratio: float = 0.0, +) -> Tuple[torch.Tensor, int]: + """Resolve sliver triangles by collapsing each sliver's shortest edge (no holes); returns (faces, n_collapsed).""" + if faces.numel() == 0 or (min_angle_deg <= 0 and max_aspect_ratio <= 0): + return faces, 0 + + fl = faces.long() + v0 = verts[fl[:, 0]]; v1 = verts[fl[:, 1]]; v2 = verts[fl[:, 2]] + e0 = v1 - v0 + e1 = v2 - v1 + e2 = v0 - v2 + l0 = torch.norm(e0, dim=-1) + l1 = torch.norm(e1, dim=-1) + l2 = torch.norm(e2, dim=-1) + area = 0.5 * torch.norm(torch.cross(e0, -e2, dim=-1), dim=-1) + + bad = torch.zeros(faces.shape[0], dtype=torch.bool, device=verts.device) + if max_aspect_ratio > 0: + max_edge = torch.maximum(torch.maximum(l0, l1), l2) + aspect = max_edge * max_edge / (2.0 * area + 1e-12) + bad = bad | (aspect > max_aspect_ratio) + if min_angle_deg > 0: + cos_a = (l1 * l1 + l2 * l2 - l0 * l0) / (2 * l1 * l2 + 1e-12) + cos_b = (l0 * l0 + l2 * l2 - l1 * l1) / (2 * l0 * l2 + 1e-12) + cos_c = (l0 * l0 + l1 * l1 - l2 * l2) / (2 * l0 * l1 + 1e-12) + cos_all = torch.stack([cos_a, cos_b, cos_c], dim=-1) + angles_deg = torch.acos(torch.clamp(cos_all, -1, 1)) * (180.0 / math.pi) + bad = bad | (angles_deg.min(dim=-1).values < min_angle_deg) + + if not bad.any(): + return faces, 0 + + # per sliver pick its shortest edge to collapse + edge_lens = torch.stack([l0, l1, l2], dim=-1) # (F, 3) + shortest_slot = edge_lens.argmin(dim=-1) # (F,) ∈ {0,1,2} + + V = verts.shape[0] + # collapse higher-index endpoint into lower (min/max ordering avoids cycles) + merge_map = torch.arange(V, device=verts.device, dtype=torch.int64) + bad_idx = torch.nonzero(bad, as_tuple=True)[0] + for slot in range(3): + sel = bad_idx[shortest_slot[bad_idx] == slot] + if sel.numel() == 0: + continue + a = fl[sel, slot] + b = fl[sel, (slot + 1) % 3] + lo = torch.minimum(a, b) + hi = torch.maximum(a, b) + merge_map[hi] = lo # last-write-wins on conflict + + # path-compress until stable + for _ in range(10): + new_map = merge_map[merge_map] + if torch.equal(new_map, merge_map): + break + merge_map = new_map + + new_faces = merge_map[fl] + nondeg = ((new_faces[:, 0] != new_faces[:, 1]) & + (new_faces[:, 1] != new_faces[:, 2]) & + (new_faces[:, 0] != new_faces[:, 2])) + new_faces = new_faces[nondeg].to(dtype=faces.dtype) + return new_faces, bad.sum() + + +def _drop_duplicate_faces(faces: torch.Tensor, num_verts: int) -> Tuple[torch.Tensor, int]: + """Remove duplicate faces (same vertex set), keeping the first occurrence (winding-preserving).""" + if faces.shape[0] <= 1: + return faces, 0 + key_sorted = torch.sort(faces, dim=1)[0] + P = num_verts + 1 + packed = (key_sorted[:, 0].long() * P + key_sorted[:, 1].long()) * P + key_sorted[:, 2].long() + unique_packed, inv = torch.unique(packed, return_inverse=True) + if unique_packed.shape[0] == faces.shape[0]: + return faces, 0 + # first-occurrence index per unique key + arange = torch.arange(packed.shape[0], device=packed.device) + first = torch.full((unique_packed.shape[0],), packed.shape[0], + dtype=torch.int64, device=packed.device) + first.scatter_reduce_(0, inv, arange, reduce="amin", include_self=True) + kept = faces[first] + return kept, int(faces.shape[0] - kept.shape[0]) + + +def _drop_unused_verts( + verts: torch.Tensor, faces: torch.Tensor, + colors: Optional[torch.Tensor] = None, + normals: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], int]: + """Remove vertices not referenced by any face; remap faces and filter attributes.""" + if verts.shape[0] == 0 or faces.numel() == 0: + return verts, faces, colors, normals, 0 + used = torch.zeros(verts.shape[0], dtype=torch.bool, device=verts.device) + used[faces[:, 0]] = True + used[faces[:, 1]] = True + used[faces[:, 2]] = True + # cumsum compact remap: 0..N-1 to used verts in order + remap = used.long().cumsum(0) - 1 + new_verts = verts[used] + new_faces = remap[faces.long()] + new_colors = colors[used] if colors is not None else None + new_normals = normals[used] if normals is not None else None + n_dropped = verts.shape[0] - used.sum() + return new_verts, new_faces, new_colors, new_normals, n_dropped + + +def _repair_nonmanifold_edges( + verts: torch.Tensor, faces: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """repair_non_manifold_edges: explode corners, re-merge only across manifold edges; returns (verts, faces, src).""" + if faces.numel() == 0: + return verts, faces + dev, vdt, fdt = verts.device, verts.dtype, faces.dtype + F = faces.detach().cpu().numpy().astype(_np.int64) + V = verts.detach().cpu().numpy() + nf = F.shape[0] + nv = V.shape[0] + corner_vert = F.reshape(-1) # (3F,) original vertex per corner + + # per-face edges keyed by (vmin,vmax) + keys_l, ca_l, cb_l = [], [], [] + for (i, j) in ((0, 1), (1, 2), (2, 0)): + va, vb = F[:, i], F[:, j] + ci = 3 * _np.arange(nf) + i + cj = 3 * _np.arange(nf) + j + amin = _np.where(va <= vb, ci, cj) # corner of the smaller-id endpoint + amax = _np.where(va <= vb, cj, ci) + vmin = _np.minimum(va, vb).astype(_np.int64) + vmax = _np.maximum(va, vb).astype(_np.int64) + keys_l.append(vmin * (nv + 1) + vmax) + ca_l.append(amin); cb_l.append(amax) + keys = _np.concatenate(keys_l) + ca = _np.concatenate(ca_l); cb = _np.concatenate(cb_l) + order = _np.argsort(keys, kind="stable") + keys = keys[order]; ca = ca[order]; cb = cb[order] + uniq, start, cnt = _np.unique(keys, return_index=True, return_counts=True) + man = start[cnt == 2] # manifold edges: exactly 2 incident faces + # union both endpoints' corners across each manifold edge + rows = _np.concatenate([ca[man], cb[man]]) + cols = _np.concatenate([ca[man + 1], cb[man + 1]]) + + n = 3 * nf + g = coo_matrix((_np.ones(rows.shape[0], dtype=_np.int8), (rows, cols)), shape=(n, n)) + _ncomp, labels = connected_components(g, directed=False) + + new_faces = labels[3 * _np.arange(nf)[:, None] + _np.array([0, 1, 2])[None, :]] + nnv = int(labels.max()) + 1 + # source original-vertex index per new vertex + src = _np.zeros(nnv, dtype=_np.int64) + src[labels] = corner_vert + new_verts = V[src] + src_t = torch.from_numpy(src).to(device=dev) + return (torch.from_numpy(new_verts).to(device=dev, dtype=vdt), + torch.from_numpy(new_faces.astype(_np.int64)).to(device=dev, dtype=fdt), + src_t) + + +def _drop_small_components( + verts: torch.Tensor, faces: torch.Tensor, min_faces: int, + max_propagation_iters: int = 200, +) -> Tuple[torch.Tensor, torch.Tensor, int]: + """Label-propagation connected components; drop components below min_faces.""" + if faces.numel() == 0 or min_faces <= 1: + return verts, faces, 0 + device = verts.device + V = verts.shape[0] + labels = torch.arange(V, device=device, dtype=torch.int64) + for _ in range(max_propagation_iters): + v0, v1, v2 = faces[:, 0], faces[:, 1], faces[:, 2] + face_min = torch.minimum(torch.minimum(labels[v0], labels[v1]), labels[v2]) + new_labels = labels.clone() + new_labels.scatter_reduce_(0, v0, face_min, reduce="amin", include_self=True) + new_labels.scatter_reduce_(0, v1, face_min, reduce="amin", include_self=True) + new_labels.scatter_reduce_(0, v2, face_min, reduce="amin", include_self=True) + new_labels = new_labels[new_labels] # path-compress + if torch.equal(new_labels, labels): + break + labels = new_labels + face_label = labels[faces[:, 0]] + unique_labels, counts = torch.unique(face_label, return_counts=True) + big_labels = unique_labels[counts >= min_faces] + if big_labels.shape[0] == unique_labels.shape[0]: + return verts, faces, 0 + # safety: never drop every component (return the small mesh, not an empty one) + if big_labels.shape[0] == 0: + return verts, faces, 0 + keep_face = torch.isin(face_label, big_labels) + kept_faces = faces[keep_face] + n_dropped = int(unique_labels.shape[0] - big_labels.shape[0]) + return verts, kept_faces, n_dropped + + +def clean_mesh( + verts: torch.Tensor, faces: torch.Tensor, + colors: Optional[torch.Tensor] = None, + normals: Optional[torch.Tensor] = None, + weld_epsilon: float = 0.0, + weld_epsilon_rel: float = 1e-6, + drop_degenerate: bool = True, + drop_duplicates: bool = True, + drop_unused: bool = True, + min_component_faces: int = 0, + min_angle_deg: float = 0.0, + max_aspect_ratio: float = 0.0, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], CleanStats]: + """Mesh hygiene pipeline; preserves per-vertex attributes through welding. Returns (v, f, colors, normals, stats).""" + stats = CleanStats(in_verts=verts.shape[0], in_faces=faces.shape[0]) + t0 = _time.perf_counter() + v = verts + f = faces.long() if faces.numel() > 0 else faces + c = colors + n = normals + + if weld_epsilon != 0.0 or weld_epsilon_rel > 0: + # eps stays a 0-d tensor (no sync) + if weld_epsilon > 0: + eps = torch.as_tensor(weld_epsilon, dtype=v.dtype, device=v.device) + else: + eps = torch.norm(v.max(dim=0)[0] - v.min(dim=0)[0]) * weld_epsilon_rel + v, f, c, n, n_welded = _weld_vertices(v, f, eps, c, n) + stats.welded_verts = n_welded + + if drop_degenerate: + f_new, n_drop = _drop_degenerate_faces(v, f) + stats.degenerate_faces = n_drop + f = f_new + # slivers get collapse-merged instead of dropped (preserves topology) + if min_angle_deg > 0 or max_aspect_ratio > 0: + f_new, n_sliv = _collapse_slivers( + v, f, min_angle_deg=min_angle_deg, max_aspect_ratio=max_aspect_ratio, + ) + stats.degenerate_faces += n_sliv + f = f_new + + if drop_duplicates: + f_new, n_dup = _drop_duplicate_faces(f, v.shape[0]) + stats.duplicate_faces = n_dup + f = f_new + + if min_component_faces > 1: + v, f, n_comp = _drop_small_components(v, f, min_component_faces) + stats.components_dropped = n_comp + + if drop_unused: + v, f, c, n, n_unused = _drop_unused_verts(v, f, c, n) + stats.unused_verts = n_unused + + stats.out_verts = v.shape[0] + stats.out_faces = f.shape[0] + stats.seconds = _time.perf_counter() - t0 + # materialize tensor-scalar counts to plain ints once at exit + for field in ("welded_verts", "degenerate_faces", "duplicate_faces", + "unused_verts", "components_dropped"): + val = getattr(stats, field) + if torch.is_tensor(val): + setattr(stats, field, int(val.item())) + return v, f, c, n, stats + + +@dataclass +class SimplifyStats: + input_verts: int = 0 + input_faces: int = 0 + output_verts: int = 0 + output_faces: int = 0 + iterations: int = 0 + total_collapses: int = 0 + seconds: float = 0.0 + peak_mem_mb: float = 0.0 + + +def qem_simplify( + vertices: torch.Tensor, + faces: torch.Tensor, + target_faces: int, + colors: Optional[torch.Tensor] = None, + normals: Optional[torch.Tensor] = None, + max_edge_length: Optional[float] = None, + config: Optional[QEMConfig] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], SimplifyStats]: + """Single-mesh QEM simplification. Returns (v, f, colors, normals, stats).""" + cfg = config or QEMConfig() + + device = vertices.device + in_v_dtype = vertices.dtype + in_f_dtype = faces.dtype + in_c_dtype = colors.dtype if colors is not None else None + in_n_dtype = normals.dtype if normals is not None else None + + verts = vertices.to(device=device, dtype=cfg.dtype, copy=True) + faces = faces.to(device=device, dtype=torch.int64).clone() + colors_w = colors.to(device=device, dtype=cfg.dtype, copy=True) if colors is not None else None + normals_w = normals.to(device=device, dtype=cfg.dtype, copy=True) if normals is not None else None + + # optional preclean: weld + drop degenerate/duplicate, attributes cluster-averaged + if cfg.preclean: + verts, faces, colors_w, normals_w, _cs = clean_mesh( + verts, faces, colors_w, normals_w, + weld_epsilon_rel=cfg.preclean_weld_epsilon_rel, + min_component_faces=cfg.preclean_min_component_faces, + ) + + num_verts = verts.shape[0] + num_faces = faces.shape[0] + + stats = SimplifyStats(input_verts=num_verts, input_faces=num_faces) + + if num_faces <= target_faces or num_verts < 4: + stats.output_verts = num_verts + stats.output_faces = num_faces + return verts.to(in_v_dtype), faces.to(in_f_dtype), \ + (colors_w.to(in_c_dtype) if colors_w is not None else None), \ + (normals_w.to(in_n_dtype) if normals_w is not None else None), \ + stats + + if device.type == "cuda": + torch.cuda.synchronize(device) + torch.cuda.reset_peak_memory_stats(device) + t0 = _time.perf_counter() + + v_alive = torch.ones(num_verts, dtype=torch.bool, device=device) + f_alive = torch.ones(num_faces, dtype=torch.bool, device=device) + + Q = _build_quadrics(verts, faces, cfg) + + bbox = verts.max(dim=0)[0] - verts.min(dim=0)[0] + mesh_scale = torch.norm(bbox) # 0-d tensor; never .item()'d + if max_edge_length is None or max_edge_length <= 0: + max_edge_length = mesh_scale * 2.0 + else: + max_edge_length = torch.as_tensor(max_edge_length, dtype=cfg.dtype, device=device) + # degenerate-mesh guard for tiny bbox (tensor-side, no sync) + max_edge_length = torch.where( + max_edge_length < 1e-6, + torch.ones((), dtype=max_edge_length.dtype, device=device), + max_edge_length, + ) + + stabilizer = mesh_scale * mesh_scale * cfg.stabilizer_scale + max_edge_length_sq = max_edge_length * max_edge_length + mesh_scale_sq = mesh_scale * mesh_scale + + # threshold scaled by mesh_scale² so the 1e-8 start is scale-robust + thresh = float(cfg.threshold_start) * float(mesh_scale_sq) if cfg.threshold_driver else 0.0 + + # pre-allocated merge_map, reused each iter + merge_map = torch.arange(num_verts, device=device) + + # py_n_faces: Python-int face count (no host sync in hot loop), re-synced at compaction + py_n_faces = num_faces + + iteration = 0 + total_collapses = 0 + + # progress bars (tqdm + optional comfy ProgressBar); best-effort + _start_faces = num_faces + _prog_total = max(1, _start_faces - int(target_faces)) + try: + _qtq = _tqdm(total=100, desc="QEM simplify", leave=False) + except Exception: + _qtq = None + try: + _qpbar = _comfy_utils.ProgressBar(100) + except Exception: + _qpbar = None + + def _qreport(): + pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total))) + if _qtq is not None: + _qtq.n = pct; _qtq.refresh() + if _qpbar is not None: + _qpbar.update_absolute(pct, 100) + + while True: + if py_n_faces <= target_faces: + break + _qreport() + + alive_f = torch.nonzero(f_alive, as_tuple=True)[0] + if alive_f.numel() == 0: + break + + active_faces = faces[alive_f] + + # memoryless QEM: rebuild Q from current geometry each iter + if cfg.threshold_driver and cfg.memoryless_qem and iteration > 0: + Q = _build_quadrics(verts, active_faces, cfg) + + Q_for_iter = Q + # edge extraction: pack each (min*V + max) so unique dedups in one go + af_roll = torch.roll(active_faces, shifts=-1, dims=1) + mn = torch.minimum(active_faces, af_roll) + mx = torch.maximum(active_faces, af_roll) + packed = torch.add(mx, mn, alpha=num_verts).flatten() + packed = torch.unique(packed) + edges_orig = torch.stack([packed // num_verts, packed % num_verts], dim=1) + + # filter by edge length + pab = verts[edges_orig] # (E, 2, 3) + el = torch.norm(pab[:, 1] - pab[:, 0], dim=-1) + edges_orig = edges_orig[el < max_edge_length] + if edges_orig.shape[0] == 0: + break + + # sampling cap + n_edges_total = edges_orig.shape[0] + if n_edges_total > cfg.sampling_cap: + perm = torch.randperm(n_edges_total, device=device)[: cfg.sampling_cap] + edges_orig = edges_orig[perm] + + # boundary mask only needed for non-qem placement + if cfg.placement_mode != "qem": + vib = _vert_is_boundary_mask(active_faces, num_verts) + else: + vib = None + optimal, err, valid = _edge_errors( + verts, Q_for_iter, edges_orig, stabilizer, max_edge_length_sq, + mesh_scale_sq, cfg, vert_is_boundary=vib, + ) + valid_idx = torch.nonzero(valid, as_tuple=True)[0] + edges_orig = edges_orig[valid_idx] + optimal = optimal[valid_idx] + err = err[valid_idx] + + faces_to_remove = py_n_faces - target_faces + n_faces_round_start = py_n_faces + # ~2 faces removed per collapse, so cap the round at faces_to_remove//2 + cap_to_target = max(1, faces_to_remove // 2) + + if cfg.threshold_driver: + # threshold-schedule selection + # candidate band = cost <= thresh (×10 until non-empty), quality-check then collapse a disjoint set + cand = err <= thresh + esc = 0 + while not bool(cand.any()) and esc < 50: + thresh *= 10.0 + cand = err <= thresh + esc += 1 + cand_idx = torch.nonzero(cand, as_tuple=True)[0] + ce = edges_orig[cand_idx] + copt = optimal[cand_idx] + cerr = err[cand_idx].clone() + need_flip = cfg.flip_reject_hard + if ((need_flip or cfg.skinny_weight > 0 or cfg.enforce_link_condition) + and ce.shape[0] > 0): + afq = faces[alive_f] + v_to_f = _build_vert_to_faces_pad(afq, num_verts, cfg.flip_check_max_degree) + # link + flip + skinny share one fused 1-ring pass on the same band + fc, sk, link_safe = _quality_checks_fused( + verts, afq, ce, copt, v_to_f, cos_threshold=cfg.flip_cos_threshold, + want_flip=need_flip, want_skinny=(cfg.skinny_weight > 0), + want_link=cfg.enforce_link_condition) + if link_safe is not None: + cerr[~link_safe] = float("inf") + if fc is not None: + cerr = torch.where(fc > 0, torch.full_like(cerr, float("inf")), cerr) + if sk is not None: + el_sq = (verts[ce[:, 1]] - verts[ce[:, 0]]).pow(2).sum(dim=-1) + cerr = cerr + cfg.skinny_weight * sk * el_sq + del v_to_f, afq + # penalties may push edges above thresh — re-gate the band + keep = cerr <= thresh + ce = ce[keep]; copt = copt[keep]; cerr = cerr[keep] + edges_orig = ce + optimal = copt + sel = _greedy_matching(ce, cerr, v_alive, cap_to_target) + if sel.numel() == 0: + # band fully rejected → raise thresh and retry + thresh *= 10.0 + iteration += 1 + if iteration >= cfg.max_iterations: + break + continue + else: + max_collapses = min( + cfg.max_collapses_ceiling, + max(cfg.max_collapses_floor, int(faces_to_remove * cfg.max_collapses_fraction)), + ) + if cfg.max_collapses_relative_cap > 0: + # hybrid tail: cap to a fraction of current mesh size (anti cascade-overshoot) + rel_cap = max(1, int(py_n_faces * cfg.max_collapses_relative_cap)) + max_collapses = min(max_collapses, rel_cap) + max_collapses = min(max_collapses, cap_to_target) + + # soft quality penalties on top-K candidates: flip check + skinny, sharing one v_to_f build + need_flip = cfg.flip_reject_hard + need_quality = ((need_flip or cfg.skinny_weight > 0 or cfg.enforce_link_condition) + and edges_orig.shape[0] > 0) + if need_quality: + n_check = min(edges_orig.shape[0], + max(1, cfg.quality_topk_multiplier * max_collapses)) + if n_check < edges_orig.shape[0]: + topk = torch.topk(err, n_check, largest=False).indices + else: + topk = torch.arange(edges_orig.shape[0], device=device) + active_for_quality = faces[alive_f] + v_to_f = _build_vert_to_faces_pad(active_for_quality, num_verts, + cfg.flip_check_max_degree) + err = err.clone() + if cfg.enforce_link_condition: + # reject link-condition violations on ALL candidate edges (not just top-K) + link_safe = _link_condition_mask(active_for_quality, edges_orig, v_to_f) + err[~link_safe] = float("inf") + # flip + skinny share the same top-K 1-ring walk + e_tk = edges_orig[topk] + o_tk = optimal[topk] + _do_flip = need_flip + _do_skinny = cfg.skinny_weight > 0 + if _do_flip and _do_skinny: + flip_count, skinny, _ = _quality_checks_fused( + verts, active_for_quality, e_tk, o_tk, v_to_f, + cos_threshold=cfg.flip_cos_threshold, want_link=False) + elif _do_flip: + flip_count = _normal_flip_mask( + verts, active_for_quality, e_tk, o_tk, v_to_f, + cos_threshold=cfg.flip_cos_threshold, return_count=True) + skinny = None + else: + skinny = _skinny_penalty(verts, active_for_quality, e_tk, o_tk, v_to_f) + flip_count = None + if _do_flip: + # hard reject: any flipping top-K edge → +inf + flips = flip_count > 0 + if flips.any(): + err[topk] = torch.where( + flips, torch.full_like(err[topk], float("inf")), + err[topk], + ) + if _do_skinny: + # skinny_cost * edge_length² (match QEM's length² scaling) + elen_sq = (verts[e_tk[:, 1]] - verts[e_tk[:, 0]]).pow(2).sum(dim=-1) + err[topk] = torch.add(err[topk], skinny * elen_sq, + alpha=cfg.skinny_weight) + del v_to_f, active_for_quality + + sel = _greedy_matching(edges_orig, err, v_alive, max_collapses) + + if sel.numel() == 0: + break + + ed_sel = edges_orig[sel] + v_a = ed_sel[:, 0] + v_b = ed_sel[:, 1] + new_pos = optimal[sel] + + # interpolate attributes by where new_pos lies on the [pa, pb] segment + if colors_w is not None or normals_w is not None: + pa_sel = verts[v_a] + pb_sel = verts[v_b] + edge_vec = pb_sel - pa_sel + edge_len_sq = (edge_vec * edge_vec).sum(dim=-1) + 1e-20 + t = ((new_pos - pa_sel) * edge_vec).sum(dim=-1) / edge_len_sq + t = t.clamp(0.0, 1.0).unsqueeze(-1) + if colors_w is not None: + colors_w[v_a] = torch.lerp(colors_w[v_a], colors_w[v_b], t) + if normals_w is not None: + normals_w[v_a] = torch.lerp(normals_w[v_a], normals_w[v_b], t) + + # apply collapse + verts[v_a] = new_pos + v_alive[v_b] = False + if not (cfg.threshold_driver and cfg.memoryless_qem): + Q[v_a] += Q[v_b] + + merge_map[v_b] = v_a + faces = merge_map[faces] + merge_map[v_b] = v_b # restore identity for next iter + + bad = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0]) + f_alive.masked_fill_(bad, False) + py_n_faces -= 2 * v_a.numel() # ~2 faces/collapse estimate; re-synced at compaction + + # schedule: round removed < 1% → raise thresh ×10 + if cfg.threshold_driver: + removed = n_faces_round_start - py_n_faces + if removed < 0.01 * n_faces_round_start: + thresh *= 10.0 + + total_collapses += int(v_a.numel()) + iteration += 1 + + # periodic compaction (resyncs py_n_faces exactly) + if iteration % cfg.compaction_period == 0: + alive_frac = py_n_faces / max(1, num_faces) + if alive_frac < cfg.compaction_threshold: + faces = faces[f_alive] + num_faces = faces.shape[0] + f_alive = torch.ones(num_faces, dtype=torch.bool, device=device) + py_n_faces = num_faces + + if iteration >= cfg.max_iterations: + break + + _qreport() + if _qtq is not None: + _qtq.close() + + # finalize: compact verts and faces + final_v = verts[v_alive] + final_c = colors_w[v_alive] if colors_w is not None else None + final_n = normals_w[v_alive] if normals_w is not None else None + + remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) + remap[v_alive] = v_alive.long().cumsum(0)[v_alive] - 1 # compact remap, no sync + + final_f_raw = faces[f_alive] + alive_mask = v_alive[final_f_raw].all(dim=1) + final_f_raw = final_f_raw[alive_mask] + final_f = remap[final_f_raw] + valid_faces = (final_f >= 0).all(dim=1) + final_f = final_f[valid_faces] + + # drop degenerate faces (two indices equal) + if final_f.numel() > 0: + nondeg = (final_f[:, 0] != final_f[:, 1]) & (final_f[:, 1] != final_f[:, 2]) & (final_f[:, 0] != final_f[:, 2]) + final_f = final_f[nondeg] + + # dedup duplicate faces, winding-preserving + if final_f.numel() > 0: + key = torch.sort(final_f, dim=1)[0] + packed = (key[:, 0].long() * (final_v.shape[0] + 1) + key[:, 1].long()) \ + * (final_v.shape[0] + 1) + key[:, 2].long() + unique_packed, inv = torch.unique(packed, return_inverse=True) + arange = torch.arange(packed.shape[0], device=packed.device) + first = torch.full((unique_packed.shape[0],), packed.shape[0], + dtype=torch.int64, device=packed.device) + first.scatter_reduce_(0, inv, arange, reduce="amin", include_self=True) + final_f = final_f[first] + + # repair_non_manifold_edges: split back fused surface sheets (after dedup, before pruning) + if cfg.repair_nonmanifold and final_f.numel() > 0: + final_v, final_f, _src = _repair_nonmanifold_edges(final_v, final_f) + if final_c is not None: + final_c = final_c[_src] + if final_n is not None: + final_n = final_n[_src] + + # post-clean: drop slivers, tiny components, unused verts + if cfg.postclean and final_f.numel() > 0: + comp_threshold = cfg.postclean_min_component_faces + final_v, final_f, final_c, final_n, _ps = clean_mesh( + final_v, final_f, final_c, final_n, + weld_epsilon=0.0, weld_epsilon_rel=0.0, # already welded + drop_degenerate=True, + drop_duplicates=False, # already done above + drop_unused=True, + min_component_faces=comp_threshold, + min_angle_deg=cfg.postclean_min_angle_deg, + max_aspect_ratio=cfg.postclean_max_aspect_ratio, + ) + + # post-simplify normals + if cfg.recompute_normals_post and final_f.numel() > 0: + final_n = _compute_vertex_normals(final_v, final_f) + elif final_n is not None and final_f.numel() > 0: + # keep supplied normals; flip face winding where it disagrees + v0 = final_v[final_f[:, 0]] + v1 = final_v[final_f[:, 1]] + v2 = final_v[final_f[:, 2]] + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + ref = (final_n[final_f[:, 0]] + final_n[final_f[:, 1]] + + final_n[final_f[:, 2]]) / 3.0 + wrong = (fn * ref).sum(dim=-1) < 0 + final_f[wrong] = final_f[wrong][:, [0, 2, 1]] + + if device.type == "cuda": + torch.cuda.synchronize(device) + stats.peak_mem_mb = torch.cuda.max_memory_allocated(device) / (1024 * 1024) + stats.seconds = _time.perf_counter() - t0 + stats.iterations = iteration + stats.total_collapses = total_collapses + stats.output_verts = final_v.shape[0] + stats.output_faces = final_f.shape[0] + + return ( + final_v.to(in_v_dtype), + final_f.to(in_f_dtype), + final_c.to(in_c_dtype) if final_c is not None else None, + final_n.to(in_n_dtype) if (final_n is not None and in_n_dtype is not None) else final_n, + stats, + ) + + +def simplify( + vertices: torch.Tensor, + faces: torch.Tensor, + target: int, + colors: Optional[torch.Tensor] = None, + normals: Optional[torch.Tensor] = None, + max_edge_length: Optional[float] = None, + config: Optional[QEMConfig] = None, +): + """Batched wrapper. Accepts (V,3)/(F,3) or (B,V,3)/(B,F,3).""" + if vertices.ndim == 3: + out_v, out_f, out_c, out_n, out_s = [], [], [], [], [] + for i in range(vertices.shape[0]): + c_in = colors[i] if colors is not None else None + n_in = normals[i] if normals is not None else None + v, f, c, n, s = qem_simplify(vertices[i], faces[i], target, c_in, n_in, max_edge_length, config) + out_v.append(v); out_f.append(f); out_s.append(s) + if c is not None: out_c.append(c) + if n is not None: out_n.append(n) + return (out_v, out_f, + out_c if out_c else None, + out_n if out_n else None, + out_s) + return qem_simplify(vertices, faces, target, colors, normals, max_edge_length, config)