"""Narrow-band Dual Contouring remeshing. Re-extracts a mesh from a sparse narrow-band voxel grid around the input surface (pure-PyTorch approximation of CuMesh's remesh_narrow_band_dc). Coarse-to-fine voxelise the band, sample SDF/UDF at voxel corners, dual contour (optionally QEF / Manifold DC), then optionally project back, filter components, fix poles, smooth, and interpolate vertex colors. """ from __future__ import annotations import functools import math from typing import Optional, Tuple import numpy as np import torch import scipy.spatial import comfy.utils from tqdm import tqdm as _tqdm from comfy.model_management import throw_exception_if_processing_interrupted from .qem_decimate import _sorted_edge_halfedges # Point-to-triangle distance (exact, vectorised) def _point_tri_closest(points: torch.Tensor, tris: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Exact closest point + squared distance per (point, triangle) pair; points (N,3), tris (N,3,3).""" a = tris[:, 0] b = tris[:, 1] c = tris[:, 2] ab = b - a ac = c - a ap = points - a d1 = (ab * ap).sum(-1) d2 = (ac * ap).sum(-1) region_A = (d1 <= 0) & (d2 <= 0) bp = points - b d3 = (ab * bp).sum(-1) d4 = (ac * bp).sum(-1) region_B = (d3 >= 0) & (d4 <= d3) cp = points - c d5 = (ab * cp).sum(-1) d6 = (ac * cp).sum(-1) region_C = (d6 >= 0) & (d5 <= d6) # Edge AB vc = d1 * d4 - d3 * d2 region_AB = (vc <= 0) & (d1 >= 0) & (d3 <= 0) v_ab = d1 / (d1 - d3 + 1e-20) closest_AB = a + v_ab.unsqueeze(-1) * ab # Edge AC vb = d5 * d2 - d1 * d6 region_AC = (vb <= 0) & (d2 >= 0) & (d6 <= 0) v_ac = d2 / (d2 - d6 + 1e-20) closest_AC = a + v_ac.unsqueeze(-1) * ac # Edge BC va = d3 * d6 - d5 * d4 region_BC = (va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0) v_bc = (d4 - d3) / ((d4 - d3) + (d5 - d6) + 1e-20) closest_BC = b + v_bc.unsqueeze(-1) * (c - b) # Face interior (barycentric) denom = va + vb + vc + 1e-20 v_face = vb / denom w_face = vc / denom closest_face = a + v_face.unsqueeze(-1) * ab + w_face.unsqueeze(-1) * ac # Combine by mask via in-place where (out= aliases input, no per-step alloc) closest = closest_face # fresh; safe to mutate torch.where(region_BC.unsqueeze(-1), closest_BC, closest, out=closest) torch.where(region_AC.unsqueeze(-1), closest_AC, closest, out=closest) torch.where(region_AB.unsqueeze(-1), closest_AB, closest, out=closest) torch.where(region_C .unsqueeze(-1), c, closest, out=closest) torch.where(region_B .unsqueeze(-1), b, closest, out=closest) torch.where(region_A .unsqueeze(-1), a, closest, out=closest) diff = points - closest return closest, (diff * diff).sum(-1) def _build_centroid_tree(tri_verts: torch.Tensor): """scipy cKDTree over triangle centroids; build once and reuse across _udf_exact calls. balanced_tree/compact_nodes off: ~2.4x faster build (and faster queries on near-uniform centroid clouds) with identical exact-kNN results.""" return scipy.spatial.cKDTree(tri_verts.mean(dim=1).detach().cpu().numpy(), balanced_tree=False, compact_nodes=False) def _udf_exact(query_points: torch.Tensor, tri_verts: torch.Tensor, k: int = 8, chunk: int = 262144, tree=None): """Exact UDF (no max_dist cap) via centroid kNN; returns (dist [N], closest [N,3], tri_idx [N]). Pass prebuilt `tree` to skip rebuild. k=8 nearest centroids before the exact point-triangle test: on dense meshes the true closest triangle is essentially always within the first few neighbours. Measured vs k=16: bit-identical topology, ~0.003-voxel RMS sub-voxel drift, ~15% faster overall.""" device = query_points.device F = tri_verts.shape[0] kq = int(min(k, F)) if tree is None: tree = _build_centroid_tree(tri_verts) _, cand = tree.query(query_points.detach().cpu().numpy(), k=kq, workers=-1) if cand.ndim == 1: cand = cand[:, None] cand = np.ascontiguousarray(cand) N = query_points.shape[0] out_d = torch.empty(N, device=device, dtype=query_points.dtype) out_c = torch.empty(N, 3, device=device, dtype=query_points.dtype) out_t = torch.empty(N, dtype=torch.long, device=device) for s in range(0, N, chunk): e = min(s + chunk, N) n = e - s ci = torch.from_numpy(cand[s:e]).to(device).long() tri = tri_verts[ci].reshape(n * kq, 3, 3) P = query_points[s:e][:, None, :].expand(-1, kq, -1).reshape(n * kq, 3) closest, d2 = _point_tri_closest(P, tri) d2 = d2.reshape(n, kq) closest = closest.reshape(n, kq, 3) best = d2.argmin(dim=1) ar = torch.arange(n, device=device) out_d[s:e] = d2[ar, best].sqrt() out_c[s:e] = closest[ar, best] out_t[s:e] = ci[ar, best] return out_d, out_c, out_t # UDF query via spatial hash on triangle AABBs def _build_tri_spatial_hash(centroids: torch.Tensor, tri_radii: torch.Tensor, cell_size: torch.Tensor): """Bucket triangles into `cell_size` cells (each tri into every cell its AABB touches); returns hash tuple.""" device = centroids.device aabb_lo = (centroids - tri_radii.unsqueeze(-1)) aabb_hi = (centroids + tri_radii.unsqueeze(-1)) origin = aabb_lo.min(0)[0] extent = aabb_hi.max(0)[0] - origin dims = (extent / cell_size).long() + 2 cell_lo = ((aabb_lo - origin) / cell_size).long().clamp(min=0) cell_hi = ((aabb_hi - origin) / cell_size).long() cell_hi = torch.minimum(cell_hi, dims - 1) # Cap span at 3 cells/axis to bound memory spans = (cell_hi - cell_lo + 1).clamp(max=3) n_per_tri = spans.prod(dim=-1) total = int(n_per_tri.sum().item()) # Per-insertion local offset within each tri's cell box rep = torch.repeat_interleave(torch.arange(centroids.shape[0], device=device), n_per_tri) cum = torch.cat([torch.zeros(1, device=device, dtype=n_per_tri.dtype), n_per_tri.cumsum(0)[:-1]]) local = torch.arange(total, device=device) - cum[rep] sx = spans[rep, 0] sy = spans[rep, 1] lx = local % sx ly = (local // sx) % sy lz = local // (sx * sy) cx = cell_lo[rep, 0] + lx cy = cell_lo[rep, 1] + ly cz = cell_lo[rep, 2] + lz keys = (cx * dims[1] + cy) * dims[2] + cz sort_idx = keys.argsort() sorted_keys = keys[sort_idx] tri_per_cell = rep[sort_idx] unique_keys, counts = torch.unique_consecutive(sorted_keys, return_counts=True) cell_starts = torch.cat([torch.zeros(1, dtype=counts.dtype, device=device), counts.cumsum(0)]) return origin, dims, unique_keys, tri_per_cell, cell_starts, centroids, tri_radii def _udf_query(query_points: torch.Tensor, tri_verts: torch.Tensor, hash_data, cell_size: torch.Tensor, max_dist: float, chunk_max: int = 4096, return_closest: bool = False, return_tri_idx: bool = False): """Capped UDF to nearest triangle (<= max_dist), optionally with closest point and/or tri index; chunk size is adaptive to hash density.""" origin, dims, unique_keys, tri_per_cell, cell_starts, tri_centroids, tri_radii = hash_data device = query_points.device Q = query_points.shape[0] # Adaptive chunk: bound per-chunk candidate-gather memory by hash density avg_per_cell = tri_per_cell.numel() / max(1, unique_keys.numel()) est_cands_per_query = max(1.0, avg_per_cell * 27) chunk = max(256, min(chunk_max, int(50_000_000 / est_cands_per_query))) out_d2 = torch.full((Q,), float(max_dist) ** 2, dtype=query_points.dtype, device=device) # Default closest_pt = query_pt itself, so a missed query's lerp is a no-op out_closest = (query_points.clone() if return_closest else None) out_tri = (torch.full((Q,), -1, dtype=torch.long, device=device) if return_tri_idx else None) rng = torch.tensor([-1, 0, 1], device=device, dtype=torch.long) offs = torch.stack(torch.meshgrid(rng, rng, rng, indexing="ij"), dim=-1).reshape(-1, 3) # (27, 3) for cs in range(0, Q, chunk): ce = min(cs + chunk, Q) qp = query_points[cs:ce] q_cell = ((qp - origin) / cell_size).long() # Look up 27 neighbour cells per query n_cell = q_cell.unsqueeze(1) + offs.unsqueeze(0) # (q, 27, 3) n_valid = ((n_cell >= 0) & (n_cell < dims)).all(-1) n_key = (n_cell[..., 0] * dims[1] + n_cell[..., 1]) * dims[2] + n_cell[..., 2] flat_key = n_key.reshape(-1).contiguous() ins = torch.searchsorted(unique_keys, flat_key) ins_c = ins.clamp(max=unique_keys.numel() - 1) found = (ins < unique_keys.numel()) & (unique_keys[ins_c] == flat_key) & n_valid.reshape(-1) cell_idx = torch.where(found, ins_c, torch.zeros_like(ins_c)) c_starts = cell_starts[cell_idx] c_ends = cell_starts[cell_idx + 1] c_counts = (c_ends - c_starts) * found.long() rep_q = torch.repeat_interleave( torch.arange(qp.shape[0] * 27, device=device) // 27, c_counts) if rep_q.numel() == 0: continue total = rep_q.numel() slot_starts_per_pair = torch.cumsum(c_counts, dim=0) - c_counts per_pair_start = torch.repeat_interleave(c_starts, c_counts) slot_within = torch.arange(total, device=device) - torch.repeat_interleave(slot_starts_per_pair, c_counts) tri_indices = tri_per_cell[per_pair_start + slot_within] pts = qp[rep_q] # Centroid pre-cull (squared): drop where ||pts-centroid||-radius > max_dist diff = pts - tri_centroids[tri_indices] d2_cand = (diff * diff).sum(-1) thresh = max_dist + tri_radii[tri_indices] cull_keep = d2_cand < thresh * thresh rep_q = rep_q[cull_keep] pts = pts[cull_keep] tri_indices = tri_indices[cull_keep] if rep_q.numel() == 0: continue tri = tri_verts[tri_indices] closest, d2 = _point_tri_closest(pts, tri) # Min per query for this chunk. local_min = torch.full((qp.shape[0],), float(max_dist) ** 2, dtype=query_points.dtype, device=device) local_min.scatter_reduce_(0, rep_q, d2, reduce="amin", include_self=True) # Only update where this chunk improved; ties may overwrite (any is valid) better = local_min < out_d2[cs:ce] out_d2[cs:ce] = torch.where(better, local_min, out_d2[cs:ce]) if return_closest or return_tri_idx: ties = (d2 == local_min[rep_q]) & better[rep_q] if return_closest: out_closest[cs + rep_q[ties]] = closest[ties] if return_tri_idx: out_tri[cs + rep_q[ties]] = tri_indices[ties] out_d = out_d2.sqrt() extras = [] if return_closest: extras.append(out_closest) if return_tri_idx: extras.append(out_tri) if extras: return (out_d, *extras) return out_d # Sparse coarse-to-fine voxel grid in narrow band def _build_narrow_band_voxels(verts: torch.Tensor, faces: torch.Tensor, center: torch.Tensor, scale: float, resolution: int, eps: float, progress_callback=None) -> torch.Tensor: """Voxel coords (Nv,3) in 0..resolution-1 whose centre is within ~0.87 cell_size of the surface; also returns the kept cKDTree.""" device = verts.device tri_verts = verts[faces.long()] # Exact UDF; build the centroid cKDTree once and reuse across refinement levels tree = _build_centroid_tree(tri_verts) base_resolution = resolution while base_resolution > 32 and base_resolution % 2 == 0: base_resolution //= 2 rng = torch.arange(base_resolution, device=device, dtype=torch.long) coords = torch.stack(torch.meshgrid(rng, rng, rng, indexing="ij"), dim=-1).reshape(-1, 3) OFFSETS = torch.tensor([ [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], ], dtype=torch.long, device=device) current_res = base_resolution while True: throw_exception_if_processing_interrupted() cell_size = scale / current_res pts = ((coords.float() + 0.5) / current_res - 0.5) * scale + center dists, _, _ = _udf_exact(pts, tri_verts, tree=tree) keep = dists < 0.87 * cell_size + eps coords = coords[keep] if progress_callback is not None: progress_callback() if current_res >= resolution: break current_res *= 2 coords = coords * 2 coords = (coords.unsqueeze(1) + OFFSETS.unsqueeze(0)).reshape(-1, 3) return coords, tree # Dual Contouring def _dual_contour(voxel_coords: torch.Tensor, corner_udf: torch.Tensor, corner_keys: torch.Tensor, resolution: int, scale: float, center: torch.Tensor, tri_face_normals: Optional[torch.Tensor] = None, qef_query=None, corner_valid: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Dual contour active voxels; returns (Nv,3) dual verts and (M,3) faces into them. QEF placement when tri_face_normals+qef_query given, else centroid of crossings.""" device = voxel_coords.device Nv = voxel_coords.shape[0] # 8 corners per voxel, packed into a 1d key CORNER_OFFS = torch.tensor([ [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], ], dtype=torch.long, device=device) corner_pos_per_voxel = voxel_coords.unsqueeze(1) + CORNER_OFFS.unsqueeze(0) # (Nv, 8, 3) R1 = resolution + 1 keys_per_voxel = (corner_pos_per_voxel[..., 0] * R1 + corner_pos_per_voxel[..., 1]) * R1 + corner_pos_per_voxel[..., 2] # Look up SDF index per corner; missing corners default to +1 (outside) idx_per_voxel = torch.searchsorted(corner_keys, keys_per_voxel.reshape(-1)) idx_clamped = idx_per_voxel.clamp(max=corner_keys.numel() - 1) found = (idx_per_voxel < corner_keys.numel()) & (corner_keys[idx_clamped] == keys_per_voxel.reshape(-1)) sd = torch.where(found, corner_udf[idx_clamped], torch.full_like(corner_udf[idx_clamped], 1.0)) sd = sd.reshape(Nv, 8) # surface at sign = 0 # 12 voxel edges as (corner_a, corner_b) pairs (indices into the 8 corners above) EDGES = torch.tensor([ [0, 1], [2, 3], [4, 5], [6, 7], # x-axis edges [0, 2], [1, 3], [4, 6], [5, 7], # y-axis [0, 4], [1, 5], [2, 6], [3, 7], # z-axis ], dtype=torch.long, device=device) a_sd = sd[:, EDGES[:, 0]] # (Nv, 12) b_sd = sd[:, EDGES[:, 1]] crosses = (a_sd * b_sd) < 0 # (Nv, 12) bool # Skip crossings touching an invalid corner (avoids fake faces at band edge) if corner_valid is not None: cv_per_voxel = torch.where(found, corner_valid[idx_clamped], torch.zeros_like(found)).reshape(Nv, 8) edge_valid = cv_per_voxel[:, EDGES[:, 0]] & cv_per_voxel[:, EDGES[:, 1]] crosses = crosses & edge_valid # Zero-crossing interp factor per edge t = a_sd / (a_sd - b_sd + 1e-20) t = t.clamp(0.0, 1.0).unsqueeze(-1) corner_world = (corner_pos_per_voxel.float() / resolution - 0.5) * scale + center.unsqueeze(0).unsqueeze(0) # (Nv, 8, 3) a_pos = corner_world[:, EDGES[:, 0]] # (Nv, 12, 3) b_pos = corner_world[:, EDGES[:, 1]] crossing_pts = torch.lerp(a_pos, b_pos, t) # (Nv, 12, 3) # Default dual vert: centroid of crossings (also QEF/no-crossing fallback) crosses_f = crosses.float().unsqueeze(-1) crossing_sum = (crossing_pts * crosses_f).sum(dim=1) n_cross = crosses.float().sum(dim=1, keepdim=True).clamp_min(1.0) centroid_verts = crossing_sum / n_cross centre_world = ((voxel_coords.float() + 0.5) / resolution - 0.5) * scale + center.unsqueeze(0) has_cross = crosses.any(dim=1, keepdim=True) dual_verts = torch.where(has_cross, centroid_verts, centre_world) # QEF placement: minimise sum_i (n_i·(x-p_i))² via Tikhonov-regularised # normal equations (A+reg I)x=b; clamp to voxel bbox, else fall back to centroid. if tri_face_normals is not None and qef_query is not None: Nv = voxel_coords.shape[0] flat_pts = crossing_pts.reshape(-1, 3) flat_mask = crosses.reshape(-1) if flat_mask.any(): query_pts = flat_pts[flat_mask] _, _, qef_tri_idx = qef_query(query_pts) # Missed queries get a zero normal (null constraint, ignored by solver) valid_q = qef_tri_idx >= 0 normals_at_q = torch.zeros_like(query_pts) normals_at_q[valid_q] = tri_face_normals[qef_tri_idx[valid_q]] full_normals = torch.zeros((Nv * 12, 3), dtype=query_pts.dtype, device=device) full_normals[flat_mask] = normals_at_q n_per_edge = full_normals.reshape(Nv, 12, 3) # einsum sums into the 3x3 directly, skipping a big intermediate A = torch.einsum('vec,ved->vcd', n_per_edge, n_per_edge) # (Nv, 3, 3) n_dot_p = (n_per_edge * crossing_pts).sum(dim=-1) # (Nv, 12) b = torch.einsum('ve,vec->vc', n_dot_p, n_per_edge) # (Nv, 3) # Tikhonov regularisation in-place (A, b are fresh einsum outputs) reg = 1e-2 A.diagonal(dim1=-2, dim2=-1).add_(reg) b.add_(centroid_verts, alpha=reg) try: qef_solution = torch.linalg.solve(A, b.unsqueeze(-1)).squeeze(-1) except Exception: qef_solution = centroid_verts # Clamp QEF output to the voxel bbox lo = corner_world[:, 0] # (Nv, 3) min corner hi = corner_world[:, 7] # (Nv, 3) max corner in_box = (qef_solution >= lo).all(dim=-1) & (qef_solution <= hi).all(dim=-1) qef_solution = torch.where(in_box.unsqueeze(-1), qef_solution, centroid_verts) dual_verts = torch.where(has_cross, qef_solution, centre_world) # Topology: each crossing grid edge is shared by 4 voxels -> quad -> 2 tris. # NEIGHBOUR_OFFS lays out the 4 sharing voxels per axis; y-axis order is # reversed vs x/z to keep manifold winding around each shared edge. NEIGHBOUR_OFFS = torch.tensor([ [[0, 0, 0], [0, -1, 0], [0, -1, -1], [0, 0, -1]], [[0, 0, 0], [0, 0, -1], [-1, 0, -1], [-1, 0, 0]], [[0, 0, 0], [-1, 0, 0], [-1, -1, 0], [0, -1, 0]], ], dtype=torch.long, device=device) # Min-corner +axis edge index per axis (slots 0/4/8 in EDGES) EDGE_OF_AXIS = torch.tensor([0, 4, 8], dtype=torch.long, device=device) # Sorted voxel-coord keys for neighbour lookup vox_dims = voxel_coords.max(dim=0)[0] + 2 vox_key = (voxel_coords[:, 0] * vox_dims[1] + voxel_coords[:, 1]) * vox_dims[2] + voxel_coords[:, 2] sort_v = vox_key.argsort() sorted_vox_key = vox_key[sort_v] tris = [] for axis in range(3): edge_idx = EDGE_OF_AXIS[axis] owner_mask = crosses[:, edge_idx] # (Nv,) bool if not owner_mask.any(): continue owner_voxels = voxel_coords[owner_mask] # (No, 3) a_sign = a_sd[owner_mask, edge_idx] # (No,) sign at corner a nbrs = owner_voxels.unsqueeze(1) + NEIGHBOUR_OFFS[axis].unsqueeze(0) # (No, 4, 3) nbr_keys = (nbrs[..., 0] * vox_dims[1] + nbrs[..., 1]) * vox_dims[2] + nbrs[..., 2] flat = nbr_keys.reshape(-1).contiguous() ins = torch.searchsorted(sorted_vox_key, flat) ins_c = ins.clamp(max=sorted_vox_key.numel() - 1) valid = (ins < sorted_vox_key.numel()) & (sorted_vox_key[ins_c] == flat) valid = valid.reshape(-1, 4).all(dim=1) if not valid.any(): continue dual_indices = sort_v[ins_c].reshape(-1, 4)[valid] # (Mv, 4) sign_a = a_sign[valid] # Winding: flip when corner a is outside (sign_a > 0) so normal points out d0 = dual_indices[:, 0] d1 = dual_indices[:, 1] d2 = dual_indices[:, 2] d3 = dual_indices[:, 3] flip = sign_a > 0 t1a = torch.stack([d0, d1, d2], dim=1) t2a = torch.stack([d0, d2, d3], dim=1) t1b = torch.stack([d0, d2, d1], dim=1) t2b = torch.stack([d0, d3, d2], dim=1) t1 = torch.where(flip.unsqueeze(-1), t1b, t1a) t2 = torch.where(flip.unsqueeze(-1), t2b, t2a) tris.append(t1) tris.append(t2) if not tris: return dual_verts, torch.empty((0, 3), dtype=torch.long, device=device) new_faces = torch.cat(tris, dim=0) return dual_verts, new_faces # Manifold Dual Contouring (Schaefer, Ju, Warren 2007) @functools.lru_cache(maxsize=None) def _build_mdc_lut() -> Tuple[torch.Tensor, torch.Tensor]: """Per 8-corner sign pattern: K (256,) patch count and group (256,12) patch id per edge (-1 if non-crossing).""" EDGE_PAIRS = [ (0, 1), (2, 3), (4, 5), (6, 7), # x-axis edges (0, 2), (1, 3), (4, 6), (5, 7), # y-axis edges (0, 4), (1, 5), (2, 6), (3, 7), # z-axis edges ] K = torch.zeros(256, dtype=torch.int64) group = torch.full((256, 12), -1, dtype=torch.int64) for pat in range(256): signs = [(pat >> i) & 1 for i in range(8)] # 1=outside, 0=inside parent = list(range(8)) def find(x: int) -> int: r = x while parent[r] != r: r = parent[r] while parent[x] != r: nxt = parent[x] parent[x] = r x = nxt return r # Union same-sign corners (not separated by the surface) for a, b in EDGE_PAIRS: if signs[a] == signs[b]: ra, rb = find(a), find(b) if ra != rb: parent[ra] = rb # Distinct (interior_root, exterior_root) pairs are distinct patches group_map: dict[tuple[int, int], int] = {} for ei, (a, b) in enumerate(EDGE_PAIRS): if signs[a] == signs[b]: continue in_c = a if signs[a] == 0 else b ex_c = b if signs[a] == 0 else a key = (find(in_c), find(ex_c)) if key not in group_map: group_map[key] = len(group_map) group[pat, ei] = group_map[key] K[pat] = len(group_map) return K, group @functools.lru_cache(maxsize=None) def _mdc_lut(device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: K, g = _build_mdc_lut() return K.to(device), g.to(device) def _dual_contour_manifold(voxel_coords: torch.Tensor, corner_udf: torch.Tensor, corner_keys: torch.Tensor, resolution: int, scale: float, center: torch.Tensor, corner_valid: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Manifold DC: like _dual_contour but emits 1-4 dual verts per voxel via the patch LUT (centroid placement only).""" device = voxel_coords.device Nv = voxel_coords.shape[0] CORNER_OFFS = torch.tensor([ [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], ], dtype=torch.long, device=device) corner_pos = voxel_coords.unsqueeze(1) + CORNER_OFFS.unsqueeze(0) # (Nv, 8, 3) R1 = resolution + 1 keys = (corner_pos[..., 0] * R1 + corner_pos[..., 1]) * R1 + corner_pos[..., 2] flat_keys = keys.reshape(-1) idx = torch.searchsorted(corner_keys, flat_keys) idx_c = idx.clamp(max=corner_keys.numel() - 1) found = (idx < corner_keys.numel()) & (corner_keys[idx_c] == flat_keys) sd = torch.where(found, corner_udf[idx_c], torch.full_like(corner_udf[idx_c], 1.0)).reshape(Nv, 8) # Sign pattern: bit i = (sd[i] > 0), matching the LUT convention sign_bits = (sd > 0).to(torch.int64) # (Nv, 8) weights = (1 << torch.arange(8, device=device, dtype=torch.int64)) pat_per_voxel = (sign_bits * weights).sum(dim=-1) # (Nv,) in 0..255 K_lut, group_lut = _mdc_lut(device) K_per_voxel = K_lut[pat_per_voxel] # (Nv,) total_verts = int(K_per_voxel.sum().item()) if total_verts == 0: return (torch.empty((0, 3), dtype=voxel_coords.dtype, device=device), torch.empty((0, 3), dtype=torch.long, device=device)) vert_offset = (torch.cumsum(K_per_voxel, dim=0) - K_per_voxel) # (Nv,) voxel_per_subvol = torch.repeat_interleave( torch.arange(Nv, device=device), K_per_voxel) # (total_verts,) EDGES = torch.tensor([ [0, 1], [2, 3], [4, 5], [6, 7], [0, 2], [1, 3], [4, 6], [5, 7], [0, 4], [1, 5], [2, 6], [3, 7], ], dtype=torch.long, device=device) sb_a = sign_bits[:, EDGES[:, 0]] # (Nv, 12) sb_b = sign_bits[:, EDGES[:, 1]] crosses = sb_a != sb_b # (Nv, 12) if corner_valid is not None: cv = torch.where(found, corner_valid[idx_c], torch.zeros_like(found)).reshape(Nv, 8) edge_valid = cv[:, EDGES[:, 0]] & cv[:, EDGES[:, 1]] crosses = crosses & edge_valid edge_group_per_voxel = group_lut[pat_per_voxel] # (Nv, 12), -1 if not crossing # LUT already gives -1 for non-crossing edges; only re-mask for corner_valid if corner_valid is not None: edge_group_per_voxel = torch.where(crosses, edge_group_per_voxel, torch.full_like(edge_group_per_voxel, -1)) a_sd = sd[:, EDGES[:, 0]] # (Nv, 12) b_sd = sd[:, EDGES[:, 1]] denom = a_sd - b_sd t = torch.where(denom.abs() > 1e-20, a_sd / denom, torch.zeros_like(a_sd)) t = t.clamp(0.0, 1.0).unsqueeze(-1) corner_world = (corner_pos.float() / resolution - 0.5) * scale + center.unsqueeze(0).unsqueeze(0) a_pos = corner_world[:, EDGES[:, 0]] b_pos = corner_world[:, EDGES[:, 1]] crossing_pts = torch.lerp(a_pos, b_pos, t) # (Nv, 12, 3) # Aggregate crossing positions per (voxel, subvolume) into global dual verts flat_group = edge_group_per_voxel.reshape(-1) valid_mask = flat_group >= 0 flat_voxel = torch.arange(Nv, device=device).unsqueeze(-1).expand(Nv, 12).reshape(-1) flat_pos = crossing_pts.reshape(-1, 3) v_idx = flat_voxel[valid_mask] g_idx = flat_group[valid_mask] pos = flat_pos[valid_mask] global_idx = vert_offset[v_idx] + g_idx # (Nvalid,) pos_dtype = crossing_pts.dtype sums = torch.zeros((total_verts, 3), dtype=pos_dtype, device=device) counts = torch.zeros(total_verts, dtype=pos_dtype, device=device) sums.scatter_add_(0, global_idx.unsqueeze(-1).expand(-1, 3), pos) counts.scatter_add_(0, global_idx, torch.ones_like(g_idx, dtype=pos_dtype)) # Fully-masked subvolumes default to the voxel centre (unreferenced) voxel_centre = ((voxel_coords.float() + 0.5) / resolution - 0.5) * scale + center.unsqueeze(0) dual_verts = torch.where( counts.unsqueeze(-1) > 0, sums / counts.clamp_min(1.0).unsqueeze(-1), voxel_centre[voxel_per_subvol].to(pos_dtype), ) # Face emission. SHARED_LOCAL_EDGE[axis,k] = the k-th neighbour's local edge # slot corresponding to the shared grid edge (owner's slot = EDGE_OF_AXIS[axis]). NEIGHBOUR_OFFS = torch.tensor([ [[0, 0, 0], [0, -1, 0], [0, -1, -1], [0, 0, -1]], [[0, 0, 0], [0, 0, -1], [-1, 0, -1], [-1, 0, 0]], [[0, 0, 0], [-1, 0, 0], [-1, -1, 0], [0, -1, 0]], ], dtype=torch.long, device=device) SHARED_LOCAL_EDGE = torch.tensor([ [0, 1, 3, 2], # x-axis [4, 6, 7, 5], # y-axis [8, 9, 11, 10], # z-axis ], dtype=torch.long, device=device) EDGE_OF_AXIS = torch.tensor([0, 4, 8], dtype=torch.long, device=device) vox_dims = voxel_coords.max(dim=0)[0] + 2 vox_key = (voxel_coords[:, 0] * vox_dims[1] + voxel_coords[:, 1]) * vox_dims[2] + voxel_coords[:, 2] sort_v = vox_key.argsort() sorted_vox_key = vox_key[sort_v] tris_out = [] for axis in range(3): edge_idx = EDGE_OF_AXIS[axis] owner_mask = crosses[:, edge_idx] if not owner_mask.any(): continue owner_voxels = voxel_coords[owner_mask] sign_a_at_owner = sb_a[owner_mask, edge_idx] # (No,) — 0 inside, 1 outside nbrs = owner_voxels.unsqueeze(1) + NEIGHBOUR_OFFS[axis].unsqueeze(0) # (No, 4, 3) nbr_keys = (nbrs[..., 0] * vox_dims[1] + nbrs[..., 1]) * vox_dims[2] + nbrs[..., 2] flat = nbr_keys.reshape(-1).contiguous() ins = torch.searchsorted(sorted_vox_key, flat) ins_c = ins.clamp(max=sorted_vox_key.numel() - 1) valid_nbr = (ins < sorted_vox_key.numel()) & (sorted_vox_key[ins_c] == flat) valid_quad = valid_nbr.reshape(-1, 4).all(dim=1) if not valid_quad.any(): continue nbr_orig = sort_v[ins_c].reshape(-1, 4)[valid_quad] # (Mv, 4) voxel idx nbr_pat = pat_per_voxel[nbr_orig] # (Mv, 4) local_e = SHARED_LOCAL_EDGE[axis].unsqueeze(0).expand_as(nbr_pat) nbr_subvol = group_lut[nbr_pat, local_e] # (Mv, 4) # Every neighbour must agree the shared edge is crossing ok = (nbr_subvol >= 0).all(dim=1) if not ok.any(): continue nbr_subvol = nbr_subvol[ok] nbr_orig = nbr_orig[ok] dual_indices = vert_offset[nbr_orig] + nbr_subvol # (Mv', 4) sign_a = sign_a_at_owner[valid_quad][ok] # 0 = inside, 1 = outside # Winding: flip when corner a is outside (same as _dual_contour) flip = sign_a > 0 d0, d1, d2, d3 = dual_indices.unbind(dim=1) t1a = torch.stack([d0, d1, d2], dim=1) t2a = torch.stack([d0, d2, d3], dim=1) t1b = torch.stack([d0, d2, d1], dim=1) t2b = torch.stack([d0, d3, d2], dim=1) tris_out.append(torch.where(flip.unsqueeze(-1), t1b, t1a)) tris_out.append(torch.where(flip.unsqueeze(-1), t2b, t2a)) if not tris_out: return dual_verts, torch.empty((0, 3), dtype=torch.long, device=device) return dual_verts, torch.cat(tris_out, dim=0) # Main entry def _filter_components(verts: torch.Tensor, faces: torch.Tensor, min_fraction: float = 0.01, drop_inverted: bool = True, drop_enclosed: bool = True) -> torch.Tensor: """Drop tiny / inverted-volume / bbox-enclosed connected components; returns filtered faces.""" device = faces.device V = verts.shape[0] # Connected components via min-label propagation across faces (200-iter max) label = torch.arange(V, dtype=torch.long, device=device) for _ in range(200): f_min = torch.minimum(torch.minimum(label[faces[:, 0]], label[faces[:, 1]]), label[faces[:, 2]]) new_label = label.clone() new_label.scatter_reduce_(0, faces[:, 0], f_min, reduce="amin", include_self=True) new_label.scatter_reduce_(0, faces[:, 1], f_min, reduce="amin", include_self=True) new_label.scatter_reduce_(0, faces[:, 2], f_min, reduce="amin", include_self=True) new_label = new_label[new_label] # path compression if torch.equal(new_label, label): break label = new_label face_label = label[faces[:, 0]] # (F,) unique_labels, inv = torch.unique(face_label, return_inverse=True) C = unique_labels.shape[0] counts = torch.bincount(inv, minlength=C) max_count = int(counts.max().item()) keep = torch.ones(C, dtype=torch.bool, device=device) if min_fraction > 0: threshold = max(1, int(max_count * min_fraction)) keep = keep & (counts >= threshold) if drop_inverted: # Drop components with negative signed volume, but always keep the largest v0 = verts[faces[:, 0]] v1 = verts[faces[:, 1]] v2 = verts[faces[:, 2]] face_vol = (v0 * torch.cross(v1, v2, dim=-1)).sum(dim=-1) # (F,) comp_vol = torch.zeros(C, dtype=face_vol.dtype, device=device) comp_vol.scatter_add_(0, inv, face_vol) if C > 1: large = counts.argmax() vol_ok = (comp_vol >= 0) vol_ok[large] = True keep = keep & vol_ok if drop_enclosed and C > 1: # Two-pass: (1) bbox-inside-largest test, then (2) +X raycast point-in-mesh large = counts.argmax() face_v = verts[faces] face_min = face_v.min(dim=1).values face_max = face_v.max(dim=1).values comp_min = torch.full((C, 3), float("inf"), dtype=verts.dtype, device=device) comp_max = torch.full((C, 3), float("-inf"), dtype=verts.dtype, device=device) comp_min.scatter_reduce_(0, inv[:, None].expand(-1, 3), face_min, reduce="amin", include_self=True) comp_max.scatter_reduce_(0, inv[:, None].expand(-1, 3), face_max, reduce="amax", include_self=True) big_min = comp_min[large] big_max = comp_max[large] enclosed = ((comp_min >= big_min).all(dim=-1) & (comp_max <= big_max).all(dim=-1)) enclosed[large] = False # Per-component centroid for the raycast test face_centroid = face_v.mean(dim=1) # (F, 3) comp_centroid = torch.zeros((C, 3), dtype=verts.dtype, device=device) comp_centroid.scatter_add_(0, inv[:, None].expand(-1, 3), face_centroid) comp_centroid = comp_centroid / counts.to(verts.dtype).unsqueeze(-1).clamp_min(1.0) # Raycast surviving non-largest candidates (small loop) big_faces = faces[inv == large] bv0 = verts[big_faces[:, 0]] bv1 = verts[big_faces[:, 1]] bv2 = verts[big_faces[:, 2]] candidates = torch.nonzero((keep & ~enclosed) & (torch.arange(C, device=device) != large), as_tuple=True)[0] for ci in candidates.tolist(): origin = comp_centroid[ci] # 2D point-in-triangle in YZ for the ray origin's (y, z) oy, oz = origin[1], origin[2] s12 = (bv1[:, 1] - oy) * (bv2[:, 2] - oz) - (bv1[:, 2] - oz) * (bv2[:, 1] - oy) s20 = (bv2[:, 1] - oy) * (bv0[:, 2] - oz) - (bv2[:, 2] - oz) * (bv0[:, 1] - oy) s01 = (bv0[:, 1] - oy) * (bv1[:, 2] - oz) - (bv0[:, 2] - oz) * (bv1[:, 1] - oy) total = s12 + s20 + s01 inside_yz = (((s12 >= 0) & (s20 >= 0) & (s01 >= 0)) | ((s12 <= 0) & (s20 <= 0) & (s01 <= 0))) inside_yz = inside_yz & (total.abs() > 1e-20) inv_t = 1.0 / total.where(total.abs() > 1e-20, torch.ones_like(total)) hit_x = (s12 * bv0[:, 0] + s20 * bv1[:, 0] + s01 * bv2[:, 0]) * inv_t crossings = int((inside_yz & (hit_x > origin[0])).sum().item()) if crossings % 2 == 1: enclosed[ci] = True keep = keep & ~enclosed if keep.all(): return faces face_keep = keep[inv] return faces[face_keep] def _taubin_smooth(verts: torch.Tensor, faces: torch.Tensor, iters: int, lam: float = 0.5, mu: float = -0.53, progress_callback=None) -> torch.Tensor: """Taubin lambda|mu low-pass smoothing (volume-preserving); boundary verts are no-ops.""" if iters <= 0 or verts.numel() == 0 or faces.numel() == 0: return verts device = verts.device V = verts.shape[0] sorted_keys, _, _ = _sorted_edge_halfedges(faces, V) uniq_keys, _ = torch.unique_consecutive(sorted_keys, return_counts=True) P = V + 1 a = uniq_keys // P b = uniq_keys % P ones = torch.ones_like(a, dtype=verts.dtype) counts = torch.zeros(V, dtype=verts.dtype, device=device) counts.scatter_add_(0, a, ones) counts.scatter_add_(0, b, ones) counts_safe = counts.clamp_min(1.0).unsqueeze(-1) has_nb = (counts > 0).unsqueeze(-1) a_exp = a.unsqueeze(-1).expand(-1, 3) b_exp = b.unsqueeze(-1).expand(-1, 3) out = verts for _ in range(iters): throw_exception_if_processing_interrupted() for w in (lam, mu): sums = torch.zeros_like(out) sums.scatter_add_(0, a_exp, out[b]) sums.scatter_add_(0, b_exp, out[a]) delta = (sums / counts_safe - out) * has_nb out = out + w * delta if progress_callback is not None: progress_callback() return out def _fix_poles(verts: torch.Tensor, faces: torch.Tensor, colors: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Midpoint-collapse edge-sharing valence-3 vertex pairs (DC T-junction poles); boundary verts excluded.""" device = verts.device V = verts.shape[0] if V == 0 or faces.numel() == 0: return verts, faces, colors sorted_keys, _, _ = _sorted_edge_halfedges(faces, V) uniq_keys, key_counts = torch.unique_consecutive(sorted_keys, return_counts=True) P = V + 1 a = uniq_keys // P b = uniq_keys % P # Boundary verts (endpoints of single-face edges) are excluded from poles boundary_v = torch.zeros(V, dtype=torch.bool, device=device) bnd_mask = key_counts == 1 if bnd_mask.any(): boundary_v[a[bnd_mask]] = True boundary_v[b[bnd_mask]] = True ones = torch.ones_like(a) valence = torch.zeros(V, dtype=torch.long, device=device) valence.scatter_add_(0, a, ones) valence.scatter_add_(0, b, ones) is_pole = (valence == 3) & ~boundary_v if int(is_pole.sum().item()) < 2: return verts, faces, colors pp_edge = is_pole[a] & is_pole[b] if not pp_edge.any(): return verts, faces, colors cand_a = a[pp_edge] cand_b = b[pp_edge] # Greedy maximal matching: accept candidates whose endpoints are still free used = torch.zeros(V, dtype=torch.bool, device="cpu") cand_a_cpu = cand_a.cpu().tolist() cand_b_cpu = cand_b.cpu().tolist() pairs: list[tuple[int, int]] = [] for ai, bi in zip(cand_a_cpu, cand_b_cpu): if not used[ai] and not used[bi]: pairs.append((ai, bi)) used[ai] = True used[bi] = True if not pairs: return verts, faces, colors pairs_t = torch.tensor(pairs, dtype=torch.long, device=device) # (P, 2) keep_i = torch.minimum(pairs_t[:, 0], pairs_t[:, 1]) drop_i = torch.maximum(pairs_t[:, 0], pairs_t[:, 1]) new_verts = verts.clone() new_verts[keep_i] = 0.5 * (verts[pairs_t[:, 0]] + verts[pairs_t[:, 1]]) new_colors = None if colors is not None: new_colors = colors.clone() new_colors[keep_i] = 0.5 * (colors[pairs_t[:, 0]] + colors[pairs_t[:, 1]]) remap = torch.arange(V, dtype=torch.long, device=device) remap[drop_i] = keep_i new_faces = remap[faces.long()] degen = ((new_faces[:, 0] == new_faces[:, 1]) | (new_faces[:, 1] == new_faces[:, 2]) | (new_faces[:, 0] == new_faces[:, 2])) new_faces = new_faces[~degen] used_mask = torch.zeros(V, dtype=torch.bool, device=device) used_mask[new_faces.reshape(-1)] = True if not used_mask.all(): compact = used_mask.long().cumsum(0) - 1 new_verts = new_verts[used_mask] if new_colors is not None: new_colors = new_colors[used_mask] new_faces = compact[new_faces] return new_verts, new_faces.to(faces.dtype), new_colors def remesh_narrow_band_dc( vertices: torch.Tensor, faces: torch.Tensor, resolution: int = 256, target_faces: int = 0, # 0 = use `resolution`; >0 = auto-derive resolution band: float = 1.0, project_back: float = 0.0, qef: bool = True, sign_mode: str = "udf", # "sdf" | "udf" drop_small_components: float = 0.01, # drop components below this fraction of max drop_inverted_components: bool = True, # drop closed components with negative signed volume drop_enclosed_components: bool = True, # drop components whose bbox is inside the largest's bbox fix_poles: bool = False, # collapse 3-3 valence vertex pairs (DC T-junction artifact) smooth_iters: int = 0, # Taubin smoothing iterations (low-pass, volume-preserving) smooth_lambda: float = 0.5, smooth_mu: float = -0.53, manifold: bool = False, # Manifold DC: emit 1-4 dual verts per voxel for multi-sheet cases colors: Optional[torch.Tensor] = None, scale: Optional[float] = None, center: Optional[torch.Tensor] = None, ): """Narrow-band Dual Contouring re-extraction; returns (new_vertices, new_faces, new_colors), new_colors None unless `colors` given. Key params: target_faces>0 auto-derives resolution; sign_mode sdf/udf (UDF disables qef and may need component filters); project_back lerps verts toward the closest surface point; scale/center default to bbox. """ assert vertices.ndim == 2 and vertices.shape[1] == 3 assert faces.ndim == 2 and faces.shape[1] == 3 device = vertices.device if center is None: center = 0.5 * (vertices.max(dim=0)[0] + vertices.min(dim=0)[0]) else: center = center.to(device=device, dtype=vertices.dtype) if scale is None: bbox = vertices.max(dim=0)[0] - vertices.min(dim=0)[0] scale = float(bbox.max().item()) * 1.1 # Auto-derive resolution from target_faces (~3 tris/crossing-voxel; +-30%) if target_faces > 0: tv = vertices[faces.long()] cross_v = torch.cross(tv[:, 1] - tv[:, 0], tv[:, 2] - tv[:, 0], dim=-1) surface_area = 0.5 * cross_v.norm(dim=-1).sum().item() relative_area = max(surface_area / (scale * scale), 1e-6) derived = int(math.sqrt(target_faces / (3.0 * relative_area))) # Round to a multiple of 32 (builder doubles from a <=32 base) derived = ((derived + 31) // 32) * 32 derived = max(32, min(1024, derived)) resolution = derived eps = band * scale / resolution # progress: one tick per narrow-band level + 3 stages (SDF/DC/post) + each smoothing iter n_levels, _b = 1, resolution while _b > 32 and _b % 2 == 0: _b //= 2 while _b < resolution: _b *= 2 n_levels += 1 _total_ticks = n_levels + 3 + int(smooth_iters) _pbar = comfy.utils.ProgressBar(_total_ticks) try: _tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False) except Exception: _tq = None def tick(): _pbar.update(1) if _tq is not None: _tq.update(1) # Step 1: sparse narrow-band voxel grid (coarse-to-fine) voxel_coords, _band_tree = _build_narrow_band_voxels( vertices, faces, center, scale, resolution, eps, progress_callback=tick) if voxel_coords.numel() == 0: return (torch.empty((0, 3), dtype=vertices.dtype, device=device), torch.empty((0, 3), dtype=faces.dtype, device=device), None if colors is None else torch.empty((0, colors.shape[1]), dtype=colors.dtype, device=device)) # Step 2: collect unique corner positions of all active voxels CORNER_OFFS = torch.tensor([ [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], ], dtype=torch.long, device=device) corners = (voxel_coords.unsqueeze(1) + CORNER_OFFS.unsqueeze(0)).reshape(-1, 3) R1 = resolution + 1 corner_keys = (corners[:, 0] * R1 + corners[:, 1]) * R1 + corners[:, 2] unique_corner_keys, corner_inv = torch.unique(corner_keys, return_inverse=True) unique_corners = torch.zeros((unique_corner_keys.shape[0], 3), dtype=torch.long, device=device) unique_corners[corner_inv] = corners if sign_mode == "sdf": use_sdf = True elif sign_mode == "udf": use_sdf = False else: raise ValueError(f"sign_mode must be 'sdf'|'udf', got {sign_mode!r}") # Step 3: distance field at every unique corner. tri_verts_g = vertices[faces.long()] centroids = tri_verts_g.mean(dim=1) tri_radii = (tri_verts_g - centroids.unsqueeze(1)).norm(dim=-1).max(dim=-1).values # face normals: needed for the SDF sign AND for QEF placement (QEF is sign-agnostic, # so it works in UDF mode too — (n·(x-p))² is unchanged by normal orientation) if use_sdf or qef: tri_face_normals_all = torch.nn.functional.normalize( torch.cross(tri_verts_g[:, 1] - tri_verts_g[:, 0], tri_verts_g[:, 2] - tri_verts_g[:, 0], dim=-1), p=2, dim=-1, eps=1e-12) cell_size = scale / resolution corner_world = (unique_corners.float() / resolution - 0.5) * scale + center.unsqueeze(0) # Exact corner UDF (no max_dist cap) so DC crossings keep fine detail udf, corner_closest, corner_tri = _udf_exact(corner_world, tri_verts_g, tree=_band_tree) corner_valid = corner_tri >= 0 if use_sdf: sign = torch.ones_like(udf) n_for_corner = tri_face_normals_all[corner_tri.clamp(min=0)] offset = corner_world - corner_closest sign_dot = (offset * n_for_corner).sum(-1) sign = torch.where(corner_valid & (sign_dot < 0), -sign, sign) sdf = sign * udf else: # UDF mode: iso at UDF=eps; double surface on closed meshes, weld after sdf = udf - eps tick() # SDF done # Short-range hash reused by project_back / colors sampling (max_dist up to 4*cell) short_hash_cell_t = torch.tensor(2.0 * cell_size, dtype=vertices.dtype, device=device) short_hash = _build_tri_spatial_hash(centroids, tri_radii, short_hash_cell_t) # Step 4 + 5: dual contouring + topology. QEF works in both modes (sign-agnostic); # in UDF it pulls the ±eps crossing back onto the triangle planes → sharper edges. if qef: tri_face_normals = tri_face_normals_all # QEF needs the nearest triangle per crossing point. The centroid cKDTree # (_band_tree) is already built, and its exact k-NN query is markedly faster # here than a spatial-hash gather (which builds ~100-triangle candidate lists # per query on a dense input) — and it's exact. So reuse it directly. def _qef_query(pts): return _udf_exact(pts, tri_verts_g, tree=_band_tree) else: tri_face_normals = None _qef_query = None if manifold and use_sdf: # MDC ignores qef / tri_face_normals — centroid placement only. dual_verts, new_faces = _dual_contour_manifold( voxel_coords, sdf, unique_corner_keys, resolution, scale, center, corner_valid=corner_valid) else: dual_verts, new_faces = _dual_contour( voxel_coords, sdf, unique_corner_keys, resolution, scale, center, tri_face_normals=tri_face_normals, qef_query=_qef_query, # corner_valid filter only matters in SDF mode corner_valid=corner_valid if use_sdf else None) tick() # DC done # Step 6: project_back and / or color sampling share one closest-point query need_query = (project_back > 0 or colors is not None) and dual_verts.numel() > 0 out_colors = None if need_query: result = _udf_query( dual_verts, tri_verts_g, short_hash, short_hash_cell_t, max_dist=4.0 * cell_size, return_closest=True, return_tri_idx=(colors is not None)) if colors is not None: _, closest_pts, closest_tri = result else: _, closest_pts = result if project_back > 0: dual_verts = torch.lerp(dual_verts, closest_pts, float(project_back)) if colors is not None: # Barycentric-interpolate input colors at the closest point safe_tri = closest_tri.clamp(min=0) tri_v_idx = faces[safe_tri].long() # (N, 3) tri_v = vertices[tri_v_idx] # (N, 3, 3) v0 = tri_v[:, 0] v1 = tri_v[:, 1] v2 = tri_v[:, 2] e0 = v1 - v0 e1 = v2 - v0 e2 = closest_pts - v0 d00 = (e0 * e0).sum(-1) d01 = (e0 * e1).sum(-1) d11 = (e1 * e1).sum(-1) d20 = (e2 * e0).sum(-1) d21 = (e2 * e1).sum(-1) denom = d00 * d11 - d01 * d01 + 1e-20 bv = ((d11 * d20 - d01 * d21) / denom).clamp(0.0, 1.0) bw = ((d00 * d21 - d01 * d20) / denom).clamp(0.0, 1.0) bu = (1.0 - bv - bw).clamp(0.0, 1.0) tri_c = colors[tri_v_idx] # (N, 3, C) out_colors = (bu.unsqueeze(-1) * tri_c[:, 0] + bv.unsqueeze(-1) * tri_c[:, 1] + bw.unsqueeze(-1) * tri_c[:, 2]) # Zero out failed-query rows (their barycentric used bogus triangle 0) invalid = closest_tri < 0 if invalid.any(): out_colors[invalid] = 0 # Filter spurious components (tiny pieces, inverted inner shells) if (new_faces.numel() > 0 and (drop_small_components > 0 or drop_inverted_components or drop_enclosed_components)): new_faces = _filter_components( dual_verts, new_faces, min_fraction=drop_small_components if drop_small_components > 0 else 0.0, drop_inverted=drop_inverted_components, drop_enclosed=drop_enclosed_components) if fix_poles and new_faces.numel() > 0: dual_verts, new_faces, out_colors = _fix_poles( dual_verts, new_faces, out_colors) tick() # post-process done if smooth_iters > 0 and dual_verts.numel() > 0 and new_faces.numel() > 0: dual_verts = _taubin_smooth(dual_verts, new_faces, iters=int(smooth_iters), lam=float(smooth_lambda), mu=float(smooth_mu), progress_callback=tick) # Drop unused verts (non-crossing voxels' dual verts) and compact faces if dual_verts.numel() > 0 and new_faces.numel() > 0: used = torch.zeros(dual_verts.shape[0], dtype=torch.bool, device=device) used[new_faces[:, 0]] = True used[new_faces[:, 1]] = True used[new_faces[:, 2]] = True remap = used.long().cumsum(0) - 1 dual_verts = dual_verts[used] new_faces = remap[new_faces.long()] if out_colors is not None: out_colors = out_colors[used] return (dual_verts.to(vertices.dtype), new_faces.to(faces.dtype), out_colors.to(colors.dtype) if (out_colors is not None and colors is not None) else None)