mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
continued cleanup
This commit is contained in:
parent
9414c33157
commit
83c7ec69c7
@ -2,9 +2,9 @@
|
||||
Pure-PyTorch GPU-parallel QEM mesh simplification.
|
||||
|
||||
- Parallel greedy edge-matching collapse loop
|
||||
- Plane / line / feature-edge / boundary quadrics, memoryless accumulation
|
||||
- Plane/line/feature-edge/boundary quadrics, memoryless accumulation
|
||||
- Normal-flip prevention, link-condition, skinny penalties
|
||||
- Non-manifold / sliver handling without dropping faces
|
||||
- Non-manifold/sliver handling without dropping faces
|
||||
- Pre/post-clean pipeline (weld, degenerates, small components)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
@ -25,17 +25,17 @@ import comfy.utils as _comfy_utils
|
||||
|
||||
@dataclass
|
||||
class QEMConfig:
|
||||
# Precision
|
||||
# Precision
|
||||
dtype: torch.dtype = torch.float32 # float64 much slower on consumer GPUs
|
||||
|
||||
# Numerical conditioning
|
||||
# Numerical conditioning
|
||||
stabilizer_scale: float = 1e-3 # Tikhonov reg: stabilizer = mesh_scale^2 * this
|
||||
wander_threshold: float = 2.0 # fall back to midpoint if v* lands > N×edge_length from an endpoint
|
||||
clamp_v_to_edge: bool = True # project v* onto the edge segment (qem mode only)
|
||||
|
||||
# Vertex placement mode (also selects the collapse driver)
|
||||
# "midpoint" (default): threshold-schedule driver, most stable. The defaults below match it.
|
||||
# "qem": sharpest, QEM-optimum placement + ratio driver.
|
||||
# Placement mode (also selects collapse driver):
|
||||
# "midpoint" = threshold-schedule driver, most stable (defaults below match it);
|
||||
# "qem" = sharpest, QEM-optimum placement + ratio driver.
|
||||
placement_mode: str = "midpoint"
|
||||
|
||||
flip_reject_hard: bool = True # hard-reject (err=+inf) top-K collapses that flip any 1-ring normal
|
||||
@ -81,8 +81,8 @@ class QEMConfig:
|
||||
lambda_edge_length: float = 1e-2 # add λ*len² to bias toward short edges; 0 disables
|
||||
lambda_edge_length_absolute: bool = True # apply λ absolutely vs relative-to-QEM-median
|
||||
|
||||
# Threshold-schedule driver (placement_mode == "midpoint")
|
||||
# Cost-threshold schedule: each round collapses a disjoint set with cost <= thresh, ×10 when < 1% removed.
|
||||
# Threshold-schedule driver (placement_mode == "midpoint"):
|
||||
# each round collapses a disjoint set with cost <= thresh, ×10 when < 1% removed.
|
||||
threshold_start: float = 1e-8
|
||||
memoryless_qem: bool = True # rebuild quadrics each round vs accumulate
|
||||
repair_nonmanifold: bool = True # final repair_non_manifold_edges pass
|
||||
@ -162,7 +162,7 @@ def _manifold_edge_pairs(
|
||||
empty = sorted_keys.new_empty(0)
|
||||
return empty, empty, empty
|
||||
pair_starts = torch.nonzero(pair_mask, as_tuple=True)[0]
|
||||
# manifold iff neither neighbour half-edge has the same key
|
||||
# manifold iff neither neighbour half-edge shares the key
|
||||
cur = sorted_keys[pair_starts]
|
||||
prev_ok = (pair_starts == 0) | (sorted_keys[(pair_starts - 1).clamp_min(0)] != cur)
|
||||
nxt_idx = (pair_starts + 2).clamp(max=sorted_keys.shape[0] - 1)
|
||||
@ -181,10 +181,9 @@ def _line_quadric_planes(
|
||||
elen = torch.norm(e, dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
e_unit = e / elen # (E, 3)
|
||||
m = 0.5 * (pa + pb) # (E, 3)
|
||||
# helper axis not parallel to e_unit
|
||||
# helper axis not parallel to e_unit, then Gram-Schmidt against e_unit
|
||||
helper = torch.zeros_like(e_unit)
|
||||
helper.scatter_(-1, e_unit.abs().argmin(dim=-1, keepdim=True), 1.0)
|
||||
# Gram-Schmidt against e_unit
|
||||
u = helper - (helper * e_unit).sum(-1, keepdim=True) * e_unit
|
||||
u = u / torch.norm(u, dim=-1, keepdim=True).clamp_min(1e-12)
|
||||
w = torch.cross(e_unit, u, dim=-1)
|
||||
@ -245,7 +244,7 @@ def _build_quadrics(
|
||||
n = torch.cross(e1, e2, dim=-1)
|
||||
area = torch.norm(n, dim=-1)
|
||||
mask = area > 1e-12
|
||||
# where() instead of boolean-index gather+scatter (fewer index kernels)
|
||||
# where() avoids boolean-index gather+scatter (fewer index kernels)
|
||||
n_norm = torch.where(mask.unsqueeze(-1),
|
||||
n / area.unsqueeze(-1).clamp_min(1e-12),
|
||||
n.new_zeros(()))
|
||||
@ -266,7 +265,7 @@ def _build_quadrics(
|
||||
skip_he_sharp = None
|
||||
if cfg.line_quadric_skip_opposite_normals_cos < 1.0:
|
||||
v_norm = torch.zeros((V, 3), dtype=dtype, device=device)
|
||||
n_weighted = n_norm * area.unsqueeze(-1) # face normal * 2× area
|
||||
n_weighted = n_norm * area.unsqueeze(-1) # normal * 2× area
|
||||
for corner in range(3):
|
||||
v_norm.scatter_add_(0, faces[:, corner].unsqueeze(-1).expand(-1, 3),
|
||||
n_weighted)
|
||||
@ -378,20 +377,20 @@ def _edge_errors(
|
||||
A = Qe[:, :3, :3] + torch.eye(3, device=device, dtype=dtype) * stabilizer
|
||||
b = -Qe[:, :3, 3].unsqueeze(-1)
|
||||
|
||||
# stabilizer keeps A invertible; solve full batch and pick midpoint via where (no host sync)
|
||||
# stabilizer keeps A invertible; full-batch solve, midpoint fallback via where (no sync)
|
||||
sol = torch.linalg.solve(A, b)
|
||||
dets = torch.det(A)
|
||||
good = (dets.abs() > 1e-12).unsqueeze(-1)
|
||||
opt = torch.where(good, sol.squeeze(-1), midpoint)
|
||||
|
||||
if cfg.clamp_v_to_edge:
|
||||
# qem mode + clamp: project v* onto the edge segment (subsumes the wander check)
|
||||
# project v* onto the edge segment (subsumes the wander check)
|
||||
edge_len_sq = (edge_vec * edge_vec).sum(dim=-1) + 1e-20
|
||||
t = ((opt - pa) * edge_vec).sum(dim=-1) / edge_len_sq
|
||||
t = t.clamp(0.0, 1.0).unsqueeze(-1)
|
||||
opt = torch.lerp(pa, pb, t)
|
||||
else:
|
||||
# qem mode + no clamp: fall back to midpoint when v* wanders from both endpoints
|
||||
# fall back to midpoint when v* wanders from both endpoints
|
||||
dist_a = torch.norm(opt - pa, dim=-1)
|
||||
dist_b = torch.norm(opt - pb, dim=-1)
|
||||
wander_bad = ((dist_a > cfg.wander_threshold * el) |
|
||||
@ -401,7 +400,7 @@ def _edge_errors(
|
||||
v4 = torch.cat([opt, torch.ones((n_edges, 1), device=device, dtype=dtype)], dim=1)
|
||||
err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4))
|
||||
|
||||
# mesh_scale_sq may be Python float or 0-d tensor
|
||||
# mesh_scale_sq: Python float or 0-d tensor
|
||||
if torch.is_tensor(mesh_scale_sq):
|
||||
length_ok = el * el > mesh_scale_sq * 1e-10
|
||||
else:
|
||||
@ -523,9 +522,12 @@ def _normal_flip_mask(
|
||||
|
||||
a_b = a.view(Ec, 1)
|
||||
b_b = b.view(Ec, 1)
|
||||
s0_a = fv[..., 0] == a_b; s0_b = fv[..., 0] == b_b
|
||||
s1_a = fv[..., 1] == a_b; s1_b = fv[..., 1] == b_b
|
||||
s2_a = fv[..., 2] == a_b; s2_b = fv[..., 2] == b_b
|
||||
s0_a = fv[..., 0] == a_b
|
||||
s0_b = fv[..., 0] == b_b
|
||||
s1_a = fv[..., 1] == a_b
|
||||
s1_b = fv[..., 1] == b_b
|
||||
s2_a = fv[..., 2] == a_b
|
||||
s2_b = fv[..., 2] == b_b
|
||||
contains_a = s0_a | s1_a | s2_a
|
||||
contains_b = s0_b | s1_b | s2_b
|
||||
# affected: face contains exactly one of {a, b} and slot is non-pad
|
||||
@ -549,7 +551,7 @@ def _normal_flip_mask(
|
||||
|
||||
nlen_old = torch.norm(n_old, dim=-1)
|
||||
nlen_new = torch.norm(n_new, dim=-1)
|
||||
# degenerate-before faces can't meaningfully flip; treat as OK
|
||||
# degenerate-before faces can't flip; treat as OK
|
||||
denom = nlen_old * nlen_new
|
||||
safe = denom > 1e-20
|
||||
cos = torch.where(safe, (n_old * n_new).sum(dim=-1) / denom.clamp_min(1e-20),
|
||||
@ -581,17 +583,20 @@ def _link_condition_mask(
|
||||
|
||||
for s in range(0, E, chunk_size):
|
||||
e = min(s + chunk_size, E)
|
||||
a = a_all[s:e]; b = b_all[s:e]
|
||||
a = a_all[s:e]
|
||||
b = b_all[s:e]
|
||||
Ec = a.shape[0]
|
||||
|
||||
fa = vert_to_faces[a] # (Ec, D)
|
||||
fb = vert_to_faces[b]
|
||||
fa_ok = fa >= 0; fb_ok = fb >= 0
|
||||
fa_ok = fa >= 0
|
||||
fb_ok = fb >= 0
|
||||
fav = faces[fa.clamp(min=0)] # (Ec, D, 3)
|
||||
fbv = faces[fb.clamp(min=0)]
|
||||
|
||||
# neighbour verts of a/b: take the 2 non-anchor verts per incident face → (Ec, 2D)
|
||||
a_b = a[:, None]; b_b = b[:, None]
|
||||
a_b = a[:, None]
|
||||
b_b = b[:, None]
|
||||
an1 = torch.where(fav[..., 0] == a_b, fav[..., 1], fav[..., 0])
|
||||
an2 = torch.where(fav[..., 2] == a_b, fav[..., 1], fav[..., 2])
|
||||
bn1 = torch.where(fbv[..., 0] == b_b, fbv[..., 1], fbv[..., 0])
|
||||
@ -707,7 +712,8 @@ def _quality_checks_fused(
|
||||
want_link: bool = False,
|
||||
chunk_size: int = 100_000,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""Fused 1-ring checks (flip count / skinny / link) sharing one faces[v_to_f] gather; returns (flip_count|None, skinny|None, link_safe|None)."""
|
||||
"""Fused 1-ring checks (flip count / skinny / link) sharing one faces gather.
|
||||
Returns (flip_count|None, skinny|None, link_safe|None)."""
|
||||
E = edges.shape[0]
|
||||
device = verts.device
|
||||
flip_out = torch.zeros(E, dtype=torch.int32, device=device) if want_flip else None
|
||||
@ -728,7 +734,7 @@ def _quality_checks_fused(
|
||||
a = a_all[start:stop]
|
||||
b = b_all[start:stop]
|
||||
|
||||
# shared gather: a's and b's incident faces (the expensive part)
|
||||
# shared gather of a's and b's incident faces (the expensive part)
|
||||
fa = vert_to_faces[a]
|
||||
fb = vert_to_faces[b]
|
||||
all_f = torch.cat([fa, fb], dim=1) # (Ec, 2D)
|
||||
@ -739,9 +745,12 @@ def _quality_checks_fused(
|
||||
|
||||
if need_geom:
|
||||
oc = opt[start:stop]
|
||||
s0_a = fv[..., 0] == a_b; s0_b = fv[..., 0] == b_b
|
||||
s1_a = fv[..., 1] == a_b; s1_b = fv[..., 1] == b_b
|
||||
s2_a = fv[..., 2] == a_b; s2_b = fv[..., 2] == b_b
|
||||
s0_a = fv[..., 0] == a_b
|
||||
s0_b = fv[..., 0] == b_b
|
||||
s1_a = fv[..., 1] == a_b
|
||||
s1_b = fv[..., 1] == b_b
|
||||
s2_a = fv[..., 2] == a_b
|
||||
s2_b = fv[..., 2] == b_b
|
||||
contains_a = s0_a | s1_a | s2_a
|
||||
contains_b = s0_b | s1_b | s2_b
|
||||
affected = (contains_a ^ contains_b) & valid_f
|
||||
@ -898,7 +907,9 @@ def _drop_degenerate_faces(
|
||||
return faces, 0
|
||||
idx_bad = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 0] == faces[:, 2])
|
||||
f_good = faces[~idx_bad]
|
||||
v0 = verts[f_good[:, 0]]; v1 = verts[f_good[:, 1]]; v2 = verts[f_good[:, 2]]
|
||||
v0 = verts[f_good[:, 0]]
|
||||
v1 = verts[f_good[:, 1]]
|
||||
v2 = verts[f_good[:, 2]]
|
||||
e0 = v1 - v0
|
||||
e2 = v0 - v2
|
||||
area = 0.5 * torch.norm(torch.cross(e0, -e2, dim=-1), dim=-1)
|
||||
@ -918,7 +929,9 @@ def _collapse_slivers(
|
||||
return faces, 0
|
||||
|
||||
fl = faces.long()
|
||||
v0 = verts[fl[:, 0]]; v1 = verts[fl[:, 1]]; v2 = verts[fl[:, 2]]
|
||||
v0 = verts[fl[:, 0]]
|
||||
v1 = verts[fl[:, 1]]
|
||||
v2 = verts[fl[:, 2]]
|
||||
e0 = v1 - v0
|
||||
e1 = v2 - v1
|
||||
e2 = v0 - v2
|
||||
@ -1041,13 +1054,17 @@ def _repair_nonmanifold_edges(
|
||||
vmin = _np.minimum(va, vb).astype(_np.int64)
|
||||
vmax = _np.maximum(va, vb).astype(_np.int64)
|
||||
keys_l.append(vmin * (nv + 1) + vmax)
|
||||
ca_l.append(amin); cb_l.append(amax)
|
||||
ca_l.append(amin)
|
||||
cb_l.append(amax)
|
||||
keys = _np.concatenate(keys_l)
|
||||
ca = _np.concatenate(ca_l); cb = _np.concatenate(cb_l)
|
||||
ca = _np.concatenate(ca_l)
|
||||
cb = _np.concatenate(cb_l)
|
||||
order = _np.argsort(keys, kind="stable")
|
||||
keys = keys[order]; ca = ca[order]; cb = cb[order]
|
||||
keys = keys[order]
|
||||
ca = ca[order]
|
||||
cb = cb[order]
|
||||
uniq, start, cnt = _np.unique(keys, return_index=True, return_counts=True)
|
||||
man = start[cnt == 2] # manifold edges: exactly 2 incident faces
|
||||
man = start[cnt == 2] # manifold edges (exactly 2 incident faces)
|
||||
# union both endpoints' corners across each manifold edge
|
||||
rows = _np.concatenate([ca[man], cb[man]])
|
||||
cols = _np.concatenate([ca[man + 1], cb[man + 1]])
|
||||
@ -1205,7 +1222,7 @@ def qem_simplify(
|
||||
colors_w = colors.to(device=device, dtype=cfg.dtype, copy=True) if colors is not None else None
|
||||
normals_w = normals.to(device=device, dtype=cfg.dtype, copy=True) if normals is not None else None
|
||||
|
||||
# optional preclean: weld + drop degenerate/duplicate, attributes cluster-averaged
|
||||
# preclean: weld + drop degenerate/duplicate, attributes cluster-averaged
|
||||
if cfg.preclean:
|
||||
verts, faces, colors_w, normals_w, _cs = clean_mesh(
|
||||
verts, faces, colors_w, normals_w,
|
||||
@ -1242,7 +1259,7 @@ def qem_simplify(
|
||||
max_edge_length = mesh_scale * 2.0
|
||||
else:
|
||||
max_edge_length = torch.as_tensor(max_edge_length, dtype=cfg.dtype, device=device)
|
||||
# degenerate-mesh guard for tiny bbox (tensor-side, no sync)
|
||||
# tiny-bbox guard (tensor-side, no sync)
|
||||
max_edge_length = torch.where(
|
||||
max_edge_length < 1e-6,
|
||||
torch.ones((), dtype=max_edge_length.dtype, device=device),
|
||||
@ -1265,7 +1282,7 @@ def qem_simplify(
|
||||
iteration = 0
|
||||
total_collapses = 0
|
||||
|
||||
# progress bars (tqdm + optional comfy ProgressBar); best-effort
|
||||
# progress bars (tqdm + optional comfy ProgressBar), best-effort
|
||||
_start_faces = num_faces
|
||||
_prog_total = max(1, _start_faces - int(target_faces))
|
||||
try:
|
||||
@ -1280,7 +1297,8 @@ def qem_simplify(
|
||||
def _qreport():
|
||||
pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total)))
|
||||
if _qtq is not None:
|
||||
_qtq.n = pct; _qtq.refresh()
|
||||
_qtq.n = pct
|
||||
_qtq.refresh()
|
||||
if _qpbar is not None:
|
||||
_qpbar.update_absolute(pct, 100)
|
||||
|
||||
@ -1300,7 +1318,7 @@ def qem_simplify(
|
||||
Q = _build_quadrics(verts, active_faces, cfg)
|
||||
|
||||
Q_for_iter = Q
|
||||
# edge extraction: pack each (min*V + max) so unique dedups in one go
|
||||
# edge extraction: pack (min*V + max) so unique dedups in one pass
|
||||
af_roll = torch.roll(active_faces, shifts=-1, dims=1)
|
||||
mn = torch.minimum(active_faces, af_roll)
|
||||
mx = torch.maximum(active_faces, af_roll)
|
||||
@ -1341,8 +1359,7 @@ def qem_simplify(
|
||||
cap_to_target = max(1, faces_to_remove // 2)
|
||||
|
||||
if cfg.threshold_driver:
|
||||
# threshold-schedule selection
|
||||
# candidate band = cost <= thresh (×10 until non-empty), quality-check then collapse a disjoint set
|
||||
# band = cost <= thresh (×10 until non-empty), quality-check, then collapse a disjoint set
|
||||
cand = err <= thresh
|
||||
esc = 0
|
||||
while not bool(cand.any()) and esc < 50:
|
||||
@ -1358,7 +1375,7 @@ def qem_simplify(
|
||||
and ce.shape[0] > 0):
|
||||
afq = faces[alive_f]
|
||||
v_to_f = _build_vert_to_faces_pad(afq, num_verts, cfg.flip_check_max_degree)
|
||||
# link + flip + skinny share one fused 1-ring pass on the same band
|
||||
# link + flip + skinny share one fused 1-ring pass
|
||||
fc, sk, link_safe = _quality_checks_fused(
|
||||
verts, afq, ce, copt, v_to_f, cos_threshold=cfg.flip_cos_threshold,
|
||||
want_flip=need_flip, want_skinny=(cfg.skinny_weight > 0),
|
||||
@ -1373,7 +1390,9 @@ def qem_simplify(
|
||||
del v_to_f, afq
|
||||
# penalties may push edges above thresh — re-gate the band
|
||||
keep = cerr <= thresh
|
||||
ce = ce[keep]; copt = copt[keep]; cerr = cerr[keep]
|
||||
ce = ce[keep]
|
||||
copt = copt[keep]
|
||||
cerr = cerr[keep]
|
||||
edges_orig = ce
|
||||
optimal = copt
|
||||
sel = _greedy_matching(ce, cerr, v_alive, cap_to_target)
|
||||
@ -1390,12 +1409,12 @@ def qem_simplify(
|
||||
max(cfg.max_collapses_floor, int(faces_to_remove * cfg.max_collapses_fraction)),
|
||||
)
|
||||
if cfg.max_collapses_relative_cap > 0:
|
||||
# hybrid tail: cap to a fraction of current mesh size (anti cascade-overshoot)
|
||||
# cap to a fraction of current mesh size (anti cascade-overshoot)
|
||||
rel_cap = max(1, int(py_n_faces * cfg.max_collapses_relative_cap))
|
||||
max_collapses = min(max_collapses, rel_cap)
|
||||
max_collapses = min(max_collapses, cap_to_target)
|
||||
|
||||
# soft quality penalties on top-K candidates: flip check + skinny, sharing one v_to_f build
|
||||
# soft quality penalties on top-K: flip + skinny, sharing one v_to_f build
|
||||
need_flip = cfg.flip_reject_hard
|
||||
need_quality = ((need_flip or cfg.skinny_weight > 0 or cfg.enforce_link_condition)
|
||||
and edges_orig.shape[0] > 0)
|
||||
@ -1411,10 +1430,9 @@ def qem_simplify(
|
||||
cfg.flip_check_max_degree)
|
||||
err = err.clone()
|
||||
if cfg.enforce_link_condition:
|
||||
# reject link-condition violations on ALL candidate edges (not just top-K)
|
||||
# reject link-condition violations on ALL candidate edges, not just top-K
|
||||
link_safe = _link_condition_mask(active_for_quality, edges_orig, v_to_f)
|
||||
err[~link_safe] = float("inf")
|
||||
# flip + skinny share the same top-K 1-ring walk
|
||||
e_tk = edges_orig[topk]
|
||||
o_tk = optimal[topk]
|
||||
_do_flip = need_flip
|
||||
@ -1440,7 +1458,7 @@ def qem_simplify(
|
||||
err[topk],
|
||||
)
|
||||
if _do_skinny:
|
||||
# skinny_cost * edge_length² (match QEM's length² scaling)
|
||||
# skinny_cost * len² (match QEM's length² scaling)
|
||||
elen_sq = (verts[e_tk[:, 1]] - verts[e_tk[:, 0]]).pow(2).sum(dim=-1)
|
||||
err[topk] = torch.add(err[topk], skinny * elen_sq,
|
||||
alpha=cfg.skinny_weight)
|
||||
@ -1456,7 +1474,7 @@ def qem_simplify(
|
||||
v_b = ed_sel[:, 1]
|
||||
new_pos = optimal[sel]
|
||||
|
||||
# interpolate attributes by where new_pos lies on the [pa, pb] segment
|
||||
# interpolate attributes by new_pos's position along [pa, pb]
|
||||
if colors_w is not None or normals_w is not None:
|
||||
pa_sel = verts[v_a]
|
||||
pb_sel = verts[v_b]
|
||||
@ -1540,7 +1558,7 @@ def qem_simplify(
|
||||
first.scatter_reduce_(0, inv, arange, reduce="amin", include_self=True)
|
||||
final_f = final_f[first]
|
||||
|
||||
# repair_non_manifold_edges: split back fused surface sheets (after dedup, before pruning)
|
||||
# split back fused surface sheets (after dedup, before pruning)
|
||||
if cfg.repair_nonmanifold and final_f.numel() > 0:
|
||||
final_v, final_f, _src = _repair_nonmanifold_edges(final_v, final_f)
|
||||
if final_c is not None:
|
||||
@ -1610,7 +1628,9 @@ def simplify(
|
||||
c_in = colors[i] if colors is not None else None
|
||||
n_in = normals[i] if normals is not None else None
|
||||
v, f, c, n, s = qem_simplify(vertices[i], faces[i], target, c_in, n_in, max_edge_length, config)
|
||||
out_v.append(v); out_f.append(f); out_s.append(s)
|
||||
out_v.append(v)
|
||||
out_f.append(f)
|
||||
out_s.append(s)
|
||||
if c is not None: out_c.append(c)
|
||||
if n is not None: out_n.append(n)
|
||||
return (out_v, out_f,
|
||||
@ -1626,9 +1646,8 @@ def cluster_decimate(
|
||||
colors: Optional[torch.Tensor] = None,
|
||||
face_chunk: int = 4_000_000,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Vertex-cluster decimation (Rossignac-Borrel): bin verts into a ~target_verts grid,
|
||||
average per cell, remap faces (chunked), drop degenerate/duplicate. Fast O(V+F) prepass
|
||||
for huge meshes before QEM/remesh. Returns (verts, faces, colors)."""
|
||||
"""Vertex-cluster decimation (Rossignac-Borrel): grid-bin/average verts, remap faces,
|
||||
drop degenerate/duplicate. Fast O(V+F) prepass for huge meshes. Returns (verts, faces, colors)."""
|
||||
if vertices.shape[0] == 0 or faces.shape[0] == 0:
|
||||
return vertices, faces, colors
|
||||
|
||||
@ -1664,7 +1683,7 @@ def cluster_decimate(
|
||||
new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors)
|
||||
new_colors = new_colors / counts_div.to(colors.dtype)
|
||||
|
||||
# remap faces in chunks (face tensor can be huge); drop degenerates per chunk
|
||||
# remap faces in chunks (face tensor can be huge), drop degenerates per chunk
|
||||
out_chunks = []
|
||||
F = faces.shape[0]
|
||||
for fs in range(0, F, face_chunk):
|
||||
|
||||
@ -10,7 +10,7 @@ from server import PromptServer
|
||||
from comfy_extras.mesh3d.postprocess.qem_decimate import (
|
||||
simplify as qem_decimate_simplify, QEMConfig, cluster_decimate as qem_cluster_decimate,
|
||||
)
|
||||
from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc
|
||||
from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc, _point_tri_closest
|
||||
from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh
|
||||
from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg
|
||||
from comfy_extras.mesh3d.uv_unwrap import parameterize as _uv_param
|
||||
@ -530,45 +530,6 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors
|
||||
return out
|
||||
|
||||
|
||||
def _closest_point_on_triangles(p, a, b, c):
|
||||
"""Vectorized exact closest point on triangles (Ericson §5.1.5). p/a/b/c [...,3] →
|
||||
[...,3]; all vertex/edge/face Voronoi regions, highest-priority-last via where."""
|
||||
ab = b - a
|
||||
ac = c - a
|
||||
ap = p - a
|
||||
d1 = (ab * ap).sum(-1)
|
||||
d2 = (ac * ap).sum(-1)
|
||||
bp = p - b
|
||||
d3 = (ab * bp).sum(-1)
|
||||
d4 = (ac * bp).sum(-1)
|
||||
cp = p - c
|
||||
d5 = (ab * cp).sum(-1)
|
||||
d6 = (ac * cp).sum(-1)
|
||||
va = d3 * d6 - d5 * d4
|
||||
vb = d5 * d2 - d1 * d6
|
||||
vc = d1 * d4 - d3 * d2
|
||||
|
||||
def u(x): # broadcast a scalar-per-element weight to [...,1]
|
||||
return x.unsqueeze(-1)
|
||||
|
||||
# face region (default)
|
||||
denom = 1.0 / (va + vb + vc).clamp_min(1e-20)
|
||||
v = vb * denom
|
||||
w = vc * denom
|
||||
res = a + ab * u(v) + ac * u(w)
|
||||
den_bc = (d4 - d3) + (d5 - d6)
|
||||
w_bc = (d4 - d3) / den_bc.clamp_min(1e-20)
|
||||
res = torch.where(u((va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0)), b + (c - b) * u(w_bc), res) # edge BC
|
||||
w_ac = d2 / (d2 - d6).clamp_min(1e-20)
|
||||
res = torch.where(u((vb <= 0) & (d2 >= 0) & (d6 <= 0)), a + ac * u(w_ac), res) # edge AC
|
||||
res = torch.where(u((d6 >= 0) & (d5 <= d6)), c, res) # vertex C
|
||||
v_ab = d1 / (d1 - d3).clamp_min(1e-20)
|
||||
res = torch.where(u((vc <= 0) & (d1 >= 0) & (d3 <= 0)), a + ab * u(v_ab), res) # edge AB
|
||||
res = torch.where(u((d3 >= 0) & (d4 <= d3)), b, res) # vertex B
|
||||
res = torch.where(u((d1 <= 0) & (d2 <= 0)), a, res) # vertex A
|
||||
return res
|
||||
|
||||
|
||||
def _msb_int64(x):
|
||||
"""floor(log2(x)) elementwise for int64 x >= 1 (bit-search, no float)."""
|
||||
r = torch.zeros_like(x)
|
||||
@ -722,8 +683,7 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
|
||||
if bool(lv.any()):
|
||||
ga = a[lv]
|
||||
tt = tri[order[node[lv] - LEAF]]
|
||||
cp = _closest_point_on_triangles(qa[lv], tt[:, 0], tt[:, 1], tt[:, 2])
|
||||
d2 = ((cp - qa[lv]) ** 2).sum(-1)
|
||||
cp, d2 = _point_tri_closest(qa[lv], tt)
|
||||
upd = d2 < best[ga]
|
||||
gu = ga[upd]
|
||||
best[gu] = d2[upd]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user