Fix normal smoothing and some cleanup

This commit is contained in:
kijai 2026-07-03 00:57:37 +03:00
parent d635cc412d
commit 429b13f97c
4 changed files with 132 additions and 258 deletions

View File

@ -818,17 +818,32 @@ def _quality_checks_fused(
return flip_out, skinny_out, link_out return flip_out, skinny_out, link_out
def _compute_vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: def _compute_vertex_normals(verts: torch.Tensor, faces: torch.Tensor, weld: bool = True) -> torch.Tensor:
"""Area-weighted smooth vertex normals. `weld` averages face normals across vertices that
share a position (UV-seam duplicates from unwrapping) so both sides of a seam get one
identical normal otherwise a visible shading seam appears in the exported GLB."""
if faces.numel() == 0: if faces.numel() == 0:
return torch.zeros_like(verts) return torch.zeros_like(verts)
faces_long = faces.to(torch.int64) faces_long = faces.to(torch.int64)
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2] i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
v0, v1, v2 = verts[i0], verts[i1], verts[i2] v0, v1, v2 = verts[i0], verts[i1], verts[i2]
fn = torch.cross(v1 - v0, v2 - v0, dim=-1) fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
vn = torch.zeros_like(verts) if weld and verts.shape[0]:
vn.scatter_add_(0, i0.unsqueeze(-1).expand_as(fn), fn) # Group coincident positions (quantized to ~1e-5 of the bbox) into one shared normal.
vn.scatter_add_(0, i1.unsqueeze(-1).expand_as(fn), fn) lo = verts.min(0).values
vn.scatter_add_(0, i2.unsqueeze(-1).expand_as(fn), fn) inv_tol = 1.0 / (float((verts.max(0).values - lo).max().clamp_min(1e-9)) * 1e-5)
q = ((verts - lo) * inv_tol).round().to(torch.int64)
_, group = torch.unique(q, dim=0, return_inverse=True)
acc = torch.zeros((int(group.max()) + 1, 3), dtype=verts.dtype, device=verts.device)
acc.scatter_add_(0, group[i0].unsqueeze(-1).expand_as(fn), fn)
acc.scatter_add_(0, group[i1].unsqueeze(-1).expand_as(fn), fn)
acc.scatter_add_(0, group[i2].unsqueeze(-1).expand_as(fn), fn)
vn = acc[group]
else:
vn = torch.zeros_like(verts)
vn.scatter_add_(0, i0.unsqueeze(-1).expand_as(fn), fn)
vn.scatter_add_(0, i1.unsqueeze(-1).expand_as(fn), fn)
vn.scatter_add_(0, i2.unsqueeze(-1).expand_as(fn), fn)
return torch.nn.functional.normalize(vn, p=2, dim=-1, eps=1e-6) return torch.nn.functional.normalize(vn, p=2, dim=-1, eps=1e-6)

View File

@ -28,112 +28,6 @@ DEFAULT_MAX_COST = 2.0
NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75° NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75°
@njit(cache=True, fastmath=False)
def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
F = face_normal.shape[0]
raw = np.zeros(F, dtype=np.float32)
for f in range(F):
nx = face_normal[f, 0]
ny = face_normal[f, 1]
nz = face_normal[f, 2]
s = np.float32(0.0)
for e in range(3):
nb = face_face[f, e]
if nb < 0:
continue
mx = face_normal[nb, 0]
my = face_normal[nb, 1]
mz = face_normal[nb, 2]
d = nx*mx + ny*my + nz*mz
s += np.float32(1.0) - d
raw[f] = s
return raw
def _face_curvature_numpy(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
nb_safe = np.maximum(face_face, 0)
nb_normal = face_normal[nb_safe]
d = (face_normal[:, None, :] * nb_normal).sum(axis=-1)
contrib = np.where(face_face >= 0, np.float32(1.0) - d, np.float32(0.0))
return contrib.sum(axis=1).astype(np.float32)
@njit(cache=True, fastmath=False)
def _farthest_point_seeds_jit(
face_centroid: np.ndarray, face_area: np.ndarray, face_weight: np.ndarray,
initial_seeds: np.ndarray, k_target: int,
):
F = face_centroid.shape[0]
INF = np.float32(1e30)
min_dist = np.full(F, INF, dtype=np.float32)
seeds = np.empty(k_target, dtype=np.int64)
n_seeds = 0
for i in range(initial_seeds.shape[0]):
s = initial_seeds[i]
if s < 0 or n_seeds >= k_target:
continue
seeds[n_seeds] = s
n_seeds += 1
sx = face_centroid[s, 0]
sy = face_centroid[s, 1]
sz = face_centroid[s, 2]
for f in range(F):
dx = face_centroid[f, 0] - sx
dy = face_centroid[f, 1] - sy
dz = face_centroid[f, 2] - sz
d2 = dx*dx + dy*dy + dz*dz
if d2 < min_dist[f]:
min_dist[f] = d2
while n_seeds < k_target:
best_f = -1
best_score = np.float32(-1.0)
for f in range(F):
d = min_dist[f]
if d >= INF * np.float32(0.5):
continue
score = d * face_weight[f]
if score > best_score:
best_score = score
best_f = f
if best_f < 0:
break
seeds[n_seeds] = best_f
n_seeds += 1
sx = face_centroid[best_f, 0]
sy = face_centroid[best_f, 1]
sz = face_centroid[best_f, 2]
for f in range(F):
dx = face_centroid[f, 0] - sx
dy = face_centroid[f, 1] - sy
dz = face_centroid[f, 2] - sz
d2 = dx*dx + dy*dy + dz*dz
if d2 < min_dist[f]:
min_dist[f] = d2
return seeds[:n_seeds]
def _farthest_point_seeds_numpy(
face_centroid: np.ndarray, initial_seeds: np.ndarray, k_target: int,
):
F = face_centroid.shape[0]
min_dist = np.full(F, np.inf, dtype=np.float32)
seeds: List[int] = []
for s in initial_seeds:
if s < 0 or len(seeds) >= k_target:
continue
seeds.append(int(s))
d = ((face_centroid - face_centroid[s])**2).sum(axis=-1)
min_dist = np.minimum(min_dist, d)
while len(seeds) < k_target:
best = int(np.argmax(min_dist))
if not np.isfinite(min_dist[best]) or min_dist[best] <= 0:
break
seeds.append(best)
d = ((face_centroid - face_centroid[best])**2).sum(axis=-1)
min_dist = np.minimum(min_dist, d)
return np.asarray(seeds, dtype=np.int64)
@njit(cache=True, fastmath=False) @njit(cache=True, fastmath=False)
def _cost_grow_iter_jit( def _cost_grow_iter_jit(
face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray, face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray,
@ -259,15 +153,14 @@ def _renumber(face_chart: np.ndarray, device) -> Tensor:
return torch.from_numpy(out).to(device) return torch.from_numpy(out).to(device)
def _segment_charts_fast( def segment_charts(
mesh: MeshData, mesh: MeshData,
max_cost: float, max_cost: float = DEFAULT_MAX_COST,
w_normal_deviation: float, w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION,
w_roundness: float = DEFAULT_W_ROUNDNESS, w_roundness: float = DEFAULT_W_ROUNDNESS,
w_straightness: float = DEFAULT_W_STRAIGHTNESS, w_straightness: float = DEFAULT_W_STRAIGHTNESS,
target_chart_count: int = 0,
) -> Tensor: ) -> Tensor:
"""Parallel batch cost-grow; target_chart_count 0 = adaptive seeding, >0 = K curvature-weighted FPS seeds.""" """Segment mesh into charts (parallel batch cost-grow). Returns face -> chart_id."""
F = mesh.faces.shape[0] F = mesh.faces.shape[0]
device = mesh.faces.device device = mesh.faces.device
if F == 0: if F == 0:
@ -291,37 +184,9 @@ def _segment_charts_fast(
else: else:
initial_seeds = np.empty(0, dtype=np.int64) initial_seeds = np.empty(0, dtype=np.int64)
adaptive_seeding = target_chart_count <= 0 seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()]
if adaptive_seeding: if not seed_faces:
seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()] seed_faces = [0]
if not seed_faces:
seed_faces = [0]
else:
if _HAVE_NUMBA:
curvature_raw = _face_curvature_jit(face_normal, face_face)
else:
curvature_raw = _face_curvature_numpy(face_normal, face_face)
cmax = float(curvature_raw.max()) if curvature_raw.size else 0.0
if cmax > 1e-6:
face_weight = (np.float32(1.0) + np.float32(50.0) *
(curvature_raw / np.float32(cmax))).astype(np.float32)
else:
face_weight = np.ones(F, dtype=np.float32)
n_comp = int(initial_seeds.size)
if n_comp < int(target_chart_count):
target_seeds = int(target_chart_count)
else:
target_seeds = n_comp + max(int(target_chart_count) // 4, 8)
target_seeds = min(target_seeds, F)
if _HAVE_NUMBA:
seeds_arr = _farthest_point_seeds_jit(
face_centroid, face_area, face_weight, initial_seeds, target_seeds,
)
else:
seeds_arr = _farthest_point_seeds_numpy(
face_centroid, initial_seeds, target_seeds,
)
seed_faces = [int(s) for s in seeds_arr.tolist()]
K = len(seed_faces) K = len(seed_faces)
chart_basis = np.zeros((K, 3), dtype=np.float32) chart_basis = np.zeros((K, 3), dtype=np.float32)
@ -345,10 +210,9 @@ def _segment_charts_fast(
return _renumber(face_chart, device) return _renumber(face_chart, device)
min_dist_to_seed = np.full(F, np.inf, dtype=np.float32) min_dist_to_seed = np.full(F, np.inf, dtype=np.float32)
if adaptive_seeding: for sf in seed_faces:
for sf in seed_faces: d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1)
d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1) min_dist_to_seed = np.minimum(min_dist_to_seed, d)
min_dist_to_seed = np.minimum(min_dist_to_seed, d)
if _HAVE_NUMBA: if _HAVE_NUMBA:
# Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg. # Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg.
@ -375,8 +239,6 @@ def _segment_charts_fast(
break break
if (face_chart == -1).sum() == 0: if (face_chart == -1).sum() == 0:
break break
if not adaptive_seeding:
break
if chart_basis.shape[0] >= max_total_charts: if chart_basis.shape[0] >= max_total_charts:
break break
unassigned_mask = face_chart == -1 unassigned_mask = face_chart == -1
@ -462,24 +324,6 @@ def _segment_charts_fast(
return _renumber(face_chart, device) return _renumber(face_chart, device)
def segment_charts(
mesh: MeshData,
max_cost: float = DEFAULT_MAX_COST,
w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION,
w_roundness: float = DEFAULT_W_ROUNDNESS,
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
target_chart_count: int = 0,
) -> Tensor:
"""Segment mesh into charts. Returns face -> chart_id."""
return _segment_charts_fast(
mesh, max_cost=max_cost,
w_normal_deviation=w_normal_deviation,
w_roundness=w_roundness,
w_straightness=w_straightness,
target_chart_count=target_chart_count,
)
# ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ---- # ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ----
def _combine_normal_cones( def _combine_normal_cones(
axis_a: Tensor, half_a: Tensor, axis_a: Tensor, half_a: Tensor,
@ -558,10 +402,7 @@ def _build_chart_edges(
def cluster_charts_pec( def cluster_charts_pec(
mesh: MeshData, mesh: MeshData,
target_chart_count: int = 0,
max_cost: float = 0.7, max_cost: float = 0.7,
area_penalty_weight: float = 0.0,
roundness_weight: float = 0.0,
max_iters: int = 1024, max_iters: int = 1024,
) -> Tensor: ) -> Tensor:
"""Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg).""" """Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg)."""
@ -570,7 +411,6 @@ def cluster_charts_pec(
faces = mesh.faces.to(torch.long) faces = mesh.faces.to(torch.long)
vertices = mesh.vertices.to(torch.float32) vertices = mesh.vertices.to(torch.float32)
face_normal = mesh.face_normal.to(torch.float32) face_normal = mesh.face_normal.to(torch.float32)
face_area = mesh.face_area.to(torch.float32)
face_face = mesh.face_face.to(torch.long) face_face = mesh.face_face.to(torch.long)
face_edge_len = face_edge_lengths(vertices, faces) face_edge_len = face_edge_lengths(vertices, faces)
@ -578,11 +418,9 @@ def cluster_charts_pec(
chart_id = torch.arange(F, dtype=torch.long, device=device) chart_id = torch.arange(F, dtype=torch.long, device=device)
chart_axis = face_normal.clone() chart_axis = face_normal.clone()
chart_half = torch.zeros(F, dtype=torch.float32, device=device) chart_half = torch.zeros(F, dtype=torch.float32, device=device)
chart_area = face_area.clone()
chart_perim = face_edge_len.sum(dim=1).clone()
for it in range(max_iters): for it in range(max_iters):
edges, edge_len = _build_chart_edges(face_face, chart_id, face_edge_len) edges, _ = _build_chart_edges(face_face, chart_id, face_edge_len)
if edges.shape[0] == 0: if edges.shape[0] == 0:
break break
@ -594,13 +432,6 @@ def cluster_charts_pec(
half_b = chart_half[b] half_b = chart_half[b]
_, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b) _, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b)
cost = new_half.clone() cost = new_half.clone()
if area_penalty_weight > 0.0:
new_area = chart_area[a] + chart_area[b]
cost = cost + area_penalty_weight * new_area
if roundness_weight > 0.0:
new_area_r = chart_area[a] + chart_area[b]
new_perim_r = chart_perim[a] + chart_perim[b] - 2.0 * edge_len
cost = cost + roundness_weight * (new_perim_r * new_perim_r) / new_area_r.clamp_min(1e-12)
# Pack (cost, edge_id) so scatter_reduce amin picks the right edge. # Pack (cost, edge_id) so scatter_reduce amin picks the right edge.
E = edges.shape[0] E = edges.shape[0]
@ -612,11 +443,12 @@ def cluster_charts_pec(
chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True) chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True)
chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True) chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True)
# Mutual-min collapse: each chart in at most one merge per iter. # Mutual-min collapse: each chart in at most one merge per iter (winners are disjoint pairs).
is_a_min = chart_min[a] == key is_a_min = chart_min[a] == key
is_b_min = chart_min[b] == key is_b_min = chart_min[b] == key
mutual = is_a_min & is_b_min
within = cost <= max_cost within = cost <= max_cost
winners = is_a_min & is_b_min & within winners = mutual & within
n_merge = int(winners.sum().item()) n_merge = int(winners.sum().item())
if n_merge == 0: if n_merge == 0:
@ -624,7 +456,6 @@ def cluster_charts_pec(
win_a = a[winners] win_a = a[winners]
win_b = b[winners] win_b = b[winners]
win_el = edge_len[winners]
axis_a_w = chart_axis[win_a] axis_a_w = chart_axis[win_a]
half_a_w = chart_half[win_a] half_a_w = chart_half[win_a]
@ -635,8 +466,6 @@ def cluster_charts_pec(
) )
chart_axis[win_a] = new_axis chart_axis[win_a] = new_axis
chart_half[win_a] = new_half_w chart_half[win_a] = new_half_w
chart_area[win_a] = chart_area[win_a] + chart_area[win_b]
chart_perim[win_a] = chart_perim[win_a] + chart_perim[win_b] - 2.0 * win_el
remap = torch.arange(N, dtype=torch.long, device=device) remap = torch.arange(N, dtype=torch.long, device=device)
remap[win_b] = win_a remap[win_b] = win_a

View File

@ -2,7 +2,7 @@ import torch
import numpy as np import numpy as np
import math import math
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest import ComfyExtension, IO, Types, io
import copy import copy
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
@ -18,6 +18,9 @@ from tqdm import tqdm
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import connected_components
from scipy.spatial import cKDTree from scipy.spatial import cKDTree
import scipy.ndimage as ndi
MeshCameras = io.Custom("MESH_CAMERAS") # carries the camera set from RenderMeshViews → BakeViewsToTexture
def get_mesh_batch_item(mesh, index): def get_mesh_batch_item(mesh, index):
if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None: if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None:
@ -460,7 +463,8 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors
try: try:
vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution) vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution)
except Exception as e: except Exception as e:
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear failed ({e}); falling back to CPU") comfy.model_management.raise_non_oom(e) # only fall back on OOM; surface real errors
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear ran out of memory ({e}); falling back to CPU")
vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution) vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution)
if not ok.all(): if not ok.all():
vals[~ok] = _nearest(valid_positions[~ok]) # no occupied neighbour vals[~ok] = _nearest(valid_positions[~ok]) # no occupied neighbour
@ -657,25 +661,46 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64, return_face=False):
return bestp return bestp
def _back_project_positions(position_map, mask, ref_v, ref_f): def _back_project_positions(position_map, mask, ref_v, ref_f, max_query_res=1024):
"""Snap covered texels onto the reference mesh's true surface (pure-torch BVH, no """Snap covered texels onto the reference mesh's true surface (pure-torch BVH, no
cumesh/scipy/trimesh) so the voxel field is sampled at full detail, not along flat cumesh/scipy/trimesh) so the voxel field is sampled at full detail, not along flat
triangle chords. Returns a new position_map.""" triangle chords. Returns a new position_map.
valid = np.ascontiguousarray(position_map[mask].astype(np.float32)) """
if valid.shape[0] == 0: if not mask.any():
return position_map return position_map
dev = comfy.model_management.get_torch_device() dev = comfy.model_management.get_torch_device()
rv = ref_v.detach().to(dev).float() rv = ref_v.detach().to(dev).float()
rf = ref_f.detach().to(dev).long() rf = ref_f.detach().to(dev).long()
tri = rv[rf] tri = rv[rf]
Q = torch.from_numpy(valid).to(dev)
bvh = _build_triangle_bvh(tri) bvh = _build_triangle_bvh(tri)
bp = _closest_points_on_mesh_bvh(Q, tri, bvh)
def _closest(pts_np):
return _closest_points_on_mesh_bvh(
torch.from_numpy(np.ascontiguousarray(pts_np.astype(np.float32))).to(dev), tri, bvh
).detach().cpu().numpy().astype(np.float32)
H, W, _ = position_map.shape
stride = max(1, int(math.ceil(max(H, W) / float(max_query_res))))
if stride == 1 or not mask[::stride, ::stride].any():
out = position_map.copy()
out[mask] = _closest(position_map[mask]).astype(position_map.dtype)
return out
# Low-res correction, then bilinear upsample to full resolution.
pos_lo = position_map[::stride, ::stride]
mask_lo = mask[::stride, ::stride]
Hl, Wl = mask_lo.shape
corr_lo = np.zeros((Hl, Wl, 3), dtype=np.float32)
corr_lo[mask_lo] = _closest(pos_lo[mask_lo]) - pos_lo[mask_lo].astype(np.float32)
inds = ndi.distance_transform_edt(~mask_lo, return_distances=False, return_indices=True)
corr_lo = corr_lo[tuple(inds)] # extrapolate into gutter (nearest)
corr = torch.nn.functional.interpolate(
torch.from_numpy(np.ascontiguousarray(corr_lo)).permute(2, 0, 1)[None].to(dev),
size=(H, W), mode="bilinear", align_corners=False,
)[0].permute(1, 2, 0).cpu().numpy()
out = position_map.copy() out = position_map.copy()
out[mask] = bp.detach().cpu().numpy().astype(position_map.dtype) out[mask] = position_map[mask] + corr[mask]
return out return out
@ -703,6 +728,7 @@ def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64):
nmin, nmax = bvh['nmin'], bvh['nmax'] nmin, nmax = bvh['nmin'], bvh['nmax']
left, right, order = bvh['left'], bvh['right'], bvh['order'] left, right, order = bvh['left'], bvh['right'], bvh['order']
inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs) inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs)
tmaxN = tmax if torch.is_tensor(tmax) else torch.full((N,), float(tmax), device=dev) # per-ray far bound
hit = torch.zeros(N, dtype=torch.bool, device=dev) hit = torch.zeros(N, dtype=torch.bool, device=dev)
# int32 stack: node indices fit in 31 bits and this [N, max_stack] array dominates memory. # int32 stack: node indices fit in 31 bits and this [N, max_stack] array dominates memory.
stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev) stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev)
@ -710,24 +736,24 @@ def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64):
stack[:, 0] = 0 stack[:, 0] = 0
active = torch.arange(N, device=dev) active = torch.arange(N, device=dev)
def slab(node, o, i): def slab(node, o, i, tmx):
t1 = (nmin[node] - o) * i t1 = (nmin[node] - o) * i
t2 = (nmax[node] - o) * i t2 = (nmax[node] - o) * i
tnear = torch.minimum(t1, t2).amax(-1) tnear = torch.minimum(t1, t2).amax(-1)
tfar = torch.maximum(t1, t2).amin(-1) tfar = torch.maximum(t1, t2).amin(-1)
return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmax) & (tfar >= tmin) return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmx) & (tfar >= tmin)
while active.numel() > 0: while active.numel() > 0:
a = active a = active
node = stack[a, sp[a] - 1] node = stack[a, sp[a] - 1]
sp[a] = sp[a] - 1 sp[a] = sp[a] - 1
within = slab(node, orig[a], inv[a]) within = slab(node, orig[a], inv[a], tmaxN[a])
isleaf = node >= LEAF isleaf = node >= LEAF
lv = within & isleaf lv = within & isleaf
if bool(lv.any()): if bool(lv.any()):
ga = a[lv] ga = a[lv]
tt = tri[order[node[lv] - LEAF]] tt = tri[order[node[lv] - LEAF]]
h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmax) h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmaxN[ga])
hit[ga[h]] = True hit[ga[h]] = True
iv = within & ~isleaf iv = within & ~isleaf
if bool(iv.any()): if bool(iv.any()):
@ -838,7 +864,7 @@ def _bake_ambient_occlusion(high_v, high_f, low_v_np, low_f_np, low_uv_np, low_n
# memory for no gain; floor keeps tiny GPUs from thrashing into too many chunks. # memory for no gain; floor keeps tiny GPUs from thrashing into too many chunks.
try: try:
free = torch.cuda.mem_get_info(dev)[0] if dev.type == "cuda" else (2 << 30) free = torch.cuda.mem_get_info(dev)[0] if dev.type == "cuda" else (2 << 30)
except Exception: except RuntimeError:
free = 2 << 30 free = 2 << 30
ray_chunk = int(min(1 << 22, max(1 << 20, (free * 0.25) / (num_samples * 4 + 200)))) ray_chunk = int(min(1 << 22, max(1 << 20, (free * 0.25) / (num_samples * 4 + 200))))
face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution) face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution)
@ -1059,11 +1085,8 @@ def _jfa_fill_gpu(img01, mask):
return filled.cpu().numpy() return filled.cpu().numpy()
def _seam_fill(img01, mask, inpaint_radius): def _seam_fill(img01, mask):
"""Fill UV-gutter texels (so seams don't pull in black) via JFA. `inpaint_radius<=0` """Fill UV-gutter texels (so seams don't pull in black) via JFA nearest-coverage."""
disables; the radius value itself is ignored (JFA fills all uncovered by nearest)."""
if inpaint_radius <= 0:
return img01
return _jfa_fill_gpu(img01, mask) return _jfa_fill_gpu(img01, mask)
@ -1092,7 +1115,7 @@ def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None):
def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
resolution, texture_size, uvs, inpaint_radius=3, resolution, texture_size, uvs,
normalize_uvs=True, reference=None, pbar=None): normalize_uvs=True, reference=None, pbar=None):
"""Bake a baseColor (+ optional metallicRoughness) texture: rasterize in UV space, """Bake a baseColor (+ optional metallicRoughness) texture: rasterize in UV space,
sample each texel from the sparse voxel volume. `uvs` (N,2) is the existing layout, sample each texel from the sparse voxel volume. `uvs` (N,2) is the existing layout,
@ -1109,7 +1132,6 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
v_np = vertices.detach().cpu().numpy().astype(np.float32) v_np = vertices.detach().cpu().numpy().astype(np.float32)
f_np = faces.detach().cpu().numpy().astype(np.uint32) f_np = faces.detach().cpu().numpy().astype(np.uint32)
fcount = int(f_np.shape[0])
uv_np = uvs.detach().cpu().numpy().astype(np.float32) uv_np = uvs.detach().cpu().numpy().astype(np.float32)
if uv_np.shape[0] != v_np.shape[0]: if uv_np.shape[0] != v_np.shape[0]:
@ -1117,13 +1139,6 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 " f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 "
f"with vertices ({v_np.shape[0]})." f"with vertices ({v_np.shape[0]})."
) )
uv_min = uv_np.min(axis=0)
uv_max = uv_np.max(axis=0)
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
f"{fcount} faces")
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ") uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ")
new_verts, new_faces, new_uvs = v_np, f_np, uv_np new_verts, new_faces, new_uvs = v_np, f_np, uv_np
@ -1151,12 +1166,12 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
roughness = attrs[..., 4:5] if C >= 5 else None roughness = attrs[..., 4:5] if C >= 5 else None
# alpha (idx 5) ignored — meshes kept opaque (upstream OPAQUE alpha_mode). # alpha (idx 5) ignored — meshes kept opaque (upstream OPAQUE alpha_mode).
base_color = _seam_fill(np.ascontiguousarray(base_color), mask, inpaint_radius) base_color = _seam_fill(np.ascontiguousarray(base_color), mask)
mr_image = None mr_image = None
if has_pbr: if has_pbr:
# glTF metallicRoughness: R unused, G=roughness, B=metallic. # glTF metallicRoughness: R unused, G=roughness, B=metallic.
mr = np.concatenate([np.zeros_like(roughness), roughness, metallic], axis=-1) mr = np.concatenate([np.zeros_like(roughness), roughness, metallic], axis=-1)
mr_image = _seam_fill(np.ascontiguousarray(mr), mask, inpaint_radius) mr_image = _seam_fill(np.ascontiguousarray(mr), mask)
device = vertices.device device = vertices.device
out_v = torch.from_numpy(new_verts).to(device=device, dtype=torch.float32) out_v = torch.from_numpy(new_verts).to(device=device, dtype=torch.float32)
@ -1195,8 +1210,8 @@ class BakeTextureFromVoxel(IO.ComfyNode):
inputs=[ inputs=[
IO.Mesh.Input("mesh"), IO.Mesh.Input("mesh"),
IO.Voxel.Input("voxel_colors"), IO.Voxel.Input("voxel_colors"),
IO.Int.Input("texture_size", default=1024, min=64, max=8192, IO.Int.Input("texture_size", default=2048, min=64, max=8192,
tooltip="Square texture resolution."), tooltip="Square UV atlas resolution."),
IO.Mesh.Input("reference_mesh", optional=True, IO.Mesh.Input("reference_mesh", optional=True,
tooltip=( tooltip=(
"Optional dense pre-decimation mesh; back-projects each texel onto its " "Optional dense pre-decimation mesh; back-projects each texel onto its "
@ -1211,13 +1226,11 @@ class BakeTextureFromVoxel(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, mesh, voxel_colors, texture_size, reference_mesh=None): def execute(cls, mesh, voxel_colors, texture_size, reference_mesh=None):
# Matches official to_glb; effectively on/off since the gutter fill ignores the value.
inpaint_radius = 3
voxels = voxel_colors voxels = voxel_colors
coords = voxels.data coords = voxels.data
colors = voxels.voxel_colors colors = voxels.voxel_colors
resolution = voxels.resolution resolution = voxels.resolution
mesh_uvs = getattr(mesh, "uvs", None) mesh_uvs = mesh.uvs
if mesh_uvs is None: if mesh_uvs is None:
raise ValueError( raise ValueError(
"BakeTextureFromVoxel: input mesh has no UVs. This node bakes onto the " "BakeTextureFromVoxel: input mesh has no UVs. This node bakes onto the "
@ -1249,7 +1262,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn( _bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
v_i, f_i, item_coords, item_colors, v_i, f_i, item_coords, item_colors,
resolution=resolution, texture_size=texture_size, resolution=resolution, texture_size=texture_size,
uvs=ev_i, inpaint_radius=inpaint_radius, uvs=ev_i,
reference=ref_i, pbar=pbar, reference=ref_i, pbar=pbar,
) )
out_tex.append(bt) out_tex.append(bt)
@ -1275,7 +1288,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn( _bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
v0, f0, coords, colors, v0, f0, coords, colors,
resolution=resolution, texture_size=texture_size, resolution=resolution, texture_size=texture_size,
uvs=ev0, inpaint_radius=inpaint_radius, uvs=ev0,
reference=ref0, pbar=pbar, reference=ref0, pbar=pbar,
) )
base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0) base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0)
@ -1317,9 +1330,9 @@ class MeshTextureToImage(IO.ComfyNode):
t = t.unsqueeze(0) t = t.unsqueeze(0)
return t return t
base = _as_image(getattr(mesh, "texture", None)) base = _as_image(mesh.texture)
mr = _as_image(getattr(mesh, "metallic_roughness", None)) mr = _as_image(mesh.metallic_roughness)
normal_map = _as_image(getattr(mesh, "normal_map", None)) normal_map = _as_image(mesh.normal_map)
if base is None: if base is None:
raise ValueError( raise ValueError(
@ -1335,7 +1348,7 @@ class MeshTextureToImage(IO.ComfyNode):
roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous() roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous()
# R is real occlusion only if AO was baked; else it's the unused zero channel, which as # R is real occlusion only if AO was baked; else it's the unused zero channel, which as
# "occlusion" would read fully-dark — so report white unless occlusion_in_mr is set. # "occlusion" would read fully-dark — so report white unless occlusion_in_mr is set.
if getattr(mesh, "occlusion_in_mr", False): if mesh.occlusion_in_mr:
occlusion = mr[..., 0:1].expand(-1, -1, -1, 3).contiguous() occlusion = mr[..., 0:1].expand(-1, -1, -1, 3).contiguous()
else: else:
occlusion = torch.ones_like(base) occlusion = torch.ones_like(base)
@ -1365,7 +1378,7 @@ class ApplyTextureToMesh(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, mesh, base_color, metallic=None, roughness=None, occlusion=None, normal_map=None): def execute(cls, mesh, base_color, metallic=None, roughness=None, occlusion=None, normal_map=None):
mesh_uvs = getattr(mesh, "uvs", None) mesh_uvs = mesh.uvs
if mesh_uvs is None: if mesh_uvs is None:
raise ValueError( raise ValueError(
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh " "ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
@ -1391,7 +1404,7 @@ class ApplyTextureToMesh(IO.ComfyNode):
# and export the smooth normals the TBN was built on — without a NORMAL attribute the # and export the smooth normals the TBN was built on — without a NORMAL attribute the
# viewer shades flat and the tangent-space detail fights the faceting. # viewer shades flat and the tangent-space detail fights the faceting.
dev = comfy.model_management.get_torch_device() dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(mesh, "normals", None) low_n_attr = mesh.normals
B = int(mesh.vertices.shape[0]) B = int(mesh.vertices.shape[0])
Nmax = int(mesh.vertices.shape[1]) if mesh.vertices.ndim == 3 else int(mesh.vertices.shape[0]) Nmax = int(mesh.vertices.shape[1]) if mesh.vertices.ndim == 3 else int(mesh.vertices.shape[0])
tangents_padded = torch.zeros((B, Nmax, 4), dtype=torch.float32) tangents_padded = torch.zeros((B, Nmax, 4), dtype=torch.float32)
@ -1467,7 +1480,7 @@ class BakeNormalMapFromMesh(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, low_poly, high_poly, resolution, cage_distance=0.05, ignore_backfaces=True): def execute(cls, low_poly, high_poly, resolution, cage_distance=0.05, ignore_backfaces=True):
low_uvs = getattr(low_poly, "uvs", None) low_uvs = low_poly.uvs
if low_uvs is None: if low_uvs is None:
raise ValueError( raise ValueError(
"BakeNormalMapFromMesh: low_poly has no UVs. Connect the UV-unwrapped " "BakeNormalMapFromMesh: low_poly has no UVs. Connect the UV-unwrapped "
@ -1475,8 +1488,8 @@ class BakeNormalMapFromMesh(IO.ComfyNode):
"onto existing UVs and never unwraps.") "onto existing UVs and never unwraps.")
dev = comfy.model_management.get_torch_device() dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(low_poly, "normals", None) low_n_attr = low_poly.normals
high_n_attr = getattr(high_poly, "normals", None) high_n_attr = high_poly.normals
B = int(low_poly.vertices.shape[0]) B = int(low_poly.vertices.shape[0])
h_batch = int(high_poly.vertices.shape[0]) h_batch = int(high_poly.vertices.shape[0])
@ -1549,13 +1562,13 @@ class BakeAmbientOcclusion(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, low_poly, high_poly, resolution, samples, max_distance, strength, bias): def execute(cls, low_poly, high_poly, resolution, samples, max_distance, strength, bias):
low_uvs = getattr(low_poly, "uvs", None) low_uvs = low_poly.uvs
if low_uvs is None: if low_uvs is None:
raise ValueError( raise ValueError(
"BakeAmbientOcclusion: low_poly has no UVs. Connect the UV-unwrapped low-poly " "BakeAmbientOcclusion: low_poly has no UVs. Connect the UV-unwrapped low-poly "
"(the same one used for the other bakes); this node never unwraps.") "(the same one used for the other bakes); this node never unwraps.")
dev = comfy.model_management.get_torch_device() dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(low_poly, "normals", None) low_n_attr = low_poly.normals
B = int(low_poly.vertices.shape[0]) B = int(low_poly.vertices.shape[0])
h_batch = int(high_poly.vertices.shape[0]) h_batch = int(high_poly.vertices.shape[0])
@ -1630,7 +1643,7 @@ class SetMeshMaterial(IO.ComfyNode):
base_color_r, base_color_g, base_color_b, metallic_factor, roughness_factor, base_color_r, base_color_g, base_color_b, metallic_factor, roughness_factor,
normal_scale, occlusion_strength, double_sided, emissive_texture=None): normal_scale, occlusion_strength, double_sided, emissive_texture=None):
out_mesh = copy.copy(mesh) out_mesh = copy.copy(mesh)
material = dict(getattr(mesh, "material", {}) or {}) # merge over any prior material material = dict(mesh.material or {}) # merge over any prior material
material.update({ material.update({
"emissive_factor": [float(emissive_r), float(emissive_g), float(emissive_b)], "emissive_factor": [float(emissive_r), float(emissive_g), float(emissive_b)],
"emissive_strength": float(emissive_strength), "emissive_strength": float(emissive_strength),
@ -2201,7 +2214,8 @@ class DecimateMesh(IO.ComfyNode):
if rc is not None: if rc is not None:
c = rc.to(src_device) c = rc.to(src_device)
except Exception as e: except Exception as e:
logging.warning(f"DecimateMesh: QEM simplify failed, passing mesh through unchanged: {e!r}") comfy.model_management.raise_non_oom(e) # surface real errors; only OOM passes through
logging.warning(f"DecimateMesh: QEM simplify ran out of memory, passing mesh through unchanged: {e!r}")
counts["out"] += int(f.shape[0]) counts["out"] += int(f.shape[0])
return v, f, c return v, f, c
@ -2318,7 +2332,8 @@ class RemeshMesh(IO.ComfyNode):
f = rf.to(src_device) f = rf.to(src_device)
c = rc.to(src_device) if rc is not None else None c = rc.to(src_device) if rc is not None else None
except Exception as e: except Exception as e:
logging.warning(f"RemeshMesh: remesh failed, passing mesh through unchanged: {e!r}") comfy.model_management.raise_non_oom(e) # surface real errors; only OOM passes through
logging.warning(f"RemeshMesh: remesh ran out of memory, passing mesh through unchanged: {e!r}")
counts["out"] += int(f.shape[0]) counts["out"] += int(f.shape[0])
return v, f, c return v, f, c
@ -2409,9 +2424,9 @@ def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance
if segmenter == "pec": if segmenter == "pec":
if mesh.faces.device.type != "cuda": if mesh.faces.device.type != "cuda":
raise RuntimeError("segmenter='pec' requires a CUDA mesh; use 'adaptive' for CPU.") raise RuntimeError("segmenter='pec' requires a CUDA mesh; use 'adaptive' for CPU.")
face_chart = _uv_seg.cluster_charts_pec(mesh, target_chart_count=0, max_cost=1.0) face_chart = _uv_seg.cluster_charts_pec(mesh, max_cost=1.0)
elif segmenter == "adaptive": elif segmenter == "adaptive":
face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0, target_chart_count=0) face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0)
else: else:
raise ValueError(f"unknown segmenter '{segmenter}'. valid: pec, adaptive") raise ValueError(f"unknown segmenter '{segmenter}'. valid: pec, adaptive")
@ -2534,12 +2549,12 @@ class UnwrapMesh(IO.ComfyNode):
if is_list or is_batched: if is_list or is_batched:
vi, fi = mesh.vertices[i], mesh.faces[i] vi, fi = mesh.vertices[i], mesh.faces[i]
ci = None ci = None
vc = getattr(mesh, "vertex_colors", None) vc = mesh.vertex_colors
if vc is not None: if vc is not None:
ci = vc[i] if (isinstance(vc, list) or vc.ndim == 3) else vc ci = vc[i] if (isinstance(vc, list) or vc.ndim == 3) else vc
else: else:
vi, fi = mesh.vertices, mesh.faces vi, fi = mesh.vertices, mesh.faces
ci = getattr(mesh, "vertex_colors", None) ci = mesh.vertex_colors
src_device = vi.device src_device = vi.device
vnp = vi.detach().cpu().numpy().astype(np.float32) vnp = vi.detach().cpu().numpy().astype(np.float32)
@ -2561,7 +2576,7 @@ class UnwrapMesh(IO.ComfyNode):
bar.update(1) bar.update(1)
out_mesh = _pack_uv_meshes(out_v, out_f, out_uv, out_c if out_c else None) out_mesh = _pack_uv_meshes(out_v, out_f, out_uv, out_c if out_c else None)
if getattr(mesh, "texture", None) is not None: if mesh.texture is not None:
out_mesh.texture = mesh.texture out_mesh.texture = mesh.texture
if cls.hidden.unique_id: if cls.hidden.unique_id:
@ -2743,7 +2758,7 @@ class RenderUVAtlas(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, mesh, resolution): def execute(cls, mesh, resolution):
uvs_t = getattr(mesh, "uvs", None) uvs_t = mesh.uvs
if uvs_t is None: if uvs_t is None:
raise RuntimeError("mesh has no UVs to render. Run UnwrapMesh first.") raise RuntimeError("mesh has no UVs to render. Run UnwrapMesh first.")
uvs_np = uvs_t.detach().cpu().numpy() uvs_np = uvs_t.detach().cpu().numpy()
@ -2847,8 +2862,8 @@ def merge_meshes(meshes):
def _b0(t): def _b0(t):
return t[0] if t.ndim == 3 else t return t[0] if t.ndim == 3 else t
any_uvs = any(getattr(m, "uvs", None) is not None for m in meshes) any_uvs = any(m.uvs is not None for m in meshes)
any_colors = any(getattr(m, "vertex_colors", None) is not None for m in meshes) any_colors = any(m.vertex_colors is not None for m in meshes)
verts_list, faces_list, uvs_list, colors_list = [], [], [], [] verts_list, faces_list, uvs_list, colors_list = [], [], [], []
texture = None texture = None
@ -2861,16 +2876,16 @@ def merge_meshes(meshes):
faces_list.append(f + offset) faces_list.append(f + offset)
offset += v.shape[0] offset += v.shape[0]
if any_uvs: if any_uvs:
mu = getattr(m, "uvs", None) mu = m.uvs
uvs_list.append(_b0(mu).cpu() if mu is not None else v.new_zeros((v.shape[0], 2))) uvs_list.append(_b0(mu).cpu() if mu is not None else v.new_zeros((v.shape[0], 2)))
if any_colors: if any_colors:
mc = getattr(m, "vertex_colors", None) mc = m.vertex_colors
if mc is not None: if mc is not None:
c = _b0(mc).cpu() c = _b0(mc).cpu()
else: else:
c = v.new_ones((v.shape[0], 3)) c = v.new_ones((v.shape[0], 3))
colors_list.append(c) colors_list.append(c)
mt = getattr(m, "texture", None) mt = m.texture
if mt is not None: if mt is not None:
if texture is None: if texture is None:
texture = mt.cpu() texture = mt.cpu()

View File

@ -105,16 +105,31 @@ def get_mesh_batch_item(mesh, index):
return mesh.vertices[index], mesh.faces[index], colors, uvs, normals return mesh.vertices[index], mesh.faces[index], colors, uvs, normals
def _smooth_vertex_normals(vertices_np, faces_np): def _smooth_vertex_normals(vertices_np, faces_np, weld=True):
"""Area-weighted per-vertex normals (unit length), fully smooth — no vertex splitting. """Area-weighted per-vertex normals (unit length), fully smooth — no vertex splitting.
Un-normalized face normals (the raw cross product) have magnitude 2*area, so Un-normalized face normals (the raw cross product) have magnitude 2*area, so
accumulating them onto their vertices yields an area-weighted average.""" accumulating them onto their vertices yields an area-weighted average. `weld` averages
across vertices that share a position UV-seam duplicates created by unwrapping so
both sides of a seam get one identical normal. Without it each side averages only its
own faces and a visible shading seam appears; welding matches the official, which
computes normals on the pre-split mesh and gathers them through the UV vmap."""
tris = vertices_np[faces_np] # (M, 3, 3) tris = vertices_np[faces_np] # (M, 3, 3)
face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0]) face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0])
normals = np.zeros((vertices_np.shape[0], 3), dtype=np.float64) if weld and vertices_np.shape[0]:
for k in range(3): # Group coincident positions (quantized to ~1e-5 of the bbox) into one shared normal.
np.add.at(normals, faces_np[:, k], face_n) lo = vertices_np.min(0)
inv_tol = 1.0 / (max(float((vertices_np.max(0) - lo).max()), 1e-9) * 1e-5)
q = np.round((vertices_np - lo) * inv_tol).astype(np.int64)
_, group = np.unique(q, axis=0, return_inverse=True)
acc = np.zeros((int(group.max()) + 1, 3), dtype=np.float64)
for k in range(3):
np.add.at(acc, group[faces_np[:, k]], face_n)
normals = acc[group] # welded normal back to each vertex
else:
normals = np.zeros((vertices_np.shape[0], 3), dtype=np.float64)
for k in range(3):
np.add.at(normals, faces_np[:, k], face_n)
lens = np.linalg.norm(normals, axis=1, keepdims=True) lens = np.linalg.norm(normals, axis=1, keepdims=True)
normals /= np.where(lens > 1e-12, lens, 1.0) normals /= np.where(lens > 1e-12, lens, 1.0)
return normals.astype(np.float32) return normals.astype(np.float32)