mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
1713 lines
71 KiB
Python
1713 lines
71 KiB
Python
"""
|
||
Pure-PyTorch GPU-parallel QEM mesh simplification.
|
||
|
||
- Parallel greedy edge-matching collapse loop
|
||
- Plane/line/feature-edge/boundary quadrics, memoryless accumulation
|
||
- Normal-flip prevention, link-condition, skinny penalties
|
||
- Non-manifold/sliver handling without dropping faces
|
||
- Pre/post-clean pipeline (weld, degenerates, small components)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Tuple
|
||
|
||
import math
|
||
import time as _time
|
||
|
||
import numpy as _np
|
||
import torch
|
||
from scipy.sparse import coo_matrix
|
||
from scipy.sparse.csgraph import connected_components
|
||
from tqdm import tqdm as _tqdm
|
||
import comfy.utils as _comfy_utils
|
||
|
||
|
||
@dataclass
|
||
class QEMConfig:
|
||
# Precision
|
||
dtype: torch.dtype = torch.float32 # float64 much slower on consumer GPUs
|
||
|
||
# Numerical conditioning
|
||
stabilizer_scale: float = 1e-3 # Tikhonov reg: stabilizer = mesh_scale^2 * this
|
||
wander_threshold: float = 2.0 # fall back to midpoint if v* lands > N×edge_length from an endpoint
|
||
clamp_v_to_edge: bool = True # project v* onto the edge segment (qem mode only)
|
||
|
||
# Placement mode (also selects collapse driver):
|
||
# "midpoint" = threshold-schedule driver, most stable (defaults below match it);
|
||
# "qem" = sharpest, QEM-optimum placement + ratio driver.
|
||
placement_mode: str = "midpoint"
|
||
|
||
flip_reject_hard: bool = True # hard-reject (err=+inf) top-K collapses that flip any 1-ring normal
|
||
|
||
# Per-iteration batch sizing
|
||
sampling_cap: int = 10_000_000 # max edges processed per outer iter
|
||
max_collapses_fraction: float = 0.25 # of remaining faces-to-remove
|
||
max_collapses_floor: int = 10_000
|
||
max_collapses_ceiling: int = 1_000_000
|
||
max_collapses_relative_cap: float = 0.10 # cap per-iter collapses as fraction of current faces; 0 disables
|
||
|
||
# Loop control
|
||
max_iterations: int = 5_000
|
||
compaction_period: int = 5
|
||
compaction_threshold: float = 0.85 # compact when alive_frac < this
|
||
|
||
# Quality knobs
|
||
boundary_quadrics: bool = True
|
||
boundary_weight: float = 1000.0
|
||
recompute_normals_post: bool = True
|
||
line_quadric_weight: float = 0.0 # penalise deviation ⟂ to edge dir → more uniform verts; 0 disables
|
||
line_quadric_skip_opposite_normals_cos: float = 0.0 # skip line quadrics on edges with endpoint cos < this
|
||
|
||
# Feature-edge quadrics on sharp interior edges (dihedral > min); 0 disables.
|
||
feature_edge_quadric_weight: float = 0.0
|
||
feature_edge_min_dihedral_deg: float = 30.0
|
||
|
||
# Flip check (FA-QEM §3.3)
|
||
quality_topk_multiplier: int = 4 # quality-check band size = this * max_collapses_per_iter
|
||
flip_cos_threshold: float = 0.0 # 0 = count any sign reversal (dihedral > 90°)
|
||
flip_check_max_degree: int = 16 # cap on vertex degree for the flip-check table
|
||
|
||
# Triangle shape penalty
|
||
skinny_weight: float = 1e-3 # penalise top-K collapses producing needle/sliver tris; 0 disables
|
||
|
||
# Topology preservation
|
||
enforce_link_condition: bool = True # reject collapses that violate the link condition
|
||
|
||
# Quadric area weighting
|
||
area_weighted_quadrics: bool = False # True: Garland-Heckbert area-weighted; False: un-weighted
|
||
|
||
# edge-length cost regularizer
|
||
lambda_edge_length: float = 1e-2 # add λ*len² to bias toward short edges; 0 disables
|
||
lambda_edge_length_absolute: bool = True # apply λ absolutely vs relative-to-QEM-median
|
||
|
||
# Threshold-schedule driver (placement_mode == "midpoint"):
|
||
# each round collapses a disjoint set with cost <= thresh, ×10 when < 1% removed.
|
||
threshold_start: float = 1e-8
|
||
memoryless_qem: bool = True # rebuild quadrics each round vs accumulate
|
||
repair_nonmanifold: bool = True # final repair_non_manifold_edges pass
|
||
|
||
# Pre-clean (input mesh)
|
||
preclean: bool = True # weld coincident verts, drop degenerate/duplicate/unused
|
||
|
||
# Post-clean (output mesh)
|
||
postclean: bool = True # remove slivers, tiny components, unused verts left by collapse
|
||
postclean_min_angle_deg: float = 0.5
|
||
postclean_max_aspect_ratio: float = 100.0
|
||
postclean_min_component_faces: int = 8 # drop components with fewer faces than this
|
||
|
||
# Preclean tuning
|
||
preclean_weld_epsilon_rel: float = 1e-5 # weld tolerance as fraction of bbox diagonal
|
||
preclean_min_component_faces: int = 0 # 0 = keep all components
|
||
|
||
|
||
@property
|
||
def threshold_driver(self) -> bool:
|
||
"""The cost-threshold collapse driver is used by the midpoint placement mode."""
|
||
return self.placement_mode == "midpoint"
|
||
|
||
|
||
def _sorted_edge_halfedges(
|
||
faces: torch.Tensor, num_verts: int,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""3F half-edges sorted by key min(a,b)*(V+1)+max(a,b); returns (sorted_keys, face_ids, slot_ids)."""
|
||
device = faces.device
|
||
F = faces.shape[0]
|
||
e_all = torch.cat([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0)
|
||
e_sorted, _ = torch.sort(e_all, dim=1)
|
||
P = num_verts + 1
|
||
key = e_sorted[:, 0].long() * P + e_sorted[:, 1].long()
|
||
face_per_he = torch.arange(F, device=device, dtype=torch.long).repeat(3)
|
||
slot_per_he = torch.arange(3, device=device, dtype=torch.long).repeat_interleave(F)
|
||
sort_idx = torch.argsort(key)
|
||
return key[sort_idx], face_per_he[sort_idx], slot_per_he[sort_idx]
|
||
|
||
|
||
def _vert_is_boundary_mask(faces: torch.Tensor, num_verts: int) -> torch.Tensor:
|
||
"""(V,) bool mask: True for verts incident to any boundary edge."""
|
||
device = faces.device
|
||
out = torch.zeros(num_verts, dtype=torch.bool, device=device)
|
||
bedges = _detect_boundary_edges(faces, num_verts)
|
||
if bedges.numel() == 0:
|
||
return out
|
||
out[bedges[:, 0]] = True
|
||
out[bedges[:, 1]] = True
|
||
return out
|
||
|
||
|
||
def _detect_boundary_edges(faces: torch.Tensor, num_verts: int) -> torch.Tensor:
|
||
"""Boundary edges as [N, 2] of vertex indices (each appearing in exactly one face)."""
|
||
if faces.numel() == 0:
|
||
return torch.empty((0, 2), dtype=torch.int64, device=faces.device)
|
||
sorted_keys, _, _ = _sorted_edge_halfedges(faces, num_verts)
|
||
unique_key, counts = torch.unique(sorted_keys, return_counts=True)
|
||
boundary_key = unique_key[counts == 1]
|
||
if boundary_key.numel() == 0:
|
||
return torch.empty((0, 2), dtype=torch.int64, device=faces.device)
|
||
P = num_verts + 1
|
||
bv0 = boundary_key // P
|
||
bv1 = boundary_key % P
|
||
return torch.stack([bv0, bv1], dim=1)
|
||
|
||
|
||
def _manifold_edge_pairs(
|
||
sorted_keys: torch.Tensor, sorted_faces: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Edges shared by exactly 2 faces (filters >2-incident groups); returns (pair_keys, fa, fb)."""
|
||
if sorted_keys.shape[0] < 2:
|
||
empty = sorted_keys.new_empty(0)
|
||
return empty, empty, empty
|
||
pair_mask = sorted_keys[:-1] == sorted_keys[1:]
|
||
if not pair_mask.any():
|
||
empty = sorted_keys.new_empty(0)
|
||
return empty, empty, empty
|
||
pair_starts = torch.nonzero(pair_mask, as_tuple=True)[0]
|
||
# manifold iff neither neighbour half-edge shares the key
|
||
cur = sorted_keys[pair_starts]
|
||
prev_ok = (pair_starts == 0) | (sorted_keys[(pair_starts - 1).clamp_min(0)] != cur)
|
||
nxt_idx = (pair_starts + 2).clamp(max=sorted_keys.shape[0] - 1)
|
||
nxt_ok = (pair_starts + 2 >= sorted_keys.shape[0]) | (sorted_keys[nxt_idx] != cur)
|
||
pair_starts = pair_starts[prev_ok & nxt_ok]
|
||
return (sorted_keys[pair_starts],
|
||
sorted_faces[pair_starts],
|
||
sorted_faces[pair_starts + 1])
|
||
|
||
|
||
def _line_quadric_planes(
|
||
pa: torch.Tensor, pb: torch.Tensor
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""Two plane equations (E,4) per edge whose squared-dist sum = squared ⟂ distance to the edge line."""
|
||
e = pb - pa # (E, 3)
|
||
elen = torch.norm(e, dim=-1, keepdim=True).clamp_min(1e-12)
|
||
e_unit = e / elen # (E, 3)
|
||
m = 0.5 * (pa + pb) # (E, 3)
|
||
# helper axis not parallel to e_unit, then Gram-Schmidt against e_unit
|
||
helper = torch.zeros_like(e_unit)
|
||
helper.scatter_(-1, e_unit.abs().argmin(dim=-1, keepdim=True), 1.0)
|
||
u = helper - (helper * e_unit).sum(-1, keepdim=True) * e_unit
|
||
u = u / torch.norm(u, dim=-1, keepdim=True).clamp_min(1e-12)
|
||
w = torch.cross(e_unit, u, dim=-1)
|
||
d_u = -(u * m).sum(-1, keepdim=True)
|
||
d_w = -(w * m).sum(-1, keepdim=True)
|
||
p_u = torch.cat([u, d_u], dim=-1) # (E, 4)
|
||
p_w = torch.cat([w, d_w], dim=-1)
|
||
return p_u, p_w, elen.squeeze(-1)
|
||
|
||
|
||
def _add_line_quadrics(
|
||
verts: torch.Tensor,
|
||
faces: torch.Tensor,
|
||
face_areas: torch.Tensor,
|
||
Q_flat: torch.Tensor,
|
||
weight: float,
|
||
skip_he_mask: Optional[torch.Tensor] = None,
|
||
) -> torch.Tensor:
|
||
"""Add line quadrics for all 3F half-edges, weighted by face_area*weight; skip_he_mask zeroes True positions."""
|
||
a_all = torch.cat([faces[:, 0], faces[:, 1], faces[:, 2]], dim=0).long()
|
||
b_all = torch.cat([faces[:, 1], faces[:, 2], faces[:, 0]], dim=0).long()
|
||
pa = verts[a_all]
|
||
pb = verts[b_all]
|
||
p_u, p_w, _ = _line_quadric_planes(pa, pb)
|
||
area_per_edge = face_areas.repeat(3)
|
||
w_per_edge = area_per_edge * weight
|
||
if skip_he_mask is not None:
|
||
w_per_edge = torch.where(skip_he_mask, torch.zeros_like(w_per_edge), w_per_edge)
|
||
w_per_edge = w_per_edge.unsqueeze(-1).unsqueeze(-1)
|
||
K_line = (
|
||
p_u.unsqueeze(-1) * p_u.unsqueeze(-2)
|
||
+ p_w.unsqueeze(-1) * p_w.unsqueeze(-2)
|
||
) * w_per_edge
|
||
K_flat = K_line.reshape(-1, 16)
|
||
Q_flat.scatter_add_(0, a_all.unsqueeze(1).expand(-1, 16), K_flat) # scatter to both endpoints
|
||
Q_flat.scatter_add_(0, b_all.unsqueeze(1).expand(-1, 16), K_flat)
|
||
return Q_flat
|
||
|
||
|
||
def _build_quadrics(
|
||
verts: torch.Tensor,
|
||
faces: torch.Tensor,
|
||
cfg: QEMConfig,
|
||
) -> torch.Tensor:
|
||
"""Per-vertex area-weighted quadric (V, 4, 4)."""
|
||
V = verts.shape[0]
|
||
dtype = verts.dtype
|
||
device = verts.device
|
||
|
||
Q_flat = torch.zeros((V, 16), dtype=dtype, device=device)
|
||
|
||
if faces.numel() > 0:
|
||
v0 = verts[faces[:, 0]]
|
||
v1 = verts[faces[:, 1]]
|
||
v2 = verts[faces[:, 2]]
|
||
e1 = v1 - v0
|
||
e2 = v2 - v0
|
||
n = torch.cross(e1, e2, dim=-1)
|
||
area = torch.norm(n, dim=-1)
|
||
mask = area > 1e-12
|
||
# where() avoids boolean-index gather+scatter (fewer index kernels)
|
||
n_norm = torch.where(mask.unsqueeze(-1),
|
||
n / area.unsqueeze(-1).clamp_min(1e-12),
|
||
n.new_zeros(()))
|
||
d = -(n_norm * v0).sum(dim=-1, keepdim=True)
|
||
p = torch.cat([n_norm, d], dim=-1) # (F, 4)
|
||
K = torch.einsum("fi,fj->fij", p, p) # (F, 4, 4)
|
||
|
||
if cfg.area_weighted_quadrics:
|
||
K.mul_(area[:, None, None])
|
||
K_flat = K.reshape(-1, 16)
|
||
for corner in range(3):
|
||
idx = faces[:, corner].unsqueeze(1).expand(-1, 16)
|
||
Q_flat.scatter_add_(0, idx, K_flat)
|
||
|
||
# Line quadrics: squared ⟂ distance from v to the edge-midpoint line, all 3F half-edges in one pass.
|
||
if cfg.line_quadric_weight > 0 and faces.numel() > 0:
|
||
# skip thin-shell rim edges (endpoint normals oppose)
|
||
skip_he_sharp = None
|
||
if cfg.line_quadric_skip_opposite_normals_cos < 1.0:
|
||
v_norm = torch.zeros((V, 3), dtype=dtype, device=device)
|
||
n_weighted = n_norm * area.unsqueeze(-1) # normal * 2× area
|
||
for corner in range(3):
|
||
v_norm.scatter_add_(0, faces[:, corner].unsqueeze(-1).expand(-1, 3),
|
||
n_weighted)
|
||
v_norm = torch.nn.functional.normalize(v_norm, p=2, dim=-1, eps=1e-12)
|
||
a_he = torch.cat([faces[:, 0], faces[:, 1], faces[:, 2]], dim=0).long()
|
||
b_he = torch.cat([faces[:, 1], faces[:, 2], faces[:, 0]], dim=0).long()
|
||
cos_endpoints = (v_norm[a_he] * v_norm[b_he]).sum(dim=-1)
|
||
skip_he_sharp = cos_endpoints < cfg.line_quadric_skip_opposite_normals_cos
|
||
if not skip_he_sharp.any():
|
||
skip_he_sharp = None
|
||
Q_flat = _add_line_quadrics(verts, faces, area, Q_flat,
|
||
cfg.line_quadric_weight,
|
||
skip_he_mask=skip_he_sharp)
|
||
|
||
# Boundary line quadrics: pin boundary-edge endpoints to the boundary line.
|
||
if cfg.boundary_quadrics and faces.numel() > 0:
|
||
b_edges = _detect_boundary_edges(faces, V)
|
||
if b_edges.shape[0] > 0:
|
||
ba = b_edges[:, 0]
|
||
bb = b_edges[:, 1]
|
||
pa = verts[ba]
|
||
pb = verts[bb]
|
||
p_u, p_w, _ = _line_quadric_planes(pa, pb)
|
||
K_b = (torch.einsum("ei,ej->eij", p_u, p_u)
|
||
+ torch.einsum("ei,ej->eij", p_w, p_w)) * cfg.boundary_weight
|
||
K_b_flat = K_b.reshape(-1, 16)
|
||
Q_flat.scatter_add_(0, ba.unsqueeze(1).expand(-1, 16), K_b_flat)
|
||
Q_flat.scatter_add_(0, bb.unsqueeze(1).expand(-1, 16), K_b_flat)
|
||
|
||
# Feature-edge quadrics: line quadric on sharp interior edges weighted by (1 - cos(dihedral)).
|
||
if cfg.feature_edge_quadric_weight > 0 and faces.numel() > 0:
|
||
v0 = verts[faces[:, 0]]
|
||
v1 = verts[faces[:, 1]]
|
||
v2 = verts[faces[:, 2]]
|
||
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||
fn = torch.nn.functional.normalize(fn, p=2, dim=-1, eps=1e-12)
|
||
sorted_keys_fe, sorted_faces_fe, _ = _sorted_edge_halfedges(faces, V)
|
||
pair_keys, f1_idx, f2_idx = _manifold_edge_pairs(sorted_keys_fe, sorted_faces_fe)
|
||
if pair_keys.numel() > 0:
|
||
P = V + 1
|
||
edge_a = pair_keys // P
|
||
edge_b = pair_keys % P
|
||
cos_dihedral = (fn[f1_idx] * fn[f2_idx]).sum(dim=-1)
|
||
cos_thresh = math.cos(math.radians(cfg.feature_edge_min_dihedral_deg))
|
||
sharp = cos_dihedral < cos_thresh
|
||
if sharp.any():
|
||
fa = edge_a[sharp]
|
||
fb = edge_b[sharp]
|
||
p_u, p_w, _ = _line_quadric_planes(verts[fa], verts[fb])
|
||
sharpness = (1.0 - cos_dihedral[sharp]).clamp_min(0.0)
|
||
avg_area = 0.5 * (area[f1_idx[sharp]] + area[f2_idx[sharp]])
|
||
w = (avg_area * sharpness * cfg.feature_edge_quadric_weight) \
|
||
.unsqueeze(-1).unsqueeze(-1)
|
||
K_feat = (
|
||
p_u.unsqueeze(-1) * p_u.unsqueeze(-2)
|
||
+ p_w.unsqueeze(-1) * p_w.unsqueeze(-2)
|
||
) * w
|
||
K_flat = K_feat.reshape(-1, 16)
|
||
Q_flat.scatter_add_(0, fa.unsqueeze(1).expand(-1, 16), K_flat)
|
||
Q_flat.scatter_add_(0, fb.unsqueeze(1).expand(-1, 16), K_flat)
|
||
|
||
return Q_flat.reshape(V, 4, 4)
|
||
|
||
|
||
def _edge_errors(
|
||
verts: torch.Tensor,
|
||
Q: torch.Tensor,
|
||
edges: torch.Tensor,
|
||
stabilizer: float,
|
||
max_edge_length_sq: float,
|
||
mesh_scale_sq: float,
|
||
cfg: QEMConfig,
|
||
vert_is_boundary: Optional[torch.Tensor] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Returns (optimal_pos, error, valid_mask); vert_is_boundary enables boundary-aware midpoint."""
|
||
n_edges = edges.shape[0]
|
||
dtype = verts.dtype
|
||
device = verts.device
|
||
|
||
if n_edges == 0:
|
||
return (
|
||
torch.empty((0, 3), dtype=dtype, device=device),
|
||
torch.empty((0,), dtype=dtype, device=device),
|
||
torch.zeros((0,), dtype=torch.bool, device=device),
|
||
)
|
||
|
||
verts_pair = verts[edges] # (E, 2, 3)
|
||
pa = verts_pair[:, 0]
|
||
pb = verts_pair[:, 1]
|
||
edge_vec = pb - pa
|
||
el = torch.norm(edge_vec, dim=-1)
|
||
|
||
# boundary-aware midpoint: snap to the boundary endpoint when exactly one is boundary
|
||
if vert_is_boundary is not None:
|
||
ba = vert_is_boundary[edges[:, 0]]
|
||
bb = vert_is_boundary[edges[:, 1]]
|
||
w_a = torch.where(ba & ~bb, torch.ones_like(el),
|
||
torch.where(~ba & bb, torch.zeros_like(el),
|
||
torch.full_like(el, 0.5)))
|
||
midpoint = pa * w_a.unsqueeze(-1) + pb * (1.0 - w_a).unsqueeze(-1)
|
||
else:
|
||
midpoint = torch.lerp(pa, pb, 0.5)
|
||
|
||
Qe = Q[edges].sum(dim=1) # (E, 4, 4) — sum of Q[va] and Q[vb]
|
||
|
||
if cfg.placement_mode == "midpoint":
|
||
opt = midpoint
|
||
else:
|
||
A = Qe[:, :3, :3] + torch.eye(3, device=device, dtype=dtype) * stabilizer
|
||
b = -Qe[:, :3, 3].unsqueeze(-1)
|
||
|
||
# stabilizer keeps A invertible; full-batch solve, midpoint fallback via where (no sync)
|
||
sol = torch.linalg.solve(A, b)
|
||
dets = torch.det(A)
|
||
good = (dets.abs() > 1e-12).unsqueeze(-1)
|
||
opt = torch.where(good, sol.squeeze(-1), midpoint)
|
||
|
||
if cfg.clamp_v_to_edge:
|
||
# project v* onto the edge segment (subsumes the wander check)
|
||
edge_len_sq = (edge_vec * edge_vec).sum(dim=-1) + 1e-20
|
||
t = ((opt - pa) * edge_vec).sum(dim=-1) / edge_len_sq
|
||
t = t.clamp(0.0, 1.0).unsqueeze(-1)
|
||
opt = torch.lerp(pa, pb, t)
|
||
else:
|
||
# fall back to midpoint when v* wanders from both endpoints
|
||
dist_a = torch.norm(opt - pa, dim=-1)
|
||
dist_b = torch.norm(opt - pb, dim=-1)
|
||
wander_bad = ((dist_a > cfg.wander_threshold * el) |
|
||
(dist_b > cfg.wander_threshold * el)).unsqueeze(-1)
|
||
opt = torch.where(wander_bad, midpoint, opt)
|
||
|
||
v4 = torch.cat([opt, torch.ones((n_edges, 1), device=device, dtype=dtype)], dim=1)
|
||
err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4))
|
||
|
||
# mesh_scale_sq: Python float or 0-d tensor
|
||
if torch.is_tensor(mesh_scale_sq):
|
||
length_ok = el * el > mesh_scale_sq * 1e-10
|
||
else:
|
||
length_ok = el > math.sqrt(mesh_scale_sq) * 1e-5
|
||
error_ok = err < max_edge_length_sq
|
||
nan_ok = ~torch.isnan(opt).any(dim=-1) & ~torch.isnan(err)
|
||
valid = length_ok & error_ok & nan_ok
|
||
|
||
# edge-length regularizer: bias collapse order toward short edges (uniform sizing)
|
||
if cfg.lambda_edge_length > 0.0 and valid.any():
|
||
el2 = el * el
|
||
if cfg.lambda_edge_length_absolute:
|
||
err = err + cfg.lambda_edge_length * el2
|
||
else:
|
||
qem_med = err[valid].median()
|
||
len_med = el2[valid].median().clamp_min(1e-30)
|
||
err = err + cfg.lambda_edge_length * el2 * (qem_med / len_med)
|
||
return opt, err, valid
|
||
|
||
|
||
def _greedy_matching(
|
||
edges: torch.Tensor,
|
||
err: torch.Tensor,
|
||
v_alive: torch.Tensor,
|
||
max_select: int,
|
||
) -> torch.Tensor:
|
||
"""Vectorised independent edge-set selection: an edge wins iff it is the min-key edge at both endpoints."""
|
||
device = edges.device
|
||
n_edges = edges.shape[0]
|
||
if n_edges == 0:
|
||
return torch.empty(0, dtype=torch.int64, device=device)
|
||
|
||
va = edges[:, 0]
|
||
vb = edges[:, 1]
|
||
num_verts = v_alive.shape[0]
|
||
|
||
err32 = err.to(torch.float32).clamp(min=0).contiguous()
|
||
err_bits = err32.view(torch.int32).to(torch.int64) & 0xFFFFFFFF
|
||
edge_idx = torch.arange(n_edges, device=device, dtype=torch.int64)
|
||
key = (err_bits << 32) | edge_idx
|
||
|
||
INT64_MAX = torch.iinfo(torch.int64).max
|
||
best_key = torch.full((num_verts,), INT64_MAX, dtype=torch.int64, device=device)
|
||
best_key.scatter_reduce_(0, va, key, reduce="amin", include_self=True)
|
||
best_key.scatter_reduce_(0, vb, key, reduce="amin", include_self=True)
|
||
|
||
is_winner = (key == best_key[va]) & (key == best_key[vb]) & v_alive[va] & v_alive[vb]
|
||
sel = torch.nonzero(is_winner, as_tuple=True)[0]
|
||
|
||
if sel.numel() > max_select:
|
||
sel_err = err[sel]
|
||
top = torch.topk(sel_err, max_select, largest=False).indices
|
||
sel = sel[top]
|
||
return sel
|
||
|
||
|
||
def _build_vert_to_faces_pad(
|
||
faces: torch.Tensor,
|
||
num_verts: int,
|
||
max_deg: int,
|
||
) -> torch.Tensor:
|
||
"""Pad-CSR vertex-to-incident-faces table (V, max_deg) of face indices, -1 padded, degree truncated."""
|
||
device = faces.device
|
||
F = faces.shape[0]
|
||
if F == 0:
|
||
return torch.full((num_verts, max_deg), -1, dtype=torch.int64, device=device)
|
||
v_rep = faces.flatten().long()
|
||
f_rep = torch.arange(F, device=device, dtype=torch.int64).repeat_interleave(3)
|
||
sort_idx = v_rep.argsort()
|
||
sorted_v = v_rep[sort_idx]
|
||
sorted_f = f_rep[sort_idx]
|
||
offsets = torch.searchsorted(
|
||
sorted_v, torch.arange(num_verts + 1, device=device, dtype=sorted_v.dtype)
|
||
)
|
||
slot = torch.arange(sorted_v.shape[0], device=device, dtype=torch.int64) - offsets[sorted_v]
|
||
keep = slot < max_deg
|
||
table = torch.full((num_verts, max_deg), -1, dtype=torch.int64, device=device)
|
||
table[sorted_v[keep], slot[keep]] = sorted_f[keep]
|
||
return table
|
||
|
||
|
||
def _normal_flip_mask(
|
||
verts: torch.Tensor, # (V, 3)
|
||
faces: torch.Tensor, # (F, 3) — must be alive faces only
|
||
edges: torch.Tensor, # (E, 2) candidate collapse edges
|
||
opt: torch.Tensor, # (E, 3) proposed collapse positions
|
||
vert_to_faces: torch.Tensor, # (V, max_deg) face indices or -1
|
||
cos_threshold: float = 0.0,
|
||
chunk_size: int = 100_000,
|
||
return_count: bool = False,
|
||
) -> torch.Tensor:
|
||
"""(E,) bool mask (no adjacent-face flip), or int count of would-flip faces per edge if return_count."""
|
||
E = edges.shape[0]
|
||
device = verts.device
|
||
if return_count:
|
||
out = torch.zeros(E, dtype=torch.int32, device=device)
|
||
else:
|
||
out = torch.ones(E, dtype=torch.bool, device=device)
|
||
if E == 0:
|
||
return out
|
||
|
||
max_deg = vert_to_faces.shape[1]
|
||
a_all = edges[:, 0]
|
||
b_all = edges[:, 1]
|
||
|
||
for start in range(0, E, chunk_size):
|
||
stop = min(start + chunk_size, E)
|
||
Ec = stop - start
|
||
a = a_all[start:stop]
|
||
b = b_all[start:stop]
|
||
oc = opt[start:stop]
|
||
|
||
fa = vert_to_faces[a] # (Ec, max_deg)
|
||
fb = vert_to_faces[b]
|
||
all_f = torch.cat([fa, fb], dim=1) # (Ec, 2*max_deg)
|
||
valid_f = all_f >= 0
|
||
all_f_safe = all_f.clamp(min=0)
|
||
fv = faces[all_f_safe] # (Ec, 2*max_deg, 3)
|
||
|
||
a_b = a.view(Ec, 1)
|
||
b_b = b.view(Ec, 1)
|
||
s0_a = fv[..., 0] == a_b
|
||
s0_b = fv[..., 0] == b_b
|
||
s1_a = fv[..., 1] == a_b
|
||
s1_b = fv[..., 1] == b_b
|
||
s2_a = fv[..., 2] == a_b
|
||
s2_b = fv[..., 2] == b_b
|
||
contains_a = s0_a | s1_a | s2_a
|
||
contains_b = s0_b | s1_b | s2_b
|
||
# affected: face contains exactly one of {a, b} and slot is non-pad
|
||
affected = (contains_a ^ contains_b) & valid_f
|
||
if not affected.any():
|
||
continue
|
||
|
||
p0 = verts[fv[..., 0]] # (Ec, 2*max_deg, 3)
|
||
p1 = verts[fv[..., 1]]
|
||
p2 = verts[fv[..., 2]]
|
||
n_old = torch.cross(p1 - p0, p2 - p0, dim=-1)
|
||
|
||
opt_b = oc.view(Ec, 1, 3).expand(-1, 2 * max_deg, -1)
|
||
rep0 = (s0_a | s0_b).unsqueeze(-1)
|
||
rep1 = (s1_a | s1_b).unsqueeze(-1)
|
||
rep2 = (s2_a | s2_b).unsqueeze(-1)
|
||
p0n = torch.where(rep0, opt_b, p0)
|
||
p1n = torch.where(rep1, opt_b, p1)
|
||
p2n = torch.where(rep2, opt_b, p2)
|
||
n_new = torch.cross(p1n - p0n, p2n - p0n, dim=-1)
|
||
|
||
nlen_old = torch.norm(n_old, dim=-1)
|
||
nlen_new = torch.norm(n_new, dim=-1)
|
||
# degenerate-before faces can't flip; treat as OK
|
||
denom = nlen_old * nlen_new
|
||
safe = denom > 1e-20
|
||
cos = torch.where(safe, (n_old * n_new).sum(dim=-1) / denom.clamp_min(1e-20),
|
||
torch.ones_like(denom))
|
||
flip = (cos < cos_threshold) & affected & safe
|
||
if return_count:
|
||
out[start:stop] = flip.sum(dim=-1).to(torch.int32)
|
||
else:
|
||
out[start:stop] = ~flip.any(dim=-1)
|
||
|
||
return out
|
||
|
||
|
||
def _link_condition_mask(
|
||
faces: torch.Tensor, # (F, 3) alive faces only
|
||
edges: torch.Tensor, # (E, 2) candidate collapse edges
|
||
vert_to_faces: torch.Tensor, # (V, max_deg) face idx or -1
|
||
chunk_size: int = 100_000,
|
||
) -> torch.Tensor:
|
||
"""(E,) bool mask — True where the collapse is topology-safe (link condition: common neighbours <= edge faces)."""
|
||
E = edges.shape[0]
|
||
device = faces.device
|
||
out = torch.ones(E, dtype=torch.bool, device=device)
|
||
if E == 0:
|
||
return out
|
||
D = vert_to_faces.shape[1]
|
||
a_all = edges[:, 0]
|
||
b_all = edges[:, 1]
|
||
|
||
for s in range(0, E, chunk_size):
|
||
e = min(s + chunk_size, E)
|
||
a = a_all[s:e]
|
||
b = b_all[s:e]
|
||
Ec = a.shape[0]
|
||
|
||
fa = vert_to_faces[a] # (Ec, D)
|
||
fb = vert_to_faces[b]
|
||
fa_ok = fa >= 0
|
||
fb_ok = fb >= 0
|
||
fav = faces[fa.clamp(min=0)] # (Ec, D, 3)
|
||
fbv = faces[fb.clamp(min=0)]
|
||
|
||
# neighbour verts of a/b: take the 2 non-anchor verts per incident face → (Ec, 2D)
|
||
a_b = a[:, None]
|
||
b_b = b[:, None]
|
||
an1 = torch.where(fav[..., 0] == a_b, fav[..., 1], fav[..., 0])
|
||
an2 = torch.where(fav[..., 2] == a_b, fav[..., 1], fav[..., 2])
|
||
bn1 = torch.where(fbv[..., 0] == b_b, fbv[..., 1], fbv[..., 0])
|
||
bn2 = torch.where(fbv[..., 2] == b_b, fbv[..., 1], fbv[..., 2])
|
||
na = torch.stack([an1, an2], dim=-1).reshape(Ec, 2 * D)
|
||
nb = torch.stack([bn1, bn2], dim=-1).reshape(Ec, 2 * D)
|
||
fa_okx = fa_ok.repeat_interleave(2, dim=1)
|
||
fb_okx = fb_ok.repeat_interleave(2, dim=1)
|
||
na[(na == a_b) | (na == b_b) | ~fa_okx] = -1
|
||
nb[(nb == a_b) | (nb == b_b) | ~fb_okx] = -1
|
||
|
||
# common neighbours: na entries also appearing in nb
|
||
in_b = (na[:, :, None] == nb[:, None, :]) & (na[:, :, None] >= 0)
|
||
na_common = torch.where(in_b.any(dim=2), na, torch.full_like(na, -1))
|
||
# distinct count of common neighbours per edge (sort + count transitions)
|
||
cs, _ = na_common.sort(dim=1)
|
||
count_common = ((cs[:, 1:] != cs[:, :-1]) & (cs[:, 1:] >= 0)).sum(dim=1) \
|
||
+ (cs[:, :1] >= 0).sum(dim=1)
|
||
|
||
# faces on the edge = a's faces also containing b
|
||
count_faces = ((fav == b[:, None, None]).any(dim=2) & fa_ok).sum(dim=1)
|
||
|
||
out[s:e] = count_common <= count_faces
|
||
|
||
return out
|
||
|
||
|
||
def _skinny_penalty(
|
||
verts: torch.Tensor, # (V, 3)
|
||
faces: torch.Tensor, # (F, 3) — alive faces only
|
||
edges: torch.Tensor, # (E, 2) candidate collapse edges
|
||
opt: torch.Tensor, # (E, 3) proposed collapse positions
|
||
vert_to_faces: torch.Tensor, # (V, max_deg)
|
||
chunk_size: int = 100_000,
|
||
) -> torch.Tensor:
|
||
"""Per-edge post-collapse triangle-shape penalty (lambda_skinny); mean of 1 - clamp(shape,0,1) over the 1-ring."""
|
||
E = edges.shape[0]
|
||
device = verts.device
|
||
out = torch.zeros(E, dtype=verts.dtype, device=device)
|
||
if E == 0:
|
||
return out
|
||
|
||
max_deg = vert_to_faces.shape[1]
|
||
a_all = edges[:, 0]
|
||
b_all = edges[:, 1]
|
||
sqrt3_4 = 4.0 * math.sqrt(3.0)
|
||
|
||
for start in range(0, E, chunk_size):
|
||
stop = min(start + chunk_size, E)
|
||
Ec = stop - start
|
||
a = a_all[start:stop]
|
||
b = b_all[start:stop]
|
||
oc = opt[start:stop]
|
||
|
||
fa = vert_to_faces[a]
|
||
fb = vert_to_faces[b]
|
||
all_f = torch.cat([fa, fb], dim=1)
|
||
valid_f = all_f >= 0
|
||
all_f_safe = all_f.clamp(min=0)
|
||
fv = faces[all_f_safe]
|
||
|
||
a_b = a.view(Ec, 1)
|
||
b_b = b.view(Ec, 1)
|
||
s0_a = fv[..., 0] == a_b
|
||
s0_b = fv[..., 0] == b_b
|
||
s1_a = fv[..., 1] == a_b
|
||
s1_b = fv[..., 1] == b_b
|
||
s2_a = fv[..., 2] == a_b
|
||
s2_b = fv[..., 2] == b_b
|
||
contains_a = s0_a | s1_a | s2_a
|
||
contains_b = s0_b | s1_b | s2_b
|
||
# affected: face contains exactly one of {a, b} and slot is non-pad
|
||
affected = (contains_a ^ contains_b) & valid_f
|
||
if not affected.any():
|
||
continue
|
||
|
||
p0 = verts[fv[..., 0]]
|
||
p1 = verts[fv[..., 1]]
|
||
p2 = verts[fv[..., 2]]
|
||
opt_b = oc.view(Ec, 1, 3).expand(-1, 2 * max_deg, -1)
|
||
rep0 = (s0_a | s0_b).unsqueeze(-1)
|
||
rep1 = (s1_a | s1_b).unsqueeze(-1)
|
||
rep2 = (s2_a | s2_b).unsqueeze(-1)
|
||
p0n = torch.where(rep0, opt_b, p0)
|
||
p1n = torch.where(rep1, opt_b, p1)
|
||
p2n = torch.where(rep2, opt_b, p2)
|
||
|
||
e01 = p1n - p0n
|
||
e02 = p2n - p0n
|
||
e12 = p2n - p1n
|
||
two_area = torch.cross(e01, e02, dim=-1).norm(dim=-1)
|
||
edge_sum_sq = ((e01 * e01).sum(-1)
|
||
+ (e02 * e02).sum(-1)
|
||
+ (e12 * e12).sum(-1))
|
||
shape = (sqrt3_4 * 0.5 * two_area) / edge_sum_sq.clamp_min(1e-20)
|
||
term = 1.0 - shape.clamp(0.0, 1.0)
|
||
term = torch.where(affected, term, torch.zeros_like(term))
|
||
n_affected = affected.sum(dim=-1).clamp_min(1).to(term.dtype)
|
||
out[start:stop] = term.sum(dim=-1) / n_affected
|
||
|
||
return out
|
||
|
||
|
||
def _quality_checks_fused(
|
||
verts: torch.Tensor,
|
||
faces: torch.Tensor,
|
||
edges: torch.Tensor,
|
||
opt: torch.Tensor,
|
||
vert_to_faces: torch.Tensor,
|
||
cos_threshold: float = 0.0,
|
||
want_flip: bool = True,
|
||
want_skinny: bool = True,
|
||
want_link: bool = False,
|
||
chunk_size: int = 100_000,
|
||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||
"""Fused 1-ring checks (flip count / skinny / link) sharing one faces gather.
|
||
Returns (flip_count|None, skinny|None, link_safe|None)."""
|
||
E = edges.shape[0]
|
||
device = verts.device
|
||
flip_out = torch.zeros(E, dtype=torch.int32, device=device) if want_flip else None
|
||
skinny_out = torch.zeros(E, dtype=verts.dtype, device=device) if want_skinny else None
|
||
link_out = torch.ones(E, dtype=torch.bool, device=device) if want_link else None
|
||
if E == 0:
|
||
return flip_out, skinny_out, link_out
|
||
|
||
D = vert_to_faces.shape[1]
|
||
a_all = edges[:, 0]
|
||
b_all = edges[:, 1]
|
||
sqrt3_4 = 4.0 * math.sqrt(3.0)
|
||
need_geom = want_flip or want_skinny
|
||
|
||
for start in range(0, E, chunk_size):
|
||
stop = min(start + chunk_size, E)
|
||
Ec = stop - start
|
||
a = a_all[start:stop]
|
||
b = b_all[start:stop]
|
||
|
||
# shared gather of a's and b's incident faces (the expensive part)
|
||
fa = vert_to_faces[a]
|
||
fb = vert_to_faces[b]
|
||
all_f = torch.cat([fa, fb], dim=1) # (Ec, 2D)
|
||
valid_f = all_f >= 0
|
||
fv = faces[all_f.clamp(min=0)] # (Ec, 2D, 3)
|
||
a_b = a.view(Ec, 1)
|
||
b_b = b.view(Ec, 1)
|
||
|
||
if need_geom:
|
||
oc = opt[start:stop]
|
||
s0_a = fv[..., 0] == a_b
|
||
s0_b = fv[..., 0] == b_b
|
||
s1_a = fv[..., 1] == a_b
|
||
s1_b = fv[..., 1] == b_b
|
||
s2_a = fv[..., 2] == a_b
|
||
s2_b = fv[..., 2] == b_b
|
||
contains_a = s0_a | s1_a | s2_a
|
||
contains_b = s0_b | s1_b | s2_b
|
||
affected = (contains_a ^ contains_b) & valid_f
|
||
if affected.any():
|
||
p0 = verts[fv[..., 0]]
|
||
p1 = verts[fv[..., 1]]
|
||
p2 = verts[fv[..., 2]]
|
||
opt_b = oc.view(Ec, 1, 3).expand(-1, 2 * D, -1)
|
||
rep0 = (s0_a | s0_b).unsqueeze(-1)
|
||
rep1 = (s1_a | s1_b).unsqueeze(-1)
|
||
rep2 = (s2_a | s2_b).unsqueeze(-1)
|
||
p0n = torch.where(rep0, opt_b, p0)
|
||
p1n = torch.where(rep1, opt_b, p1)
|
||
p2n = torch.where(rep2, opt_b, p2)
|
||
|
||
# post-collapse normal (skinny's two_area == flip's ‖n_new‖)
|
||
e01 = p1n - p0n
|
||
e02 = p2n - p0n
|
||
n_new = torch.cross(e01, e02, dim=-1)
|
||
nlen_new = torch.norm(n_new, dim=-1)
|
||
|
||
if want_flip:
|
||
n_old = torch.cross(p1 - p0, p2 - p0, dim=-1)
|
||
nlen_old = torch.norm(n_old, dim=-1)
|
||
denom = nlen_old * nlen_new
|
||
safe = denom > 1e-20
|
||
cos = torch.where(safe, (n_old * n_new).sum(dim=-1) / denom.clamp_min(1e-20),
|
||
torch.ones_like(denom))
|
||
flip = (cos < cos_threshold) & affected & safe
|
||
flip_out[start:stop] = flip.sum(dim=-1).to(torch.int32)
|
||
|
||
if want_skinny:
|
||
e12 = p2n - p1n
|
||
edge_sum_sq = ((e01 * e01).sum(-1) + (e02 * e02).sum(-1) + (e12 * e12).sum(-1))
|
||
shape = (sqrt3_4 * 0.5 * nlen_new) / edge_sum_sq.clamp_min(1e-20)
|
||
term = 1.0 - shape.clamp(0.0, 1.0)
|
||
term = torch.where(affected, term, torch.zeros_like(term))
|
||
n_affected = affected.sum(dim=-1).clamp_min(1).to(term.dtype)
|
||
skinny_out[start:stop] = term.sum(dim=-1) / n_affected
|
||
|
||
if want_link:
|
||
# reuses fv / valid_f; matches _link_condition_mask
|
||
fa_ok = valid_f[:, :D]
|
||
fb_ok = valid_f[:, D:]
|
||
fav = fv[:, :D]
|
||
fbv = fv[:, D:]
|
||
an1 = torch.where(fav[..., 0] == a_b, fav[..., 1], fav[..., 0])
|
||
an2 = torch.where(fav[..., 2] == a_b, fav[..., 1], fav[..., 2])
|
||
bn1 = torch.where(fbv[..., 0] == b_b, fbv[..., 1], fbv[..., 0])
|
||
bn2 = torch.where(fbv[..., 2] == b_b, fbv[..., 1], fbv[..., 2])
|
||
na = torch.stack([an1, an2], dim=-1).reshape(Ec, 2 * D)
|
||
nb = torch.stack([bn1, bn2], dim=-1).reshape(Ec, 2 * D)
|
||
fa_okx = fa_ok.repeat_interleave(2, dim=1)
|
||
fb_okx = fb_ok.repeat_interleave(2, dim=1)
|
||
na[(na == a_b) | (na == b_b) | ~fa_okx] = -1
|
||
nb[(nb == a_b) | (nb == b_b) | ~fb_okx] = -1
|
||
in_b = (na[:, :, None] == nb[:, None, :]) & (na[:, :, None] >= 0)
|
||
na_common = torch.where(in_b.any(dim=2), na, torch.full_like(na, -1))
|
||
cs, _ = na_common.sort(dim=1)
|
||
count_common = ((cs[:, 1:] != cs[:, :-1]) & (cs[:, 1:] >= 0)).sum(dim=1) \
|
||
+ (cs[:, :1] >= 0).sum(dim=1)
|
||
count_faces = ((fav == b[:, None, None]).any(dim=2) & fa_ok).sum(dim=1)
|
||
link_out[start:stop] = count_common <= count_faces
|
||
|
||
return flip_out, skinny_out, link_out
|
||
|
||
|
||
def _compute_vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
|
||
if faces.numel() == 0:
|
||
return torch.zeros_like(verts)
|
||
faces_long = faces.to(torch.int64)
|
||
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
|
||
v0, v1, v2 = verts[i0], verts[i1], verts[i2]
|
||
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||
vn = torch.zeros_like(verts)
|
||
vn.scatter_add_(0, i0.unsqueeze(-1).expand_as(fn), fn)
|
||
vn.scatter_add_(0, i1.unsqueeze(-1).expand_as(fn), fn)
|
||
vn.scatter_add_(0, i2.unsqueeze(-1).expand_as(fn), fn)
|
||
return torch.nn.functional.normalize(vn, p=2, dim=-1, eps=1e-6)
|
||
|
||
|
||
# Public API
|
||
|
||
@dataclass
|
||
class CleanStats:
|
||
in_verts: int = 0
|
||
in_faces: int = 0
|
||
out_verts: int = 0
|
||
out_faces: int = 0
|
||
welded_verts: int = 0 # how many vertex IDs collapsed during welding
|
||
degenerate_faces: int = 0 # zero-area or repeated-index faces removed
|
||
duplicate_faces: int = 0 # same vertex-set removed
|
||
unused_verts: int = 0 # verts not in any face removed
|
||
components_dropped: int = 0 # disconnected components below threshold
|
||
seconds: float = 0.0
|
||
|
||
def __str__(self):
|
||
return (f"clean: in={self.in_verts}v/{self.in_faces}f -> "
|
||
f"out={self.out_verts}v/{self.out_faces}f "
|
||
f"(welded {self.welded_verts}v, degen {self.degenerate_faces}f, "
|
||
f"dup {self.duplicate_faces}f, unused {self.unused_verts}v, "
|
||
f"comps {self.components_dropped}) {self.seconds*1000:.1f}ms")
|
||
|
||
|
||
def _weld_vertices(
|
||
verts: torch.Tensor, faces: torch.Tensor, epsilon,
|
||
colors: Optional[torch.Tensor] = None,
|
||
normals: Optional[torch.Tensor] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], int]:
|
||
"""Merge vertices closer than epsilon (L_inf grid), cluster-averaging attributes; returns (v, f, colors, normals, n_welded)."""
|
||
if verts.shape[0] == 0:
|
||
return verts, faces, colors, normals, 0
|
||
device = verts.device
|
||
scale = 1.0 / epsilon
|
||
bbox_min = verts.min(dim=0)[0]
|
||
q = ((verts - bbox_min) * scale).round().to(torch.int64)
|
||
bbox = (verts.max(dim=0)[0] - bbox_min)
|
||
extent = (bbox * scale).round().to(torch.int64) + 2
|
||
key = (q[:, 0] * extent[1] + q[:, 1]) * extent[2] + q[:, 2] # pack 3D quantized pos to 1D key
|
||
unique_key, inv = torch.unique(key, return_inverse=True)
|
||
n_unique = unique_key.shape[0]
|
||
if n_unique == verts.shape[0]:
|
||
return verts, faces, colors, normals, 0
|
||
counts = torch.zeros(n_unique, dtype=verts.dtype, device=device)
|
||
counts.scatter_add_(0, inv, torch.ones(verts.shape[0], dtype=verts.dtype, device=device))
|
||
counts_div = counts.unsqueeze(-1).clamp_min(1.0)
|
||
|
||
new_verts = torch.zeros((n_unique, 3), dtype=verts.dtype, device=device)
|
||
new_verts.scatter_add_(0, inv.unsqueeze(-1).expand_as(verts), verts)
|
||
new_verts = new_verts / counts_div
|
||
|
||
new_colors = None
|
||
if colors is not None:
|
||
new_colors = torch.zeros((n_unique, colors.shape[1]), dtype=colors.dtype, device=device)
|
||
new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors)
|
||
new_colors = new_colors / counts_div.to(colors.dtype)
|
||
|
||
new_normals = None
|
||
if normals is not None:
|
||
new_normals = torch.zeros((n_unique, normals.shape[1]), dtype=normals.dtype, device=device)
|
||
new_normals.scatter_add_(0, inv.unsqueeze(-1).expand_as(normals), normals)
|
||
new_normals = torch.nn.functional.normalize(new_normals, p=2, dim=-1, eps=1e-6)
|
||
|
||
new_faces = inv[faces.long()] if faces.numel() > 0 else faces
|
||
return new_verts, new_faces, new_colors, new_normals, int(verts.shape[0] - n_unique)
|
||
|
||
|
||
def _drop_degenerate_faces(
|
||
verts: torch.Tensor, faces: torch.Tensor,
|
||
min_area: float = 1e-14,
|
||
) -> Tuple[torch.Tensor, int]:
|
||
"""Drop degenerate-by-construction faces (repeated indices or zero-area); slivers go to _collapse_slivers."""
|
||
if faces.numel() == 0:
|
||
return faces, 0
|
||
idx_bad = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 0] == faces[:, 2])
|
||
f_good = faces[~idx_bad]
|
||
v0 = verts[f_good[:, 0]]
|
||
v1 = verts[f_good[:, 1]]
|
||
v2 = verts[f_good[:, 2]]
|
||
e0 = v1 - v0
|
||
e2 = v0 - v2
|
||
area = 0.5 * torch.norm(torch.cross(e0, -e2, dim=-1), dim=-1)
|
||
bad = area < min_area
|
||
kept = f_good[~bad]
|
||
n_dropped = idx_bad.sum() + bad.sum() # tensor-scalar; caller .item()s once
|
||
return kept, n_dropped
|
||
|
||
|
||
def _collapse_slivers(
|
||
verts: torch.Tensor, faces: torch.Tensor,
|
||
min_angle_deg: float = 0.0,
|
||
max_aspect_ratio: float = 0.0,
|
||
) -> Tuple[torch.Tensor, int]:
|
||
"""Resolve sliver triangles by collapsing each sliver's shortest edge (no holes); returns (faces, n_collapsed)."""
|
||
if faces.numel() == 0 or (min_angle_deg <= 0 and max_aspect_ratio <= 0):
|
||
return faces, 0
|
||
|
||
fl = faces.long()
|
||
v0 = verts[fl[:, 0]]
|
||
v1 = verts[fl[:, 1]]
|
||
v2 = verts[fl[:, 2]]
|
||
e0 = v1 - v0
|
||
e1 = v2 - v1
|
||
e2 = v0 - v2
|
||
l0 = torch.norm(e0, dim=-1)
|
||
l1 = torch.norm(e1, dim=-1)
|
||
l2 = torch.norm(e2, dim=-1)
|
||
area = 0.5 * torch.norm(torch.cross(e0, -e2, dim=-1), dim=-1)
|
||
|
||
bad = torch.zeros(faces.shape[0], dtype=torch.bool, device=verts.device)
|
||
if max_aspect_ratio > 0:
|
||
max_edge = torch.maximum(torch.maximum(l0, l1), l2)
|
||
aspect = max_edge * max_edge / (2.0 * area + 1e-12)
|
||
bad = bad | (aspect > max_aspect_ratio)
|
||
if min_angle_deg > 0:
|
||
cos_a = (l1 * l1 + l2 * l2 - l0 * l0) / (2 * l1 * l2 + 1e-12)
|
||
cos_b = (l0 * l0 + l2 * l2 - l1 * l1) / (2 * l0 * l2 + 1e-12)
|
||
cos_c = (l0 * l0 + l1 * l1 - l2 * l2) / (2 * l0 * l1 + 1e-12)
|
||
cos_all = torch.stack([cos_a, cos_b, cos_c], dim=-1)
|
||
angles_deg = torch.acos(torch.clamp(cos_all, -1, 1)) * (180.0 / math.pi)
|
||
bad = bad | (angles_deg.min(dim=-1).values < min_angle_deg)
|
||
|
||
if not bad.any():
|
||
return faces, 0
|
||
|
||
# per sliver pick its shortest edge to collapse
|
||
edge_lens = torch.stack([l0, l1, l2], dim=-1) # (F, 3)
|
||
shortest_slot = edge_lens.argmin(dim=-1) # (F,) ∈ {0,1,2}
|
||
|
||
V = verts.shape[0]
|
||
# collapse higher-index endpoint into lower (min/max ordering avoids cycles)
|
||
merge_map = torch.arange(V, device=verts.device, dtype=torch.int64)
|
||
bad_idx = torch.nonzero(bad, as_tuple=True)[0]
|
||
for slot in range(3):
|
||
sel = bad_idx[shortest_slot[bad_idx] == slot]
|
||
if sel.numel() == 0:
|
||
continue
|
||
a = fl[sel, slot]
|
||
b = fl[sel, (slot + 1) % 3]
|
||
lo = torch.minimum(a, b)
|
||
hi = torch.maximum(a, b)
|
||
merge_map[hi] = lo # last-write-wins on conflict
|
||
|
||
# path-compress until stable
|
||
for _ in range(10):
|
||
new_map = merge_map[merge_map]
|
||
if torch.equal(new_map, merge_map):
|
||
break
|
||
merge_map = new_map
|
||
|
||
new_faces = merge_map[fl]
|
||
nondeg = ((new_faces[:, 0] != new_faces[:, 1]) &
|
||
(new_faces[:, 1] != new_faces[:, 2]) &
|
||
(new_faces[:, 0] != new_faces[:, 2]))
|
||
new_faces = new_faces[nondeg].to(dtype=faces.dtype)
|
||
return new_faces, bad.sum()
|
||
|
||
|
||
def _drop_duplicate_faces(faces: torch.Tensor, num_verts: int) -> Tuple[torch.Tensor, int]:
|
||
"""Remove duplicate faces (same vertex set), keeping the first occurrence (winding-preserving)."""
|
||
if faces.shape[0] <= 1:
|
||
return faces, 0
|
||
key_sorted = torch.sort(faces, dim=1)[0]
|
||
P = num_verts + 1
|
||
packed = (key_sorted[:, 0].long() * P + key_sorted[:, 1].long()) * P + key_sorted[:, 2].long()
|
||
unique_packed, inv = torch.unique(packed, return_inverse=True)
|
||
if unique_packed.shape[0] == faces.shape[0]:
|
||
return faces, 0
|
||
# first-occurrence index per unique key
|
||
arange = torch.arange(packed.shape[0], device=packed.device)
|
||
first = torch.full((unique_packed.shape[0],), packed.shape[0],
|
||
dtype=torch.int64, device=packed.device)
|
||
first.scatter_reduce_(0, inv, arange, reduce="amin", include_self=True)
|
||
kept = faces[first]
|
||
return kept, int(faces.shape[0] - kept.shape[0])
|
||
|
||
|
||
def _drop_unused_verts(
|
||
verts: torch.Tensor, faces: torch.Tensor,
|
||
colors: Optional[torch.Tensor] = None,
|
||
normals: Optional[torch.Tensor] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], int]:
|
||
"""Remove vertices not referenced by any face; remap faces and filter attributes."""
|
||
if verts.shape[0] == 0 or faces.numel() == 0:
|
||
return verts, faces, colors, normals, 0
|
||
used = torch.zeros(verts.shape[0], dtype=torch.bool, device=verts.device)
|
||
used[faces[:, 0]] = True
|
||
used[faces[:, 1]] = True
|
||
used[faces[:, 2]] = True
|
||
# cumsum compact remap: 0..N-1 to used verts in order
|
||
remap = used.long().cumsum(0) - 1
|
||
new_verts = verts[used]
|
||
new_faces = remap[faces.long()]
|
||
new_colors = colors[used] if colors is not None else None
|
||
new_normals = normals[used] if normals is not None else None
|
||
n_dropped = verts.shape[0] - used.sum()
|
||
return new_verts, new_faces, new_colors, new_normals, n_dropped
|
||
|
||
|
||
def _repair_nonmanifold_edges(
|
||
verts: torch.Tensor, faces: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""repair_non_manifold_edges: explode corners, re-merge only across manifold edges; returns (verts, faces, src)."""
|
||
if faces.numel() == 0:
|
||
return verts, faces
|
||
dev, vdt, fdt = verts.device, verts.dtype, faces.dtype
|
||
F = faces.detach().cpu().numpy().astype(_np.int64)
|
||
V = verts.detach().cpu().numpy()
|
||
nf = F.shape[0]
|
||
nv = V.shape[0]
|
||
corner_vert = F.reshape(-1) # (3F,) original vertex per corner
|
||
|
||
# per-face edges keyed by (vmin,vmax)
|
||
keys_l, ca_l, cb_l = [], [], []
|
||
for (i, j) in ((0, 1), (1, 2), (2, 0)):
|
||
va, vb = F[:, i], F[:, j]
|
||
ci = 3 * _np.arange(nf) + i
|
||
cj = 3 * _np.arange(nf) + j
|
||
amin = _np.where(va <= vb, ci, cj) # corner of the smaller-id endpoint
|
||
amax = _np.where(va <= vb, cj, ci)
|
||
vmin = _np.minimum(va, vb).astype(_np.int64)
|
||
vmax = _np.maximum(va, vb).astype(_np.int64)
|
||
keys_l.append(vmin * (nv + 1) + vmax)
|
||
ca_l.append(amin)
|
||
cb_l.append(amax)
|
||
keys = _np.concatenate(keys_l)
|
||
ca = _np.concatenate(ca_l)
|
||
cb = _np.concatenate(cb_l)
|
||
order = _np.argsort(keys, kind="stable")
|
||
keys = keys[order]
|
||
ca = ca[order]
|
||
cb = cb[order]
|
||
uniq, start, cnt = _np.unique(keys, return_index=True, return_counts=True)
|
||
man = start[cnt == 2] # manifold edges (exactly 2 incident faces)
|
||
# union both endpoints' corners across each manifold edge
|
||
rows = _np.concatenate([ca[man], cb[man]])
|
||
cols = _np.concatenate([ca[man + 1], cb[man + 1]])
|
||
|
||
n = 3 * nf
|
||
g = coo_matrix((_np.ones(rows.shape[0], dtype=_np.int8), (rows, cols)), shape=(n, n))
|
||
_ncomp, labels = connected_components(g, directed=False)
|
||
|
||
new_faces = labels[3 * _np.arange(nf)[:, None] + _np.array([0, 1, 2])[None, :]]
|
||
nnv = int(labels.max()) + 1
|
||
# source original-vertex index per new vertex
|
||
src = _np.zeros(nnv, dtype=_np.int64)
|
||
src[labels] = corner_vert
|
||
new_verts = V[src]
|
||
src_t = torch.from_numpy(src).to(device=dev)
|
||
return (torch.from_numpy(new_verts).to(device=dev, dtype=vdt),
|
||
torch.from_numpy(new_faces.astype(_np.int64)).to(device=dev, dtype=fdt),
|
||
src_t)
|
||
|
||
|
||
def _drop_small_components(
|
||
verts: torch.Tensor, faces: torch.Tensor, min_faces: int,
|
||
max_propagation_iters: int = 200,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||
"""Label-propagation connected components; drop components below min_faces."""
|
||
if faces.numel() == 0 or min_faces <= 1:
|
||
return verts, faces, 0
|
||
device = verts.device
|
||
V = verts.shape[0]
|
||
labels = torch.arange(V, device=device, dtype=torch.int64)
|
||
for _ in range(max_propagation_iters):
|
||
v0, v1, v2 = faces[:, 0], faces[:, 1], faces[:, 2]
|
||
face_min = torch.minimum(torch.minimum(labels[v0], labels[v1]), labels[v2])
|
||
new_labels = labels.clone()
|
||
new_labels.scatter_reduce_(0, v0, face_min, reduce="amin", include_self=True)
|
||
new_labels.scatter_reduce_(0, v1, face_min, reduce="amin", include_self=True)
|
||
new_labels.scatter_reduce_(0, v2, face_min, reduce="amin", include_self=True)
|
||
new_labels = new_labels[new_labels] # path-compress
|
||
if torch.equal(new_labels, labels):
|
||
break
|
||
labels = new_labels
|
||
face_label = labels[faces[:, 0]]
|
||
unique_labels, counts = torch.unique(face_label, return_counts=True)
|
||
big_labels = unique_labels[counts >= min_faces]
|
||
if big_labels.shape[0] == unique_labels.shape[0]:
|
||
return verts, faces, 0
|
||
# safety: never drop every component (return the small mesh, not an empty one)
|
||
if big_labels.shape[0] == 0:
|
||
return verts, faces, 0
|
||
keep_face = torch.isin(face_label, big_labels)
|
||
kept_faces = faces[keep_face]
|
||
n_dropped = int(unique_labels.shape[0] - big_labels.shape[0])
|
||
return verts, kept_faces, n_dropped
|
||
|
||
|
||
def clean_mesh(
|
||
verts: torch.Tensor, faces: torch.Tensor,
|
||
colors: Optional[torch.Tensor] = None,
|
||
normals: Optional[torch.Tensor] = None,
|
||
weld_epsilon: float = 0.0,
|
||
weld_epsilon_rel: float = 1e-6,
|
||
drop_degenerate: bool = True,
|
||
drop_duplicates: bool = True,
|
||
drop_unused: bool = True,
|
||
min_component_faces: int = 0,
|
||
min_angle_deg: float = 0.0,
|
||
max_aspect_ratio: float = 0.0,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], CleanStats]:
|
||
"""Mesh hygiene pipeline; preserves per-vertex attributes through welding. Returns (v, f, colors, normals, stats)."""
|
||
stats = CleanStats(in_verts=verts.shape[0], in_faces=faces.shape[0])
|
||
t0 = _time.perf_counter()
|
||
v = verts
|
||
f = faces.long() if faces.numel() > 0 else faces
|
||
c = colors
|
||
n = normals
|
||
|
||
if weld_epsilon != 0.0 or weld_epsilon_rel > 0:
|
||
# eps stays a 0-d tensor (no sync)
|
||
if weld_epsilon > 0:
|
||
eps = torch.as_tensor(weld_epsilon, dtype=v.dtype, device=v.device)
|
||
else:
|
||
eps = torch.norm(v.max(dim=0)[0] - v.min(dim=0)[0]) * weld_epsilon_rel
|
||
v, f, c, n, n_welded = _weld_vertices(v, f, eps, c, n)
|
||
stats.welded_verts = n_welded
|
||
|
||
if drop_degenerate:
|
||
f_new, n_drop = _drop_degenerate_faces(v, f)
|
||
stats.degenerate_faces = n_drop
|
||
f = f_new
|
||
# slivers get collapse-merged instead of dropped (preserves topology)
|
||
if min_angle_deg > 0 or max_aspect_ratio > 0:
|
||
f_new, n_sliv = _collapse_slivers(
|
||
v, f, min_angle_deg=min_angle_deg, max_aspect_ratio=max_aspect_ratio,
|
||
)
|
||
stats.degenerate_faces += n_sliv
|
||
f = f_new
|
||
|
||
if drop_duplicates:
|
||
f_new, n_dup = _drop_duplicate_faces(f, v.shape[0])
|
||
stats.duplicate_faces = n_dup
|
||
f = f_new
|
||
|
||
if min_component_faces > 1:
|
||
v, f, n_comp = _drop_small_components(v, f, min_component_faces)
|
||
stats.components_dropped = n_comp
|
||
|
||
if drop_unused:
|
||
v, f, c, n, n_unused = _drop_unused_verts(v, f, c, n)
|
||
stats.unused_verts = n_unused
|
||
|
||
stats.out_verts = v.shape[0]
|
||
stats.out_faces = f.shape[0]
|
||
stats.seconds = _time.perf_counter() - t0
|
||
# materialize tensor-scalar counts to plain ints once at exit
|
||
for field in ("welded_verts", "degenerate_faces", "duplicate_faces",
|
||
"unused_verts", "components_dropped"):
|
||
val = getattr(stats, field)
|
||
if torch.is_tensor(val):
|
||
setattr(stats, field, int(val.item()))
|
||
return v, f, c, n, stats
|
||
|
||
|
||
@dataclass
|
||
class SimplifyStats:
|
||
input_verts: int = 0
|
||
input_faces: int = 0
|
||
output_verts: int = 0
|
||
output_faces: int = 0
|
||
iterations: int = 0
|
||
total_collapses: int = 0
|
||
seconds: float = 0.0
|
||
peak_mem_mb: float = 0.0
|
||
|
||
|
||
def qem_simplify(
|
||
vertices: torch.Tensor,
|
||
faces: torch.Tensor,
|
||
target_faces: int,
|
||
colors: Optional[torch.Tensor] = None,
|
||
normals: Optional[torch.Tensor] = None,
|
||
max_edge_length: Optional[float] = None,
|
||
config: Optional[QEMConfig] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], SimplifyStats]:
|
||
"""Single-mesh QEM simplification. Returns (v, f, colors, normals, stats)."""
|
||
cfg = config or QEMConfig()
|
||
|
||
device = vertices.device
|
||
in_v_dtype = vertices.dtype
|
||
in_f_dtype = faces.dtype
|
||
in_c_dtype = colors.dtype if colors is not None else None
|
||
in_n_dtype = normals.dtype if normals is not None else None
|
||
|
||
verts = vertices.to(device=device, dtype=cfg.dtype, copy=True)
|
||
faces = faces.to(device=device, dtype=torch.int64).clone()
|
||
colors_w = colors.to(device=device, dtype=cfg.dtype, copy=True) if colors is not None else None
|
||
normals_w = normals.to(device=device, dtype=cfg.dtype, copy=True) if normals is not None else None
|
||
|
||
# preclean: weld + drop degenerate/duplicate, attributes cluster-averaged
|
||
if cfg.preclean:
|
||
verts, faces, colors_w, normals_w, _cs = clean_mesh(
|
||
verts, faces, colors_w, normals_w,
|
||
weld_epsilon_rel=cfg.preclean_weld_epsilon_rel,
|
||
min_component_faces=cfg.preclean_min_component_faces,
|
||
)
|
||
|
||
num_verts = verts.shape[0]
|
||
num_faces = faces.shape[0]
|
||
|
||
stats = SimplifyStats(input_verts=num_verts, input_faces=num_faces)
|
||
|
||
if num_faces <= target_faces or num_verts < 4:
|
||
stats.output_verts = num_verts
|
||
stats.output_faces = num_faces
|
||
return verts.to(in_v_dtype), faces.to(in_f_dtype), \
|
||
(colors_w.to(in_c_dtype) if colors_w is not None else None), \
|
||
(normals_w.to(in_n_dtype) if normals_w is not None else None), \
|
||
stats
|
||
|
||
if device.type == "cuda":
|
||
torch.cuda.synchronize(device)
|
||
torch.cuda.reset_peak_memory_stats(device)
|
||
t0 = _time.perf_counter()
|
||
|
||
v_alive = torch.ones(num_verts, dtype=torch.bool, device=device)
|
||
f_alive = torch.ones(num_faces, dtype=torch.bool, device=device)
|
||
|
||
Q = _build_quadrics(verts, faces, cfg)
|
||
|
||
bbox = verts.max(dim=0)[0] - verts.min(dim=0)[0]
|
||
mesh_scale = torch.norm(bbox) # 0-d tensor; never .item()'d
|
||
if max_edge_length is None or max_edge_length <= 0:
|
||
max_edge_length = mesh_scale * 2.0
|
||
else:
|
||
max_edge_length = torch.as_tensor(max_edge_length, dtype=cfg.dtype, device=device)
|
||
# tiny-bbox guard (tensor-side, no sync)
|
||
max_edge_length = torch.where(
|
||
max_edge_length < 1e-6,
|
||
torch.ones((), dtype=max_edge_length.dtype, device=device),
|
||
max_edge_length,
|
||
)
|
||
|
||
stabilizer = mesh_scale * mesh_scale * cfg.stabilizer_scale
|
||
max_edge_length_sq = max_edge_length * max_edge_length
|
||
mesh_scale_sq = mesh_scale * mesh_scale
|
||
|
||
# threshold scaled by mesh_scale² so the 1e-8 start is scale-robust
|
||
thresh = float(cfg.threshold_start) * float(mesh_scale_sq) if cfg.threshold_driver else 0.0
|
||
|
||
# pre-allocated merge_map, reused each iter
|
||
merge_map = torch.arange(num_verts, device=device)
|
||
|
||
# py_n_faces: Python-int face count (no host sync in hot loop), re-synced at compaction
|
||
py_n_faces = num_faces
|
||
|
||
iteration = 0
|
||
total_collapses = 0
|
||
|
||
# progress bars (tqdm + optional comfy ProgressBar), best-effort
|
||
_start_faces = num_faces
|
||
_prog_total = max(1, _start_faces - int(target_faces))
|
||
try:
|
||
_qtq = _tqdm(total=100, desc="QEM simplify", leave=False)
|
||
except Exception:
|
||
_qtq = None
|
||
try:
|
||
_qpbar = _comfy_utils.ProgressBar(100)
|
||
except Exception:
|
||
_qpbar = None
|
||
|
||
def _qreport():
|
||
pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total)))
|
||
if _qtq is not None:
|
||
_qtq.n = pct
|
||
_qtq.refresh()
|
||
if _qpbar is not None:
|
||
_qpbar.update_absolute(pct, 100)
|
||
|
||
while True:
|
||
if py_n_faces <= target_faces:
|
||
break
|
||
_qreport()
|
||
|
||
alive_f = torch.nonzero(f_alive, as_tuple=True)[0]
|
||
if alive_f.numel() == 0:
|
||
break
|
||
|
||
active_faces = faces[alive_f]
|
||
|
||
# memoryless QEM: rebuild Q from current geometry each iter
|
||
if cfg.threshold_driver and cfg.memoryless_qem and iteration > 0:
|
||
Q = _build_quadrics(verts, active_faces, cfg)
|
||
|
||
Q_for_iter = Q
|
||
# edge extraction: pack (min*V + max) so unique dedups in one pass
|
||
af_roll = torch.roll(active_faces, shifts=-1, dims=1)
|
||
mn = torch.minimum(active_faces, af_roll)
|
||
mx = torch.maximum(active_faces, af_roll)
|
||
packed = torch.add(mx, mn, alpha=num_verts).flatten()
|
||
packed = torch.unique(packed)
|
||
edges_orig = torch.stack([packed // num_verts, packed % num_verts], dim=1)
|
||
|
||
# filter by edge length
|
||
pab = verts[edges_orig] # (E, 2, 3)
|
||
el = torch.norm(pab[:, 1] - pab[:, 0], dim=-1)
|
||
edges_orig = edges_orig[el < max_edge_length]
|
||
if edges_orig.shape[0] == 0:
|
||
break
|
||
|
||
# sampling cap
|
||
n_edges_total = edges_orig.shape[0]
|
||
if n_edges_total > cfg.sampling_cap:
|
||
perm = torch.randperm(n_edges_total, device=device)[: cfg.sampling_cap]
|
||
edges_orig = edges_orig[perm]
|
||
|
||
# boundary mask only needed for non-qem placement
|
||
if cfg.placement_mode != "qem":
|
||
vib = _vert_is_boundary_mask(active_faces, num_verts)
|
||
else:
|
||
vib = None
|
||
optimal, err, valid = _edge_errors(
|
||
verts, Q_for_iter, edges_orig, stabilizer, max_edge_length_sq,
|
||
mesh_scale_sq, cfg, vert_is_boundary=vib,
|
||
)
|
||
valid_idx = torch.nonzero(valid, as_tuple=True)[0]
|
||
edges_orig = edges_orig[valid_idx]
|
||
optimal = optimal[valid_idx]
|
||
err = err[valid_idx]
|
||
|
||
faces_to_remove = py_n_faces - target_faces
|
||
n_faces_round_start = py_n_faces
|
||
# ~2 faces removed per collapse, so cap the round at faces_to_remove//2
|
||
cap_to_target = max(1, faces_to_remove // 2)
|
||
|
||
if cfg.threshold_driver:
|
||
# band = cost <= thresh (×10 until non-empty), quality-check, then collapse a disjoint set
|
||
cand = err <= thresh
|
||
esc = 0
|
||
while not bool(cand.any()) and esc < 50:
|
||
thresh *= 10.0
|
||
cand = err <= thresh
|
||
esc += 1
|
||
cand_idx = torch.nonzero(cand, as_tuple=True)[0]
|
||
ce = edges_orig[cand_idx]
|
||
copt = optimal[cand_idx]
|
||
cerr = err[cand_idx].clone()
|
||
need_flip = cfg.flip_reject_hard
|
||
if ((need_flip or cfg.skinny_weight > 0 or cfg.enforce_link_condition)
|
||
and ce.shape[0] > 0):
|
||
afq = faces[alive_f]
|
||
v_to_f = _build_vert_to_faces_pad(afq, num_verts, cfg.flip_check_max_degree)
|
||
# link + flip + skinny share one fused 1-ring pass
|
||
fc, sk, link_safe = _quality_checks_fused(
|
||
verts, afq, ce, copt, v_to_f, cos_threshold=cfg.flip_cos_threshold,
|
||
want_flip=need_flip, want_skinny=(cfg.skinny_weight > 0),
|
||
want_link=cfg.enforce_link_condition)
|
||
if link_safe is not None:
|
||
cerr[~link_safe] = float("inf")
|
||
if fc is not None:
|
||
cerr = torch.where(fc > 0, torch.full_like(cerr, float("inf")), cerr)
|
||
if sk is not None:
|
||
el_sq = (verts[ce[:, 1]] - verts[ce[:, 0]]).pow(2).sum(dim=-1)
|
||
cerr = cerr + cfg.skinny_weight * sk * el_sq
|
||
del v_to_f, afq
|
||
# penalties may push edges above thresh — re-gate the band
|
||
keep = cerr <= thresh
|
||
ce = ce[keep]
|
||
copt = copt[keep]
|
||
cerr = cerr[keep]
|
||
edges_orig = ce
|
||
optimal = copt
|
||
sel = _greedy_matching(ce, cerr, v_alive, cap_to_target)
|
||
if sel.numel() == 0:
|
||
# band fully rejected → raise thresh and retry
|
||
thresh *= 10.0
|
||
iteration += 1
|
||
if iteration >= cfg.max_iterations:
|
||
break
|
||
continue
|
||
else:
|
||
max_collapses = min(
|
||
cfg.max_collapses_ceiling,
|
||
max(cfg.max_collapses_floor, int(faces_to_remove * cfg.max_collapses_fraction)),
|
||
)
|
||
if cfg.max_collapses_relative_cap > 0:
|
||
# cap to a fraction of current mesh size (anti cascade-overshoot)
|
||
rel_cap = max(1, int(py_n_faces * cfg.max_collapses_relative_cap))
|
||
max_collapses = min(max_collapses, rel_cap)
|
||
max_collapses = min(max_collapses, cap_to_target)
|
||
|
||
# soft quality penalties on top-K: flip + skinny, sharing one v_to_f build
|
||
need_flip = cfg.flip_reject_hard
|
||
need_quality = ((need_flip or cfg.skinny_weight > 0 or cfg.enforce_link_condition)
|
||
and edges_orig.shape[0] > 0)
|
||
if need_quality:
|
||
n_check = min(edges_orig.shape[0],
|
||
max(1, cfg.quality_topk_multiplier * max_collapses))
|
||
if n_check < edges_orig.shape[0]:
|
||
topk = torch.topk(err, n_check, largest=False).indices
|
||
else:
|
||
topk = torch.arange(edges_orig.shape[0], device=device)
|
||
active_for_quality = faces[alive_f]
|
||
v_to_f = _build_vert_to_faces_pad(active_for_quality, num_verts,
|
||
cfg.flip_check_max_degree)
|
||
err = err.clone()
|
||
if cfg.enforce_link_condition:
|
||
# reject link-condition violations on ALL candidate edges, not just top-K
|
||
link_safe = _link_condition_mask(active_for_quality, edges_orig, v_to_f)
|
||
err[~link_safe] = float("inf")
|
||
e_tk = edges_orig[topk]
|
||
o_tk = optimal[topk]
|
||
_do_flip = need_flip
|
||
_do_skinny = cfg.skinny_weight > 0
|
||
if _do_flip and _do_skinny:
|
||
flip_count, skinny, _ = _quality_checks_fused(
|
||
verts, active_for_quality, e_tk, o_tk, v_to_f,
|
||
cos_threshold=cfg.flip_cos_threshold, want_link=False)
|
||
elif _do_flip:
|
||
flip_count = _normal_flip_mask(
|
||
verts, active_for_quality, e_tk, o_tk, v_to_f,
|
||
cos_threshold=cfg.flip_cos_threshold, return_count=True)
|
||
skinny = None
|
||
else:
|
||
skinny = _skinny_penalty(verts, active_for_quality, e_tk, o_tk, v_to_f)
|
||
flip_count = None
|
||
if _do_flip:
|
||
# hard reject: any flipping top-K edge → +inf
|
||
flips = flip_count > 0
|
||
if flips.any():
|
||
err[topk] = torch.where(
|
||
flips, torch.full_like(err[topk], float("inf")),
|
||
err[topk],
|
||
)
|
||
if _do_skinny:
|
||
# skinny_cost * len² (match QEM's length² scaling)
|
||
elen_sq = (verts[e_tk[:, 1]] - verts[e_tk[:, 0]]).pow(2).sum(dim=-1)
|
||
err[topk] = torch.add(err[topk], skinny * elen_sq,
|
||
alpha=cfg.skinny_weight)
|
||
del v_to_f, active_for_quality
|
||
|
||
sel = _greedy_matching(edges_orig, err, v_alive, max_collapses)
|
||
|
||
if sel.numel() == 0:
|
||
break
|
||
|
||
ed_sel = edges_orig[sel]
|
||
v_a = ed_sel[:, 0]
|
||
v_b = ed_sel[:, 1]
|
||
new_pos = optimal[sel]
|
||
|
||
# interpolate attributes by new_pos's position along [pa, pb]
|
||
if colors_w is not None or normals_w is not None:
|
||
pa_sel = verts[v_a]
|
||
pb_sel = verts[v_b]
|
||
edge_vec = pb_sel - pa_sel
|
||
edge_len_sq = (edge_vec * edge_vec).sum(dim=-1) + 1e-20
|
||
t = ((new_pos - pa_sel) * edge_vec).sum(dim=-1) / edge_len_sq
|
||
t = t.clamp(0.0, 1.0).unsqueeze(-1)
|
||
if colors_w is not None:
|
||
colors_w[v_a] = torch.lerp(colors_w[v_a], colors_w[v_b], t)
|
||
if normals_w is not None:
|
||
normals_w[v_a] = torch.lerp(normals_w[v_a], normals_w[v_b], t)
|
||
|
||
# apply collapse
|
||
verts[v_a] = new_pos
|
||
v_alive[v_b] = False
|
||
if not (cfg.threshold_driver and cfg.memoryless_qem):
|
||
Q[v_a] += Q[v_b]
|
||
|
||
merge_map[v_b] = v_a
|
||
faces = merge_map[faces]
|
||
merge_map[v_b] = v_b # restore identity for next iter
|
||
|
||
bad = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0])
|
||
f_alive.masked_fill_(bad, False)
|
||
py_n_faces -= 2 * v_a.numel() # ~2 faces/collapse estimate; re-synced at compaction
|
||
|
||
# schedule: round removed < 1% → raise thresh ×10
|
||
if cfg.threshold_driver:
|
||
removed = n_faces_round_start - py_n_faces
|
||
if removed < 0.01 * n_faces_round_start:
|
||
thresh *= 10.0
|
||
|
||
total_collapses += int(v_a.numel())
|
||
iteration += 1
|
||
|
||
# periodic compaction (resyncs py_n_faces exactly)
|
||
if iteration % cfg.compaction_period == 0:
|
||
alive_frac = py_n_faces / max(1, num_faces)
|
||
if alive_frac < cfg.compaction_threshold:
|
||
faces = faces[f_alive]
|
||
num_faces = faces.shape[0]
|
||
f_alive = torch.ones(num_faces, dtype=torch.bool, device=device)
|
||
py_n_faces = num_faces
|
||
|
||
if iteration >= cfg.max_iterations:
|
||
break
|
||
|
||
_qreport()
|
||
if _qtq is not None:
|
||
_qtq.close()
|
||
|
||
# finalize: compact verts and faces
|
||
final_v = verts[v_alive]
|
||
final_c = colors_w[v_alive] if colors_w is not None else None
|
||
final_n = normals_w[v_alive] if normals_w is not None else None
|
||
|
||
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
|
||
remap[v_alive] = v_alive.long().cumsum(0)[v_alive] - 1 # compact remap, no sync
|
||
|
||
final_f_raw = faces[f_alive]
|
||
alive_mask = v_alive[final_f_raw].all(dim=1)
|
||
final_f_raw = final_f_raw[alive_mask]
|
||
final_f = remap[final_f_raw]
|
||
valid_faces = (final_f >= 0).all(dim=1)
|
||
final_f = final_f[valid_faces]
|
||
|
||
# drop degenerate faces (two indices equal)
|
||
if final_f.numel() > 0:
|
||
nondeg = (final_f[:, 0] != final_f[:, 1]) & (final_f[:, 1] != final_f[:, 2]) & (final_f[:, 0] != final_f[:, 2])
|
||
final_f = final_f[nondeg]
|
||
|
||
# dedup duplicate faces, winding-preserving
|
||
if final_f.numel() > 0:
|
||
key = torch.sort(final_f, dim=1)[0]
|
||
packed = (key[:, 0].long() * (final_v.shape[0] + 1) + key[:, 1].long()) \
|
||
* (final_v.shape[0] + 1) + key[:, 2].long()
|
||
unique_packed, inv = torch.unique(packed, return_inverse=True)
|
||
arange = torch.arange(packed.shape[0], device=packed.device)
|
||
first = torch.full((unique_packed.shape[0],), packed.shape[0],
|
||
dtype=torch.int64, device=packed.device)
|
||
first.scatter_reduce_(0, inv, arange, reduce="amin", include_self=True)
|
||
final_f = final_f[first]
|
||
|
||
# split back fused surface sheets (after dedup, before pruning)
|
||
if cfg.repair_nonmanifold and final_f.numel() > 0:
|
||
final_v, final_f, _src = _repair_nonmanifold_edges(final_v, final_f)
|
||
if final_c is not None:
|
||
final_c = final_c[_src]
|
||
if final_n is not None:
|
||
final_n = final_n[_src]
|
||
|
||
# post-clean: drop slivers, tiny components, unused verts
|
||
if cfg.postclean and final_f.numel() > 0:
|
||
comp_threshold = cfg.postclean_min_component_faces
|
||
final_v, final_f, final_c, final_n, _ps = clean_mesh(
|
||
final_v, final_f, final_c, final_n,
|
||
weld_epsilon=0.0, weld_epsilon_rel=0.0, # already welded
|
||
drop_degenerate=True,
|
||
drop_duplicates=False, # already done above
|
||
drop_unused=True,
|
||
min_component_faces=comp_threshold,
|
||
min_angle_deg=cfg.postclean_min_angle_deg,
|
||
max_aspect_ratio=cfg.postclean_max_aspect_ratio,
|
||
)
|
||
|
||
# post-simplify normals
|
||
if cfg.recompute_normals_post and final_f.numel() > 0:
|
||
final_n = _compute_vertex_normals(final_v, final_f)
|
||
elif final_n is not None and final_f.numel() > 0:
|
||
# keep supplied normals; flip face winding where it disagrees
|
||
v0 = final_v[final_f[:, 0]]
|
||
v1 = final_v[final_f[:, 1]]
|
||
v2 = final_v[final_f[:, 2]]
|
||
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||
ref = (final_n[final_f[:, 0]] + final_n[final_f[:, 1]]
|
||
+ final_n[final_f[:, 2]]) / 3.0
|
||
wrong = (fn * ref).sum(dim=-1) < 0
|
||
final_f[wrong] = final_f[wrong][:, [0, 2, 1]]
|
||
|
||
if device.type == "cuda":
|
||
torch.cuda.synchronize(device)
|
||
stats.peak_mem_mb = torch.cuda.max_memory_allocated(device) / (1024 * 1024)
|
||
stats.seconds = _time.perf_counter() - t0
|
||
stats.iterations = iteration
|
||
stats.total_collapses = total_collapses
|
||
stats.output_verts = final_v.shape[0]
|
||
stats.output_faces = final_f.shape[0]
|
||
|
||
return (
|
||
final_v.to(in_v_dtype),
|
||
final_f.to(in_f_dtype),
|
||
final_c.to(in_c_dtype) if final_c is not None else None,
|
||
final_n.to(in_n_dtype) if (final_n is not None and in_n_dtype is not None) else final_n,
|
||
stats,
|
||
)
|
||
|
||
|
||
def simplify(
|
||
vertices: torch.Tensor,
|
||
faces: torch.Tensor,
|
||
target: int,
|
||
colors: Optional[torch.Tensor] = None,
|
||
normals: Optional[torch.Tensor] = None,
|
||
max_edge_length: Optional[float] = None,
|
||
config: Optional[QEMConfig] = None,
|
||
):
|
||
"""Batched wrapper. Accepts (V,3)/(F,3) or (B,V,3)/(B,F,3)."""
|
||
if vertices.ndim == 3:
|
||
out_v, out_f, out_c, out_n, out_s = [], [], [], [], []
|
||
for i in range(vertices.shape[0]):
|
||
c_in = colors[i] if colors is not None else None
|
||
n_in = normals[i] if normals is not None else None
|
||
v, f, c, n, s = qem_simplify(vertices[i], faces[i], target, c_in, n_in, max_edge_length, config)
|
||
out_v.append(v)
|
||
out_f.append(f)
|
||
out_s.append(s)
|
||
if c is not None: out_c.append(c)
|
||
if n is not None: out_n.append(n)
|
||
return (out_v, out_f,
|
||
out_c if out_c else None,
|
||
out_n if out_n else None,
|
||
out_s)
|
||
return qem_simplify(vertices, faces, target, colors, normals, max_edge_length, config)
|
||
|
||
|
||
def cluster_decimate(
|
||
vertices: torch.Tensor, faces: torch.Tensor,
|
||
target_verts: int = 1_000_000,
|
||
colors: Optional[torch.Tensor] = None,
|
||
face_chunk: int = 4_000_000,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||
"""Vertex-cluster decimation (Rossignac-Borrel): grid-bin/average verts, remap faces,
|
||
drop degenerate/duplicate. Fast O(V+F) prepass for huge meshes. Returns (verts, faces, colors)."""
|
||
if vertices.shape[0] == 0 or faces.shape[0] == 0:
|
||
return vertices, faces, colors
|
||
|
||
device = vertices.device
|
||
bbox = vertices.max(dim=0)[0] - vertices.min(dim=0)[0]
|
||
bbox_min = vertices.min(dim=0)[0]
|
||
# cell size so the bbox holds ~3× target_verts cells (surface occupancy ~1/3)
|
||
cell_count_target = max(target_verts * 3, 1000)
|
||
extent_max = float(bbox.max().item())
|
||
cells_per_axis = (cell_count_target ** (1 / 3))
|
||
cell_size = extent_max / max(1.0, cells_per_axis)
|
||
scale = 1.0 / max(cell_size, 1e-20)
|
||
|
||
q = ((vertices - bbox_min) * scale).floor().to(torch.int64)
|
||
extent = (bbox * scale).floor().to(torch.int64) + 2
|
||
Wy = extent[1]
|
||
Wz = extent[2]
|
||
key = (q[:, 0] * Wy + q[:, 1]) * Wz + q[:, 2]
|
||
|
||
unique_key, inv = torch.unique(key, return_inverse=True)
|
||
n_unique = unique_key.shape[0]
|
||
counts = torch.zeros(n_unique, dtype=vertices.dtype, device=device)
|
||
counts.scatter_add_(0, inv, torch.ones(vertices.shape[0], dtype=vertices.dtype, device=device))
|
||
counts_div = counts.unsqueeze(-1).clamp_min(1.0)
|
||
|
||
new_verts = torch.zeros((n_unique, 3), dtype=vertices.dtype, device=device)
|
||
new_verts.scatter_add_(0, inv.unsqueeze(-1).expand_as(vertices), vertices)
|
||
new_verts = new_verts / counts_div
|
||
|
||
new_colors = None
|
||
if colors is not None:
|
||
new_colors = torch.zeros((n_unique, colors.shape[1]), dtype=colors.dtype, device=device)
|
||
new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors)
|
||
new_colors = new_colors / counts_div.to(colors.dtype)
|
||
|
||
# remap faces in chunks (face tensor can be huge), drop degenerates per chunk
|
||
out_chunks = []
|
||
F = faces.shape[0]
|
||
for fs in range(0, F, face_chunk):
|
||
fe = min(fs + face_chunk, F)
|
||
cf = inv[faces[fs:fe].long()]
|
||
nondeg = ((cf[:, 0] != cf[:, 1]) & (cf[:, 1] != cf[:, 2]) & (cf[:, 0] != cf[:, 2]))
|
||
if nondeg.any():
|
||
out_chunks.append(cf[nondeg])
|
||
if out_chunks:
|
||
new_faces = torch.cat(out_chunks, dim=0)
|
||
else:
|
||
new_faces = torch.empty((0, 3), dtype=faces.dtype, device=device)
|
||
|
||
# drop duplicate faces (same vertex set after clustering)
|
||
if new_faces.numel() > 0:
|
||
key_sorted = torch.sort(new_faces, dim=1)[0]
|
||
P = n_unique + 1
|
||
packed = (key_sorted[:, 0].long() * P + key_sorted[:, 1].long()) * P + key_sorted[:, 2].long()
|
||
_, first = torch.unique(packed, return_inverse=True)
|
||
arange = torch.arange(packed.shape[0], device=device, dtype=torch.int64)
|
||
first_idx = torch.full((int(first.max().item()) + 1,), packed.shape[0],
|
||
dtype=torch.int64, device=device)
|
||
first_idx.scatter_reduce_(0, first, arange, reduce="amin", include_self=True)
|
||
new_faces = new_faces[first_idx]
|
||
|
||
return new_verts.to(vertices.dtype), new_faces.to(faces.dtype), new_colors
|