ComfyUI/comfy_extras/mesh3d/uv_unwrap/parameterize.py
2026-06-17 00:59:58 +03:00

388 lines
14 KiB
Python

"""Chart parameterization: ortho PCA projection, falling back to ABF/LSCM."""
from __future__ import annotations
import warnings
from typing import List, Tuple
import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as spla
import torch
from torch import Tensor
from . import mesh as _mesh
def solve_least_squares(A: sp.csr_matrix, b: np.ndarray) -> np.ndarray:
"""Solve ||Ax - b||^2 by factorizing AtA."""
At = A.T.tocsr()
AtA = (At @ A).tocsc()
Atb = At @ b
return spla.spsolve(AtA, Atb)
def _triangle_local_2d(verts_3d: np.ndarray, faces: np.ndarray) -> np.ndarray:
"""Per-triangle 2D coords [F, 3, 2] with v0 at origin, v1 along +x."""
v0 = verts_3d[faces[:, 0]]
v1 = verts_3d[faces[:, 1]]
v2 = verts_3d[faces[:, 2]]
e01 = v1 - v0
e02 = v2 - v0
L01 = np.linalg.norm(e01, axis=1).clip(min=1e-20)
x_axis = e01 / L01[:, None]
n = np.cross(e01, e02)
n /= np.linalg.norm(n, axis=1, keepdims=True).clip(min=1e-20)
y_axis = np.cross(n, x_axis)
out = np.zeros((faces.shape[0], 3, 2), dtype=np.float64)
out[:, 1, 0] = L01
out[:, 2, 0] = (e02 * x_axis).sum(axis=1)
out[:, 2, 1] = (e02 * y_axis).sum(axis=1)
return out
def _pick_pins(loops: List[List[int]], verts_3d: np.ndarray) -> Tuple[int, int]:
"""Pick the longest-diameter axis-extremal boundary vertex pair across all boundary verts."""
if not loops:
# Closed surface: two far verts via two-pass farthest.
d2 = np.sum((verts_3d - verts_3d[0]) ** 2, axis=1)
a = int(np.argmax(d2))
d2 = np.sum((verts_3d - verts_3d[a]) ** 2, axis=1)
b = int(np.argmax(d2))
return a, b
boundary_verts: List[int] = []
for loop in loops:
boundary_verts.extend(loop)
seen = set()
uniq = []
for v in boundary_verts:
if v not in seen:
seen.add(v)
uniq.append(v)
bv = np.asarray(uniq, dtype=np.int64)
pts = verts_3d[bv]
pin_pairs = []
for axis in range(3):
i_min = int(bv[int(np.argmin(pts[:, axis]))])
i_max = int(bv[int(np.argmax(pts[:, axis]))])
d = float(np.linalg.norm(verts_3d[i_min] - verts_3d[i_max]))
pin_pairs.append((d, i_min, i_max))
d0, _, _ = pin_pairs[0]
d1, _, _ = pin_pairs[1]
d2, _, _ = pin_pairs[2]
if d0 > d1 and d0 > d2:
_, a, b = pin_pairs[0]
elif d1 > d2:
_, a, b = pin_pairs[1]
else:
_, a, b = pin_pairs[2]
return a, b
def _ortho_project(verts_3d: np.ndarray) -> np.ndarray:
"""PCA-fit plane normal, axis-aligned tangent, project verts to 2D."""
centroid = verts_3d.mean(axis=0)
pts = verts_3d - centroid
cov = pts.T @ pts
_w, ev = np.linalg.eigh(cov)
normal = ev[:, 0]
a = np.abs(normal)
if a[0] < a[1] and a[0] < a[2]:
t = np.array([1.0, 0.0, 0.0])
elif a[1] < a[2]:
t = np.array([0.0, 1.0, 0.0])
else:
t = np.array([0.0, 0.0, 1.0])
t = t - normal * float(np.dot(normal, t))
t /= max(float(np.linalg.norm(t)), 1e-20)
b = np.cross(normal, t)
return np.stack([verts_3d @ t, verts_3d @ b], axis=1)
def _stretch_metrics(verts_3d: np.ndarray, uvs: np.ndarray, faces: np.ndarray) -> Tuple[float, float, int, int]:
"""Sander's stretch metric. Returns (rms, max, n_flipped, n_zero_area)."""
p = verts_3d[faces]
t = uvs[faces]
parametric_area = 0.5 * (
(t[:, 1, 1] - t[:, 0, 1]) * (t[:, 2, 0] - t[:, 0, 0])
- (t[:, 2, 1] - t[:, 0, 1]) * (t[:, 1, 0] - t[:, 0, 0])
)
n_flipped = int((parametric_area < -1e-12).sum())
n_zero = int((np.abs(parametric_area) < 1e-12).sum())
pa = np.abs(parametric_area).clip(min=1e-20)
geom_area = 0.5 * np.linalg.norm(
np.cross(p[:, 1] - p[:, 0], p[:, 2] - p[:, 0]), axis=1
)
keep = (geom_area > 1e-12) & (np.abs(parametric_area) > 1e-12)
if not keep.any():
return float("inf"), float("inf"), n_flipped, n_zero
t1 = t[:, 0, 0]; s1 = t[:, 0, 1]
t2 = t[:, 1, 0]; s2 = t[:, 1, 1]
t3 = t[:, 2, 0]; s3 = t[:, 2, 1]
inv_2pa = 1.0 / (2.0 * pa)
Ss = (
p[:, 0] * (t2 - t3)[:, None]
+ p[:, 1] * (t3 - t1)[:, None]
+ p[:, 2] * (t1 - t2)[:, None]
) * inv_2pa[:, None]
St = (
p[:, 0] * (s3 - s2)[:, None]
+ p[:, 1] * (s1 - s3)[:, None]
+ p[:, 2] * (s2 - s1)[:, None]
) * inv_2pa[:, None]
a = (Ss * Ss).sum(axis=1)
bb = (Ss * St).sum(axis=1)
c = (St * St).sum(axis=1)
sigma2_sq = 0.5 * (a + c + np.sqrt(np.maximum(0.0, (a - c) ** 2 + 4 * bb ** 2)))
rms_sq = (a + c) * 0.5
rms_stretch_sq_sum = float((rms_sq[keep] * geom_area[keep]).sum())
total_geom = float(geom_area[keep].sum())
total_param = float(pa[keep].sum())
if total_geom <= 0.0:
return float("inf"), float("inf"), n_flipped, n_zero
norm_factor = np.sqrt(total_param / total_geom)
rms_stretch = float(np.sqrt(rms_stretch_sq_sum / total_geom)) * norm_factor
max_stretch = float(np.sqrt(sigma2_sq[keep].max())) * norm_factor
return rms_stretch, max_stretch, n_flipped, n_zero
def _uv_boundary_self_intersects(
uvs: np.ndarray, faces: np.ndarray, face_face: np.ndarray, eps: float = 1e-9
) -> bool:
"""True if any chart-boundary edge pair crosses in 2D (ortho folded the chart)."""
fi, ei = np.nonzero(face_face < 0)
n = fi.size
if n < 2:
return False
a = uvs[faces[fi, ei]].astype(np.float64)
b = uvs[faces[fi, (ei + 1) % 3]].astype(np.float64)
d = b - a
# Pairwise segment crossings, row-chunked to bound memory at chunk*n.
chunk = max(1, min(n, 1_000_000 // max(n, 1)))
for s in range(0, n, chunk):
e = min(s + chunk, n)
d1 = d[s:e, None, :]
denom = d1[:, :, 0] * d[None, :, 1] - d1[:, :, 1] * d[None, :, 0]
rx = a[None, :, 0] - a[s:e, None, 0]
ry = a[None, :, 1] - a[s:e, None, 1]
with np.errstate(divide="ignore", invalid="ignore"):
t = (rx * d[None, :, 1] - ry * d[None, :, 0]) / denom
u = (rx * d1[:, :, 1] - ry * d1[:, :, 0]) / denom
cross = (
(np.abs(denom) >= eps)
& (t > eps) & (t < 1.0 - eps)
& (u > eps) & (u < 1.0 - eps)
)
if bool(cross.any()):
return True
return False
def parametrize_chart(
local_verts: Tensor, local_faces: Tensor, local_face_face: Tensor
) -> Tensor:
"""Parameterize one chart: ortho first, ABF/LSCM fallback; charts <=5 faces stay ortho."""
verts_np = local_verts.detach().cpu().numpy().astype(np.float64)
faces_np = local_faces.detach().cpu().numpy().astype(np.int64)
if verts_np.shape[0] < 3 or faces_np.shape[0] == 0:
return torch.zeros((verts_np.shape[0], 2), dtype=torch.float32, device=local_verts.device)
ortho = _ortho_project(verts_np)
n_faces = faces_np.shape[0]
if n_faces <= 5:
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
rms, mx, n_flip, n_zero = _stretch_metrics(verts_np, ortho, faces_np)
flip_ok = n_flip == 0 or n_flip == n_faces
if flip_ok and n_zero == 0 and rms <= 1.5 and mx <= 2.0:
ff_np = local_face_face.detach().cpu().numpy().astype(np.int64)
if not _uv_boundary_self_intersects(ortho, faces_np, ff_np):
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
uvs_t = lscm_chart(local_verts, local_faces, local_face_face, pin_positions=ortho)
# Collapsed UV island (aspect > 100:1) blows up packing scale; fall back to ortho.
uvs_np = uvs_t.detach().cpu().numpy()
bbox = uvs_np.max(axis=0) - uvs_np.min(axis=0)
bbox_max = float(max(bbox[0], bbox[1], 1e-12))
bbox_min = float(max(min(bbox[0], bbox[1]), 1e-12))
if bbox_max / bbox_min > 100.0:
return torch.from_numpy(ortho.astype(np.float32)).to(local_verts.device)
return uvs_t
def _abf_face_coefficients(
verts_3d: np.ndarray, faces: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Per-face ABF constraint (largest-sine vertex at local index 2); returns (faces_reordered, cosine, sine, valid_mask) with valid_mask False for degenerate tris."""
Fc = faces.shape[0]
p0 = verts_3d[faces[:, 0]]
p1 = verts_3d[faces[:, 1]]
p2 = verts_3d[faces[:, 2]]
e01 = p1 - p0
e12 = p2 - p1
e20 = p0 - p2
L01 = np.linalg.norm(e01, axis=1).clip(min=1e-20)
L12 = np.linalg.norm(e12, axis=1).clip(min=1e-20)
L20 = np.linalg.norm(e20, axis=1).clip(min=1e-20)
cos_a0 = ((-e20) * e01).sum(axis=1) / (L20 * L01)
cos_a1 = ((-e01) * e12).sum(axis=1) / (L01 * L12)
cos_a2 = ((-e12) * e20).sum(axis=1) / (L12 * L20)
cos_a0 = cos_a0.clip(-1.0, 1.0)
cos_a1 = cos_a1.clip(-1.0, 1.0)
cos_a2 = cos_a2.clip(-1.0, 1.0)
a = np.arccos(cos_a0)
b_ang = np.arccos(cos_a1)
c_ang = np.arccos(cos_a2)
angles = np.stack([a, b_ang, c_ang], axis=1)
sines = np.stack([np.sin(a), np.sin(b_ang), np.sin(c_ang)], axis=1)
valid = (angles > 1e-12).all(axis=1)
ids = faces.astype(np.int64).copy()
s0, s1, s2 = sines[:, 0], sines[:, 1], sines[:, 2]
pattA = (s1 > s0) & (s1 > s2)
pattB = (~pattA) & (s0 > s1) & (s0 > s2)
if pattA.any():
old_a = angles[pattA].copy()
old_s = sines[pattA].copy()
old_id = ids[pattA].copy()
angles[pattA] = old_a[:, [2, 0, 1]]
sines[pattA] = old_s[:, [2, 0, 1]]
ids[pattA] = old_id[:, [2, 0, 1]]
if pattB.any():
old_a = angles[pattB].copy()
old_s = sines[pattB].copy()
old_id = ids[pattB].copy()
angles[pattB] = old_a[:, [1, 2, 0]]
sines[pattB] = old_s[:, [1, 2, 0]]
ids[pattB] = old_id[:, [1, 2, 0]]
a0 = angles[:, 0]
s0 = sines[:, 0]
s1 = sines[:, 1]
s2 = sines[:, 2]
c0 = np.cos(a0)
ratio = np.where(s2 > 0.0, s1 / s2.clip(min=1e-20), 1.0)
cosine = c0 * ratio
sine = s0 * ratio
return ids, cosine, sine, valid
def lscm_chart(
local_verts: Tensor,
local_faces: Tensor,
local_face_face: Tensor,
pin_positions: "np.ndarray | None" = None,
) -> Tensor:
"""ABF parameterization on one chart (degenerate faces use plain LSCM rows; two pins fix gauge at pin_positions)."""
verts_np = local_verts.detach().cpu().numpy().astype(np.float64)
faces_np = local_faces.detach().cpu().numpy().astype(np.int64)
Vc = verts_np.shape[0]
Fc = faces_np.shape[0]
if Vc < 3 or Fc == 0:
return torch.zeros((Vc, 2), dtype=torch.float32, device=local_verts.device)
loops = _mesh.chart_boundary_loops(local_faces, local_face_face)
pin_a, pin_b = _pick_pins(loops, verts_np)
if pin_positions is not None and pin_positions.shape == (Vc, 2):
pa = pin_positions[pin_a]
pb = pin_positions[pin_b]
u_a, v_a = float(pa[0]), float(pa[1])
u_b, v_b = float(pb[0]), float(pb[1])
else:
u_a, v_a = 0.0, 0.0
u_b, v_b = 1.0, 0.0
abf_ids, abf_cos, abf_sin, abf_valid = _abf_face_coefficients(verts_np, faces_np)
rows_list: List[np.ndarray] = []
cols_list: List[np.ndarray] = []
vals_list: List[np.ndarray] = []
# ABF rows for valid faces.
valid_idx = np.nonzero(abf_valid)[0]
if valid_idx.size:
Nv = valid_idx.size
id0 = abf_ids[valid_idx, 0]
id1 = abf_ids[valid_idx, 1]
id2 = abf_ids[valid_idx, 2]
cosf = abf_cos[valid_idx]
sinf = abf_sin[valid_idx]
r_real = valid_idx * 2
r_imag = valid_idx * 2 + 1
ones = np.ones(Nv, dtype=np.float64)
rows_list.extend([r_real] * 5)
cols_list.extend([id0, id0 + Vc, id1, id1 + Vc, id2])
vals_list.extend([cosf - 1.0, -sinf, -cosf, sinf, ones])
rows_list.extend([r_imag] * 5)
cols_list.extend([id0, id0 + Vc, id1, id1 + Vc, id2 + Vc])
vals_list.extend([sinf, cosf - 1.0, -sinf, -cosf, ones])
# Plain-LSCM rows for invalid (degenerate) faces.
invalid_idx = np.nonzero(~abf_valid)[0]
if invalid_idx.size:
tri2d_inv = _triangle_local_2d(verts_np, faces_np[invalid_idx])
twice_area_inv = (
tri2d_inv[:, 1, 0] * tri2d_inv[:, 2, 1]
- tri2d_inv[:, 1, 1] * tri2d_inv[:, 2, 0]
)
weight_inv = 1.0 / np.sqrt(2.0 * np.abs(twice_area_inv).clip(min=1e-20))
r_real_inv = invalid_idx * 2
r_imag_inv = invalid_idx * 2 + 1
for j in range(3):
jp1 = (j + 1) % 3
jp2 = (j + 2) % 3
a_j = (tri2d_inv[:, jp1, 0] - tri2d_inv[:, jp2, 0]) * weight_inv
b_j = (tri2d_inv[:, jp1, 1] - tri2d_inv[:, jp2, 1]) * weight_inv
v_idx = faces_np[invalid_idx, j]
rows_list.extend([r_real_inv, r_real_inv, r_imag_inv, r_imag_inv])
cols_list.extend([v_idx, v_idx + Vc, v_idx, v_idx + Vc])
vals_list.extend([a_j, -b_j, b_j, a_j])
rows = np.concatenate(rows_list) if rows_list else np.empty(0, dtype=np.int64)
cols = np.concatenate(cols_list) if cols_list else np.empty(0, dtype=np.int64)
vals = np.concatenate(vals_list) if vals_list else np.empty(0, dtype=np.float64)
A_full = sp.csr_matrix((vals, (rows, cols)), shape=(2 * Fc, 2 * Vc))
pin_cols = np.array([pin_a, pin_b, pin_a + Vc, pin_b + Vc], dtype=np.int64)
pin_vals = np.array([u_a, u_b, v_a, v_b], dtype=np.float64)
free_mask = np.ones(2 * Vc, dtype=bool)
free_mask[pin_cols] = False
free_cols = np.nonzero(free_mask)[0]
A_pinned = A_full[:, pin_cols]
A_free = A_full[:, free_cols]
b = -(A_pinned @ pin_vals)
# Singular system (under-constrained chart) falls back to ortho.
fallback_to_ortho = False
try:
with warnings.catch_warnings():
warnings.simplefilter("error", category=sp.linalg.MatrixRankWarning)
x_free = solve_least_squares(A_free, b)
if not np.all(np.isfinite(x_free)):
fallback_to_ortho = True
except Exception:
fallback_to_ortho = True
if fallback_to_ortho:
if pin_positions is not None and pin_positions.shape == (Vc, 2):
uvs = pin_positions.astype(np.float32)
else:
uvs = _ortho_project(verts_np).astype(np.float32)
return torch.from_numpy(uvs).to(local_verts.device)
full = np.zeros(2 * Vc, dtype=np.float64)
full[free_cols] = x_free
full[pin_cols] = pin_vals
uvs = np.stack([full[:Vc], full[Vc:]], axis=1).astype(np.float32)
if not np.all(np.isfinite(uvs)):
if pin_positions is not None and pin_positions.shape == (Vc, 2):
uvs = pin_positions.astype(np.float32)
else:
uvs = _ortho_project(verts_np).astype(np.float32)
return torch.from_numpy(uvs).to(local_verts.device)