Fix normal smoothing and some cleanup

This commit is contained in:
kijai 2026-07-03 00:57:37 +03:00
parent d635cc412d
commit 429b13f97c
4 changed files with 132 additions and 258 deletions

View File

@ -818,17 +818,32 @@ def _quality_checks_fused(
return flip_out, skinny_out, link_out
def _compute_vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
def _compute_vertex_normals(verts: torch.Tensor, faces: torch.Tensor, weld: bool = True) -> torch.Tensor:
"""Area-weighted smooth vertex normals. `weld` averages face normals across vertices that
share a position (UV-seam duplicates from unwrapping) so both sides of a seam get one
identical normal otherwise a visible shading seam appears in the exported GLB."""
if faces.numel() == 0:
return torch.zeros_like(verts)
faces_long = faces.to(torch.int64)
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
v0, v1, v2 = verts[i0], verts[i1], verts[i2]
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
vn = torch.zeros_like(verts)
vn.scatter_add_(0, i0.unsqueeze(-1).expand_as(fn), fn)
vn.scatter_add_(0, i1.unsqueeze(-1).expand_as(fn), fn)
vn.scatter_add_(0, i2.unsqueeze(-1).expand_as(fn), fn)
if weld and verts.shape[0]:
# Group coincident positions (quantized to ~1e-5 of the bbox) into one shared normal.
lo = verts.min(0).values
inv_tol = 1.0 / (float((verts.max(0).values - lo).max().clamp_min(1e-9)) * 1e-5)
q = ((verts - lo) * inv_tol).round().to(torch.int64)
_, group = torch.unique(q, dim=0, return_inverse=True)
acc = torch.zeros((int(group.max()) + 1, 3), dtype=verts.dtype, device=verts.device)
acc.scatter_add_(0, group[i0].unsqueeze(-1).expand_as(fn), fn)
acc.scatter_add_(0, group[i1].unsqueeze(-1).expand_as(fn), fn)
acc.scatter_add_(0, group[i2].unsqueeze(-1).expand_as(fn), fn)
vn = acc[group]
else:
vn = torch.zeros_like(verts)
vn.scatter_add_(0, i0.unsqueeze(-1).expand_as(fn), fn)
vn.scatter_add_(0, i1.unsqueeze(-1).expand_as(fn), fn)
vn.scatter_add_(0, i2.unsqueeze(-1).expand_as(fn), fn)
return torch.nn.functional.normalize(vn, p=2, dim=-1, eps=1e-6)

View File

@ -28,112 +28,6 @@ DEFAULT_MAX_COST = 2.0
NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75°
@njit(cache=True, fastmath=False)
def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
F = face_normal.shape[0]
raw = np.zeros(F, dtype=np.float32)
for f in range(F):
nx = face_normal[f, 0]
ny = face_normal[f, 1]
nz = face_normal[f, 2]
s = np.float32(0.0)
for e in range(3):
nb = face_face[f, e]
if nb < 0:
continue
mx = face_normal[nb, 0]
my = face_normal[nb, 1]
mz = face_normal[nb, 2]
d = nx*mx + ny*my + nz*mz
s += np.float32(1.0) - d
raw[f] = s
return raw
def _face_curvature_numpy(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
nb_safe = np.maximum(face_face, 0)
nb_normal = face_normal[nb_safe]
d = (face_normal[:, None, :] * nb_normal).sum(axis=-1)
contrib = np.where(face_face >= 0, np.float32(1.0) - d, np.float32(0.0))
return contrib.sum(axis=1).astype(np.float32)
@njit(cache=True, fastmath=False)
def _farthest_point_seeds_jit(
face_centroid: np.ndarray, face_area: np.ndarray, face_weight: np.ndarray,
initial_seeds: np.ndarray, k_target: int,
):
F = face_centroid.shape[0]
INF = np.float32(1e30)
min_dist = np.full(F, INF, dtype=np.float32)
seeds = np.empty(k_target, dtype=np.int64)
n_seeds = 0
for i in range(initial_seeds.shape[0]):
s = initial_seeds[i]
if s < 0 or n_seeds >= k_target:
continue
seeds[n_seeds] = s
n_seeds += 1
sx = face_centroid[s, 0]
sy = face_centroid[s, 1]
sz = face_centroid[s, 2]
for f in range(F):
dx = face_centroid[f, 0] - sx
dy = face_centroid[f, 1] - sy
dz = face_centroid[f, 2] - sz
d2 = dx*dx + dy*dy + dz*dz
if d2 < min_dist[f]:
min_dist[f] = d2
while n_seeds < k_target:
best_f = -1
best_score = np.float32(-1.0)
for f in range(F):
d = min_dist[f]
if d >= INF * np.float32(0.5):
continue
score = d * face_weight[f]
if score > best_score:
best_score = score
best_f = f
if best_f < 0:
break
seeds[n_seeds] = best_f
n_seeds += 1
sx = face_centroid[best_f, 0]
sy = face_centroid[best_f, 1]
sz = face_centroid[best_f, 2]
for f in range(F):
dx = face_centroid[f, 0] - sx
dy = face_centroid[f, 1] - sy
dz = face_centroid[f, 2] - sz
d2 = dx*dx + dy*dy + dz*dz
if d2 < min_dist[f]:
min_dist[f] = d2
return seeds[:n_seeds]
def _farthest_point_seeds_numpy(
face_centroid: np.ndarray, initial_seeds: np.ndarray, k_target: int,
):
F = face_centroid.shape[0]
min_dist = np.full(F, np.inf, dtype=np.float32)
seeds: List[int] = []
for s in initial_seeds:
if s < 0 or len(seeds) >= k_target:
continue
seeds.append(int(s))
d = ((face_centroid - face_centroid[s])**2).sum(axis=-1)
min_dist = np.minimum(min_dist, d)
while len(seeds) < k_target:
best = int(np.argmax(min_dist))
if not np.isfinite(min_dist[best]) or min_dist[best] <= 0:
break
seeds.append(best)
d = ((face_centroid - face_centroid[best])**2).sum(axis=-1)
min_dist = np.minimum(min_dist, d)
return np.asarray(seeds, dtype=np.int64)
@njit(cache=True, fastmath=False)
def _cost_grow_iter_jit(
face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray,
@ -259,15 +153,14 @@ def _renumber(face_chart: np.ndarray, device) -> Tensor:
return torch.from_numpy(out).to(device)
def _segment_charts_fast(
def segment_charts(
mesh: MeshData,
max_cost: float,
w_normal_deviation: float,
max_cost: float = DEFAULT_MAX_COST,
w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION,
w_roundness: float = DEFAULT_W_ROUNDNESS,
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
target_chart_count: int = 0,
) -> Tensor:
"""Parallel batch cost-grow; target_chart_count 0 = adaptive seeding, >0 = K curvature-weighted FPS seeds."""
"""Segment mesh into charts (parallel batch cost-grow). Returns face -> chart_id."""
F = mesh.faces.shape[0]
device = mesh.faces.device
if F == 0:
@ -291,37 +184,9 @@ def _segment_charts_fast(
else:
initial_seeds = np.empty(0, dtype=np.int64)
adaptive_seeding = target_chart_count <= 0
if adaptive_seeding:
seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()]
if not seed_faces:
seed_faces = [0]
else:
if _HAVE_NUMBA:
curvature_raw = _face_curvature_jit(face_normal, face_face)
else:
curvature_raw = _face_curvature_numpy(face_normal, face_face)
cmax = float(curvature_raw.max()) if curvature_raw.size else 0.0
if cmax > 1e-6:
face_weight = (np.float32(1.0) + np.float32(50.0) *
(curvature_raw / np.float32(cmax))).astype(np.float32)
else:
face_weight = np.ones(F, dtype=np.float32)
n_comp = int(initial_seeds.size)
if n_comp < int(target_chart_count):
target_seeds = int(target_chart_count)
else:
target_seeds = n_comp + max(int(target_chart_count) // 4, 8)
target_seeds = min(target_seeds, F)
if _HAVE_NUMBA:
seeds_arr = _farthest_point_seeds_jit(
face_centroid, face_area, face_weight, initial_seeds, target_seeds,
)
else:
seeds_arr = _farthest_point_seeds_numpy(
face_centroid, initial_seeds, target_seeds,
)
seed_faces = [int(s) for s in seeds_arr.tolist()]
seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()]
if not seed_faces:
seed_faces = [0]
K = len(seed_faces)
chart_basis = np.zeros((K, 3), dtype=np.float32)
@ -345,10 +210,9 @@ def _segment_charts_fast(
return _renumber(face_chart, device)
min_dist_to_seed = np.full(F, np.inf, dtype=np.float32)
if adaptive_seeding:
for sf in seed_faces:
d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1)
min_dist_to_seed = np.minimum(min_dist_to_seed, d)
for sf in seed_faces:
d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1)
min_dist_to_seed = np.minimum(min_dist_to_seed, d)
if _HAVE_NUMBA:
# Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg.
@ -375,8 +239,6 @@ def _segment_charts_fast(
break
if (face_chart == -1).sum() == 0:
break
if not adaptive_seeding:
break
if chart_basis.shape[0] >= max_total_charts:
break
unassigned_mask = face_chart == -1
@ -462,24 +324,6 @@ def _segment_charts_fast(
return _renumber(face_chart, device)
def segment_charts(
mesh: MeshData,
max_cost: float = DEFAULT_MAX_COST,
w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION,
w_roundness: float = DEFAULT_W_ROUNDNESS,
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
target_chart_count: int = 0,
) -> Tensor:
"""Segment mesh into charts. Returns face -> chart_id."""
return _segment_charts_fast(
mesh, max_cost=max_cost,
w_normal_deviation=w_normal_deviation,
w_roundness=w_roundness,
w_straightness=w_straightness,
target_chart_count=target_chart_count,
)
# ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ----
def _combine_normal_cones(
axis_a: Tensor, half_a: Tensor,
@ -558,10 +402,7 @@ def _build_chart_edges(
def cluster_charts_pec(
mesh: MeshData,
target_chart_count: int = 0,
max_cost: float = 0.7,
area_penalty_weight: float = 0.0,
roundness_weight: float = 0.0,
max_iters: int = 1024,
) -> Tensor:
"""Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg)."""
@ -570,7 +411,6 @@ def cluster_charts_pec(
faces = mesh.faces.to(torch.long)
vertices = mesh.vertices.to(torch.float32)
face_normal = mesh.face_normal.to(torch.float32)
face_area = mesh.face_area.to(torch.float32)
face_face = mesh.face_face.to(torch.long)
face_edge_len = face_edge_lengths(vertices, faces)
@ -578,11 +418,9 @@ def cluster_charts_pec(
chart_id = torch.arange(F, dtype=torch.long, device=device)
chart_axis = face_normal.clone()
chart_half = torch.zeros(F, dtype=torch.float32, device=device)
chart_area = face_area.clone()
chart_perim = face_edge_len.sum(dim=1).clone()
for it in range(max_iters):
edges, edge_len = _build_chart_edges(face_face, chart_id, face_edge_len)
edges, _ = _build_chart_edges(face_face, chart_id, face_edge_len)
if edges.shape[0] == 0:
break
@ -594,13 +432,6 @@ def cluster_charts_pec(
half_b = chart_half[b]
_, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b)
cost = new_half.clone()
if area_penalty_weight > 0.0:
new_area = chart_area[a] + chart_area[b]
cost = cost + area_penalty_weight * new_area
if roundness_weight > 0.0:
new_area_r = chart_area[a] + chart_area[b]
new_perim_r = chart_perim[a] + chart_perim[b] - 2.0 * edge_len
cost = cost + roundness_weight * (new_perim_r * new_perim_r) / new_area_r.clamp_min(1e-12)
# Pack (cost, edge_id) so scatter_reduce amin picks the right edge.
E = edges.shape[0]
@ -612,11 +443,12 @@ def cluster_charts_pec(
chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True)
chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True)
# Mutual-min collapse: each chart in at most one merge per iter.
# Mutual-min collapse: each chart in at most one merge per iter (winners are disjoint pairs).
is_a_min = chart_min[a] == key
is_b_min = chart_min[b] == key
mutual = is_a_min & is_b_min
within = cost <= max_cost
winners = is_a_min & is_b_min & within
winners = mutual & within
n_merge = int(winners.sum().item())
if n_merge == 0:
@ -624,7 +456,6 @@ def cluster_charts_pec(
win_a = a[winners]
win_b = b[winners]
win_el = edge_len[winners]
axis_a_w = chart_axis[win_a]
half_a_w = chart_half[win_a]
@ -635,8 +466,6 @@ def cluster_charts_pec(
)
chart_axis[win_a] = new_axis
chart_half[win_a] = new_half_w
chart_area[win_a] = chart_area[win_a] + chart_area[win_b]
chart_perim[win_a] = chart_perim[win_a] + chart_perim[win_b] - 2.0 * win_el
remap = torch.arange(N, dtype=torch.long, device=device)
remap[win_b] = win_a

View File

@ -2,7 +2,7 @@ import torch
import numpy as np
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_api.latest import ComfyExtension, IO, Types, io
import copy
import comfy.utils
import comfy.model_management
@ -18,6 +18,9 @@ from tqdm import tqdm
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from scipy.spatial import cKDTree
import scipy.ndimage as ndi
MeshCameras = io.Custom("MESH_CAMERAS") # carries the camera set from RenderMeshViews → BakeViewsToTexture
def get_mesh_batch_item(mesh, index):
if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None:
@ -460,7 +463,8 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors
try:
vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution)
except Exception as e:
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear failed ({e}); falling back to CPU")
comfy.model_management.raise_non_oom(e) # only fall back on OOM; surface real errors
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear ran out of memory ({e}); falling back to CPU")
vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution)
if not ok.all():
vals[~ok] = _nearest(valid_positions[~ok]) # no occupied neighbour
@ -657,25 +661,46 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64, return_face=False):
return bestp
def _back_project_positions(position_map, mask, ref_v, ref_f):
def _back_project_positions(position_map, mask, ref_v, ref_f, max_query_res=1024):
"""Snap covered texels onto the reference mesh's true surface (pure-torch BVH, no
cumesh/scipy/trimesh) so the voxel field is sampled at full detail, not along flat
triangle chords. Returns a new position_map."""
valid = np.ascontiguousarray(position_map[mask].astype(np.float32))
if valid.shape[0] == 0:
triangle chords. Returns a new position_map.
"""
if not mask.any():
return position_map
dev = comfy.model_management.get_torch_device()
rv = ref_v.detach().to(dev).float()
rf = ref_f.detach().to(dev).long()
tri = rv[rf]
Q = torch.from_numpy(valid).to(dev)
bvh = _build_triangle_bvh(tri)
bp = _closest_points_on_mesh_bvh(Q, tri, bvh)
def _closest(pts_np):
return _closest_points_on_mesh_bvh(
torch.from_numpy(np.ascontiguousarray(pts_np.astype(np.float32))).to(dev), tri, bvh
).detach().cpu().numpy().astype(np.float32)
H, W, _ = position_map.shape
stride = max(1, int(math.ceil(max(H, W) / float(max_query_res))))
if stride == 1 or not mask[::stride, ::stride].any():
out = position_map.copy()
out[mask] = _closest(position_map[mask]).astype(position_map.dtype)
return out
# Low-res correction, then bilinear upsample to full resolution.
pos_lo = position_map[::stride, ::stride]
mask_lo = mask[::stride, ::stride]
Hl, Wl = mask_lo.shape
corr_lo = np.zeros((Hl, Wl, 3), dtype=np.float32)
corr_lo[mask_lo] = _closest(pos_lo[mask_lo]) - pos_lo[mask_lo].astype(np.float32)
inds = ndi.distance_transform_edt(~mask_lo, return_distances=False, return_indices=True)
corr_lo = corr_lo[tuple(inds)] # extrapolate into gutter (nearest)
corr = torch.nn.functional.interpolate(
torch.from_numpy(np.ascontiguousarray(corr_lo)).permute(2, 0, 1)[None].to(dev),
size=(H, W), mode="bilinear", align_corners=False,
)[0].permute(1, 2, 0).cpu().numpy()
out = position_map.copy()
out[mask] = bp.detach().cpu().numpy().astype(position_map.dtype)
out[mask] = position_map[mask] + corr[mask]
return out
@ -703,6 +728,7 @@ def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64):
nmin, nmax = bvh['nmin'], bvh['nmax']
left, right, order = bvh['left'], bvh['right'], bvh['order']
inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs)
tmaxN = tmax if torch.is_tensor(tmax) else torch.full((N,), float(tmax), device=dev) # per-ray far bound
hit = torch.zeros(N, dtype=torch.bool, device=dev)
# int32 stack: node indices fit in 31 bits and this [N, max_stack] array dominates memory.
stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev)
@ -710,24 +736,24 @@ def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64):
stack[:, 0] = 0
active = torch.arange(N, device=dev)
def slab(node, o, i):
def slab(node, o, i, tmx):
t1 = (nmin[node] - o) * i
t2 = (nmax[node] - o) * i
tnear = torch.minimum(t1, t2).amax(-1)
tfar = torch.maximum(t1, t2).amin(-1)
return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmax) & (tfar >= tmin)
return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmx) & (tfar >= tmin)
while active.numel() > 0:
a = active
node = stack[a, sp[a] - 1]
sp[a] = sp[a] - 1
within = slab(node, orig[a], inv[a])
within = slab(node, orig[a], inv[a], tmaxN[a])
isleaf = node >= LEAF
lv = within & isleaf
if bool(lv.any()):
ga = a[lv]
tt = tri[order[node[lv] - LEAF]]
h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmax)
h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmaxN[ga])
hit[ga[h]] = True
iv = within & ~isleaf
if bool(iv.any()):
@ -838,7 +864,7 @@ def _bake_ambient_occlusion(high_v, high_f, low_v_np, low_f_np, low_uv_np, low_n
# memory for no gain; floor keeps tiny GPUs from thrashing into too many chunks.
try:
free = torch.cuda.mem_get_info(dev)[0] if dev.type == "cuda" else (2 << 30)
except Exception:
except RuntimeError:
free = 2 << 30
ray_chunk = int(min(1 << 22, max(1 << 20, (free * 0.25) / (num_samples * 4 + 200))))
face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution)
@ -1059,11 +1085,8 @@ def _jfa_fill_gpu(img01, mask):
return filled.cpu().numpy()
def _seam_fill(img01, mask, inpaint_radius):
"""Fill UV-gutter texels (so seams don't pull in black) via JFA. `inpaint_radius<=0`
disables; the radius value itself is ignored (JFA fills all uncovered by nearest)."""
if inpaint_radius <= 0:
return img01
def _seam_fill(img01, mask):
"""Fill UV-gutter texels (so seams don't pull in black) via JFA nearest-coverage."""
return _jfa_fill_gpu(img01, mask)
@ -1092,7 +1115,7 @@ def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None):
def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
resolution, texture_size, uvs, inpaint_radius=3,
resolution, texture_size, uvs,
normalize_uvs=True, reference=None, pbar=None):
"""Bake a baseColor (+ optional metallicRoughness) texture: rasterize in UV space,
sample each texel from the sparse voxel volume. `uvs` (N,2) is the existing layout,
@ -1109,7 +1132,6 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
v_np = vertices.detach().cpu().numpy().astype(np.float32)
f_np = faces.detach().cpu().numpy().astype(np.uint32)
fcount = int(f_np.shape[0])
uv_np = uvs.detach().cpu().numpy().astype(np.float32)
if uv_np.shape[0] != v_np.shape[0]:
@ -1117,13 +1139,6 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 "
f"with vertices ({v_np.shape[0]})."
)
uv_min = uv_np.min(axis=0)
uv_max = uv_np.max(axis=0)
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
f"{fcount} faces")
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ")
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
@ -1151,12 +1166,12 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
roughness = attrs[..., 4:5] if C >= 5 else None
# alpha (idx 5) ignored — meshes kept opaque (upstream OPAQUE alpha_mode).
base_color = _seam_fill(np.ascontiguousarray(base_color), mask, inpaint_radius)
base_color = _seam_fill(np.ascontiguousarray(base_color), mask)
mr_image = None
if has_pbr:
# glTF metallicRoughness: R unused, G=roughness, B=metallic.
mr = np.concatenate([np.zeros_like(roughness), roughness, metallic], axis=-1)
mr_image = _seam_fill(np.ascontiguousarray(mr), mask, inpaint_radius)
mr_image = _seam_fill(np.ascontiguousarray(mr), mask)
device = vertices.device
out_v = torch.from_numpy(new_verts).to(device=device, dtype=torch.float32)
@ -1195,8 +1210,8 @@ class BakeTextureFromVoxel(IO.ComfyNode):
inputs=[
IO.Mesh.Input("mesh"),
IO.Voxel.Input("voxel_colors"),
IO.Int.Input("texture_size", default=1024, min=64, max=8192,
tooltip="Square texture resolution."),
IO.Int.Input("texture_size", default=2048, min=64, max=8192,
tooltip="Square UV atlas resolution."),
IO.Mesh.Input("reference_mesh", optional=True,
tooltip=(
"Optional dense pre-decimation mesh; back-projects each texel onto its "
@ -1211,13 +1226,11 @@ class BakeTextureFromVoxel(IO.ComfyNode):
@classmethod
def execute(cls, mesh, voxel_colors, texture_size, reference_mesh=None):
# Matches official to_glb; effectively on/off since the gutter fill ignores the value.
inpaint_radius = 3
voxels = voxel_colors
coords = voxels.data
colors = voxels.voxel_colors
resolution = voxels.resolution
mesh_uvs = getattr(mesh, "uvs", None)
mesh_uvs = mesh.uvs
if mesh_uvs is None:
raise ValueError(
"BakeTextureFromVoxel: input mesh has no UVs. This node bakes onto the "
@ -1249,7 +1262,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
v_i, f_i, item_coords, item_colors,
resolution=resolution, texture_size=texture_size,
uvs=ev_i, inpaint_radius=inpaint_radius,
uvs=ev_i,
reference=ref_i, pbar=pbar,
)
out_tex.append(bt)
@ -1275,7 +1288,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
v0, f0, coords, colors,
resolution=resolution, texture_size=texture_size,
uvs=ev0, inpaint_radius=inpaint_radius,
uvs=ev0,
reference=ref0, pbar=pbar,
)
base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0)
@ -1317,9 +1330,9 @@ class MeshTextureToImage(IO.ComfyNode):
t = t.unsqueeze(0)
return t
base = _as_image(getattr(mesh, "texture", None))
mr = _as_image(getattr(mesh, "metallic_roughness", None))
normal_map = _as_image(getattr(mesh, "normal_map", None))
base = _as_image(mesh.texture)
mr = _as_image(mesh.metallic_roughness)
normal_map = _as_image(mesh.normal_map)
if base is None:
raise ValueError(
@ -1335,7 +1348,7 @@ class MeshTextureToImage(IO.ComfyNode):
roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous()
# R is real occlusion only if AO was baked; else it's the unused zero channel, which as
# "occlusion" would read fully-dark — so report white unless occlusion_in_mr is set.
if getattr(mesh, "occlusion_in_mr", False):
if mesh.occlusion_in_mr:
occlusion = mr[..., 0:1].expand(-1, -1, -1, 3).contiguous()
else:
occlusion = torch.ones_like(base)
@ -1365,7 +1378,7 @@ class ApplyTextureToMesh(IO.ComfyNode):
@classmethod
def execute(cls, mesh, base_color, metallic=None, roughness=None, occlusion=None, normal_map=None):
mesh_uvs = getattr(mesh, "uvs", None)
mesh_uvs = mesh.uvs
if mesh_uvs is None:
raise ValueError(
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
@ -1391,7 +1404,7 @@ class ApplyTextureToMesh(IO.ComfyNode):
# and export the smooth normals the TBN was built on — without a NORMAL attribute the
# viewer shades flat and the tangent-space detail fights the faceting.
dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(mesh, "normals", None)
low_n_attr = mesh.normals
B = int(mesh.vertices.shape[0])
Nmax = int(mesh.vertices.shape[1]) if mesh.vertices.ndim == 3 else int(mesh.vertices.shape[0])
tangents_padded = torch.zeros((B, Nmax, 4), dtype=torch.float32)
@ -1467,7 +1480,7 @@ class BakeNormalMapFromMesh(IO.ComfyNode):
@classmethod
def execute(cls, low_poly, high_poly, resolution, cage_distance=0.05, ignore_backfaces=True):
low_uvs = getattr(low_poly, "uvs", None)
low_uvs = low_poly.uvs
if low_uvs is None:
raise ValueError(
"BakeNormalMapFromMesh: low_poly has no UVs. Connect the UV-unwrapped "
@ -1475,8 +1488,8 @@ class BakeNormalMapFromMesh(IO.ComfyNode):
"onto existing UVs and never unwraps.")
dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(low_poly, "normals", None)
high_n_attr = getattr(high_poly, "normals", None)
low_n_attr = low_poly.normals
high_n_attr = high_poly.normals
B = int(low_poly.vertices.shape[0])
h_batch = int(high_poly.vertices.shape[0])
@ -1549,13 +1562,13 @@ class BakeAmbientOcclusion(IO.ComfyNode):
@classmethod
def execute(cls, low_poly, high_poly, resolution, samples, max_distance, strength, bias):
low_uvs = getattr(low_poly, "uvs", None)
low_uvs = low_poly.uvs
if low_uvs is None:
raise ValueError(
"BakeAmbientOcclusion: low_poly has no UVs. Connect the UV-unwrapped low-poly "
"(the same one used for the other bakes); this node never unwraps.")
dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(low_poly, "normals", None)
low_n_attr = low_poly.normals
B = int(low_poly.vertices.shape[0])
h_batch = int(high_poly.vertices.shape[0])
@ -1630,7 +1643,7 @@ class SetMeshMaterial(IO.ComfyNode):
base_color_r, base_color_g, base_color_b, metallic_factor, roughness_factor,
normal_scale, occlusion_strength, double_sided, emissive_texture=None):
out_mesh = copy.copy(mesh)
material = dict(getattr(mesh, "material", {}) or {}) # merge over any prior material
material = dict(mesh.material or {}) # merge over any prior material
material.update({
"emissive_factor": [float(emissive_r), float(emissive_g), float(emissive_b)],
"emissive_strength": float(emissive_strength),
@ -2201,7 +2214,8 @@ class DecimateMesh(IO.ComfyNode):
if rc is not None:
c = rc.to(src_device)
except Exception as e:
logging.warning(f"DecimateMesh: QEM simplify failed, passing mesh through unchanged: {e!r}")
comfy.model_management.raise_non_oom(e) # surface real errors; only OOM passes through
logging.warning(f"DecimateMesh: QEM simplify ran out of memory, passing mesh through unchanged: {e!r}")
counts["out"] += int(f.shape[0])
return v, f, c
@ -2318,7 +2332,8 @@ class RemeshMesh(IO.ComfyNode):
f = rf.to(src_device)
c = rc.to(src_device) if rc is not None else None
except Exception as e:
logging.warning(f"RemeshMesh: remesh failed, passing mesh through unchanged: {e!r}")
comfy.model_management.raise_non_oom(e) # surface real errors; only OOM passes through
logging.warning(f"RemeshMesh: remesh ran out of memory, passing mesh through unchanged: {e!r}")
counts["out"] += int(f.shape[0])
return v, f, c
@ -2409,9 +2424,9 @@ def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance
if segmenter == "pec":
if mesh.faces.device.type != "cuda":
raise RuntimeError("segmenter='pec' requires a CUDA mesh; use 'adaptive' for CPU.")
face_chart = _uv_seg.cluster_charts_pec(mesh, target_chart_count=0, max_cost=1.0)
face_chart = _uv_seg.cluster_charts_pec(mesh, max_cost=1.0)
elif segmenter == "adaptive":
face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0, target_chart_count=0)
face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0)
else:
raise ValueError(f"unknown segmenter '{segmenter}'. valid: pec, adaptive")
@ -2534,12 +2549,12 @@ class UnwrapMesh(IO.ComfyNode):
if is_list or is_batched:
vi, fi = mesh.vertices[i], mesh.faces[i]
ci = None
vc = getattr(mesh, "vertex_colors", None)
vc = mesh.vertex_colors
if vc is not None:
ci = vc[i] if (isinstance(vc, list) or vc.ndim == 3) else vc
else:
vi, fi = mesh.vertices, mesh.faces
ci = getattr(mesh, "vertex_colors", None)
ci = mesh.vertex_colors
src_device = vi.device
vnp = vi.detach().cpu().numpy().astype(np.float32)
@ -2561,7 +2576,7 @@ class UnwrapMesh(IO.ComfyNode):
bar.update(1)
out_mesh = _pack_uv_meshes(out_v, out_f, out_uv, out_c if out_c else None)
if getattr(mesh, "texture", None) is not None:
if mesh.texture is not None:
out_mesh.texture = mesh.texture
if cls.hidden.unique_id:
@ -2743,7 +2758,7 @@ class RenderUVAtlas(IO.ComfyNode):
@classmethod
def execute(cls, mesh, resolution):
uvs_t = getattr(mesh, "uvs", None)
uvs_t = mesh.uvs
if uvs_t is None:
raise RuntimeError("mesh has no UVs to render. Run UnwrapMesh first.")
uvs_np = uvs_t.detach().cpu().numpy()
@ -2847,8 +2862,8 @@ def merge_meshes(meshes):
def _b0(t):
return t[0] if t.ndim == 3 else t
any_uvs = any(getattr(m, "uvs", None) is not None for m in meshes)
any_colors = any(getattr(m, "vertex_colors", None) is not None for m in meshes)
any_uvs = any(m.uvs is not None for m in meshes)
any_colors = any(m.vertex_colors is not None for m in meshes)
verts_list, faces_list, uvs_list, colors_list = [], [], [], []
texture = None
@ -2861,16 +2876,16 @@ def merge_meshes(meshes):
faces_list.append(f + offset)
offset += v.shape[0]
if any_uvs:
mu = getattr(m, "uvs", None)
mu = m.uvs
uvs_list.append(_b0(mu).cpu() if mu is not None else v.new_zeros((v.shape[0], 2)))
if any_colors:
mc = getattr(m, "vertex_colors", None)
mc = m.vertex_colors
if mc is not None:
c = _b0(mc).cpu()
else:
c = v.new_ones((v.shape[0], 3))
colors_list.append(c)
mt = getattr(m, "texture", None)
mt = m.texture
if mt is not None:
if texture is None:
texture = mt.cpu()

View File

@ -105,16 +105,31 @@ def get_mesh_batch_item(mesh, index):
return mesh.vertices[index], mesh.faces[index], colors, uvs, normals
def _smooth_vertex_normals(vertices_np, faces_np):
def _smooth_vertex_normals(vertices_np, faces_np, weld=True):
"""Area-weighted per-vertex normals (unit length), fully smooth — no vertex splitting.
Un-normalized face normals (the raw cross product) have magnitude 2*area, so
accumulating them onto their vertices yields an area-weighted average."""
accumulating them onto their vertices yields an area-weighted average. `weld` averages
across vertices that share a position UV-seam duplicates created by unwrapping so
both sides of a seam get one identical normal. Without it each side averages only its
own faces and a visible shading seam appears; welding matches the official, which
computes normals on the pre-split mesh and gathers them through the UV vmap."""
tris = vertices_np[faces_np] # (M, 3, 3)
face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0])
normals = np.zeros((vertices_np.shape[0], 3), dtype=np.float64)
for k in range(3):
np.add.at(normals, faces_np[:, k], face_n)
if weld and vertices_np.shape[0]:
# Group coincident positions (quantized to ~1e-5 of the bbox) into one shared normal.
lo = vertices_np.min(0)
inv_tol = 1.0 / (max(float((vertices_np.max(0) - lo).max()), 1e-9) * 1e-5)
q = np.round((vertices_np - lo) * inv_tol).astype(np.int64)
_, group = np.unique(q, axis=0, return_inverse=True)
acc = np.zeros((int(group.max()) + 1, 3), dtype=np.float64)
for k in range(3):
np.add.at(acc, group[faces_np[:, k]], face_n)
normals = acc[group] # welded normal back to each vertex
else:
normals = np.zeros((vertices_np.shape[0], 3), dtype=np.float64)
for k in range(3):
np.add.at(normals, faces_np[:, k], face_n)
lens = np.linalg.norm(normals, axis=1, keepdims=True)
normals /= np.where(lens > 1e-12, lens, 1.0)
return normals.astype(np.float32)