From 83c7ec69c7c00abcabcf192696e9e9435c8466bf Mon Sep 17 00:00:00 2001 From: kijai Date: Sat, 27 Jun 2026 02:08:12 +0300 Subject: [PATCH] continued cleanup --- .../mesh3d/postprocess/qem_decimate.py | 133 ++++++++++-------- comfy_extras/nodes_mesh_postprocess.py | 44 +----- 2 files changed, 78 insertions(+), 99 deletions(-) diff --git a/comfy_extras/mesh3d/postprocess/qem_decimate.py b/comfy_extras/mesh3d/postprocess/qem_decimate.py index 074c2a808..202f9dd3d 100644 --- a/comfy_extras/mesh3d/postprocess/qem_decimate.py +++ b/comfy_extras/mesh3d/postprocess/qem_decimate.py @@ -2,9 +2,9 @@ Pure-PyTorch GPU-parallel QEM mesh simplification. - Parallel greedy edge-matching collapse loop - - Plane / line / feature-edge / boundary quadrics, memoryless accumulation + - Plane/line/feature-edge/boundary quadrics, memoryless accumulation - Normal-flip prevention, link-condition, skinny penalties - - Non-manifold / sliver handling without dropping faces + - Non-manifold/sliver handling without dropping faces - Pre/post-clean pipeline (weld, degenerates, small components) """ from __future__ import annotations @@ -25,17 +25,17 @@ import comfy.utils as _comfy_utils @dataclass class QEMConfig: - # Precision + # Precision dtype: torch.dtype = torch.float32 # float64 much slower on consumer GPUs - # Numerical conditioning + # Numerical conditioning stabilizer_scale: float = 1e-3 # Tikhonov reg: stabilizer = mesh_scale^2 * this wander_threshold: float = 2.0 # fall back to midpoint if v* lands > N×edge_length from an endpoint clamp_v_to_edge: bool = True # project v* onto the edge segment (qem mode only) - # Vertex placement mode (also selects the collapse driver) - # "midpoint" (default): threshold-schedule driver, most stable. The defaults below match it. - # "qem": sharpest, QEM-optimum placement + ratio driver. + # Placement mode (also selects collapse driver): + # "midpoint" = threshold-schedule driver, most stable (defaults below match it); + # "qem" = sharpest, QEM-optimum placement + ratio driver. placement_mode: str = "midpoint" flip_reject_hard: bool = True # hard-reject (err=+inf) top-K collapses that flip any 1-ring normal @@ -81,8 +81,8 @@ class QEMConfig: lambda_edge_length: float = 1e-2 # add λ*len² to bias toward short edges; 0 disables lambda_edge_length_absolute: bool = True # apply λ absolutely vs relative-to-QEM-median - # Threshold-schedule driver (placement_mode == "midpoint") - # Cost-threshold schedule: each round collapses a disjoint set with cost <= thresh, ×10 when < 1% removed. + # Threshold-schedule driver (placement_mode == "midpoint"): + # each round collapses a disjoint set with cost <= thresh, ×10 when < 1% removed. threshold_start: float = 1e-8 memoryless_qem: bool = True # rebuild quadrics each round vs accumulate repair_nonmanifold: bool = True # final repair_non_manifold_edges pass @@ -162,7 +162,7 @@ def _manifold_edge_pairs( empty = sorted_keys.new_empty(0) return empty, empty, empty pair_starts = torch.nonzero(pair_mask, as_tuple=True)[0] - # manifold iff neither neighbour half-edge has the same key + # manifold iff neither neighbour half-edge shares the key cur = sorted_keys[pair_starts] prev_ok = (pair_starts == 0) | (sorted_keys[(pair_starts - 1).clamp_min(0)] != cur) nxt_idx = (pair_starts + 2).clamp(max=sorted_keys.shape[0] - 1) @@ -181,10 +181,9 @@ def _line_quadric_planes( elen = torch.norm(e, dim=-1, keepdim=True).clamp_min(1e-12) e_unit = e / elen # (E, 3) m = 0.5 * (pa + pb) # (E, 3) - # helper axis not parallel to e_unit + # helper axis not parallel to e_unit, then Gram-Schmidt against e_unit helper = torch.zeros_like(e_unit) helper.scatter_(-1, e_unit.abs().argmin(dim=-1, keepdim=True), 1.0) - # Gram-Schmidt against e_unit u = helper - (helper * e_unit).sum(-1, keepdim=True) * e_unit u = u / torch.norm(u, dim=-1, keepdim=True).clamp_min(1e-12) w = torch.cross(e_unit, u, dim=-1) @@ -245,7 +244,7 @@ def _build_quadrics( n = torch.cross(e1, e2, dim=-1) area = torch.norm(n, dim=-1) mask = area > 1e-12 - # where() instead of boolean-index gather+scatter (fewer index kernels) + # where() avoids boolean-index gather+scatter (fewer index kernels) n_norm = torch.where(mask.unsqueeze(-1), n / area.unsqueeze(-1).clamp_min(1e-12), n.new_zeros(())) @@ -266,7 +265,7 @@ def _build_quadrics( skip_he_sharp = None if cfg.line_quadric_skip_opposite_normals_cos < 1.0: v_norm = torch.zeros((V, 3), dtype=dtype, device=device) - n_weighted = n_norm * area.unsqueeze(-1) # face normal * 2× area + n_weighted = n_norm * area.unsqueeze(-1) # normal * 2× area for corner in range(3): v_norm.scatter_add_(0, faces[:, corner].unsqueeze(-1).expand(-1, 3), n_weighted) @@ -378,20 +377,20 @@ def _edge_errors( A = Qe[:, :3, :3] + torch.eye(3, device=device, dtype=dtype) * stabilizer b = -Qe[:, :3, 3].unsqueeze(-1) - # stabilizer keeps A invertible; solve full batch and pick midpoint via where (no host sync) + # stabilizer keeps A invertible; full-batch solve, midpoint fallback via where (no sync) sol = torch.linalg.solve(A, b) dets = torch.det(A) good = (dets.abs() > 1e-12).unsqueeze(-1) opt = torch.where(good, sol.squeeze(-1), midpoint) if cfg.clamp_v_to_edge: - # qem mode + clamp: project v* onto the edge segment (subsumes the wander check) + # project v* onto the edge segment (subsumes the wander check) edge_len_sq = (edge_vec * edge_vec).sum(dim=-1) + 1e-20 t = ((opt - pa) * edge_vec).sum(dim=-1) / edge_len_sq t = t.clamp(0.0, 1.0).unsqueeze(-1) opt = torch.lerp(pa, pb, t) else: - # qem mode + no clamp: fall back to midpoint when v* wanders from both endpoints + # fall back to midpoint when v* wanders from both endpoints dist_a = torch.norm(opt - pa, dim=-1) dist_b = torch.norm(opt - pb, dim=-1) wander_bad = ((dist_a > cfg.wander_threshold * el) | @@ -401,7 +400,7 @@ def _edge_errors( v4 = torch.cat([opt, torch.ones((n_edges, 1), device=device, dtype=dtype)], dim=1) err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4)) - # mesh_scale_sq may be Python float or 0-d tensor + # mesh_scale_sq: Python float or 0-d tensor if torch.is_tensor(mesh_scale_sq): length_ok = el * el > mesh_scale_sq * 1e-10 else: @@ -523,9 +522,12 @@ def _normal_flip_mask( a_b = a.view(Ec, 1) b_b = b.view(Ec, 1) - s0_a = fv[..., 0] == a_b; s0_b = fv[..., 0] == b_b - s1_a = fv[..., 1] == a_b; s1_b = fv[..., 1] == b_b - s2_a = fv[..., 2] == a_b; s2_b = fv[..., 2] == b_b + s0_a = fv[..., 0] == a_b + s0_b = fv[..., 0] == b_b + s1_a = fv[..., 1] == a_b + s1_b = fv[..., 1] == b_b + s2_a = fv[..., 2] == a_b + s2_b = fv[..., 2] == b_b contains_a = s0_a | s1_a | s2_a contains_b = s0_b | s1_b | s2_b # affected: face contains exactly one of {a, b} and slot is non-pad @@ -549,7 +551,7 @@ def _normal_flip_mask( nlen_old = torch.norm(n_old, dim=-1) nlen_new = torch.norm(n_new, dim=-1) - # degenerate-before faces can't meaningfully flip; treat as OK + # degenerate-before faces can't flip; treat as OK denom = nlen_old * nlen_new safe = denom > 1e-20 cos = torch.where(safe, (n_old * n_new).sum(dim=-1) / denom.clamp_min(1e-20), @@ -581,17 +583,20 @@ def _link_condition_mask( for s in range(0, E, chunk_size): e = min(s + chunk_size, E) - a = a_all[s:e]; b = b_all[s:e] + a = a_all[s:e] + b = b_all[s:e] Ec = a.shape[0] fa = vert_to_faces[a] # (Ec, D) fb = vert_to_faces[b] - fa_ok = fa >= 0; fb_ok = fb >= 0 + fa_ok = fa >= 0 + fb_ok = fb >= 0 fav = faces[fa.clamp(min=0)] # (Ec, D, 3) fbv = faces[fb.clamp(min=0)] # neighbour verts of a/b: take the 2 non-anchor verts per incident face → (Ec, 2D) - a_b = a[:, None]; b_b = b[:, None] + a_b = a[:, None] + b_b = b[:, None] an1 = torch.where(fav[..., 0] == a_b, fav[..., 1], fav[..., 0]) an2 = torch.where(fav[..., 2] == a_b, fav[..., 1], fav[..., 2]) bn1 = torch.where(fbv[..., 0] == b_b, fbv[..., 1], fbv[..., 0]) @@ -707,7 +712,8 @@ def _quality_checks_fused( want_link: bool = False, chunk_size: int = 100_000, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - """Fused 1-ring checks (flip count / skinny / link) sharing one faces[v_to_f] gather; returns (flip_count|None, skinny|None, link_safe|None).""" + """Fused 1-ring checks (flip count / skinny / link) sharing one faces gather. + Returns (flip_count|None, skinny|None, link_safe|None).""" E = edges.shape[0] device = verts.device flip_out = torch.zeros(E, dtype=torch.int32, device=device) if want_flip else None @@ -728,7 +734,7 @@ def _quality_checks_fused( a = a_all[start:stop] b = b_all[start:stop] - # shared gather: a's and b's incident faces (the expensive part) + # shared gather of a's and b's incident faces (the expensive part) fa = vert_to_faces[a] fb = vert_to_faces[b] all_f = torch.cat([fa, fb], dim=1) # (Ec, 2D) @@ -739,9 +745,12 @@ def _quality_checks_fused( if need_geom: oc = opt[start:stop] - s0_a = fv[..., 0] == a_b; s0_b = fv[..., 0] == b_b - s1_a = fv[..., 1] == a_b; s1_b = fv[..., 1] == b_b - s2_a = fv[..., 2] == a_b; s2_b = fv[..., 2] == b_b + s0_a = fv[..., 0] == a_b + s0_b = fv[..., 0] == b_b + s1_a = fv[..., 1] == a_b + s1_b = fv[..., 1] == b_b + s2_a = fv[..., 2] == a_b + s2_b = fv[..., 2] == b_b contains_a = s0_a | s1_a | s2_a contains_b = s0_b | s1_b | s2_b affected = (contains_a ^ contains_b) & valid_f @@ -898,7 +907,9 @@ def _drop_degenerate_faces( return faces, 0 idx_bad = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 0] == faces[:, 2]) f_good = faces[~idx_bad] - v0 = verts[f_good[:, 0]]; v1 = verts[f_good[:, 1]]; v2 = verts[f_good[:, 2]] + v0 = verts[f_good[:, 0]] + v1 = verts[f_good[:, 1]] + v2 = verts[f_good[:, 2]] e0 = v1 - v0 e2 = v0 - v2 area = 0.5 * torch.norm(torch.cross(e0, -e2, dim=-1), dim=-1) @@ -918,7 +929,9 @@ def _collapse_slivers( return faces, 0 fl = faces.long() - v0 = verts[fl[:, 0]]; v1 = verts[fl[:, 1]]; v2 = verts[fl[:, 2]] + v0 = verts[fl[:, 0]] + v1 = verts[fl[:, 1]] + v2 = verts[fl[:, 2]] e0 = v1 - v0 e1 = v2 - v1 e2 = v0 - v2 @@ -1041,13 +1054,17 @@ def _repair_nonmanifold_edges( vmin = _np.minimum(va, vb).astype(_np.int64) vmax = _np.maximum(va, vb).astype(_np.int64) keys_l.append(vmin * (nv + 1) + vmax) - ca_l.append(amin); cb_l.append(amax) + ca_l.append(amin) + cb_l.append(amax) keys = _np.concatenate(keys_l) - ca = _np.concatenate(ca_l); cb = _np.concatenate(cb_l) + ca = _np.concatenate(ca_l) + cb = _np.concatenate(cb_l) order = _np.argsort(keys, kind="stable") - keys = keys[order]; ca = ca[order]; cb = cb[order] + keys = keys[order] + ca = ca[order] + cb = cb[order] uniq, start, cnt = _np.unique(keys, return_index=True, return_counts=True) - man = start[cnt == 2] # manifold edges: exactly 2 incident faces + man = start[cnt == 2] # manifold edges (exactly 2 incident faces) # union both endpoints' corners across each manifold edge rows = _np.concatenate([ca[man], cb[man]]) cols = _np.concatenate([ca[man + 1], cb[man + 1]]) @@ -1205,7 +1222,7 @@ def qem_simplify( colors_w = colors.to(device=device, dtype=cfg.dtype, copy=True) if colors is not None else None normals_w = normals.to(device=device, dtype=cfg.dtype, copy=True) if normals is not None else None - # optional preclean: weld + drop degenerate/duplicate, attributes cluster-averaged + # preclean: weld + drop degenerate/duplicate, attributes cluster-averaged if cfg.preclean: verts, faces, colors_w, normals_w, _cs = clean_mesh( verts, faces, colors_w, normals_w, @@ -1242,7 +1259,7 @@ def qem_simplify( max_edge_length = mesh_scale * 2.0 else: max_edge_length = torch.as_tensor(max_edge_length, dtype=cfg.dtype, device=device) - # degenerate-mesh guard for tiny bbox (tensor-side, no sync) + # tiny-bbox guard (tensor-side, no sync) max_edge_length = torch.where( max_edge_length < 1e-6, torch.ones((), dtype=max_edge_length.dtype, device=device), @@ -1265,7 +1282,7 @@ def qem_simplify( iteration = 0 total_collapses = 0 - # progress bars (tqdm + optional comfy ProgressBar); best-effort + # progress bars (tqdm + optional comfy ProgressBar), best-effort _start_faces = num_faces _prog_total = max(1, _start_faces - int(target_faces)) try: @@ -1280,7 +1297,8 @@ def qem_simplify( def _qreport(): pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total))) if _qtq is not None: - _qtq.n = pct; _qtq.refresh() + _qtq.n = pct + _qtq.refresh() if _qpbar is not None: _qpbar.update_absolute(pct, 100) @@ -1300,7 +1318,7 @@ def qem_simplify( Q = _build_quadrics(verts, active_faces, cfg) Q_for_iter = Q - # edge extraction: pack each (min*V + max) so unique dedups in one go + # edge extraction: pack (min*V + max) so unique dedups in one pass af_roll = torch.roll(active_faces, shifts=-1, dims=1) mn = torch.minimum(active_faces, af_roll) mx = torch.maximum(active_faces, af_roll) @@ -1341,8 +1359,7 @@ def qem_simplify( cap_to_target = max(1, faces_to_remove // 2) if cfg.threshold_driver: - # threshold-schedule selection - # candidate band = cost <= thresh (×10 until non-empty), quality-check then collapse a disjoint set + # band = cost <= thresh (×10 until non-empty), quality-check, then collapse a disjoint set cand = err <= thresh esc = 0 while not bool(cand.any()) and esc < 50: @@ -1358,7 +1375,7 @@ def qem_simplify( and ce.shape[0] > 0): afq = faces[alive_f] v_to_f = _build_vert_to_faces_pad(afq, num_verts, cfg.flip_check_max_degree) - # link + flip + skinny share one fused 1-ring pass on the same band + # link + flip + skinny share one fused 1-ring pass fc, sk, link_safe = _quality_checks_fused( verts, afq, ce, copt, v_to_f, cos_threshold=cfg.flip_cos_threshold, want_flip=need_flip, want_skinny=(cfg.skinny_weight > 0), @@ -1373,7 +1390,9 @@ def qem_simplify( del v_to_f, afq # penalties may push edges above thresh — re-gate the band keep = cerr <= thresh - ce = ce[keep]; copt = copt[keep]; cerr = cerr[keep] + ce = ce[keep] + copt = copt[keep] + cerr = cerr[keep] edges_orig = ce optimal = copt sel = _greedy_matching(ce, cerr, v_alive, cap_to_target) @@ -1390,12 +1409,12 @@ def qem_simplify( max(cfg.max_collapses_floor, int(faces_to_remove * cfg.max_collapses_fraction)), ) if cfg.max_collapses_relative_cap > 0: - # hybrid tail: cap to a fraction of current mesh size (anti cascade-overshoot) + # cap to a fraction of current mesh size (anti cascade-overshoot) rel_cap = max(1, int(py_n_faces * cfg.max_collapses_relative_cap)) max_collapses = min(max_collapses, rel_cap) max_collapses = min(max_collapses, cap_to_target) - # soft quality penalties on top-K candidates: flip check + skinny, sharing one v_to_f build + # soft quality penalties on top-K: flip + skinny, sharing one v_to_f build need_flip = cfg.flip_reject_hard need_quality = ((need_flip or cfg.skinny_weight > 0 or cfg.enforce_link_condition) and edges_orig.shape[0] > 0) @@ -1411,10 +1430,9 @@ def qem_simplify( cfg.flip_check_max_degree) err = err.clone() if cfg.enforce_link_condition: - # reject link-condition violations on ALL candidate edges (not just top-K) + # reject link-condition violations on ALL candidate edges, not just top-K link_safe = _link_condition_mask(active_for_quality, edges_orig, v_to_f) err[~link_safe] = float("inf") - # flip + skinny share the same top-K 1-ring walk e_tk = edges_orig[topk] o_tk = optimal[topk] _do_flip = need_flip @@ -1440,7 +1458,7 @@ def qem_simplify( err[topk], ) if _do_skinny: - # skinny_cost * edge_length² (match QEM's length² scaling) + # skinny_cost * len² (match QEM's length² scaling) elen_sq = (verts[e_tk[:, 1]] - verts[e_tk[:, 0]]).pow(2).sum(dim=-1) err[topk] = torch.add(err[topk], skinny * elen_sq, alpha=cfg.skinny_weight) @@ -1456,7 +1474,7 @@ def qem_simplify( v_b = ed_sel[:, 1] new_pos = optimal[sel] - # interpolate attributes by where new_pos lies on the [pa, pb] segment + # interpolate attributes by new_pos's position along [pa, pb] if colors_w is not None or normals_w is not None: pa_sel = verts[v_a] pb_sel = verts[v_b] @@ -1540,7 +1558,7 @@ def qem_simplify( first.scatter_reduce_(0, inv, arange, reduce="amin", include_self=True) final_f = final_f[first] - # repair_non_manifold_edges: split back fused surface sheets (after dedup, before pruning) + # split back fused surface sheets (after dedup, before pruning) if cfg.repair_nonmanifold and final_f.numel() > 0: final_v, final_f, _src = _repair_nonmanifold_edges(final_v, final_f) if final_c is not None: @@ -1610,7 +1628,9 @@ def simplify( c_in = colors[i] if colors is not None else None n_in = normals[i] if normals is not None else None v, f, c, n, s = qem_simplify(vertices[i], faces[i], target, c_in, n_in, max_edge_length, config) - out_v.append(v); out_f.append(f); out_s.append(s) + out_v.append(v) + out_f.append(f) + out_s.append(s) if c is not None: out_c.append(c) if n is not None: out_n.append(n) return (out_v, out_f, @@ -1626,9 +1646,8 @@ def cluster_decimate( colors: Optional[torch.Tensor] = None, face_chunk: int = 4_000_000, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Vertex-cluster decimation (Rossignac-Borrel): bin verts into a ~target_verts grid, - average per cell, remap faces (chunked), drop degenerate/duplicate. Fast O(V+F) prepass - for huge meshes before QEM/remesh. Returns (verts, faces, colors).""" + """Vertex-cluster decimation (Rossignac-Borrel): grid-bin/average verts, remap faces, + drop degenerate/duplicate. Fast O(V+F) prepass for huge meshes. Returns (verts, faces, colors).""" if vertices.shape[0] == 0 or faces.shape[0] == 0: return vertices, faces, colors @@ -1664,7 +1683,7 @@ def cluster_decimate( new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors) new_colors = new_colors / counts_div.to(colors.dtype) - # remap faces in chunks (face tensor can be huge); drop degenerates per chunk + # remap faces in chunks (face tensor can be huge), drop degenerates per chunk out_chunks = [] F = faces.shape[0] for fs in range(0, F, face_chunk): diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index b7f1f3f0b..9549fa11e 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -10,7 +10,7 @@ from server import PromptServer from comfy_extras.mesh3d.postprocess.qem_decimate import ( simplify as qem_decimate_simplify, QEMConfig, cluster_decimate as qem_cluster_decimate, ) -from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc +from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc, _point_tri_closest from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg from comfy_extras.mesh3d.uv_unwrap import parameterize as _uv_param @@ -530,45 +530,6 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors return out -def _closest_point_on_triangles(p, a, b, c): - """Vectorized exact closest point on triangles (Ericson §5.1.5). p/a/b/c [...,3] → - [...,3]; all vertex/edge/face Voronoi regions, highest-priority-last via where.""" - ab = b - a - ac = c - a - ap = p - a - d1 = (ab * ap).sum(-1) - d2 = (ac * ap).sum(-1) - bp = p - b - d3 = (ab * bp).sum(-1) - d4 = (ac * bp).sum(-1) - cp = p - c - d5 = (ab * cp).sum(-1) - d6 = (ac * cp).sum(-1) - va = d3 * d6 - d5 * d4 - vb = d5 * d2 - d1 * d6 - vc = d1 * d4 - d3 * d2 - - def u(x): # broadcast a scalar-per-element weight to [...,1] - return x.unsqueeze(-1) - - # face region (default) - denom = 1.0 / (va + vb + vc).clamp_min(1e-20) - v = vb * denom - w = vc * denom - res = a + ab * u(v) + ac * u(w) - den_bc = (d4 - d3) + (d5 - d6) - w_bc = (d4 - d3) / den_bc.clamp_min(1e-20) - res = torch.where(u((va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0)), b + (c - b) * u(w_bc), res) # edge BC - w_ac = d2 / (d2 - d6).clamp_min(1e-20) - res = torch.where(u((vb <= 0) & (d2 >= 0) & (d6 <= 0)), a + ac * u(w_ac), res) # edge AC - res = torch.where(u((d6 >= 0) & (d5 <= d6)), c, res) # vertex C - v_ab = d1 / (d1 - d3).clamp_min(1e-20) - res = torch.where(u((vc <= 0) & (d1 >= 0) & (d3 <= 0)), a + ab * u(v_ab), res) # edge AB - res = torch.where(u((d3 >= 0) & (d4 <= d3)), b, res) # vertex B - res = torch.where(u((d1 <= 0) & (d2 <= 0)), a, res) # vertex A - return res - - def _msb_int64(x): """floor(log2(x)) elementwise for int64 x >= 1 (bit-search, no float).""" r = torch.zeros_like(x) @@ -722,8 +683,7 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64): if bool(lv.any()): ga = a[lv] tt = tri[order[node[lv] - LEAF]] - cp = _closest_point_on_triangles(qa[lv], tt[:, 0], tt[:, 1], tt[:, 2]) - d2 = ((cp - qa[lv]) ** 2).sum(-1) + cp, d2 = _point_tri_closest(qa[lv], tt) upd = d2 < best[ga] gu = ga[upd] best[gu] = d2[upd]