mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 05:31:03 +08:00
Remesh, UV unwrap
This commit is contained in:
parent
72ff035fe0
commit
6ef69849a0
@ -1618,3 +1618,76 @@ def simplify(
|
||||
out_n if out_n else None,
|
||||
out_s)
|
||||
return qem_simplify(vertices, faces, target, colors, normals, max_edge_length, config)
|
||||
|
||||
|
||||
def cluster_decimate(
|
||||
vertices: torch.Tensor, faces: torch.Tensor,
|
||||
target_verts: int = 1_000_000,
|
||||
colors: Optional[torch.Tensor] = None,
|
||||
face_chunk: int = 4_000_000,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Vertex-cluster decimation (Rossignac-Borrel): bin verts into a ~target_verts grid,
|
||||
average per cell, remap faces (chunked), drop degenerate/duplicate. Fast O(V+F) prepass
|
||||
for huge meshes before QEM/remesh. Returns (verts, faces, colors)."""
|
||||
if vertices.shape[0] == 0 or faces.shape[0] == 0:
|
||||
return vertices, faces, colors
|
||||
|
||||
device = vertices.device
|
||||
bbox = vertices.max(dim=0)[0] - vertices.min(dim=0)[0]
|
||||
bbox_min = vertices.min(dim=0)[0]
|
||||
# cell size so the bbox holds ~3× target_verts cells (surface occupancy ~1/3)
|
||||
cell_count_target = max(target_verts * 3, 1000)
|
||||
extent_max = float(bbox.max().item())
|
||||
cells_per_axis = (cell_count_target ** (1 / 3))
|
||||
cell_size = extent_max / max(1.0, cells_per_axis)
|
||||
scale = 1.0 / max(cell_size, 1e-20)
|
||||
|
||||
q = ((vertices - bbox_min) * scale).floor().to(torch.int64)
|
||||
extent = (bbox * scale).floor().to(torch.int64) + 2
|
||||
Wy = extent[1]
|
||||
Wz = extent[2]
|
||||
key = (q[:, 0] * Wy + q[:, 1]) * Wz + q[:, 2]
|
||||
|
||||
unique_key, inv = torch.unique(key, return_inverse=True)
|
||||
n_unique = unique_key.shape[0]
|
||||
counts = torch.zeros(n_unique, dtype=vertices.dtype, device=device)
|
||||
counts.scatter_add_(0, inv, torch.ones(vertices.shape[0], dtype=vertices.dtype, device=device))
|
||||
counts_div = counts.unsqueeze(-1).clamp_min(1.0)
|
||||
|
||||
new_verts = torch.zeros((n_unique, 3), dtype=vertices.dtype, device=device)
|
||||
new_verts.scatter_add_(0, inv.unsqueeze(-1).expand_as(vertices), vertices)
|
||||
new_verts = new_verts / counts_div
|
||||
|
||||
new_colors = None
|
||||
if colors is not None:
|
||||
new_colors = torch.zeros((n_unique, colors.shape[1]), dtype=colors.dtype, device=device)
|
||||
new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors)
|
||||
new_colors = new_colors / counts_div.to(colors.dtype)
|
||||
|
||||
# remap faces in chunks (face tensor can be huge); drop degenerates per chunk
|
||||
out_chunks = []
|
||||
F = faces.shape[0]
|
||||
for fs in range(0, F, face_chunk):
|
||||
fe = min(fs + face_chunk, F)
|
||||
cf = inv[faces[fs:fe].long()]
|
||||
nondeg = ((cf[:, 0] != cf[:, 1]) & (cf[:, 1] != cf[:, 2]) & (cf[:, 0] != cf[:, 2]))
|
||||
if nondeg.any():
|
||||
out_chunks.append(cf[nondeg])
|
||||
if out_chunks:
|
||||
new_faces = torch.cat(out_chunks, dim=0)
|
||||
else:
|
||||
new_faces = torch.empty((0, 3), dtype=faces.dtype, device=device)
|
||||
|
||||
# drop duplicate faces (same vertex set after clustering)
|
||||
if new_faces.numel() > 0:
|
||||
key_sorted = torch.sort(new_faces, dim=1)[0]
|
||||
P = n_unique + 1
|
||||
packed = (key_sorted[:, 0].long() * P + key_sorted[:, 1].long()) * P + key_sorted[:, 2].long()
|
||||
_, first = torch.unique(packed, return_inverse=True)
|
||||
arange = torch.arange(packed.shape[0], device=device, dtype=torch.int64)
|
||||
first_idx = torch.full((int(first.max().item()) + 1,), packed.shape[0],
|
||||
dtype=torch.int64, device=device)
|
||||
first_idx.scatter_reduce_(0, first, arange, reduce="amin", include_self=True)
|
||||
new_faces = new_faces[first_idx]
|
||||
|
||||
return new_verts.to(vertices.dtype), new_faces.to(faces.dtype), new_colors
|
||||
1151
comfy_extras/mesh3d/postprocess/remesh.py
Normal file
1151
comfy_extras/mesh3d/postprocess/remesh.py
Normal file
File diff suppressed because it is too large
Load Diff
158
comfy_extras/mesh3d/uv_unwrap/mesh.py
Normal file
158
comfy_extras/mesh3d/uv_unwrap/mesh.py
Normal file
@ -0,0 +1,158 @@
|
||||
"""Mesh container, edge/face adjacency, manifold cleanup."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.sparse import csr_matrix
|
||||
from scipy.sparse.csgraph import connected_components
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
# ---- Per-face / per-vertex geometry ----
|
||||
|
||||
def face_normals(vertices: Tensor, faces: Tensor) -> Tensor:
|
||||
"""[F,3] unit face normals (degenerate faces -> zero)."""
|
||||
v0 = vertices[faces[:, 0]]; v1 = vertices[faces[:, 1]]; v2 = vertices[faces[:, 2]]
|
||||
n = torch.linalg.cross(v1 - v0, v2 - v0)
|
||||
return n / n.norm(dim=1, keepdim=True).clamp_min(1e-20)
|
||||
|
||||
|
||||
def face_areas(vertices: Tensor, faces: Tensor) -> Tensor:
|
||||
"""[F] triangle areas."""
|
||||
v0 = vertices[faces[:, 0]]; v1 = vertices[faces[:, 1]]; v2 = vertices[faces[:, 2]]
|
||||
return 0.5 * torch.linalg.cross(v1 - v0, v2 - v0).norm(dim=1)
|
||||
|
||||
|
||||
def face_centroids(vertices: Tensor, faces: Tensor) -> Tensor:
|
||||
"""[F,3] triangle centroids."""
|
||||
return vertices[faces].mean(dim=1)
|
||||
|
||||
|
||||
def face_edge_lengths(vertices: Tensor, faces: Tensor) -> Tensor:
|
||||
"""[F,3] edge lengths; column e = |v[faces[:,e]] - v[faces[:,(e+1)%3]]|."""
|
||||
va = vertices[faces]
|
||||
vb = vertices[faces.roll(shifts=-1, dims=1)]
|
||||
return (vb - va).norm(dim=-1).to(torch.float32)
|
||||
|
||||
|
||||
def chart_3d_areas(face_area: Tensor, face_chart: Tensor, n_charts: int) -> Tensor:
|
||||
"""[n_charts] sum of face areas per chart."""
|
||||
out = torch.zeros(n_charts, dtype=face_area.dtype, device=face_area.device)
|
||||
out.scatter_add_(0, face_chart, face_area)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeshData:
|
||||
"""Cleaned mesh with adjacency; face_face[f, i] = face sharing edge (faces[f,i], faces[f,(i+1)%3]) or -1 if boundary."""
|
||||
|
||||
vertices: Tensor # [V, 3] float
|
||||
faces: Tensor # [F, 3] long
|
||||
face_face: Tensor # [F, 3] long, neighbor face id or -1
|
||||
face_normal: Tensor # [F, 3] float
|
||||
face_area: Tensor # [F] float
|
||||
face_centroid: Tensor # [F, 3] float
|
||||
component: Tensor # [F] long, connected-component id
|
||||
n_components: int
|
||||
|
||||
|
||||
def build_mesh(vertices: Tensor, faces: Tensor) -> MeshData:
|
||||
"""Build adjacency; non-manifold edges (>2 incident faces) get no neighbor and act as boundary."""
|
||||
if vertices.dtype != torch.float32:
|
||||
vertices = vertices.to(torch.float32)
|
||||
if faces.dtype != torch.long:
|
||||
faces = faces.to(torch.long)
|
||||
|
||||
device = faces.device
|
||||
V = vertices.shape[0]
|
||||
F = faces.shape[0]
|
||||
|
||||
# Per directed face-edge; flat layout p = f*3+i.
|
||||
a = faces.flatten()
|
||||
b = faces.roll(shifts=-1, dims=1).flatten()
|
||||
lo = torch.minimum(a, b)
|
||||
hi = torch.maximum(a, b)
|
||||
edge_key = lo * (V + 1) + hi
|
||||
|
||||
# Pair manifold (count==2) face-edges; others get no neighbor.
|
||||
_, inverse, counts = torch.unique(edge_key, return_inverse=True, return_counts=True)
|
||||
edge_count = counts[inverse]
|
||||
manifold_mask = edge_count == 2
|
||||
|
||||
sort_idx = torch.argsort(edge_key, stable=True)
|
||||
sorted_manifold = manifold_mask[sort_idx]
|
||||
pair_positions = sort_idx[sorted_manifold]
|
||||
pair_a = pair_positions[0::2]
|
||||
pair_b = pair_positions[1::2]
|
||||
|
||||
face_id_flat = torch.arange(F, device=device).repeat_interleave(3)
|
||||
face_face_flat = torch.full((3 * F,), -1, dtype=torch.long, device=device)
|
||||
face_face_flat[pair_a] = face_id_flat[pair_b]
|
||||
face_face_flat[pair_b] = face_id_flat[pair_a]
|
||||
face_face = face_face_flat.view(F, 3)
|
||||
|
||||
face_face_np = face_face.cpu().numpy()
|
||||
rows_mask = face_face_np >= 0
|
||||
if rows_mask.any():
|
||||
rows = np.broadcast_to(np.arange(F)[:, None], (F, 3))[rows_mask]
|
||||
cols = face_face_np[rows_mask]
|
||||
adj = csr_matrix(
|
||||
(np.ones(rows.size, dtype=np.int8), (rows, cols)),
|
||||
shape=(F, F),
|
||||
)
|
||||
else:
|
||||
adj = csr_matrix((F, F), dtype=np.int8)
|
||||
n_components, labels = connected_components(adj, directed=False)
|
||||
|
||||
face_normal = face_normals(vertices, faces)
|
||||
face_area = face_areas(vertices, faces)
|
||||
face_centroid = face_centroids(vertices, faces)
|
||||
|
||||
return MeshData(
|
||||
vertices=vertices,
|
||||
faces=faces,
|
||||
face_face=face_face,
|
||||
face_normal=face_normal,
|
||||
face_area=face_area,
|
||||
face_centroid=face_centroid,
|
||||
component=torch.from_numpy(labels.astype(np.int64)).to(device),
|
||||
n_components=int(n_components),
|
||||
)
|
||||
|
||||
|
||||
def chart_boundary_loops(
|
||||
faces_subset: Tensor, face_face_subset: Tensor
|
||||
) -> List[List[int]]:
|
||||
"""Return ordered boundary vertex loops for a chart submesh (face_face_subset[f,i]==-1 marks a boundary edge)."""
|
||||
F = faces_subset.shape[0]
|
||||
faces_np = faces_subset.cpu().numpy()
|
||||
ff = face_face_subset.cpu().numpy()
|
||||
|
||||
next_v: Dict[int, int] = {}
|
||||
for f in range(F):
|
||||
for i in range(3):
|
||||
if ff[f, i] == -1:
|
||||
a = int(faces_np[f, i])
|
||||
b = int(faces_np[f, (i + 1) % 3])
|
||||
next_v[a] = b
|
||||
|
||||
loops: List[List[int]] = []
|
||||
visited = set()
|
||||
for start in list(next_v.keys()):
|
||||
if start in visited:
|
||||
continue
|
||||
loop = [start]
|
||||
visited.add(start)
|
||||
cur = next_v.get(start)
|
||||
while cur is not None and cur != start:
|
||||
if cur in visited:
|
||||
break
|
||||
loop.append(cur)
|
||||
visited.add(cur)
|
||||
cur = next_v.get(cur)
|
||||
if len(loop) >= 3:
|
||||
loops.append(loop)
|
||||
return loops
|
||||
759
comfy_extras/mesh3d/uv_unwrap/pack.py
Normal file
759
comfy_extras/mesh3d/uv_unwrap/pack.py
Normal file
@ -0,0 +1,759 @@
|
||||
"""Atlas packing via bitmap rasterize-and-place."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
try:
|
||||
from numba import njit as _njit
|
||||
_HAVE_NUMBA_PACK = True
|
||||
except ImportError:
|
||||
_HAVE_NUMBA_PACK = False
|
||||
def _njit(*args, **kwargs):
|
||||
def deco(fn): return fn
|
||||
return deco if not args else args[0]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChartPlacement:
|
||||
chart_id: int
|
||||
offset: Tuple[float, float] # in texels
|
||||
scale: float # texels per UV unit
|
||||
rotation: float = 0.0 # radians
|
||||
swap_xy: bool = False # extra 90° bitmap rotation chosen at place time
|
||||
chart_h: float = 0.0 # unswapped bitmap height in texels (rotation pivot)
|
||||
|
||||
|
||||
@_njit(cache=True, boundscheck=False)
|
||||
def _best_rotation_jit(uvs_np: np.ndarray, n_angles: int) -> float:
|
||||
V = uvs_np.shape[0]
|
||||
best_area = 1e30
|
||||
best_theta = 0.0
|
||||
if V == 0:
|
||||
return 0.0
|
||||
half_pi = math.pi * 0.5
|
||||
for k in range(n_angles):
|
||||
theta = half_pi * k / n_angles
|
||||
c = math.cos(theta); s = math.sin(theta)
|
||||
xmin = 1e30; xmax = -1e30
|
||||
ymin = 1e30; ymax = -1e30
|
||||
for i in range(V):
|
||||
ux = uvs_np[i, 0]; uy = uvs_np[i, 1]
|
||||
xr = ux * c - uy * s
|
||||
yr = ux * s + uy * c
|
||||
if xr < xmin: xmin = xr
|
||||
if xr > xmax: xmax = xr
|
||||
if yr < ymin: ymin = yr
|
||||
if yr > ymax: ymax = yr
|
||||
area = (xmax - xmin) * (ymax - ymin)
|
||||
if area < best_area:
|
||||
best_area = area
|
||||
best_theta = theta
|
||||
return best_theta
|
||||
|
||||
|
||||
def _best_rotation(uvs_np: np.ndarray, n_angles: int = 36) -> float:
|
||||
return float(_best_rotation_jit(uvs_np.astype(np.float64), n_angles))
|
||||
|
||||
|
||||
def _rotate_xy(uv: np.ndarray, theta: float) -> np.ndarray:
|
||||
if theta == 0.0:
|
||||
return uv
|
||||
c = math.cos(theta)
|
||||
s = math.sin(theta)
|
||||
return np.stack([uv[:, 0] * c - uv[:, 1] * s, uv[:, 0] * s + uv[:, 1] * c], axis=1)
|
||||
|
||||
|
||||
@_njit(cache=True, boundscheck=False)
|
||||
def _rasterize_chart_jit(
|
||||
uvs_tex: np.ndarray, faces: np.ndarray, w: int, h: int
|
||||
) -> np.ndarray:
|
||||
"""JIT-rasterize triangles into an (h, w) bool bitmap via barycentric test."""
|
||||
bm = np.zeros((h, w), dtype=np.bool_)
|
||||
F = faces.shape[0]
|
||||
eps = 1e-7
|
||||
for fi in range(F):
|
||||
i0 = faces[fi, 0]; i1 = faces[fi, 1]; i2 = faces[fi, 2]
|
||||
x0 = uvs_tex[i0, 0]; y0 = uvs_tex[i0, 1]
|
||||
x1 = uvs_tex[i1, 0]; y1 = uvs_tex[i1, 1]
|
||||
x2 = uvs_tex[i2, 0]; y2 = uvs_tex[i2, 1]
|
||||
xmin_f = x0
|
||||
if x1 < xmin_f: xmin_f = x1
|
||||
if x2 < xmin_f: xmin_f = x2
|
||||
xmax_f = x0
|
||||
if x1 > xmax_f: xmax_f = x1
|
||||
if x2 > xmax_f: xmax_f = x2
|
||||
ymin_f = y0
|
||||
if y1 < ymin_f: ymin_f = y1
|
||||
if y2 < ymin_f: ymin_f = y2
|
||||
ymax_f = y0
|
||||
if y1 > ymax_f: ymax_f = y1
|
||||
if y2 > ymax_f: ymax_f = y2
|
||||
xmin = int(math.floor(xmin_f))
|
||||
if xmin < 0: xmin = 0
|
||||
xmax = int(math.ceil(xmax_f))
|
||||
if xmax > w - 1: xmax = w - 1
|
||||
ymin = int(math.floor(ymin_f))
|
||||
if ymin < 0: ymin = 0
|
||||
ymax = int(math.ceil(ymax_f))
|
||||
if ymax > h - 1: ymax = h - 1
|
||||
if xmax < xmin or ymax < ymin:
|
||||
continue
|
||||
denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2)
|
||||
if abs(denom) < 1e-20:
|
||||
continue
|
||||
inv_denom = 1.0 / denom
|
||||
for py in range(ymin, ymax + 1):
|
||||
yc = py + 0.5
|
||||
for px in range(xmin, xmax + 1):
|
||||
xc = px + 0.5
|
||||
a = ((y1 - y2) * (xc - x2) + (x2 - x1) * (yc - y2)) * inv_denom
|
||||
b = ((y2 - y0) * (xc - x2) + (x0 - x2) * (yc - y2)) * inv_denom
|
||||
c = 1.0 - a - b
|
||||
if a >= -eps and b >= -eps and c >= -eps:
|
||||
bm[py, px] = True
|
||||
return bm
|
||||
|
||||
|
||||
def _rasterize_chart(
|
||||
uvs_tex: np.ndarray, faces: np.ndarray, w: int, h: int, padding: int
|
||||
) -> np.ndarray:
|
||||
"""Rasterize chart triangles into (h, w) bool bitmap, dilated by padding texels."""
|
||||
if faces.size == 0:
|
||||
return np.zeros((h, w), dtype=bool)
|
||||
bm = _rasterize_chart_jit(
|
||||
uvs_tex.astype(np.float64), faces.astype(np.int64), int(w), int(h)
|
||||
)
|
||||
if padding > 0:
|
||||
bm = _dilate_bitmap(bm, padding)
|
||||
return bm
|
||||
|
||||
|
||||
def _dilate_bitmap(bm: np.ndarray, k: int) -> np.ndarray:
|
||||
"""k-step Manhattan max-filter dilation."""
|
||||
out = bm.copy()
|
||||
for _ in range(k):
|
||||
next_out = out.copy()
|
||||
next_out[1:, :] |= out[:-1, :]
|
||||
next_out[:-1, :] |= out[1:, :]
|
||||
next_out[:, 1:] |= out[:, :-1]
|
||||
next_out[:, :-1] |= out[:, 1:]
|
||||
out = next_out
|
||||
return out
|
||||
|
||||
|
||||
@_njit(cache=True, boundscheck=False)
|
||||
def _build_candidates_jit(
|
||||
skyline: np.ndarray,
|
||||
cur_w: int, cur_h: int,
|
||||
bw0: int, bh0: int, bw1: int, bh1: int,
|
||||
step: int,
|
||||
) -> np.ndarray:
|
||||
"""Build per-chart (x, y, swap_flag) candidate positions (skyline-flush + edge-sweep, both orientations)."""
|
||||
nx_skyline = (max(cur_w, 1) // step) + 2
|
||||
nx_edge = (max(cur_w, 1) // step) + 2
|
||||
ny_edge = (max(cur_h, 1) // step) + 2
|
||||
per_orient = nx_skyline + 2 * nx_edge + 2 * ny_edge
|
||||
out = np.empty((per_orient * 2, 3), dtype=np.int64)
|
||||
k = 0
|
||||
for swap_flag in range(2):
|
||||
cw = bw0 if swap_flag == 0 else bw1
|
||||
x = 0
|
||||
while x <= cur_w:
|
||||
y = 0
|
||||
x_end = x + cw
|
||||
if x_end > skyline.shape[0]:
|
||||
x_end = skyline.shape[0]
|
||||
for xs in range(x, x_end):
|
||||
if skyline[xs] > y:
|
||||
y = int(skyline[xs])
|
||||
out[k, 0] = x; out[k, 1] = y; out[k, 2] = swap_flag
|
||||
k += 1
|
||||
x += step
|
||||
for y_fixed in (0, cur_h):
|
||||
x = 0
|
||||
while x <= cur_w:
|
||||
out[k, 0] = x; out[k, 1] = y_fixed; out[k, 2] = swap_flag
|
||||
k += 1
|
||||
x += step
|
||||
for x_fixed in (0, cur_w):
|
||||
y = 0
|
||||
while y <= cur_h:
|
||||
out[k, 0] = x_fixed; out[k, 1] = y; out[k, 2] = swap_flag
|
||||
k += 1
|
||||
y += step
|
||||
return out[:k]
|
||||
|
||||
|
||||
@_njit(cache=True, boundscheck=False)
|
||||
def _update_skyline_jit(skyline: np.ndarray, chart: np.ndarray,
|
||||
x: int, y: int) -> None:
|
||||
"""Lift skyline[x+i] to y + topmost_True_row + 1 per chart column."""
|
||||
ch = chart.shape[0]; cw = chart.shape[1]
|
||||
sw = skyline.shape[0]
|
||||
for i in range(cw):
|
||||
col_x = x + i
|
||||
if col_x >= sw or col_x < 0:
|
||||
continue
|
||||
col_top = -1
|
||||
for j in range(ch - 1, -1, -1):
|
||||
if chart[j, i]:
|
||||
col_top = j
|
||||
break
|
||||
if col_top < 0:
|
||||
continue
|
||||
new_h = y + col_top + 1
|
||||
if new_h > skyline[col_x]:
|
||||
skyline[col_x] = new_h
|
||||
|
||||
|
||||
@_njit(cache=True, boundscheck=False)
|
||||
def _best_placement_jit(
|
||||
atlas: np.ndarray,
|
||||
bitmap: np.ndarray,
|
||||
bitmap_rot: np.ndarray,
|
||||
candidates: np.ndarray,
|
||||
cur_w: int,
|
||||
cur_h: int,
|
||||
):
|
||||
"""Pick lowest-score non-colliding candidate (score = max(new_w,new_h)^2 + new_w*new_h); out-of-atlas treated as free."""
|
||||
n = candidates.shape[0]
|
||||
best_x = -1
|
||||
best_y = -1
|
||||
best_score = -1
|
||||
best_swap = 0
|
||||
bh0 = bitmap.shape[0]; bw0 = bitmap.shape[1]
|
||||
bh1 = bitmap_rot.shape[0]; bw1 = bitmap_rot.shape[1]
|
||||
ah = atlas.shape[0]; aw = atlas.shape[1]
|
||||
for k in range(n):
|
||||
x = candidates[k, 0]
|
||||
y = candidates[k, 1]
|
||||
swap = candidates[k, 2]
|
||||
if swap == 0:
|
||||
ch = bh0; cw = bw0
|
||||
else:
|
||||
ch = bh1; cw = bw1
|
||||
if x < 0 or y < 0:
|
||||
continue
|
||||
nw = cur_w if cur_w > x + cw else x + cw
|
||||
nh = cur_h if cur_h > y + ch else y + ch
|
||||
ext = nw if nw > nh else nh
|
||||
score = ext * ext + nw * nh
|
||||
if best_score >= 0 and score >= best_score:
|
||||
continue
|
||||
ok = True
|
||||
for j in range(ch):
|
||||
yy = y + j
|
||||
if yy >= ah:
|
||||
continue
|
||||
for i in range(cw):
|
||||
bit = bitmap[j, i] if swap == 0 else bitmap_rot[j, i]
|
||||
if not bit:
|
||||
continue
|
||||
xx = x + i
|
||||
if xx >= aw:
|
||||
continue
|
||||
if atlas[yy, xx]:
|
||||
ok = False
|
||||
break
|
||||
if not ok:
|
||||
break
|
||||
if not ok:
|
||||
continue
|
||||
best_x = x; best_y = y
|
||||
best_score = score; best_swap = swap
|
||||
if x + cw <= cur_w and y + ch <= cur_h:
|
||||
break
|
||||
return best_x, best_y, best_score, best_swap
|
||||
|
||||
|
||||
def _blit(atlas: np.ndarray, chart: np.ndarray, x: int, y: int) -> None:
|
||||
ah, aw = atlas.shape
|
||||
ch, cw = chart.shape
|
||||
atlas[y: y + ch, x: x + cw] |= chart
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PreparedChart:
|
||||
chart_id: int
|
||||
uvs_tex: np.ndarray # [V, 2] in texel coords (rotated, scaled, origin 0)
|
||||
bitmap: np.ndarray # [h, w] bool, padded
|
||||
bitmap_rot: np.ndarray # 90° rotated bitmap (for swap_xy placement)
|
||||
bbox_w: int
|
||||
bbox_h: int
|
||||
rotation: float # radians, applied to UVs
|
||||
s_tex: float # texels per UV unit
|
||||
perimeter: float # for chart ordering
|
||||
|
||||
|
||||
@_njit(cache=True, boundscheck=False)
|
||||
def _chart_perimeter_jit(uvs: np.ndarray, faces: np.ndarray, V: int) -> float:
|
||||
"""Sum unique-edge lengths via sorted int64 edge keys."""
|
||||
F = faces.shape[0]
|
||||
keys = np.empty(F * 3, dtype=np.int64)
|
||||
for fi in range(F):
|
||||
for j in range(3):
|
||||
a = faces[fi, j]
|
||||
b = faces[fi, (j + 1) % 3]
|
||||
if a < b:
|
||||
keys[fi * 3 + j] = a * V + b
|
||||
else:
|
||||
keys[fi * 3 + j] = b * V + a
|
||||
keys = np.sort(keys)
|
||||
p = 0.0
|
||||
for i in range(keys.shape[0]):
|
||||
if i > 0 and keys[i] == keys[i - 1]:
|
||||
continue
|
||||
a = keys[i] // V
|
||||
b = keys[i] % V
|
||||
dx = uvs[a, 0] - uvs[b, 0]
|
||||
dy = uvs[a, 1] - uvs[b, 1]
|
||||
p += math.sqrt(dx * dx + dy * dy)
|
||||
return p
|
||||
|
||||
|
||||
def _chart_perimeter(uvs: np.ndarray, faces: np.ndarray) -> float:
|
||||
V = int(faces.max()) + 1 if faces.size else 0
|
||||
return float(_chart_perimeter_jit(uvs.astype(np.float64), faces.astype(np.int64), V))
|
||||
|
||||
|
||||
# ---- Torch fallback (used when numba is unavailable; runs on GPU if present) ----
|
||||
|
||||
def _dilate_local(x: Tensor, p: int) -> Tensor:
|
||||
"""4-connectivity dilation by p, applied per-image over a batch of (cnt,g,g) bitmaps.
|
||||
Matches the old per-chart _dilate_torch; dilation distributes over union so per-triangle
|
||||
dilation OR-scattered equals dilating the assembled chart bitmap."""
|
||||
for _ in range(p):
|
||||
y = x.clone()
|
||||
y[:, 1:, :] |= x[:, :-1, :]; y[:, :-1, :] |= x[:, 1:, :]
|
||||
y[:, :, 1:] |= x[:, :, :-1]; y[:, :, :-1] |= x[:, :, 1:]
|
||||
x = y
|
||||
return x
|
||||
|
||||
|
||||
def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device):
|
||||
"""Batched rasterize EVERY chart at once into one flat bool buffer, replacing the per-chart
|
||||
loop. Returns (buf, cbase) where buf[cbase[i]:cbase[i+1]].view(bh,bw) is chart i's [y,x] bitmap.
|
||||
Triangles are bucketed by next-pow2 bbox size so each batch's local grid stays tiny (bounded
|
||||
memory) while collapsing ~N chart rasters into a handful of kernels."""
|
||||
n = uvs_tex_pad.shape[0]
|
||||
fmax = faces_pad.shape[1]
|
||||
bwL, bhL = bw_t.long(), bh_t.long()
|
||||
cbase = torch.zeros(n + 1, dtype=torch.long, device=device)
|
||||
torch.cumsum(bwL * bhL, 0, out=cbase[1:])
|
||||
buf = torch.zeros(int(cbase[-1].item()), dtype=torch.bool, device=device)
|
||||
|
||||
# gather all triangle coords, keep only valid faces -> (Ttot,3,2) + chart id per triangle
|
||||
fp = faces_pad.reshape(n, fmax * 3)
|
||||
tri = torch.gather(uvs_tex_pad, 1, fp[..., None].expand(-1, -1, 2)).reshape(n * fmax, 3, 2)
|
||||
fm = fmask.reshape(-1)
|
||||
tri_f = tri[fm]
|
||||
if tri_f.shape[0] == 0:
|
||||
return buf, cbase
|
||||
cid = torch.arange(n, device=device).repeat_interleave(fmax)[fm]
|
||||
|
||||
# per-triangle pixel bbox, inflated by padding (origin >= 0); bucket by next-pow2 max-dim
|
||||
tmin = tri_f.amin(1); tmax = tri_f.amax(1)
|
||||
x0 = (tmin[:, 0].floor().long() - padding).clamp_min(0)
|
||||
y0 = (tmin[:, 1].floor().long() - padding).clamp_min(0)
|
||||
bbw = (tmax[:, 0].ceil().long() + padding) - x0 + 1
|
||||
bbh = (tmax[:, 1].ceil().long() + padding) - y0 + 1
|
||||
mxd = torch.maximum(bbw, bbh).clamp_min(1)
|
||||
bsz = (2 ** torch.ceil(torch.log2(mxd.float())).long()).long()
|
||||
|
||||
a = tri_f[:, 0]; b = tri_f[:, 1]; c = tri_f[:, 2]
|
||||
v0 = b - a; v1 = c - a
|
||||
d00 = (v0 * v0).sum(-1); d01 = (v0 * v1).sum(-1); d11 = (v1 * v1).sum(-1)
|
||||
den = (d00 * d11 - d01 * d01).clamp(min=1e-20)
|
||||
|
||||
for g in sorted(set(bsz.tolist())): # one batch per pow2 grid
|
||||
sel = (bsz == g).nonzero(as_tuple=True)[0]
|
||||
m = sel.shape[0]
|
||||
xs0 = x0[sel].view(m, 1, 1); ys0 = y0[sel].view(m, 1, 1)
|
||||
cc = cid[sel]; bwp = bwL[cc].view(m, 1, 1); bhp = bhL[cc].view(m, 1, 1)
|
||||
gi = torch.arange(g, device=device)
|
||||
px = xs0 + gi.view(1, 1, g); py = ys0 + gi.view(1, g, 1) # (m,g,g) int
|
||||
pxf = px.float() + 0.5; pyf = py.float() + 0.5
|
||||
v2x = pxf - a[sel, 0].view(m, 1, 1); v2y = pyf - a[sel, 1].view(m, 1, 1)
|
||||
d20 = v2x * v0[sel, 0].view(m, 1, 1) + v2y * v0[sel, 1].view(m, 1, 1)
|
||||
d21 = v2x * v1[sel, 0].view(m, 1, 1) + v2y * v1[sel, 1].view(m, 1, 1)
|
||||
idn = den[sel].view(m, 1, 1).reciprocal()
|
||||
vv = torch.addcmul(d11[sel].view(m, 1, 1) * d20, d01[sel].view(m, 1, 1), d21, value=-1) * idn
|
||||
ww = torch.addcmul(d00[sel].view(m, 1, 1) * d21, d01[sel].view(m, 1, 1), d20, value=-1) * idn
|
||||
uu = 1.0 - vv - ww
|
||||
inside = (uu >= -1e-6) & (vv >= -1e-6) & (ww >= -1e-6)
|
||||
if padding > 0:
|
||||
inside = _dilate_local(inside, padding)
|
||||
valid = inside & (px < bwp) & (py < bhp)
|
||||
flat = (cbase[cc].view(m, 1, 1) + py * bwp + px)[valid]
|
||||
buf[flat] = True
|
||||
return buf, cbase
|
||||
|
||||
|
||||
def _build_candidates_gpu(sky_t, cur_w, cur_h, bw0, bw1, step, rand_n, gen, device):
|
||||
"""Skyline-flush + edge-sweep + random candidate (x,y) positions per orientation, built on the
|
||||
GPU. Returns (cand0, cand1). Random samples find tight pockets the deterministic grid misses."""
|
||||
xs = torch.arange(0, max(cur_w, 1) + 1, step, device=device)
|
||||
ys = torch.arange(0, max(cur_h, 1) + 1, step, device=device)
|
||||
# edge-sweep candidates are orientation-independent: build once, shared by both orientations
|
||||
common = [torch.stack([xs, torch.full_like(xs, yf)], 1) for yf in (0, cur_h)]
|
||||
common += [torch.stack([torch.full_like(ys, xf), ys], 1) for xf in (0, cur_w)]
|
||||
common = torch.cat(common, 0)
|
||||
out = []
|
||||
for cw in (bw0, bw1): # skyline-flush + random differ
|
||||
if cw > 0 and sky_t.shape[0] >= cw:
|
||||
wmax = sky_t.unfold(0, cw, 1).amax(1)[xs.clamp(max=max(sky_t.shape[0] - cw, 0))]
|
||||
else:
|
||||
wmax = torch.zeros_like(xs)
|
||||
parts = [torch.stack([xs, wmax], 1), common]
|
||||
if rand_n > 0: # distinct draws keep density
|
||||
rx = torch.randint(0, max(cur_w, 1) + 1, (rand_n,), generator=gen, device=device)
|
||||
ry = torch.randint(0, max(cur_h, 1) + 1, (rand_n,), generator=gen, device=device)
|
||||
parts.append(torch.stack([rx, ry], 1))
|
||||
out.append(torch.cat(parts, 0))
|
||||
return out[0], out[1]
|
||||
|
||||
|
||||
def _col_top(b: Tensor) -> Tensor:
|
||||
"""Topmost True row index per column of a bool bitmap (h,w); -1 for empty columns."""
|
||||
h = b.shape[0]
|
||||
rows = torch.arange(h, device=b.device)[:, None]
|
||||
return torch.where(b, rows, torch.full_like(rows.expand_as(b), -1)).amax(0)
|
||||
|
||||
|
||||
def _best_placement_torch(atlas, pix0, dim0, pix1, dim1, cand0, cand1, cur_w, cur_h, device):
|
||||
"""Lowest-score non-colliding candidate as a (3,) int tensor [x, y, swap] (x=-1 if none).
|
||||
Collision tests only each bitmap's True-pixel offsets (pix), not the full window. Fully on-GPU;
|
||||
the caller does the single sync (.tolist())."""
|
||||
INF = 1 << 60
|
||||
|
||||
def best(cand, pix, dim): # -> (score, x, y) 0-d tensors
|
||||
ch, cw = dim
|
||||
cx, cy = cand[:, 0], cand[:, 1]
|
||||
coll = atlas[cy[:, None] + pix[:, 0][None, :], # (M,k) True-pixel gather
|
||||
cx[:, None] + pix[:, 1][None, :]].any(dim=1)
|
||||
nw = torch.clamp(cx + cw, min=cur_w); nh = torch.clamp(cy + ch, min=cur_h)
|
||||
ext = torch.maximum(nw, nh)
|
||||
score = torch.where(coll, torch.full_like(nw, INF), ext * ext + nw * nh)
|
||||
j = score.argmin()
|
||||
return score[j], cx[j], cy[j]
|
||||
|
||||
s0, x0, y0 = best(cand0, pix0, dim0)
|
||||
s1, x1, y1 = best(cand1, pix1, dim1)
|
||||
take0 = s0 <= s1
|
||||
bsc = torch.where(take0, s0, s1)
|
||||
pick = torch.stack([torch.where(take0, x0, x1), torch.where(take0, y0, y1),
|
||||
torch.where(take0, x0.new_zeros(()), x0.new_ones(()))])
|
||||
return torch.where(bsc < INF, pick, torch.tensor([-1, -1, 0], device=device))
|
||||
|
||||
|
||||
def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
|
||||
texels_per_unit, padding_texels):
|
||||
"""Torch rasterize-and-place packer (numba-free fallback). Returns (placements, atlas_w, atlas_h)."""
|
||||
n = len(chart_uvs)
|
||||
if n == 0:
|
||||
return [], 1, 1
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
ang = torch.linspace(0.0, math.pi / 2.0, 37, device=device)[:-1]
|
||||
cos_a, sin_a = ang.cos(), ang.sin()
|
||||
|
||||
# ---- Prepare pass 1: best-rotation + scale + bbox for ALL charts at once (batched) ----
|
||||
vcount = [int(u.shape[0]) for u in chart_uvs]
|
||||
fcount = [int(f.shape[0]) for f in chart_faces]
|
||||
vmax = max(vcount); fmax = max(fcount)
|
||||
uvs_pad = torch.zeros(n, vmax, 2, device=device)
|
||||
vmask = torch.zeros(n, vmax, dtype=torch.bool, device=device)
|
||||
faces_pad = torch.zeros(n, fmax, 3, dtype=torch.long, device=device)
|
||||
fmask = torch.zeros(n, fmax, dtype=torch.bool, device=device)
|
||||
for i in range(n):
|
||||
uvs_pad[i, :vcount[i]] = chart_uvs[i].to(device=device, dtype=torch.float32)
|
||||
vmask[i, :vcount[i]] = True
|
||||
if fcount[i]:
|
||||
faces_pad[i, :fcount[i]] = chart_faces[i].to(device=device, dtype=torch.long)
|
||||
fmask[i, :fcount[i]] = True
|
||||
u0, u1 = uvs_pad[..., 0], uvs_pad[..., 1] # (N,Vmax)
|
||||
BIG = 1e30
|
||||
mlo = torch.where(vmask, torch.zeros_like(u0), u0.new_full((), BIG))
|
||||
mhi = torch.where(vmask, torch.zeros_like(u0), u0.new_full((), -BIG))
|
||||
xr = torch.addcmul(u0[:, :, None] * cos_a, u1[:, :, None], sin_a, value=-1) # (N,Vmax,A)
|
||||
yr = torch.addcmul(u0[:, :, None] * sin_a, u1[:, :, None], cos_a)
|
||||
xsp = (xr + mhi[:, :, None]).amax(1) - (xr + mlo[:, :, None]).amin(1) # (N,A) masked span
|
||||
ysp = (yr + mhi[:, :, None]).amax(1) - (yr + mlo[:, :, None]).amin(1)
|
||||
ti = (xsp * ysp).argmin(1) # (N,) best angle per chart
|
||||
cc, ss = cos_a[ti][:, None], sin_a[ti][:, None] # (N,1)
|
||||
rx = torch.addcmul(u0 * cc, u1, ss, value=-1) # (N,Vmax)
|
||||
ry = torch.addcmul(u0 * ss, u1, cc)
|
||||
rxmin = (rx + mlo).amin(1); rxmax = (rx + mhi).amax(1) # (N,)
|
||||
rymin = (ry + mlo).amin(1); rymax = (ry + mhi).amax(1)
|
||||
a3 = torch.tensor([max(a, 1e-12) for a in chart_3d_areas], device=device)
|
||||
au = torch.tensor([max(a, 1e-12) for a in chart_uv_areas], device=device)
|
||||
base = (a3 / au).sqrt() * texels_per_unit
|
||||
maxb = (4.0 * a3.sqrt() * texels_per_unit).clamp_min(8.0)
|
||||
bbm = torch.maximum(rxmax - rxmin, rymax - rymin).clamp_min(1e-12)
|
||||
scale = torch.minimum(base, maxb / bbm) # (N,)
|
||||
uvs_tex_pad = torch.stack([(rx - rxmin[:, None]) * scale[:, None],
|
||||
(ry - rymin[:, None]) * scale[:, None]], dim=-1) # (N,Vmax,2)
|
||||
bw_t = ((rxmax - rxmin) * scale).ceil().int() + padding_texels + 1
|
||||
bh_t = ((rymax - rymin) * scale).ceil().int() + padding_texels + 1
|
||||
|
||||
# one sync: pull all per-chart scalars
|
||||
thetas = ang[ti].cpu().tolist()
|
||||
scales = scale.cpu().tolist()
|
||||
bws = bw_t.cpu().tolist(); bhs = bh_t.cpu().tolist()
|
||||
|
||||
# ---- Prepare pass 2: rasterize ALL charts at once, then trim each bitmap to its bounds ----
|
||||
buf, cbase = _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding_texels, device)
|
||||
cb = cbase.cpu().tolist()
|
||||
raw, bnd = [], []
|
||||
for i in range(n):
|
||||
bm = buf[cb[i]:cb[i + 1]].view(bhs[i], bws[i])
|
||||
raw.append(bm)
|
||||
rr = torch.arange(bm.shape[0], device=device); cc = torch.arange(bm.shape[1], device=device)
|
||||
rmax = torch.where(bm.any(1), rr, rr.new_full((), -1)).amax() # last occupied row / col (-1 if empty)
|
||||
cmax = torch.where(bm.any(0), cc, cc.new_full((), -1)).amax()
|
||||
bnd.append(torch.stack([rmax, cmax]))
|
||||
bnd_cpu = torch.stack(bnd).cpu().tolist() # one sync for all trim bounds
|
||||
|
||||
# per-chart True-pixel offsets (sparse collision/blit), dims, col-tops (all kept on GPU)
|
||||
pix_l, pixr_l, dim_l, dimr_l, bm_h = [], [], [], [], []
|
||||
col_tops, col_tops_rot = [], []
|
||||
for i in range(n):
|
||||
rm, cm = bnd_cpu[i]
|
||||
bm = (raw[i][:rm + 1, :cm + 1].contiguous() if rm >= 0 and cm >= 0
|
||||
else torch.zeros((1, 1), dtype=torch.bool, device=device))
|
||||
bm_rot = torch.flip(bm.t(), dims=[1]).contiguous()
|
||||
pix_l.append(bm.nonzero()); pixr_l.append(bm_rot.nonzero())
|
||||
dim_l.append((int(bm.shape[0]), int(bm.shape[1])))
|
||||
dimr_l.append((int(bm_rot.shape[0]), int(bm_rot.shape[1])))
|
||||
col_tops.append(_col_top(bm)); col_tops_rot.append(_col_top(bm_rot))
|
||||
bm_h.append(int(bm.shape[0]))
|
||||
wmax = max(d[1] for d in dim_l + dimr_l)
|
||||
ct_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device)
|
||||
ctr_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device)
|
||||
for i in range(n):
|
||||
ct_pad[i, :col_tops[i].shape[0]] = col_tops[i]
|
||||
ctr_pad[i, :col_tops_rot[i].shape[0]] = col_tops_rot[i]
|
||||
del raw
|
||||
|
||||
# ---- Placement: skyline bin-pack on GPU (1 sync/chart for the chosen position) ----
|
||||
order = sorted(range(n), key=lambda i: -(dim_l[i][0] * dim_l[i][1])) # biggest bitmap first
|
||||
max_b = max(max(d) for d in dim_l)
|
||||
margin = max_b + 8
|
||||
side_guess = int(math.sqrt(sum(d[0] * d[1] for d in dim_l)) * 2) + 16
|
||||
cap = side_guess + margin
|
||||
atlas = torch.zeros((cap, cap), dtype=torch.bool, device=device)
|
||||
sky_t = torch.zeros(cap, dtype=torch.long, device=device)
|
||||
cur_w = cur_h = 0
|
||||
placements = [None] * n
|
||||
gen = torch.Generator(device=device).manual_seed(0)
|
||||
rand_n = 512 # random samples per orientation
|
||||
|
||||
for ci in order:
|
||||
if cur_h + margin > atlas.shape[0] or cur_w + margin > atlas.shape[1]:
|
||||
ns = max(atlas.shape[0], cur_h + margin, cur_w + margin)
|
||||
na = torch.zeros((ns, ns), dtype=torch.bool, device=device)
|
||||
na[:atlas.shape[0], :atlas.shape[1]] = atlas; atlas = na
|
||||
nsk = torch.zeros(ns, dtype=torch.long, device=device); nsk[:sky_t.shape[0]] = sky_t; sky_t = nsk
|
||||
dim, dimr = dim_l[ci], dimr_l[ci]
|
||||
step = max(1, min(dim[0], dim[1]) // 8)
|
||||
cand0, cand1 = _build_candidates_gpu(sky_t, cur_w, cur_h, dim[1], dimr[1], step, rand_n, gen, device)
|
||||
res = _best_placement_torch(atlas, pix_l[ci], dim, pixr_l[ci], dimr, cand0, cand1, cur_w, cur_h, device)
|
||||
bx, by, swap = (int(v) for v in res.tolist()) # the one sync/chart
|
||||
if bx < 0:
|
||||
bx, by, swap = cur_w, 0, 0
|
||||
pix = pixr_l[ci] if swap else pix_l[ci]
|
||||
bh_, bw_ = (dimr if swap else dim)
|
||||
atlas[by + pix[:, 0], bx + pix[:, 1]] = True # sparse blit
|
||||
cur_w = max(cur_w, bx + bw_); cur_h = max(cur_h, by + bh_)
|
||||
ct = (ctr_pad if swap else ct_pad)[ci, :bw_] # GPU skyline lift
|
||||
ix = torch.arange(bx, bx + bw_, device=device)
|
||||
sky_t[ix] = torch.where(ct >= 0, torch.maximum(sky_t[ix], by + ct + 1), sky_t[ix])
|
||||
placements[ci] = ChartPlacement(chart_id=ci, offset=(float(bx), float(by)),
|
||||
scale=scales[ci], rotation=thetas[ci], swap_xy=bool(swap),
|
||||
chart_h=float(bm_h[ci]))
|
||||
return placements, cur_w, cur_h
|
||||
|
||||
|
||||
def pack_bitmap(
|
||||
chart_uvs: List[Tensor],
|
||||
chart_3d_areas: List[float],
|
||||
chart_uv_areas: List[float],
|
||||
chart_faces: List[Tensor],
|
||||
texels_per_unit: float = 256.0,
|
||||
padding_texels: int = 2,
|
||||
attempts: int = 4096,
|
||||
rng_seed: int = 0,
|
||||
) -> Tuple[List[ChartPlacement], int, int]:
|
||||
"""Rasterize-and-place packer. Returns (placements, atlas_w, atlas_h)."""
|
||||
n = len(chart_uvs)
|
||||
if n == 0:
|
||||
return [], 1, 1
|
||||
if not _HAVE_NUMBA_PACK:
|
||||
return _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas,
|
||||
chart_faces, texels_per_unit, padding_texels)
|
||||
|
||||
rng = np.random.default_rng(rng_seed)
|
||||
prepared: List[_PreparedChart] = []
|
||||
skyline_cap = 4096
|
||||
skyline = np.zeros(skyline_cap, dtype=np.int64)
|
||||
|
||||
for i, (uvs_t, area_3d, area_uv, faces_t) in enumerate(
|
||||
zip(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces)
|
||||
):
|
||||
uvs = uvs_t.detach().cpu().numpy().astype(np.float64)
|
||||
faces = faces_t.detach().cpu().numpy()
|
||||
|
||||
theta = _best_rotation(uvs)
|
||||
rotated = _rotate_xy(uvs, theta)
|
||||
scale = math.sqrt(max(area_3d, 1e-12) / max(area_uv, 1e-12)) * texels_per_unit
|
||||
# Cap per-chart bbox to 4x nominal so a degenerate chart can't span the atlas.
|
||||
nominal_side = math.sqrt(max(area_3d, 1e-12)) * float(texels_per_unit)
|
||||
max_bbox_texels = max(8.0, 4.0 * nominal_side)
|
||||
bbox_uv = (rotated.max(axis=0) - rotated.min(axis=0))
|
||||
bbox_uv_max = float(max(bbox_uv[0], bbox_uv[1], 1e-12))
|
||||
if scale * bbox_uv_max > max_bbox_texels:
|
||||
scale = max_bbox_texels / bbox_uv_max
|
||||
uvs_tex = rotated * scale
|
||||
uvs_tex = uvs_tex - uvs_tex.min(axis=0)
|
||||
bbox_w = int(math.ceil(uvs_tex[:, 0].max())) + padding_texels + 1
|
||||
bbox_h = int(math.ceil(uvs_tex[:, 1].max())) + padding_texels + 1
|
||||
|
||||
bm = _rasterize_chart(uvs_tex, faces, bbox_w, bbox_h, padding_texels)
|
||||
nz_rows = np.where(bm.any(axis=1))[0]
|
||||
nz_cols = np.where(bm.any(axis=0))[0]
|
||||
if nz_rows.size == 0 or nz_cols.size == 0:
|
||||
bm = np.zeros((1, 1), dtype=bool)
|
||||
bbox_h, bbox_w = 1, 1
|
||||
else:
|
||||
bm = bm[: nz_rows[-1] + 1, : nz_cols[-1] + 1]
|
||||
bbox_h, bbox_w = bm.shape
|
||||
# True 90 deg rotation; plain transpose would mirror and flip winding.
|
||||
bm_rot = bm.T[:, ::-1].copy()
|
||||
|
||||
perim = _chart_perimeter(uvs_tex, faces)
|
||||
prepared.append(
|
||||
_PreparedChart(
|
||||
chart_id=i,
|
||||
uvs_tex=uvs_tex,
|
||||
bitmap=bm,
|
||||
bitmap_rot=bm_rot,
|
||||
bbox_w=bbox_w,
|
||||
bbox_h=bbox_h,
|
||||
rotation=theta,
|
||||
s_tex=scale,
|
||||
perimeter=perim,
|
||||
)
|
||||
)
|
||||
|
||||
order = sorted(range(n), key=lambda i: -prepared[i].perimeter)
|
||||
|
||||
total_area = sum(p.bbox_w * p.bbox_h for p in prepared)
|
||||
side_guess = int(math.sqrt(total_area) * 2) + 16
|
||||
atlas = np.zeros((side_guess, side_guess), dtype=bool)
|
||||
cur_w = 0
|
||||
cur_h = 0
|
||||
|
||||
placements: List[ChartPlacement] = [None] * n # type: ignore
|
||||
|
||||
for ci in order:
|
||||
p = prepared[ci]
|
||||
|
||||
step = max(1, min(p.bbox_w, p.bbox_h) // 8)
|
||||
det_arr = _build_candidates_jit(
|
||||
skyline, cur_w, cur_h,
|
||||
p.bitmap.shape[1], p.bitmap.shape[0],
|
||||
p.bitmap_rot.shape[1], p.bitmap_rot.shape[0],
|
||||
step,
|
||||
)
|
||||
|
||||
x_range = max(cur_w + 1, 1)
|
||||
y_range = max(cur_h + 1, 1)
|
||||
rand_x = rng.integers(0, x_range, size=attempts).astype(np.int64)
|
||||
rand_y = rng.integers(0, y_range, size=attempts).astype(np.int64)
|
||||
rand_swap = (np.arange(attempts) & 1).astype(np.int64)
|
||||
rand_arr = np.stack([rand_x, rand_y, rand_swap], axis=1)
|
||||
candidates = np.concatenate([det_arr, rand_arr], axis=0) if det_arr.size else rand_arr
|
||||
|
||||
best_x, best_y, best_score_int, best_swap_int = _best_placement_jit(
|
||||
atlas, p.bitmap, p.bitmap_rot, candidates, cur_w, cur_h,
|
||||
)
|
||||
best_swap = bool(best_swap_int)
|
||||
|
||||
if best_x >= 0:
|
||||
bm_b = p.bitmap_rot if best_swap else p.bitmap
|
||||
need_h = max(cur_h, best_y + bm_b.shape[0])
|
||||
need_w = max(cur_w, best_x + bm_b.shape[1])
|
||||
if atlas.shape[0] < need_h or atlas.shape[1] < need_w:
|
||||
target_h = max(atlas.shape[0], need_h, side_guess)
|
||||
target_w = max(atlas.shape[1], need_w, side_guess)
|
||||
new_atlas = np.zeros((target_h, target_w), dtype=bool)
|
||||
new_atlas[: atlas.shape[0], : atlas.shape[1]] = atlas
|
||||
atlas = new_atlas
|
||||
|
||||
if best_x < 0:
|
||||
# Fallback: place at extension corner.
|
||||
best_x, best_y = cur_w, 0
|
||||
best_swap = False
|
||||
bm = p.bitmap
|
||||
need_h = max(cur_h, best_y + bm.shape[0])
|
||||
need_w = max(cur_w, best_x + bm.shape[1])
|
||||
if atlas.shape[0] < need_h or atlas.shape[1] < need_w:
|
||||
target_h = max(atlas.shape[0], need_h)
|
||||
target_w = max(atlas.shape[1], need_w)
|
||||
new_atlas = np.zeros((target_h, target_w), dtype=bool)
|
||||
new_atlas[: atlas.shape[0], : atlas.shape[1]] = atlas
|
||||
atlas = new_atlas
|
||||
|
||||
bm = p.bitmap_rot if best_swap else p.bitmap
|
||||
_blit(atlas, bm, best_x, best_y)
|
||||
cur_w = max(cur_w, best_x + bm.shape[1])
|
||||
cur_h = max(cur_h, best_y + bm.shape[0])
|
||||
if cur_w + 1 > skyline.shape[0]:
|
||||
new_sky = np.zeros(max(skyline.shape[0] * 2, cur_w + 1), dtype=np.int64)
|
||||
new_sky[: skyline.shape[0]] = skyline
|
||||
skyline = new_sky
|
||||
_update_skyline_jit(skyline, bm, best_x, best_y)
|
||||
|
||||
placements[ci] = ChartPlacement(
|
||||
chart_id=ci,
|
||||
offset=(float(best_x), float(best_y)),
|
||||
scale=p.s_tex,
|
||||
rotation=p.rotation,
|
||||
swap_xy=best_swap,
|
||||
chart_h=float(p.bitmap.shape[0]),
|
||||
)
|
||||
|
||||
return placements, cur_w, cur_h
|
||||
|
||||
|
||||
def apply_placements(
|
||||
chart_uvs: List[Tensor], placements: List[ChartPlacement], atlas_w: int, atlas_h: int
|
||||
) -> List[Tensor]:
|
||||
"""Apply per-chart (rotation, scale, swap_xy, offset) and normalize by the larger atlas side (shared scale keeps texel density uniform)."""
|
||||
out: List[Tensor] = []
|
||||
side = float(max(atlas_w, atlas_h, 1))
|
||||
for uvs, p in zip(chart_uvs, placements):
|
||||
device = uvs.device
|
||||
dtype = uvs.dtype
|
||||
uvs_np = uvs.detach().cpu().numpy().astype(np.float64)
|
||||
if p.rotation != 0.0:
|
||||
uvs_np = _rotate_xy(uvs_np, p.rotation)
|
||||
uvs_np = uvs_np - uvs_np.min(axis=0)
|
||||
uvs_np = uvs_np * p.scale
|
||||
if p.swap_xy:
|
||||
# 90 deg rotation matching bm.T[:, ::-1]: (u, v) -> (chart_h - v, u).
|
||||
u_old = uvs_np[:, 0].copy()
|
||||
uvs_np[:, 0] = p.chart_h - uvs_np[:, 1]
|
||||
uvs_np[:, 1] = u_old
|
||||
uvs_np[:, 0] += p.offset[0]
|
||||
uvs_np[:, 1] += p.offset[1]
|
||||
uvs_np /= side
|
||||
# Clamp into [0,1]; slivers can stick sub-texel past the tracked extent.
|
||||
np.clip(uvs_np, 0.0, 1.0, out=uvs_np)
|
||||
out.append(torch.from_numpy(uvs_np).to(device=device, dtype=dtype))
|
||||
return out
|
||||
387
comfy_extras/mesh3d/uv_unwrap/parameterize.py
Normal file
387
comfy_extras/mesh3d/uv_unwrap/parameterize.py
Normal file
@ -0,0 +1,387 @@
|
||||
"""Chart parameterization: ortho PCA projection, falling back to ABF/LSCM."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import scipy.sparse.linalg as spla
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from . import mesh as _mesh
|
||||
|
||||
|
||||
def solve_least_squares(A: sp.csr_matrix, b: np.ndarray) -> np.ndarray:
|
||||
"""Solve ||Ax - b||^2 by factorizing AtA."""
|
||||
At = A.T.tocsr()
|
||||
AtA = (At @ A).tocsc()
|
||||
Atb = At @ b
|
||||
return spla.spsolve(AtA, Atb)
|
||||
|
||||
|
||||
def _triangle_local_2d(verts_3d: np.ndarray, faces: np.ndarray) -> np.ndarray:
|
||||
"""Per-triangle 2D coords [F, 3, 2] with v0 at origin, v1 along +x."""
|
||||
v0 = verts_3d[faces[:, 0]]
|
||||
v1 = verts_3d[faces[:, 1]]
|
||||
v2 = verts_3d[faces[:, 2]]
|
||||
e01 = v1 - v0
|
||||
e02 = v2 - v0
|
||||
L01 = np.linalg.norm(e01, axis=1).clip(min=1e-20)
|
||||
x_axis = e01 / L01[:, None]
|
||||
n = np.cross(e01, e02)
|
||||
n /= np.linalg.norm(n, axis=1, keepdims=True).clip(min=1e-20)
|
||||
y_axis = np.cross(n, x_axis)
|
||||
|
||||
out = np.zeros((faces.shape[0], 3, 2), dtype=np.float64)
|
||||
out[:, 1, 0] = L01
|
||||
out[:, 2, 0] = (e02 * x_axis).sum(axis=1)
|
||||
out[:, 2, 1] = (e02 * y_axis).sum(axis=1)
|
||||
return out
|
||||
|
||||
|
||||
def _pick_pins(loops: List[List[int]], verts_3d: np.ndarray) -> Tuple[int, int]:
|
||||
"""Pick the longest-diameter axis-extremal boundary vertex pair across all boundary verts."""
|
||||
if not loops:
|
||||
# Closed surface: two far verts via two-pass farthest.
|
||||
d2 = np.sum((verts_3d - verts_3d[0]) ** 2, axis=1)
|
||||
a = int(np.argmax(d2))
|
||||
d2 = np.sum((verts_3d - verts_3d[a]) ** 2, axis=1)
|
||||
b = int(np.argmax(d2))
|
||||
return a, b
|
||||
boundary_verts: List[int] = []
|
||||
for loop in loops:
|
||||
boundary_verts.extend(loop)
|
||||
seen = set()
|
||||
uniq = []
|
||||
for v in boundary_verts:
|
||||
if v not in seen:
|
||||
seen.add(v)
|
||||
uniq.append(v)
|
||||
bv = np.asarray(uniq, dtype=np.int64)
|
||||
pts = verts_3d[bv]
|
||||
pin_pairs = []
|
||||
for axis in range(3):
|
||||
i_min = int(bv[int(np.argmin(pts[:, axis]))])
|
||||
i_max = int(bv[int(np.argmax(pts[:, axis]))])
|
||||
d = float(np.linalg.norm(verts_3d[i_min] - verts_3d[i_max]))
|
||||
pin_pairs.append((d, i_min, i_max))
|
||||
d0, _, _ = pin_pairs[0]
|
||||
d1, _, _ = pin_pairs[1]
|
||||
d2, _, _ = pin_pairs[2]
|
||||
if d0 > d1 and d0 > d2:
|
||||
_, a, b = pin_pairs[0]
|
||||
elif d1 > d2:
|
||||
_, a, b = pin_pairs[1]
|
||||
else:
|
||||
_, a, b = pin_pairs[2]
|
||||
return a, b
|
||||
|
||||
|
||||
def _ortho_project(verts_3d: np.ndarray) -> np.ndarray:
|
||||
"""PCA-fit plane normal, axis-aligned tangent, project verts to 2D."""
|
||||
centroid = verts_3d.mean(axis=0)
|
||||
pts = verts_3d - centroid
|
||||
cov = pts.T @ pts
|
||||
_w, ev = np.linalg.eigh(cov)
|
||||
normal = ev[:, 0]
|
||||
a = np.abs(normal)
|
||||
if a[0] < a[1] and a[0] < a[2]:
|
||||
t = np.array([1.0, 0.0, 0.0])
|
||||
elif a[1] < a[2]:
|
||||
t = np.array([0.0, 1.0, 0.0])
|
||||
else:
|
||||
t = np.array([0.0, 0.0, 1.0])
|
||||
t = t - normal * float(np.dot(normal, t))
|
||||
t /= max(float(np.linalg.norm(t)), 1e-20)
|
||||
b = np.cross(normal, t)
|
||||
return np.stack([verts_3d @ t, verts_3d @ b], axis=1)
|
||||
|
||||
|
||||
def _stretch_metrics(verts_3d: np.ndarray, uvs: np.ndarray, faces: np.ndarray) -> Tuple[float, float, int, int]:
|
||||
"""Sander's stretch metric. Returns (rms, max, n_flipped, n_zero_area)."""
|
||||
p = verts_3d[faces]
|
||||
t = uvs[faces]
|
||||
parametric_area = 0.5 * (
|
||||
(t[:, 1, 1] - t[:, 0, 1]) * (t[:, 2, 0] - t[:, 0, 0])
|
||||
- (t[:, 2, 1] - t[:, 0, 1]) * (t[:, 1, 0] - t[:, 0, 0])
|
||||
)
|
||||
n_flipped = int((parametric_area < -1e-12).sum())
|
||||
n_zero = int((np.abs(parametric_area) < 1e-12).sum())
|
||||
pa = np.abs(parametric_area).clip(min=1e-20)
|
||||
geom_area = 0.5 * np.linalg.norm(
|
||||
np.cross(p[:, 1] - p[:, 0], p[:, 2] - p[:, 0]), axis=1
|
||||
)
|
||||
keep = (geom_area > 1e-12) & (np.abs(parametric_area) > 1e-12)
|
||||
if not keep.any():
|
||||
return float("inf"), float("inf"), n_flipped, n_zero
|
||||
t1 = t[:, 0, 0]; s1 = t[:, 0, 1]
|
||||
t2 = t[:, 1, 0]; s2 = t[:, 1, 1]
|
||||
t3 = t[:, 2, 0]; s3 = t[:, 2, 1]
|
||||
inv_2pa = 1.0 / (2.0 * pa)
|
||||
Ss = (
|
||||
p[:, 0] * (t2 - t3)[:, None]
|
||||
+ p[:, 1] * (t3 - t1)[:, None]
|
||||
+ p[:, 2] * (t1 - t2)[:, None]
|
||||
) * inv_2pa[:, None]
|
||||
St = (
|
||||
p[:, 0] * (s3 - s2)[:, None]
|
||||
+ p[:, 1] * (s1 - s3)[:, None]
|
||||
+ p[:, 2] * (s2 - s1)[:, None]
|
||||
) * inv_2pa[:, None]
|
||||
a = (Ss * Ss).sum(axis=1)
|
||||
bb = (Ss * St).sum(axis=1)
|
||||
c = (St * St).sum(axis=1)
|
||||
sigma2_sq = 0.5 * (a + c + np.sqrt(np.maximum(0.0, (a - c) ** 2 + 4 * bb ** 2)))
|
||||
rms_sq = (a + c) * 0.5
|
||||
rms_stretch_sq_sum = float((rms_sq[keep] * geom_area[keep]).sum())
|
||||
total_geom = float(geom_area[keep].sum())
|
||||
total_param = float(pa[keep].sum())
|
||||
if total_geom <= 0.0:
|
||||
return float("inf"), float("inf"), n_flipped, n_zero
|
||||
norm_factor = np.sqrt(total_param / total_geom)
|
||||
rms_stretch = float(np.sqrt(rms_stretch_sq_sum / total_geom)) * norm_factor
|
||||
max_stretch = float(np.sqrt(sigma2_sq[keep].max())) * norm_factor
|
||||
return rms_stretch, max_stretch, n_flipped, n_zero
|
||||
|
||||
|
||||
def _uv_boundary_self_intersects(
|
||||
uvs: np.ndarray, faces: np.ndarray, face_face: np.ndarray, eps: float = 1e-9
|
||||
) -> bool:
|
||||
"""True if any chart-boundary edge pair crosses in 2D (ortho folded the chart)."""
|
||||
fi, ei = np.nonzero(face_face < 0)
|
||||
n = fi.size
|
||||
if n < 2:
|
||||
return False
|
||||
a = uvs[faces[fi, ei]].astype(np.float64)
|
||||
b = uvs[faces[fi, (ei + 1) % 3]].astype(np.float64)
|
||||
d = b - a
|
||||
# Pairwise segment crossings, row-chunked to bound memory at chunk*n.
|
||||
chunk = max(1, min(n, 1_000_000 // max(n, 1)))
|
||||
for s in range(0, n, chunk):
|
||||
e = min(s + chunk, n)
|
||||
d1 = d[s:e, None, :]
|
||||
denom = d1[:, :, 0] * d[None, :, 1] - d1[:, :, 1] * d[None, :, 0]
|
||||
rx = a[None, :, 0] - a[s:e, None, 0]
|
||||
ry = a[None, :, 1] - a[s:e, None, 1]
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
t = (rx * d[None, :, 1] - ry * d[None, :, 0]) / denom
|
||||
u = (rx * d1[:, :, 1] - ry * d1[:, :, 0]) / denom
|
||||
cross = (
|
||||
(np.abs(denom) >= eps)
|
||||
& (t > eps) & (t < 1.0 - eps)
|
||||
& (u > eps) & (u < 1.0 - eps)
|
||||
)
|
||||
if bool(cross.any()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parametrize_chart(
|
||||
local_verts: Tensor, local_faces: Tensor, local_face_face: Tensor
|
||||
) -> Tensor:
|
||||
"""Parameterize one chart: ortho first, ABF/LSCM fallback; charts <=5 faces stay ortho."""
|
||||
verts_np = local_verts.detach().cpu().numpy().astype(np.float64)
|
||||
faces_np = local_faces.detach().cpu().numpy().astype(np.int64)
|
||||
if verts_np.shape[0] < 3 or faces_np.shape[0] == 0:
|
||||
return torch.zeros((verts_np.shape[0], 2), dtype=torch.float32, device=local_verts.device)
|
||||
|
||||
ortho = _ortho_project(verts_np)
|
||||
n_faces = faces_np.shape[0]
|
||||
if n_faces <= 5:
|
||||
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
|
||||
rms, mx, n_flip, n_zero = _stretch_metrics(verts_np, ortho, faces_np)
|
||||
flip_ok = n_flip == 0 or n_flip == n_faces
|
||||
if flip_ok and n_zero == 0 and rms <= 1.5 and mx <= 2.0:
|
||||
ff_np = local_face_face.detach().cpu().numpy().astype(np.int64)
|
||||
if not _uv_boundary_self_intersects(ortho, faces_np, ff_np):
|
||||
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
|
||||
uvs_t = lscm_chart(local_verts, local_faces, local_face_face, pin_positions=ortho)
|
||||
# Collapsed UV island (aspect > 100:1) blows up packing scale; fall back to ortho.
|
||||
uvs_np = uvs_t.detach().cpu().numpy()
|
||||
bbox = uvs_np.max(axis=0) - uvs_np.min(axis=0)
|
||||
bbox_max = float(max(bbox[0], bbox[1], 1e-12))
|
||||
bbox_min = float(max(min(bbox[0], bbox[1]), 1e-12))
|
||||
if bbox_max / bbox_min > 100.0:
|
||||
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
|
||||
return uvs_t
|
||||
|
||||
|
||||
def _abf_face_coefficients(
|
||||
verts_3d: np.ndarray, faces: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Per-face ABF constraint (largest-sine vertex at local index 2); returns (faces_reordered, cosine, sine, valid_mask) with valid_mask False for degenerate tris."""
|
||||
Fc = faces.shape[0]
|
||||
p0 = verts_3d[faces[:, 0]]
|
||||
p1 = verts_3d[faces[:, 1]]
|
||||
p2 = verts_3d[faces[:, 2]]
|
||||
e01 = p1 - p0
|
||||
e12 = p2 - p1
|
||||
e20 = p0 - p2
|
||||
L01 = np.linalg.norm(e01, axis=1).clip(min=1e-20)
|
||||
L12 = np.linalg.norm(e12, axis=1).clip(min=1e-20)
|
||||
L20 = np.linalg.norm(e20, axis=1).clip(min=1e-20)
|
||||
cos_a0 = ((-e20) * e01).sum(axis=1) / (L20 * L01)
|
||||
cos_a1 = ((-e01) * e12).sum(axis=1) / (L01 * L12)
|
||||
cos_a2 = ((-e12) * e20).sum(axis=1) / (L12 * L20)
|
||||
cos_a0 = cos_a0.clip(-1.0, 1.0)
|
||||
cos_a1 = cos_a1.clip(-1.0, 1.0)
|
||||
cos_a2 = cos_a2.clip(-1.0, 1.0)
|
||||
a = np.arccos(cos_a0)
|
||||
b_ang = np.arccos(cos_a1)
|
||||
c_ang = np.arccos(cos_a2)
|
||||
angles = np.stack([a, b_ang, c_ang], axis=1)
|
||||
sines = np.stack([np.sin(a), np.sin(b_ang), np.sin(c_ang)], axis=1)
|
||||
valid = (angles > 1e-12).all(axis=1)
|
||||
ids = faces.astype(np.int64).copy()
|
||||
|
||||
s0, s1, s2 = sines[:, 0], sines[:, 1], sines[:, 2]
|
||||
pattA = (s1 > s0) & (s1 > s2)
|
||||
pattB = (~pattA) & (s0 > s1) & (s0 > s2)
|
||||
|
||||
if pattA.any():
|
||||
old_a = angles[pattA].copy()
|
||||
old_s = sines[pattA].copy()
|
||||
old_id = ids[pattA].copy()
|
||||
angles[pattA] = old_a[:, [2, 0, 1]]
|
||||
sines[pattA] = old_s[:, [2, 0, 1]]
|
||||
ids[pattA] = old_id[:, [2, 0, 1]]
|
||||
if pattB.any():
|
||||
old_a = angles[pattB].copy()
|
||||
old_s = sines[pattB].copy()
|
||||
old_id = ids[pattB].copy()
|
||||
angles[pattB] = old_a[:, [1, 2, 0]]
|
||||
sines[pattB] = old_s[:, [1, 2, 0]]
|
||||
ids[pattB] = old_id[:, [1, 2, 0]]
|
||||
|
||||
a0 = angles[:, 0]
|
||||
s0 = sines[:, 0]
|
||||
s1 = sines[:, 1]
|
||||
s2 = sines[:, 2]
|
||||
c0 = np.cos(a0)
|
||||
ratio = np.where(s2 > 0.0, s1 / s2.clip(min=1e-20), 1.0)
|
||||
cosine = c0 * ratio
|
||||
sine = s0 * ratio
|
||||
return ids, cosine, sine, valid
|
||||
|
||||
|
||||
def lscm_chart(
|
||||
local_verts: Tensor,
|
||||
local_faces: Tensor,
|
||||
local_face_face: Tensor,
|
||||
pin_positions: "np.ndarray | None" = None,
|
||||
) -> Tensor:
|
||||
"""ABF parameterization on one chart (degenerate faces use plain LSCM rows; two pins fix gauge at pin_positions)."""
|
||||
verts_np = local_verts.detach().cpu().numpy().astype(np.float64)
|
||||
faces_np = local_faces.detach().cpu().numpy().astype(np.int64)
|
||||
Vc = verts_np.shape[0]
|
||||
Fc = faces_np.shape[0]
|
||||
|
||||
if Vc < 3 or Fc == 0:
|
||||
return torch.zeros((Vc, 2), dtype=torch.float32, device=local_verts.device)
|
||||
|
||||
loops = _mesh.chart_boundary_loops(local_faces, local_face_face)
|
||||
pin_a, pin_b = _pick_pins(loops, verts_np)
|
||||
|
||||
if pin_positions is not None and pin_positions.shape == (Vc, 2):
|
||||
pa = pin_positions[pin_a]
|
||||
pb = pin_positions[pin_b]
|
||||
u_a, v_a = float(pa[0]), float(pa[1])
|
||||
u_b, v_b = float(pb[0]), float(pb[1])
|
||||
else:
|
||||
u_a, v_a = 0.0, 0.0
|
||||
u_b, v_b = 1.0, 0.0
|
||||
|
||||
abf_ids, abf_cos, abf_sin, abf_valid = _abf_face_coefficients(verts_np, faces_np)
|
||||
|
||||
rows_list: List[np.ndarray] = []
|
||||
cols_list: List[np.ndarray] = []
|
||||
vals_list: List[np.ndarray] = []
|
||||
|
||||
# ABF rows for valid faces.
|
||||
valid_idx = np.nonzero(abf_valid)[0]
|
||||
if valid_idx.size:
|
||||
Nv = valid_idx.size
|
||||
id0 = abf_ids[valid_idx, 0]
|
||||
id1 = abf_ids[valid_idx, 1]
|
||||
id2 = abf_ids[valid_idx, 2]
|
||||
cosf = abf_cos[valid_idx]
|
||||
sinf = abf_sin[valid_idx]
|
||||
r_real = valid_idx * 2
|
||||
r_imag = valid_idx * 2 + 1
|
||||
ones = np.ones(Nv, dtype=np.float64)
|
||||
rows_list.extend([r_real] * 5)
|
||||
cols_list.extend([id0, id0 + Vc, id1, id1 + Vc, id2])
|
||||
vals_list.extend([cosf - 1.0, -sinf, -cosf, sinf, ones])
|
||||
rows_list.extend([r_imag] * 5)
|
||||
cols_list.extend([id0, id0 + Vc, id1, id1 + Vc, id2 + Vc])
|
||||
vals_list.extend([sinf, cosf - 1.0, -sinf, -cosf, ones])
|
||||
|
||||
# Plain-LSCM rows for invalid (degenerate) faces.
|
||||
invalid_idx = np.nonzero(~abf_valid)[0]
|
||||
if invalid_idx.size:
|
||||
tri2d_inv = _triangle_local_2d(verts_np, faces_np[invalid_idx])
|
||||
twice_area_inv = (
|
||||
tri2d_inv[:, 1, 0] * tri2d_inv[:, 2, 1]
|
||||
- tri2d_inv[:, 1, 1] * tri2d_inv[:, 2, 0]
|
||||
)
|
||||
weight_inv = 1.0 / np.sqrt(2.0 * np.abs(twice_area_inv).clip(min=1e-20))
|
||||
r_real_inv = invalid_idx * 2
|
||||
r_imag_inv = invalid_idx * 2 + 1
|
||||
for j in range(3):
|
||||
jp1 = (j + 1) % 3
|
||||
jp2 = (j + 2) % 3
|
||||
a_j = (tri2d_inv[:, jp1, 0] - tri2d_inv[:, jp2, 0]) * weight_inv
|
||||
b_j = (tri2d_inv[:, jp1, 1] - tri2d_inv[:, jp2, 1]) * weight_inv
|
||||
v_idx = faces_np[invalid_idx, j]
|
||||
rows_list.extend([r_real_inv, r_real_inv, r_imag_inv, r_imag_inv])
|
||||
cols_list.extend([v_idx, v_idx + Vc, v_idx, v_idx + Vc])
|
||||
vals_list.extend([a_j, -b_j, b_j, a_j])
|
||||
|
||||
rows = np.concatenate(rows_list) if rows_list else np.empty(0, dtype=np.int64)
|
||||
cols = np.concatenate(cols_list) if cols_list else np.empty(0, dtype=np.int64)
|
||||
vals = np.concatenate(vals_list) if vals_list else np.empty(0, dtype=np.float64)
|
||||
|
||||
A_full = sp.csr_matrix((vals, (rows, cols)), shape=(2 * Fc, 2 * Vc))
|
||||
|
||||
pin_cols = np.array([pin_a, pin_b, pin_a + Vc, pin_b + Vc], dtype=np.int64)
|
||||
pin_vals = np.array([u_a, u_b, v_a, v_b], dtype=np.float64)
|
||||
|
||||
free_mask = np.ones(2 * Vc, dtype=bool)
|
||||
free_mask[pin_cols] = False
|
||||
free_cols = np.nonzero(free_mask)[0]
|
||||
|
||||
A_pinned = A_full[:, pin_cols]
|
||||
A_free = A_full[:, free_cols]
|
||||
b = -(A_pinned @ pin_vals)
|
||||
|
||||
# Singular system (under-constrained chart) falls back to ortho.
|
||||
fallback_to_ortho = False
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", category=sp.linalg.MatrixRankWarning)
|
||||
x_free = solve_least_squares(A_free, b)
|
||||
if not np.all(np.isfinite(x_free)):
|
||||
fallback_to_ortho = True
|
||||
except Exception:
|
||||
fallback_to_ortho = True
|
||||
|
||||
if fallback_to_ortho:
|
||||
if pin_positions is not None and pin_positions.shape == (Vc, 2):
|
||||
uvs = pin_positions.astype(np.float32)
|
||||
else:
|
||||
uvs = _ortho_project(verts_np).astype(np.float32)
|
||||
return torch.from_numpy(uvs).to(local_verts.device)
|
||||
|
||||
full = np.zeros(2 * Vc, dtype=np.float64)
|
||||
full[free_cols] = x_free
|
||||
full[pin_cols] = pin_vals
|
||||
uvs = np.stack([full[:Vc], full[Vc:]], axis=1).astype(np.float32)
|
||||
if not np.all(np.isfinite(uvs)):
|
||||
if pin_positions is not None and pin_positions.shape == (Vc, 2):
|
||||
uvs = pin_positions.astype(np.float32)
|
||||
else:
|
||||
uvs = _ortho_project(verts_np).astype(np.float32)
|
||||
|
||||
return torch.from_numpy(uvs).to(local_verts.device)
|
||||
638
comfy_extras/mesh3d/uv_unwrap/segment.py
Normal file
638
comfy_extras/mesh3d/uv_unwrap/segment.py
Normal file
@ -0,0 +1,638 @@
|
||||
"""Adaptive cost-grow chart segmentation (CPU); numba optional, numpy path is nd-only."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
try:
|
||||
from numba import njit
|
||||
_HAVE_NUMBA = True
|
||||
except ImportError:
|
||||
_HAVE_NUMBA = False
|
||||
def njit(*args, **kwargs): # noqa: ARG001
|
||||
def deco(fn):
|
||||
return fn
|
||||
return deco if not args else args[0]
|
||||
|
||||
|
||||
from .mesh import MeshData, face_edge_lengths
|
||||
|
||||
|
||||
DEFAULT_W_NORMAL_DEVIATION = 2.0
|
||||
DEFAULT_W_ROUNDNESS = 0.01
|
||||
DEFAULT_W_STRAIGHTNESS = 6.0
|
||||
DEFAULT_MAX_COST = 2.0
|
||||
NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75°
|
||||
|
||||
|
||||
@njit(cache=True, fastmath=False)
|
||||
def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
|
||||
F = face_normal.shape[0]
|
||||
raw = np.zeros(F, dtype=np.float32)
|
||||
for f in range(F):
|
||||
nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2]
|
||||
s = np.float32(0.0)
|
||||
for e in range(3):
|
||||
nb = face_face[f, e]
|
||||
if nb < 0:
|
||||
continue
|
||||
mx = face_normal[nb, 0]; my = face_normal[nb, 1]; mz = face_normal[nb, 2]
|
||||
d = nx*mx + ny*my + nz*mz
|
||||
s += np.float32(1.0) - d
|
||||
raw[f] = s
|
||||
return raw
|
||||
|
||||
|
||||
def _face_curvature_numpy(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
|
||||
nb_safe = np.maximum(face_face, 0)
|
||||
nb_normal = face_normal[nb_safe]
|
||||
d = (face_normal[:, None, :] * nb_normal).sum(axis=-1)
|
||||
contrib = np.where(face_face >= 0, np.float32(1.0) - d, np.float32(0.0))
|
||||
return contrib.sum(axis=1).astype(np.float32)
|
||||
|
||||
|
||||
@njit(cache=True, fastmath=False)
|
||||
def _farthest_point_seeds_jit(
|
||||
face_centroid: np.ndarray, face_area: np.ndarray, face_weight: np.ndarray,
|
||||
initial_seeds: np.ndarray, k_target: int,
|
||||
):
|
||||
F = face_centroid.shape[0]
|
||||
INF = np.float32(1e30)
|
||||
min_dist = np.full(F, INF, dtype=np.float32)
|
||||
seeds = np.empty(k_target, dtype=np.int64)
|
||||
n_seeds = 0
|
||||
for i in range(initial_seeds.shape[0]):
|
||||
s = initial_seeds[i]
|
||||
if s < 0 or n_seeds >= k_target:
|
||||
continue
|
||||
seeds[n_seeds] = s
|
||||
n_seeds += 1
|
||||
sx = face_centroid[s, 0]; sy = face_centroid[s, 1]; sz = face_centroid[s, 2]
|
||||
for f in range(F):
|
||||
dx = face_centroid[f, 0] - sx
|
||||
dy = face_centroid[f, 1] - sy
|
||||
dz = face_centroid[f, 2] - sz
|
||||
d2 = dx*dx + dy*dy + dz*dz
|
||||
if d2 < min_dist[f]:
|
||||
min_dist[f] = d2
|
||||
while n_seeds < k_target:
|
||||
best_f = -1
|
||||
best_score = np.float32(-1.0)
|
||||
for f in range(F):
|
||||
d = min_dist[f]
|
||||
if d >= INF * np.float32(0.5):
|
||||
continue
|
||||
score = d * face_weight[f]
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_f = f
|
||||
if best_f < 0:
|
||||
break
|
||||
seeds[n_seeds] = best_f
|
||||
n_seeds += 1
|
||||
sx = face_centroid[best_f, 0]
|
||||
sy = face_centroid[best_f, 1]
|
||||
sz = face_centroid[best_f, 2]
|
||||
for f in range(F):
|
||||
dx = face_centroid[f, 0] - sx
|
||||
dy = face_centroid[f, 1] - sy
|
||||
dz = face_centroid[f, 2] - sz
|
||||
d2 = dx*dx + dy*dy + dz*dz
|
||||
if d2 < min_dist[f]:
|
||||
min_dist[f] = d2
|
||||
return seeds[:n_seeds]
|
||||
|
||||
|
||||
def _farthest_point_seeds_numpy(
|
||||
face_centroid: np.ndarray, initial_seeds: np.ndarray, k_target: int,
|
||||
):
|
||||
F = face_centroid.shape[0]
|
||||
min_dist = np.full(F, np.inf, dtype=np.float32)
|
||||
seeds: List[int] = []
|
||||
for s in initial_seeds:
|
||||
if s < 0 or len(seeds) >= k_target:
|
||||
continue
|
||||
seeds.append(int(s))
|
||||
d = ((face_centroid - face_centroid[s])**2).sum(axis=-1)
|
||||
min_dist = np.minimum(min_dist, d)
|
||||
while len(seeds) < k_target:
|
||||
best = int(np.argmax(min_dist))
|
||||
if not np.isfinite(min_dist[best]) or min_dist[best] <= 0:
|
||||
break
|
||||
seeds.append(best)
|
||||
d = ((face_centroid - face_centroid[best])**2).sum(axis=-1)
|
||||
min_dist = np.minimum(min_dist, d)
|
||||
return np.asarray(seeds, dtype=np.int64)
|
||||
|
||||
|
||||
@njit(cache=True, fastmath=False)
|
||||
def _cost_grow_iter_jit(
|
||||
face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray,
|
||||
face_area: np.ndarray, face_edge_len: np.ndarray,
|
||||
chart_basis: np.ndarray, chart_normal_sum: np.ndarray,
|
||||
chart_area: np.ndarray, chart_perim: np.ndarray,
|
||||
nd_cutoff: float, max_cost: float,
|
||||
w_nd: float, w_round: float, w_straight: float,
|
||||
):
|
||||
"""One grow iter: each unassigned face joins its lowest-cost adjacent chart if cost<max_cost."""
|
||||
F = face_chart.shape[0]
|
||||
best_chart_per_face = np.full(F, -1, dtype=np.int64)
|
||||
best_cost_per_face = np.full(F, np.inf, dtype=np.float32)
|
||||
|
||||
for f in range(F):
|
||||
if face_chart[f] != -1:
|
||||
continue
|
||||
nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2]
|
||||
af = face_area[f]
|
||||
for e0 in range(3):
|
||||
nb0 = face_face[f, e0]
|
||||
if nb0 < 0:
|
||||
continue
|
||||
c = face_chart[nb0]
|
||||
if c < 0:
|
||||
continue
|
||||
d = (nx * chart_basis[c, 0] + ny * chart_basis[c, 1] + nz * chart_basis[c, 2])
|
||||
nd = np.float32(1.0) - d
|
||||
if nd > np.float32(1.0):
|
||||
nd = np.float32(1.0)
|
||||
if nd < np.float32(0.0):
|
||||
nd = np.float32(0.0)
|
||||
if nd >= nd_cutoff:
|
||||
continue
|
||||
l_in = np.float32(0.0)
|
||||
l_out = np.float32(0.0)
|
||||
for e1 in range(3):
|
||||
nb1 = face_face[f, e1]
|
||||
el = face_edge_len[f, e1]
|
||||
if nb1 < 0:
|
||||
l_out += el
|
||||
elif face_chart[nb1] == c:
|
||||
l_in += el
|
||||
else:
|
||||
l_out += el
|
||||
ca = chart_area[c]
|
||||
cp = chart_perim[c]
|
||||
new_perim = cp - l_in + l_out
|
||||
new_area = ca + af
|
||||
if cp <= np.float32(1e-20) or ca <= np.float32(1e-20):
|
||||
round_cost = np.float32(0.0)
|
||||
else:
|
||||
old_r = (cp * cp) / ca
|
||||
new_r = (new_perim * new_perim) / new_area
|
||||
if new_r <= np.float32(1e-20):
|
||||
round_cost = np.float32(0.0)
|
||||
else:
|
||||
round_cost = np.float32(1.0) - old_r / new_r
|
||||
denom = l_out + l_in
|
||||
if denom <= np.float32(1e-20):
|
||||
straight_cost = np.float32(0.0)
|
||||
else:
|
||||
ratio = (l_out - l_in) / denom
|
||||
if ratio < np.float32(0.0):
|
||||
straight_cost = ratio
|
||||
else:
|
||||
straight_cost = np.float32(0.0)
|
||||
cost = (w_nd * nd + w_round * round_cost + w_straight * straight_cost)
|
||||
if cost < best_cost_per_face[f]:
|
||||
best_cost_per_face[f] = cost
|
||||
best_chart_per_face[f] = c
|
||||
|
||||
n_assigned = 0
|
||||
for f in range(F):
|
||||
if face_chart[f] != -1:
|
||||
continue
|
||||
if best_chart_per_face[f] < 0:
|
||||
continue
|
||||
if best_cost_per_face[f] > max_cost:
|
||||
continue
|
||||
c = best_chart_per_face[f]
|
||||
l_in = np.float32(0.0)
|
||||
l_out = np.float32(0.0)
|
||||
for e1 in range(3):
|
||||
nb1 = face_face[f, e1]
|
||||
el = face_edge_len[f, e1]
|
||||
if nb1 < 0:
|
||||
l_out += el
|
||||
elif face_chart[nb1] == c:
|
||||
l_in += el
|
||||
else:
|
||||
l_out += el
|
||||
af = face_area[f]
|
||||
face_chart[f] = c
|
||||
chart_normal_sum[c, 0] += face_normal[f, 0] * af
|
||||
chart_normal_sum[c, 1] += face_normal[f, 1] * af
|
||||
chart_normal_sum[c, 2] += face_normal[f, 2] * af
|
||||
chart_area[c] += af
|
||||
chart_perim[c] = chart_perim[c] - l_in + l_out
|
||||
nx = chart_normal_sum[c, 0]
|
||||
ny = chart_normal_sum[c, 1]
|
||||
nz = chart_normal_sum[c, 2]
|
||||
nlen = np.sqrt(nx * nx + ny * ny + nz * nz)
|
||||
if nlen > np.float32(1e-20):
|
||||
chart_basis[c, 0] = nx / nlen
|
||||
chart_basis[c, 1] = ny / nlen
|
||||
chart_basis[c, 2] = nz / nlen
|
||||
n_assigned += 1
|
||||
return n_assigned
|
||||
|
||||
|
||||
def _renumber(face_chart: np.ndarray, device) -> Tensor:
|
||||
unique = np.unique(face_chart[face_chart >= 0])
|
||||
if unique.size == 0:
|
||||
return torch.from_numpy(face_chart).to(device)
|
||||
remap = np.full(int(unique.max()) + 1, -1, dtype=np.int64)
|
||||
remap[unique] = np.arange(unique.size)
|
||||
out = face_chart.copy()
|
||||
mask = out >= 0
|
||||
out[mask] = remap[out[mask]]
|
||||
return torch.from_numpy(out).to(device)
|
||||
|
||||
|
||||
def _segment_charts_fast(
|
||||
mesh: MeshData,
|
||||
max_cost: float,
|
||||
w_normal_deviation: float,
|
||||
w_roundness: float = DEFAULT_W_ROUNDNESS,
|
||||
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
|
||||
target_chart_count: int = 0,
|
||||
) -> Tensor:
|
||||
"""Parallel batch cost-grow; target_chart_count 0 = adaptive seeding, >0 = K curvature-weighted FPS seeds."""
|
||||
F = mesh.faces.shape[0]
|
||||
device = mesh.faces.device
|
||||
if F == 0:
|
||||
return torch.zeros(0, dtype=torch.long, device=device)
|
||||
|
||||
face_normal = mesh.face_normal.detach().cpu().numpy().astype(np.float32)
|
||||
face_area = mesh.face_area.detach().cpu().numpy().astype(np.float32)
|
||||
face_centroid = mesh.face_centroid.detach().cpu().numpy().astype(np.float32)
|
||||
face_face = mesh.face_face.detach().cpu().numpy()
|
||||
|
||||
face_chart = np.full(F, -1, dtype=np.int64)
|
||||
nd_cutoff = np.float32(NORMAL_DEVIATION_HARD_CUTOFF)
|
||||
nd_threshold = np.float32(min(max_cost / max(w_normal_deviation, 1e-6),
|
||||
NORMAL_DEVIATION_HARD_CUTOFF * 0.99))
|
||||
|
||||
component = (mesh.component.detach().cpu().numpy()
|
||||
if hasattr(mesh.component, "detach") else np.asarray(mesh.component))
|
||||
if component.size:
|
||||
_, first_idx = np.unique(component, return_index=True)
|
||||
initial_seeds = first_idx.astype(np.int64)
|
||||
else:
|
||||
initial_seeds = np.empty(0, dtype=np.int64)
|
||||
|
||||
adaptive_seeding = target_chart_count <= 0
|
||||
if adaptive_seeding:
|
||||
seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()]
|
||||
if not seed_faces:
|
||||
seed_faces = [0]
|
||||
else:
|
||||
if _HAVE_NUMBA:
|
||||
curvature_raw = _face_curvature_jit(face_normal, face_face)
|
||||
else:
|
||||
curvature_raw = _face_curvature_numpy(face_normal, face_face)
|
||||
cmax = float(curvature_raw.max()) if curvature_raw.size else 0.0
|
||||
if cmax > 1e-6:
|
||||
face_weight = (np.float32(1.0) + np.float32(50.0) *
|
||||
(curvature_raw / np.float32(cmax))).astype(np.float32)
|
||||
else:
|
||||
face_weight = np.ones(F, dtype=np.float32)
|
||||
n_comp = int(initial_seeds.size)
|
||||
if n_comp < int(target_chart_count):
|
||||
target_seeds = int(target_chart_count)
|
||||
else:
|
||||
target_seeds = n_comp + max(int(target_chart_count) // 4, 8)
|
||||
target_seeds = min(target_seeds, F)
|
||||
if _HAVE_NUMBA:
|
||||
seeds_arr = _farthest_point_seeds_jit(
|
||||
face_centroid, face_area, face_weight, initial_seeds, target_seeds,
|
||||
)
|
||||
else:
|
||||
seeds_arr = _farthest_point_seeds_numpy(
|
||||
face_centroid, initial_seeds, target_seeds,
|
||||
)
|
||||
seed_faces = [int(s) for s in seeds_arr.tolist()]
|
||||
|
||||
K = len(seed_faces)
|
||||
chart_basis = np.zeros((K, 3), dtype=np.float32)
|
||||
chart_normal_sum = np.zeros((K, 3), dtype=np.float32)
|
||||
chart_area = np.zeros(K, dtype=np.float32)
|
||||
chart_perim = np.zeros(K, dtype=np.float32)
|
||||
face_edge_len = (
|
||||
face_edge_lengths(mesh.vertices, mesh.faces)
|
||||
.detach().cpu().numpy()
|
||||
)
|
||||
for cid, sf in enumerate(seed_faces):
|
||||
face_chart[sf] = cid
|
||||
n = face_normal[sf]
|
||||
a = face_area[sf]
|
||||
chart_basis[cid] = n.astype(np.float32)
|
||||
chart_normal_sum[cid] = (n * a).astype(np.float32)
|
||||
chart_area[cid] = float(a)
|
||||
chart_perim[cid] = float(face_edge_len[sf].sum())
|
||||
|
||||
if K == 0:
|
||||
return _renumber(face_chart, device)
|
||||
|
||||
min_dist_to_seed = np.full(F, np.inf, dtype=np.float32)
|
||||
if adaptive_seeding:
|
||||
for sf in seed_faces:
|
||||
d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1)
|
||||
min_dist_to_seed = np.minimum(min_dist_to_seed, d)
|
||||
|
||||
if _HAVE_NUMBA:
|
||||
# Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg.
|
||||
tau_final = min(max_cost * 0.25, 0.5)
|
||||
thresholds = [t for t in (0.05, 0.1, 0.25) if t < tau_final] + [tau_final]
|
||||
max_inner = max(64, int(np.sqrt(F)) * 2)
|
||||
max_total_charts = max(F, 8000)
|
||||
outer_iter = 0
|
||||
while True:
|
||||
outer_iter += 1
|
||||
if outer_iter > F + 16:
|
||||
break
|
||||
for tau in thresholds:
|
||||
for _ in range(max_inner):
|
||||
n_added = _cost_grow_iter_jit(
|
||||
face_chart, face_face, face_normal, face_area, face_edge_len,
|
||||
chart_basis, chart_normal_sum, chart_area, chart_perim,
|
||||
nd_cutoff, np.float32(tau),
|
||||
np.float32(w_normal_deviation),
|
||||
np.float32(w_roundness),
|
||||
np.float32(w_straightness),
|
||||
)
|
||||
if n_added == 0:
|
||||
break
|
||||
if (face_chart == -1).sum() == 0:
|
||||
break
|
||||
if not adaptive_seeding:
|
||||
break
|
||||
if chart_basis.shape[0] >= max_total_charts:
|
||||
break
|
||||
unassigned_mask = face_chart == -1
|
||||
cand = np.where(unassigned_mask, min_dist_to_seed, np.float32(-np.inf))
|
||||
new_seed = int(np.argmax(cand))
|
||||
n = face_normal[new_seed]
|
||||
a = face_area[new_seed]
|
||||
chart_basis = np.vstack([chart_basis, n[None, :].astype(np.float32)])
|
||||
chart_normal_sum = np.vstack(
|
||||
[chart_normal_sum, (n * a)[None, :].astype(np.float32)]
|
||||
)
|
||||
chart_area = np.concatenate([chart_area, np.array([a], dtype=np.float32)])
|
||||
chart_perim = np.concatenate(
|
||||
[chart_perim, np.array([face_edge_len[new_seed].sum()], dtype=np.float32)]
|
||||
)
|
||||
face_chart[new_seed] = chart_basis.shape[0] - 1
|
||||
new_d = ((face_centroid - face_centroid[new_seed]) ** 2).sum(axis=-1)
|
||||
min_dist_to_seed = np.minimum(min_dist_to_seed, new_d)
|
||||
else:
|
||||
# Numpy fallback: nd-only adaptive grow.
|
||||
for _ in range(max(64, int(np.sqrt(F)) + 32)):
|
||||
unassigned = face_chart == -1
|
||||
if not unassigned.any():
|
||||
break
|
||||
u_idx = np.nonzero(unassigned)[0]
|
||||
nbs = face_face[u_idx]
|
||||
nbs_safe = np.where(nbs >= 0, nbs, 0)
|
||||
nb_charts = np.where(nbs >= 0, face_chart[nbs_safe], -1)
|
||||
valid = (nb_charts >= 0)
|
||||
if not valid.any():
|
||||
break
|
||||
nb_charts_safe = np.where(valid, nb_charts, 0)
|
||||
nb_basis = chart_basis[nb_charts_safe]
|
||||
d = (face_normal[u_idx][:, None, :] * nb_basis).sum(axis=-1)
|
||||
nd = np.where(valid, np.float32(1.0) - d, np.inf).clip(max=1.0)
|
||||
nd = np.where(nd >= nd_cutoff, np.inf, nd)
|
||||
best_e = np.argmin(nd, axis=1)
|
||||
best_cost = nd[np.arange(u_idx.size), best_e]
|
||||
best_c = nb_charts_safe[np.arange(u_idx.size), best_e]
|
||||
accept = (best_cost <= nd_threshold) & np.isfinite(best_cost)
|
||||
if not accept.any():
|
||||
break
|
||||
pick_u = u_idx[accept]
|
||||
pick_c = best_c[accept]
|
||||
face_chart[pick_u] = pick_c
|
||||
for f, c in zip(pick_u, pick_c):
|
||||
chart_normal_sum[c] += face_normal[f] * face_area[f]
|
||||
chart_area[c] += face_area[f]
|
||||
|
||||
# Orphan cleanup: leftover faces join their best-matching neighbor's chart.
|
||||
if (face_chart == -1).any() and chart_basis.shape[0] > 0:
|
||||
while True:
|
||||
orphans = np.nonzero(face_chart == -1)[0]
|
||||
if orphans.size == 0:
|
||||
break
|
||||
nbs = face_face[orphans]
|
||||
nbs_safe = np.where(nbs >= 0, nbs, 0)
|
||||
nb_charts = np.where(nbs >= 0, face_chart[nbs_safe], -1)
|
||||
valid = (nb_charts >= 0)
|
||||
if not valid.any():
|
||||
break
|
||||
nb_charts_safe = np.where(valid, nb_charts, 0)
|
||||
nb_basis = chart_basis[nb_charts_safe]
|
||||
d = (face_normal[orphans][:, None, :] * nb_basis).sum(axis=-1)
|
||||
nd = np.where(valid, np.float32(1.0) - d, np.inf)
|
||||
best_e = np.argmin(nd, axis=1)
|
||||
best_c = nb_charts_safe[np.arange(orphans.size), best_e]
|
||||
assignable = valid.any(axis=1)
|
||||
if not assignable.any():
|
||||
break
|
||||
assign_idx = orphans[assignable]
|
||||
assign_c = best_c[assignable]
|
||||
face_chart[assign_idx] = assign_c
|
||||
if (face_chart == -1).any():
|
||||
new_singletons = np.nonzero(face_chart == -1)[0]
|
||||
for f in new_singletons:
|
||||
face_chart[int(f)] = chart_basis.shape[0]
|
||||
chart_basis = np.concatenate(
|
||||
[chart_basis, face_normal[int(f)].astype(np.float32)[None, :]],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
return _renumber(face_chart, device)
|
||||
|
||||
|
||||
def segment_charts(
|
||||
mesh: MeshData,
|
||||
max_cost: float = DEFAULT_MAX_COST,
|
||||
w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION,
|
||||
w_roundness: float = DEFAULT_W_ROUNDNESS,
|
||||
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
|
||||
target_chart_count: int = 0,
|
||||
) -> Tensor:
|
||||
"""Segment mesh into charts. Returns face -> chart_id."""
|
||||
return _segment_charts_fast(
|
||||
mesh, max_cost=max_cost,
|
||||
w_normal_deviation=w_normal_deviation,
|
||||
w_roundness=w_roundness,
|
||||
w_straightness=w_straightness,
|
||||
target_chart_count=target_chart_count,
|
||||
)
|
||||
|
||||
|
||||
# ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ----
|
||||
def _combine_normal_cones(
|
||||
axis_a: Tensor, half_a: Tensor,
|
||||
axis_b: Tensor, half_b: Tensor,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Merge two normal cones along the great circle from axis_a; returns (combined_axis, combined_half_angle, axis_angle)."""
|
||||
cos_angle = (axis_a * axis_b).sum(dim=-1).clamp(-1.0, 1.0)
|
||||
axis_angle = torch.acos(cos_angle)
|
||||
new_low = torch.minimum(-half_a, axis_angle - half_b)
|
||||
new_high = torch.maximum(half_a, axis_angle + half_b)
|
||||
new_half = (new_high - new_low) * 0.5
|
||||
rot_angle = (new_high + new_low) * 0.5
|
||||
b_perp = axis_b - axis_a * cos_angle.unsqueeze(-1)
|
||||
b_perp_norm = b_perp.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
b_perp_unit = b_perp / b_perp_norm
|
||||
new_axis = (
|
||||
axis_a * torch.cos(rot_angle).unsqueeze(-1)
|
||||
+ b_perp_unit * torch.sin(rot_angle).unsqueeze(-1)
|
||||
)
|
||||
new_axis_norm = new_axis.norm(dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
new_axis = new_axis / new_axis_norm
|
||||
return new_axis, new_half, axis_angle
|
||||
|
||||
|
||||
def _build_chart_edges(
|
||||
face_face: Tensor,
|
||||
chart_id: Tensor,
|
||||
face_edge_len: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Build chart-edge list (chart_pairs[E,2] with a<b, edge_length[E]); same-chart edges dropped, duplicates summed."""
|
||||
F = face_face.shape[0]
|
||||
device = face_face.device
|
||||
f_idx = torch.arange(F, device=device).repeat_interleave(3)
|
||||
nb = face_face.flatten()
|
||||
valid = nb >= 0
|
||||
f_idx = f_idx[valid]
|
||||
nb = nb[valid]
|
||||
el = face_edge_len.flatten()[valid]
|
||||
|
||||
ca = chart_id[f_idx]
|
||||
cb = chart_id[nb]
|
||||
diff = ca != cb
|
||||
ca = ca[diff]
|
||||
cb = cb[diff]
|
||||
el = el[diff]
|
||||
if ca.numel() == 0:
|
||||
return (
|
||||
torch.empty((0, 2), dtype=torch.long, device=device),
|
||||
torch.empty(0, device=device),
|
||||
)
|
||||
|
||||
lo = torch.minimum(ca, cb)
|
||||
hi = torch.maximum(ca, cb)
|
||||
V = int(chart_id.max().item()) + 1
|
||||
key = lo * V + hi
|
||||
sort_idx = torch.argsort(key)
|
||||
sorted_key = key[sort_idx]
|
||||
sorted_lo = lo[sort_idx]
|
||||
sorted_hi = hi[sort_idx]
|
||||
sorted_el = el[sort_idx]
|
||||
unique_key, inverse, counts = torch.unique(
|
||||
sorted_key, return_inverse=True, return_counts=True
|
||||
)
|
||||
n_unique = unique_key.shape[0]
|
||||
reduced_el = torch.zeros(n_unique, device=device, dtype=el.dtype)
|
||||
reduced_el.scatter_add_(0, inverse, sorted_el)
|
||||
first_idx = torch.cat([
|
||||
torch.zeros(1, dtype=torch.long, device=device),
|
||||
counts.cumsum(0)[:-1],
|
||||
])
|
||||
pair_lo = sorted_lo[first_idx]
|
||||
pair_hi = sorted_hi[first_idx]
|
||||
chart_pairs = torch.stack([pair_lo, pair_hi], dim=1)
|
||||
return chart_pairs, reduced_el
|
||||
|
||||
|
||||
def cluster_charts_pec(
|
||||
mesh: MeshData,
|
||||
target_chart_count: int = 0,
|
||||
max_cost: float = 0.7,
|
||||
area_penalty_weight: float = 0.0,
|
||||
roundness_weight: float = 0.0,
|
||||
max_iters: int = 1024,
|
||||
) -> Tensor:
|
||||
"""Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg)."""
|
||||
device = mesh.faces.device
|
||||
F = mesh.faces.shape[0]
|
||||
faces = mesh.faces.to(torch.long)
|
||||
vertices = mesh.vertices.to(torch.float32)
|
||||
face_normal = mesh.face_normal.to(torch.float32)
|
||||
face_area = mesh.face_area.to(torch.float32)
|
||||
face_face = mesh.face_face.to(torch.long)
|
||||
|
||||
face_edge_len = face_edge_lengths(vertices, faces)
|
||||
|
||||
chart_id = torch.arange(F, dtype=torch.long, device=device)
|
||||
chart_axis = face_normal.clone()
|
||||
chart_half = torch.zeros(F, dtype=torch.float32, device=device)
|
||||
chart_area = face_area.clone()
|
||||
chart_perim = face_edge_len.sum(dim=1).clone()
|
||||
|
||||
for it in range(max_iters):
|
||||
edges, edge_len = _build_chart_edges(face_face, chart_id, face_edge_len)
|
||||
if edges.shape[0] == 0:
|
||||
break
|
||||
|
||||
a = edges[:, 0]
|
||||
b = edges[:, 1]
|
||||
axis_a = chart_axis[a]
|
||||
axis_b = chart_axis[b]
|
||||
half_a = chart_half[a]
|
||||
half_b = chart_half[b]
|
||||
_, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b)
|
||||
cost = new_half.clone()
|
||||
if area_penalty_weight > 0.0:
|
||||
new_area = chart_area[a] + chart_area[b]
|
||||
cost = cost + area_penalty_weight * new_area
|
||||
if roundness_weight > 0.0:
|
||||
new_area_r = chart_area[a] + chart_area[b]
|
||||
new_perim_r = chart_perim[a] + chart_perim[b] - 2.0 * edge_len
|
||||
cost = cost + roundness_weight * (new_perim_r * new_perim_r) / new_area_r.clamp_min(1e-12)
|
||||
|
||||
# Pack (cost, edge_id) so scatter_reduce amin picks the right edge.
|
||||
E = edges.shape[0]
|
||||
N = int(chart_id.max().item()) + 1
|
||||
edge_ids = torch.arange(E, dtype=torch.long, device=device)
|
||||
cost_i32 = torch.clamp(cost * 1e6, max=2e9).to(torch.int64)
|
||||
key = (cost_i32 << 32) | edge_ids
|
||||
chart_min = torch.full((N,), (2**62), dtype=torch.long, device=device)
|
||||
chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True)
|
||||
chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True)
|
||||
|
||||
# Mutual-min collapse: each chart in at most one merge per iter.
|
||||
is_a_min = chart_min[a] == key
|
||||
is_b_min = chart_min[b] == key
|
||||
within = cost <= max_cost
|
||||
winners = is_a_min & is_b_min & within
|
||||
|
||||
n_merge = int(winners.sum().item())
|
||||
if n_merge == 0:
|
||||
break
|
||||
|
||||
win_a = a[winners]
|
||||
win_b = b[winners]
|
||||
win_el = edge_len[winners]
|
||||
|
||||
axis_a_w = chart_axis[win_a]
|
||||
half_a_w = chart_half[win_a]
|
||||
axis_b_w = chart_axis[win_b]
|
||||
half_b_w = chart_half[win_b]
|
||||
new_axis, new_half_w, _ = _combine_normal_cones(
|
||||
axis_a_w, half_a_w, axis_b_w, half_b_w,
|
||||
)
|
||||
chart_axis[win_a] = new_axis
|
||||
chart_half[win_a] = new_half_w
|
||||
chart_area[win_a] = chart_area[win_a] + chart_area[win_b]
|
||||
chart_perim[win_a] = chart_perim[win_a] + chart_perim[win_b] - 2.0 * win_el
|
||||
|
||||
remap = torch.arange(N, dtype=torch.long, device=device)
|
||||
remap[win_b] = win_a
|
||||
chart_id = remap[chart_id]
|
||||
|
||||
_, inverse = torch.unique(chart_id, sorted=True, return_inverse=True)
|
||||
return inverse
|
||||
@ -6,9 +6,20 @@ from comfy_api.latest import ComfyExtension, IO, Types
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_extras.qem_decimate.qem_core import simplify as qem_decimate_simplify, QEMConfig
|
||||
from server import PromptServer
|
||||
from comfy_extras.mesh3d.postprocess.qem_decimate import (
|
||||
simplify as qem_decimate_simplify, QEMConfig, cluster_decimate as qem_cluster_decimate,
|
||||
)
|
||||
from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc
|
||||
from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh
|
||||
from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg
|
||||
from comfy_extras.mesh3d.uv_unwrap import parameterize as _uv_param
|
||||
from comfy_extras.mesh3d.uv_unwrap import pack as _uv_pack
|
||||
import warnings
|
||||
import logging
|
||||
import scipy
|
||||
from scipy.sparse import csr_matrix
|
||||
from scipy.sparse.csgraph import connected_components
|
||||
|
||||
def get_mesh_batch_item(mesh, index):
|
||||
if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None:
|
||||
@ -2162,6 +2173,566 @@ class DecimateMesh(IO.ComfyNode):
|
||||
return result
|
||||
|
||||
|
||||
class RemeshMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
# sign_mode picks the scalar field, and exposes only the knobs relevant to it
|
||||
# (DynamicCombo: udf sub-widgets show for 'udf', sdf sub-widgets for 'sdf').
|
||||
sign_mode_options = [
|
||||
IO.DynamicCombo.Option(key="udf", inputs=[
|
||||
IO.Boolean.Input("qef", default=False,
|
||||
tooltip="Experimental: place dual vertices via QEF (closest-triangle normals) "
|
||||
"instead of edge-crossing centroid. QEF is sign-agnostic so it works "
|
||||
"in UDF too — pulls the ±eps surface back onto the planes for sharper "
|
||||
"edges. May misbehave near the UDF double shell; compare with it off."),
|
||||
IO.Boolean.Input("drop_inverted_components", default=True,
|
||||
tooltip="Drop closed components with inward normals (negative signed volume) — "
|
||||
"the inner shell UDF produces on closed regions."),
|
||||
IO.Boolean.Input("drop_enclosed_components", default=True,
|
||||
tooltip="Drop components whose bbox is inside the largest's AND fail a raycast "
|
||||
"point-in-mesh test. Disable if you have legitimate parts inside others."),
|
||||
]),
|
||||
IO.DynamicCombo.Option(key="sdf", inputs=[
|
||||
IO.Boolean.Input("qef", default=True,
|
||||
tooltip="Place dual vertices via QEF solve from closest-triangle normals "
|
||||
"(recovers sharp features) vs edge-crossing centroid."),
|
||||
IO.Boolean.Input("manifold", default=False,
|
||||
tooltip="Manifold Dual Contouring: emit 1-4 dual verts per voxel for "
|
||||
"multi-sheet (thin/touching) cases. Slower; guarantees manifold output."),
|
||||
]),
|
||||
]
|
||||
return IO.Schema(
|
||||
node_id="RemeshMesh",
|
||||
display_name="Remesh Mesh (Narrow-Band DC)",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Re-extracts a uniformly tessellated mesh by sampling a distance field on a "
|
||||
"narrow-band voxel grid and contouring it with Dual Contouring, on the active "
|
||||
"compute device. Normalizes topology of messy / non-manifold / self-intersecting "
|
||||
"input; run before DecimateMesh to hit an exact face count. Output stays welded."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Int.Input("target_faces", default=0, min=0, max=50_000_000,
|
||||
tooltip="0 = use 'resolution'. >0 = auto-pick resolution to roughly hit this "
|
||||
"face count (±30-50%); usually overshoot then DecimateMesh to exact."),
|
||||
IO.Int.Input("resolution", default=256, min=32, max=1024,
|
||||
tooltip="Voxel grid resolution (used when target_faces=0). Higher = more detail, "
|
||||
"slower. 256 ~ 100k faces, 512 ~ 1M."),
|
||||
IO.DynamicCombo.Input("sign_mode", options=sign_mode_options, display_name="sign_mode",
|
||||
tooltip="udf: robust to messy/non-manifold input (double shell cleaned by "
|
||||
"the inner-shell filters). sdf: clean single surface with optional "
|
||||
"QEF sharp-feature recovery, but needs consistent winding."),
|
||||
IO.Float.Input("band", default=1.0, min=0.5, max=4.0, step=0.1,
|
||||
tooltip="Narrow-band width in voxel units (which voxels are sampled). In UDF "
|
||||
"mode also offsets the surface by this many voxels."),
|
||||
IO.Float.Input("project_back", default=0.0, min=0.0, max=1.0, step=0.05,
|
||||
tooltip="Lerp output verts toward the closest point on the original surface "
|
||||
"(0 = pure DC, 1 = snapped). Recovers voxelization-lost detail."),
|
||||
IO.Boolean.Input("fix_poles", default=False,
|
||||
tooltip="Collapse valence-3 vertex pairs (DC T-junction artifact). Cheap; "
|
||||
"improves shading and downstream simplification."),
|
||||
IO.Int.Input("smooth_iters", default=0, min=0, max=20,
|
||||
tooltip="Taubin λ|μ smoothing iterations (0 = off). Volume-preserving; cleans DC "
|
||||
"stairstepping. 2-3 is enough; higher rounds off QEF sharp features."),
|
||||
IO.Float.Input("drop_small_components", default=0.01, min=0.0, max=0.5, step=0.005,
|
||||
tooltip="Drop components with fewer than this fraction of the largest component's "
|
||||
"faces (inner-shell fragments, noise). 0 disables."),
|
||||
IO.Int.Input("precluster_max_verts", default=0, min=0, max=50_000_000,
|
||||
tooltip="Safety fallback: if input has more verts than this (>0), cluster-decimate "
|
||||
"it down first so the distance-field queries don't OOM on huge inputs. "
|
||||
"0 = off; 1-2M is reasonable for very large meshes."),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
hidden=[IO.Hidden.unique_id],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, target_faces, resolution, sign_mode, band,
|
||||
project_back, fix_poles, smooth_iters,
|
||||
drop_small_components, precluster_max_verts):
|
||||
mode = sign_mode.get("sign_mode", "udf")
|
||||
# mode-specific sub-widgets (absent ones fall back to defaults)
|
||||
qef = bool(sign_mode.get("qef", True))
|
||||
manifold = bool(sign_mode.get("manifold", False))
|
||||
drop_inverted_components = bool(sign_mode.get("drop_inverted_components", True))
|
||||
drop_enclosed_components = bool(sign_mode.get("drop_enclosed_components", True))
|
||||
|
||||
# ComfyUI passes meshes on CPU; remesh is far faster on GPU. Run on the
|
||||
# selected compute device and return on the mesh's original device.
|
||||
compute_device = comfy.model_management.get_torch_device()
|
||||
counts = {"in": 0, "out": 0}
|
||||
|
||||
def _fn(v, f, c):
|
||||
counts["in"] += int(f.shape[0])
|
||||
try:
|
||||
src_device = v.device
|
||||
vv = v.to(compute_device).float()
|
||||
ff = f.to(compute_device).to(torch.int64)
|
||||
cc = c.to(compute_device).float() if c is not None else None
|
||||
|
||||
# safety fallback: cluster-decimate very large inputs before the field queries
|
||||
if precluster_max_verts > 0 and vv.shape[0] > precluster_max_verts:
|
||||
vv, ff, cc = qem_cluster_decimate(
|
||||
vv, ff, target_verts=int(precluster_max_verts), colors=cc)
|
||||
|
||||
# Fixed [-0.5,0.5] cube domain (matches cumesh / TRELLIS2). scale ≈ 1.0
|
||||
# for any resolution, so this is consistent in target_faces auto mode too.
|
||||
rs_scale = (resolution + 3.0 * band) / resolution
|
||||
rs_center = torch.zeros(3, dtype=vv.dtype, device=compute_device)
|
||||
|
||||
rv, rf, rc = remesh_narrow_band_dc(
|
||||
vv, ff,
|
||||
resolution=int(resolution), target_faces=int(target_faces),
|
||||
band=float(band), project_back=float(project_back),
|
||||
qef=qef, sign_mode=mode,
|
||||
manifold=manifold, fix_poles=bool(fix_poles),
|
||||
smooth_iters=int(smooth_iters),
|
||||
drop_small_components=float(drop_small_components),
|
||||
drop_inverted_components=drop_inverted_components,
|
||||
drop_enclosed_components=drop_enclosed_components,
|
||||
scale=rs_scale, center=rs_center, colors=cc)
|
||||
|
||||
v = rv.to(src_device)
|
||||
f = rf.to(src_device)
|
||||
c = rc.to(src_device) if rc is not None else None
|
||||
except Exception as e:
|
||||
logging.warning(f"RemeshMesh: remesh failed, passing mesh through unchanged: {e!r}")
|
||||
counts["out"] += int(f.shape[0])
|
||||
return v, f, c
|
||||
|
||||
result = _process_mesh_batch(mesh, _fn)
|
||||
|
||||
# Send progress text to display the face change on the node
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _pack_uv_meshes(vs, fs, uvs, colors):
|
||||
"""Pack per-item (verts, faces, uvs[, colors]) into a MESH; stack if single, else pad with counts."""
|
||||
if len(vs) == 1:
|
||||
m = Types.MESH(vertices=vs[0].unsqueeze(0), faces=fs[0].unsqueeze(0), uvs=uvs[0].unsqueeze(0))
|
||||
if colors is not None:
|
||||
m.vertex_colors = colors[0].unsqueeze(0)
|
||||
return m
|
||||
bsz = len(vs)
|
||||
dev = vs[0].device
|
||||
maxv = max(v.shape[0] for v in vs)
|
||||
maxf = max(f.shape[0] for f in fs)
|
||||
pv = vs[0].new_zeros((bsz, maxv, 3))
|
||||
pf = fs[0].new_zeros((bsz, maxf, 3))
|
||||
pu = uvs[0].new_zeros((bsz, maxv, 2))
|
||||
for i, (v, f, u) in enumerate(zip(vs, fs, uvs)):
|
||||
pv[i, :v.shape[0]] = v
|
||||
pf[i, :f.shape[0]] = f
|
||||
pu[i, :u.shape[0]] = u
|
||||
vc = torch.tensor([v.shape[0] for v in vs], device=dev, dtype=torch.int64)
|
||||
fc = torch.tensor([f.shape[0] for f in fs], device=dev, dtype=torch.int64)
|
||||
m = Types.MESH(vertices=pv, faces=pf, uvs=pu, vertex_counts=vc, face_counts=fc)
|
||||
if colors is not None:
|
||||
pc = colors[0].new_zeros((bsz, maxv, colors[0].shape[1]))
|
||||
for i, c in enumerate(colors):
|
||||
pc[i, :c.shape[0]] = c
|
||||
m.vertex_colors = pc
|
||||
return m
|
||||
|
||||
|
||||
def _uv_weld_vertices(v, f, weld_distance):
|
||||
"""Merge coincident verts; returns (welded_v, welded_f, welded_to_orig) (last None if no welding)."""
|
||||
v_np = v.cpu().numpy()
|
||||
f_np = f.cpu().numpy()
|
||||
if v_np.size == 0:
|
||||
return v, f, None
|
||||
extent = float(np.linalg.norm(v_np.max(axis=0) - v_np.min(axis=0)))
|
||||
tol = weld_distance if weld_distance > 0.0 else 1e-5 * extent
|
||||
if tol <= 0.0:
|
||||
return v, f, None
|
||||
keys = np.round(v_np / tol).astype(np.int64)
|
||||
_, inv = np.unique(keys, axis=0, return_inverse=True)
|
||||
n_unique = int(inv.max()) + 1
|
||||
if n_unique >= v_np.shape[0]:
|
||||
return v, f, None
|
||||
v_welded = np.zeros((n_unique, 3), dtype=np.float32)
|
||||
counts = np.zeros(n_unique, dtype=np.int64)
|
||||
np.add.at(v_welded, inv, v_np)
|
||||
np.add.at(counts, inv, 1)
|
||||
v_welded /= counts[:, None]
|
||||
welded_to_orig = np.empty(n_unique, dtype=np.int64)
|
||||
welded_to_orig[inv] = np.arange(v_np.shape[0], dtype=np.int64)
|
||||
v_new = torch.from_numpy(v_welded).to(v.dtype).to(v.device)
|
||||
f_new = torch.from_numpy(inv[f_np]).to(f.dtype).to(f.device)
|
||||
return v_new, f_new, welded_to_orig
|
||||
|
||||
|
||||
def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance):
|
||||
"""UV-unwrap a single mesh; returns (vmapping, indices, uvs) — vmapping maps each output
|
||||
vertex to an input vertex (seam verts duplicated)."""
|
||||
v_in = positions.to(torch.float32)
|
||||
f_in = indices.to(torch.long).reshape(-1, 3)
|
||||
v_in, f_in, welded_to_orig = _uv_weld_vertices(v_in, f_in, weld_distance)
|
||||
|
||||
# drop degenerate faces (repeated index) — they corrupt edge adjacency
|
||||
degen = ((f_in[:, 0] == f_in[:, 1]) | (f_in[:, 1] == f_in[:, 2]) | (f_in[:, 2] == f_in[:, 0]))
|
||||
if bool(degen.any()):
|
||||
f_in = f_in[~degen]
|
||||
|
||||
mesh = _uv_mesh.build_mesh(v_in, f_in)
|
||||
ff = mesh.face_face
|
||||
if ff.numel() and float((ff >= 0).float().mean().item()) < 0.25:
|
||||
warnings.warn("[uv_unwrap] mesh face-adjacency < 25% — vertices appear un-welded "
|
||||
"(triangle soup); UV charts will be per-face. Raise weld_distance.")
|
||||
|
||||
if segmenter == "pec":
|
||||
if mesh.faces.device.type != "cuda":
|
||||
raise RuntimeError("segmenter='pec' requires a CUDA mesh; use 'adaptive' for CPU.")
|
||||
face_chart = _uv_seg.cluster_charts_pec(mesh, target_chart_count=0, max_cost=1.0)
|
||||
elif segmenter == "adaptive":
|
||||
face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0, target_chart_count=0)
|
||||
else:
|
||||
raise ValueError(f"unknown segmenter '{segmenter}'. valid: pec, adaptive")
|
||||
|
||||
n_charts = int(face_chart.max().item()) + 1 if face_chart.numel() else 0
|
||||
areas_cpu = _uv_mesh.chart_3d_areas(mesh.face_area, face_chart, n_charts).detach().cpu()
|
||||
|
||||
# per-chart loop runs on CPU/numpy to avoid per-chart GPU sync
|
||||
face_chart_np = face_chart.cpu().numpy()
|
||||
faces_np = mesh.faces.cpu().numpy()
|
||||
vertices_np = mesh.vertices.cpu().numpy()
|
||||
face_face_np = mesh.face_face.cpu().numpy()
|
||||
sorted_face_idx_np = np.argsort(face_chart_np, kind="stable")
|
||||
chart_counts_np = np.bincount(face_chart_np, minlength=n_charts)
|
||||
chart_offsets_np = np.empty(n_charts + 1, dtype=np.int64)
|
||||
chart_offsets_np[0] = 0
|
||||
np.cumsum(chart_counts_np, out=chart_offsets_np[1:])
|
||||
|
||||
all_chart_uvs, all_chart_3d_areas, all_chart_uv_areas, all_chart_faces = [], [], [], []
|
||||
chart_records = []
|
||||
for c in range(n_charts):
|
||||
gfi_np = sorted_face_idx_np[chart_offsets_np[c]:chart_offsets_np[c + 1]]
|
||||
chart_faces_global = faces_np[gfi_np]
|
||||
used_verts_np = np.unique(chart_faces_global)
|
||||
local_faces_np = np.searchsorted(used_verts_np, chart_faces_global)
|
||||
local_verts_np = vertices_np[used_verts_np]
|
||||
ff_global = face_face_np[gfi_np]
|
||||
ff_safe = np.maximum(ff_global, 0)
|
||||
nb_chart = np.where(ff_global >= 0, face_chart_np[ff_safe], -1)
|
||||
keep = (ff_global >= 0) & (nb_chart == c)
|
||||
local_neighbor = np.searchsorted(gfi_np, ff_safe)
|
||||
local_ff_np = np.where(keep, local_neighbor, -1)
|
||||
|
||||
lf = torch.from_numpy(local_faces_np)
|
||||
uvs = _uv_param.parametrize_chart(
|
||||
torch.from_numpy(local_verts_np), lf, torch.from_numpy(local_ff_np))
|
||||
ua, ub, uc = uvs[lf[:, 0]], uvs[lf[:, 1]], uvs[lf[:, 2]]
|
||||
uv_area_sum = float(0.5 * (
|
||||
(ub[:, 0] - ua[:, 0]) * (uc[:, 1] - ua[:, 1])
|
||||
- (uc[:, 0] - ua[:, 0]) * (ub[:, 1] - ua[:, 1])).abs().sum().item())
|
||||
chart_records.append({"local_faces": lf, "vmap": torch.from_numpy(used_verts_np),
|
||||
"global_face_idx": torch.from_numpy(gfi_np)})
|
||||
all_chart_uvs.append(uvs)
|
||||
all_chart_3d_areas.append(float(areas_cpu[c].item()))
|
||||
all_chart_uv_areas.append(uv_area_sum)
|
||||
all_chart_faces.append(lf)
|
||||
|
||||
# auto-tune texel density to land near `resolution` (assumes ~0.62 pack fill)
|
||||
total_3d_area = sum(all_chart_3d_areas) or 1.0
|
||||
target_dim = float(resolution) if resolution > 0 else 1024.0
|
||||
tex_per_unit = math.sqrt((target_dim * target_dim) * 0.62 / total_3d_area)
|
||||
|
||||
placements, atlas_w, atlas_h = _uv_pack.pack_bitmap(
|
||||
all_chart_uvs, all_chart_3d_areas, all_chart_uv_areas, all_chart_faces,
|
||||
texels_per_unit=tex_per_unit, padding_texels=padding)
|
||||
placed = _uv_pack.apply_placements(all_chart_uvs, placements, atlas_w, atlas_h)
|
||||
|
||||
n_in_faces = mesh.faces.shape[0]
|
||||
out_indices = np.zeros((n_in_faces, 3), dtype=np.int64)
|
||||
out_uvs_list, out_vmap_list, v_cursor = [], [], 0
|
||||
for c, rec in enumerate(chart_records):
|
||||
vmap_np = rec["vmap"].cpu().numpy()
|
||||
local_faces_np = rec["local_faces"].cpu().numpy()
|
||||
global_face_idx = rec["global_face_idx"].cpu().numpy()
|
||||
out_uvs_list.append(placed[c].cpu().numpy())
|
||||
if welded_to_orig is not None:
|
||||
vmap_np = welded_to_orig[vmap_np]
|
||||
out_vmap_list.append(vmap_np)
|
||||
out_indices[global_face_idx] = local_faces_np + v_cursor
|
||||
v_cursor += vmap_np.shape[0]
|
||||
|
||||
vmapping_out = np.concatenate(out_vmap_list) if out_vmap_list else np.empty(0, dtype=np.int64)
|
||||
uvs_out = np.concatenate(out_uvs_list) if out_uvs_list else np.empty((0, 2), dtype=np.float32)
|
||||
return vmapping_out, out_indices, uvs_out
|
||||
|
||||
|
||||
class UnwrapMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="UnwrapMesh",
|
||||
display_name="Unwrap Mesh UVs",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Generates a UV atlas (pure-torch, no xatlas dependency): segments the surface into "
|
||||
"charts, parameterizes each, and packs them into a [0,1] atlas. Verts on chart seams "
|
||||
"are duplicated. Run after DecimateMesh/RemeshMesh, before texture baking."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Combo.Input("segmenter", options=["pec", "adaptive"], default="pec",
|
||||
tooltip="pec: fast parallel-edge-collapse charting (CUDA; falls back to "
|
||||
"adaptive on CPU). adaptive: CPU charting, slower."),
|
||||
IO.Int.Input("resolution", default=1024, min=0, max=8192, step=256,
|
||||
tooltip="Target atlas resolution used to auto-scale texel density (0 = fit-to-content)."),
|
||||
IO.Int.Input("padding", default=1, min=0, max=16,
|
||||
tooltip="Texel padding between charts in the packed atlas."),
|
||||
IO.Float.Input("weld_distance", default=0.0, min=0.0, max=1.0, step=0.0001,
|
||||
tooltip="Merge radius for coincident verts as a fraction of mesh extent "
|
||||
"(0 = auto, 1e-5). Raise to ~0.001 if you get per-triangle charts "
|
||||
"(unwelded / triangle-soup input)."),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
hidden=[IO.Hidden.unique_id],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, segmenter, resolution, padding, weld_distance):
|
||||
compute_device = comfy.model_management.get_torch_device()
|
||||
seg = segmenter
|
||||
if seg == "pec" and compute_device.type != "cuda":
|
||||
seg = "adaptive"
|
||||
seg_device = compute_device if seg == "pec" else torch.device("cpu")
|
||||
|
||||
is_list = isinstance(mesh.vertices, list)
|
||||
is_batched = not is_list and mesh.vertices.ndim == 3
|
||||
bsz = len(mesh.vertices) if is_list else (mesh.vertices.shape[0] if is_batched else 1)
|
||||
bar = comfy.utils.ProgressBar(bsz)
|
||||
|
||||
out_v, out_f, out_uv, out_c = [], [], [], []
|
||||
for i in range(bsz):
|
||||
if is_list or is_batched:
|
||||
vi, fi = mesh.vertices[i], mesh.faces[i]
|
||||
ci = None
|
||||
vc = getattr(mesh, "vertex_colors", None)
|
||||
if vc is not None:
|
||||
ci = vc[i] if (isinstance(vc, list) or vc.ndim == 3) else vc
|
||||
else:
|
||||
vi, fi = mesh.vertices, mesh.faces
|
||||
ci = getattr(mesh, "vertex_colors", None)
|
||||
|
||||
src_device = vi.device
|
||||
vnp = vi.detach().cpu().numpy().astype(np.float32)
|
||||
extent = float(np.linalg.norm(vnp.max(0) - vnp.min(0))) if vnp.shape[0] else 0.0
|
||||
weld_abs = weld_distance * extent if weld_distance > 0.0 else 0.0
|
||||
|
||||
vmapping, indices, uvs = _uv_unwrap(
|
||||
vi.to(seg_device).float(), fi.to(seg_device).long(),
|
||||
seg, int(resolution), int(padding), weld_abs)
|
||||
uvs = uvs.copy()
|
||||
uvs[:, 1] = 1.0 - uvs[:, 1] # UV y flipped vs trimesh
|
||||
|
||||
out_v.append(torch.from_numpy(vnp[vmapping]).to(src_device))
|
||||
out_f.append(torch.from_numpy(indices).to(device=src_device, dtype=torch.long))
|
||||
out_uv.append(torch.from_numpy(uvs.astype(np.float32)).to(src_device))
|
||||
if ci is not None:
|
||||
cnp = ci.detach().cpu().numpy()
|
||||
out_c.append(torch.from_numpy(np.ascontiguousarray(cnp[vmapping])).to(src_device))
|
||||
bar.update(1)
|
||||
|
||||
out_mesh = _pack_uv_meshes(out_v, out_f, out_uv, out_c if out_c else None)
|
||||
if getattr(mesh, "texture", None) is not None:
|
||||
out_mesh.texture = mesh.texture
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"UV: {out_v[0].shape[0]}v / {out_f[0].shape[0]}f, atlas ~{resolution}px",
|
||||
cls.hidden.unique_id)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
|
||||
def _uv_sorted_edge_keys(indices: np.ndarray):
|
||||
"""Undirected edge keys per face-edge, sorted; returns (sorted_keys, face_id, lo, hi, first_mask)."""
|
||||
a = indices.ravel().astype(np.int64)
|
||||
b = np.roll(indices, -1, axis=1).ravel().astype(np.int64)
|
||||
lo = np.minimum(a, b)
|
||||
hi = np.maximum(a, b)
|
||||
V = int(indices.max()) + 1
|
||||
key = lo * V + hi
|
||||
order = np.argsort(key, kind="stable")
|
||||
sk = key[order]
|
||||
fid = (np.arange(a.size, dtype=np.int64) // 3)[order]
|
||||
first = np.ones(sk.size, dtype=bool)
|
||||
first[1:] = sk[1:] != sk[:-1]
|
||||
return sk, fid, lo[order], hi[order], first
|
||||
|
||||
|
||||
def _uv_faces_to_chart_ids(indices: np.ndarray) -> np.ndarray:
|
||||
"""Chart = connected component of faces adjacent iff they share a (non-seam-duplicated) UV vertex."""
|
||||
F = indices.shape[0]
|
||||
if F == 0:
|
||||
return np.empty(0, dtype=np.int64)
|
||||
_sk, fid, _lo, _hi, first = _uv_sorted_edge_keys(indices)
|
||||
group_id = np.cumsum(first) - 1
|
||||
starts = np.nonzero(first)[0]
|
||||
rows = fid[starts[group_id[~first]]]
|
||||
cols = fid[~first]
|
||||
if rows.size == 0:
|
||||
return np.arange(F, dtype=np.int64)
|
||||
adj = csr_matrix((np.ones(rows.size, dtype=np.int8), (rows, cols)), shape=(F, F))
|
||||
_, labels = connected_components(adj, directed=False)
|
||||
return labels.astype(np.int64)
|
||||
|
||||
|
||||
_UV_TAB20 = np.array([
|
||||
[0.121568627, 0.466666667, 0.705882353], [0.682352941, 0.780392157, 0.909803922],
|
||||
[1.000000000, 0.498039216, 0.054901961], [1.000000000, 0.733333333, 0.470588235],
|
||||
[0.172549020, 0.627450980, 0.172549020], [0.596078431, 0.874509804, 0.541176471],
|
||||
[0.839215686, 0.152941176, 0.156862745], [1.000000000, 0.596078431, 0.588235294],
|
||||
[0.580392157, 0.403921569, 0.741176471], [0.772549020, 0.690196078, 0.835294118],
|
||||
[0.549019608, 0.337254902, 0.294117647], [0.768627451, 0.611764706, 0.580392157],
|
||||
[0.890196078, 0.466666667, 0.760784314], [0.968627451, 0.713725490, 0.823529412],
|
||||
[0.498039216, 0.498039216, 0.498039216], [0.780392157, 0.780392157, 0.780392157],
|
||||
[0.737254902, 0.741176471, 0.133333333], [0.858823529, 0.858823529, 0.552941176],
|
||||
[0.090196078, 0.745098039, 0.811764706], [0.619607843, 0.854901961, 0.898039216],
|
||||
], dtype=np.float32)
|
||||
|
||||
|
||||
def _uv_palette(n: int) -> np.ndarray:
|
||||
rng = np.random.RandomState(42)
|
||||
perm = rng.permutation(max(1, n))
|
||||
out = np.empty((n, 3), dtype=np.float32)
|
||||
for i in range(n):
|
||||
out[i] = _UV_TAB20[perm[i % len(perm)] % 20]
|
||||
return out
|
||||
|
||||
|
||||
def _uv_render_atlas(uvs_np, indices_np, resolution, device,
|
||||
bg=(0.13, 0.13, 0.13), edge=(0.0, 0.0, 0.0)):
|
||||
"""Tile-based torch rasterizer of the UV atlas (charts colored, borders outlined); returns (H,W,3)."""
|
||||
w = h = int(resolution)
|
||||
chart_ids_np = _uv_faces_to_chart_ids(indices_np)
|
||||
uvs = torch.from_numpy(uvs_np).to(device=device, dtype=torch.float32)
|
||||
indices = torch.from_numpy(indices_np).to(device=device, dtype=torch.long)
|
||||
chart_ids = torch.from_numpy(chart_ids_np).to(device=device, dtype=torch.long)
|
||||
|
||||
img = torch.zeros((h, w, 3), dtype=torch.float32, device=device)
|
||||
img[..., 0] = bg[0]; img[..., 1] = bg[1]; img[..., 2] = bg[2]
|
||||
if indices.numel() == 0:
|
||||
return img
|
||||
|
||||
n_charts = int(chart_ids.max().item()) + 1 if chart_ids.numel() else 1
|
||||
colors = torch.from_numpy(_uv_palette(n_charts)).to(device=device, dtype=torch.float32)
|
||||
|
||||
uv_px = uvs.clone()
|
||||
uv_px[:, 0] = uv_px[:, 0].clamp(0.0, 1.0) * (w - 1)
|
||||
uv_px[:, 1] = uv_px[:, 1].clamp(0.0, 1.0) * (h - 1)
|
||||
|
||||
tri = uv_px[indices]
|
||||
x0 = tri[:, 0, 0]; y0 = tri[:, 0, 1]
|
||||
x1 = tri[:, 1, 0]; y1 = tri[:, 1, 1]
|
||||
x2 = tri[:, 2, 0]; y2 = tri[:, 2, 1]
|
||||
denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2)
|
||||
nondegen = denom.abs() > 1e-20
|
||||
|
||||
xmin = torch.minimum(torch.minimum(x0, x1), x2).floor().clamp_(0, w - 1).long()
|
||||
xmax = torch.maximum(torch.maximum(x0, x1), x2).ceil().clamp_(0, w - 1).long()
|
||||
ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, h - 1).long()
|
||||
ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, h - 1).long()
|
||||
|
||||
# full point-in-triangle over all (pixel, tri) pairs is O(H*W*F); tile and test only bbox-overlapping tris
|
||||
TILE = 64
|
||||
eps = 1e-6
|
||||
for ty in range(0, h, TILE):
|
||||
ty_end = min(ty + TILE, h)
|
||||
for tx in range(0, w, TILE):
|
||||
tx_end = min(tx + TILE, w)
|
||||
tri_mask = (nondegen & (xmin < tx_end) & (xmax >= tx)
|
||||
& (ymin < ty_end) & (ymax >= ty))
|
||||
if not tri_mask.any():
|
||||
continue
|
||||
idx = torch.nonzero(tri_mask, as_tuple=True)[0]
|
||||
ys = torch.arange(ty, ty_end, dtype=torch.float32, device=device) + 0.5
|
||||
xs = torch.arange(tx, tx_end, dtype=torch.float32, device=device) + 0.5
|
||||
yy, xx = torch.meshgrid(ys, xs, indexing="ij")
|
||||
sub_x0 = x0[idx][:, None, None]; sub_y0 = y0[idx][:, None, None]
|
||||
sub_x1 = x1[idx][:, None, None]; sub_y1 = y1[idx][:, None, None]
|
||||
sub_x2 = x2[idx][:, None, None]; sub_y2 = y2[idx][:, None, None]
|
||||
sub_den = denom[idx][:, None, None]
|
||||
bx = ((sub_y1 - sub_y2) * (xx - sub_x2) + (sub_x2 - sub_x1) * (yy - sub_y2)) / sub_den
|
||||
by = ((sub_y2 - sub_y0) * (xx - sub_x2) + (sub_x0 - sub_x2) * (yy - sub_y2)) / sub_den
|
||||
bz = 1.0 - bx - by
|
||||
inside = (bx >= -eps) & (by >= -eps) & (bz >= -eps)
|
||||
if not inside.any():
|
||||
continue
|
||||
hit_any = inside.any(dim=0)
|
||||
best_tri = idx[inside.int().argmax(dim=0)]
|
||||
tile_color = colors[chart_ids[best_tri]]
|
||||
tile_img = img[ty:ty_end, tx:tx_end]
|
||||
tile_img[hit_any] = tile_color[hit_any]
|
||||
img[ty:ty_end, tx:tx_end] = tile_img
|
||||
|
||||
# chart outlines: a chart border is an open boundary in UV space (seam verts duplicated) → edges with 1 incident face
|
||||
_sk, _fid, lo, hi, first = _uv_sorted_edge_keys(indices_np)
|
||||
starts = np.nonzero(first)[0]
|
||||
counts = np.diff(np.append(starts, first.size))
|
||||
boundary = counts == 1
|
||||
uv_cpu = uv_px.cpu().numpy()
|
||||
px_xs, px_ys = [], []
|
||||
for a, b in zip(lo[starts[boundary]], hi[starts[boundary]]):
|
||||
p0 = uv_cpu[a]; p1 = uv_cpu[b]
|
||||
steps = int(max(abs(p1[0] - p0[0]), abs(p1[1] - p0[1])) + 1)
|
||||
if steps <= 1:
|
||||
continue
|
||||
ts = np.linspace(0.0, 1.0, steps)
|
||||
xs = (p0[0] + (p1[0] - p0[0]) * ts).astype(np.int32)
|
||||
ys = (p0[1] + (p1[1] - p0[1]) * ts).astype(np.int32)
|
||||
valid = (xs >= 0) & (xs < w) & (ys >= 0) & (ys < h)
|
||||
px_xs.append(xs[valid]); px_ys.append(ys[valid])
|
||||
if px_xs:
|
||||
xs_all = torch.from_numpy(np.concatenate(px_xs)).to(device=device, dtype=torch.long)
|
||||
ys_all = torch.from_numpy(np.concatenate(px_ys)).to(device=device, dtype=torch.long)
|
||||
img[ys_all, xs_all] = torch.tensor(edge, dtype=torch.float32, device=device)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class RenderUVAtlas(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RenderUVAtlas",
|
||||
display_name="Render UV Atlas",
|
||||
category="latent/3d",
|
||||
description=("Renders a mesh's UV layout as an image — each chart a distinct color, "
|
||||
"outlined where it borders other charts. Run UnwrapMesh first."),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Int.Input("resolution", default=1024, min=64, max=4096, step=64),
|
||||
],
|
||||
outputs=[IO.Image.Output("image")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, resolution):
|
||||
uvs_t = getattr(mesh, "uvs", None)
|
||||
if uvs_t is None:
|
||||
raise RuntimeError("mesh has no UVs to render. Run UnwrapMesh first.")
|
||||
uvs_np = uvs_t.detach().cpu().numpy()
|
||||
if uvs_np.ndim == 3:
|
||||
uvs_np = uvs_np[0]
|
||||
f = mesh.faces
|
||||
if torch.is_tensor(f):
|
||||
f = f.detach().cpu().numpy()
|
||||
if f.ndim == 3:
|
||||
f = f[0]
|
||||
f = np.ascontiguousarray(f, dtype=np.int64)
|
||||
uvs_np = np.ascontiguousarray(uvs_np, dtype=np.float32)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
img = _uv_render_atlas(uvs_np, f, int(resolution), device)
|
||||
return IO.NodeOutput(img.detach().cpu().unsqueeze(0))
|
||||
|
||||
|
||||
class FillHoles(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -2379,6 +2950,9 @@ class PostProcessMeshExtension(ComfyExtension):
|
||||
FillHolesV2,
|
||||
WeldVertices,
|
||||
DecimateMesh,
|
||||
RemeshMesh,
|
||||
UnwrapMesh,
|
||||
RenderUVAtlas,
|
||||
PaintMesh,
|
||||
BakeTextureFromVoxel,
|
||||
MeshTextureToImage,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user