"""Adaptive cost-grow chart segmentation (CPU); numba optional, numpy path is nd-only.""" from __future__ import annotations from typing import List, Tuple import numpy as np import torch from torch import Tensor try: from numba import njit _HAVE_NUMBA = True except ImportError: _HAVE_NUMBA = False def njit(*args, **kwargs): # noqa: ARG001 def deco(fn): return fn return deco if not args else args[0] from .mesh import MeshData, face_edge_lengths DEFAULT_W_NORMAL_DEVIATION = 2.0 DEFAULT_W_ROUNDNESS = 0.01 DEFAULT_W_STRAIGHTNESS = 6.0 DEFAULT_MAX_COST = 2.0 NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75° @njit(cache=True, fastmath=False) def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray: F = face_normal.shape[0] raw = np.zeros(F, dtype=np.float32) for f in range(F): nx = face_normal[f, 0] ny = face_normal[f, 1] nz = face_normal[f, 2] s = np.float32(0.0) for e in range(3): nb = face_face[f, e] if nb < 0: continue mx = face_normal[nb, 0] my = face_normal[nb, 1] mz = face_normal[nb, 2] d = nx*mx + ny*my + nz*mz s += np.float32(1.0) - d raw[f] = s return raw def _face_curvature_numpy(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray: nb_safe = np.maximum(face_face, 0) nb_normal = face_normal[nb_safe] d = (face_normal[:, None, :] * nb_normal).sum(axis=-1) contrib = np.where(face_face >= 0, np.float32(1.0) - d, np.float32(0.0)) return contrib.sum(axis=1).astype(np.float32) @njit(cache=True, fastmath=False) def _farthest_point_seeds_jit( face_centroid: np.ndarray, face_area: np.ndarray, face_weight: np.ndarray, initial_seeds: np.ndarray, k_target: int, ): F = face_centroid.shape[0] INF = np.float32(1e30) min_dist = np.full(F, INF, dtype=np.float32) seeds = np.empty(k_target, dtype=np.int64) n_seeds = 0 for i in range(initial_seeds.shape[0]): s = initial_seeds[i] if s < 0 or n_seeds >= k_target: continue seeds[n_seeds] = s n_seeds += 1 sx = face_centroid[s, 0] sy = face_centroid[s, 1] sz = face_centroid[s, 2] for f in range(F): dx = face_centroid[f, 0] - sx dy = face_centroid[f, 1] - sy dz = face_centroid[f, 2] - sz d2 = dx*dx + dy*dy + dz*dz if d2 < min_dist[f]: min_dist[f] = d2 while n_seeds < k_target: best_f = -1 best_score = np.float32(-1.0) for f in range(F): d = min_dist[f] if d >= INF * np.float32(0.5): continue score = d * face_weight[f] if score > best_score: best_score = score best_f = f if best_f < 0: break seeds[n_seeds] = best_f n_seeds += 1 sx = face_centroid[best_f, 0] sy = face_centroid[best_f, 1] sz = face_centroid[best_f, 2] for f in range(F): dx = face_centroid[f, 0] - sx dy = face_centroid[f, 1] - sy dz = face_centroid[f, 2] - sz d2 = dx*dx + dy*dy + dz*dz if d2 < min_dist[f]: min_dist[f] = d2 return seeds[:n_seeds] def _farthest_point_seeds_numpy( face_centroid: np.ndarray, initial_seeds: np.ndarray, k_target: int, ): F = face_centroid.shape[0] min_dist = np.full(F, np.inf, dtype=np.float32) seeds: List[int] = [] for s in initial_seeds: if s < 0 or len(seeds) >= k_target: continue seeds.append(int(s)) d = ((face_centroid - face_centroid[s])**2).sum(axis=-1) min_dist = np.minimum(min_dist, d) while len(seeds) < k_target: best = int(np.argmax(min_dist)) if not np.isfinite(min_dist[best]) or min_dist[best] <= 0: break seeds.append(best) d = ((face_centroid - face_centroid[best])**2).sum(axis=-1) min_dist = np.minimum(min_dist, d) return np.asarray(seeds, dtype=np.int64) @njit(cache=True, fastmath=False) def _cost_grow_iter_jit( face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray, face_area: np.ndarray, face_edge_len: np.ndarray, chart_basis: np.ndarray, chart_normal_sum: np.ndarray, chart_area: np.ndarray, chart_perim: np.ndarray, nd_cutoff: float, max_cost: float, w_nd: float, w_round: float, w_straight: float, ): """One grow iter: each unassigned face joins its lowest-cost adjacent chart if cost np.float32(1.0): nd = np.float32(1.0) if nd < np.float32(0.0): nd = np.float32(0.0) if nd >= nd_cutoff: continue l_in = np.float32(0.0) l_out = np.float32(0.0) for e1 in range(3): nb1 = face_face[f, e1] el = face_edge_len[f, e1] if nb1 < 0: l_out += el elif face_chart[nb1] == c: l_in += el else: l_out += el ca = chart_area[c] cp = chart_perim[c] new_perim = cp - l_in + l_out new_area = ca + af if cp <= np.float32(1e-20) or ca <= np.float32(1e-20): round_cost = np.float32(0.0) else: old_r = (cp * cp) / ca new_r = (new_perim * new_perim) / new_area if new_r <= np.float32(1e-20): round_cost = np.float32(0.0) else: round_cost = np.float32(1.0) - old_r / new_r denom = l_out + l_in if denom <= np.float32(1e-20): straight_cost = np.float32(0.0) else: ratio = (l_out - l_in) / denom if ratio < np.float32(0.0): straight_cost = ratio else: straight_cost = np.float32(0.0) cost = (w_nd * nd + w_round * round_cost + w_straight * straight_cost) if cost < best_cost_per_face[f]: best_cost_per_face[f] = cost best_chart_per_face[f] = c n_assigned = 0 for f in range(F): if face_chart[f] != -1: continue if best_chart_per_face[f] < 0: continue if best_cost_per_face[f] > max_cost: continue c = best_chart_per_face[f] l_in = np.float32(0.0) l_out = np.float32(0.0) for e1 in range(3): nb1 = face_face[f, e1] el = face_edge_len[f, e1] if nb1 < 0: l_out += el elif face_chart[nb1] == c: l_in += el else: l_out += el af = face_area[f] face_chart[f] = c chart_normal_sum[c, 0] += face_normal[f, 0] * af chart_normal_sum[c, 1] += face_normal[f, 1] * af chart_normal_sum[c, 2] += face_normal[f, 2] * af chart_area[c] += af chart_perim[c] = chart_perim[c] - l_in + l_out nx = chart_normal_sum[c, 0] ny = chart_normal_sum[c, 1] nz = chart_normal_sum[c, 2] nlen = np.sqrt(nx * nx + ny * ny + nz * nz) if nlen > np.float32(1e-20): chart_basis[c, 0] = nx / nlen chart_basis[c, 1] = ny / nlen chart_basis[c, 2] = nz / nlen n_assigned += 1 return n_assigned def _renumber(face_chart: np.ndarray, device) -> Tensor: unique = np.unique(face_chart[face_chart >= 0]) if unique.size == 0: return torch.from_numpy(face_chart).to(device) remap = np.full(int(unique.max()) + 1, -1, dtype=np.int64) remap[unique] = np.arange(unique.size) out = face_chart.copy() mask = out >= 0 out[mask] = remap[out[mask]] return torch.from_numpy(out).to(device) def _segment_charts_fast( mesh: MeshData, max_cost: float, w_normal_deviation: float, w_roundness: float = DEFAULT_W_ROUNDNESS, w_straightness: float = DEFAULT_W_STRAIGHTNESS, target_chart_count: int = 0, ) -> Tensor: """Parallel batch cost-grow; target_chart_count 0 = adaptive seeding, >0 = K curvature-weighted FPS seeds.""" F = mesh.faces.shape[0] device = mesh.faces.device if F == 0: return torch.zeros(0, dtype=torch.long, device=device) face_normal = mesh.face_normal.detach().cpu().numpy().astype(np.float32) face_area = mesh.face_area.detach().cpu().numpy().astype(np.float32) face_centroid = mesh.face_centroid.detach().cpu().numpy().astype(np.float32) face_face = mesh.face_face.detach().cpu().numpy() face_chart = np.full(F, -1, dtype=np.int64) nd_cutoff = np.float32(NORMAL_DEVIATION_HARD_CUTOFF) nd_threshold = np.float32(min(max_cost / max(w_normal_deviation, 1e-6), NORMAL_DEVIATION_HARD_CUTOFF * 0.99)) component = (mesh.component.detach().cpu().numpy() if hasattr(mesh.component, "detach") else np.asarray(mesh.component)) if component.size: _, first_idx = np.unique(component, return_index=True) initial_seeds = first_idx.astype(np.int64) else: initial_seeds = np.empty(0, dtype=np.int64) adaptive_seeding = target_chart_count <= 0 if adaptive_seeding: seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()] if not seed_faces: seed_faces = [0] else: if _HAVE_NUMBA: curvature_raw = _face_curvature_jit(face_normal, face_face) else: curvature_raw = _face_curvature_numpy(face_normal, face_face) cmax = float(curvature_raw.max()) if curvature_raw.size else 0.0 if cmax > 1e-6: face_weight = (np.float32(1.0) + np.float32(50.0) * (curvature_raw / np.float32(cmax))).astype(np.float32) else: face_weight = np.ones(F, dtype=np.float32) n_comp = int(initial_seeds.size) if n_comp < int(target_chart_count): target_seeds = int(target_chart_count) else: target_seeds = n_comp + max(int(target_chart_count) // 4, 8) target_seeds = min(target_seeds, F) if _HAVE_NUMBA: seeds_arr = _farthest_point_seeds_jit( face_centroid, face_area, face_weight, initial_seeds, target_seeds, ) else: seeds_arr = _farthest_point_seeds_numpy( face_centroid, initial_seeds, target_seeds, ) seed_faces = [int(s) for s in seeds_arr.tolist()] K = len(seed_faces) chart_basis = np.zeros((K, 3), dtype=np.float32) chart_normal_sum = np.zeros((K, 3), dtype=np.float32) chart_area = np.zeros(K, dtype=np.float32) chart_perim = np.zeros(K, dtype=np.float32) face_edge_len = ( face_edge_lengths(mesh.vertices, mesh.faces) .detach().cpu().numpy() ) for cid, sf in enumerate(seed_faces): face_chart[sf] = cid n = face_normal[sf] a = face_area[sf] chart_basis[cid] = n.astype(np.float32) chart_normal_sum[cid] = (n * a).astype(np.float32) chart_area[cid] = float(a) chart_perim[cid] = float(face_edge_len[sf].sum()) if K == 0: return _renumber(face_chart, device) min_dist_to_seed = np.full(F, np.inf, dtype=np.float32) if adaptive_seeding: for sf in seed_faces: d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1) min_dist_to_seed = np.minimum(min_dist_to_seed, d) if _HAVE_NUMBA: # Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg. tau_final = min(max_cost * 0.25, 0.5) thresholds = [t for t in (0.05, 0.1, 0.25) if t < tau_final] + [tau_final] max_inner = max(64, int(np.sqrt(F)) * 2) max_total_charts = max(F, 8000) outer_iter = 0 while True: outer_iter += 1 if outer_iter > F + 16: break for tau in thresholds: for _ in range(max_inner): n_added = _cost_grow_iter_jit( face_chart, face_face, face_normal, face_area, face_edge_len, chart_basis, chart_normal_sum, chart_area, chart_perim, nd_cutoff, np.float32(tau), np.float32(w_normal_deviation), np.float32(w_roundness), np.float32(w_straightness), ) if n_added == 0: break if (face_chart == -1).sum() == 0: break if not adaptive_seeding: break if chart_basis.shape[0] >= max_total_charts: break unassigned_mask = face_chart == -1 cand = np.where(unassigned_mask, min_dist_to_seed, np.float32(-np.inf)) new_seed = int(np.argmax(cand)) n = face_normal[new_seed] a = face_area[new_seed] chart_basis = np.vstack([chart_basis, n[None, :].astype(np.float32)]) chart_normal_sum = np.vstack( [chart_normal_sum, (n * a)[None, :].astype(np.float32)] ) chart_area = np.concatenate([chart_area, np.array([a], dtype=np.float32)]) chart_perim = np.concatenate( [chart_perim, np.array([face_edge_len[new_seed].sum()], dtype=np.float32)] ) face_chart[new_seed] = chart_basis.shape[0] - 1 new_d = ((face_centroid - face_centroid[new_seed]) ** 2).sum(axis=-1) min_dist_to_seed = np.minimum(min_dist_to_seed, new_d) else: # Numpy fallback: nd-only adaptive grow. for _ in range(max(64, int(np.sqrt(F)) + 32)): unassigned = face_chart == -1 if not unassigned.any(): break u_idx = np.nonzero(unassigned)[0] nbs = face_face[u_idx] nbs_safe = np.where(nbs >= 0, nbs, 0) nb_charts = np.where(nbs >= 0, face_chart[nbs_safe], -1) valid = (nb_charts >= 0) if not valid.any(): break nb_charts_safe = np.where(valid, nb_charts, 0) nb_basis = chart_basis[nb_charts_safe] d = (face_normal[u_idx][:, None, :] * nb_basis).sum(axis=-1) nd = np.where(valid, np.float32(1.0) - d, np.inf).clip(max=1.0) nd = np.where(nd >= nd_cutoff, np.inf, nd) best_e = np.argmin(nd, axis=1) best_cost = nd[np.arange(u_idx.size), best_e] best_c = nb_charts_safe[np.arange(u_idx.size), best_e] accept = (best_cost <= nd_threshold) & np.isfinite(best_cost) if not accept.any(): break pick_u = u_idx[accept] pick_c = best_c[accept] face_chart[pick_u] = pick_c for f, c in zip(pick_u, pick_c): chart_normal_sum[c] += face_normal[f] * face_area[f] chart_area[c] += face_area[f] # Orphan cleanup: leftover faces join their best-matching neighbor's chart. if (face_chart == -1).any() and chart_basis.shape[0] > 0: while True: orphans = np.nonzero(face_chart == -1)[0] if orphans.size == 0: break nbs = face_face[orphans] nbs_safe = np.where(nbs >= 0, nbs, 0) nb_charts = np.where(nbs >= 0, face_chart[nbs_safe], -1) valid = (nb_charts >= 0) if not valid.any(): break nb_charts_safe = np.where(valid, nb_charts, 0) nb_basis = chart_basis[nb_charts_safe] d = (face_normal[orphans][:, None, :] * nb_basis).sum(axis=-1) nd = np.where(valid, np.float32(1.0) - d, np.inf) best_e = np.argmin(nd, axis=1) best_c = nb_charts_safe[np.arange(orphans.size), best_e] assignable = valid.any(axis=1) if not assignable.any(): break assign_idx = orphans[assignable] assign_c = best_c[assignable] face_chart[assign_idx] = assign_c if (face_chart == -1).any(): new_singletons = np.nonzero(face_chart == -1)[0] for f in new_singletons: face_chart[int(f)] = chart_basis.shape[0] chart_basis = np.concatenate( [chart_basis, face_normal[int(f)].astype(np.float32)[None, :]], axis=0, ) return _renumber(face_chart, device) def segment_charts( mesh: MeshData, max_cost: float = DEFAULT_MAX_COST, w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION, w_roundness: float = DEFAULT_W_ROUNDNESS, w_straightness: float = DEFAULT_W_STRAIGHTNESS, target_chart_count: int = 0, ) -> Tensor: """Segment mesh into charts. Returns face -> chart_id.""" return _segment_charts_fast( mesh, max_cost=max_cost, w_normal_deviation=w_normal_deviation, w_roundness=w_roundness, w_straightness=w_straightness, target_chart_count=target_chart_count, ) # ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ---- def _combine_normal_cones( axis_a: Tensor, half_a: Tensor, axis_b: Tensor, half_b: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """Merge two normal cones along the great circle from axis_a; returns (combined_axis, combined_half_angle, axis_angle).""" cos_angle = (axis_a * axis_b).sum(dim=-1).clamp(-1.0, 1.0) axis_angle = torch.acos(cos_angle) new_low = torch.minimum(-half_a, axis_angle - half_b) new_high = torch.maximum(half_a, axis_angle + half_b) new_half = (new_high - new_low) * 0.5 rot_angle = (new_high + new_low) * 0.5 b_perp = axis_b - axis_a * cos_angle.unsqueeze(-1) b_perp_norm = b_perp.norm(dim=-1, keepdim=True).clamp_min(1e-12) b_perp_unit = b_perp / b_perp_norm new_axis = ( axis_a * torch.cos(rot_angle).unsqueeze(-1) + b_perp_unit * torch.sin(rot_angle).unsqueeze(-1) ) new_axis_norm = new_axis.norm(dim=-1, keepdim=True).clamp_min(1e-12) new_axis = new_axis / new_axis_norm return new_axis, new_half, axis_angle def _build_chart_edges( face_face: Tensor, chart_id: Tensor, face_edge_len: Tensor, ) -> Tuple[Tensor, Tensor]: """Build chart-edge list (chart_pairs[E,2] with a= 0 f_idx = f_idx[valid] nb = nb[valid] el = face_edge_len.flatten()[valid] ca = chart_id[f_idx] cb = chart_id[nb] diff = ca != cb ca = ca[diff] cb = cb[diff] el = el[diff] if ca.numel() == 0: return ( torch.empty((0, 2), dtype=torch.long, device=device), torch.empty(0, device=device), ) lo = torch.minimum(ca, cb) hi = torch.maximum(ca, cb) V = int(chart_id.max().item()) + 1 key = lo * V + hi sort_idx = torch.argsort(key) sorted_key = key[sort_idx] sorted_lo = lo[sort_idx] sorted_hi = hi[sort_idx] sorted_el = el[sort_idx] unique_key, inverse, counts = torch.unique( sorted_key, return_inverse=True, return_counts=True ) n_unique = unique_key.shape[0] reduced_el = torch.zeros(n_unique, device=device, dtype=el.dtype) reduced_el.scatter_add_(0, inverse, sorted_el) first_idx = torch.cat([ torch.zeros(1, dtype=torch.long, device=device), counts.cumsum(0)[:-1], ]) pair_lo = sorted_lo[first_idx] pair_hi = sorted_hi[first_idx] chart_pairs = torch.stack([pair_lo, pair_hi], dim=1) return chart_pairs, reduced_el def cluster_charts_pec( mesh: MeshData, target_chart_count: int = 0, max_cost: float = 0.7, area_penalty_weight: float = 0.0, roundness_weight: float = 0.0, max_iters: int = 1024, ) -> Tensor: """Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg).""" device = mesh.faces.device F = mesh.faces.shape[0] faces = mesh.faces.to(torch.long) vertices = mesh.vertices.to(torch.float32) face_normal = mesh.face_normal.to(torch.float32) face_area = mesh.face_area.to(torch.float32) face_face = mesh.face_face.to(torch.long) face_edge_len = face_edge_lengths(vertices, faces) chart_id = torch.arange(F, dtype=torch.long, device=device) chart_axis = face_normal.clone() chart_half = torch.zeros(F, dtype=torch.float32, device=device) chart_area = face_area.clone() chart_perim = face_edge_len.sum(dim=1).clone() for it in range(max_iters): edges, edge_len = _build_chart_edges(face_face, chart_id, face_edge_len) if edges.shape[0] == 0: break a = edges[:, 0] b = edges[:, 1] axis_a = chart_axis[a] axis_b = chart_axis[b] half_a = chart_half[a] half_b = chart_half[b] _, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b) cost = new_half.clone() if area_penalty_weight > 0.0: new_area = chart_area[a] + chart_area[b] cost = cost + area_penalty_weight * new_area if roundness_weight > 0.0: new_area_r = chart_area[a] + chart_area[b] new_perim_r = chart_perim[a] + chart_perim[b] - 2.0 * edge_len cost = cost + roundness_weight * (new_perim_r * new_perim_r) / new_area_r.clamp_min(1e-12) # Pack (cost, edge_id) so scatter_reduce amin picks the right edge. E = edges.shape[0] N = int(chart_id.max().item()) + 1 edge_ids = torch.arange(E, dtype=torch.long, device=device) cost_i32 = torch.clamp(cost * 1e6, max=2e9).to(torch.int64) key = (cost_i32 << 32) | edge_ids chart_min = torch.full((N,), (2**62), dtype=torch.long, device=device) chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True) chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True) # Mutual-min collapse: each chart in at most one merge per iter. is_a_min = chart_min[a] == key is_b_min = chart_min[b] == key within = cost <= max_cost winners = is_a_min & is_b_min & within n_merge = int(winners.sum().item()) if n_merge == 0: break win_a = a[winners] win_b = b[winners] win_el = edge_len[winners] axis_a_w = chart_axis[win_a] half_a_w = chart_half[win_a] axis_b_w = chart_axis[win_b] half_b_w = chart_half[win_b] new_axis, new_half_w, _ = _combine_normal_cones( axis_a_w, half_a_w, axis_b_w, half_b_w, ) chart_axis[win_a] = new_axis chart_half[win_a] = new_half_w chart_area[win_a] = chart_area[win_a] + chart_area[win_b] chart_perim[win_a] = chart_perim[win_a] + chart_perim[win_b] - 2.0 * win_el remap = torch.arange(N, dtype=torch.long, device=device) remap[win_b] = win_a chart_id = remap[chart_id] _, inverse = torch.unique(chart_id, sorted=True, return_inverse=True) return inverse