diff --git a/comfy_extras/qem_decimate/qem_core.py b/comfy_extras/mesh3d/postprocess/qem_decimate.py similarity index 95% rename from comfy_extras/qem_decimate/qem_core.py rename to comfy_extras/mesh3d/postprocess/qem_decimate.py index 7727a9273..074c2a808 100644 --- a/comfy_extras/qem_decimate/qem_core.py +++ b/comfy_extras/mesh3d/postprocess/qem_decimate.py @@ -1618,3 +1618,76 @@ def simplify( out_n if out_n else None, out_s) return qem_simplify(vertices, faces, target, colors, normals, max_edge_length, config) + + +def cluster_decimate( + vertices: torch.Tensor, faces: torch.Tensor, + target_verts: int = 1_000_000, + colors: Optional[torch.Tensor] = None, + face_chunk: int = 4_000_000, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Vertex-cluster decimation (Rossignac-Borrel): bin verts into a ~target_verts grid, + average per cell, remap faces (chunked), drop degenerate/duplicate. Fast O(V+F) prepass + for huge meshes before QEM/remesh. Returns (verts, faces, colors).""" + if vertices.shape[0] == 0 or faces.shape[0] == 0: + return vertices, faces, colors + + device = vertices.device + bbox = vertices.max(dim=0)[0] - vertices.min(dim=0)[0] + bbox_min = vertices.min(dim=0)[0] + # cell size so the bbox holds ~3× target_verts cells (surface occupancy ~1/3) + cell_count_target = max(target_verts * 3, 1000) + extent_max = float(bbox.max().item()) + cells_per_axis = (cell_count_target ** (1 / 3)) + cell_size = extent_max / max(1.0, cells_per_axis) + scale = 1.0 / max(cell_size, 1e-20) + + q = ((vertices - bbox_min) * scale).floor().to(torch.int64) + extent = (bbox * scale).floor().to(torch.int64) + 2 + Wy = extent[1] + Wz = extent[2] + key = (q[:, 0] * Wy + q[:, 1]) * Wz + q[:, 2] + + unique_key, inv = torch.unique(key, return_inverse=True) + n_unique = unique_key.shape[0] + counts = torch.zeros(n_unique, dtype=vertices.dtype, device=device) + counts.scatter_add_(0, inv, torch.ones(vertices.shape[0], dtype=vertices.dtype, device=device)) + counts_div = counts.unsqueeze(-1).clamp_min(1.0) + + new_verts = torch.zeros((n_unique, 3), dtype=vertices.dtype, device=device) + new_verts.scatter_add_(0, inv.unsqueeze(-1).expand_as(vertices), vertices) + new_verts = new_verts / counts_div + + new_colors = None + if colors is not None: + new_colors = torch.zeros((n_unique, colors.shape[1]), dtype=colors.dtype, device=device) + new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors) + new_colors = new_colors / counts_div.to(colors.dtype) + + # remap faces in chunks (face tensor can be huge); drop degenerates per chunk + out_chunks = [] + F = faces.shape[0] + for fs in range(0, F, face_chunk): + fe = min(fs + face_chunk, F) + cf = inv[faces[fs:fe].long()] + nondeg = ((cf[:, 0] != cf[:, 1]) & (cf[:, 1] != cf[:, 2]) & (cf[:, 0] != cf[:, 2])) + if nondeg.any(): + out_chunks.append(cf[nondeg]) + if out_chunks: + new_faces = torch.cat(out_chunks, dim=0) + else: + new_faces = torch.empty((0, 3), dtype=faces.dtype, device=device) + + # drop duplicate faces (same vertex set after clustering) + if new_faces.numel() > 0: + key_sorted = torch.sort(new_faces, dim=1)[0] + P = n_unique + 1 + packed = (key_sorted[:, 0].long() * P + key_sorted[:, 1].long()) * P + key_sorted[:, 2].long() + _, first = torch.unique(packed, return_inverse=True) + arange = torch.arange(packed.shape[0], device=device, dtype=torch.int64) + first_idx = torch.full((int(first.max().item()) + 1,), packed.shape[0], + dtype=torch.int64, device=device) + first_idx.scatter_reduce_(0, first, arange, reduce="amin", include_self=True) + new_faces = new_faces[first_idx] + + return new_verts.to(vertices.dtype), new_faces.to(faces.dtype), new_colors diff --git a/comfy_extras/mesh3d/postprocess/remesh.py b/comfy_extras/mesh3d/postprocess/remesh.py new file mode 100644 index 000000000..502c4e9bc --- /dev/null +++ b/comfy_extras/mesh3d/postprocess/remesh.py @@ -0,0 +1,1151 @@ +"""Narrow-band Dual Contouring remeshing. + +Re-extracts a mesh from a sparse narrow-band voxel grid around the input +surface (pure-PyTorch approximation of CuMesh's remesh_narrow_band_dc). +Coarse-to-fine voxelise the band, sample SDF/UDF at voxel corners, dual +contour (optionally QEF / Manifold DC), then optionally project back, +filter components, fix poles, smooth, and interpolate vertex colors. +""" +from __future__ import annotations + +import functools +import math +from typing import Optional, Tuple + +import numpy as np +import torch +import scipy.spatial +import comfy.utils +from tqdm import tqdm as _tqdm +from comfy.model_management import throw_exception_if_processing_interrupted + +from .qem_decimate import _sorted_edge_halfedges + + +# Point-to-triangle distance (exact, vectorised) + +def _point_tri_closest(points: torch.Tensor, tris: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Exact closest point + squared distance per (point, triangle) pair; points (N,3), tris (N,3,3).""" + a = tris[:, 0] + b = tris[:, 1] + c = tris[:, 2] + ab = b - a + ac = c - a + ap = points - a + + d1 = (ab * ap).sum(-1) + d2 = (ac * ap).sum(-1) + + region_A = (d1 <= 0) & (d2 <= 0) + + bp = points - b + d3 = (ab * bp).sum(-1) + d4 = (ac * bp).sum(-1) + region_B = (d3 >= 0) & (d4 <= d3) + + cp = points - c + d5 = (ab * cp).sum(-1) + d6 = (ac * cp).sum(-1) + region_C = (d6 >= 0) & (d5 <= d6) + + # Edge AB + vc = d1 * d4 - d3 * d2 + region_AB = (vc <= 0) & (d1 >= 0) & (d3 <= 0) + v_ab = d1 / (d1 - d3 + 1e-20) + closest_AB = a + v_ab.unsqueeze(-1) * ab + + # Edge AC + vb = d5 * d2 - d1 * d6 + region_AC = (vb <= 0) & (d2 >= 0) & (d6 <= 0) + v_ac = d2 / (d2 - d6 + 1e-20) + closest_AC = a + v_ac.unsqueeze(-1) * ac + + # Edge BC + va = d3 * d6 - d5 * d4 + region_BC = (va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0) + v_bc = (d4 - d3) / ((d4 - d3) + (d5 - d6) + 1e-20) + closest_BC = b + v_bc.unsqueeze(-1) * (c - b) + + # Face interior (barycentric) + denom = va + vb + vc + 1e-20 + v_face = vb / denom + w_face = vc / denom + closest_face = a + v_face.unsqueeze(-1) * ab + w_face.unsqueeze(-1) * ac + + # Combine by mask via in-place where (out= aliases input, no per-step alloc) + closest = closest_face # fresh; safe to mutate + torch.where(region_BC.unsqueeze(-1), closest_BC, closest, out=closest) + torch.where(region_AC.unsqueeze(-1), closest_AC, closest, out=closest) + torch.where(region_AB.unsqueeze(-1), closest_AB, closest, out=closest) + torch.where(region_C .unsqueeze(-1), c, closest, out=closest) + torch.where(region_B .unsqueeze(-1), b, closest, out=closest) + torch.where(region_A .unsqueeze(-1), a, closest, out=closest) + + diff = points - closest + return closest, (diff * diff).sum(-1) + + +def _build_centroid_tree(tri_verts: torch.Tensor): + """scipy cKDTree over triangle centroids; build once and reuse across _udf_exact calls. + balanced_tree/compact_nodes off: ~2.4x faster build (and faster queries on near-uniform + centroid clouds) with identical exact-kNN results.""" + return scipy.spatial.cKDTree(tri_verts.mean(dim=1).detach().cpu().numpy(), + balanced_tree=False, compact_nodes=False) + + +def _udf_exact(query_points: torch.Tensor, tri_verts: torch.Tensor, + k: int = 8, chunk: int = 262144, tree=None): + """Exact UDF (no max_dist cap) via centroid kNN; returns (dist [N], closest [N,3], tri_idx [N]). Pass prebuilt `tree` to skip rebuild. + + k=8 nearest centroids before the exact point-triangle test: on dense meshes the true + closest triangle is essentially always within the first few neighbours. Measured vs k=16: + bit-identical topology, ~0.003-voxel RMS sub-voxel drift, ~15% faster overall.""" + device = query_points.device + F = tri_verts.shape[0] + kq = int(min(k, F)) + if tree is None: + tree = _build_centroid_tree(tri_verts) + _, cand = tree.query(query_points.detach().cpu().numpy(), k=kq, workers=-1) + if cand.ndim == 1: + cand = cand[:, None] + cand = np.ascontiguousarray(cand) + + N = query_points.shape[0] + out_d = torch.empty(N, device=device, dtype=query_points.dtype) + out_c = torch.empty(N, 3, device=device, dtype=query_points.dtype) + out_t = torch.empty(N, dtype=torch.long, device=device) + for s in range(0, N, chunk): + e = min(s + chunk, N) + n = e - s + ci = torch.from_numpy(cand[s:e]).to(device).long() + tri = tri_verts[ci].reshape(n * kq, 3, 3) + P = query_points[s:e][:, None, :].expand(-1, kq, -1).reshape(n * kq, 3) + closest, d2 = _point_tri_closest(P, tri) + d2 = d2.reshape(n, kq) + closest = closest.reshape(n, kq, 3) + best = d2.argmin(dim=1) + ar = torch.arange(n, device=device) + out_d[s:e] = d2[ar, best].sqrt() + out_c[s:e] = closest[ar, best] + out_t[s:e] = ci[ar, best] + return out_d, out_c, out_t + + +# UDF query via spatial hash on triangle AABBs + +def _build_tri_spatial_hash(centroids: torch.Tensor, tri_radii: torch.Tensor, + cell_size: torch.Tensor): + """Bucket triangles into `cell_size` cells (each tri into every cell its AABB touches); returns hash tuple.""" + device = centroids.device + aabb_lo = (centroids - tri_radii.unsqueeze(-1)) + aabb_hi = (centroids + tri_radii.unsqueeze(-1)) + origin = aabb_lo.min(0)[0] + extent = aabb_hi.max(0)[0] - origin + dims = (extent / cell_size).long() + 2 + + cell_lo = ((aabb_lo - origin) / cell_size).long().clamp(min=0) + cell_hi = ((aabb_hi - origin) / cell_size).long() + cell_hi = torch.minimum(cell_hi, dims - 1) + + # Cap span at 3 cells/axis to bound memory + spans = (cell_hi - cell_lo + 1).clamp(max=3) + n_per_tri = spans.prod(dim=-1) + total = int(n_per_tri.sum().item()) + + # Per-insertion local offset within each tri's cell box + rep = torch.repeat_interleave(torch.arange(centroids.shape[0], device=device), n_per_tri) + cum = torch.cat([torch.zeros(1, device=device, dtype=n_per_tri.dtype), + n_per_tri.cumsum(0)[:-1]]) + local = torch.arange(total, device=device) - cum[rep] + sx = spans[rep, 0] + sy = spans[rep, 1] + sz = spans[rep, 2] + lx = local % sx + ly = (local // sx) % sy + lz = local // (sx * sy) + cx = cell_lo[rep, 0] + lx + cy = cell_lo[rep, 1] + ly + cz = cell_lo[rep, 2] + lz + keys = (cx * dims[1] + cy) * dims[2] + cz + + sort_idx = keys.argsort() + sorted_keys = keys[sort_idx] + tri_per_cell = rep[sort_idx] + + unique_keys, counts = torch.unique_consecutive(sorted_keys, return_counts=True) + cell_starts = torch.cat([torch.zeros(1, dtype=counts.dtype, device=device), + counts.cumsum(0)]) + return origin, dims, unique_keys, tri_per_cell, cell_starts, centroids, tri_radii + + +def _udf_query(query_points: torch.Tensor, + tri_verts: torch.Tensor, + hash_data, + cell_size: torch.Tensor, + max_dist: float, + chunk_max: int = 4096, + return_closest: bool = False, + return_tri_idx: bool = False): + """Capped UDF to nearest triangle (<= max_dist), optionally with closest point and/or tri index; chunk size is adaptive to hash density.""" + origin, dims, unique_keys, tri_per_cell, cell_starts, tri_centroids, tri_radii = hash_data + device = query_points.device + Q = query_points.shape[0] + # Adaptive chunk: bound per-chunk candidate-gather memory by hash density + avg_per_cell = tri_per_cell.numel() / max(1, unique_keys.numel()) + est_cands_per_query = max(1.0, avg_per_cell * 27) + chunk = max(256, min(chunk_max, int(50_000_000 / est_cands_per_query))) + out_d2 = torch.full((Q,), float(max_dist) ** 2, dtype=query_points.dtype, device=device) + # Default closest_pt = query_pt itself, so a missed query's lerp is a no-op + out_closest = (query_points.clone() if return_closest else None) + out_tri = (torch.full((Q,), -1, dtype=torch.long, device=device) + if return_tri_idx else None) + + rng = torch.tensor([-1, 0, 1], device=device, dtype=torch.long) + offs = torch.stack(torch.meshgrid(rng, rng, rng, indexing="ij"), dim=-1).reshape(-1, 3) # (27, 3) + + for cs in range(0, Q, chunk): + ce = min(cs + chunk, Q) + qp = query_points[cs:ce] + q_cell = ((qp - origin) / cell_size).long() + # Look up 27 neighbour cells per query + n_cell = q_cell.unsqueeze(1) + offs.unsqueeze(0) # (q, 27, 3) + n_valid = ((n_cell >= 0) & (n_cell < dims)).all(-1) + n_key = (n_cell[..., 0] * dims[1] + n_cell[..., 1]) * dims[2] + n_cell[..., 2] + flat_key = n_key.reshape(-1).contiguous() + ins = torch.searchsorted(unique_keys, flat_key) + ins_c = ins.clamp(max=unique_keys.numel() - 1) + found = (ins < unique_keys.numel()) & (unique_keys[ins_c] == flat_key) & n_valid.reshape(-1) + cell_idx = torch.where(found, ins_c, torch.zeros_like(ins_c)) + c_starts = cell_starts[cell_idx] + c_ends = cell_starts[cell_idx + 1] + c_counts = (c_ends - c_starts) * found.long() + rep_q = torch.repeat_interleave( + torch.arange(qp.shape[0] * 27, device=device) // 27, c_counts) + if rep_q.numel() == 0: + continue + total = rep_q.numel() + slot_starts_per_pair = torch.cumsum(c_counts, dim=0) - c_counts + per_pair_start = torch.repeat_interleave(c_starts, c_counts) + slot_within = torch.arange(total, device=device) - torch.repeat_interleave(slot_starts_per_pair, c_counts) + tri_indices = tri_per_cell[per_pair_start + slot_within] + + pts = qp[rep_q] + # Centroid pre-cull (squared): drop where ||pts-centroid||-radius > max_dist + diff = pts - tri_centroids[tri_indices] + d2_cand = (diff * diff).sum(-1) + thresh = max_dist + tri_radii[tri_indices] + cull_keep = d2_cand < thresh * thresh + rep_q = rep_q[cull_keep] + pts = pts[cull_keep] + tri_indices = tri_indices[cull_keep] + if rep_q.numel() == 0: + continue + tri = tri_verts[tri_indices] + closest, d2 = _point_tri_closest(pts, tri) + + # Min per query for this chunk. + local_min = torch.full((qp.shape[0],), float(max_dist) ** 2, + dtype=query_points.dtype, device=device) + local_min.scatter_reduce_(0, rep_q, d2, reduce="amin", include_self=True) + # Only update where this chunk improved; ties may overwrite (any is valid) + better = local_min < out_d2[cs:ce] + out_d2[cs:ce] = torch.where(better, local_min, out_d2[cs:ce]) + if return_closest or return_tri_idx: + ties = (d2 == local_min[rep_q]) & better[rep_q] + if return_closest: + out_closest[cs + rep_q[ties]] = closest[ties] + if return_tri_idx: + out_tri[cs + rep_q[ties]] = tri_indices[ties] + + out_d = out_d2.sqrt() + extras = [] + if return_closest: + extras.append(out_closest) + if return_tri_idx: + extras.append(out_tri) + if extras: + return (out_d, *extras) + return out_d + + +# Sparse coarse-to-fine voxel grid in narrow band + +def _build_narrow_band_voxels(verts: torch.Tensor, faces: torch.Tensor, + center: torch.Tensor, scale: float, + resolution: int, eps: float, + progress_callback=None) -> torch.Tensor: + """Voxel coords (Nv,3) in 0..resolution-1 whose centre is within ~0.87 cell_size of the surface; also returns the kept cKDTree.""" + device = verts.device + tri_verts = verts[faces.long()] + # Exact UDF; build the centroid cKDTree once and reuse across refinement levels + tree = _build_centroid_tree(tri_verts) + + base_resolution = resolution + while base_resolution > 32 and base_resolution % 2 == 0: + base_resolution //= 2 + + rng = torch.arange(base_resolution, device=device, dtype=torch.long) + coords = torch.stack(torch.meshgrid(rng, rng, rng, indexing="ij"), dim=-1).reshape(-1, 3) + + OFFSETS = torch.tensor([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], + ], dtype=torch.long, device=device) + + current_res = base_resolution + while True: + throw_exception_if_processing_interrupted() + cell_size = scale / current_res + pts = ((coords.float() + 0.5) / current_res - 0.5) * scale + center + dists, _, _ = _udf_exact(pts, tri_verts, tree=tree) + keep = dists < 0.87 * cell_size + eps + coords = coords[keep] + if progress_callback is not None: + progress_callback() + if current_res >= resolution: + break + current_res *= 2 + coords = coords * 2 + coords = (coords.unsqueeze(1) + OFFSETS.unsqueeze(0)).reshape(-1, 3) + + return coords, tree + + +# Dual Contouring + +def _dual_contour(voxel_coords: torch.Tensor, corner_udf: torch.Tensor, + corner_keys: torch.Tensor, + resolution: int, scale: float, center: torch.Tensor, + tri_face_normals: Optional[torch.Tensor] = None, + qef_query=None, + corner_valid: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Dual contour active voxels; returns (Nv,3) dual verts and (M,3) faces into them. QEF placement when tri_face_normals+qef_query given, else centroid of crossings.""" + device = voxel_coords.device + Nv = voxel_coords.shape[0] + # 8 corners per voxel, packed into a 1d key + CORNER_OFFS = torch.tensor([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], + ], dtype=torch.long, device=device) + + corner_pos_per_voxel = voxel_coords.unsqueeze(1) + CORNER_OFFS.unsqueeze(0) # (Nv, 8, 3) + R1 = resolution + 1 + keys_per_voxel = (corner_pos_per_voxel[..., 0] * R1 + + corner_pos_per_voxel[..., 1]) * R1 + corner_pos_per_voxel[..., 2] + # Look up SDF index per corner; missing corners default to +1 (outside) + idx_per_voxel = torch.searchsorted(corner_keys, keys_per_voxel.reshape(-1)) + idx_clamped = idx_per_voxel.clamp(max=corner_keys.numel() - 1) + found = (idx_per_voxel < corner_keys.numel()) & (corner_keys[idx_clamped] == keys_per_voxel.reshape(-1)) + sd = torch.where(found, corner_udf[idx_clamped], torch.full_like(corner_udf[idx_clamped], 1.0)) + sd = sd.reshape(Nv, 8) # surface at sign = 0 + + # 12 voxel edges as (corner_a, corner_b) pairs (indices into the 8 corners above) + EDGES = torch.tensor([ + [0, 1], [2, 3], [4, 5], [6, 7], # x-axis edges + [0, 2], [1, 3], [4, 6], [5, 7], # y-axis + [0, 4], [1, 5], [2, 6], [3, 7], # z-axis + ], dtype=torch.long, device=device) + + a_sd = sd[:, EDGES[:, 0]] # (Nv, 12) + b_sd = sd[:, EDGES[:, 1]] + crosses = (a_sd * b_sd) < 0 # (Nv, 12) bool + # Skip crossings touching an invalid corner (avoids fake faces at band edge) + if corner_valid is not None: + cv_per_voxel = torch.where(found, corner_valid[idx_clamped], + torch.zeros_like(found)).reshape(Nv, 8) + edge_valid = cv_per_voxel[:, EDGES[:, 0]] & cv_per_voxel[:, EDGES[:, 1]] + crosses = crosses & edge_valid + # Zero-crossing interp factor per edge + t = a_sd / (a_sd - b_sd + 1e-20) + t = t.clamp(0.0, 1.0).unsqueeze(-1) + + corner_world = (corner_pos_per_voxel.float() / resolution - 0.5) * scale + center.unsqueeze(0).unsqueeze(0) # (Nv, 8, 3) + a_pos = corner_world[:, EDGES[:, 0]] # (Nv, 12, 3) + b_pos = corner_world[:, EDGES[:, 1]] + crossing_pts = torch.lerp(a_pos, b_pos, t) # (Nv, 12, 3) + + # Default dual vert: centroid of crossings (also QEF/no-crossing fallback) + crosses_f = crosses.float().unsqueeze(-1) + crossing_sum = (crossing_pts * crosses_f).sum(dim=1) + n_cross = crosses.float().sum(dim=1, keepdim=True).clamp_min(1.0) + centroid_verts = crossing_sum / n_cross + centre_world = ((voxel_coords.float() + 0.5) / resolution - 0.5) * scale + center.unsqueeze(0) + has_cross = crosses.any(dim=1, keepdim=True) + dual_verts = torch.where(has_cross, centroid_verts, centre_world) + + # QEF placement: minimise sum_i (n_i·(x-p_i))² via Tikhonov-regularised + # normal equations (A+reg I)x=b; clamp to voxel bbox, else fall back to centroid. + if tri_face_normals is not None and qef_query is not None: + Nv = voxel_coords.shape[0] + flat_pts = crossing_pts.reshape(-1, 3) + flat_mask = crosses.reshape(-1) + if flat_mask.any(): + query_pts = flat_pts[flat_mask] + _, _, qef_tri_idx = qef_query(query_pts) + # Missed queries get a zero normal (null constraint, ignored by solver) + valid_q = qef_tri_idx >= 0 + normals_at_q = torch.zeros_like(query_pts) + normals_at_q[valid_q] = tri_face_normals[qef_tri_idx[valid_q]] + full_normals = torch.zeros((Nv * 12, 3), dtype=query_pts.dtype, device=device) + full_normals[flat_mask] = normals_at_q + n_per_edge = full_normals.reshape(Nv, 12, 3) + + # einsum sums into the 3x3 directly, skipping a big intermediate + A = torch.einsum('vec,ved->vcd', n_per_edge, n_per_edge) # (Nv, 3, 3) + n_dot_p = (n_per_edge * crossing_pts).sum(dim=-1) # (Nv, 12) + b = torch.einsum('ve,vec->vc', n_dot_p, n_per_edge) # (Nv, 3) + + # Tikhonov regularisation in-place (A, b are fresh einsum outputs) + reg = 1e-2 + A.diagonal(dim1=-2, dim2=-1).add_(reg) + b.add_(centroid_verts, alpha=reg) + try: + qef_solution = torch.linalg.solve(A, b.unsqueeze(-1)).squeeze(-1) + except Exception: + qef_solution = centroid_verts + + # Clamp QEF output to the voxel bbox + lo = corner_world[:, 0] # (Nv, 3) min corner + hi = corner_world[:, 7] # (Nv, 3) max corner + in_box = (qef_solution >= lo).all(dim=-1) & (qef_solution <= hi).all(dim=-1) + qef_solution = torch.where(in_box.unsqueeze(-1), qef_solution, centroid_verts) + + dual_verts = torch.where(has_cross, qef_solution, centre_world) + + # Topology: each crossing grid edge is shared by 4 voxels -> quad -> 2 tris. + # NEIGHBOUR_OFFS lays out the 4 sharing voxels per axis; y-axis order is + # reversed vs x/z to keep manifold winding around each shared edge. + NEIGHBOUR_OFFS = torch.tensor([ + [[0, 0, 0], [0, -1, 0], [0, -1, -1], [0, 0, -1]], + [[0, 0, 0], [0, 0, -1], [-1, 0, -1], [-1, 0, 0]], + [[0, 0, 0], [-1, 0, 0], [-1, -1, 0], [0, -1, 0]], + ], dtype=torch.long, device=device) + + # Min-corner +axis edge index per axis (slots 0/4/8 in EDGES) + EDGE_OF_AXIS = torch.tensor([0, 4, 8], dtype=torch.long, device=device) + + # Sorted voxel-coord keys for neighbour lookup + vox_dims = voxel_coords.max(dim=0)[0] + 2 + vox_key = (voxel_coords[:, 0] * vox_dims[1] + voxel_coords[:, 1]) * vox_dims[2] + voxel_coords[:, 2] + sort_v = vox_key.argsort() + sorted_vox_key = vox_key[sort_v] + + tris = [] + for axis in range(3): + edge_idx = EDGE_OF_AXIS[axis] + owner_mask = crosses[:, edge_idx] # (Nv,) bool + if not owner_mask.any(): + continue + owner_voxels = voxel_coords[owner_mask] # (No, 3) + a_sign = a_sd[owner_mask, edge_idx] # (No,) sign at corner a + nbrs = owner_voxels.unsqueeze(1) + NEIGHBOUR_OFFS[axis].unsqueeze(0) # (No, 4, 3) + nbr_keys = (nbrs[..., 0] * vox_dims[1] + nbrs[..., 1]) * vox_dims[2] + nbrs[..., 2] + flat = nbr_keys.reshape(-1).contiguous() + ins = torch.searchsorted(sorted_vox_key, flat) + ins_c = ins.clamp(max=sorted_vox_key.numel() - 1) + valid = (ins < sorted_vox_key.numel()) & (sorted_vox_key[ins_c] == flat) + valid = valid.reshape(-1, 4).all(dim=1) + if not valid.any(): + continue + dual_indices = sort_v[ins_c].reshape(-1, 4)[valid] # (Mv, 4) + sign_a = a_sign[valid] + # Winding: flip when corner a is outside (sign_a > 0) so normal points out + d0 = dual_indices[:, 0] + d1 = dual_indices[:, 1] + d2 = dual_indices[:, 2] + d3 = dual_indices[:, 3] + flip = sign_a > 0 + t1a = torch.stack([d0, d1, d2], dim=1) + t2a = torch.stack([d0, d2, d3], dim=1) + t1b = torch.stack([d0, d2, d1], dim=1) + t2b = torch.stack([d0, d3, d2], dim=1) + t1 = torch.where(flip.unsqueeze(-1), t1b, t1a) + t2 = torch.where(flip.unsqueeze(-1), t2b, t2a) + tris.append(t1) + tris.append(t2) + + if not tris: + return dual_verts, torch.empty((0, 3), dtype=torch.long, device=device) + new_faces = torch.cat(tris, dim=0) + return dual_verts, new_faces + + +# Manifold Dual Contouring (Schaefer, Ju, Warren 2007) + +@functools.lru_cache(maxsize=None) +def _build_mdc_lut() -> Tuple[torch.Tensor, torch.Tensor]: + """Per 8-corner sign pattern: K (256,) patch count and group (256,12) patch id per edge (-1 if non-crossing).""" + EDGE_PAIRS = [ + (0, 1), (2, 3), (4, 5), (6, 7), # x-axis edges + (0, 2), (1, 3), (4, 6), (5, 7), # y-axis edges + (0, 4), (1, 5), (2, 6), (3, 7), # z-axis edges + ] + K = torch.zeros(256, dtype=torch.int64) + group = torch.full((256, 12), -1, dtype=torch.int64) + + for pat in range(256): + signs = [(pat >> i) & 1 for i in range(8)] # 1=outside, 0=inside + + parent = list(range(8)) + + def find(x: int) -> int: + r = x + while parent[r] != r: + r = parent[r] + while parent[x] != r: + nxt = parent[x] + parent[x] = r + x = nxt + return r + + # Union same-sign corners (not separated by the surface) + for a, b in EDGE_PAIRS: + if signs[a] == signs[b]: + ra, rb = find(a), find(b) + if ra != rb: + parent[ra] = rb + + # Distinct (interior_root, exterior_root) pairs are distinct patches + group_map: dict[tuple[int, int], int] = {} + for ei, (a, b) in enumerate(EDGE_PAIRS): + if signs[a] == signs[b]: + continue + in_c = a if signs[a] == 0 else b + ex_c = b if signs[a] == 0 else a + key = (find(in_c), find(ex_c)) + if key not in group_map: + group_map[key] = len(group_map) + group[pat, ei] = group_map[key] + K[pat] = len(group_map) + + return K, group + + +@functools.lru_cache(maxsize=None) +def _mdc_lut(device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + K, g = _build_mdc_lut() + return K.to(device), g.to(device) + + +def _dual_contour_manifold(voxel_coords: torch.Tensor, corner_udf: torch.Tensor, + corner_keys: torch.Tensor, + resolution: int, scale: float, center: torch.Tensor, + corner_valid: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Manifold DC: like _dual_contour but emits 1-4 dual verts per voxel via the patch LUT (centroid placement only).""" + device = voxel_coords.device + Nv = voxel_coords.shape[0] + + CORNER_OFFS = torch.tensor([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], + ], dtype=torch.long, device=device) + corner_pos = voxel_coords.unsqueeze(1) + CORNER_OFFS.unsqueeze(0) # (Nv, 8, 3) + R1 = resolution + 1 + keys = (corner_pos[..., 0] * R1 + corner_pos[..., 1]) * R1 + corner_pos[..., 2] + flat_keys = keys.reshape(-1) + idx = torch.searchsorted(corner_keys, flat_keys) + idx_c = idx.clamp(max=corner_keys.numel() - 1) + found = (idx < corner_keys.numel()) & (corner_keys[idx_c] == flat_keys) + sd = torch.where(found, corner_udf[idx_c], + torch.full_like(corner_udf[idx_c], 1.0)).reshape(Nv, 8) + + # Sign pattern: bit i = (sd[i] > 0), matching the LUT convention + sign_bits = (sd > 0).to(torch.int64) # (Nv, 8) + weights = (1 << torch.arange(8, device=device, dtype=torch.int64)) + pat_per_voxel = (sign_bits * weights).sum(dim=-1) # (Nv,) in 0..255 + + K_lut, group_lut = _mdc_lut(device) + K_per_voxel = K_lut[pat_per_voxel] # (Nv,) + total_verts = int(K_per_voxel.sum().item()) + if total_verts == 0: + return (torch.empty((0, 3), dtype=voxel_coords.dtype, device=device), + torch.empty((0, 3), dtype=torch.long, device=device)) + + vert_offset = (torch.cumsum(K_per_voxel, dim=0) - K_per_voxel) # (Nv,) + voxel_per_subvol = torch.repeat_interleave( + torch.arange(Nv, device=device), K_per_voxel) # (total_verts,) + + EDGES = torch.tensor([ + [0, 1], [2, 3], [4, 5], [6, 7], + [0, 2], [1, 3], [4, 6], [5, 7], + [0, 4], [1, 5], [2, 6], [3, 7], + ], dtype=torch.long, device=device) + sb_a = sign_bits[:, EDGES[:, 0]] # (Nv, 12) + sb_b = sign_bits[:, EDGES[:, 1]] + crosses = sb_a != sb_b # (Nv, 12) + + if corner_valid is not None: + cv = torch.where(found, corner_valid[idx_c], + torch.zeros_like(found)).reshape(Nv, 8) + edge_valid = cv[:, EDGES[:, 0]] & cv[:, EDGES[:, 1]] + crosses = crosses & edge_valid + + edge_group_per_voxel = group_lut[pat_per_voxel] # (Nv, 12), -1 if not crossing + # LUT already gives -1 for non-crossing edges; only re-mask for corner_valid + if corner_valid is not None: + edge_group_per_voxel = torch.where(crosses, edge_group_per_voxel, + torch.full_like(edge_group_per_voxel, -1)) + + a_sd = sd[:, EDGES[:, 0]] # (Nv, 12) + b_sd = sd[:, EDGES[:, 1]] + denom = a_sd - b_sd + t = torch.where(denom.abs() > 1e-20, a_sd / denom, torch.zeros_like(a_sd)) + t = t.clamp(0.0, 1.0).unsqueeze(-1) + corner_world = (corner_pos.float() / resolution - 0.5) * scale + center.unsqueeze(0).unsqueeze(0) + a_pos = corner_world[:, EDGES[:, 0]] + b_pos = corner_world[:, EDGES[:, 1]] + crossing_pts = torch.lerp(a_pos, b_pos, t) # (Nv, 12, 3) + + # Aggregate crossing positions per (voxel, subvolume) into global dual verts + flat_group = edge_group_per_voxel.reshape(-1) + valid_mask = flat_group >= 0 + flat_voxel = torch.arange(Nv, device=device).unsqueeze(-1).expand(Nv, 12).reshape(-1) + flat_pos = crossing_pts.reshape(-1, 3) + v_idx = flat_voxel[valid_mask] + g_idx = flat_group[valid_mask] + pos = flat_pos[valid_mask] + global_idx = vert_offset[v_idx] + g_idx # (Nvalid,) + + pos_dtype = crossing_pts.dtype + sums = torch.zeros((total_verts, 3), dtype=pos_dtype, device=device) + counts = torch.zeros(total_verts, dtype=pos_dtype, device=device) + sums.scatter_add_(0, global_idx.unsqueeze(-1).expand(-1, 3), pos) + counts.scatter_add_(0, global_idx, torch.ones_like(g_idx, dtype=pos_dtype)) + # Fully-masked subvolumes default to the voxel centre (unreferenced) + voxel_centre = ((voxel_coords.float() + 0.5) / resolution - 0.5) * scale + center.unsqueeze(0) + dual_verts = torch.where( + counts.unsqueeze(-1) > 0, + sums / counts.clamp_min(1.0).unsqueeze(-1), + voxel_centre[voxel_per_subvol].to(pos_dtype), + ) + + # Face emission. SHARED_LOCAL_EDGE[axis,k] = the k-th neighbour's local edge + # slot corresponding to the shared grid edge (owner's slot = EDGE_OF_AXIS[axis]). + NEIGHBOUR_OFFS = torch.tensor([ + [[0, 0, 0], [0, -1, 0], [0, -1, -1], [0, 0, -1]], + [[0, 0, 0], [0, 0, -1], [-1, 0, -1], [-1, 0, 0]], + [[0, 0, 0], [-1, 0, 0], [-1, -1, 0], [0, -1, 0]], + ], dtype=torch.long, device=device) + SHARED_LOCAL_EDGE = torch.tensor([ + [0, 1, 3, 2], # x-axis + [4, 6, 7, 5], # y-axis + [8, 9, 11, 10], # z-axis + ], dtype=torch.long, device=device) + EDGE_OF_AXIS = torch.tensor([0, 4, 8], dtype=torch.long, device=device) + + vox_dims = voxel_coords.max(dim=0)[0] + 2 + vox_key = (voxel_coords[:, 0] * vox_dims[1] + voxel_coords[:, 1]) * vox_dims[2] + voxel_coords[:, 2] + sort_v = vox_key.argsort() + sorted_vox_key = vox_key[sort_v] + + tris_out = [] + for axis in range(3): + edge_idx = EDGE_OF_AXIS[axis] + owner_mask = crosses[:, edge_idx] + if not owner_mask.any(): + continue + owner_voxels = voxel_coords[owner_mask] + sign_a_at_owner = sb_a[owner_mask, edge_idx] # (No,) — 0 inside, 1 outside + + nbrs = owner_voxels.unsqueeze(1) + NEIGHBOUR_OFFS[axis].unsqueeze(0) # (No, 4, 3) + nbr_keys = (nbrs[..., 0] * vox_dims[1] + nbrs[..., 1]) * vox_dims[2] + nbrs[..., 2] + flat = nbr_keys.reshape(-1).contiguous() + ins = torch.searchsorted(sorted_vox_key, flat) + ins_c = ins.clamp(max=sorted_vox_key.numel() - 1) + valid_nbr = (ins < sorted_vox_key.numel()) & (sorted_vox_key[ins_c] == flat) + valid_quad = valid_nbr.reshape(-1, 4).all(dim=1) + if not valid_quad.any(): + continue + + nbr_orig = sort_v[ins_c].reshape(-1, 4)[valid_quad] # (Mv, 4) voxel idx + nbr_pat = pat_per_voxel[nbr_orig] # (Mv, 4) + local_e = SHARED_LOCAL_EDGE[axis].unsqueeze(0).expand_as(nbr_pat) + nbr_subvol = group_lut[nbr_pat, local_e] # (Mv, 4) + # Every neighbour must agree the shared edge is crossing + ok = (nbr_subvol >= 0).all(dim=1) + if not ok.any(): + continue + nbr_subvol = nbr_subvol[ok] + nbr_orig = nbr_orig[ok] + dual_indices = vert_offset[nbr_orig] + nbr_subvol # (Mv', 4) + sign_a = sign_a_at_owner[valid_quad][ok] # 0 = inside, 1 = outside + + # Winding: flip when corner a is outside (same as _dual_contour) + flip = sign_a > 0 + d0, d1, d2, d3 = dual_indices.unbind(dim=1) + t1a = torch.stack([d0, d1, d2], dim=1) + t2a = torch.stack([d0, d2, d3], dim=1) + t1b = torch.stack([d0, d2, d1], dim=1) + t2b = torch.stack([d0, d3, d2], dim=1) + tris_out.append(torch.where(flip.unsqueeze(-1), t1b, t1a)) + tris_out.append(torch.where(flip.unsqueeze(-1), t2b, t2a)) + + if not tris_out: + return dual_verts, torch.empty((0, 3), dtype=torch.long, device=device) + return dual_verts, torch.cat(tris_out, dim=0) + + +# Main entry + +def _filter_components(verts: torch.Tensor, faces: torch.Tensor, + min_fraction: float = 0.01, + drop_inverted: bool = True, + drop_enclosed: bool = True) -> torch.Tensor: + """Drop tiny / inverted-volume / bbox-enclosed connected components; returns filtered faces.""" + device = faces.device + V = verts.shape[0] + F = faces.shape[0] + + # Connected components via min-label propagation across faces (200-iter max) + label = torch.arange(V, dtype=torch.long, device=device) + for _ in range(200): + f_min = torch.minimum(torch.minimum(label[faces[:, 0]], label[faces[:, 1]]), + label[faces[:, 2]]) + new_label = label.clone() + new_label.scatter_reduce_(0, faces[:, 0], f_min, reduce="amin", include_self=True) + new_label.scatter_reduce_(0, faces[:, 1], f_min, reduce="amin", include_self=True) + new_label.scatter_reduce_(0, faces[:, 2], f_min, reduce="amin", include_self=True) + new_label = new_label[new_label] # path compression + if torch.equal(new_label, label): + break + label = new_label + + face_label = label[faces[:, 0]] # (F,) + unique_labels, inv = torch.unique(face_label, return_inverse=True) + C = unique_labels.shape[0] + counts = torch.bincount(inv, minlength=C) + max_count = int(counts.max().item()) + keep = torch.ones(C, dtype=torch.bool, device=device) + + if min_fraction > 0: + threshold = max(1, int(max_count * min_fraction)) + keep = keep & (counts >= threshold) + + if drop_inverted: + # Drop components with negative signed volume, but always keep the largest + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + face_vol = (v0 * torch.cross(v1, v2, dim=-1)).sum(dim=-1) # (F,) + comp_vol = torch.zeros(C, dtype=face_vol.dtype, device=device) + comp_vol.scatter_add_(0, inv, face_vol) + if C > 1: + large = counts.argmax() + vol_ok = (comp_vol >= 0) + vol_ok[large] = True + keep = keep & vol_ok + + if drop_enclosed and C > 1: + # Two-pass: (1) bbox-inside-largest test, then (2) +X raycast point-in-mesh + large = counts.argmax() + face_v = verts[faces] + face_min = face_v.min(dim=1).values + face_max = face_v.max(dim=1).values + comp_min = torch.full((C, 3), float("inf"), dtype=verts.dtype, device=device) + comp_max = torch.full((C, 3), float("-inf"), dtype=verts.dtype, device=device) + comp_min.scatter_reduce_(0, inv[:, None].expand(-1, 3), face_min, + reduce="amin", include_self=True) + comp_max.scatter_reduce_(0, inv[:, None].expand(-1, 3), face_max, + reduce="amax", include_self=True) + big_min = comp_min[large] + big_max = comp_max[large] + enclosed = ((comp_min >= big_min).all(dim=-1) + & (comp_max <= big_max).all(dim=-1)) + enclosed[large] = False + + # Per-component centroid for the raycast test + face_centroid = face_v.mean(dim=1) # (F, 3) + comp_centroid = torch.zeros((C, 3), dtype=verts.dtype, device=device) + comp_centroid.scatter_add_(0, inv[:, None].expand(-1, 3), face_centroid) + comp_centroid = comp_centroid / counts.to(verts.dtype).unsqueeze(-1).clamp_min(1.0) + + # Raycast surviving non-largest candidates (small loop) + big_faces = faces[inv == large] + bv0 = verts[big_faces[:, 0]] + bv1 = verts[big_faces[:, 1]] + bv2 = verts[big_faces[:, 2]] + candidates = torch.nonzero((keep & ~enclosed) + & (torch.arange(C, device=device) != large), + as_tuple=True)[0] + for ci in candidates.tolist(): + origin = comp_centroid[ci] + # 2D point-in-triangle in YZ for the ray origin's (y, z) + oy, oz = origin[1], origin[2] + s12 = (bv1[:, 1] - oy) * (bv2[:, 2] - oz) - (bv1[:, 2] - oz) * (bv2[:, 1] - oy) + s20 = (bv2[:, 1] - oy) * (bv0[:, 2] - oz) - (bv2[:, 2] - oz) * (bv0[:, 1] - oy) + s01 = (bv0[:, 1] - oy) * (bv1[:, 2] - oz) - (bv0[:, 2] - oz) * (bv1[:, 1] - oy) + total = s12 + s20 + s01 + inside_yz = (((s12 >= 0) & (s20 >= 0) & (s01 >= 0)) + | ((s12 <= 0) & (s20 <= 0) & (s01 <= 0))) + inside_yz = inside_yz & (total.abs() > 1e-20) + inv_t = 1.0 / total.where(total.abs() > 1e-20, torch.ones_like(total)) + hit_x = (s12 * bv0[:, 0] + s20 * bv1[:, 0] + s01 * bv2[:, 0]) * inv_t + crossings = int((inside_yz & (hit_x > origin[0])).sum().item()) + if crossings % 2 == 1: + enclosed[ci] = True + keep = keep & ~enclosed + + if keep.all(): + return faces + face_keep = keep[inv] + return faces[face_keep] + + +def _taubin_smooth(verts: torch.Tensor, faces: torch.Tensor, + iters: int, lam: float = 0.5, mu: float = -0.53, + progress_callback=None) -> torch.Tensor: + """Taubin lambda|mu low-pass smoothing (volume-preserving); boundary verts are no-ops.""" + if iters <= 0 or verts.numel() == 0 or faces.numel() == 0: + return verts + device = verts.device + V = verts.shape[0] + sorted_keys, _, _ = _sorted_edge_halfedges(faces, V) + uniq_keys, _ = torch.unique_consecutive(sorted_keys, return_counts=True) + P = V + 1 + a = uniq_keys // P + b = uniq_keys % P + ones = torch.ones_like(a, dtype=verts.dtype) + counts = torch.zeros(V, dtype=verts.dtype, device=device) + counts.scatter_add_(0, a, ones) + counts.scatter_add_(0, b, ones) + counts_safe = counts.clamp_min(1.0).unsqueeze(-1) + has_nb = (counts > 0).unsqueeze(-1) + a_exp = a.unsqueeze(-1).expand(-1, 3) + b_exp = b.unsqueeze(-1).expand(-1, 3) + + out = verts + for _ in range(iters): + throw_exception_if_processing_interrupted() + for w in (lam, mu): + sums = torch.zeros_like(out) + sums.scatter_add_(0, a_exp, out[b]) + sums.scatter_add_(0, b_exp, out[a]) + delta = (sums / counts_safe - out) * has_nb + out = out + w * delta + if progress_callback is not None: + progress_callback() + return out + + +def _fix_poles(verts: torch.Tensor, faces: torch.Tensor, + colors: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Midpoint-collapse edge-sharing valence-3 vertex pairs (DC T-junction poles); boundary verts excluded.""" + device = verts.device + V = verts.shape[0] + if V == 0 or faces.numel() == 0: + return verts, faces, colors + + sorted_keys, _, _ = _sorted_edge_halfedges(faces, V) + uniq_keys, key_counts = torch.unique_consecutive(sorted_keys, return_counts=True) + P = V + 1 + a = uniq_keys // P + b = uniq_keys % P + # Boundary verts (endpoints of single-face edges) are excluded from poles + boundary_v = torch.zeros(V, dtype=torch.bool, device=device) + bnd_mask = key_counts == 1 + if bnd_mask.any(): + boundary_v[a[bnd_mask]] = True + boundary_v[b[bnd_mask]] = True + ones = torch.ones_like(a) + valence = torch.zeros(V, dtype=torch.long, device=device) + valence.scatter_add_(0, a, ones) + valence.scatter_add_(0, b, ones) + is_pole = (valence == 3) & ~boundary_v + if int(is_pole.sum().item()) < 2: + return verts, faces, colors + + pp_edge = is_pole[a] & is_pole[b] + if not pp_edge.any(): + return verts, faces, colors + cand_a = a[pp_edge] + cand_b = b[pp_edge] + + # Greedy maximal matching: accept candidates whose endpoints are still free + used = torch.zeros(V, dtype=torch.bool, device="cpu") + cand_a_cpu = cand_a.cpu().tolist() + cand_b_cpu = cand_b.cpu().tolist() + pairs: list[tuple[int, int]] = [] + for ai, bi in zip(cand_a_cpu, cand_b_cpu): + if not used[ai] and not used[bi]: + pairs.append((ai, bi)) + used[ai] = True + used[bi] = True + if not pairs: + return verts, faces, colors + + pairs_t = torch.tensor(pairs, dtype=torch.long, device=device) # (P, 2) + keep_i = torch.minimum(pairs_t[:, 0], pairs_t[:, 1]) + drop_i = torch.maximum(pairs_t[:, 0], pairs_t[:, 1]) + + new_verts = verts.clone() + new_verts[keep_i] = 0.5 * (verts[pairs_t[:, 0]] + verts[pairs_t[:, 1]]) + new_colors = None + if colors is not None: + new_colors = colors.clone() + new_colors[keep_i] = 0.5 * (colors[pairs_t[:, 0]] + colors[pairs_t[:, 1]]) + + remap = torch.arange(V, dtype=torch.long, device=device) + remap[drop_i] = keep_i + new_faces = remap[faces.long()] + degen = ((new_faces[:, 0] == new_faces[:, 1]) + | (new_faces[:, 1] == new_faces[:, 2]) + | (new_faces[:, 0] == new_faces[:, 2])) + new_faces = new_faces[~degen] + + used_mask = torch.zeros(V, dtype=torch.bool, device=device) + used_mask[new_faces.reshape(-1)] = True + if not used_mask.all(): + compact = used_mask.long().cumsum(0) - 1 + new_verts = new_verts[used_mask] + if new_colors is not None: + new_colors = new_colors[used_mask] + new_faces = compact[new_faces] + return new_verts, new_faces.to(faces.dtype), new_colors + + +def remesh_narrow_band_dc( + vertices: torch.Tensor, + faces: torch.Tensor, + resolution: int = 256, + target_faces: int = 0, # 0 = use `resolution`; >0 = auto-derive resolution + band: float = 1.0, + project_back: float = 0.0, + qef: bool = True, + sign_mode: str = "udf", # "sdf" | "udf" + drop_small_components: float = 0.01, # drop components below this fraction of max + drop_inverted_components: bool = True, # drop closed components with negative signed volume + drop_enclosed_components: bool = True, # drop components whose bbox is inside the largest's bbox + fix_poles: bool = False, # collapse 3-3 valence vertex pairs (DC T-junction artifact) + smooth_iters: int = 0, # Taubin smoothing iterations (low-pass, volume-preserving) + smooth_lambda: float = 0.5, + smooth_mu: float = -0.53, + manifold: bool = False, # Manifold DC: emit 1-4 dual verts per voxel for multi-sheet cases + colors: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + center: Optional[torch.Tensor] = None, +): + """Narrow-band Dual Contouring re-extraction; returns (new_vertices, new_faces, new_colors), new_colors None unless `colors` given. + + Key params: target_faces>0 auto-derives resolution; sign_mode sdf/udf + (UDF disables qef and may need component filters); project_back lerps verts + toward the closest surface point; scale/center default to bbox. + """ + assert vertices.ndim == 2 and vertices.shape[1] == 3 + assert faces.ndim == 2 and faces.shape[1] == 3 + device = vertices.device + + if center is None: + center = 0.5 * (vertices.max(dim=0)[0] + vertices.min(dim=0)[0]) + else: + center = center.to(device=device, dtype=vertices.dtype) + if scale is None: + bbox = vertices.max(dim=0)[0] - vertices.min(dim=0)[0] + scale = float(bbox.max().item()) * 1.1 + + # Auto-derive resolution from target_faces (~3 tris/crossing-voxel; +-30%) + if target_faces > 0: + tv = vertices[faces.long()] + cross_v = torch.cross(tv[:, 1] - tv[:, 0], tv[:, 2] - tv[:, 0], dim=-1) + surface_area = 0.5 * cross_v.norm(dim=-1).sum().item() + relative_area = max(surface_area / (scale * scale), 1e-6) + derived = int(math.sqrt(target_faces / (3.0 * relative_area))) + # Round to a multiple of 32 (builder doubles from a <=32 base) + derived = ((derived + 31) // 32) * 32 + derived = max(32, min(1024, derived)) + resolution = derived + + eps = band * scale / resolution + + # progress: one tick per narrow-band level + 3 stages (SDF/DC/post) + each smoothing iter + n_levels, _b = 1, resolution + while _b > 32 and _b % 2 == 0: + _b //= 2 + while _b < resolution: + _b *= 2 + n_levels += 1 + _total_ticks = n_levels + 3 + int(smooth_iters) + _pbar = comfy.utils.ProgressBar(_total_ticks) + try: + _tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False) + except Exception: + _tq = None + + def tick(): + _pbar.update(1) + if _tq is not None: + _tq.update(1) + + # Step 1: sparse narrow-band voxel grid (coarse-to-fine) + voxel_coords, _band_tree = _build_narrow_band_voxels( + vertices, faces, center, scale, resolution, eps, + progress_callback=tick) + if voxel_coords.numel() == 0: + return (torch.empty((0, 3), dtype=vertices.dtype, device=device), + torch.empty((0, 3), dtype=faces.dtype, device=device), + None if colors is None else torch.empty((0, colors.shape[1]), + dtype=colors.dtype, device=device)) + + # Step 2: collect unique corner positions of all active voxels + CORNER_OFFS = torch.tensor([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1], + ], dtype=torch.long, device=device) + corners = (voxel_coords.unsqueeze(1) + CORNER_OFFS.unsqueeze(0)).reshape(-1, 3) + R1 = resolution + 1 + corner_keys = (corners[:, 0] * R1 + corners[:, 1]) * R1 + corners[:, 2] + unique_corner_keys, corner_inv = torch.unique(corner_keys, return_inverse=True) + unique_corners = torch.zeros((unique_corner_keys.shape[0], 3), dtype=torch.long, device=device) + unique_corners[corner_inv] = corners + + if sign_mode == "sdf": + use_sdf = True + elif sign_mode == "udf": + use_sdf = False + else: + raise ValueError(f"sign_mode must be 'sdf'|'udf', got {sign_mode!r}") + + # Step 3: distance field at every unique corner. + tri_verts_g = vertices[faces.long()] + centroids = tri_verts_g.mean(dim=1) + tri_radii = (tri_verts_g - centroids.unsqueeze(1)).norm(dim=-1).max(dim=-1).values + # face normals: needed for the SDF sign AND for QEF placement (QEF is sign-agnostic, + # so it works in UDF mode too — (n·(x-p))² is unchanged by normal orientation) + if use_sdf or qef: + tri_face_normals_all = torch.nn.functional.normalize( + torch.cross(tri_verts_g[:, 1] - tri_verts_g[:, 0], + tri_verts_g[:, 2] - tri_verts_g[:, 0], dim=-1), + p=2, dim=-1, eps=1e-12) + cell_size = scale / resolution + corner_world = (unique_corners.float() / resolution - 0.5) * scale + center.unsqueeze(0) + # Exact corner UDF (no max_dist cap) so DC crossings keep fine detail + udf, corner_closest, corner_tri = _udf_exact(corner_world, tri_verts_g, tree=_band_tree) + corner_valid = corner_tri >= 0 + if use_sdf: + sign = torch.ones_like(udf) + n_for_corner = tri_face_normals_all[corner_tri.clamp(min=0)] + offset = corner_world - corner_closest + sign_dot = (offset * n_for_corner).sum(-1) + sign = torch.where(corner_valid & (sign_dot < 0), -sign, sign) + sdf = sign * udf + else: + # UDF mode: iso at UDF=eps; double surface on closed meshes, weld after + sdf = udf - eps + tick() # SDF done + + # Short-range hash reused by project_back / colors sampling (max_dist up to 4*cell) + short_hash_cell_t = torch.tensor(2.0 * cell_size, dtype=vertices.dtype, device=device) + short_hash = _build_tri_spatial_hash(centroids, tri_radii, short_hash_cell_t) + + # Step 4 + 5: dual contouring + topology. QEF works in both modes (sign-agnostic); + # in UDF it pulls the ±eps crossing back onto the triangle planes → sharper edges. + if qef: + tri_face_normals = tri_face_normals_all + # QEF needs the nearest triangle per crossing point. The centroid cKDTree + # (_band_tree) is already built, and its exact k-NN query is markedly faster + # here than a spatial-hash gather (which builds ~100-triangle candidate lists + # per query on a dense input) — and it's exact. So reuse it directly. + def _qef_query(pts): + return _udf_exact(pts, tri_verts_g, tree=_band_tree) + else: + tri_face_normals = None + _qef_query = None + + if manifold and use_sdf: + # MDC ignores qef / tri_face_normals — centroid placement only. + dual_verts, new_faces = _dual_contour_manifold( + voxel_coords, sdf, unique_corner_keys, + resolution, scale, center, + corner_valid=corner_valid) + else: + dual_verts, new_faces = _dual_contour( + voxel_coords, sdf, unique_corner_keys, + resolution, scale, center, + tri_face_normals=tri_face_normals, qef_query=_qef_query, + # corner_valid filter only matters in SDF mode + corner_valid=corner_valid if use_sdf else None) + tick() # DC done + + # Step 6: project_back and / or color sampling share one closest-point query + need_query = (project_back > 0 or colors is not None) and dual_verts.numel() > 0 + out_colors = None + if need_query: + result = _udf_query( + dual_verts, tri_verts_g, short_hash, short_hash_cell_t, + max_dist=4.0 * cell_size, + return_closest=True, + return_tri_idx=(colors is not None)) + if colors is not None: + _, closest_pts, closest_tri = result + else: + _, closest_pts = result + + if project_back > 0: + dual_verts = torch.lerp(dual_verts, closest_pts, float(project_back)) + + if colors is not None: + # Barycentric-interpolate input colors at the closest point + safe_tri = closest_tri.clamp(min=0) + tri_v_idx = faces[safe_tri].long() # (N, 3) + tri_v = vertices[tri_v_idx] # (N, 3, 3) + v0 = tri_v[:, 0]; v1 = tri_v[:, 1]; v2 = tri_v[:, 2] + e0 = v1 - v0 + e1 = v2 - v0 + e2 = closest_pts - v0 + d00 = (e0 * e0).sum(-1) + d01 = (e0 * e1).sum(-1) + d11 = (e1 * e1).sum(-1) + d20 = (e2 * e0).sum(-1) + d21 = (e2 * e1).sum(-1) + denom = d00 * d11 - d01 * d01 + 1e-20 + bv = ((d11 * d20 - d01 * d21) / denom).clamp(0.0, 1.0) + bw = ((d00 * d21 - d01 * d20) / denom).clamp(0.0, 1.0) + bu = (1.0 - bv - bw).clamp(0.0, 1.0) + tri_c = colors[tri_v_idx] # (N, 3, C) + out_colors = (bu.unsqueeze(-1) * tri_c[:, 0] + + bv.unsqueeze(-1) * tri_c[:, 1] + + bw.unsqueeze(-1) * tri_c[:, 2]) + # Zero out failed-query rows (their barycentric used bogus triangle 0) + invalid = closest_tri < 0 + if invalid.any(): + out_colors[invalid] = 0 + + # Filter spurious components (tiny pieces, inverted inner shells) + if (new_faces.numel() > 0 + and (drop_small_components > 0 or drop_inverted_components + or drop_enclosed_components)): + new_faces = _filter_components( + dual_verts, new_faces, + min_fraction=drop_small_components if drop_small_components > 0 else 0.0, + drop_inverted=drop_inverted_components, + drop_enclosed=drop_enclosed_components) + + if fix_poles and new_faces.numel() > 0: + dual_verts, new_faces, out_colors = _fix_poles( + dual_verts, new_faces, out_colors) + tick() # post-process done + + if smooth_iters > 0 and dual_verts.numel() > 0 and new_faces.numel() > 0: + dual_verts = _taubin_smooth(dual_verts, new_faces, + iters=int(smooth_iters), + lam=float(smooth_lambda), + mu=float(smooth_mu), + progress_callback=tick) + + # Drop unused verts (non-crossing voxels' dual verts) and compact faces + if dual_verts.numel() > 0 and new_faces.numel() > 0: + used = torch.zeros(dual_verts.shape[0], dtype=torch.bool, device=device) + used[new_faces[:, 0]] = True + used[new_faces[:, 1]] = True + used[new_faces[:, 2]] = True + remap = used.long().cumsum(0) - 1 + dual_verts = dual_verts[used] + new_faces = remap[new_faces.long()] + if out_colors is not None: + out_colors = out_colors[used] + + return (dual_verts.to(vertices.dtype), + new_faces.to(faces.dtype), + out_colors.to(colors.dtype) if (out_colors is not None and colors is not None) else None) diff --git a/comfy_extras/mesh3d/uv_unwrap/mesh.py b/comfy_extras/mesh3d/uv_unwrap/mesh.py new file mode 100644 index 000000000..da0098f74 --- /dev/null +++ b/comfy_extras/mesh3d/uv_unwrap/mesh.py @@ -0,0 +1,158 @@ +"""Mesh container, edge/face adjacency, manifold cleanup.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List + +import numpy as np +import torch +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import connected_components +from torch import Tensor + + +# ---- Per-face / per-vertex geometry ---- + +def face_normals(vertices: Tensor, faces: Tensor) -> Tensor: + """[F,3] unit face normals (degenerate faces -> zero).""" + v0 = vertices[faces[:, 0]]; v1 = vertices[faces[:, 1]]; v2 = vertices[faces[:, 2]] + n = torch.linalg.cross(v1 - v0, v2 - v0) + return n / n.norm(dim=1, keepdim=True).clamp_min(1e-20) + + +def face_areas(vertices: Tensor, faces: Tensor) -> Tensor: + """[F] triangle areas.""" + v0 = vertices[faces[:, 0]]; v1 = vertices[faces[:, 1]]; v2 = vertices[faces[:, 2]] + return 0.5 * torch.linalg.cross(v1 - v0, v2 - v0).norm(dim=1) + + +def face_centroids(vertices: Tensor, faces: Tensor) -> Tensor: + """[F,3] triangle centroids.""" + return vertices[faces].mean(dim=1) + + +def face_edge_lengths(vertices: Tensor, faces: Tensor) -> Tensor: + """[F,3] edge lengths; column e = |v[faces[:,e]] - v[faces[:,(e+1)%3]]|.""" + va = vertices[faces] + vb = vertices[faces.roll(shifts=-1, dims=1)] + return (vb - va).norm(dim=-1).to(torch.float32) + + +def chart_3d_areas(face_area: Tensor, face_chart: Tensor, n_charts: int) -> Tensor: + """[n_charts] sum of face areas per chart.""" + out = torch.zeros(n_charts, dtype=face_area.dtype, device=face_area.device) + out.scatter_add_(0, face_chart, face_area) + return out + + +@dataclass +class MeshData: + """Cleaned mesh with adjacency; face_face[f, i] = face sharing edge (faces[f,i], faces[f,(i+1)%3]) or -1 if boundary.""" + + vertices: Tensor # [V, 3] float + faces: Tensor # [F, 3] long + face_face: Tensor # [F, 3] long, neighbor face id or -1 + face_normal: Tensor # [F, 3] float + face_area: Tensor # [F] float + face_centroid: Tensor # [F, 3] float + component: Tensor # [F] long, connected-component id + n_components: int + + +def build_mesh(vertices: Tensor, faces: Tensor) -> MeshData: + """Build adjacency; non-manifold edges (>2 incident faces) get no neighbor and act as boundary.""" + if vertices.dtype != torch.float32: + vertices = vertices.to(torch.float32) + if faces.dtype != torch.long: + faces = faces.to(torch.long) + + device = faces.device + V = vertices.shape[0] + F = faces.shape[0] + + # Per directed face-edge; flat layout p = f*3+i. + a = faces.flatten() + b = faces.roll(shifts=-1, dims=1).flatten() + lo = torch.minimum(a, b) + hi = torch.maximum(a, b) + edge_key = lo * (V + 1) + hi + + # Pair manifold (count==2) face-edges; others get no neighbor. + _, inverse, counts = torch.unique(edge_key, return_inverse=True, return_counts=True) + edge_count = counts[inverse] + manifold_mask = edge_count == 2 + + sort_idx = torch.argsort(edge_key, stable=True) + sorted_manifold = manifold_mask[sort_idx] + pair_positions = sort_idx[sorted_manifold] + pair_a = pair_positions[0::2] + pair_b = pair_positions[1::2] + + face_id_flat = torch.arange(F, device=device).repeat_interleave(3) + face_face_flat = torch.full((3 * F,), -1, dtype=torch.long, device=device) + face_face_flat[pair_a] = face_id_flat[pair_b] + face_face_flat[pair_b] = face_id_flat[pair_a] + face_face = face_face_flat.view(F, 3) + + face_face_np = face_face.cpu().numpy() + rows_mask = face_face_np >= 0 + if rows_mask.any(): + rows = np.broadcast_to(np.arange(F)[:, None], (F, 3))[rows_mask] + cols = face_face_np[rows_mask] + adj = csr_matrix( + (np.ones(rows.size, dtype=np.int8), (rows, cols)), + shape=(F, F), + ) + else: + adj = csr_matrix((F, F), dtype=np.int8) + n_components, labels = connected_components(adj, directed=False) + + face_normal = face_normals(vertices, faces) + face_area = face_areas(vertices, faces) + face_centroid = face_centroids(vertices, faces) + + return MeshData( + vertices=vertices, + faces=faces, + face_face=face_face, + face_normal=face_normal, + face_area=face_area, + face_centroid=face_centroid, + component=torch.from_numpy(labels.astype(np.int64)).to(device), + n_components=int(n_components), + ) + + +def chart_boundary_loops( + faces_subset: Tensor, face_face_subset: Tensor +) -> List[List[int]]: + """Return ordered boundary vertex loops for a chart submesh (face_face_subset[f,i]==-1 marks a boundary edge).""" + F = faces_subset.shape[0] + faces_np = faces_subset.cpu().numpy() + ff = face_face_subset.cpu().numpy() + + next_v: Dict[int, int] = {} + for f in range(F): + for i in range(3): + if ff[f, i] == -1: + a = int(faces_np[f, i]) + b = int(faces_np[f, (i + 1) % 3]) + next_v[a] = b + + loops: List[List[int]] = [] + visited = set() + for start in list(next_v.keys()): + if start in visited: + continue + loop = [start] + visited.add(start) + cur = next_v.get(start) + while cur is not None and cur != start: + if cur in visited: + break + loop.append(cur) + visited.add(cur) + cur = next_v.get(cur) + if len(loop) >= 3: + loops.append(loop) + return loops diff --git a/comfy_extras/mesh3d/uv_unwrap/pack.py b/comfy_extras/mesh3d/uv_unwrap/pack.py new file mode 100644 index 000000000..9ad77257a --- /dev/null +++ b/comfy_extras/mesh3d/uv_unwrap/pack.py @@ -0,0 +1,759 @@ +"""Atlas packing via bitmap rasterize-and-place.""" +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import List, Tuple + +import numpy as np +import torch +from torch import Tensor + +try: + from numba import njit as _njit + _HAVE_NUMBA_PACK = True +except ImportError: + _HAVE_NUMBA_PACK = False + def _njit(*args, **kwargs): + def deco(fn): return fn + return deco if not args else args[0] + + +@dataclass +class ChartPlacement: + chart_id: int + offset: Tuple[float, float] # in texels + scale: float # texels per UV unit + rotation: float = 0.0 # radians + swap_xy: bool = False # extra 90° bitmap rotation chosen at place time + chart_h: float = 0.0 # unswapped bitmap height in texels (rotation pivot) + + +@_njit(cache=True, boundscheck=False) +def _best_rotation_jit(uvs_np: np.ndarray, n_angles: int) -> float: + V = uvs_np.shape[0] + best_area = 1e30 + best_theta = 0.0 + if V == 0: + return 0.0 + half_pi = math.pi * 0.5 + for k in range(n_angles): + theta = half_pi * k / n_angles + c = math.cos(theta); s = math.sin(theta) + xmin = 1e30; xmax = -1e30 + ymin = 1e30; ymax = -1e30 + for i in range(V): + ux = uvs_np[i, 0]; uy = uvs_np[i, 1] + xr = ux * c - uy * s + yr = ux * s + uy * c + if xr < xmin: xmin = xr + if xr > xmax: xmax = xr + if yr < ymin: ymin = yr + if yr > ymax: ymax = yr + area = (xmax - xmin) * (ymax - ymin) + if area < best_area: + best_area = area + best_theta = theta + return best_theta + + +def _best_rotation(uvs_np: np.ndarray, n_angles: int = 36) -> float: + return float(_best_rotation_jit(uvs_np.astype(np.float64), n_angles)) + + +def _rotate_xy(uv: np.ndarray, theta: float) -> np.ndarray: + if theta == 0.0: + return uv + c = math.cos(theta) + s = math.sin(theta) + return np.stack([uv[:, 0] * c - uv[:, 1] * s, uv[:, 0] * s + uv[:, 1] * c], axis=1) + + +@_njit(cache=True, boundscheck=False) +def _rasterize_chart_jit( + uvs_tex: np.ndarray, faces: np.ndarray, w: int, h: int +) -> np.ndarray: + """JIT-rasterize triangles into an (h, w) bool bitmap via barycentric test.""" + bm = np.zeros((h, w), dtype=np.bool_) + F = faces.shape[0] + eps = 1e-7 + for fi in range(F): + i0 = faces[fi, 0]; i1 = faces[fi, 1]; i2 = faces[fi, 2] + x0 = uvs_tex[i0, 0]; y0 = uvs_tex[i0, 1] + x1 = uvs_tex[i1, 0]; y1 = uvs_tex[i1, 1] + x2 = uvs_tex[i2, 0]; y2 = uvs_tex[i2, 1] + xmin_f = x0 + if x1 < xmin_f: xmin_f = x1 + if x2 < xmin_f: xmin_f = x2 + xmax_f = x0 + if x1 > xmax_f: xmax_f = x1 + if x2 > xmax_f: xmax_f = x2 + ymin_f = y0 + if y1 < ymin_f: ymin_f = y1 + if y2 < ymin_f: ymin_f = y2 + ymax_f = y0 + if y1 > ymax_f: ymax_f = y1 + if y2 > ymax_f: ymax_f = y2 + xmin = int(math.floor(xmin_f)) + if xmin < 0: xmin = 0 + xmax = int(math.ceil(xmax_f)) + if xmax > w - 1: xmax = w - 1 + ymin = int(math.floor(ymin_f)) + if ymin < 0: ymin = 0 + ymax = int(math.ceil(ymax_f)) + if ymax > h - 1: ymax = h - 1 + if xmax < xmin or ymax < ymin: + continue + denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2) + if abs(denom) < 1e-20: + continue + inv_denom = 1.0 / denom + for py in range(ymin, ymax + 1): + yc = py + 0.5 + for px in range(xmin, xmax + 1): + xc = px + 0.5 + a = ((y1 - y2) * (xc - x2) + (x2 - x1) * (yc - y2)) * inv_denom + b = ((y2 - y0) * (xc - x2) + (x0 - x2) * (yc - y2)) * inv_denom + c = 1.0 - a - b + if a >= -eps and b >= -eps and c >= -eps: + bm[py, px] = True + return bm + + +def _rasterize_chart( + uvs_tex: np.ndarray, faces: np.ndarray, w: int, h: int, padding: int +) -> np.ndarray: + """Rasterize chart triangles into (h, w) bool bitmap, dilated by padding texels.""" + if faces.size == 0: + return np.zeros((h, w), dtype=bool) + bm = _rasterize_chart_jit( + uvs_tex.astype(np.float64), faces.astype(np.int64), int(w), int(h) + ) + if padding > 0: + bm = _dilate_bitmap(bm, padding) + return bm + + +def _dilate_bitmap(bm: np.ndarray, k: int) -> np.ndarray: + """k-step Manhattan max-filter dilation.""" + out = bm.copy() + for _ in range(k): + next_out = out.copy() + next_out[1:, :] |= out[:-1, :] + next_out[:-1, :] |= out[1:, :] + next_out[:, 1:] |= out[:, :-1] + next_out[:, :-1] |= out[:, 1:] + out = next_out + return out + + +@_njit(cache=True, boundscheck=False) +def _build_candidates_jit( + skyline: np.ndarray, + cur_w: int, cur_h: int, + bw0: int, bh0: int, bw1: int, bh1: int, + step: int, +) -> np.ndarray: + """Build per-chart (x, y, swap_flag) candidate positions (skyline-flush + edge-sweep, both orientations).""" + nx_skyline = (max(cur_w, 1) // step) + 2 + nx_edge = (max(cur_w, 1) // step) + 2 + ny_edge = (max(cur_h, 1) // step) + 2 + per_orient = nx_skyline + 2 * nx_edge + 2 * ny_edge + out = np.empty((per_orient * 2, 3), dtype=np.int64) + k = 0 + for swap_flag in range(2): + cw = bw0 if swap_flag == 0 else bw1 + x = 0 + while x <= cur_w: + y = 0 + x_end = x + cw + if x_end > skyline.shape[0]: + x_end = skyline.shape[0] + for xs in range(x, x_end): + if skyline[xs] > y: + y = int(skyline[xs]) + out[k, 0] = x; out[k, 1] = y; out[k, 2] = swap_flag + k += 1 + x += step + for y_fixed in (0, cur_h): + x = 0 + while x <= cur_w: + out[k, 0] = x; out[k, 1] = y_fixed; out[k, 2] = swap_flag + k += 1 + x += step + for x_fixed in (0, cur_w): + y = 0 + while y <= cur_h: + out[k, 0] = x_fixed; out[k, 1] = y; out[k, 2] = swap_flag + k += 1 + y += step + return out[:k] + + +@_njit(cache=True, boundscheck=False) +def _update_skyline_jit(skyline: np.ndarray, chart: np.ndarray, + x: int, y: int) -> None: + """Lift skyline[x+i] to y + topmost_True_row + 1 per chart column.""" + ch = chart.shape[0]; cw = chart.shape[1] + sw = skyline.shape[0] + for i in range(cw): + col_x = x + i + if col_x >= sw or col_x < 0: + continue + col_top = -1 + for j in range(ch - 1, -1, -1): + if chart[j, i]: + col_top = j + break + if col_top < 0: + continue + new_h = y + col_top + 1 + if new_h > skyline[col_x]: + skyline[col_x] = new_h + + +@_njit(cache=True, boundscheck=False) +def _best_placement_jit( + atlas: np.ndarray, + bitmap: np.ndarray, + bitmap_rot: np.ndarray, + candidates: np.ndarray, + cur_w: int, + cur_h: int, +): + """Pick lowest-score non-colliding candidate (score = max(new_w,new_h)^2 + new_w*new_h); out-of-atlas treated as free.""" + n = candidates.shape[0] + best_x = -1 + best_y = -1 + best_score = -1 + best_swap = 0 + bh0 = bitmap.shape[0]; bw0 = bitmap.shape[1] + bh1 = bitmap_rot.shape[0]; bw1 = bitmap_rot.shape[1] + ah = atlas.shape[0]; aw = atlas.shape[1] + for k in range(n): + x = candidates[k, 0] + y = candidates[k, 1] + swap = candidates[k, 2] + if swap == 0: + ch = bh0; cw = bw0 + else: + ch = bh1; cw = bw1 + if x < 0 or y < 0: + continue + nw = cur_w if cur_w > x + cw else x + cw + nh = cur_h if cur_h > y + ch else y + ch + ext = nw if nw > nh else nh + score = ext * ext + nw * nh + if best_score >= 0 and score >= best_score: + continue + ok = True + for j in range(ch): + yy = y + j + if yy >= ah: + continue + for i in range(cw): + bit = bitmap[j, i] if swap == 0 else bitmap_rot[j, i] + if not bit: + continue + xx = x + i + if xx >= aw: + continue + if atlas[yy, xx]: + ok = False + break + if not ok: + break + if not ok: + continue + best_x = x; best_y = y + best_score = score; best_swap = swap + if x + cw <= cur_w and y + ch <= cur_h: + break + return best_x, best_y, best_score, best_swap + + +def _blit(atlas: np.ndarray, chart: np.ndarray, x: int, y: int) -> None: + ah, aw = atlas.shape + ch, cw = chart.shape + atlas[y: y + ch, x: x + cw] |= chart + + +@dataclass +class _PreparedChart: + chart_id: int + uvs_tex: np.ndarray # [V, 2] in texel coords (rotated, scaled, origin 0) + bitmap: np.ndarray # [h, w] bool, padded + bitmap_rot: np.ndarray # 90° rotated bitmap (for swap_xy placement) + bbox_w: int + bbox_h: int + rotation: float # radians, applied to UVs + s_tex: float # texels per UV unit + perimeter: float # for chart ordering + + +@_njit(cache=True, boundscheck=False) +def _chart_perimeter_jit(uvs: np.ndarray, faces: np.ndarray, V: int) -> float: + """Sum unique-edge lengths via sorted int64 edge keys.""" + F = faces.shape[0] + keys = np.empty(F * 3, dtype=np.int64) + for fi in range(F): + for j in range(3): + a = faces[fi, j] + b = faces[fi, (j + 1) % 3] + if a < b: + keys[fi * 3 + j] = a * V + b + else: + keys[fi * 3 + j] = b * V + a + keys = np.sort(keys) + p = 0.0 + for i in range(keys.shape[0]): + if i > 0 and keys[i] == keys[i - 1]: + continue + a = keys[i] // V + b = keys[i] % V + dx = uvs[a, 0] - uvs[b, 0] + dy = uvs[a, 1] - uvs[b, 1] + p += math.sqrt(dx * dx + dy * dy) + return p + + +def _chart_perimeter(uvs: np.ndarray, faces: np.ndarray) -> float: + V = int(faces.max()) + 1 if faces.size else 0 + return float(_chart_perimeter_jit(uvs.astype(np.float64), faces.astype(np.int64), V)) + + +# ---- Torch fallback (used when numba is unavailable; runs on GPU if present) ---- + +def _dilate_local(x: Tensor, p: int) -> Tensor: + """4-connectivity dilation by p, applied per-image over a batch of (cnt,g,g) bitmaps. + Matches the old per-chart _dilate_torch; dilation distributes over union so per-triangle + dilation OR-scattered equals dilating the assembled chart bitmap.""" + for _ in range(p): + y = x.clone() + y[:, 1:, :] |= x[:, :-1, :]; y[:, :-1, :] |= x[:, 1:, :] + y[:, :, 1:] |= x[:, :, :-1]; y[:, :, :-1] |= x[:, :, 1:] + x = y + return x + + +def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device): + """Batched rasterize EVERY chart at once into one flat bool buffer, replacing the per-chart + loop. Returns (buf, cbase) where buf[cbase[i]:cbase[i+1]].view(bh,bw) is chart i's [y,x] bitmap. + Triangles are bucketed by next-pow2 bbox size so each batch's local grid stays tiny (bounded + memory) while collapsing ~N chart rasters into a handful of kernels.""" + n = uvs_tex_pad.shape[0] + fmax = faces_pad.shape[1] + bwL, bhL = bw_t.long(), bh_t.long() + cbase = torch.zeros(n + 1, dtype=torch.long, device=device) + torch.cumsum(bwL * bhL, 0, out=cbase[1:]) + buf = torch.zeros(int(cbase[-1].item()), dtype=torch.bool, device=device) + + # gather all triangle coords, keep only valid faces -> (Ttot,3,2) + chart id per triangle + fp = faces_pad.reshape(n, fmax * 3) + tri = torch.gather(uvs_tex_pad, 1, fp[..., None].expand(-1, -1, 2)).reshape(n * fmax, 3, 2) + fm = fmask.reshape(-1) + tri_f = tri[fm] + if tri_f.shape[0] == 0: + return buf, cbase + cid = torch.arange(n, device=device).repeat_interleave(fmax)[fm] + + # per-triangle pixel bbox, inflated by padding (origin >= 0); bucket by next-pow2 max-dim + tmin = tri_f.amin(1); tmax = tri_f.amax(1) + x0 = (tmin[:, 0].floor().long() - padding).clamp_min(0) + y0 = (tmin[:, 1].floor().long() - padding).clamp_min(0) + bbw = (tmax[:, 0].ceil().long() + padding) - x0 + 1 + bbh = (tmax[:, 1].ceil().long() + padding) - y0 + 1 + mxd = torch.maximum(bbw, bbh).clamp_min(1) + bsz = (2 ** torch.ceil(torch.log2(mxd.float())).long()).long() + + a = tri_f[:, 0]; b = tri_f[:, 1]; c = tri_f[:, 2] + v0 = b - a; v1 = c - a + d00 = (v0 * v0).sum(-1); d01 = (v0 * v1).sum(-1); d11 = (v1 * v1).sum(-1) + den = (d00 * d11 - d01 * d01).clamp(min=1e-20) + + for g in sorted(set(bsz.tolist())): # one batch per pow2 grid + sel = (bsz == g).nonzero(as_tuple=True)[0] + m = sel.shape[0] + xs0 = x0[sel].view(m, 1, 1); ys0 = y0[sel].view(m, 1, 1) + cc = cid[sel]; bwp = bwL[cc].view(m, 1, 1); bhp = bhL[cc].view(m, 1, 1) + gi = torch.arange(g, device=device) + px = xs0 + gi.view(1, 1, g); py = ys0 + gi.view(1, g, 1) # (m,g,g) int + pxf = px.float() + 0.5; pyf = py.float() + 0.5 + v2x = pxf - a[sel, 0].view(m, 1, 1); v2y = pyf - a[sel, 1].view(m, 1, 1) + d20 = v2x * v0[sel, 0].view(m, 1, 1) + v2y * v0[sel, 1].view(m, 1, 1) + d21 = v2x * v1[sel, 0].view(m, 1, 1) + v2y * v1[sel, 1].view(m, 1, 1) + idn = den[sel].view(m, 1, 1).reciprocal() + vv = torch.addcmul(d11[sel].view(m, 1, 1) * d20, d01[sel].view(m, 1, 1), d21, value=-1) * idn + ww = torch.addcmul(d00[sel].view(m, 1, 1) * d21, d01[sel].view(m, 1, 1), d20, value=-1) * idn + uu = 1.0 - vv - ww + inside = (uu >= -1e-6) & (vv >= -1e-6) & (ww >= -1e-6) + if padding > 0: + inside = _dilate_local(inside, padding) + valid = inside & (px < bwp) & (py < bhp) + flat = (cbase[cc].view(m, 1, 1) + py * bwp + px)[valid] + buf[flat] = True + return buf, cbase + + +def _build_candidates_gpu(sky_t, cur_w, cur_h, bw0, bw1, step, rand_n, gen, device): + """Skyline-flush + edge-sweep + random candidate (x,y) positions per orientation, built on the + GPU. Returns (cand0, cand1). Random samples find tight pockets the deterministic grid misses.""" + xs = torch.arange(0, max(cur_w, 1) + 1, step, device=device) + ys = torch.arange(0, max(cur_h, 1) + 1, step, device=device) + # edge-sweep candidates are orientation-independent: build once, shared by both orientations + common = [torch.stack([xs, torch.full_like(xs, yf)], 1) for yf in (0, cur_h)] + common += [torch.stack([torch.full_like(ys, xf), ys], 1) for xf in (0, cur_w)] + common = torch.cat(common, 0) + out = [] + for cw in (bw0, bw1): # skyline-flush + random differ + if cw > 0 and sky_t.shape[0] >= cw: + wmax = sky_t.unfold(0, cw, 1).amax(1)[xs.clamp(max=max(sky_t.shape[0] - cw, 0))] + else: + wmax = torch.zeros_like(xs) + parts = [torch.stack([xs, wmax], 1), common] + if rand_n > 0: # distinct draws keep density + rx = torch.randint(0, max(cur_w, 1) + 1, (rand_n,), generator=gen, device=device) + ry = torch.randint(0, max(cur_h, 1) + 1, (rand_n,), generator=gen, device=device) + parts.append(torch.stack([rx, ry], 1)) + out.append(torch.cat(parts, 0)) + return out[0], out[1] + + +def _col_top(b: Tensor) -> Tensor: + """Topmost True row index per column of a bool bitmap (h,w); -1 for empty columns.""" + h = b.shape[0] + rows = torch.arange(h, device=b.device)[:, None] + return torch.where(b, rows, torch.full_like(rows.expand_as(b), -1)).amax(0) + + +def _best_placement_torch(atlas, pix0, dim0, pix1, dim1, cand0, cand1, cur_w, cur_h, device): + """Lowest-score non-colliding candidate as a (3,) int tensor [x, y, swap] (x=-1 if none). + Collision tests only each bitmap's True-pixel offsets (pix), not the full window. Fully on-GPU; + the caller does the single sync (.tolist()).""" + INF = 1 << 60 + + def best(cand, pix, dim): # -> (score, x, y) 0-d tensors + ch, cw = dim + cx, cy = cand[:, 0], cand[:, 1] + coll = atlas[cy[:, None] + pix[:, 0][None, :], # (M,k) True-pixel gather + cx[:, None] + pix[:, 1][None, :]].any(dim=1) + nw = torch.clamp(cx + cw, min=cur_w); nh = torch.clamp(cy + ch, min=cur_h) + ext = torch.maximum(nw, nh) + score = torch.where(coll, torch.full_like(nw, INF), ext * ext + nw * nh) + j = score.argmin() + return score[j], cx[j], cy[j] + + s0, x0, y0 = best(cand0, pix0, dim0) + s1, x1, y1 = best(cand1, pix1, dim1) + take0 = s0 <= s1 + bsc = torch.where(take0, s0, s1) + pick = torch.stack([torch.where(take0, x0, x1), torch.where(take0, y0, y1), + torch.where(take0, x0.new_zeros(()), x0.new_ones(()))]) + return torch.where(bsc < INF, pick, torch.tensor([-1, -1, 0], device=device)) + + +def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces, + texels_per_unit, padding_texels): + """Torch rasterize-and-place packer (numba-free fallback). Returns (placements, atlas_w, atlas_h).""" + n = len(chart_uvs) + if n == 0: + return [], 1, 1 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ang = torch.linspace(0.0, math.pi / 2.0, 37, device=device)[:-1] + cos_a, sin_a = ang.cos(), ang.sin() + + # ---- Prepare pass 1: best-rotation + scale + bbox for ALL charts at once (batched) ---- + vcount = [int(u.shape[0]) for u in chart_uvs] + fcount = [int(f.shape[0]) for f in chart_faces] + vmax = max(vcount); fmax = max(fcount) + uvs_pad = torch.zeros(n, vmax, 2, device=device) + vmask = torch.zeros(n, vmax, dtype=torch.bool, device=device) + faces_pad = torch.zeros(n, fmax, 3, dtype=torch.long, device=device) + fmask = torch.zeros(n, fmax, dtype=torch.bool, device=device) + for i in range(n): + uvs_pad[i, :vcount[i]] = chart_uvs[i].to(device=device, dtype=torch.float32) + vmask[i, :vcount[i]] = True + if fcount[i]: + faces_pad[i, :fcount[i]] = chart_faces[i].to(device=device, dtype=torch.long) + fmask[i, :fcount[i]] = True + u0, u1 = uvs_pad[..., 0], uvs_pad[..., 1] # (N,Vmax) + BIG = 1e30 + mlo = torch.where(vmask, torch.zeros_like(u0), u0.new_full((), BIG)) + mhi = torch.where(vmask, torch.zeros_like(u0), u0.new_full((), -BIG)) + xr = torch.addcmul(u0[:, :, None] * cos_a, u1[:, :, None], sin_a, value=-1) # (N,Vmax,A) + yr = torch.addcmul(u0[:, :, None] * sin_a, u1[:, :, None], cos_a) + xsp = (xr + mhi[:, :, None]).amax(1) - (xr + mlo[:, :, None]).amin(1) # (N,A) masked span + ysp = (yr + mhi[:, :, None]).amax(1) - (yr + mlo[:, :, None]).amin(1) + ti = (xsp * ysp).argmin(1) # (N,) best angle per chart + cc, ss = cos_a[ti][:, None], sin_a[ti][:, None] # (N,1) + rx = torch.addcmul(u0 * cc, u1, ss, value=-1) # (N,Vmax) + ry = torch.addcmul(u0 * ss, u1, cc) + rxmin = (rx + mlo).amin(1); rxmax = (rx + mhi).amax(1) # (N,) + rymin = (ry + mlo).amin(1); rymax = (ry + mhi).amax(1) + a3 = torch.tensor([max(a, 1e-12) for a in chart_3d_areas], device=device) + au = torch.tensor([max(a, 1e-12) for a in chart_uv_areas], device=device) + base = (a3 / au).sqrt() * texels_per_unit + maxb = (4.0 * a3.sqrt() * texels_per_unit).clamp_min(8.0) + bbm = torch.maximum(rxmax - rxmin, rymax - rymin).clamp_min(1e-12) + scale = torch.minimum(base, maxb / bbm) # (N,) + uvs_tex_pad = torch.stack([(rx - rxmin[:, None]) * scale[:, None], + (ry - rymin[:, None]) * scale[:, None]], dim=-1) # (N,Vmax,2) + bw_t = ((rxmax - rxmin) * scale).ceil().int() + padding_texels + 1 + bh_t = ((rymax - rymin) * scale).ceil().int() + padding_texels + 1 + + # one sync: pull all per-chart scalars + thetas = ang[ti].cpu().tolist() + scales = scale.cpu().tolist() + bws = bw_t.cpu().tolist(); bhs = bh_t.cpu().tolist() + + # ---- Prepare pass 2: rasterize ALL charts at once, then trim each bitmap to its bounds ---- + buf, cbase = _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding_texels, device) + cb = cbase.cpu().tolist() + raw, bnd = [], [] + for i in range(n): + bm = buf[cb[i]:cb[i + 1]].view(bhs[i], bws[i]) + raw.append(bm) + rr = torch.arange(bm.shape[0], device=device); cc = torch.arange(bm.shape[1], device=device) + rmax = torch.where(bm.any(1), rr, rr.new_full((), -1)).amax() # last occupied row / col (-1 if empty) + cmax = torch.where(bm.any(0), cc, cc.new_full((), -1)).amax() + bnd.append(torch.stack([rmax, cmax])) + bnd_cpu = torch.stack(bnd).cpu().tolist() # one sync for all trim bounds + + # per-chart True-pixel offsets (sparse collision/blit), dims, col-tops (all kept on GPU) + pix_l, pixr_l, dim_l, dimr_l, bm_h = [], [], [], [], [] + col_tops, col_tops_rot = [], [] + for i in range(n): + rm, cm = bnd_cpu[i] + bm = (raw[i][:rm + 1, :cm + 1].contiguous() if rm >= 0 and cm >= 0 + else torch.zeros((1, 1), dtype=torch.bool, device=device)) + bm_rot = torch.flip(bm.t(), dims=[1]).contiguous() + pix_l.append(bm.nonzero()); pixr_l.append(bm_rot.nonzero()) + dim_l.append((int(bm.shape[0]), int(bm.shape[1]))) + dimr_l.append((int(bm_rot.shape[0]), int(bm_rot.shape[1]))) + col_tops.append(_col_top(bm)); col_tops_rot.append(_col_top(bm_rot)) + bm_h.append(int(bm.shape[0])) + wmax = max(d[1] for d in dim_l + dimr_l) + ct_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device) + ctr_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device) + for i in range(n): + ct_pad[i, :col_tops[i].shape[0]] = col_tops[i] + ctr_pad[i, :col_tops_rot[i].shape[0]] = col_tops_rot[i] + del raw + + # ---- Placement: skyline bin-pack on GPU (1 sync/chart for the chosen position) ---- + order = sorted(range(n), key=lambda i: -(dim_l[i][0] * dim_l[i][1])) # biggest bitmap first + max_b = max(max(d) for d in dim_l) + margin = max_b + 8 + side_guess = int(math.sqrt(sum(d[0] * d[1] for d in dim_l)) * 2) + 16 + cap = side_guess + margin + atlas = torch.zeros((cap, cap), dtype=torch.bool, device=device) + sky_t = torch.zeros(cap, dtype=torch.long, device=device) + cur_w = cur_h = 0 + placements = [None] * n + gen = torch.Generator(device=device).manual_seed(0) + rand_n = 512 # random samples per orientation + + for ci in order: + if cur_h + margin > atlas.shape[0] or cur_w + margin > atlas.shape[1]: + ns = max(atlas.shape[0], cur_h + margin, cur_w + margin) + na = torch.zeros((ns, ns), dtype=torch.bool, device=device) + na[:atlas.shape[0], :atlas.shape[1]] = atlas; atlas = na + nsk = torch.zeros(ns, dtype=torch.long, device=device); nsk[:sky_t.shape[0]] = sky_t; sky_t = nsk + dim, dimr = dim_l[ci], dimr_l[ci] + step = max(1, min(dim[0], dim[1]) // 8) + cand0, cand1 = _build_candidates_gpu(sky_t, cur_w, cur_h, dim[1], dimr[1], step, rand_n, gen, device) + res = _best_placement_torch(atlas, pix_l[ci], dim, pixr_l[ci], dimr, cand0, cand1, cur_w, cur_h, device) + bx, by, swap = (int(v) for v in res.tolist()) # the one sync/chart + if bx < 0: + bx, by, swap = cur_w, 0, 0 + pix = pixr_l[ci] if swap else pix_l[ci] + bh_, bw_ = (dimr if swap else dim) + atlas[by + pix[:, 0], bx + pix[:, 1]] = True # sparse blit + cur_w = max(cur_w, bx + bw_); cur_h = max(cur_h, by + bh_) + ct = (ctr_pad if swap else ct_pad)[ci, :bw_] # GPU skyline lift + ix = torch.arange(bx, bx + bw_, device=device) + sky_t[ix] = torch.where(ct >= 0, torch.maximum(sky_t[ix], by + ct + 1), sky_t[ix]) + placements[ci] = ChartPlacement(chart_id=ci, offset=(float(bx), float(by)), + scale=scales[ci], rotation=thetas[ci], swap_xy=bool(swap), + chart_h=float(bm_h[ci])) + return placements, cur_w, cur_h + + +def pack_bitmap( + chart_uvs: List[Tensor], + chart_3d_areas: List[float], + chart_uv_areas: List[float], + chart_faces: List[Tensor], + texels_per_unit: float = 256.0, + padding_texels: int = 2, + attempts: int = 4096, + rng_seed: int = 0, +) -> Tuple[List[ChartPlacement], int, int]: + """Rasterize-and-place packer. Returns (placements, atlas_w, atlas_h).""" + n = len(chart_uvs) + if n == 0: + return [], 1, 1 + if not _HAVE_NUMBA_PACK: + return _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, + chart_faces, texels_per_unit, padding_texels) + + rng = np.random.default_rng(rng_seed) + prepared: List[_PreparedChart] = [] + skyline_cap = 4096 + skyline = np.zeros(skyline_cap, dtype=np.int64) + + for i, (uvs_t, area_3d, area_uv, faces_t) in enumerate( + zip(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces) + ): + uvs = uvs_t.detach().cpu().numpy().astype(np.float64) + faces = faces_t.detach().cpu().numpy() + + theta = _best_rotation(uvs) + rotated = _rotate_xy(uvs, theta) + scale = math.sqrt(max(area_3d, 1e-12) / max(area_uv, 1e-12)) * texels_per_unit + # Cap per-chart bbox to 4x nominal so a degenerate chart can't span the atlas. + nominal_side = math.sqrt(max(area_3d, 1e-12)) * float(texels_per_unit) + max_bbox_texels = max(8.0, 4.0 * nominal_side) + bbox_uv = (rotated.max(axis=0) - rotated.min(axis=0)) + bbox_uv_max = float(max(bbox_uv[0], bbox_uv[1], 1e-12)) + if scale * bbox_uv_max > max_bbox_texels: + scale = max_bbox_texels / bbox_uv_max + uvs_tex = rotated * scale + uvs_tex = uvs_tex - uvs_tex.min(axis=0) + bbox_w = int(math.ceil(uvs_tex[:, 0].max())) + padding_texels + 1 + bbox_h = int(math.ceil(uvs_tex[:, 1].max())) + padding_texels + 1 + + bm = _rasterize_chart(uvs_tex, faces, bbox_w, bbox_h, padding_texels) + nz_rows = np.where(bm.any(axis=1))[0] + nz_cols = np.where(bm.any(axis=0))[0] + if nz_rows.size == 0 or nz_cols.size == 0: + bm = np.zeros((1, 1), dtype=bool) + bbox_h, bbox_w = 1, 1 + else: + bm = bm[: nz_rows[-1] + 1, : nz_cols[-1] + 1] + bbox_h, bbox_w = bm.shape + # True 90 deg rotation; plain transpose would mirror and flip winding. + bm_rot = bm.T[:, ::-1].copy() + + perim = _chart_perimeter(uvs_tex, faces) + prepared.append( + _PreparedChart( + chart_id=i, + uvs_tex=uvs_tex, + bitmap=bm, + bitmap_rot=bm_rot, + bbox_w=bbox_w, + bbox_h=bbox_h, + rotation=theta, + s_tex=scale, + perimeter=perim, + ) + ) + + order = sorted(range(n), key=lambda i: -prepared[i].perimeter) + + total_area = sum(p.bbox_w * p.bbox_h for p in prepared) + side_guess = int(math.sqrt(total_area) * 2) + 16 + atlas = np.zeros((side_guess, side_guess), dtype=bool) + cur_w = 0 + cur_h = 0 + + placements: List[ChartPlacement] = [None] * n # type: ignore + + for ci in order: + p = prepared[ci] + + step = max(1, min(p.bbox_w, p.bbox_h) // 8) + det_arr = _build_candidates_jit( + skyline, cur_w, cur_h, + p.bitmap.shape[1], p.bitmap.shape[0], + p.bitmap_rot.shape[1], p.bitmap_rot.shape[0], + step, + ) + + x_range = max(cur_w + 1, 1) + y_range = max(cur_h + 1, 1) + rand_x = rng.integers(0, x_range, size=attempts).astype(np.int64) + rand_y = rng.integers(0, y_range, size=attempts).astype(np.int64) + rand_swap = (np.arange(attempts) & 1).astype(np.int64) + rand_arr = np.stack([rand_x, rand_y, rand_swap], axis=1) + candidates = np.concatenate([det_arr, rand_arr], axis=0) if det_arr.size else rand_arr + + best_x, best_y, best_score_int, best_swap_int = _best_placement_jit( + atlas, p.bitmap, p.bitmap_rot, candidates, cur_w, cur_h, + ) + best_swap = bool(best_swap_int) + + if best_x >= 0: + bm_b = p.bitmap_rot if best_swap else p.bitmap + need_h = max(cur_h, best_y + bm_b.shape[0]) + need_w = max(cur_w, best_x + bm_b.shape[1]) + if atlas.shape[0] < need_h or atlas.shape[1] < need_w: + target_h = max(atlas.shape[0], need_h, side_guess) + target_w = max(atlas.shape[1], need_w, side_guess) + new_atlas = np.zeros((target_h, target_w), dtype=bool) + new_atlas[: atlas.shape[0], : atlas.shape[1]] = atlas + atlas = new_atlas + + if best_x < 0: + # Fallback: place at extension corner. + best_x, best_y = cur_w, 0 + best_swap = False + bm = p.bitmap + need_h = max(cur_h, best_y + bm.shape[0]) + need_w = max(cur_w, best_x + bm.shape[1]) + if atlas.shape[0] < need_h or atlas.shape[1] < need_w: + target_h = max(atlas.shape[0], need_h) + target_w = max(atlas.shape[1], need_w) + new_atlas = np.zeros((target_h, target_w), dtype=bool) + new_atlas[: atlas.shape[0], : atlas.shape[1]] = atlas + atlas = new_atlas + + bm = p.bitmap_rot if best_swap else p.bitmap + _blit(atlas, bm, best_x, best_y) + cur_w = max(cur_w, best_x + bm.shape[1]) + cur_h = max(cur_h, best_y + bm.shape[0]) + if cur_w + 1 > skyline.shape[0]: + new_sky = np.zeros(max(skyline.shape[0] * 2, cur_w + 1), dtype=np.int64) + new_sky[: skyline.shape[0]] = skyline + skyline = new_sky + _update_skyline_jit(skyline, bm, best_x, best_y) + + placements[ci] = ChartPlacement( + chart_id=ci, + offset=(float(best_x), float(best_y)), + scale=p.s_tex, + rotation=p.rotation, + swap_xy=best_swap, + chart_h=float(p.bitmap.shape[0]), + ) + + return placements, cur_w, cur_h + + +def apply_placements( + chart_uvs: List[Tensor], placements: List[ChartPlacement], atlas_w: int, atlas_h: int +) -> List[Tensor]: + """Apply per-chart (rotation, scale, swap_xy, offset) and normalize by the larger atlas side (shared scale keeps texel density uniform).""" + out: List[Tensor] = [] + side = float(max(atlas_w, atlas_h, 1)) + for uvs, p in zip(chart_uvs, placements): + device = uvs.device + dtype = uvs.dtype + uvs_np = uvs.detach().cpu().numpy().astype(np.float64) + if p.rotation != 0.0: + uvs_np = _rotate_xy(uvs_np, p.rotation) + uvs_np = uvs_np - uvs_np.min(axis=0) + uvs_np = uvs_np * p.scale + if p.swap_xy: + # 90 deg rotation matching bm.T[:, ::-1]: (u, v) -> (chart_h - v, u). + u_old = uvs_np[:, 0].copy() + uvs_np[:, 0] = p.chart_h - uvs_np[:, 1] + uvs_np[:, 1] = u_old + uvs_np[:, 0] += p.offset[0] + uvs_np[:, 1] += p.offset[1] + uvs_np /= side + # Clamp into [0,1]; slivers can stick sub-texel past the tracked extent. + np.clip(uvs_np, 0.0, 1.0, out=uvs_np) + out.append(torch.from_numpy(uvs_np).to(device=device, dtype=dtype)) + return out diff --git a/comfy_extras/mesh3d/uv_unwrap/parameterize.py b/comfy_extras/mesh3d/uv_unwrap/parameterize.py new file mode 100644 index 000000000..e20494d36 --- /dev/null +++ b/comfy_extras/mesh3d/uv_unwrap/parameterize.py @@ -0,0 +1,387 @@ +"""Chart parameterization: ortho PCA projection, falling back to ABF/LSCM.""" +from __future__ import annotations + +import warnings +from typing import List, Tuple + +import numpy as np +import scipy.sparse as sp +import scipy.sparse.linalg as spla +import torch +from torch import Tensor + +from . import mesh as _mesh + + +def solve_least_squares(A: sp.csr_matrix, b: np.ndarray) -> np.ndarray: + """Solve ||Ax - b||^2 by factorizing AtA.""" + At = A.T.tocsr() + AtA = (At @ A).tocsc() + Atb = At @ b + return spla.spsolve(AtA, Atb) + + +def _triangle_local_2d(verts_3d: np.ndarray, faces: np.ndarray) -> np.ndarray: + """Per-triangle 2D coords [F, 3, 2] with v0 at origin, v1 along +x.""" + v0 = verts_3d[faces[:, 0]] + v1 = verts_3d[faces[:, 1]] + v2 = verts_3d[faces[:, 2]] + e01 = v1 - v0 + e02 = v2 - v0 + L01 = np.linalg.norm(e01, axis=1).clip(min=1e-20) + x_axis = e01 / L01[:, None] + n = np.cross(e01, e02) + n /= np.linalg.norm(n, axis=1, keepdims=True).clip(min=1e-20) + y_axis = np.cross(n, x_axis) + + out = np.zeros((faces.shape[0], 3, 2), dtype=np.float64) + out[:, 1, 0] = L01 + out[:, 2, 0] = (e02 * x_axis).sum(axis=1) + out[:, 2, 1] = (e02 * y_axis).sum(axis=1) + return out + + +def _pick_pins(loops: List[List[int]], verts_3d: np.ndarray) -> Tuple[int, int]: + """Pick the longest-diameter axis-extremal boundary vertex pair across all boundary verts.""" + if not loops: + # Closed surface: two far verts via two-pass farthest. + d2 = np.sum((verts_3d - verts_3d[0]) ** 2, axis=1) + a = int(np.argmax(d2)) + d2 = np.sum((verts_3d - verts_3d[a]) ** 2, axis=1) + b = int(np.argmax(d2)) + return a, b + boundary_verts: List[int] = [] + for loop in loops: + boundary_verts.extend(loop) + seen = set() + uniq = [] + for v in boundary_verts: + if v not in seen: + seen.add(v) + uniq.append(v) + bv = np.asarray(uniq, dtype=np.int64) + pts = verts_3d[bv] + pin_pairs = [] + for axis in range(3): + i_min = int(bv[int(np.argmin(pts[:, axis]))]) + i_max = int(bv[int(np.argmax(pts[:, axis]))]) + d = float(np.linalg.norm(verts_3d[i_min] - verts_3d[i_max])) + pin_pairs.append((d, i_min, i_max)) + d0, _, _ = pin_pairs[0] + d1, _, _ = pin_pairs[1] + d2, _, _ = pin_pairs[2] + if d0 > d1 and d0 > d2: + _, a, b = pin_pairs[0] + elif d1 > d2: + _, a, b = pin_pairs[1] + else: + _, a, b = pin_pairs[2] + return a, b + + +def _ortho_project(verts_3d: np.ndarray) -> np.ndarray: + """PCA-fit plane normal, axis-aligned tangent, project verts to 2D.""" + centroid = verts_3d.mean(axis=0) + pts = verts_3d - centroid + cov = pts.T @ pts + _w, ev = np.linalg.eigh(cov) + normal = ev[:, 0] + a = np.abs(normal) + if a[0] < a[1] and a[0] < a[2]: + t = np.array([1.0, 0.0, 0.0]) + elif a[1] < a[2]: + t = np.array([0.0, 1.0, 0.0]) + else: + t = np.array([0.0, 0.0, 1.0]) + t = t - normal * float(np.dot(normal, t)) + t /= max(float(np.linalg.norm(t)), 1e-20) + b = np.cross(normal, t) + return np.stack([verts_3d @ t, verts_3d @ b], axis=1) + + +def _stretch_metrics(verts_3d: np.ndarray, uvs: np.ndarray, faces: np.ndarray) -> Tuple[float, float, int, int]: + """Sander's stretch metric. Returns (rms, max, n_flipped, n_zero_area).""" + p = verts_3d[faces] + t = uvs[faces] + parametric_area = 0.5 * ( + (t[:, 1, 1] - t[:, 0, 1]) * (t[:, 2, 0] - t[:, 0, 0]) + - (t[:, 2, 1] - t[:, 0, 1]) * (t[:, 1, 0] - t[:, 0, 0]) + ) + n_flipped = int((parametric_area < -1e-12).sum()) + n_zero = int((np.abs(parametric_area) < 1e-12).sum()) + pa = np.abs(parametric_area).clip(min=1e-20) + geom_area = 0.5 * np.linalg.norm( + np.cross(p[:, 1] - p[:, 0], p[:, 2] - p[:, 0]), axis=1 + ) + keep = (geom_area > 1e-12) & (np.abs(parametric_area) > 1e-12) + if not keep.any(): + return float("inf"), float("inf"), n_flipped, n_zero + t1 = t[:, 0, 0]; s1 = t[:, 0, 1] + t2 = t[:, 1, 0]; s2 = t[:, 1, 1] + t3 = t[:, 2, 0]; s3 = t[:, 2, 1] + inv_2pa = 1.0 / (2.0 * pa) + Ss = ( + p[:, 0] * (t2 - t3)[:, None] + + p[:, 1] * (t3 - t1)[:, None] + + p[:, 2] * (t1 - t2)[:, None] + ) * inv_2pa[:, None] + St = ( + p[:, 0] * (s3 - s2)[:, None] + + p[:, 1] * (s1 - s3)[:, None] + + p[:, 2] * (s2 - s1)[:, None] + ) * inv_2pa[:, None] + a = (Ss * Ss).sum(axis=1) + bb = (Ss * St).sum(axis=1) + c = (St * St).sum(axis=1) + sigma2_sq = 0.5 * (a + c + np.sqrt(np.maximum(0.0, (a - c) ** 2 + 4 * bb ** 2))) + rms_sq = (a + c) * 0.5 + rms_stretch_sq_sum = float((rms_sq[keep] * geom_area[keep]).sum()) + total_geom = float(geom_area[keep].sum()) + total_param = float(pa[keep].sum()) + if total_geom <= 0.0: + return float("inf"), float("inf"), n_flipped, n_zero + norm_factor = np.sqrt(total_param / total_geom) + rms_stretch = float(np.sqrt(rms_stretch_sq_sum / total_geom)) * norm_factor + max_stretch = float(np.sqrt(sigma2_sq[keep].max())) * norm_factor + return rms_stretch, max_stretch, n_flipped, n_zero + + +def _uv_boundary_self_intersects( + uvs: np.ndarray, faces: np.ndarray, face_face: np.ndarray, eps: float = 1e-9 +) -> bool: + """True if any chart-boundary edge pair crosses in 2D (ortho folded the chart).""" + fi, ei = np.nonzero(face_face < 0) + n = fi.size + if n < 2: + return False + a = uvs[faces[fi, ei]].astype(np.float64) + b = uvs[faces[fi, (ei + 1) % 3]].astype(np.float64) + d = b - a + # Pairwise segment crossings, row-chunked to bound memory at chunk*n. + chunk = max(1, min(n, 1_000_000 // max(n, 1))) + for s in range(0, n, chunk): + e = min(s + chunk, n) + d1 = d[s:e, None, :] + denom = d1[:, :, 0] * d[None, :, 1] - d1[:, :, 1] * d[None, :, 0] + rx = a[None, :, 0] - a[s:e, None, 0] + ry = a[None, :, 1] - a[s:e, None, 1] + with np.errstate(divide="ignore", invalid="ignore"): + t = (rx * d[None, :, 1] - ry * d[None, :, 0]) / denom + u = (rx * d1[:, :, 1] - ry * d1[:, :, 0]) / denom + cross = ( + (np.abs(denom) >= eps) + & (t > eps) & (t < 1.0 - eps) + & (u > eps) & (u < 1.0 - eps) + ) + if bool(cross.any()): + return True + return False + + +def parametrize_chart( + local_verts: Tensor, local_faces: Tensor, local_face_face: Tensor +) -> Tensor: + """Parameterize one chart: ortho first, ABF/LSCM fallback; charts <=5 faces stay ortho.""" + verts_np = local_verts.detach().cpu().numpy().astype(np.float64) + faces_np = local_faces.detach().cpu().numpy().astype(np.int64) + if verts_np.shape[0] < 3 or faces_np.shape[0] == 0: + return torch.zeros((verts_np.shape[0], 2), dtype=torch.float32, device=local_verts.device) + + ortho = _ortho_project(verts_np) + n_faces = faces_np.shape[0] + if n_faces <= 5: + return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device) + rms, mx, n_flip, n_zero = _stretch_metrics(verts_np, ortho, faces_np) + flip_ok = n_flip == 0 or n_flip == n_faces + if flip_ok and n_zero == 0 and rms <= 1.5 and mx <= 2.0: + ff_np = local_face_face.detach().cpu().numpy().astype(np.int64) + if not _uv_boundary_self_intersects(ortho, faces_np, ff_np): + return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device) + uvs_t = lscm_chart(local_verts, local_faces, local_face_face, pin_positions=ortho) + # Collapsed UV island (aspect > 100:1) blows up packing scale; fall back to ortho. + uvs_np = uvs_t.detach().cpu().numpy() + bbox = uvs_np.max(axis=0) - uvs_np.min(axis=0) + bbox_max = float(max(bbox[0], bbox[1], 1e-12)) + bbox_min = float(max(min(bbox[0], bbox[1]), 1e-12)) + if bbox_max / bbox_min > 100.0: + return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device) + return uvs_t + + +def _abf_face_coefficients( + verts_3d: np.ndarray, faces: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Per-face ABF constraint (largest-sine vertex at local index 2); returns (faces_reordered, cosine, sine, valid_mask) with valid_mask False for degenerate tris.""" + Fc = faces.shape[0] + p0 = verts_3d[faces[:, 0]] + p1 = verts_3d[faces[:, 1]] + p2 = verts_3d[faces[:, 2]] + e01 = p1 - p0 + e12 = p2 - p1 + e20 = p0 - p2 + L01 = np.linalg.norm(e01, axis=1).clip(min=1e-20) + L12 = np.linalg.norm(e12, axis=1).clip(min=1e-20) + L20 = np.linalg.norm(e20, axis=1).clip(min=1e-20) + cos_a0 = ((-e20) * e01).sum(axis=1) / (L20 * L01) + cos_a1 = ((-e01) * e12).sum(axis=1) / (L01 * L12) + cos_a2 = ((-e12) * e20).sum(axis=1) / (L12 * L20) + cos_a0 = cos_a0.clip(-1.0, 1.0) + cos_a1 = cos_a1.clip(-1.0, 1.0) + cos_a2 = cos_a2.clip(-1.0, 1.0) + a = np.arccos(cos_a0) + b_ang = np.arccos(cos_a1) + c_ang = np.arccos(cos_a2) + angles = np.stack([a, b_ang, c_ang], axis=1) + sines = np.stack([np.sin(a), np.sin(b_ang), np.sin(c_ang)], axis=1) + valid = (angles > 1e-12).all(axis=1) + ids = faces.astype(np.int64).copy() + + s0, s1, s2 = sines[:, 0], sines[:, 1], sines[:, 2] + pattA = (s1 > s0) & (s1 > s2) + pattB = (~pattA) & (s0 > s1) & (s0 > s2) + + if pattA.any(): + old_a = angles[pattA].copy() + old_s = sines[pattA].copy() + old_id = ids[pattA].copy() + angles[pattA] = old_a[:, [2, 0, 1]] + sines[pattA] = old_s[:, [2, 0, 1]] + ids[pattA] = old_id[:, [2, 0, 1]] + if pattB.any(): + old_a = angles[pattB].copy() + old_s = sines[pattB].copy() + old_id = ids[pattB].copy() + angles[pattB] = old_a[:, [1, 2, 0]] + sines[pattB] = old_s[:, [1, 2, 0]] + ids[pattB] = old_id[:, [1, 2, 0]] + + a0 = angles[:, 0] + s0 = sines[:, 0] + s1 = sines[:, 1] + s2 = sines[:, 2] + c0 = np.cos(a0) + ratio = np.where(s2 > 0.0, s1 / s2.clip(min=1e-20), 1.0) + cosine = c0 * ratio + sine = s0 * ratio + return ids, cosine, sine, valid + + +def lscm_chart( + local_verts: Tensor, + local_faces: Tensor, + local_face_face: Tensor, + pin_positions: "np.ndarray | None" = None, +) -> Tensor: + """ABF parameterization on one chart (degenerate faces use plain LSCM rows; two pins fix gauge at pin_positions).""" + verts_np = local_verts.detach().cpu().numpy().astype(np.float64) + faces_np = local_faces.detach().cpu().numpy().astype(np.int64) + Vc = verts_np.shape[0] + Fc = faces_np.shape[0] + + if Vc < 3 or Fc == 0: + return torch.zeros((Vc, 2), dtype=torch.float32, device=local_verts.device) + + loops = _mesh.chart_boundary_loops(local_faces, local_face_face) + pin_a, pin_b = _pick_pins(loops, verts_np) + + if pin_positions is not None and pin_positions.shape == (Vc, 2): + pa = pin_positions[pin_a] + pb = pin_positions[pin_b] + u_a, v_a = float(pa[0]), float(pa[1]) + u_b, v_b = float(pb[0]), float(pb[1]) + else: + u_a, v_a = 0.0, 0.0 + u_b, v_b = 1.0, 0.0 + + abf_ids, abf_cos, abf_sin, abf_valid = _abf_face_coefficients(verts_np, faces_np) + + rows_list: List[np.ndarray] = [] + cols_list: List[np.ndarray] = [] + vals_list: List[np.ndarray] = [] + + # ABF rows for valid faces. + valid_idx = np.nonzero(abf_valid)[0] + if valid_idx.size: + Nv = valid_idx.size + id0 = abf_ids[valid_idx, 0] + id1 = abf_ids[valid_idx, 1] + id2 = abf_ids[valid_idx, 2] + cosf = abf_cos[valid_idx] + sinf = abf_sin[valid_idx] + r_real = valid_idx * 2 + r_imag = valid_idx * 2 + 1 + ones = np.ones(Nv, dtype=np.float64) + rows_list.extend([r_real] * 5) + cols_list.extend([id0, id0 + Vc, id1, id1 + Vc, id2]) + vals_list.extend([cosf - 1.0, -sinf, -cosf, sinf, ones]) + rows_list.extend([r_imag] * 5) + cols_list.extend([id0, id0 + Vc, id1, id1 + Vc, id2 + Vc]) + vals_list.extend([sinf, cosf - 1.0, -sinf, -cosf, ones]) + + # Plain-LSCM rows for invalid (degenerate) faces. + invalid_idx = np.nonzero(~abf_valid)[0] + if invalid_idx.size: + tri2d_inv = _triangle_local_2d(verts_np, faces_np[invalid_idx]) + twice_area_inv = ( + tri2d_inv[:, 1, 0] * tri2d_inv[:, 2, 1] + - tri2d_inv[:, 1, 1] * tri2d_inv[:, 2, 0] + ) + weight_inv = 1.0 / np.sqrt(2.0 * np.abs(twice_area_inv).clip(min=1e-20)) + r_real_inv = invalid_idx * 2 + r_imag_inv = invalid_idx * 2 + 1 + for j in range(3): + jp1 = (j + 1) % 3 + jp2 = (j + 2) % 3 + a_j = (tri2d_inv[:, jp1, 0] - tri2d_inv[:, jp2, 0]) * weight_inv + b_j = (tri2d_inv[:, jp1, 1] - tri2d_inv[:, jp2, 1]) * weight_inv + v_idx = faces_np[invalid_idx, j] + rows_list.extend([r_real_inv, r_real_inv, r_imag_inv, r_imag_inv]) + cols_list.extend([v_idx, v_idx + Vc, v_idx, v_idx + Vc]) + vals_list.extend([a_j, -b_j, b_j, a_j]) + + rows = np.concatenate(rows_list) if rows_list else np.empty(0, dtype=np.int64) + cols = np.concatenate(cols_list) if cols_list else np.empty(0, dtype=np.int64) + vals = np.concatenate(vals_list) if vals_list else np.empty(0, dtype=np.float64) + + A_full = sp.csr_matrix((vals, (rows, cols)), shape=(2 * Fc, 2 * Vc)) + + pin_cols = np.array([pin_a, pin_b, pin_a + Vc, pin_b + Vc], dtype=np.int64) + pin_vals = np.array([u_a, u_b, v_a, v_b], dtype=np.float64) + + free_mask = np.ones(2 * Vc, dtype=bool) + free_mask[pin_cols] = False + free_cols = np.nonzero(free_mask)[0] + + A_pinned = A_full[:, pin_cols] + A_free = A_full[:, free_cols] + b = -(A_pinned @ pin_vals) + + # Singular system (under-constrained chart) falls back to ortho. + fallback_to_ortho = False + try: + with warnings.catch_warnings(): + warnings.simplefilter("error", category=sp.linalg.MatrixRankWarning) + x_free = solve_least_squares(A_free, b) + if not np.all(np.isfinite(x_free)): + fallback_to_ortho = True + except Exception: + fallback_to_ortho = True + + if fallback_to_ortho: + if pin_positions is not None and pin_positions.shape == (Vc, 2): + uvs = pin_positions.astype(np.float32) + else: + uvs = _ortho_project(verts_np).astype(np.float32) + return torch.from_numpy(uvs).to(local_verts.device) + + full = np.zeros(2 * Vc, dtype=np.float64) + full[free_cols] = x_free + full[pin_cols] = pin_vals + uvs = np.stack([full[:Vc], full[Vc:]], axis=1).astype(np.float32) + if not np.all(np.isfinite(uvs)): + if pin_positions is not None and pin_positions.shape == (Vc, 2): + uvs = pin_positions.astype(np.float32) + else: + uvs = _ortho_project(verts_np).astype(np.float32) + + return torch.from_numpy(uvs).to(local_verts.device) diff --git a/comfy_extras/mesh3d/uv_unwrap/segment.py b/comfy_extras/mesh3d/uv_unwrap/segment.py new file mode 100644 index 000000000..48dc82ab9 --- /dev/null +++ b/comfy_extras/mesh3d/uv_unwrap/segment.py @@ -0,0 +1,638 @@ +"""Adaptive cost-grow chart segmentation (CPU); numba optional, numpy path is nd-only.""" +from __future__ import annotations + +from typing import List, Tuple + +import numpy as np +import torch +from torch import Tensor + +try: + from numba import njit + _HAVE_NUMBA = True +except ImportError: + _HAVE_NUMBA = False + def njit(*args, **kwargs): # noqa: ARG001 + def deco(fn): + return fn + return deco if not args else args[0] + + +from .mesh import MeshData, face_edge_lengths + + +DEFAULT_W_NORMAL_DEVIATION = 2.0 +DEFAULT_W_ROUNDNESS = 0.01 +DEFAULT_W_STRAIGHTNESS = 6.0 +DEFAULT_MAX_COST = 2.0 +NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75° + + +@njit(cache=True, fastmath=False) +def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray: + F = face_normal.shape[0] + raw = np.zeros(F, dtype=np.float32) + for f in range(F): + nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2] + s = np.float32(0.0) + for e in range(3): + nb = face_face[f, e] + if nb < 0: + continue + mx = face_normal[nb, 0]; my = face_normal[nb, 1]; mz = face_normal[nb, 2] + d = nx*mx + ny*my + nz*mz + s += np.float32(1.0) - d + raw[f] = s + return raw + + +def _face_curvature_numpy(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray: + nb_safe = np.maximum(face_face, 0) + nb_normal = face_normal[nb_safe] + d = (face_normal[:, None, :] * nb_normal).sum(axis=-1) + contrib = np.where(face_face >= 0, np.float32(1.0) - d, np.float32(0.0)) + return contrib.sum(axis=1).astype(np.float32) + + +@njit(cache=True, fastmath=False) +def _farthest_point_seeds_jit( + face_centroid: np.ndarray, face_area: np.ndarray, face_weight: np.ndarray, + initial_seeds: np.ndarray, k_target: int, +): + F = face_centroid.shape[0] + INF = np.float32(1e30) + min_dist = np.full(F, INF, dtype=np.float32) + seeds = np.empty(k_target, dtype=np.int64) + n_seeds = 0 + for i in range(initial_seeds.shape[0]): + s = initial_seeds[i] + if s < 0 or n_seeds >= k_target: + continue + seeds[n_seeds] = s + n_seeds += 1 + sx = face_centroid[s, 0]; sy = face_centroid[s, 1]; sz = face_centroid[s, 2] + for f in range(F): + dx = face_centroid[f, 0] - sx + dy = face_centroid[f, 1] - sy + dz = face_centroid[f, 2] - sz + d2 = dx*dx + dy*dy + dz*dz + if d2 < min_dist[f]: + min_dist[f] = d2 + while n_seeds < k_target: + best_f = -1 + best_score = np.float32(-1.0) + for f in range(F): + d = min_dist[f] + if d >= INF * np.float32(0.5): + continue + score = d * face_weight[f] + if score > best_score: + best_score = score + best_f = f + if best_f < 0: + break + seeds[n_seeds] = best_f + n_seeds += 1 + sx = face_centroid[best_f, 0] + sy = face_centroid[best_f, 1] + sz = face_centroid[best_f, 2] + for f in range(F): + dx = face_centroid[f, 0] - sx + dy = face_centroid[f, 1] - sy + dz = face_centroid[f, 2] - sz + d2 = dx*dx + dy*dy + dz*dz + if d2 < min_dist[f]: + min_dist[f] = d2 + return seeds[:n_seeds] + + +def _farthest_point_seeds_numpy( + face_centroid: np.ndarray, initial_seeds: np.ndarray, k_target: int, +): + F = face_centroid.shape[0] + min_dist = np.full(F, np.inf, dtype=np.float32) + seeds: List[int] = [] + for s in initial_seeds: + if s < 0 or len(seeds) >= k_target: + continue + seeds.append(int(s)) + d = ((face_centroid - face_centroid[s])**2).sum(axis=-1) + min_dist = np.minimum(min_dist, d) + while len(seeds) < k_target: + best = int(np.argmax(min_dist)) + if not np.isfinite(min_dist[best]) or min_dist[best] <= 0: + break + seeds.append(best) + d = ((face_centroid - face_centroid[best])**2).sum(axis=-1) + min_dist = np.minimum(min_dist, d) + return np.asarray(seeds, dtype=np.int64) + + +@njit(cache=True, fastmath=False) +def _cost_grow_iter_jit( + face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray, + face_area: np.ndarray, face_edge_len: np.ndarray, + chart_basis: np.ndarray, chart_normal_sum: np.ndarray, + chart_area: np.ndarray, chart_perim: np.ndarray, + nd_cutoff: float, max_cost: float, + w_nd: float, w_round: float, w_straight: float, +): + """One grow iter: each unassigned face joins its lowest-cost adjacent chart if cost np.float32(1.0): + nd = np.float32(1.0) + if nd < np.float32(0.0): + nd = np.float32(0.0) + if nd >= nd_cutoff: + continue + l_in = np.float32(0.0) + l_out = np.float32(0.0) + for e1 in range(3): + nb1 = face_face[f, e1] + el = face_edge_len[f, e1] + if nb1 < 0: + l_out += el + elif face_chart[nb1] == c: + l_in += el + else: + l_out += el + ca = chart_area[c] + cp = chart_perim[c] + new_perim = cp - l_in + l_out + new_area = ca + af + if cp <= np.float32(1e-20) or ca <= np.float32(1e-20): + round_cost = np.float32(0.0) + else: + old_r = (cp * cp) / ca + new_r = (new_perim * new_perim) / new_area + if new_r <= np.float32(1e-20): + round_cost = np.float32(0.0) + else: + round_cost = np.float32(1.0) - old_r / new_r + denom = l_out + l_in + if denom <= np.float32(1e-20): + straight_cost = np.float32(0.0) + else: + ratio = (l_out - l_in) / denom + if ratio < np.float32(0.0): + straight_cost = ratio + else: + straight_cost = np.float32(0.0) + cost = (w_nd * nd + w_round * round_cost + w_straight * straight_cost) + if cost < best_cost_per_face[f]: + best_cost_per_face[f] = cost + best_chart_per_face[f] = c + + n_assigned = 0 + for f in range(F): + if face_chart[f] != -1: + continue + if best_chart_per_face[f] < 0: + continue + if best_cost_per_face[f] > max_cost: + continue + c = best_chart_per_face[f] + l_in = np.float32(0.0) + l_out = np.float32(0.0) + for e1 in range(3): + nb1 = face_face[f, e1] + el = face_edge_len[f, e1] + if nb1 < 0: + l_out += el + elif face_chart[nb1] == c: + l_in += el + else: + l_out += el + af = face_area[f] + face_chart[f] = c + chart_normal_sum[c, 0] += face_normal[f, 0] * af + chart_normal_sum[c, 1] += face_normal[f, 1] * af + chart_normal_sum[c, 2] += face_normal[f, 2] * af + chart_area[c] += af + chart_perim[c] = chart_perim[c] - l_in + l_out + nx = chart_normal_sum[c, 0] + ny = chart_normal_sum[c, 1] + nz = chart_normal_sum[c, 2] + nlen = np.sqrt(nx * nx + ny * ny + nz * nz) + if nlen > np.float32(1e-20): + chart_basis[c, 0] = nx / nlen + chart_basis[c, 1] = ny / nlen + chart_basis[c, 2] = nz / nlen + n_assigned += 1 + return n_assigned + + +def _renumber(face_chart: np.ndarray, device) -> Tensor: + unique = np.unique(face_chart[face_chart >= 0]) + if unique.size == 0: + return torch.from_numpy(face_chart).to(device) + remap = np.full(int(unique.max()) + 1, -1, dtype=np.int64) + remap[unique] = np.arange(unique.size) + out = face_chart.copy() + mask = out >= 0 + out[mask] = remap[out[mask]] + return torch.from_numpy(out).to(device) + + +def _segment_charts_fast( + mesh: MeshData, + max_cost: float, + w_normal_deviation: float, + w_roundness: float = DEFAULT_W_ROUNDNESS, + w_straightness: float = DEFAULT_W_STRAIGHTNESS, + target_chart_count: int = 0, +) -> Tensor: + """Parallel batch cost-grow; target_chart_count 0 = adaptive seeding, >0 = K curvature-weighted FPS seeds.""" + F = mesh.faces.shape[0] + device = mesh.faces.device + if F == 0: + return torch.zeros(0, dtype=torch.long, device=device) + + face_normal = mesh.face_normal.detach().cpu().numpy().astype(np.float32) + face_area = mesh.face_area.detach().cpu().numpy().astype(np.float32) + face_centroid = mesh.face_centroid.detach().cpu().numpy().astype(np.float32) + face_face = mesh.face_face.detach().cpu().numpy() + + face_chart = np.full(F, -1, dtype=np.int64) + nd_cutoff = np.float32(NORMAL_DEVIATION_HARD_CUTOFF) + nd_threshold = np.float32(min(max_cost / max(w_normal_deviation, 1e-6), + NORMAL_DEVIATION_HARD_CUTOFF * 0.99)) + + component = (mesh.component.detach().cpu().numpy() + if hasattr(mesh.component, "detach") else np.asarray(mesh.component)) + if component.size: + _, first_idx = np.unique(component, return_index=True) + initial_seeds = first_idx.astype(np.int64) + else: + initial_seeds = np.empty(0, dtype=np.int64) + + adaptive_seeding = target_chart_count <= 0 + if adaptive_seeding: + seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()] + if not seed_faces: + seed_faces = [0] + else: + if _HAVE_NUMBA: + curvature_raw = _face_curvature_jit(face_normal, face_face) + else: + curvature_raw = _face_curvature_numpy(face_normal, face_face) + cmax = float(curvature_raw.max()) if curvature_raw.size else 0.0 + if cmax > 1e-6: + face_weight = (np.float32(1.0) + np.float32(50.0) * + (curvature_raw / np.float32(cmax))).astype(np.float32) + else: + face_weight = np.ones(F, dtype=np.float32) + n_comp = int(initial_seeds.size) + if n_comp < int(target_chart_count): + target_seeds = int(target_chart_count) + else: + target_seeds = n_comp + max(int(target_chart_count) // 4, 8) + target_seeds = min(target_seeds, F) + if _HAVE_NUMBA: + seeds_arr = _farthest_point_seeds_jit( + face_centroid, face_area, face_weight, initial_seeds, target_seeds, + ) + else: + seeds_arr = _farthest_point_seeds_numpy( + face_centroid, initial_seeds, target_seeds, + ) + seed_faces = [int(s) for s in seeds_arr.tolist()] + + K = len(seed_faces) + chart_basis = np.zeros((K, 3), dtype=np.float32) + chart_normal_sum = np.zeros((K, 3), dtype=np.float32) + chart_area = np.zeros(K, dtype=np.float32) + chart_perim = np.zeros(K, dtype=np.float32) + face_edge_len = ( + face_edge_lengths(mesh.vertices, mesh.faces) + .detach().cpu().numpy() + ) + for cid, sf in enumerate(seed_faces): + face_chart[sf] = cid + n = face_normal[sf] + a = face_area[sf] + chart_basis[cid] = n.astype(np.float32) + chart_normal_sum[cid] = (n * a).astype(np.float32) + chart_area[cid] = float(a) + chart_perim[cid] = float(face_edge_len[sf].sum()) + + if K == 0: + return _renumber(face_chart, device) + + min_dist_to_seed = np.full(F, np.inf, dtype=np.float32) + if adaptive_seeding: + for sf in seed_faces: + d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1) + min_dist_to_seed = np.minimum(min_dist_to_seed, d) + + if _HAVE_NUMBA: + # Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg. + tau_final = min(max_cost * 0.25, 0.5) + thresholds = [t for t in (0.05, 0.1, 0.25) if t < tau_final] + [tau_final] + max_inner = max(64, int(np.sqrt(F)) * 2) + max_total_charts = max(F, 8000) + outer_iter = 0 + while True: + outer_iter += 1 + if outer_iter > F + 16: + break + for tau in thresholds: + for _ in range(max_inner): + n_added = _cost_grow_iter_jit( + face_chart, face_face, face_normal, face_area, face_edge_len, + chart_basis, chart_normal_sum, chart_area, chart_perim, + nd_cutoff, np.float32(tau), + np.float32(w_normal_deviation), + np.float32(w_roundness), + np.float32(w_straightness), + ) + if n_added == 0: + break + if (face_chart == -1).sum() == 0: + break + if not adaptive_seeding: + break + if chart_basis.shape[0] >= max_total_charts: + break + unassigned_mask = face_chart == -1 + cand = np.where(unassigned_mask, min_dist_to_seed, np.float32(-np.inf)) + new_seed = int(np.argmax(cand)) + n = face_normal[new_seed] + a = face_area[new_seed] + chart_basis = np.vstack([chart_basis, n[None, :].astype(np.float32)]) + chart_normal_sum = np.vstack( + [chart_normal_sum, (n * a)[None, :].astype(np.float32)] + ) + chart_area = np.concatenate([chart_area, np.array([a], dtype=np.float32)]) + chart_perim = np.concatenate( + [chart_perim, np.array([face_edge_len[new_seed].sum()], dtype=np.float32)] + ) + face_chart[new_seed] = chart_basis.shape[0] - 1 + new_d = ((face_centroid - face_centroid[new_seed]) ** 2).sum(axis=-1) + min_dist_to_seed = np.minimum(min_dist_to_seed, new_d) + else: + # Numpy fallback: nd-only adaptive grow. + for _ in range(max(64, int(np.sqrt(F)) + 32)): + unassigned = face_chart == -1 + if not unassigned.any(): + break + u_idx = np.nonzero(unassigned)[0] + nbs = face_face[u_idx] + nbs_safe = np.where(nbs >= 0, nbs, 0) + nb_charts = np.where(nbs >= 0, face_chart[nbs_safe], -1) + valid = (nb_charts >= 0) + if not valid.any(): + break + nb_charts_safe = np.where(valid, nb_charts, 0) + nb_basis = chart_basis[nb_charts_safe] + d = (face_normal[u_idx][:, None, :] * nb_basis).sum(axis=-1) + nd = np.where(valid, np.float32(1.0) - d, np.inf).clip(max=1.0) + nd = np.where(nd >= nd_cutoff, np.inf, nd) + best_e = np.argmin(nd, axis=1) + best_cost = nd[np.arange(u_idx.size), best_e] + best_c = nb_charts_safe[np.arange(u_idx.size), best_e] + accept = (best_cost <= nd_threshold) & np.isfinite(best_cost) + if not accept.any(): + break + pick_u = u_idx[accept] + pick_c = best_c[accept] + face_chart[pick_u] = pick_c + for f, c in zip(pick_u, pick_c): + chart_normal_sum[c] += face_normal[f] * face_area[f] + chart_area[c] += face_area[f] + + # Orphan cleanup: leftover faces join their best-matching neighbor's chart. + if (face_chart == -1).any() and chart_basis.shape[0] > 0: + while True: + orphans = np.nonzero(face_chart == -1)[0] + if orphans.size == 0: + break + nbs = face_face[orphans] + nbs_safe = np.where(nbs >= 0, nbs, 0) + nb_charts = np.where(nbs >= 0, face_chart[nbs_safe], -1) + valid = (nb_charts >= 0) + if not valid.any(): + break + nb_charts_safe = np.where(valid, nb_charts, 0) + nb_basis = chart_basis[nb_charts_safe] + d = (face_normal[orphans][:, None, :] * nb_basis).sum(axis=-1) + nd = np.where(valid, np.float32(1.0) - d, np.inf) + best_e = np.argmin(nd, axis=1) + best_c = nb_charts_safe[np.arange(orphans.size), best_e] + assignable = valid.any(axis=1) + if not assignable.any(): + break + assign_idx = orphans[assignable] + assign_c = best_c[assignable] + face_chart[assign_idx] = assign_c + if (face_chart == -1).any(): + new_singletons = np.nonzero(face_chart == -1)[0] + for f in new_singletons: + face_chart[int(f)] = chart_basis.shape[0] + chart_basis = np.concatenate( + [chart_basis, face_normal[int(f)].astype(np.float32)[None, :]], + axis=0, + ) + + return _renumber(face_chart, device) + + +def segment_charts( + mesh: MeshData, + max_cost: float = DEFAULT_MAX_COST, + w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION, + w_roundness: float = DEFAULT_W_ROUNDNESS, + w_straightness: float = DEFAULT_W_STRAIGHTNESS, + target_chart_count: int = 0, +) -> Tensor: + """Segment mesh into charts. Returns face -> chart_id.""" + return _segment_charts_fast( + mesh, max_cost=max_cost, + w_normal_deviation=w_normal_deviation, + w_roundness=w_roundness, + w_straightness=w_straightness, + target_chart_count=target_chart_count, + ) + + +# ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ---- +def _combine_normal_cones( + axis_a: Tensor, half_a: Tensor, + axis_b: Tensor, half_b: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """Merge two normal cones along the great circle from axis_a; returns (combined_axis, combined_half_angle, axis_angle).""" + cos_angle = (axis_a * axis_b).sum(dim=-1).clamp(-1.0, 1.0) + axis_angle = torch.acos(cos_angle) + new_low = torch.minimum(-half_a, axis_angle - half_b) + new_high = torch.maximum(half_a, axis_angle + half_b) + new_half = (new_high - new_low) * 0.5 + rot_angle = (new_high + new_low) * 0.5 + b_perp = axis_b - axis_a * cos_angle.unsqueeze(-1) + b_perp_norm = b_perp.norm(dim=-1, keepdim=True).clamp_min(1e-12) + b_perp_unit = b_perp / b_perp_norm + new_axis = ( + axis_a * torch.cos(rot_angle).unsqueeze(-1) + + b_perp_unit * torch.sin(rot_angle).unsqueeze(-1) + ) + new_axis_norm = new_axis.norm(dim=-1, keepdim=True).clamp_min(1e-12) + new_axis = new_axis / new_axis_norm + return new_axis, new_half, axis_angle + + +def _build_chart_edges( + face_face: Tensor, + chart_id: Tensor, + face_edge_len: Tensor, +) -> Tuple[Tensor, Tensor]: + """Build chart-edge list (chart_pairs[E,2] with a= 0 + f_idx = f_idx[valid] + nb = nb[valid] + el = face_edge_len.flatten()[valid] + + ca = chart_id[f_idx] + cb = chart_id[nb] + diff = ca != cb + ca = ca[diff] + cb = cb[diff] + el = el[diff] + if ca.numel() == 0: + return ( + torch.empty((0, 2), dtype=torch.long, device=device), + torch.empty(0, device=device), + ) + + lo = torch.minimum(ca, cb) + hi = torch.maximum(ca, cb) + V = int(chart_id.max().item()) + 1 + key = lo * V + hi + sort_idx = torch.argsort(key) + sorted_key = key[sort_idx] + sorted_lo = lo[sort_idx] + sorted_hi = hi[sort_idx] + sorted_el = el[sort_idx] + unique_key, inverse, counts = torch.unique( + sorted_key, return_inverse=True, return_counts=True + ) + n_unique = unique_key.shape[0] + reduced_el = torch.zeros(n_unique, device=device, dtype=el.dtype) + reduced_el.scatter_add_(0, inverse, sorted_el) + first_idx = torch.cat([ + torch.zeros(1, dtype=torch.long, device=device), + counts.cumsum(0)[:-1], + ]) + pair_lo = sorted_lo[first_idx] + pair_hi = sorted_hi[first_idx] + chart_pairs = torch.stack([pair_lo, pair_hi], dim=1) + return chart_pairs, reduced_el + + +def cluster_charts_pec( + mesh: MeshData, + target_chart_count: int = 0, + max_cost: float = 0.7, + area_penalty_weight: float = 0.0, + roundness_weight: float = 0.0, + max_iters: int = 1024, +) -> Tensor: + """Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg).""" + device = mesh.faces.device + F = mesh.faces.shape[0] + faces = mesh.faces.to(torch.long) + vertices = mesh.vertices.to(torch.float32) + face_normal = mesh.face_normal.to(torch.float32) + face_area = mesh.face_area.to(torch.float32) + face_face = mesh.face_face.to(torch.long) + + face_edge_len = face_edge_lengths(vertices, faces) + + chart_id = torch.arange(F, dtype=torch.long, device=device) + chart_axis = face_normal.clone() + chart_half = torch.zeros(F, dtype=torch.float32, device=device) + chart_area = face_area.clone() + chart_perim = face_edge_len.sum(dim=1).clone() + + for it in range(max_iters): + edges, edge_len = _build_chart_edges(face_face, chart_id, face_edge_len) + if edges.shape[0] == 0: + break + + a = edges[:, 0] + b = edges[:, 1] + axis_a = chart_axis[a] + axis_b = chart_axis[b] + half_a = chart_half[a] + half_b = chart_half[b] + _, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b) + cost = new_half.clone() + if area_penalty_weight > 0.0: + new_area = chart_area[a] + chart_area[b] + cost = cost + area_penalty_weight * new_area + if roundness_weight > 0.0: + new_area_r = chart_area[a] + chart_area[b] + new_perim_r = chart_perim[a] + chart_perim[b] - 2.0 * edge_len + cost = cost + roundness_weight * (new_perim_r * new_perim_r) / new_area_r.clamp_min(1e-12) + + # Pack (cost, edge_id) so scatter_reduce amin picks the right edge. + E = edges.shape[0] + N = int(chart_id.max().item()) + 1 + edge_ids = torch.arange(E, dtype=torch.long, device=device) + cost_i32 = torch.clamp(cost * 1e6, max=2e9).to(torch.int64) + key = (cost_i32 << 32) | edge_ids + chart_min = torch.full((N,), (2**62), dtype=torch.long, device=device) + chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True) + chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True) + + # Mutual-min collapse: each chart in at most one merge per iter. + is_a_min = chart_min[a] == key + is_b_min = chart_min[b] == key + within = cost <= max_cost + winners = is_a_min & is_b_min & within + + n_merge = int(winners.sum().item()) + if n_merge == 0: + break + + win_a = a[winners] + win_b = b[winners] + win_el = edge_len[winners] + + axis_a_w = chart_axis[win_a] + half_a_w = chart_half[win_a] + axis_b_w = chart_axis[win_b] + half_b_w = chart_half[win_b] + new_axis, new_half_w, _ = _combine_normal_cones( + axis_a_w, half_a_w, axis_b_w, half_b_w, + ) + chart_axis[win_a] = new_axis + chart_half[win_a] = new_half_w + chart_area[win_a] = chart_area[win_a] + chart_area[win_b] + chart_perim[win_a] = chart_perim[win_a] + chart_perim[win_b] - 2.0 * win_el + + remap = torch.arange(N, dtype=torch.long, device=device) + remap[win_b] = win_a + chart_id = remap[chart_id] + + _, inverse = torch.unique(chart_id, sorted=True, return_inverse=True) + return inverse diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 1ad98c589..a0da42139 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -6,9 +6,20 @@ from comfy_api.latest import ComfyExtension, IO, Types import copy import comfy.utils import comfy.model_management -from comfy_extras.qem_decimate.qem_core import simplify as qem_decimate_simplify, QEMConfig +from server import PromptServer +from comfy_extras.mesh3d.postprocess.qem_decimate import ( + simplify as qem_decimate_simplify, QEMConfig, cluster_decimate as qem_cluster_decimate, +) +from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc +from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh +from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg +from comfy_extras.mesh3d.uv_unwrap import parameterize as _uv_param +from comfy_extras.mesh3d.uv_unwrap import pack as _uv_pack +import warnings import logging import scipy +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import connected_components def get_mesh_batch_item(mesh, index): if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None: @@ -2162,6 +2173,566 @@ class DecimateMesh(IO.ComfyNode): return result +class RemeshMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + # sign_mode picks the scalar field, and exposes only the knobs relevant to it + # (DynamicCombo: udf sub-widgets show for 'udf', sdf sub-widgets for 'sdf'). + sign_mode_options = [ + IO.DynamicCombo.Option(key="udf", inputs=[ + IO.Boolean.Input("qef", default=False, + tooltip="Experimental: place dual vertices via QEF (closest-triangle normals) " + "instead of edge-crossing centroid. QEF is sign-agnostic so it works " + "in UDF too — pulls the ±eps surface back onto the planes for sharper " + "edges. May misbehave near the UDF double shell; compare with it off."), + IO.Boolean.Input("drop_inverted_components", default=True, + tooltip="Drop closed components with inward normals (negative signed volume) — " + "the inner shell UDF produces on closed regions."), + IO.Boolean.Input("drop_enclosed_components", default=True, + tooltip="Drop components whose bbox is inside the largest's AND fail a raycast " + "point-in-mesh test. Disable if you have legitimate parts inside others."), + ]), + IO.DynamicCombo.Option(key="sdf", inputs=[ + IO.Boolean.Input("qef", default=True, + tooltip="Place dual vertices via QEF solve from closest-triangle normals " + "(recovers sharp features) vs edge-crossing centroid."), + IO.Boolean.Input("manifold", default=False, + tooltip="Manifold Dual Contouring: emit 1-4 dual verts per voxel for " + "multi-sheet (thin/touching) cases. Slower; guarantees manifold output."), + ]), + ] + return IO.Schema( + node_id="RemeshMesh", + display_name="Remesh Mesh (Narrow-Band DC)", + category="latent/3d", + description=( + "Re-extracts a uniformly tessellated mesh by sampling a distance field on a " + "narrow-band voxel grid and contouring it with Dual Contouring, on the active " + "compute device. Normalizes topology of messy / non-manifold / self-intersecting " + "input; run before DecimateMesh to hit an exact face count. Output stays welded." + ), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Int.Input("target_faces", default=0, min=0, max=50_000_000, + tooltip="0 = use 'resolution'. >0 = auto-pick resolution to roughly hit this " + "face count (±30-50%); usually overshoot then DecimateMesh to exact."), + IO.Int.Input("resolution", default=256, min=32, max=1024, + tooltip="Voxel grid resolution (used when target_faces=0). Higher = more detail, " + "slower. 256 ~ 100k faces, 512 ~ 1M."), + IO.DynamicCombo.Input("sign_mode", options=sign_mode_options, display_name="sign_mode", + tooltip="udf: robust to messy/non-manifold input (double shell cleaned by " + "the inner-shell filters). sdf: clean single surface with optional " + "QEF sharp-feature recovery, but needs consistent winding."), + IO.Float.Input("band", default=1.0, min=0.5, max=4.0, step=0.1, + tooltip="Narrow-band width in voxel units (which voxels are sampled). In UDF " + "mode also offsets the surface by this many voxels."), + IO.Float.Input("project_back", default=0.0, min=0.0, max=1.0, step=0.05, + tooltip="Lerp output verts toward the closest point on the original surface " + "(0 = pure DC, 1 = snapped). Recovers voxelization-lost detail."), + IO.Boolean.Input("fix_poles", default=False, + tooltip="Collapse valence-3 vertex pairs (DC T-junction artifact). Cheap; " + "improves shading and downstream simplification."), + IO.Int.Input("smooth_iters", default=0, min=0, max=20, + tooltip="Taubin λ|μ smoothing iterations (0 = off). Volume-preserving; cleans DC " + "stairstepping. 2-3 is enough; higher rounds off QEF sharp features."), + IO.Float.Input("drop_small_components", default=0.01, min=0.0, max=0.5, step=0.005, + tooltip="Drop components with fewer than this fraction of the largest component's " + "faces (inner-shell fragments, noise). 0 disables."), + IO.Int.Input("precluster_max_verts", default=0, min=0, max=50_000_000, + tooltip="Safety fallback: if input has more verts than this (>0), cluster-decimate " + "it down first so the distance-field queries don't OOM on huge inputs. " + "0 = off; 1-2M is reasonable for very large meshes."), + ], + outputs=[IO.Mesh.Output("mesh")], + hidden=[IO.Hidden.unique_id], + ) + + @classmethod + def execute(cls, mesh, target_faces, resolution, sign_mode, band, + project_back, fix_poles, smooth_iters, + drop_small_components, precluster_max_verts): + mode = sign_mode.get("sign_mode", "udf") + # mode-specific sub-widgets (absent ones fall back to defaults) + qef = bool(sign_mode.get("qef", True)) + manifold = bool(sign_mode.get("manifold", False)) + drop_inverted_components = bool(sign_mode.get("drop_inverted_components", True)) + drop_enclosed_components = bool(sign_mode.get("drop_enclosed_components", True)) + + # ComfyUI passes meshes on CPU; remesh is far faster on GPU. Run on the + # selected compute device and return on the mesh's original device. + compute_device = comfy.model_management.get_torch_device() + counts = {"in": 0, "out": 0} + + def _fn(v, f, c): + counts["in"] += int(f.shape[0]) + try: + src_device = v.device + vv = v.to(compute_device).float() + ff = f.to(compute_device).to(torch.int64) + cc = c.to(compute_device).float() if c is not None else None + + # safety fallback: cluster-decimate very large inputs before the field queries + if precluster_max_verts > 0 and vv.shape[0] > precluster_max_verts: + vv, ff, cc = qem_cluster_decimate( + vv, ff, target_verts=int(precluster_max_verts), colors=cc) + + # Fixed [-0.5,0.5] cube domain (matches cumesh / TRELLIS2). scale ≈ 1.0 + # for any resolution, so this is consistent in target_faces auto mode too. + rs_scale = (resolution + 3.0 * band) / resolution + rs_center = torch.zeros(3, dtype=vv.dtype, device=compute_device) + + rv, rf, rc = remesh_narrow_band_dc( + vv, ff, + resolution=int(resolution), target_faces=int(target_faces), + band=float(band), project_back=float(project_back), + qef=qef, sign_mode=mode, + manifold=manifold, fix_poles=bool(fix_poles), + smooth_iters=int(smooth_iters), + drop_small_components=float(drop_small_components), + drop_inverted_components=drop_inverted_components, + drop_enclosed_components=drop_enclosed_components, + scale=rs_scale, center=rs_center, colors=cc) + + v = rv.to(src_device) + f = rf.to(src_device) + c = rc.to(src_device) if rc is not None else None + except Exception as e: + logging.warning(f"RemeshMesh: remesh failed, passing mesh through unchanged: {e!r}") + counts["out"] += int(f.shape[0]) + return v, f, c + + result = _process_mesh_batch(mesh, _fn) + + # Send progress text to display the face change on the node + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text( + f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id) + + return result + + +def _pack_uv_meshes(vs, fs, uvs, colors): + """Pack per-item (verts, faces, uvs[, colors]) into a MESH; stack if single, else pad with counts.""" + if len(vs) == 1: + m = Types.MESH(vertices=vs[0].unsqueeze(0), faces=fs[0].unsqueeze(0), uvs=uvs[0].unsqueeze(0)) + if colors is not None: + m.vertex_colors = colors[0].unsqueeze(0) + return m + bsz = len(vs) + dev = vs[0].device + maxv = max(v.shape[0] for v in vs) + maxf = max(f.shape[0] for f in fs) + pv = vs[0].new_zeros((bsz, maxv, 3)) + pf = fs[0].new_zeros((bsz, maxf, 3)) + pu = uvs[0].new_zeros((bsz, maxv, 2)) + for i, (v, f, u) in enumerate(zip(vs, fs, uvs)): + pv[i, :v.shape[0]] = v + pf[i, :f.shape[0]] = f + pu[i, :u.shape[0]] = u + vc = torch.tensor([v.shape[0] for v in vs], device=dev, dtype=torch.int64) + fc = torch.tensor([f.shape[0] for f in fs], device=dev, dtype=torch.int64) + m = Types.MESH(vertices=pv, faces=pf, uvs=pu, vertex_counts=vc, face_counts=fc) + if colors is not None: + pc = colors[0].new_zeros((bsz, maxv, colors[0].shape[1])) + for i, c in enumerate(colors): + pc[i, :c.shape[0]] = c + m.vertex_colors = pc + return m + + +def _uv_weld_vertices(v, f, weld_distance): + """Merge coincident verts; returns (welded_v, welded_f, welded_to_orig) (last None if no welding).""" + v_np = v.cpu().numpy() + f_np = f.cpu().numpy() + if v_np.size == 0: + return v, f, None + extent = float(np.linalg.norm(v_np.max(axis=0) - v_np.min(axis=0))) + tol = weld_distance if weld_distance > 0.0 else 1e-5 * extent + if tol <= 0.0: + return v, f, None + keys = np.round(v_np / tol).astype(np.int64) + _, inv = np.unique(keys, axis=0, return_inverse=True) + n_unique = int(inv.max()) + 1 + if n_unique >= v_np.shape[0]: + return v, f, None + v_welded = np.zeros((n_unique, 3), dtype=np.float32) + counts = np.zeros(n_unique, dtype=np.int64) + np.add.at(v_welded, inv, v_np) + np.add.at(counts, inv, 1) + v_welded /= counts[:, None] + welded_to_orig = np.empty(n_unique, dtype=np.int64) + welded_to_orig[inv] = np.arange(v_np.shape[0], dtype=np.int64) + v_new = torch.from_numpy(v_welded).to(v.dtype).to(v.device) + f_new = torch.from_numpy(inv[f_np]).to(f.dtype).to(f.device) + return v_new, f_new, welded_to_orig + + +def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance): + """UV-unwrap a single mesh; returns (vmapping, indices, uvs) — vmapping maps each output + vertex to an input vertex (seam verts duplicated).""" + v_in = positions.to(torch.float32) + f_in = indices.to(torch.long).reshape(-1, 3) + v_in, f_in, welded_to_orig = _uv_weld_vertices(v_in, f_in, weld_distance) + + # drop degenerate faces (repeated index) — they corrupt edge adjacency + degen = ((f_in[:, 0] == f_in[:, 1]) | (f_in[:, 1] == f_in[:, 2]) | (f_in[:, 2] == f_in[:, 0])) + if bool(degen.any()): + f_in = f_in[~degen] + + mesh = _uv_mesh.build_mesh(v_in, f_in) + ff = mesh.face_face + if ff.numel() and float((ff >= 0).float().mean().item()) < 0.25: + warnings.warn("[uv_unwrap] mesh face-adjacency < 25% — vertices appear un-welded " + "(triangle soup); UV charts will be per-face. Raise weld_distance.") + + if segmenter == "pec": + if mesh.faces.device.type != "cuda": + raise RuntimeError("segmenter='pec' requires a CUDA mesh; use 'adaptive' for CPU.") + face_chart = _uv_seg.cluster_charts_pec(mesh, target_chart_count=0, max_cost=1.0) + elif segmenter == "adaptive": + face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0, target_chart_count=0) + else: + raise ValueError(f"unknown segmenter '{segmenter}'. valid: pec, adaptive") + + n_charts = int(face_chart.max().item()) + 1 if face_chart.numel() else 0 + areas_cpu = _uv_mesh.chart_3d_areas(mesh.face_area, face_chart, n_charts).detach().cpu() + + # per-chart loop runs on CPU/numpy to avoid per-chart GPU sync + face_chart_np = face_chart.cpu().numpy() + faces_np = mesh.faces.cpu().numpy() + vertices_np = mesh.vertices.cpu().numpy() + face_face_np = mesh.face_face.cpu().numpy() + sorted_face_idx_np = np.argsort(face_chart_np, kind="stable") + chart_counts_np = np.bincount(face_chart_np, minlength=n_charts) + chart_offsets_np = np.empty(n_charts + 1, dtype=np.int64) + chart_offsets_np[0] = 0 + np.cumsum(chart_counts_np, out=chart_offsets_np[1:]) + + all_chart_uvs, all_chart_3d_areas, all_chart_uv_areas, all_chart_faces = [], [], [], [] + chart_records = [] + for c in range(n_charts): + gfi_np = sorted_face_idx_np[chart_offsets_np[c]:chart_offsets_np[c + 1]] + chart_faces_global = faces_np[gfi_np] + used_verts_np = np.unique(chart_faces_global) + local_faces_np = np.searchsorted(used_verts_np, chart_faces_global) + local_verts_np = vertices_np[used_verts_np] + ff_global = face_face_np[gfi_np] + ff_safe = np.maximum(ff_global, 0) + nb_chart = np.where(ff_global >= 0, face_chart_np[ff_safe], -1) + keep = (ff_global >= 0) & (nb_chart == c) + local_neighbor = np.searchsorted(gfi_np, ff_safe) + local_ff_np = np.where(keep, local_neighbor, -1) + + lf = torch.from_numpy(local_faces_np) + uvs = _uv_param.parametrize_chart( + torch.from_numpy(local_verts_np), lf, torch.from_numpy(local_ff_np)) + ua, ub, uc = uvs[lf[:, 0]], uvs[lf[:, 1]], uvs[lf[:, 2]] + uv_area_sum = float(0.5 * ( + (ub[:, 0] - ua[:, 0]) * (uc[:, 1] - ua[:, 1]) + - (uc[:, 0] - ua[:, 0]) * (ub[:, 1] - ua[:, 1])).abs().sum().item()) + chart_records.append({"local_faces": lf, "vmap": torch.from_numpy(used_verts_np), + "global_face_idx": torch.from_numpy(gfi_np)}) + all_chart_uvs.append(uvs) + all_chart_3d_areas.append(float(areas_cpu[c].item())) + all_chart_uv_areas.append(uv_area_sum) + all_chart_faces.append(lf) + + # auto-tune texel density to land near `resolution` (assumes ~0.62 pack fill) + total_3d_area = sum(all_chart_3d_areas) or 1.0 + target_dim = float(resolution) if resolution > 0 else 1024.0 + tex_per_unit = math.sqrt((target_dim * target_dim) * 0.62 / total_3d_area) + + placements, atlas_w, atlas_h = _uv_pack.pack_bitmap( + all_chart_uvs, all_chart_3d_areas, all_chart_uv_areas, all_chart_faces, + texels_per_unit=tex_per_unit, padding_texels=padding) + placed = _uv_pack.apply_placements(all_chart_uvs, placements, atlas_w, atlas_h) + + n_in_faces = mesh.faces.shape[0] + out_indices = np.zeros((n_in_faces, 3), dtype=np.int64) + out_uvs_list, out_vmap_list, v_cursor = [], [], 0 + for c, rec in enumerate(chart_records): + vmap_np = rec["vmap"].cpu().numpy() + local_faces_np = rec["local_faces"].cpu().numpy() + global_face_idx = rec["global_face_idx"].cpu().numpy() + out_uvs_list.append(placed[c].cpu().numpy()) + if welded_to_orig is not None: + vmap_np = welded_to_orig[vmap_np] + out_vmap_list.append(vmap_np) + out_indices[global_face_idx] = local_faces_np + v_cursor + v_cursor += vmap_np.shape[0] + + vmapping_out = np.concatenate(out_vmap_list) if out_vmap_list else np.empty(0, dtype=np.int64) + uvs_out = np.concatenate(out_uvs_list) if out_uvs_list else np.empty((0, 2), dtype=np.float32) + return vmapping_out, out_indices, uvs_out + + +class UnwrapMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="UnwrapMesh", + display_name="Unwrap Mesh UVs", + category="latent/3d", + description=( + "Generates a UV atlas (pure-torch, no xatlas dependency): segments the surface into " + "charts, parameterizes each, and packs them into a [0,1] atlas. Verts on chart seams " + "are duplicated. Run after DecimateMesh/RemeshMesh, before texture baking." + ), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Combo.Input("segmenter", options=["pec", "adaptive"], default="pec", + tooltip="pec: fast parallel-edge-collapse charting (CUDA; falls back to " + "adaptive on CPU). adaptive: CPU charting, slower."), + IO.Int.Input("resolution", default=1024, min=0, max=8192, step=256, + tooltip="Target atlas resolution used to auto-scale texel density (0 = fit-to-content)."), + IO.Int.Input("padding", default=1, min=0, max=16, + tooltip="Texel padding between charts in the packed atlas."), + IO.Float.Input("weld_distance", default=0.0, min=0.0, max=1.0, step=0.0001, + tooltip="Merge radius for coincident verts as a fraction of mesh extent " + "(0 = auto, 1e-5). Raise to ~0.001 if you get per-triangle charts " + "(unwelded / triangle-soup input)."), + ], + outputs=[IO.Mesh.Output("mesh")], + hidden=[IO.Hidden.unique_id], + ) + + @classmethod + def execute(cls, mesh, segmenter, resolution, padding, weld_distance): + compute_device = comfy.model_management.get_torch_device() + seg = segmenter + if seg == "pec" and compute_device.type != "cuda": + seg = "adaptive" + seg_device = compute_device if seg == "pec" else torch.device("cpu") + + is_list = isinstance(mesh.vertices, list) + is_batched = not is_list and mesh.vertices.ndim == 3 + bsz = len(mesh.vertices) if is_list else (mesh.vertices.shape[0] if is_batched else 1) + bar = comfy.utils.ProgressBar(bsz) + + out_v, out_f, out_uv, out_c = [], [], [], [] + for i in range(bsz): + if is_list or is_batched: + vi, fi = mesh.vertices[i], mesh.faces[i] + ci = None + vc = getattr(mesh, "vertex_colors", None) + if vc is not None: + ci = vc[i] if (isinstance(vc, list) or vc.ndim == 3) else vc + else: + vi, fi = mesh.vertices, mesh.faces + ci = getattr(mesh, "vertex_colors", None) + + src_device = vi.device + vnp = vi.detach().cpu().numpy().astype(np.float32) + extent = float(np.linalg.norm(vnp.max(0) - vnp.min(0))) if vnp.shape[0] else 0.0 + weld_abs = weld_distance * extent if weld_distance > 0.0 else 0.0 + + vmapping, indices, uvs = _uv_unwrap( + vi.to(seg_device).float(), fi.to(seg_device).long(), + seg, int(resolution), int(padding), weld_abs) + uvs = uvs.copy() + uvs[:, 1] = 1.0 - uvs[:, 1] # UV y flipped vs trimesh + + out_v.append(torch.from_numpy(vnp[vmapping]).to(src_device)) + out_f.append(torch.from_numpy(indices).to(device=src_device, dtype=torch.long)) + out_uv.append(torch.from_numpy(uvs.astype(np.float32)).to(src_device)) + if ci is not None: + cnp = ci.detach().cpu().numpy() + out_c.append(torch.from_numpy(np.ascontiguousarray(cnp[vmapping])).to(src_device)) + bar.update(1) + + out_mesh = _pack_uv_meshes(out_v, out_f, out_uv, out_c if out_c else None) + if getattr(mesh, "texture", None) is not None: + out_mesh.texture = mesh.texture + + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text( + f"UV: {out_v[0].shape[0]}v / {out_f[0].shape[0]}f, atlas ~{resolution}px", + cls.hidden.unique_id) + return IO.NodeOutput(out_mesh) + + +def _uv_sorted_edge_keys(indices: np.ndarray): + """Undirected edge keys per face-edge, sorted; returns (sorted_keys, face_id, lo, hi, first_mask).""" + a = indices.ravel().astype(np.int64) + b = np.roll(indices, -1, axis=1).ravel().astype(np.int64) + lo = np.minimum(a, b) + hi = np.maximum(a, b) + V = int(indices.max()) + 1 + key = lo * V + hi + order = np.argsort(key, kind="stable") + sk = key[order] + fid = (np.arange(a.size, dtype=np.int64) // 3)[order] + first = np.ones(sk.size, dtype=bool) + first[1:] = sk[1:] != sk[:-1] + return sk, fid, lo[order], hi[order], first + + +def _uv_faces_to_chart_ids(indices: np.ndarray) -> np.ndarray: + """Chart = connected component of faces adjacent iff they share a (non-seam-duplicated) UV vertex.""" + F = indices.shape[0] + if F == 0: + return np.empty(0, dtype=np.int64) + _sk, fid, _lo, _hi, first = _uv_sorted_edge_keys(indices) + group_id = np.cumsum(first) - 1 + starts = np.nonzero(first)[0] + rows = fid[starts[group_id[~first]]] + cols = fid[~first] + if rows.size == 0: + return np.arange(F, dtype=np.int64) + adj = csr_matrix((np.ones(rows.size, dtype=np.int8), (rows, cols)), shape=(F, F)) + _, labels = connected_components(adj, directed=False) + return labels.astype(np.int64) + + +_UV_TAB20 = np.array([ + [0.121568627, 0.466666667, 0.705882353], [0.682352941, 0.780392157, 0.909803922], + [1.000000000, 0.498039216, 0.054901961], [1.000000000, 0.733333333, 0.470588235], + [0.172549020, 0.627450980, 0.172549020], [0.596078431, 0.874509804, 0.541176471], + [0.839215686, 0.152941176, 0.156862745], [1.000000000, 0.596078431, 0.588235294], + [0.580392157, 0.403921569, 0.741176471], [0.772549020, 0.690196078, 0.835294118], + [0.549019608, 0.337254902, 0.294117647], [0.768627451, 0.611764706, 0.580392157], + [0.890196078, 0.466666667, 0.760784314], [0.968627451, 0.713725490, 0.823529412], + [0.498039216, 0.498039216, 0.498039216], [0.780392157, 0.780392157, 0.780392157], + [0.737254902, 0.741176471, 0.133333333], [0.858823529, 0.858823529, 0.552941176], + [0.090196078, 0.745098039, 0.811764706], [0.619607843, 0.854901961, 0.898039216], +], dtype=np.float32) + + +def _uv_palette(n: int) -> np.ndarray: + rng = np.random.RandomState(42) + perm = rng.permutation(max(1, n)) + out = np.empty((n, 3), dtype=np.float32) + for i in range(n): + out[i] = _UV_TAB20[perm[i % len(perm)] % 20] + return out + + +def _uv_render_atlas(uvs_np, indices_np, resolution, device, + bg=(0.13, 0.13, 0.13), edge=(0.0, 0.0, 0.0)): + """Tile-based torch rasterizer of the UV atlas (charts colored, borders outlined); returns (H,W,3).""" + w = h = int(resolution) + chart_ids_np = _uv_faces_to_chart_ids(indices_np) + uvs = torch.from_numpy(uvs_np).to(device=device, dtype=torch.float32) + indices = torch.from_numpy(indices_np).to(device=device, dtype=torch.long) + chart_ids = torch.from_numpy(chart_ids_np).to(device=device, dtype=torch.long) + + img = torch.zeros((h, w, 3), dtype=torch.float32, device=device) + img[..., 0] = bg[0]; img[..., 1] = bg[1]; img[..., 2] = bg[2] + if indices.numel() == 0: + return img + + n_charts = int(chart_ids.max().item()) + 1 if chart_ids.numel() else 1 + colors = torch.from_numpy(_uv_palette(n_charts)).to(device=device, dtype=torch.float32) + + uv_px = uvs.clone() + uv_px[:, 0] = uv_px[:, 0].clamp(0.0, 1.0) * (w - 1) + uv_px[:, 1] = uv_px[:, 1].clamp(0.0, 1.0) * (h - 1) + + tri = uv_px[indices] + x0 = tri[:, 0, 0]; y0 = tri[:, 0, 1] + x1 = tri[:, 1, 0]; y1 = tri[:, 1, 1] + x2 = tri[:, 2, 0]; y2 = tri[:, 2, 1] + denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2) + nondegen = denom.abs() > 1e-20 + + xmin = torch.minimum(torch.minimum(x0, x1), x2).floor().clamp_(0, w - 1).long() + xmax = torch.maximum(torch.maximum(x0, x1), x2).ceil().clamp_(0, w - 1).long() + ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, h - 1).long() + ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, h - 1).long() + + # full point-in-triangle over all (pixel, tri) pairs is O(H*W*F); tile and test only bbox-overlapping tris + TILE = 64 + eps = 1e-6 + for ty in range(0, h, TILE): + ty_end = min(ty + TILE, h) + for tx in range(0, w, TILE): + tx_end = min(tx + TILE, w) + tri_mask = (nondegen & (xmin < tx_end) & (xmax >= tx) + & (ymin < ty_end) & (ymax >= ty)) + if not tri_mask.any(): + continue + idx = torch.nonzero(tri_mask, as_tuple=True)[0] + ys = torch.arange(ty, ty_end, dtype=torch.float32, device=device) + 0.5 + xs = torch.arange(tx, tx_end, dtype=torch.float32, device=device) + 0.5 + yy, xx = torch.meshgrid(ys, xs, indexing="ij") + sub_x0 = x0[idx][:, None, None]; sub_y0 = y0[idx][:, None, None] + sub_x1 = x1[idx][:, None, None]; sub_y1 = y1[idx][:, None, None] + sub_x2 = x2[idx][:, None, None]; sub_y2 = y2[idx][:, None, None] + sub_den = denom[idx][:, None, None] + bx = ((sub_y1 - sub_y2) * (xx - sub_x2) + (sub_x2 - sub_x1) * (yy - sub_y2)) / sub_den + by = ((sub_y2 - sub_y0) * (xx - sub_x2) + (sub_x0 - sub_x2) * (yy - sub_y2)) / sub_den + bz = 1.0 - bx - by + inside = (bx >= -eps) & (by >= -eps) & (bz >= -eps) + if not inside.any(): + continue + hit_any = inside.any(dim=0) + best_tri = idx[inside.int().argmax(dim=0)] + tile_color = colors[chart_ids[best_tri]] + tile_img = img[ty:ty_end, tx:tx_end] + tile_img[hit_any] = tile_color[hit_any] + img[ty:ty_end, tx:tx_end] = tile_img + + # chart outlines: a chart border is an open boundary in UV space (seam verts duplicated) → edges with 1 incident face + _sk, _fid, lo, hi, first = _uv_sorted_edge_keys(indices_np) + starts = np.nonzero(first)[0] + counts = np.diff(np.append(starts, first.size)) + boundary = counts == 1 + uv_cpu = uv_px.cpu().numpy() + px_xs, px_ys = [], [] + for a, b in zip(lo[starts[boundary]], hi[starts[boundary]]): + p0 = uv_cpu[a]; p1 = uv_cpu[b] + steps = int(max(abs(p1[0] - p0[0]), abs(p1[1] - p0[1])) + 1) + if steps <= 1: + continue + ts = np.linspace(0.0, 1.0, steps) + xs = (p0[0] + (p1[0] - p0[0]) * ts).astype(np.int32) + ys = (p0[1] + (p1[1] - p0[1]) * ts).astype(np.int32) + valid = (xs >= 0) & (xs < w) & (ys >= 0) & (ys < h) + px_xs.append(xs[valid]); px_ys.append(ys[valid]) + if px_xs: + xs_all = torch.from_numpy(np.concatenate(px_xs)).to(device=device, dtype=torch.long) + ys_all = torch.from_numpy(np.concatenate(px_ys)).to(device=device, dtype=torch.long) + img[ys_all, xs_all] = torch.tensor(edge, dtype=torch.float32, device=device) + + return img + + +class RenderUVAtlas(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RenderUVAtlas", + display_name="Render UV Atlas", + category="latent/3d", + description=("Renders a mesh's UV layout as an image — each chart a distinct color, " + "outlined where it borders other charts. Run UnwrapMesh first."), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Int.Input("resolution", default=1024, min=64, max=4096, step=64), + ], + outputs=[IO.Image.Output("image")], + ) + + @classmethod + def execute(cls, mesh, resolution): + uvs_t = getattr(mesh, "uvs", None) + if uvs_t is None: + raise RuntimeError("mesh has no UVs to render. Run UnwrapMesh first.") + uvs_np = uvs_t.detach().cpu().numpy() + if uvs_np.ndim == 3: + uvs_np = uvs_np[0] + f = mesh.faces + if torch.is_tensor(f): + f = f.detach().cpu().numpy() + if f.ndim == 3: + f = f[0] + f = np.ascontiguousarray(f, dtype=np.int64) + uvs_np = np.ascontiguousarray(uvs_np, dtype=np.float32) + device = comfy.model_management.get_torch_device() + img = _uv_render_atlas(uvs_np, f, int(resolution), device) + return IO.NodeOutput(img.detach().cpu().unsqueeze(0)) + + class FillHoles(IO.ComfyNode): @classmethod def define_schema(cls): @@ -2379,6 +2950,9 @@ class PostProcessMeshExtension(ComfyExtension): FillHolesV2, WeldVertices, DecimateMesh, + RemeshMesh, + UnwrapMesh, + RenderUVAtlas, PaintMesh, BakeTextureFromVoxel, MeshTextureToImage,