Remesh, UV unwrap

This commit is contained in:
kijai 2026-06-17 00:59:58 +03:00
parent 72ff035fe0
commit 6ef69849a0
7 changed files with 3741 additions and 1 deletions

View File

@ -1618,3 +1618,76 @@ def simplify(
out_n if out_n else None,
out_s)
return qem_simplify(vertices, faces, target, colors, normals, max_edge_length, config)
def cluster_decimate(
vertices: torch.Tensor, faces: torch.Tensor,
target_verts: int = 1_000_000,
colors: Optional[torch.Tensor] = None,
face_chunk: int = 4_000_000,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Vertex-cluster decimation (Rossignac-Borrel): bin verts into a ~target_verts grid,
average per cell, remap faces (chunked), drop degenerate/duplicate. Fast O(V+F) prepass
for huge meshes before QEM/remesh. Returns (verts, faces, colors)."""
if vertices.shape[0] == 0 or faces.shape[0] == 0:
return vertices, faces, colors
device = vertices.device
bbox = vertices.max(dim=0)[0] - vertices.min(dim=0)[0]
bbox_min = vertices.min(dim=0)[0]
# cell size so the bbox holds ~3× target_verts cells (surface occupancy ~1/3)
cell_count_target = max(target_verts * 3, 1000)
extent_max = float(bbox.max().item())
cells_per_axis = (cell_count_target ** (1 / 3))
cell_size = extent_max / max(1.0, cells_per_axis)
scale = 1.0 / max(cell_size, 1e-20)
q = ((vertices - bbox_min) * scale).floor().to(torch.int64)
extent = (bbox * scale).floor().to(torch.int64) + 2
Wy = extent[1]
Wz = extent[2]
key = (q[:, 0] * Wy + q[:, 1]) * Wz + q[:, 2]
unique_key, inv = torch.unique(key, return_inverse=True)
n_unique = unique_key.shape[0]
counts = torch.zeros(n_unique, dtype=vertices.dtype, device=device)
counts.scatter_add_(0, inv, torch.ones(vertices.shape[0], dtype=vertices.dtype, device=device))
counts_div = counts.unsqueeze(-1).clamp_min(1.0)
new_verts = torch.zeros((n_unique, 3), dtype=vertices.dtype, device=device)
new_verts.scatter_add_(0, inv.unsqueeze(-1).expand_as(vertices), vertices)
new_verts = new_verts / counts_div
new_colors = None
if colors is not None:
new_colors = torch.zeros((n_unique, colors.shape[1]), dtype=colors.dtype, device=device)
new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors)
new_colors = new_colors / counts_div.to(colors.dtype)
# remap faces in chunks (face tensor can be huge); drop degenerates per chunk
out_chunks = []
F = faces.shape[0]
for fs in range(0, F, face_chunk):
fe = min(fs + face_chunk, F)
cf = inv[faces[fs:fe].long()]
nondeg = ((cf[:, 0] != cf[:, 1]) & (cf[:, 1] != cf[:, 2]) & (cf[:, 0] != cf[:, 2]))
if nondeg.any():
out_chunks.append(cf[nondeg])
if out_chunks:
new_faces = torch.cat(out_chunks, dim=0)
else:
new_faces = torch.empty((0, 3), dtype=faces.dtype, device=device)
# drop duplicate faces (same vertex set after clustering)
if new_faces.numel() > 0:
key_sorted = torch.sort(new_faces, dim=1)[0]
P = n_unique + 1
packed = (key_sorted[:, 0].long() * P + key_sorted[:, 1].long()) * P + key_sorted[:, 2].long()
_, first = torch.unique(packed, return_inverse=True)
arange = torch.arange(packed.shape[0], device=device, dtype=torch.int64)
first_idx = torch.full((int(first.max().item()) + 1,), packed.shape[0],
dtype=torch.int64, device=device)
first_idx.scatter_reduce_(0, first, arange, reduce="amin", include_self=True)
new_faces = new_faces[first_idx]
return new_verts.to(vertices.dtype), new_faces.to(faces.dtype), new_colors

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,158 @@
"""Mesh container, edge/face adjacency, manifold cleanup."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List
import numpy as np
import torch
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from torch import Tensor
# ---- Per-face / per-vertex geometry ----
def face_normals(vertices: Tensor, faces: Tensor) -> Tensor:
"""[F,3] unit face normals (degenerate faces -> zero)."""
v0 = vertices[faces[:, 0]]; v1 = vertices[faces[:, 1]]; v2 = vertices[faces[:, 2]]
n = torch.linalg.cross(v1 - v0, v2 - v0)
return n / n.norm(dim=1, keepdim=True).clamp_min(1e-20)
def face_areas(vertices: Tensor, faces: Tensor) -> Tensor:
"""[F] triangle areas."""
v0 = vertices[faces[:, 0]]; v1 = vertices[faces[:, 1]]; v2 = vertices[faces[:, 2]]
return 0.5 * torch.linalg.cross(v1 - v0, v2 - v0).norm(dim=1)
def face_centroids(vertices: Tensor, faces: Tensor) -> Tensor:
"""[F,3] triangle centroids."""
return vertices[faces].mean(dim=1)
def face_edge_lengths(vertices: Tensor, faces: Tensor) -> Tensor:
"""[F,3] edge lengths; column e = |v[faces[:,e]] - v[faces[:,(e+1)%3]]|."""
va = vertices[faces]
vb = vertices[faces.roll(shifts=-1, dims=1)]
return (vb - va).norm(dim=-1).to(torch.float32)
def chart_3d_areas(face_area: Tensor, face_chart: Tensor, n_charts: int) -> Tensor:
"""[n_charts] sum of face areas per chart."""
out = torch.zeros(n_charts, dtype=face_area.dtype, device=face_area.device)
out.scatter_add_(0, face_chart, face_area)
return out
@dataclass
class MeshData:
"""Cleaned mesh with adjacency; face_face[f, i] = face sharing edge (faces[f,i], faces[f,(i+1)%3]) or -1 if boundary."""
vertices: Tensor # [V, 3] float
faces: Tensor # [F, 3] long
face_face: Tensor # [F, 3] long, neighbor face id or -1
face_normal: Tensor # [F, 3] float
face_area: Tensor # [F] float
face_centroid: Tensor # [F, 3] float
component: Tensor # [F] long, connected-component id
n_components: int
def build_mesh(vertices: Tensor, faces: Tensor) -> MeshData:
"""Build adjacency; non-manifold edges (>2 incident faces) get no neighbor and act as boundary."""
if vertices.dtype != torch.float32:
vertices = vertices.to(torch.float32)
if faces.dtype != torch.long:
faces = faces.to(torch.long)
device = faces.device
V = vertices.shape[0]
F = faces.shape[0]
# Per directed face-edge; flat layout p = f*3+i.
a = faces.flatten()
b = faces.roll(shifts=-1, dims=1).flatten()
lo = torch.minimum(a, b)
hi = torch.maximum(a, b)
edge_key = lo * (V + 1) + hi
# Pair manifold (count==2) face-edges; others get no neighbor.
_, inverse, counts = torch.unique(edge_key, return_inverse=True, return_counts=True)
edge_count = counts[inverse]
manifold_mask = edge_count == 2
sort_idx = torch.argsort(edge_key, stable=True)
sorted_manifold = manifold_mask[sort_idx]
pair_positions = sort_idx[sorted_manifold]
pair_a = pair_positions[0::2]
pair_b = pair_positions[1::2]
face_id_flat = torch.arange(F, device=device).repeat_interleave(3)
face_face_flat = torch.full((3 * F,), -1, dtype=torch.long, device=device)
face_face_flat[pair_a] = face_id_flat[pair_b]
face_face_flat[pair_b] = face_id_flat[pair_a]
face_face = face_face_flat.view(F, 3)
face_face_np = face_face.cpu().numpy()
rows_mask = face_face_np >= 0
if rows_mask.any():
rows = np.broadcast_to(np.arange(F)[:, None], (F, 3))[rows_mask]
cols = face_face_np[rows_mask]
adj = csr_matrix(
(np.ones(rows.size, dtype=np.int8), (rows, cols)),
shape=(F, F),
)
else:
adj = csr_matrix((F, F), dtype=np.int8)
n_components, labels = connected_components(adj, directed=False)
face_normal = face_normals(vertices, faces)
face_area = face_areas(vertices, faces)
face_centroid = face_centroids(vertices, faces)
return MeshData(
vertices=vertices,
faces=faces,
face_face=face_face,
face_normal=face_normal,
face_area=face_area,
face_centroid=face_centroid,
component=torch.from_numpy(labels.astype(np.int64)).to(device),
n_components=int(n_components),
)
def chart_boundary_loops(
faces_subset: Tensor, face_face_subset: Tensor
) -> List[List[int]]:
"""Return ordered boundary vertex loops for a chart submesh (face_face_subset[f,i]==-1 marks a boundary edge)."""
F = faces_subset.shape[0]
faces_np = faces_subset.cpu().numpy()
ff = face_face_subset.cpu().numpy()
next_v: Dict[int, int] = {}
for f in range(F):
for i in range(3):
if ff[f, i] == -1:
a = int(faces_np[f, i])
b = int(faces_np[f, (i + 1) % 3])
next_v[a] = b
loops: List[List[int]] = []
visited = set()
for start in list(next_v.keys()):
if start in visited:
continue
loop = [start]
visited.add(start)
cur = next_v.get(start)
while cur is not None and cur != start:
if cur in visited:
break
loop.append(cur)
visited.add(cur)
cur = next_v.get(cur)
if len(loop) >= 3:
loops.append(loop)
return loops

View File

@ -0,0 +1,759 @@
"""Atlas packing via bitmap rasterize-and-place."""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
import torch
from torch import Tensor
try:
from numba import njit as _njit
_HAVE_NUMBA_PACK = True
except ImportError:
_HAVE_NUMBA_PACK = False
def _njit(*args, **kwargs):
def deco(fn): return fn
return deco if not args else args[0]
@dataclass
class ChartPlacement:
chart_id: int
offset: Tuple[float, float] # in texels
scale: float # texels per UV unit
rotation: float = 0.0 # radians
swap_xy: bool = False # extra 90° bitmap rotation chosen at place time
chart_h: float = 0.0 # unswapped bitmap height in texels (rotation pivot)
@_njit(cache=True, boundscheck=False)
def _best_rotation_jit(uvs_np: np.ndarray, n_angles: int) -> float:
V = uvs_np.shape[0]
best_area = 1e30
best_theta = 0.0
if V == 0:
return 0.0
half_pi = math.pi * 0.5
for k in range(n_angles):
theta = half_pi * k / n_angles
c = math.cos(theta); s = math.sin(theta)
xmin = 1e30; xmax = -1e30
ymin = 1e30; ymax = -1e30
for i in range(V):
ux = uvs_np[i, 0]; uy = uvs_np[i, 1]
xr = ux * c - uy * s
yr = ux * s + uy * c
if xr < xmin: xmin = xr
if xr > xmax: xmax = xr
if yr < ymin: ymin = yr
if yr > ymax: ymax = yr
area = (xmax - xmin) * (ymax - ymin)
if area < best_area:
best_area = area
best_theta = theta
return best_theta
def _best_rotation(uvs_np: np.ndarray, n_angles: int = 36) -> float:
return float(_best_rotation_jit(uvs_np.astype(np.float64), n_angles))
def _rotate_xy(uv: np.ndarray, theta: float) -> np.ndarray:
if theta == 0.0:
return uv
c = math.cos(theta)
s = math.sin(theta)
return np.stack([uv[:, 0] * c - uv[:, 1] * s, uv[:, 0] * s + uv[:, 1] * c], axis=1)
@_njit(cache=True, boundscheck=False)
def _rasterize_chart_jit(
uvs_tex: np.ndarray, faces: np.ndarray, w: int, h: int
) -> np.ndarray:
"""JIT-rasterize triangles into an (h, w) bool bitmap via barycentric test."""
bm = np.zeros((h, w), dtype=np.bool_)
F = faces.shape[0]
eps = 1e-7
for fi in range(F):
i0 = faces[fi, 0]; i1 = faces[fi, 1]; i2 = faces[fi, 2]
x0 = uvs_tex[i0, 0]; y0 = uvs_tex[i0, 1]
x1 = uvs_tex[i1, 0]; y1 = uvs_tex[i1, 1]
x2 = uvs_tex[i2, 0]; y2 = uvs_tex[i2, 1]
xmin_f = x0
if x1 < xmin_f: xmin_f = x1
if x2 < xmin_f: xmin_f = x2
xmax_f = x0
if x1 > xmax_f: xmax_f = x1
if x2 > xmax_f: xmax_f = x2
ymin_f = y0
if y1 < ymin_f: ymin_f = y1
if y2 < ymin_f: ymin_f = y2
ymax_f = y0
if y1 > ymax_f: ymax_f = y1
if y2 > ymax_f: ymax_f = y2
xmin = int(math.floor(xmin_f))
if xmin < 0: xmin = 0
xmax = int(math.ceil(xmax_f))
if xmax > w - 1: xmax = w - 1
ymin = int(math.floor(ymin_f))
if ymin < 0: ymin = 0
ymax = int(math.ceil(ymax_f))
if ymax > h - 1: ymax = h - 1
if xmax < xmin or ymax < ymin:
continue
denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2)
if abs(denom) < 1e-20:
continue
inv_denom = 1.0 / denom
for py in range(ymin, ymax + 1):
yc = py + 0.5
for px in range(xmin, xmax + 1):
xc = px + 0.5
a = ((y1 - y2) * (xc - x2) + (x2 - x1) * (yc - y2)) * inv_denom
b = ((y2 - y0) * (xc - x2) + (x0 - x2) * (yc - y2)) * inv_denom
c = 1.0 - a - b
if a >= -eps and b >= -eps and c >= -eps:
bm[py, px] = True
return bm
def _rasterize_chart(
uvs_tex: np.ndarray, faces: np.ndarray, w: int, h: int, padding: int
) -> np.ndarray:
"""Rasterize chart triangles into (h, w) bool bitmap, dilated by padding texels."""
if faces.size == 0:
return np.zeros((h, w), dtype=bool)
bm = _rasterize_chart_jit(
uvs_tex.astype(np.float64), faces.astype(np.int64), int(w), int(h)
)
if padding > 0:
bm = _dilate_bitmap(bm, padding)
return bm
def _dilate_bitmap(bm: np.ndarray, k: int) -> np.ndarray:
"""k-step Manhattan max-filter dilation."""
out = bm.copy()
for _ in range(k):
next_out = out.copy()
next_out[1:, :] |= out[:-1, :]
next_out[:-1, :] |= out[1:, :]
next_out[:, 1:] |= out[:, :-1]
next_out[:, :-1] |= out[:, 1:]
out = next_out
return out
@_njit(cache=True, boundscheck=False)
def _build_candidates_jit(
skyline: np.ndarray,
cur_w: int, cur_h: int,
bw0: int, bh0: int, bw1: int, bh1: int,
step: int,
) -> np.ndarray:
"""Build per-chart (x, y, swap_flag) candidate positions (skyline-flush + edge-sweep, both orientations)."""
nx_skyline = (max(cur_w, 1) // step) + 2
nx_edge = (max(cur_w, 1) // step) + 2
ny_edge = (max(cur_h, 1) // step) + 2
per_orient = nx_skyline + 2 * nx_edge + 2 * ny_edge
out = np.empty((per_orient * 2, 3), dtype=np.int64)
k = 0
for swap_flag in range(2):
cw = bw0 if swap_flag == 0 else bw1
x = 0
while x <= cur_w:
y = 0
x_end = x + cw
if x_end > skyline.shape[0]:
x_end = skyline.shape[0]
for xs in range(x, x_end):
if skyline[xs] > y:
y = int(skyline[xs])
out[k, 0] = x; out[k, 1] = y; out[k, 2] = swap_flag
k += 1
x += step
for y_fixed in (0, cur_h):
x = 0
while x <= cur_w:
out[k, 0] = x; out[k, 1] = y_fixed; out[k, 2] = swap_flag
k += 1
x += step
for x_fixed in (0, cur_w):
y = 0
while y <= cur_h:
out[k, 0] = x_fixed; out[k, 1] = y; out[k, 2] = swap_flag
k += 1
y += step
return out[:k]
@_njit(cache=True, boundscheck=False)
def _update_skyline_jit(skyline: np.ndarray, chart: np.ndarray,
x: int, y: int) -> None:
"""Lift skyline[x+i] to y + topmost_True_row + 1 per chart column."""
ch = chart.shape[0]; cw = chart.shape[1]
sw = skyline.shape[0]
for i in range(cw):
col_x = x + i
if col_x >= sw or col_x < 0:
continue
col_top = -1
for j in range(ch - 1, -1, -1):
if chart[j, i]:
col_top = j
break
if col_top < 0:
continue
new_h = y + col_top + 1
if new_h > skyline[col_x]:
skyline[col_x] = new_h
@_njit(cache=True, boundscheck=False)
def _best_placement_jit(
atlas: np.ndarray,
bitmap: np.ndarray,
bitmap_rot: np.ndarray,
candidates: np.ndarray,
cur_w: int,
cur_h: int,
):
"""Pick lowest-score non-colliding candidate (score = max(new_w,new_h)^2 + new_w*new_h); out-of-atlas treated as free."""
n = candidates.shape[0]
best_x = -1
best_y = -1
best_score = -1
best_swap = 0
bh0 = bitmap.shape[0]; bw0 = bitmap.shape[1]
bh1 = bitmap_rot.shape[0]; bw1 = bitmap_rot.shape[1]
ah = atlas.shape[0]; aw = atlas.shape[1]
for k in range(n):
x = candidates[k, 0]
y = candidates[k, 1]
swap = candidates[k, 2]
if swap == 0:
ch = bh0; cw = bw0
else:
ch = bh1; cw = bw1
if x < 0 or y < 0:
continue
nw = cur_w if cur_w > x + cw else x + cw
nh = cur_h if cur_h > y + ch else y + ch
ext = nw if nw > nh else nh
score = ext * ext + nw * nh
if best_score >= 0 and score >= best_score:
continue
ok = True
for j in range(ch):
yy = y + j
if yy >= ah:
continue
for i in range(cw):
bit = bitmap[j, i] if swap == 0 else bitmap_rot[j, i]
if not bit:
continue
xx = x + i
if xx >= aw:
continue
if atlas[yy, xx]:
ok = False
break
if not ok:
break
if not ok:
continue
best_x = x; best_y = y
best_score = score; best_swap = swap
if x + cw <= cur_w and y + ch <= cur_h:
break
return best_x, best_y, best_score, best_swap
def _blit(atlas: np.ndarray, chart: np.ndarray, x: int, y: int) -> None:
ah, aw = atlas.shape
ch, cw = chart.shape
atlas[y: y + ch, x: x + cw] |= chart
@dataclass
class _PreparedChart:
chart_id: int
uvs_tex: np.ndarray # [V, 2] in texel coords (rotated, scaled, origin 0)
bitmap: np.ndarray # [h, w] bool, padded
bitmap_rot: np.ndarray # 90° rotated bitmap (for swap_xy placement)
bbox_w: int
bbox_h: int
rotation: float # radians, applied to UVs
s_tex: float # texels per UV unit
perimeter: float # for chart ordering
@_njit(cache=True, boundscheck=False)
def _chart_perimeter_jit(uvs: np.ndarray, faces: np.ndarray, V: int) -> float:
"""Sum unique-edge lengths via sorted int64 edge keys."""
F = faces.shape[0]
keys = np.empty(F * 3, dtype=np.int64)
for fi in range(F):
for j in range(3):
a = faces[fi, j]
b = faces[fi, (j + 1) % 3]
if a < b:
keys[fi * 3 + j] = a * V + b
else:
keys[fi * 3 + j] = b * V + a
keys = np.sort(keys)
p = 0.0
for i in range(keys.shape[0]):
if i > 0 and keys[i] == keys[i - 1]:
continue
a = keys[i] // V
b = keys[i] % V
dx = uvs[a, 0] - uvs[b, 0]
dy = uvs[a, 1] - uvs[b, 1]
p += math.sqrt(dx * dx + dy * dy)
return p
def _chart_perimeter(uvs: np.ndarray, faces: np.ndarray) -> float:
V = int(faces.max()) + 1 if faces.size else 0
return float(_chart_perimeter_jit(uvs.astype(np.float64), faces.astype(np.int64), V))
# ---- Torch fallback (used when numba is unavailable; runs on GPU if present) ----
def _dilate_local(x: Tensor, p: int) -> Tensor:
"""4-connectivity dilation by p, applied per-image over a batch of (cnt,g,g) bitmaps.
Matches the old per-chart _dilate_torch; dilation distributes over union so per-triangle
dilation OR-scattered equals dilating the assembled chart bitmap."""
for _ in range(p):
y = x.clone()
y[:, 1:, :] |= x[:, :-1, :]; y[:, :-1, :] |= x[:, 1:, :]
y[:, :, 1:] |= x[:, :, :-1]; y[:, :, :-1] |= x[:, :, 1:]
x = y
return x
def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device):
"""Batched rasterize EVERY chart at once into one flat bool buffer, replacing the per-chart
loop. Returns (buf, cbase) where buf[cbase[i]:cbase[i+1]].view(bh,bw) is chart i's [y,x] bitmap.
Triangles are bucketed by next-pow2 bbox size so each batch's local grid stays tiny (bounded
memory) while collapsing ~N chart rasters into a handful of kernels."""
n = uvs_tex_pad.shape[0]
fmax = faces_pad.shape[1]
bwL, bhL = bw_t.long(), bh_t.long()
cbase = torch.zeros(n + 1, dtype=torch.long, device=device)
torch.cumsum(bwL * bhL, 0, out=cbase[1:])
buf = torch.zeros(int(cbase[-1].item()), dtype=torch.bool, device=device)
# gather all triangle coords, keep only valid faces -> (Ttot,3,2) + chart id per triangle
fp = faces_pad.reshape(n, fmax * 3)
tri = torch.gather(uvs_tex_pad, 1, fp[..., None].expand(-1, -1, 2)).reshape(n * fmax, 3, 2)
fm = fmask.reshape(-1)
tri_f = tri[fm]
if tri_f.shape[0] == 0:
return buf, cbase
cid = torch.arange(n, device=device).repeat_interleave(fmax)[fm]
# per-triangle pixel bbox, inflated by padding (origin >= 0); bucket by next-pow2 max-dim
tmin = tri_f.amin(1); tmax = tri_f.amax(1)
x0 = (tmin[:, 0].floor().long() - padding).clamp_min(0)
y0 = (tmin[:, 1].floor().long() - padding).clamp_min(0)
bbw = (tmax[:, 0].ceil().long() + padding) - x0 + 1
bbh = (tmax[:, 1].ceil().long() + padding) - y0 + 1
mxd = torch.maximum(bbw, bbh).clamp_min(1)
bsz = (2 ** torch.ceil(torch.log2(mxd.float())).long()).long()
a = tri_f[:, 0]; b = tri_f[:, 1]; c = tri_f[:, 2]
v0 = b - a; v1 = c - a
d00 = (v0 * v0).sum(-1); d01 = (v0 * v1).sum(-1); d11 = (v1 * v1).sum(-1)
den = (d00 * d11 - d01 * d01).clamp(min=1e-20)
for g in sorted(set(bsz.tolist())): # one batch per pow2 grid
sel = (bsz == g).nonzero(as_tuple=True)[0]
m = sel.shape[0]
xs0 = x0[sel].view(m, 1, 1); ys0 = y0[sel].view(m, 1, 1)
cc = cid[sel]; bwp = bwL[cc].view(m, 1, 1); bhp = bhL[cc].view(m, 1, 1)
gi = torch.arange(g, device=device)
px = xs0 + gi.view(1, 1, g); py = ys0 + gi.view(1, g, 1) # (m,g,g) int
pxf = px.float() + 0.5; pyf = py.float() + 0.5
v2x = pxf - a[sel, 0].view(m, 1, 1); v2y = pyf - a[sel, 1].view(m, 1, 1)
d20 = v2x * v0[sel, 0].view(m, 1, 1) + v2y * v0[sel, 1].view(m, 1, 1)
d21 = v2x * v1[sel, 0].view(m, 1, 1) + v2y * v1[sel, 1].view(m, 1, 1)
idn = den[sel].view(m, 1, 1).reciprocal()
vv = torch.addcmul(d11[sel].view(m, 1, 1) * d20, d01[sel].view(m, 1, 1), d21, value=-1) * idn
ww = torch.addcmul(d00[sel].view(m, 1, 1) * d21, d01[sel].view(m, 1, 1), d20, value=-1) * idn
uu = 1.0 - vv - ww
inside = (uu >= -1e-6) & (vv >= -1e-6) & (ww >= -1e-6)
if padding > 0:
inside = _dilate_local(inside, padding)
valid = inside & (px < bwp) & (py < bhp)
flat = (cbase[cc].view(m, 1, 1) + py * bwp + px)[valid]
buf[flat] = True
return buf, cbase
def _build_candidates_gpu(sky_t, cur_w, cur_h, bw0, bw1, step, rand_n, gen, device):
"""Skyline-flush + edge-sweep + random candidate (x,y) positions per orientation, built on the
GPU. Returns (cand0, cand1). Random samples find tight pockets the deterministic grid misses."""
xs = torch.arange(0, max(cur_w, 1) + 1, step, device=device)
ys = torch.arange(0, max(cur_h, 1) + 1, step, device=device)
# edge-sweep candidates are orientation-independent: build once, shared by both orientations
common = [torch.stack([xs, torch.full_like(xs, yf)], 1) for yf in (0, cur_h)]
common += [torch.stack([torch.full_like(ys, xf), ys], 1) for xf in (0, cur_w)]
common = torch.cat(common, 0)
out = []
for cw in (bw0, bw1): # skyline-flush + random differ
if cw > 0 and sky_t.shape[0] >= cw:
wmax = sky_t.unfold(0, cw, 1).amax(1)[xs.clamp(max=max(sky_t.shape[0] - cw, 0))]
else:
wmax = torch.zeros_like(xs)
parts = [torch.stack([xs, wmax], 1), common]
if rand_n > 0: # distinct draws keep density
rx = torch.randint(0, max(cur_w, 1) + 1, (rand_n,), generator=gen, device=device)
ry = torch.randint(0, max(cur_h, 1) + 1, (rand_n,), generator=gen, device=device)
parts.append(torch.stack([rx, ry], 1))
out.append(torch.cat(parts, 0))
return out[0], out[1]
def _col_top(b: Tensor) -> Tensor:
"""Topmost True row index per column of a bool bitmap (h,w); -1 for empty columns."""
h = b.shape[0]
rows = torch.arange(h, device=b.device)[:, None]
return torch.where(b, rows, torch.full_like(rows.expand_as(b), -1)).amax(0)
def _best_placement_torch(atlas, pix0, dim0, pix1, dim1, cand0, cand1, cur_w, cur_h, device):
"""Lowest-score non-colliding candidate as a (3,) int tensor [x, y, swap] (x=-1 if none).
Collision tests only each bitmap's True-pixel offsets (pix), not the full window. Fully on-GPU;
the caller does the single sync (.tolist())."""
INF = 1 << 60
def best(cand, pix, dim): # -> (score, x, y) 0-d tensors
ch, cw = dim
cx, cy = cand[:, 0], cand[:, 1]
coll = atlas[cy[:, None] + pix[:, 0][None, :], # (M,k) True-pixel gather
cx[:, None] + pix[:, 1][None, :]].any(dim=1)
nw = torch.clamp(cx + cw, min=cur_w); nh = torch.clamp(cy + ch, min=cur_h)
ext = torch.maximum(nw, nh)
score = torch.where(coll, torch.full_like(nw, INF), ext * ext + nw * nh)
j = score.argmin()
return score[j], cx[j], cy[j]
s0, x0, y0 = best(cand0, pix0, dim0)
s1, x1, y1 = best(cand1, pix1, dim1)
take0 = s0 <= s1
bsc = torch.where(take0, s0, s1)
pick = torch.stack([torch.where(take0, x0, x1), torch.where(take0, y0, y1),
torch.where(take0, x0.new_zeros(()), x0.new_ones(()))])
return torch.where(bsc < INF, pick, torch.tensor([-1, -1, 0], device=device))
def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
texels_per_unit, padding_texels):
"""Torch rasterize-and-place packer (numba-free fallback). Returns (placements, atlas_w, atlas_h)."""
n = len(chart_uvs)
if n == 0:
return [], 1, 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ang = torch.linspace(0.0, math.pi / 2.0, 37, device=device)[:-1]
cos_a, sin_a = ang.cos(), ang.sin()
# ---- Prepare pass 1: best-rotation + scale + bbox for ALL charts at once (batched) ----
vcount = [int(u.shape[0]) for u in chart_uvs]
fcount = [int(f.shape[0]) for f in chart_faces]
vmax = max(vcount); fmax = max(fcount)
uvs_pad = torch.zeros(n, vmax, 2, device=device)
vmask = torch.zeros(n, vmax, dtype=torch.bool, device=device)
faces_pad = torch.zeros(n, fmax, 3, dtype=torch.long, device=device)
fmask = torch.zeros(n, fmax, dtype=torch.bool, device=device)
for i in range(n):
uvs_pad[i, :vcount[i]] = chart_uvs[i].to(device=device, dtype=torch.float32)
vmask[i, :vcount[i]] = True
if fcount[i]:
faces_pad[i, :fcount[i]] = chart_faces[i].to(device=device, dtype=torch.long)
fmask[i, :fcount[i]] = True
u0, u1 = uvs_pad[..., 0], uvs_pad[..., 1] # (N,Vmax)
BIG = 1e30
mlo = torch.where(vmask, torch.zeros_like(u0), u0.new_full((), BIG))
mhi = torch.where(vmask, torch.zeros_like(u0), u0.new_full((), -BIG))
xr = torch.addcmul(u0[:, :, None] * cos_a, u1[:, :, None], sin_a, value=-1) # (N,Vmax,A)
yr = torch.addcmul(u0[:, :, None] * sin_a, u1[:, :, None], cos_a)
xsp = (xr + mhi[:, :, None]).amax(1) - (xr + mlo[:, :, None]).amin(1) # (N,A) masked span
ysp = (yr + mhi[:, :, None]).amax(1) - (yr + mlo[:, :, None]).amin(1)
ti = (xsp * ysp).argmin(1) # (N,) best angle per chart
cc, ss = cos_a[ti][:, None], sin_a[ti][:, None] # (N,1)
rx = torch.addcmul(u0 * cc, u1, ss, value=-1) # (N,Vmax)
ry = torch.addcmul(u0 * ss, u1, cc)
rxmin = (rx + mlo).amin(1); rxmax = (rx + mhi).amax(1) # (N,)
rymin = (ry + mlo).amin(1); rymax = (ry + mhi).amax(1)
a3 = torch.tensor([max(a, 1e-12) for a in chart_3d_areas], device=device)
au = torch.tensor([max(a, 1e-12) for a in chart_uv_areas], device=device)
base = (a3 / au).sqrt() * texels_per_unit
maxb = (4.0 * a3.sqrt() * texels_per_unit).clamp_min(8.0)
bbm = torch.maximum(rxmax - rxmin, rymax - rymin).clamp_min(1e-12)
scale = torch.minimum(base, maxb / bbm) # (N,)
uvs_tex_pad = torch.stack([(rx - rxmin[:, None]) * scale[:, None],
(ry - rymin[:, None]) * scale[:, None]], dim=-1) # (N,Vmax,2)
bw_t = ((rxmax - rxmin) * scale).ceil().int() + padding_texels + 1
bh_t = ((rymax - rymin) * scale).ceil().int() + padding_texels + 1
# one sync: pull all per-chart scalars
thetas = ang[ti].cpu().tolist()
scales = scale.cpu().tolist()
bws = bw_t.cpu().tolist(); bhs = bh_t.cpu().tolist()
# ---- Prepare pass 2: rasterize ALL charts at once, then trim each bitmap to its bounds ----
buf, cbase = _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding_texels, device)
cb = cbase.cpu().tolist()
raw, bnd = [], []
for i in range(n):
bm = buf[cb[i]:cb[i + 1]].view(bhs[i], bws[i])
raw.append(bm)
rr = torch.arange(bm.shape[0], device=device); cc = torch.arange(bm.shape[1], device=device)
rmax = torch.where(bm.any(1), rr, rr.new_full((), -1)).amax() # last occupied row / col (-1 if empty)
cmax = torch.where(bm.any(0), cc, cc.new_full((), -1)).amax()
bnd.append(torch.stack([rmax, cmax]))
bnd_cpu = torch.stack(bnd).cpu().tolist() # one sync for all trim bounds
# per-chart True-pixel offsets (sparse collision/blit), dims, col-tops (all kept on GPU)
pix_l, pixr_l, dim_l, dimr_l, bm_h = [], [], [], [], []
col_tops, col_tops_rot = [], []
for i in range(n):
rm, cm = bnd_cpu[i]
bm = (raw[i][:rm + 1, :cm + 1].contiguous() if rm >= 0 and cm >= 0
else torch.zeros((1, 1), dtype=torch.bool, device=device))
bm_rot = torch.flip(bm.t(), dims=[1]).contiguous()
pix_l.append(bm.nonzero()); pixr_l.append(bm_rot.nonzero())
dim_l.append((int(bm.shape[0]), int(bm.shape[1])))
dimr_l.append((int(bm_rot.shape[0]), int(bm_rot.shape[1])))
col_tops.append(_col_top(bm)); col_tops_rot.append(_col_top(bm_rot))
bm_h.append(int(bm.shape[0]))
wmax = max(d[1] for d in dim_l + dimr_l)
ct_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device)
ctr_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device)
for i in range(n):
ct_pad[i, :col_tops[i].shape[0]] = col_tops[i]
ctr_pad[i, :col_tops_rot[i].shape[0]] = col_tops_rot[i]
del raw
# ---- Placement: skyline bin-pack on GPU (1 sync/chart for the chosen position) ----
order = sorted(range(n), key=lambda i: -(dim_l[i][0] * dim_l[i][1])) # biggest bitmap first
max_b = max(max(d) for d in dim_l)
margin = max_b + 8
side_guess = int(math.sqrt(sum(d[0] * d[1] for d in dim_l)) * 2) + 16
cap = side_guess + margin
atlas = torch.zeros((cap, cap), dtype=torch.bool, device=device)
sky_t = torch.zeros(cap, dtype=torch.long, device=device)
cur_w = cur_h = 0
placements = [None] * n
gen = torch.Generator(device=device).manual_seed(0)
rand_n = 512 # random samples per orientation
for ci in order:
if cur_h + margin > atlas.shape[0] or cur_w + margin > atlas.shape[1]:
ns = max(atlas.shape[0], cur_h + margin, cur_w + margin)
na = torch.zeros((ns, ns), dtype=torch.bool, device=device)
na[:atlas.shape[0], :atlas.shape[1]] = atlas; atlas = na
nsk = torch.zeros(ns, dtype=torch.long, device=device); nsk[:sky_t.shape[0]] = sky_t; sky_t = nsk
dim, dimr = dim_l[ci], dimr_l[ci]
step = max(1, min(dim[0], dim[1]) // 8)
cand0, cand1 = _build_candidates_gpu(sky_t, cur_w, cur_h, dim[1], dimr[1], step, rand_n, gen, device)
res = _best_placement_torch(atlas, pix_l[ci], dim, pixr_l[ci], dimr, cand0, cand1, cur_w, cur_h, device)
bx, by, swap = (int(v) for v in res.tolist()) # the one sync/chart
if bx < 0:
bx, by, swap = cur_w, 0, 0
pix = pixr_l[ci] if swap else pix_l[ci]
bh_, bw_ = (dimr if swap else dim)
atlas[by + pix[:, 0], bx + pix[:, 1]] = True # sparse blit
cur_w = max(cur_w, bx + bw_); cur_h = max(cur_h, by + bh_)
ct = (ctr_pad if swap else ct_pad)[ci, :bw_] # GPU skyline lift
ix = torch.arange(bx, bx + bw_, device=device)
sky_t[ix] = torch.where(ct >= 0, torch.maximum(sky_t[ix], by + ct + 1), sky_t[ix])
placements[ci] = ChartPlacement(chart_id=ci, offset=(float(bx), float(by)),
scale=scales[ci], rotation=thetas[ci], swap_xy=bool(swap),
chart_h=float(bm_h[ci]))
return placements, cur_w, cur_h
def pack_bitmap(
chart_uvs: List[Tensor],
chart_3d_areas: List[float],
chart_uv_areas: List[float],
chart_faces: List[Tensor],
texels_per_unit: float = 256.0,
padding_texels: int = 2,
attempts: int = 4096,
rng_seed: int = 0,
) -> Tuple[List[ChartPlacement], int, int]:
"""Rasterize-and-place packer. Returns (placements, atlas_w, atlas_h)."""
n = len(chart_uvs)
if n == 0:
return [], 1, 1
if not _HAVE_NUMBA_PACK:
return _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas,
chart_faces, texels_per_unit, padding_texels)
rng = np.random.default_rng(rng_seed)
prepared: List[_PreparedChart] = []
skyline_cap = 4096
skyline = np.zeros(skyline_cap, dtype=np.int64)
for i, (uvs_t, area_3d, area_uv, faces_t) in enumerate(
zip(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces)
):
uvs = uvs_t.detach().cpu().numpy().astype(np.float64)
faces = faces_t.detach().cpu().numpy()
theta = _best_rotation(uvs)
rotated = _rotate_xy(uvs, theta)
scale = math.sqrt(max(area_3d, 1e-12) / max(area_uv, 1e-12)) * texels_per_unit
# Cap per-chart bbox to 4x nominal so a degenerate chart can't span the atlas.
nominal_side = math.sqrt(max(area_3d, 1e-12)) * float(texels_per_unit)
max_bbox_texels = max(8.0, 4.0 * nominal_side)
bbox_uv = (rotated.max(axis=0) - rotated.min(axis=0))
bbox_uv_max = float(max(bbox_uv[0], bbox_uv[1], 1e-12))
if scale * bbox_uv_max > max_bbox_texels:
scale = max_bbox_texels / bbox_uv_max
uvs_tex = rotated * scale
uvs_tex = uvs_tex - uvs_tex.min(axis=0)
bbox_w = int(math.ceil(uvs_tex[:, 0].max())) + padding_texels + 1
bbox_h = int(math.ceil(uvs_tex[:, 1].max())) + padding_texels + 1
bm = _rasterize_chart(uvs_tex, faces, bbox_w, bbox_h, padding_texels)
nz_rows = np.where(bm.any(axis=1))[0]
nz_cols = np.where(bm.any(axis=0))[0]
if nz_rows.size == 0 or nz_cols.size == 0:
bm = np.zeros((1, 1), dtype=bool)
bbox_h, bbox_w = 1, 1
else:
bm = bm[: nz_rows[-1] + 1, : nz_cols[-1] + 1]
bbox_h, bbox_w = bm.shape
# True 90 deg rotation; plain transpose would mirror and flip winding.
bm_rot = bm.T[:, ::-1].copy()
perim = _chart_perimeter(uvs_tex, faces)
prepared.append(
_PreparedChart(
chart_id=i,
uvs_tex=uvs_tex,
bitmap=bm,
bitmap_rot=bm_rot,
bbox_w=bbox_w,
bbox_h=bbox_h,
rotation=theta,
s_tex=scale,
perimeter=perim,
)
)
order = sorted(range(n), key=lambda i: -prepared[i].perimeter)
total_area = sum(p.bbox_w * p.bbox_h for p in prepared)
side_guess = int(math.sqrt(total_area) * 2) + 16
atlas = np.zeros((side_guess, side_guess), dtype=bool)
cur_w = 0
cur_h = 0
placements: List[ChartPlacement] = [None] * n # type: ignore
for ci in order:
p = prepared[ci]
step = max(1, min(p.bbox_w, p.bbox_h) // 8)
det_arr = _build_candidates_jit(
skyline, cur_w, cur_h,
p.bitmap.shape[1], p.bitmap.shape[0],
p.bitmap_rot.shape[1], p.bitmap_rot.shape[0],
step,
)
x_range = max(cur_w + 1, 1)
y_range = max(cur_h + 1, 1)
rand_x = rng.integers(0, x_range, size=attempts).astype(np.int64)
rand_y = rng.integers(0, y_range, size=attempts).astype(np.int64)
rand_swap = (np.arange(attempts) & 1).astype(np.int64)
rand_arr = np.stack([rand_x, rand_y, rand_swap], axis=1)
candidates = np.concatenate([det_arr, rand_arr], axis=0) if det_arr.size else rand_arr
best_x, best_y, best_score_int, best_swap_int = _best_placement_jit(
atlas, p.bitmap, p.bitmap_rot, candidates, cur_w, cur_h,
)
best_swap = bool(best_swap_int)
if best_x >= 0:
bm_b = p.bitmap_rot if best_swap else p.bitmap
need_h = max(cur_h, best_y + bm_b.shape[0])
need_w = max(cur_w, best_x + bm_b.shape[1])
if atlas.shape[0] < need_h or atlas.shape[1] < need_w:
target_h = max(atlas.shape[0], need_h, side_guess)
target_w = max(atlas.shape[1], need_w, side_guess)
new_atlas = np.zeros((target_h, target_w), dtype=bool)
new_atlas[: atlas.shape[0], : atlas.shape[1]] = atlas
atlas = new_atlas
if best_x < 0:
# Fallback: place at extension corner.
best_x, best_y = cur_w, 0
best_swap = False
bm = p.bitmap
need_h = max(cur_h, best_y + bm.shape[0])
need_w = max(cur_w, best_x + bm.shape[1])
if atlas.shape[0] < need_h or atlas.shape[1] < need_w:
target_h = max(atlas.shape[0], need_h)
target_w = max(atlas.shape[1], need_w)
new_atlas = np.zeros((target_h, target_w), dtype=bool)
new_atlas[: atlas.shape[0], : atlas.shape[1]] = atlas
atlas = new_atlas
bm = p.bitmap_rot if best_swap else p.bitmap
_blit(atlas, bm, best_x, best_y)
cur_w = max(cur_w, best_x + bm.shape[1])
cur_h = max(cur_h, best_y + bm.shape[0])
if cur_w + 1 > skyline.shape[0]:
new_sky = np.zeros(max(skyline.shape[0] * 2, cur_w + 1), dtype=np.int64)
new_sky[: skyline.shape[0]] = skyline
skyline = new_sky
_update_skyline_jit(skyline, bm, best_x, best_y)
placements[ci] = ChartPlacement(
chart_id=ci,
offset=(float(best_x), float(best_y)),
scale=p.s_tex,
rotation=p.rotation,
swap_xy=best_swap,
chart_h=float(p.bitmap.shape[0]),
)
return placements, cur_w, cur_h
def apply_placements(
chart_uvs: List[Tensor], placements: List[ChartPlacement], atlas_w: int, atlas_h: int
) -> List[Tensor]:
"""Apply per-chart (rotation, scale, swap_xy, offset) and normalize by the larger atlas side (shared scale keeps texel density uniform)."""
out: List[Tensor] = []
side = float(max(atlas_w, atlas_h, 1))
for uvs, p in zip(chart_uvs, placements):
device = uvs.device
dtype = uvs.dtype
uvs_np = uvs.detach().cpu().numpy().astype(np.float64)
if p.rotation != 0.0:
uvs_np = _rotate_xy(uvs_np, p.rotation)
uvs_np = uvs_np - uvs_np.min(axis=0)
uvs_np = uvs_np * p.scale
if p.swap_xy:
# 90 deg rotation matching bm.T[:, ::-1]: (u, v) -> (chart_h - v, u).
u_old = uvs_np[:, 0].copy()
uvs_np[:, 0] = p.chart_h - uvs_np[:, 1]
uvs_np[:, 1] = u_old
uvs_np[:, 0] += p.offset[0]
uvs_np[:, 1] += p.offset[1]
uvs_np /= side
# Clamp into [0,1]; slivers can stick sub-texel past the tracked extent.
np.clip(uvs_np, 0.0, 1.0, out=uvs_np)
out.append(torch.from_numpy(uvs_np).to(device=device, dtype=dtype))
return out

View File

@ -0,0 +1,387 @@
"""Chart parameterization: ortho PCA projection, falling back to ABF/LSCM."""
from __future__ import annotations
import warnings
from typing import List, Tuple
import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as spla
import torch
from torch import Tensor
from . import mesh as _mesh
def solve_least_squares(A: sp.csr_matrix, b: np.ndarray) -> np.ndarray:
"""Solve ||Ax - b||^2 by factorizing AtA."""
At = A.T.tocsr()
AtA = (At @ A).tocsc()
Atb = At @ b
return spla.spsolve(AtA, Atb)
def _triangle_local_2d(verts_3d: np.ndarray, faces: np.ndarray) -> np.ndarray:
"""Per-triangle 2D coords [F, 3, 2] with v0 at origin, v1 along +x."""
v0 = verts_3d[faces[:, 0]]
v1 = verts_3d[faces[:, 1]]
v2 = verts_3d[faces[:, 2]]
e01 = v1 - v0
e02 = v2 - v0
L01 = np.linalg.norm(e01, axis=1).clip(min=1e-20)
x_axis = e01 / L01[:, None]
n = np.cross(e01, e02)
n /= np.linalg.norm(n, axis=1, keepdims=True).clip(min=1e-20)
y_axis = np.cross(n, x_axis)
out = np.zeros((faces.shape[0], 3, 2), dtype=np.float64)
out[:, 1, 0] = L01
out[:, 2, 0] = (e02 * x_axis).sum(axis=1)
out[:, 2, 1] = (e02 * y_axis).sum(axis=1)
return out
def _pick_pins(loops: List[List[int]], verts_3d: np.ndarray) -> Tuple[int, int]:
"""Pick the longest-diameter axis-extremal boundary vertex pair across all boundary verts."""
if not loops:
# Closed surface: two far verts via two-pass farthest.
d2 = np.sum((verts_3d - verts_3d[0]) ** 2, axis=1)
a = int(np.argmax(d2))
d2 = np.sum((verts_3d - verts_3d[a]) ** 2, axis=1)
b = int(np.argmax(d2))
return a, b
boundary_verts: List[int] = []
for loop in loops:
boundary_verts.extend(loop)
seen = set()
uniq = []
for v in boundary_verts:
if v not in seen:
seen.add(v)
uniq.append(v)
bv = np.asarray(uniq, dtype=np.int64)
pts = verts_3d[bv]
pin_pairs = []
for axis in range(3):
i_min = int(bv[int(np.argmin(pts[:, axis]))])
i_max = int(bv[int(np.argmax(pts[:, axis]))])
d = float(np.linalg.norm(verts_3d[i_min] - verts_3d[i_max]))
pin_pairs.append((d, i_min, i_max))
d0, _, _ = pin_pairs[0]
d1, _, _ = pin_pairs[1]
d2, _, _ = pin_pairs[2]
if d0 > d1 and d0 > d2:
_, a, b = pin_pairs[0]
elif d1 > d2:
_, a, b = pin_pairs[1]
else:
_, a, b = pin_pairs[2]
return a, b
def _ortho_project(verts_3d: np.ndarray) -> np.ndarray:
"""PCA-fit plane normal, axis-aligned tangent, project verts to 2D."""
centroid = verts_3d.mean(axis=0)
pts = verts_3d - centroid
cov = pts.T @ pts
_w, ev = np.linalg.eigh(cov)
normal = ev[:, 0]
a = np.abs(normal)
if a[0] < a[1] and a[0] < a[2]:
t = np.array([1.0, 0.0, 0.0])
elif a[1] < a[2]:
t = np.array([0.0, 1.0, 0.0])
else:
t = np.array([0.0, 0.0, 1.0])
t = t - normal * float(np.dot(normal, t))
t /= max(float(np.linalg.norm(t)), 1e-20)
b = np.cross(normal, t)
return np.stack([verts_3d @ t, verts_3d @ b], axis=1)
def _stretch_metrics(verts_3d: np.ndarray, uvs: np.ndarray, faces: np.ndarray) -> Tuple[float, float, int, int]:
"""Sander's stretch metric. Returns (rms, max, n_flipped, n_zero_area)."""
p = verts_3d[faces]
t = uvs[faces]
parametric_area = 0.5 * (
(t[:, 1, 1] - t[:, 0, 1]) * (t[:, 2, 0] - t[:, 0, 0])
- (t[:, 2, 1] - t[:, 0, 1]) * (t[:, 1, 0] - t[:, 0, 0])
)
n_flipped = int((parametric_area < -1e-12).sum())
n_zero = int((np.abs(parametric_area) < 1e-12).sum())
pa = np.abs(parametric_area).clip(min=1e-20)
geom_area = 0.5 * np.linalg.norm(
np.cross(p[:, 1] - p[:, 0], p[:, 2] - p[:, 0]), axis=1
)
keep = (geom_area > 1e-12) & (np.abs(parametric_area) > 1e-12)
if not keep.any():
return float("inf"), float("inf"), n_flipped, n_zero
t1 = t[:, 0, 0]; s1 = t[:, 0, 1]
t2 = t[:, 1, 0]; s2 = t[:, 1, 1]
t3 = t[:, 2, 0]; s3 = t[:, 2, 1]
inv_2pa = 1.0 / (2.0 * pa)
Ss = (
p[:, 0] * (t2 - t3)[:, None]
+ p[:, 1] * (t3 - t1)[:, None]
+ p[:, 2] * (t1 - t2)[:, None]
) * inv_2pa[:, None]
St = (
p[:, 0] * (s3 - s2)[:, None]
+ p[:, 1] * (s1 - s3)[:, None]
+ p[:, 2] * (s2 - s1)[:, None]
) * inv_2pa[:, None]
a = (Ss * Ss).sum(axis=1)
bb = (Ss * St).sum(axis=1)
c = (St * St).sum(axis=1)
sigma2_sq = 0.5 * (a + c + np.sqrt(np.maximum(0.0, (a - c) ** 2 + 4 * bb ** 2)))
rms_sq = (a + c) * 0.5
rms_stretch_sq_sum = float((rms_sq[keep] * geom_area[keep]).sum())
total_geom = float(geom_area[keep].sum())
total_param = float(pa[keep].sum())
if total_geom <= 0.0:
return float("inf"), float("inf"), n_flipped, n_zero
norm_factor = np.sqrt(total_param / total_geom)
rms_stretch = float(np.sqrt(rms_stretch_sq_sum / total_geom)) * norm_factor
max_stretch = float(np.sqrt(sigma2_sq[keep].max())) * norm_factor
return rms_stretch, max_stretch, n_flipped, n_zero
def _uv_boundary_self_intersects(
uvs: np.ndarray, faces: np.ndarray, face_face: np.ndarray, eps: float = 1e-9
) -> bool:
"""True if any chart-boundary edge pair crosses in 2D (ortho folded the chart)."""
fi, ei = np.nonzero(face_face < 0)
n = fi.size
if n < 2:
return False
a = uvs[faces[fi, ei]].astype(np.float64)
b = uvs[faces[fi, (ei + 1) % 3]].astype(np.float64)
d = b - a
# Pairwise segment crossings, row-chunked to bound memory at chunk*n.
chunk = max(1, min(n, 1_000_000 // max(n, 1)))
for s in range(0, n, chunk):
e = min(s + chunk, n)
d1 = d[s:e, None, :]
denom = d1[:, :, 0] * d[None, :, 1] - d1[:, :, 1] * d[None, :, 0]
rx = a[None, :, 0] - a[s:e, None, 0]
ry = a[None, :, 1] - a[s:e, None, 1]
with np.errstate(divide="ignore", invalid="ignore"):
t = (rx * d[None, :, 1] - ry * d[None, :, 0]) / denom
u = (rx * d1[:, :, 1] - ry * d1[:, :, 0]) / denom
cross = (
(np.abs(denom) >= eps)
& (t > eps) & (t < 1.0 - eps)
& (u > eps) & (u < 1.0 - eps)
)
if bool(cross.any()):
return True
return False
def parametrize_chart(
local_verts: Tensor, local_faces: Tensor, local_face_face: Tensor
) -> Tensor:
"""Parameterize one chart: ortho first, ABF/LSCM fallback; charts <=5 faces stay ortho."""
verts_np = local_verts.detach().cpu().numpy().astype(np.float64)
faces_np = local_faces.detach().cpu().numpy().astype(np.int64)
if verts_np.shape[0] < 3 or faces_np.shape[0] == 0:
return torch.zeros((verts_np.shape[0], 2), dtype=torch.float32, device=local_verts.device)
ortho = _ortho_project(verts_np)
n_faces = faces_np.shape[0]
if n_faces <= 5:
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
rms, mx, n_flip, n_zero = _stretch_metrics(verts_np, ortho, faces_np)
flip_ok = n_flip == 0 or n_flip == n_faces
if flip_ok and n_zero == 0 and rms <= 1.5 and mx <= 2.0:
ff_np = local_face_face.detach().cpu().numpy().astype(np.int64)
if not _uv_boundary_self_intersects(ortho, faces_np, ff_np):
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
uvs_t = lscm_chart(local_verts, local_faces, local_face_face, pin_positions=ortho)
# Collapsed UV island (aspect > 100:1) blows up packing scale; fall back to ortho.
uvs_np = uvs_t.detach().cpu().numpy()
bbox = uvs_np.max(axis=0) - uvs_np.min(axis=0)
bbox_max = float(max(bbox[0], bbox[1], 1e-12))
bbox_min = float(max(min(bbox[0], bbox[1]), 1e-12))
if bbox_max / bbox_min > 100.0:
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
return uvs_t
def _abf_face_coefficients(
verts_3d: np.ndarray, faces: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Per-face ABF constraint (largest-sine vertex at local index 2); returns (faces_reordered, cosine, sine, valid_mask) with valid_mask False for degenerate tris."""
Fc = faces.shape[0]
p0 = verts_3d[faces[:, 0]]
p1 = verts_3d[faces[:, 1]]
p2 = verts_3d[faces[:, 2]]
e01 = p1 - p0
e12 = p2 - p1
e20 = p0 - p2
L01 = np.linalg.norm(e01, axis=1).clip(min=1e-20)
L12 = np.linalg.norm(e12, axis=1).clip(min=1e-20)
L20 = np.linalg.norm(e20, axis=1).clip(min=1e-20)
cos_a0 = ((-e20) * e01).sum(axis=1) / (L20 * L01)
cos_a1 = ((-e01) * e12).sum(axis=1) / (L01 * L12)
cos_a2 = ((-e12) * e20).sum(axis=1) / (L12 * L20)
cos_a0 = cos_a0.clip(-1.0, 1.0)
cos_a1 = cos_a1.clip(-1.0, 1.0)
cos_a2 = cos_a2.clip(-1.0, 1.0)
a = np.arccos(cos_a0)
b_ang = np.arccos(cos_a1)
c_ang = np.arccos(cos_a2)
angles = np.stack([a, b_ang, c_ang], axis=1)
sines = np.stack([np.sin(a), np.sin(b_ang), np.sin(c_ang)], axis=1)
valid = (angles > 1e-12).all(axis=1)
ids = faces.astype(np.int64).copy()
s0, s1, s2 = sines[:, 0], sines[:, 1], sines[:, 2]
pattA = (s1 > s0) & (s1 > s2)
pattB = (~pattA) & (s0 > s1) & (s0 > s2)
if pattA.any():
old_a = angles[pattA].copy()
old_s = sines[pattA].copy()
old_id = ids[pattA].copy()
angles[pattA] = old_a[:, [2, 0, 1]]
sines[pattA] = old_s[:, [2, 0, 1]]
ids[pattA] = old_id[:, [2, 0, 1]]
if pattB.any():
old_a = angles[pattB].copy()
old_s = sines[pattB].copy()
old_id = ids[pattB].copy()
angles[pattB] = old_a[:, [1, 2, 0]]
sines[pattB] = old_s[:, [1, 2, 0]]
ids[pattB] = old_id[:, [1, 2, 0]]
a0 = angles[:, 0]
s0 = sines[:, 0]
s1 = sines[:, 1]
s2 = sines[:, 2]
c0 = np.cos(a0)
ratio = np.where(s2 > 0.0, s1 / s2.clip(min=1e-20), 1.0)
cosine = c0 * ratio
sine = s0 * ratio
return ids, cosine, sine, valid
def lscm_chart(
local_verts: Tensor,
local_faces: Tensor,
local_face_face: Tensor,
pin_positions: "np.ndarray | None" = None,
) -> Tensor:
"""ABF parameterization on one chart (degenerate faces use plain LSCM rows; two pins fix gauge at pin_positions)."""
verts_np = local_verts.detach().cpu().numpy().astype(np.float64)
faces_np = local_faces.detach().cpu().numpy().astype(np.int64)
Vc = verts_np.shape[0]
Fc = faces_np.shape[0]
if Vc < 3 or Fc == 0:
return torch.zeros((Vc, 2), dtype=torch.float32, device=local_verts.device)
loops = _mesh.chart_boundary_loops(local_faces, local_face_face)
pin_a, pin_b = _pick_pins(loops, verts_np)
if pin_positions is not None and pin_positions.shape == (Vc, 2):
pa = pin_positions[pin_a]
pb = pin_positions[pin_b]
u_a, v_a = float(pa[0]), float(pa[1])
u_b, v_b = float(pb[0]), float(pb[1])
else:
u_a, v_a = 0.0, 0.0
u_b, v_b = 1.0, 0.0
abf_ids, abf_cos, abf_sin, abf_valid = _abf_face_coefficients(verts_np, faces_np)
rows_list: List[np.ndarray] = []
cols_list: List[np.ndarray] = []
vals_list: List[np.ndarray] = []
# ABF rows for valid faces.
valid_idx = np.nonzero(abf_valid)[0]
if valid_idx.size:
Nv = valid_idx.size
id0 = abf_ids[valid_idx, 0]
id1 = abf_ids[valid_idx, 1]
id2 = abf_ids[valid_idx, 2]
cosf = abf_cos[valid_idx]
sinf = abf_sin[valid_idx]
r_real = valid_idx * 2
r_imag = valid_idx * 2 + 1
ones = np.ones(Nv, dtype=np.float64)
rows_list.extend([r_real] * 5)
cols_list.extend([id0, id0 + Vc, id1, id1 + Vc, id2])
vals_list.extend([cosf - 1.0, -sinf, -cosf, sinf, ones])
rows_list.extend([r_imag] * 5)
cols_list.extend([id0, id0 + Vc, id1, id1 + Vc, id2 + Vc])
vals_list.extend([sinf, cosf - 1.0, -sinf, -cosf, ones])
# Plain-LSCM rows for invalid (degenerate) faces.
invalid_idx = np.nonzero(~abf_valid)[0]
if invalid_idx.size:
tri2d_inv = _triangle_local_2d(verts_np, faces_np[invalid_idx])
twice_area_inv = (
tri2d_inv[:, 1, 0] * tri2d_inv[:, 2, 1]
- tri2d_inv[:, 1, 1] * tri2d_inv[:, 2, 0]
)
weight_inv = 1.0 / np.sqrt(2.0 * np.abs(twice_area_inv).clip(min=1e-20))
r_real_inv = invalid_idx * 2
r_imag_inv = invalid_idx * 2 + 1
for j in range(3):
jp1 = (j + 1) % 3
jp2 = (j + 2) % 3
a_j = (tri2d_inv[:, jp1, 0] - tri2d_inv[:, jp2, 0]) * weight_inv
b_j = (tri2d_inv[:, jp1, 1] - tri2d_inv[:, jp2, 1]) * weight_inv
v_idx = faces_np[invalid_idx, j]
rows_list.extend([r_real_inv, r_real_inv, r_imag_inv, r_imag_inv])
cols_list.extend([v_idx, v_idx + Vc, v_idx, v_idx + Vc])
vals_list.extend([a_j, -b_j, b_j, a_j])
rows = np.concatenate(rows_list) if rows_list else np.empty(0, dtype=np.int64)
cols = np.concatenate(cols_list) if cols_list else np.empty(0, dtype=np.int64)
vals = np.concatenate(vals_list) if vals_list else np.empty(0, dtype=np.float64)
A_full = sp.csr_matrix((vals, (rows, cols)), shape=(2 * Fc, 2 * Vc))
pin_cols = np.array([pin_a, pin_b, pin_a + Vc, pin_b + Vc], dtype=np.int64)
pin_vals = np.array([u_a, u_b, v_a, v_b], dtype=np.float64)
free_mask = np.ones(2 * Vc, dtype=bool)
free_mask[pin_cols] = False
free_cols = np.nonzero(free_mask)[0]
A_pinned = A_full[:, pin_cols]
A_free = A_full[:, free_cols]
b = -(A_pinned @ pin_vals)
# Singular system (under-constrained chart) falls back to ortho.
fallback_to_ortho = False
try:
with warnings.catch_warnings():
warnings.simplefilter("error", category=sp.linalg.MatrixRankWarning)
x_free = solve_least_squares(A_free, b)
if not np.all(np.isfinite(x_free)):
fallback_to_ortho = True
except Exception:
fallback_to_ortho = True
if fallback_to_ortho:
if pin_positions is not None and pin_positions.shape == (Vc, 2):
uvs = pin_positions.astype(np.float32)
else:
uvs = _ortho_project(verts_np).astype(np.float32)
return torch.from_numpy(uvs).to(local_verts.device)
full = np.zeros(2 * Vc, dtype=np.float64)
full[free_cols] = x_free
full[pin_cols] = pin_vals
uvs = np.stack([full[:Vc], full[Vc:]], axis=1).astype(np.float32)
if not np.all(np.isfinite(uvs)):
if pin_positions is not None and pin_positions.shape == (Vc, 2):
uvs = pin_positions.astype(np.float32)
else:
uvs = _ortho_project(verts_np).astype(np.float32)
return torch.from_numpy(uvs).to(local_verts.device)

View File

@ -0,0 +1,638 @@
"""Adaptive cost-grow chart segmentation (CPU); numba optional, numpy path is nd-only."""
from __future__ import annotations
from typing import List, Tuple
import numpy as np
import torch
from torch import Tensor
try:
from numba import njit
_HAVE_NUMBA = True
except ImportError:
_HAVE_NUMBA = False
def njit(*args, **kwargs): # noqa: ARG001
def deco(fn):
return fn
return deco if not args else args[0]
from .mesh import MeshData, face_edge_lengths
DEFAULT_W_NORMAL_DEVIATION = 2.0
DEFAULT_W_ROUNDNESS = 0.01
DEFAULT_W_STRAIGHTNESS = 6.0
DEFAULT_MAX_COST = 2.0
NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75°
@njit(cache=True, fastmath=False)
def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
F = face_normal.shape[0]
raw = np.zeros(F, dtype=np.float32)
for f in range(F):
nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2]
s = np.float32(0.0)
for e in range(3):
nb = face_face[f, e]
if nb < 0:
continue
mx = face_normal[nb, 0]; my = face_normal[nb, 1]; mz = face_normal[nb, 2]
d = nx*mx + ny*my + nz*mz
s += np.float32(1.0) - d
raw[f] = s
return raw
def _face_curvature_numpy(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
nb_safe = np.maximum(face_face, 0)
nb_normal = face_normal[nb_safe]
d = (face_normal[:, None, :] * nb_normal).sum(axis=-1)
contrib = np.where(face_face >= 0, np.float32(1.0) - d, np.float32(0.0))
return contrib.sum(axis=1).astype(np.float32)
@njit(cache=True, fastmath=False)
def _farthest_point_seeds_jit(
face_centroid: np.ndarray, face_area: np.ndarray, face_weight: np.ndarray,
initial_seeds: np.ndarray, k_target: int,
):
F = face_centroid.shape[0]
INF = np.float32(1e30)
min_dist = np.full(F, INF, dtype=np.float32)
seeds = np.empty(k_target, dtype=np.int64)
n_seeds = 0
for i in range(initial_seeds.shape[0]):
s = initial_seeds[i]
if s < 0 or n_seeds >= k_target:
continue
seeds[n_seeds] = s
n_seeds += 1
sx = face_centroid[s, 0]; sy = face_centroid[s, 1]; sz = face_centroid[s, 2]
for f in range(F):
dx = face_centroid[f, 0] - sx
dy = face_centroid[f, 1] - sy
dz = face_centroid[f, 2] - sz
d2 = dx*dx + dy*dy + dz*dz
if d2 < min_dist[f]:
min_dist[f] = d2
while n_seeds < k_target:
best_f = -1
best_score = np.float32(-1.0)
for f in range(F):
d = min_dist[f]
if d >= INF * np.float32(0.5):
continue
score = d * face_weight[f]
if score > best_score:
best_score = score
best_f = f
if best_f < 0:
break
seeds[n_seeds] = best_f
n_seeds += 1
sx = face_centroid[best_f, 0]
sy = face_centroid[best_f, 1]
sz = face_centroid[best_f, 2]
for f in range(F):
dx = face_centroid[f, 0] - sx
dy = face_centroid[f, 1] - sy
dz = face_centroid[f, 2] - sz
d2 = dx*dx + dy*dy + dz*dz
if d2 < min_dist[f]:
min_dist[f] = d2
return seeds[:n_seeds]
def _farthest_point_seeds_numpy(
face_centroid: np.ndarray, initial_seeds: np.ndarray, k_target: int,
):
F = face_centroid.shape[0]
min_dist = np.full(F, np.inf, dtype=np.float32)
seeds: List[int] = []
for s in initial_seeds:
if s < 0 or len(seeds) >= k_target:
continue
seeds.append(int(s))
d = ((face_centroid - face_centroid[s])**2).sum(axis=-1)
min_dist = np.minimum(min_dist, d)
while len(seeds) < k_target:
best = int(np.argmax(min_dist))
if not np.isfinite(min_dist[best]) or min_dist[best] <= 0:
break
seeds.append(best)
d = ((face_centroid - face_centroid[best])**2).sum(axis=-1)
min_dist = np.minimum(min_dist, d)
return np.asarray(seeds, dtype=np.int64)
@njit(cache=True, fastmath=False)
def _cost_grow_iter_jit(
face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray,
face_area: np.ndarray, face_edge_len: np.ndarray,
chart_basis: np.ndarray, chart_normal_sum: np.ndarray,
chart_area: np.ndarray, chart_perim: np.ndarray,
nd_cutoff: float, max_cost: float,
w_nd: float, w_round: float, w_straight: float,
):
"""One grow iter: each unassigned face joins its lowest-cost adjacent chart if cost<max_cost."""
F = face_chart.shape[0]
best_chart_per_face = np.full(F, -1, dtype=np.int64)
best_cost_per_face = np.full(F, np.inf, dtype=np.float32)
for f in range(F):
if face_chart[f] != -1:
continue
nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2]
af = face_area[f]
for e0 in range(3):
nb0 = face_face[f, e0]
if nb0 < 0:
continue
c = face_chart[nb0]
if c < 0:
continue
d = (nx * chart_basis[c, 0] + ny * chart_basis[c, 1] + nz * chart_basis[c, 2])
nd = np.float32(1.0) - d
if nd > np.float32(1.0):
nd = np.float32(1.0)
if nd < np.float32(0.0):
nd = np.float32(0.0)
if nd >= nd_cutoff:
continue
l_in = np.float32(0.0)
l_out = np.float32(0.0)
for e1 in range(3):
nb1 = face_face[f, e1]
el = face_edge_len[f, e1]
if nb1 < 0:
l_out += el
elif face_chart[nb1] == c:
l_in += el
else:
l_out += el
ca = chart_area[c]
cp = chart_perim[c]
new_perim = cp - l_in + l_out
new_area = ca + af
if cp <= np.float32(1e-20) or ca <= np.float32(1e-20):
round_cost = np.float32(0.0)
else:
old_r = (cp * cp) / ca
new_r = (new_perim * new_perim) / new_area
if new_r <= np.float32(1e-20):
round_cost = np.float32(0.0)
else:
round_cost = np.float32(1.0) - old_r / new_r
denom = l_out + l_in
if denom <= np.float32(1e-20):
straight_cost = np.float32(0.0)
else:
ratio = (l_out - l_in) / denom
if ratio < np.float32(0.0):
straight_cost = ratio
else:
straight_cost = np.float32(0.0)
cost = (w_nd * nd + w_round * round_cost + w_straight * straight_cost)
if cost < best_cost_per_face[f]:
best_cost_per_face[f] = cost
best_chart_per_face[f] = c
n_assigned = 0
for f in range(F):
if face_chart[f] != -1:
continue
if best_chart_per_face[f] < 0:
continue
if best_cost_per_face[f] > max_cost:
continue
c = best_chart_per_face[f]
l_in = np.float32(0.0)
l_out = np.float32(0.0)
for e1 in range(3):
nb1 = face_face[f, e1]
el = face_edge_len[f, e1]
if nb1 < 0:
l_out += el
elif face_chart[nb1] == c:
l_in += el
else:
l_out += el
af = face_area[f]
face_chart[f] = c
chart_normal_sum[c, 0] += face_normal[f, 0] * af
chart_normal_sum[c, 1] += face_normal[f, 1] * af
chart_normal_sum[c, 2] += face_normal[f, 2] * af
chart_area[c] += af
chart_perim[c] = chart_perim[c] - l_in + l_out
nx = chart_normal_sum[c, 0]
ny = chart_normal_sum[c, 1]
nz = chart_normal_sum[c, 2]
nlen = np.sqrt(nx * nx + ny * ny + nz * nz)
if nlen > np.float32(1e-20):
chart_basis[c, 0] = nx / nlen
chart_basis[c, 1] = ny / nlen
chart_basis[c, 2] = nz / nlen
n_assigned += 1
return n_assigned
def _renumber(face_chart: np.ndarray, device) -> Tensor:
unique = np.unique(face_chart[face_chart >= 0])
if unique.size == 0:
return torch.from_numpy(face_chart).to(device)
remap = np.full(int(unique.max()) + 1, -1, dtype=np.int64)
remap[unique] = np.arange(unique.size)
out = face_chart.copy()
mask = out >= 0
out[mask] = remap[out[mask]]
return torch.from_numpy(out).to(device)
def _segment_charts_fast(
mesh: MeshData,
max_cost: float,
w_normal_deviation: float,
w_roundness: float = DEFAULT_W_ROUNDNESS,
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
target_chart_count: int = 0,
) -> Tensor:
"""Parallel batch cost-grow; target_chart_count 0 = adaptive seeding, >0 = K curvature-weighted FPS seeds."""
F = mesh.faces.shape[0]
device = mesh.faces.device
if F == 0:
return torch.zeros(0, dtype=torch.long, device=device)
face_normal = mesh.face_normal.detach().cpu().numpy().astype(np.float32)
face_area = mesh.face_area.detach().cpu().numpy().astype(np.float32)
face_centroid = mesh.face_centroid.detach().cpu().numpy().astype(np.float32)
face_face = mesh.face_face.detach().cpu().numpy()
face_chart = np.full(F, -1, dtype=np.int64)
nd_cutoff = np.float32(NORMAL_DEVIATION_HARD_CUTOFF)
nd_threshold = np.float32(min(max_cost / max(w_normal_deviation, 1e-6),
NORMAL_DEVIATION_HARD_CUTOFF * 0.99))
component = (mesh.component.detach().cpu().numpy()
if hasattr(mesh.component, "detach") else np.asarray(mesh.component))
if component.size:
_, first_idx = np.unique(component, return_index=True)
initial_seeds = first_idx.astype(np.int64)
else:
initial_seeds = np.empty(0, dtype=np.int64)
adaptive_seeding = target_chart_count <= 0
if adaptive_seeding:
seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()]
if not seed_faces:
seed_faces = [0]
else:
if _HAVE_NUMBA:
curvature_raw = _face_curvature_jit(face_normal, face_face)
else:
curvature_raw = _face_curvature_numpy(face_normal, face_face)
cmax = float(curvature_raw.max()) if curvature_raw.size else 0.0
if cmax > 1e-6:
face_weight = (np.float32(1.0) + np.float32(50.0) *
(curvature_raw / np.float32(cmax))).astype(np.float32)
else:
face_weight = np.ones(F, dtype=np.float32)
n_comp = int(initial_seeds.size)
if n_comp < int(target_chart_count):
target_seeds = int(target_chart_count)
else:
target_seeds = n_comp + max(int(target_chart_count) // 4, 8)
target_seeds = min(target_seeds, F)
if _HAVE_NUMBA:
seeds_arr = _farthest_point_seeds_jit(
face_centroid, face_area, face_weight, initial_seeds, target_seeds,
)
else:
seeds_arr = _farthest_point_seeds_numpy(
face_centroid, initial_seeds, target_seeds,
)
seed_faces = [int(s) for s in seeds_arr.tolist()]
K = len(seed_faces)
chart_basis = np.zeros((K, 3), dtype=np.float32)
chart_normal_sum = np.zeros((K, 3), dtype=np.float32)
chart_area = np.zeros(K, dtype=np.float32)
chart_perim = np.zeros(K, dtype=np.float32)
face_edge_len = (
face_edge_lengths(mesh.vertices, mesh.faces)
.detach().cpu().numpy()
)
for cid, sf in enumerate(seed_faces):
face_chart[sf] = cid
n = face_normal[sf]
a = face_area[sf]
chart_basis[cid] = n.astype(np.float32)
chart_normal_sum[cid] = (n * a).astype(np.float32)
chart_area[cid] = float(a)
chart_perim[cid] = float(face_edge_len[sf].sum())
if K == 0:
return _renumber(face_chart, device)
min_dist_to_seed = np.full(F, np.inf, dtype=np.float32)
if adaptive_seeding:
for sf in seed_faces:
d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1)
min_dist_to_seed = np.minimum(min_dist_to_seed, d)
if _HAVE_NUMBA:
# Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg.
tau_final = min(max_cost * 0.25, 0.5)
thresholds = [t for t in (0.05, 0.1, 0.25) if t < tau_final] + [tau_final]
max_inner = max(64, int(np.sqrt(F)) * 2)
max_total_charts = max(F, 8000)
outer_iter = 0
while True:
outer_iter += 1
if outer_iter > F + 16:
break
for tau in thresholds:
for _ in range(max_inner):
n_added = _cost_grow_iter_jit(
face_chart, face_face, face_normal, face_area, face_edge_len,
chart_basis, chart_normal_sum, chart_area, chart_perim,
nd_cutoff, np.float32(tau),
np.float32(w_normal_deviation),
np.float32(w_roundness),
np.float32(w_straightness),
)
if n_added == 0:
break
if (face_chart == -1).sum() == 0:
break
if not adaptive_seeding:
break
if chart_basis.shape[0] >= max_total_charts:
break
unassigned_mask = face_chart == -1
cand = np.where(unassigned_mask, min_dist_to_seed, np.float32(-np.inf))
new_seed = int(np.argmax(cand))
n = face_normal[new_seed]
a = face_area[new_seed]
chart_basis = np.vstack([chart_basis, n[None, :].astype(np.float32)])
chart_normal_sum = np.vstack(
[chart_normal_sum, (n * a)[None, :].astype(np.float32)]
)
chart_area = np.concatenate([chart_area, np.array([a], dtype=np.float32)])
chart_perim = np.concatenate(
[chart_perim, np.array([face_edge_len[new_seed].sum()], dtype=np.float32)]
)
face_chart[new_seed] = chart_basis.shape[0] - 1
new_d = ((face_centroid - face_centroid[new_seed]) ** 2).sum(axis=-1)
min_dist_to_seed = np.minimum(min_dist_to_seed, new_d)
else:
# Numpy fallback: nd-only adaptive grow.
for _ in range(max(64, int(np.sqrt(F)) + 32)):
unassigned = face_chart == -1
if not unassigned.any():
break
u_idx = np.nonzero(unassigned)[0]
nbs = face_face[u_idx]
nbs_safe = np.where(nbs >= 0, nbs, 0)
nb_charts = np.where(nbs >= 0, face_chart[nbs_safe], -1)
valid = (nb_charts >= 0)
if not valid.any():
break
nb_charts_safe = np.where(valid, nb_charts, 0)
nb_basis = chart_basis[nb_charts_safe]
d = (face_normal[u_idx][:, None, :] * nb_basis).sum(axis=-1)
nd = np.where(valid, np.float32(1.0) - d, np.inf).clip(max=1.0)
nd = np.where(nd >= nd_cutoff, np.inf, nd)
best_e = np.argmin(nd, axis=1)
best_cost = nd[np.arange(u_idx.size), best_e]
best_c = nb_charts_safe[np.arange(u_idx.size), best_e]
accept = (best_cost <= nd_threshold) & np.isfinite(best_cost)
if not accept.any():
break
pick_u = u_idx[accept]
pick_c = best_c[accept]
face_chart[pick_u] = pick_c
for f, c in zip(pick_u, pick_c):
chart_normal_sum[c] += face_normal[f] * face_area[f]
chart_area[c] += face_area[f]
# Orphan cleanup: leftover faces join their best-matching neighbor's chart.
if (face_chart == -1).any() and chart_basis.shape[0] > 0:
while True:
orphans = np.nonzero(face_chart == -1)[0]
if orphans.size == 0:
break
nbs = face_face[orphans]
nbs_safe = np.where(nbs >= 0, nbs, 0)
nb_charts = np.where(nbs >= 0, face_chart[nbs_safe], -1)
valid = (nb_charts >= 0)
if not valid.any():
break
nb_charts_safe = np.where(valid, nb_charts, 0)
nb_basis = chart_basis[nb_charts_safe]
d = (face_normal[orphans][:, None, :] * nb_basis).sum(axis=-1)
nd = np.where(valid, np.float32(1.0) - d, np.inf)
best_e = np.argmin(nd, axis=1)
best_c = nb_charts_safe[np.arange(orphans.size), best_e]
assignable = valid.any(axis=1)
if not assignable.any():
break
assign_idx = orphans[assignable]
assign_c = best_c[assignable]
face_chart[assign_idx] = assign_c
if (face_chart == -1).any():
new_singletons = np.nonzero(face_chart == -1)[0]
for f in new_singletons:
face_chart[int(f)] = chart_basis.shape[0]
chart_basis = np.concatenate(
[chart_basis, face_normal[int(f)].astype(np.float32)[None, :]],
axis=0,
)
return _renumber(face_chart, device)
def segment_charts(
mesh: MeshData,
max_cost: float = DEFAULT_MAX_COST,
w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION,
w_roundness: float = DEFAULT_W_ROUNDNESS,
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
target_chart_count: int = 0,
) -> Tensor:
"""Segment mesh into charts. Returns face -> chart_id."""
return _segment_charts_fast(
mesh, max_cost=max_cost,
w_normal_deviation=w_normal_deviation,
w_roundness=w_roundness,
w_straightness=w_straightness,
target_chart_count=target_chart_count,
)
# ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ----
def _combine_normal_cones(
axis_a: Tensor, half_a: Tensor,
axis_b: Tensor, half_b: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Merge two normal cones along the great circle from axis_a; returns (combined_axis, combined_half_angle, axis_angle)."""
cos_angle = (axis_a * axis_b).sum(dim=-1).clamp(-1.0, 1.0)
axis_angle = torch.acos(cos_angle)
new_low = torch.minimum(-half_a, axis_angle - half_b)
new_high = torch.maximum(half_a, axis_angle + half_b)
new_half = (new_high - new_low) * 0.5
rot_angle = (new_high + new_low) * 0.5
b_perp = axis_b - axis_a * cos_angle.unsqueeze(-1)
b_perp_norm = b_perp.norm(dim=-1, keepdim=True).clamp_min(1e-12)
b_perp_unit = b_perp / b_perp_norm
new_axis = (
axis_a * torch.cos(rot_angle).unsqueeze(-1)
+ b_perp_unit * torch.sin(rot_angle).unsqueeze(-1)
)
new_axis_norm = new_axis.norm(dim=-1, keepdim=True).clamp_min(1e-12)
new_axis = new_axis / new_axis_norm
return new_axis, new_half, axis_angle
def _build_chart_edges(
face_face: Tensor,
chart_id: Tensor,
face_edge_len: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Build chart-edge list (chart_pairs[E,2] with a<b, edge_length[E]); same-chart edges dropped, duplicates summed."""
F = face_face.shape[0]
device = face_face.device
f_idx = torch.arange(F, device=device).repeat_interleave(3)
nb = face_face.flatten()
valid = nb >= 0
f_idx = f_idx[valid]
nb = nb[valid]
el = face_edge_len.flatten()[valid]
ca = chart_id[f_idx]
cb = chart_id[nb]
diff = ca != cb
ca = ca[diff]
cb = cb[diff]
el = el[diff]
if ca.numel() == 0:
return (
torch.empty((0, 2), dtype=torch.long, device=device),
torch.empty(0, device=device),
)
lo = torch.minimum(ca, cb)
hi = torch.maximum(ca, cb)
V = int(chart_id.max().item()) + 1
key = lo * V + hi
sort_idx = torch.argsort(key)
sorted_key = key[sort_idx]
sorted_lo = lo[sort_idx]
sorted_hi = hi[sort_idx]
sorted_el = el[sort_idx]
unique_key, inverse, counts = torch.unique(
sorted_key, return_inverse=True, return_counts=True
)
n_unique = unique_key.shape[0]
reduced_el = torch.zeros(n_unique, device=device, dtype=el.dtype)
reduced_el.scatter_add_(0, inverse, sorted_el)
first_idx = torch.cat([
torch.zeros(1, dtype=torch.long, device=device),
counts.cumsum(0)[:-1],
])
pair_lo = sorted_lo[first_idx]
pair_hi = sorted_hi[first_idx]
chart_pairs = torch.stack([pair_lo, pair_hi], dim=1)
return chart_pairs, reduced_el
def cluster_charts_pec(
mesh: MeshData,
target_chart_count: int = 0,
max_cost: float = 0.7,
area_penalty_weight: float = 0.0,
roundness_weight: float = 0.0,
max_iters: int = 1024,
) -> Tensor:
"""Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg)."""
device = mesh.faces.device
F = mesh.faces.shape[0]
faces = mesh.faces.to(torch.long)
vertices = mesh.vertices.to(torch.float32)
face_normal = mesh.face_normal.to(torch.float32)
face_area = mesh.face_area.to(torch.float32)
face_face = mesh.face_face.to(torch.long)
face_edge_len = face_edge_lengths(vertices, faces)
chart_id = torch.arange(F, dtype=torch.long, device=device)
chart_axis = face_normal.clone()
chart_half = torch.zeros(F, dtype=torch.float32, device=device)
chart_area = face_area.clone()
chart_perim = face_edge_len.sum(dim=1).clone()
for it in range(max_iters):
edges, edge_len = _build_chart_edges(face_face, chart_id, face_edge_len)
if edges.shape[0] == 0:
break
a = edges[:, 0]
b = edges[:, 1]
axis_a = chart_axis[a]
axis_b = chart_axis[b]
half_a = chart_half[a]
half_b = chart_half[b]
_, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b)
cost = new_half.clone()
if area_penalty_weight > 0.0:
new_area = chart_area[a] + chart_area[b]
cost = cost + area_penalty_weight * new_area
if roundness_weight > 0.0:
new_area_r = chart_area[a] + chart_area[b]
new_perim_r = chart_perim[a] + chart_perim[b] - 2.0 * edge_len
cost = cost + roundness_weight * (new_perim_r * new_perim_r) / new_area_r.clamp_min(1e-12)
# Pack (cost, edge_id) so scatter_reduce amin picks the right edge.
E = edges.shape[0]
N = int(chart_id.max().item()) + 1
edge_ids = torch.arange(E, dtype=torch.long, device=device)
cost_i32 = torch.clamp(cost * 1e6, max=2e9).to(torch.int64)
key = (cost_i32 << 32) | edge_ids
chart_min = torch.full((N,), (2**62), dtype=torch.long, device=device)
chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True)
chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True)
# Mutual-min collapse: each chart in at most one merge per iter.
is_a_min = chart_min[a] == key
is_b_min = chart_min[b] == key
within = cost <= max_cost
winners = is_a_min & is_b_min & within
n_merge = int(winners.sum().item())
if n_merge == 0:
break
win_a = a[winners]
win_b = b[winners]
win_el = edge_len[winners]
axis_a_w = chart_axis[win_a]
half_a_w = chart_half[win_a]
axis_b_w = chart_axis[win_b]
half_b_w = chart_half[win_b]
new_axis, new_half_w, _ = _combine_normal_cones(
axis_a_w, half_a_w, axis_b_w, half_b_w,
)
chart_axis[win_a] = new_axis
chart_half[win_a] = new_half_w
chart_area[win_a] = chart_area[win_a] + chart_area[win_b]
chart_perim[win_a] = chart_perim[win_a] + chart_perim[win_b] - 2.0 * win_el
remap = torch.arange(N, dtype=torch.long, device=device)
remap[win_b] = win_a
chart_id = remap[chart_id]
_, inverse = torch.unique(chart_id, sorted=True, return_inverse=True)
return inverse

View File

@ -6,9 +6,20 @@ from comfy_api.latest import ComfyExtension, IO, Types
import copy
import comfy.utils
import comfy.model_management
from comfy_extras.qem_decimate.qem_core import simplify as qem_decimate_simplify, QEMConfig
from server import PromptServer
from comfy_extras.mesh3d.postprocess.qem_decimate import (
simplify as qem_decimate_simplify, QEMConfig, cluster_decimate as qem_cluster_decimate,
)
from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc
from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh
from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg
from comfy_extras.mesh3d.uv_unwrap import parameterize as _uv_param
from comfy_extras.mesh3d.uv_unwrap import pack as _uv_pack
import warnings
import logging
import scipy
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
def get_mesh_batch_item(mesh, index):
if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None:
@ -2162,6 +2173,566 @@ class DecimateMesh(IO.ComfyNode):
return result
class RemeshMesh(IO.ComfyNode):
@classmethod
def define_schema(cls):
# sign_mode picks the scalar field, and exposes only the knobs relevant to it
# (DynamicCombo: udf sub-widgets show for 'udf', sdf sub-widgets for 'sdf').
sign_mode_options = [
IO.DynamicCombo.Option(key="udf", inputs=[
IO.Boolean.Input("qef", default=False,
tooltip="Experimental: place dual vertices via QEF (closest-triangle normals) "
"instead of edge-crossing centroid. QEF is sign-agnostic so it works "
"in UDF too — pulls the ±eps surface back onto the planes for sharper "
"edges. May misbehave near the UDF double shell; compare with it off."),
IO.Boolean.Input("drop_inverted_components", default=True,
tooltip="Drop closed components with inward normals (negative signed volume) — "
"the inner shell UDF produces on closed regions."),
IO.Boolean.Input("drop_enclosed_components", default=True,
tooltip="Drop components whose bbox is inside the largest's AND fail a raycast "
"point-in-mesh test. Disable if you have legitimate parts inside others."),
]),
IO.DynamicCombo.Option(key="sdf", inputs=[
IO.Boolean.Input("qef", default=True,
tooltip="Place dual vertices via QEF solve from closest-triangle normals "
"(recovers sharp features) vs edge-crossing centroid."),
IO.Boolean.Input("manifold", default=False,
tooltip="Manifold Dual Contouring: emit 1-4 dual verts per voxel for "
"multi-sheet (thin/touching) cases. Slower; guarantees manifold output."),
]),
]
return IO.Schema(
node_id="RemeshMesh",
display_name="Remesh Mesh (Narrow-Band DC)",
category="latent/3d",
description=(
"Re-extracts a uniformly tessellated mesh by sampling a distance field on a "
"narrow-band voxel grid and contouring it with Dual Contouring, on the active "
"compute device. Normalizes topology of messy / non-manifold / self-intersecting "
"input; run before DecimateMesh to hit an exact face count. Output stays welded."
),
inputs=[
IO.Mesh.Input("mesh"),
IO.Int.Input("target_faces", default=0, min=0, max=50_000_000,
tooltip="0 = use 'resolution'. >0 = auto-pick resolution to roughly hit this "
"face count (±30-50%); usually overshoot then DecimateMesh to exact."),
IO.Int.Input("resolution", default=256, min=32, max=1024,
tooltip="Voxel grid resolution (used when target_faces=0). Higher = more detail, "
"slower. 256 ~ 100k faces, 512 ~ 1M."),
IO.DynamicCombo.Input("sign_mode", options=sign_mode_options, display_name="sign_mode",
tooltip="udf: robust to messy/non-manifold input (double shell cleaned by "
"the inner-shell filters). sdf: clean single surface with optional "
"QEF sharp-feature recovery, but needs consistent winding."),
IO.Float.Input("band", default=1.0, min=0.5, max=4.0, step=0.1,
tooltip="Narrow-band width in voxel units (which voxels are sampled). In UDF "
"mode also offsets the surface by this many voxels."),
IO.Float.Input("project_back", default=0.0, min=0.0, max=1.0, step=0.05,
tooltip="Lerp output verts toward the closest point on the original surface "
"(0 = pure DC, 1 = snapped). Recovers voxelization-lost detail."),
IO.Boolean.Input("fix_poles", default=False,
tooltip="Collapse valence-3 vertex pairs (DC T-junction artifact). Cheap; "
"improves shading and downstream simplification."),
IO.Int.Input("smooth_iters", default=0, min=0, max=20,
tooltip="Taubin λ|μ smoothing iterations (0 = off). Volume-preserving; cleans DC "
"stairstepping. 2-3 is enough; higher rounds off QEF sharp features."),
IO.Float.Input("drop_small_components", default=0.01, min=0.0, max=0.5, step=0.005,
tooltip="Drop components with fewer than this fraction of the largest component's "
"faces (inner-shell fragments, noise). 0 disables."),
IO.Int.Input("precluster_max_verts", default=0, min=0, max=50_000_000,
tooltip="Safety fallback: if input has more verts than this (>0), cluster-decimate "
"it down first so the distance-field queries don't OOM on huge inputs. "
"0 = off; 1-2M is reasonable for very large meshes."),
],
outputs=[IO.Mesh.Output("mesh")],
hidden=[IO.Hidden.unique_id],
)
@classmethod
def execute(cls, mesh, target_faces, resolution, sign_mode, band,
project_back, fix_poles, smooth_iters,
drop_small_components, precluster_max_verts):
mode = sign_mode.get("sign_mode", "udf")
# mode-specific sub-widgets (absent ones fall back to defaults)
qef = bool(sign_mode.get("qef", True))
manifold = bool(sign_mode.get("manifold", False))
drop_inverted_components = bool(sign_mode.get("drop_inverted_components", True))
drop_enclosed_components = bool(sign_mode.get("drop_enclosed_components", True))
# ComfyUI passes meshes on CPU; remesh is far faster on GPU. Run on the
# selected compute device and return on the mesh's original device.
compute_device = comfy.model_management.get_torch_device()
counts = {"in": 0, "out": 0}
def _fn(v, f, c):
counts["in"] += int(f.shape[0])
try:
src_device = v.device
vv = v.to(compute_device).float()
ff = f.to(compute_device).to(torch.int64)
cc = c.to(compute_device).float() if c is not None else None
# safety fallback: cluster-decimate very large inputs before the field queries
if precluster_max_verts > 0 and vv.shape[0] > precluster_max_verts:
vv, ff, cc = qem_cluster_decimate(
vv, ff, target_verts=int(precluster_max_verts), colors=cc)
# Fixed [-0.5,0.5] cube domain (matches cumesh / TRELLIS2). scale ≈ 1.0
# for any resolution, so this is consistent in target_faces auto mode too.
rs_scale = (resolution + 3.0 * band) / resolution
rs_center = torch.zeros(3, dtype=vv.dtype, device=compute_device)
rv, rf, rc = remesh_narrow_band_dc(
vv, ff,
resolution=int(resolution), target_faces=int(target_faces),
band=float(band), project_back=float(project_back),
qef=qef, sign_mode=mode,
manifold=manifold, fix_poles=bool(fix_poles),
smooth_iters=int(smooth_iters),
drop_small_components=float(drop_small_components),
drop_inverted_components=drop_inverted_components,
drop_enclosed_components=drop_enclosed_components,
scale=rs_scale, center=rs_center, colors=cc)
v = rv.to(src_device)
f = rf.to(src_device)
c = rc.to(src_device) if rc is not None else None
except Exception as e:
logging.warning(f"RemeshMesh: remesh failed, passing mesh through unchanged: {e!r}")
counts["out"] += int(f.shape[0])
return v, f, c
result = _process_mesh_batch(mesh, _fn)
# Send progress text to display the face change on the node
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(
f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id)
return result
def _pack_uv_meshes(vs, fs, uvs, colors):
"""Pack per-item (verts, faces, uvs[, colors]) into a MESH; stack if single, else pad with counts."""
if len(vs) == 1:
m = Types.MESH(vertices=vs[0].unsqueeze(0), faces=fs[0].unsqueeze(0), uvs=uvs[0].unsqueeze(0))
if colors is not None:
m.vertex_colors = colors[0].unsqueeze(0)
return m
bsz = len(vs)
dev = vs[0].device
maxv = max(v.shape[0] for v in vs)
maxf = max(f.shape[0] for f in fs)
pv = vs[0].new_zeros((bsz, maxv, 3))
pf = fs[0].new_zeros((bsz, maxf, 3))
pu = uvs[0].new_zeros((bsz, maxv, 2))
for i, (v, f, u) in enumerate(zip(vs, fs, uvs)):
pv[i, :v.shape[0]] = v
pf[i, :f.shape[0]] = f
pu[i, :u.shape[0]] = u
vc = torch.tensor([v.shape[0] for v in vs], device=dev, dtype=torch.int64)
fc = torch.tensor([f.shape[0] for f in fs], device=dev, dtype=torch.int64)
m = Types.MESH(vertices=pv, faces=pf, uvs=pu, vertex_counts=vc, face_counts=fc)
if colors is not None:
pc = colors[0].new_zeros((bsz, maxv, colors[0].shape[1]))
for i, c in enumerate(colors):
pc[i, :c.shape[0]] = c
m.vertex_colors = pc
return m
def _uv_weld_vertices(v, f, weld_distance):
"""Merge coincident verts; returns (welded_v, welded_f, welded_to_orig) (last None if no welding)."""
v_np = v.cpu().numpy()
f_np = f.cpu().numpy()
if v_np.size == 0:
return v, f, None
extent = float(np.linalg.norm(v_np.max(axis=0) - v_np.min(axis=0)))
tol = weld_distance if weld_distance > 0.0 else 1e-5 * extent
if tol <= 0.0:
return v, f, None
keys = np.round(v_np / tol).astype(np.int64)
_, inv = np.unique(keys, axis=0, return_inverse=True)
n_unique = int(inv.max()) + 1
if n_unique >= v_np.shape[0]:
return v, f, None
v_welded = np.zeros((n_unique, 3), dtype=np.float32)
counts = np.zeros(n_unique, dtype=np.int64)
np.add.at(v_welded, inv, v_np)
np.add.at(counts, inv, 1)
v_welded /= counts[:, None]
welded_to_orig = np.empty(n_unique, dtype=np.int64)
welded_to_orig[inv] = np.arange(v_np.shape[0], dtype=np.int64)
v_new = torch.from_numpy(v_welded).to(v.dtype).to(v.device)
f_new = torch.from_numpy(inv[f_np]).to(f.dtype).to(f.device)
return v_new, f_new, welded_to_orig
def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance):
"""UV-unwrap a single mesh; returns (vmapping, indices, uvs) — vmapping maps each output
vertex to an input vertex (seam verts duplicated)."""
v_in = positions.to(torch.float32)
f_in = indices.to(torch.long).reshape(-1, 3)
v_in, f_in, welded_to_orig = _uv_weld_vertices(v_in, f_in, weld_distance)
# drop degenerate faces (repeated index) — they corrupt edge adjacency
degen = ((f_in[:, 0] == f_in[:, 1]) | (f_in[:, 1] == f_in[:, 2]) | (f_in[:, 2] == f_in[:, 0]))
if bool(degen.any()):
f_in = f_in[~degen]
mesh = _uv_mesh.build_mesh(v_in, f_in)
ff = mesh.face_face
if ff.numel() and float((ff >= 0).float().mean().item()) < 0.25:
warnings.warn("[uv_unwrap] mesh face-adjacency < 25% — vertices appear un-welded "
"(triangle soup); UV charts will be per-face. Raise weld_distance.")
if segmenter == "pec":
if mesh.faces.device.type != "cuda":
raise RuntimeError("segmenter='pec' requires a CUDA mesh; use 'adaptive' for CPU.")
face_chart = _uv_seg.cluster_charts_pec(mesh, target_chart_count=0, max_cost=1.0)
elif segmenter == "adaptive":
face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0, target_chart_count=0)
else:
raise ValueError(f"unknown segmenter '{segmenter}'. valid: pec, adaptive")
n_charts = int(face_chart.max().item()) + 1 if face_chart.numel() else 0
areas_cpu = _uv_mesh.chart_3d_areas(mesh.face_area, face_chart, n_charts).detach().cpu()
# per-chart loop runs on CPU/numpy to avoid per-chart GPU sync
face_chart_np = face_chart.cpu().numpy()
faces_np = mesh.faces.cpu().numpy()
vertices_np = mesh.vertices.cpu().numpy()
face_face_np = mesh.face_face.cpu().numpy()
sorted_face_idx_np = np.argsort(face_chart_np, kind="stable")
chart_counts_np = np.bincount(face_chart_np, minlength=n_charts)
chart_offsets_np = np.empty(n_charts + 1, dtype=np.int64)
chart_offsets_np[0] = 0
np.cumsum(chart_counts_np, out=chart_offsets_np[1:])
all_chart_uvs, all_chart_3d_areas, all_chart_uv_areas, all_chart_faces = [], [], [], []
chart_records = []
for c in range(n_charts):
gfi_np = sorted_face_idx_np[chart_offsets_np[c]:chart_offsets_np[c + 1]]
chart_faces_global = faces_np[gfi_np]
used_verts_np = np.unique(chart_faces_global)
local_faces_np = np.searchsorted(used_verts_np, chart_faces_global)
local_verts_np = vertices_np[used_verts_np]
ff_global = face_face_np[gfi_np]
ff_safe = np.maximum(ff_global, 0)
nb_chart = np.where(ff_global >= 0, face_chart_np[ff_safe], -1)
keep = (ff_global >= 0) & (nb_chart == c)
local_neighbor = np.searchsorted(gfi_np, ff_safe)
local_ff_np = np.where(keep, local_neighbor, -1)
lf = torch.from_numpy(local_faces_np)
uvs = _uv_param.parametrize_chart(
torch.from_numpy(local_verts_np), lf, torch.from_numpy(local_ff_np))
ua, ub, uc = uvs[lf[:, 0]], uvs[lf[:, 1]], uvs[lf[:, 2]]
uv_area_sum = float(0.5 * (
(ub[:, 0] - ua[:, 0]) * (uc[:, 1] - ua[:, 1])
- (uc[:, 0] - ua[:, 0]) * (ub[:, 1] - ua[:, 1])).abs().sum().item())
chart_records.append({"local_faces": lf, "vmap": torch.from_numpy(used_verts_np),
"global_face_idx": torch.from_numpy(gfi_np)})
all_chart_uvs.append(uvs)
all_chart_3d_areas.append(float(areas_cpu[c].item()))
all_chart_uv_areas.append(uv_area_sum)
all_chart_faces.append(lf)
# auto-tune texel density to land near `resolution` (assumes ~0.62 pack fill)
total_3d_area = sum(all_chart_3d_areas) or 1.0
target_dim = float(resolution) if resolution > 0 else 1024.0
tex_per_unit = math.sqrt((target_dim * target_dim) * 0.62 / total_3d_area)
placements, atlas_w, atlas_h = _uv_pack.pack_bitmap(
all_chart_uvs, all_chart_3d_areas, all_chart_uv_areas, all_chart_faces,
texels_per_unit=tex_per_unit, padding_texels=padding)
placed = _uv_pack.apply_placements(all_chart_uvs, placements, atlas_w, atlas_h)
n_in_faces = mesh.faces.shape[0]
out_indices = np.zeros((n_in_faces, 3), dtype=np.int64)
out_uvs_list, out_vmap_list, v_cursor = [], [], 0
for c, rec in enumerate(chart_records):
vmap_np = rec["vmap"].cpu().numpy()
local_faces_np = rec["local_faces"].cpu().numpy()
global_face_idx = rec["global_face_idx"].cpu().numpy()
out_uvs_list.append(placed[c].cpu().numpy())
if welded_to_orig is not None:
vmap_np = welded_to_orig[vmap_np]
out_vmap_list.append(vmap_np)
out_indices[global_face_idx] = local_faces_np + v_cursor
v_cursor += vmap_np.shape[0]
vmapping_out = np.concatenate(out_vmap_list) if out_vmap_list else np.empty(0, dtype=np.int64)
uvs_out = np.concatenate(out_uvs_list) if out_uvs_list else np.empty((0, 2), dtype=np.float32)
return vmapping_out, out_indices, uvs_out
class UnwrapMesh(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="UnwrapMesh",
display_name="Unwrap Mesh UVs",
category="latent/3d",
description=(
"Generates a UV atlas (pure-torch, no xatlas dependency): segments the surface into "
"charts, parameterizes each, and packs them into a [0,1] atlas. Verts on chart seams "
"are duplicated. Run after DecimateMesh/RemeshMesh, before texture baking."
),
inputs=[
IO.Mesh.Input("mesh"),
IO.Combo.Input("segmenter", options=["pec", "adaptive"], default="pec",
tooltip="pec: fast parallel-edge-collapse charting (CUDA; falls back to "
"adaptive on CPU). adaptive: CPU charting, slower."),
IO.Int.Input("resolution", default=1024, min=0, max=8192, step=256,
tooltip="Target atlas resolution used to auto-scale texel density (0 = fit-to-content)."),
IO.Int.Input("padding", default=1, min=0, max=16,
tooltip="Texel padding between charts in the packed atlas."),
IO.Float.Input("weld_distance", default=0.0, min=0.0, max=1.0, step=0.0001,
tooltip="Merge radius for coincident verts as a fraction of mesh extent "
"(0 = auto, 1e-5). Raise to ~0.001 if you get per-triangle charts "
"(unwelded / triangle-soup input)."),
],
outputs=[IO.Mesh.Output("mesh")],
hidden=[IO.Hidden.unique_id],
)
@classmethod
def execute(cls, mesh, segmenter, resolution, padding, weld_distance):
compute_device = comfy.model_management.get_torch_device()
seg = segmenter
if seg == "pec" and compute_device.type != "cuda":
seg = "adaptive"
seg_device = compute_device if seg == "pec" else torch.device("cpu")
is_list = isinstance(mesh.vertices, list)
is_batched = not is_list and mesh.vertices.ndim == 3
bsz = len(mesh.vertices) if is_list else (mesh.vertices.shape[0] if is_batched else 1)
bar = comfy.utils.ProgressBar(bsz)
out_v, out_f, out_uv, out_c = [], [], [], []
for i in range(bsz):
if is_list or is_batched:
vi, fi = mesh.vertices[i], mesh.faces[i]
ci = None
vc = getattr(mesh, "vertex_colors", None)
if vc is not None:
ci = vc[i] if (isinstance(vc, list) or vc.ndim == 3) else vc
else:
vi, fi = mesh.vertices, mesh.faces
ci = getattr(mesh, "vertex_colors", None)
src_device = vi.device
vnp = vi.detach().cpu().numpy().astype(np.float32)
extent = float(np.linalg.norm(vnp.max(0) - vnp.min(0))) if vnp.shape[0] else 0.0
weld_abs = weld_distance * extent if weld_distance > 0.0 else 0.0
vmapping, indices, uvs = _uv_unwrap(
vi.to(seg_device).float(), fi.to(seg_device).long(),
seg, int(resolution), int(padding), weld_abs)
uvs = uvs.copy()
uvs[:, 1] = 1.0 - uvs[:, 1] # UV y flipped vs trimesh
out_v.append(torch.from_numpy(vnp[vmapping]).to(src_device))
out_f.append(torch.from_numpy(indices).to(device=src_device, dtype=torch.long))
out_uv.append(torch.from_numpy(uvs.astype(np.float32)).to(src_device))
if ci is not None:
cnp = ci.detach().cpu().numpy()
out_c.append(torch.from_numpy(np.ascontiguousarray(cnp[vmapping])).to(src_device))
bar.update(1)
out_mesh = _pack_uv_meshes(out_v, out_f, out_uv, out_c if out_c else None)
if getattr(mesh, "texture", None) is not None:
out_mesh.texture = mesh.texture
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(
f"UV: {out_v[0].shape[0]}v / {out_f[0].shape[0]}f, atlas ~{resolution}px",
cls.hidden.unique_id)
return IO.NodeOutput(out_mesh)
def _uv_sorted_edge_keys(indices: np.ndarray):
"""Undirected edge keys per face-edge, sorted; returns (sorted_keys, face_id, lo, hi, first_mask)."""
a = indices.ravel().astype(np.int64)
b = np.roll(indices, -1, axis=1).ravel().astype(np.int64)
lo = np.minimum(a, b)
hi = np.maximum(a, b)
V = int(indices.max()) + 1
key = lo * V + hi
order = np.argsort(key, kind="stable")
sk = key[order]
fid = (np.arange(a.size, dtype=np.int64) // 3)[order]
first = np.ones(sk.size, dtype=bool)
first[1:] = sk[1:] != sk[:-1]
return sk, fid, lo[order], hi[order], first
def _uv_faces_to_chart_ids(indices: np.ndarray) -> np.ndarray:
"""Chart = connected component of faces adjacent iff they share a (non-seam-duplicated) UV vertex."""
F = indices.shape[0]
if F == 0:
return np.empty(0, dtype=np.int64)
_sk, fid, _lo, _hi, first = _uv_sorted_edge_keys(indices)
group_id = np.cumsum(first) - 1
starts = np.nonzero(first)[0]
rows = fid[starts[group_id[~first]]]
cols = fid[~first]
if rows.size == 0:
return np.arange(F, dtype=np.int64)
adj = csr_matrix((np.ones(rows.size, dtype=np.int8), (rows, cols)), shape=(F, F))
_, labels = connected_components(adj, directed=False)
return labels.astype(np.int64)
_UV_TAB20 = np.array([
[0.121568627, 0.466666667, 0.705882353], [0.682352941, 0.780392157, 0.909803922],
[1.000000000, 0.498039216, 0.054901961], [1.000000000, 0.733333333, 0.470588235],
[0.172549020, 0.627450980, 0.172549020], [0.596078431, 0.874509804, 0.541176471],
[0.839215686, 0.152941176, 0.156862745], [1.000000000, 0.596078431, 0.588235294],
[0.580392157, 0.403921569, 0.741176471], [0.772549020, 0.690196078, 0.835294118],
[0.549019608, 0.337254902, 0.294117647], [0.768627451, 0.611764706, 0.580392157],
[0.890196078, 0.466666667, 0.760784314], [0.968627451, 0.713725490, 0.823529412],
[0.498039216, 0.498039216, 0.498039216], [0.780392157, 0.780392157, 0.780392157],
[0.737254902, 0.741176471, 0.133333333], [0.858823529, 0.858823529, 0.552941176],
[0.090196078, 0.745098039, 0.811764706], [0.619607843, 0.854901961, 0.898039216],
], dtype=np.float32)
def _uv_palette(n: int) -> np.ndarray:
rng = np.random.RandomState(42)
perm = rng.permutation(max(1, n))
out = np.empty((n, 3), dtype=np.float32)
for i in range(n):
out[i] = _UV_TAB20[perm[i % len(perm)] % 20]
return out
def _uv_render_atlas(uvs_np, indices_np, resolution, device,
bg=(0.13, 0.13, 0.13), edge=(0.0, 0.0, 0.0)):
"""Tile-based torch rasterizer of the UV atlas (charts colored, borders outlined); returns (H,W,3)."""
w = h = int(resolution)
chart_ids_np = _uv_faces_to_chart_ids(indices_np)
uvs = torch.from_numpy(uvs_np).to(device=device, dtype=torch.float32)
indices = torch.from_numpy(indices_np).to(device=device, dtype=torch.long)
chart_ids = torch.from_numpy(chart_ids_np).to(device=device, dtype=torch.long)
img = torch.zeros((h, w, 3), dtype=torch.float32, device=device)
img[..., 0] = bg[0]; img[..., 1] = bg[1]; img[..., 2] = bg[2]
if indices.numel() == 0:
return img
n_charts = int(chart_ids.max().item()) + 1 if chart_ids.numel() else 1
colors = torch.from_numpy(_uv_palette(n_charts)).to(device=device, dtype=torch.float32)
uv_px = uvs.clone()
uv_px[:, 0] = uv_px[:, 0].clamp(0.0, 1.0) * (w - 1)
uv_px[:, 1] = uv_px[:, 1].clamp(0.0, 1.0) * (h - 1)
tri = uv_px[indices]
x0 = tri[:, 0, 0]; y0 = tri[:, 0, 1]
x1 = tri[:, 1, 0]; y1 = tri[:, 1, 1]
x2 = tri[:, 2, 0]; y2 = tri[:, 2, 1]
denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2)
nondegen = denom.abs() > 1e-20
xmin = torch.minimum(torch.minimum(x0, x1), x2).floor().clamp_(0, w - 1).long()
xmax = torch.maximum(torch.maximum(x0, x1), x2).ceil().clamp_(0, w - 1).long()
ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, h - 1).long()
ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, h - 1).long()
# full point-in-triangle over all (pixel, tri) pairs is O(H*W*F); tile and test only bbox-overlapping tris
TILE = 64
eps = 1e-6
for ty in range(0, h, TILE):
ty_end = min(ty + TILE, h)
for tx in range(0, w, TILE):
tx_end = min(tx + TILE, w)
tri_mask = (nondegen & (xmin < tx_end) & (xmax >= tx)
& (ymin < ty_end) & (ymax >= ty))
if not tri_mask.any():
continue
idx = torch.nonzero(tri_mask, as_tuple=True)[0]
ys = torch.arange(ty, ty_end, dtype=torch.float32, device=device) + 0.5
xs = torch.arange(tx, tx_end, dtype=torch.float32, device=device) + 0.5
yy, xx = torch.meshgrid(ys, xs, indexing="ij")
sub_x0 = x0[idx][:, None, None]; sub_y0 = y0[idx][:, None, None]
sub_x1 = x1[idx][:, None, None]; sub_y1 = y1[idx][:, None, None]
sub_x2 = x2[idx][:, None, None]; sub_y2 = y2[idx][:, None, None]
sub_den = denom[idx][:, None, None]
bx = ((sub_y1 - sub_y2) * (xx - sub_x2) + (sub_x2 - sub_x1) * (yy - sub_y2)) / sub_den
by = ((sub_y2 - sub_y0) * (xx - sub_x2) + (sub_x0 - sub_x2) * (yy - sub_y2)) / sub_den
bz = 1.0 - bx - by
inside = (bx >= -eps) & (by >= -eps) & (bz >= -eps)
if not inside.any():
continue
hit_any = inside.any(dim=0)
best_tri = idx[inside.int().argmax(dim=0)]
tile_color = colors[chart_ids[best_tri]]
tile_img = img[ty:ty_end, tx:tx_end]
tile_img[hit_any] = tile_color[hit_any]
img[ty:ty_end, tx:tx_end] = tile_img
# chart outlines: a chart border is an open boundary in UV space (seam verts duplicated) → edges with 1 incident face
_sk, _fid, lo, hi, first = _uv_sorted_edge_keys(indices_np)
starts = np.nonzero(first)[0]
counts = np.diff(np.append(starts, first.size))
boundary = counts == 1
uv_cpu = uv_px.cpu().numpy()
px_xs, px_ys = [], []
for a, b in zip(lo[starts[boundary]], hi[starts[boundary]]):
p0 = uv_cpu[a]; p1 = uv_cpu[b]
steps = int(max(abs(p1[0] - p0[0]), abs(p1[1] - p0[1])) + 1)
if steps <= 1:
continue
ts = np.linspace(0.0, 1.0, steps)
xs = (p0[0] + (p1[0] - p0[0]) * ts).astype(np.int32)
ys = (p0[1] + (p1[1] - p0[1]) * ts).astype(np.int32)
valid = (xs >= 0) & (xs < w) & (ys >= 0) & (ys < h)
px_xs.append(xs[valid]); px_ys.append(ys[valid])
if px_xs:
xs_all = torch.from_numpy(np.concatenate(px_xs)).to(device=device, dtype=torch.long)
ys_all = torch.from_numpy(np.concatenate(px_ys)).to(device=device, dtype=torch.long)
img[ys_all, xs_all] = torch.tensor(edge, dtype=torch.float32, device=device)
return img
class RenderUVAtlas(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RenderUVAtlas",
display_name="Render UV Atlas",
category="latent/3d",
description=("Renders a mesh's UV layout as an image — each chart a distinct color, "
"outlined where it borders other charts. Run UnwrapMesh first."),
inputs=[
IO.Mesh.Input("mesh"),
IO.Int.Input("resolution", default=1024, min=64, max=4096, step=64),
],
outputs=[IO.Image.Output("image")],
)
@classmethod
def execute(cls, mesh, resolution):
uvs_t = getattr(mesh, "uvs", None)
if uvs_t is None:
raise RuntimeError("mesh has no UVs to render. Run UnwrapMesh first.")
uvs_np = uvs_t.detach().cpu().numpy()
if uvs_np.ndim == 3:
uvs_np = uvs_np[0]
f = mesh.faces
if torch.is_tensor(f):
f = f.detach().cpu().numpy()
if f.ndim == 3:
f = f[0]
f = np.ascontiguousarray(f, dtype=np.int64)
uvs_np = np.ascontiguousarray(uvs_np, dtype=np.float32)
device = comfy.model_management.get_torch_device()
img = _uv_render_atlas(uvs_np, f, int(resolution), device)
return IO.NodeOutput(img.detach().cpu().unsqueeze(0))
class FillHoles(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -2379,6 +2950,9 @@ class PostProcessMeshExtension(ComfyExtension):
FillHolesV2,
WeldVertices,
DecimateMesh,
RemeshMesh,
UnwrapMesh,
RenderUVAtlas,
PaintMesh,
BakeTextureFromVoxel,
MeshTextureToImage,