mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Fix normal smoothing and some cleanup
This commit is contained in:
parent
d635cc412d
commit
429b13f97c
@ -818,17 +818,32 @@ def _quality_checks_fused(
|
|||||||
return flip_out, skinny_out, link_out
|
return flip_out, skinny_out, link_out
|
||||||
|
|
||||||
|
|
||||||
def _compute_vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
|
def _compute_vertex_normals(verts: torch.Tensor, faces: torch.Tensor, weld: bool = True) -> torch.Tensor:
|
||||||
|
"""Area-weighted smooth vertex normals. `weld` averages face normals across vertices that
|
||||||
|
share a position (UV-seam duplicates from unwrapping) so both sides of a seam get one
|
||||||
|
identical normal — otherwise a visible shading seam appears in the exported GLB."""
|
||||||
if faces.numel() == 0:
|
if faces.numel() == 0:
|
||||||
return torch.zeros_like(verts)
|
return torch.zeros_like(verts)
|
||||||
faces_long = faces.to(torch.int64)
|
faces_long = faces.to(torch.int64)
|
||||||
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
|
i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2]
|
||||||
v0, v1, v2 = verts[i0], verts[i1], verts[i2]
|
v0, v1, v2 = verts[i0], verts[i1], verts[i2]
|
||||||
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||||
vn = torch.zeros_like(verts)
|
if weld and verts.shape[0]:
|
||||||
vn.scatter_add_(0, i0.unsqueeze(-1).expand_as(fn), fn)
|
# Group coincident positions (quantized to ~1e-5 of the bbox) into one shared normal.
|
||||||
vn.scatter_add_(0, i1.unsqueeze(-1).expand_as(fn), fn)
|
lo = verts.min(0).values
|
||||||
vn.scatter_add_(0, i2.unsqueeze(-1).expand_as(fn), fn)
|
inv_tol = 1.0 / (float((verts.max(0).values - lo).max().clamp_min(1e-9)) * 1e-5)
|
||||||
|
q = ((verts - lo) * inv_tol).round().to(torch.int64)
|
||||||
|
_, group = torch.unique(q, dim=0, return_inverse=True)
|
||||||
|
acc = torch.zeros((int(group.max()) + 1, 3), dtype=verts.dtype, device=verts.device)
|
||||||
|
acc.scatter_add_(0, group[i0].unsqueeze(-1).expand_as(fn), fn)
|
||||||
|
acc.scatter_add_(0, group[i1].unsqueeze(-1).expand_as(fn), fn)
|
||||||
|
acc.scatter_add_(0, group[i2].unsqueeze(-1).expand_as(fn), fn)
|
||||||
|
vn = acc[group]
|
||||||
|
else:
|
||||||
|
vn = torch.zeros_like(verts)
|
||||||
|
vn.scatter_add_(0, i0.unsqueeze(-1).expand_as(fn), fn)
|
||||||
|
vn.scatter_add_(0, i1.unsqueeze(-1).expand_as(fn), fn)
|
||||||
|
vn.scatter_add_(0, i2.unsqueeze(-1).expand_as(fn), fn)
|
||||||
return torch.nn.functional.normalize(vn, p=2, dim=-1, eps=1e-6)
|
return torch.nn.functional.normalize(vn, p=2, dim=-1, eps=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,112 +28,6 @@ DEFAULT_MAX_COST = 2.0
|
|||||||
NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75°
|
NORMAL_DEVIATION_HARD_CUTOFF = 0.707 # ~75°
|
||||||
|
|
||||||
|
|
||||||
@njit(cache=True, fastmath=False)
|
|
||||||
def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
|
|
||||||
F = face_normal.shape[0]
|
|
||||||
raw = np.zeros(F, dtype=np.float32)
|
|
||||||
for f in range(F):
|
|
||||||
nx = face_normal[f, 0]
|
|
||||||
ny = face_normal[f, 1]
|
|
||||||
nz = face_normal[f, 2]
|
|
||||||
s = np.float32(0.0)
|
|
||||||
for e in range(3):
|
|
||||||
nb = face_face[f, e]
|
|
||||||
if nb < 0:
|
|
||||||
continue
|
|
||||||
mx = face_normal[nb, 0]
|
|
||||||
my = face_normal[nb, 1]
|
|
||||||
mz = face_normal[nb, 2]
|
|
||||||
d = nx*mx + ny*my + nz*mz
|
|
||||||
s += np.float32(1.0) - d
|
|
||||||
raw[f] = s
|
|
||||||
return raw
|
|
||||||
|
|
||||||
|
|
||||||
def _face_curvature_numpy(face_normal: np.ndarray, face_face: np.ndarray) -> np.ndarray:
|
|
||||||
nb_safe = np.maximum(face_face, 0)
|
|
||||||
nb_normal = face_normal[nb_safe]
|
|
||||||
d = (face_normal[:, None, :] * nb_normal).sum(axis=-1)
|
|
||||||
contrib = np.where(face_face >= 0, np.float32(1.0) - d, np.float32(0.0))
|
|
||||||
return contrib.sum(axis=1).astype(np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
@njit(cache=True, fastmath=False)
|
|
||||||
def _farthest_point_seeds_jit(
|
|
||||||
face_centroid: np.ndarray, face_area: np.ndarray, face_weight: np.ndarray,
|
|
||||||
initial_seeds: np.ndarray, k_target: int,
|
|
||||||
):
|
|
||||||
F = face_centroid.shape[0]
|
|
||||||
INF = np.float32(1e30)
|
|
||||||
min_dist = np.full(F, INF, dtype=np.float32)
|
|
||||||
seeds = np.empty(k_target, dtype=np.int64)
|
|
||||||
n_seeds = 0
|
|
||||||
for i in range(initial_seeds.shape[0]):
|
|
||||||
s = initial_seeds[i]
|
|
||||||
if s < 0 or n_seeds >= k_target:
|
|
||||||
continue
|
|
||||||
seeds[n_seeds] = s
|
|
||||||
n_seeds += 1
|
|
||||||
sx = face_centroid[s, 0]
|
|
||||||
sy = face_centroid[s, 1]
|
|
||||||
sz = face_centroid[s, 2]
|
|
||||||
for f in range(F):
|
|
||||||
dx = face_centroid[f, 0] - sx
|
|
||||||
dy = face_centroid[f, 1] - sy
|
|
||||||
dz = face_centroid[f, 2] - sz
|
|
||||||
d2 = dx*dx + dy*dy + dz*dz
|
|
||||||
if d2 < min_dist[f]:
|
|
||||||
min_dist[f] = d2
|
|
||||||
while n_seeds < k_target:
|
|
||||||
best_f = -1
|
|
||||||
best_score = np.float32(-1.0)
|
|
||||||
for f in range(F):
|
|
||||||
d = min_dist[f]
|
|
||||||
if d >= INF * np.float32(0.5):
|
|
||||||
continue
|
|
||||||
score = d * face_weight[f]
|
|
||||||
if score > best_score:
|
|
||||||
best_score = score
|
|
||||||
best_f = f
|
|
||||||
if best_f < 0:
|
|
||||||
break
|
|
||||||
seeds[n_seeds] = best_f
|
|
||||||
n_seeds += 1
|
|
||||||
sx = face_centroid[best_f, 0]
|
|
||||||
sy = face_centroid[best_f, 1]
|
|
||||||
sz = face_centroid[best_f, 2]
|
|
||||||
for f in range(F):
|
|
||||||
dx = face_centroid[f, 0] - sx
|
|
||||||
dy = face_centroid[f, 1] - sy
|
|
||||||
dz = face_centroid[f, 2] - sz
|
|
||||||
d2 = dx*dx + dy*dy + dz*dz
|
|
||||||
if d2 < min_dist[f]:
|
|
||||||
min_dist[f] = d2
|
|
||||||
return seeds[:n_seeds]
|
|
||||||
|
|
||||||
|
|
||||||
def _farthest_point_seeds_numpy(
|
|
||||||
face_centroid: np.ndarray, initial_seeds: np.ndarray, k_target: int,
|
|
||||||
):
|
|
||||||
F = face_centroid.shape[0]
|
|
||||||
min_dist = np.full(F, np.inf, dtype=np.float32)
|
|
||||||
seeds: List[int] = []
|
|
||||||
for s in initial_seeds:
|
|
||||||
if s < 0 or len(seeds) >= k_target:
|
|
||||||
continue
|
|
||||||
seeds.append(int(s))
|
|
||||||
d = ((face_centroid - face_centroid[s])**2).sum(axis=-1)
|
|
||||||
min_dist = np.minimum(min_dist, d)
|
|
||||||
while len(seeds) < k_target:
|
|
||||||
best = int(np.argmax(min_dist))
|
|
||||||
if not np.isfinite(min_dist[best]) or min_dist[best] <= 0:
|
|
||||||
break
|
|
||||||
seeds.append(best)
|
|
||||||
d = ((face_centroid - face_centroid[best])**2).sum(axis=-1)
|
|
||||||
min_dist = np.minimum(min_dist, d)
|
|
||||||
return np.asarray(seeds, dtype=np.int64)
|
|
||||||
|
|
||||||
|
|
||||||
@njit(cache=True, fastmath=False)
|
@njit(cache=True, fastmath=False)
|
||||||
def _cost_grow_iter_jit(
|
def _cost_grow_iter_jit(
|
||||||
face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray,
|
face_chart: np.ndarray, face_face: np.ndarray, face_normal: np.ndarray,
|
||||||
@ -259,15 +153,14 @@ def _renumber(face_chart: np.ndarray, device) -> Tensor:
|
|||||||
return torch.from_numpy(out).to(device)
|
return torch.from_numpy(out).to(device)
|
||||||
|
|
||||||
|
|
||||||
def _segment_charts_fast(
|
def segment_charts(
|
||||||
mesh: MeshData,
|
mesh: MeshData,
|
||||||
max_cost: float,
|
max_cost: float = DEFAULT_MAX_COST,
|
||||||
w_normal_deviation: float,
|
w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION,
|
||||||
w_roundness: float = DEFAULT_W_ROUNDNESS,
|
w_roundness: float = DEFAULT_W_ROUNDNESS,
|
||||||
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
|
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
|
||||||
target_chart_count: int = 0,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Parallel batch cost-grow; target_chart_count 0 = adaptive seeding, >0 = K curvature-weighted FPS seeds."""
|
"""Segment mesh into charts (parallel batch cost-grow). Returns face -> chart_id."""
|
||||||
F = mesh.faces.shape[0]
|
F = mesh.faces.shape[0]
|
||||||
device = mesh.faces.device
|
device = mesh.faces.device
|
||||||
if F == 0:
|
if F == 0:
|
||||||
@ -291,37 +184,9 @@ def _segment_charts_fast(
|
|||||||
else:
|
else:
|
||||||
initial_seeds = np.empty(0, dtype=np.int64)
|
initial_seeds = np.empty(0, dtype=np.int64)
|
||||||
|
|
||||||
adaptive_seeding = target_chart_count <= 0
|
seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()]
|
||||||
if adaptive_seeding:
|
if not seed_faces:
|
||||||
seed_faces: List[int] = [int(s) for s in initial_seeds.tolist()]
|
seed_faces = [0]
|
||||||
if not seed_faces:
|
|
||||||
seed_faces = [0]
|
|
||||||
else:
|
|
||||||
if _HAVE_NUMBA:
|
|
||||||
curvature_raw = _face_curvature_jit(face_normal, face_face)
|
|
||||||
else:
|
|
||||||
curvature_raw = _face_curvature_numpy(face_normal, face_face)
|
|
||||||
cmax = float(curvature_raw.max()) if curvature_raw.size else 0.0
|
|
||||||
if cmax > 1e-6:
|
|
||||||
face_weight = (np.float32(1.0) + np.float32(50.0) *
|
|
||||||
(curvature_raw / np.float32(cmax))).astype(np.float32)
|
|
||||||
else:
|
|
||||||
face_weight = np.ones(F, dtype=np.float32)
|
|
||||||
n_comp = int(initial_seeds.size)
|
|
||||||
if n_comp < int(target_chart_count):
|
|
||||||
target_seeds = int(target_chart_count)
|
|
||||||
else:
|
|
||||||
target_seeds = n_comp + max(int(target_chart_count) // 4, 8)
|
|
||||||
target_seeds = min(target_seeds, F)
|
|
||||||
if _HAVE_NUMBA:
|
|
||||||
seeds_arr = _farthest_point_seeds_jit(
|
|
||||||
face_centroid, face_area, face_weight, initial_seeds, target_seeds,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
seeds_arr = _farthest_point_seeds_numpy(
|
|
||||||
face_centroid, initial_seeds, target_seeds,
|
|
||||||
)
|
|
||||||
seed_faces = [int(s) for s in seeds_arr.tolist()]
|
|
||||||
|
|
||||||
K = len(seed_faces)
|
K = len(seed_faces)
|
||||||
chart_basis = np.zeros((K, 3), dtype=np.float32)
|
chart_basis = np.zeros((K, 3), dtype=np.float32)
|
||||||
@ -345,10 +210,9 @@ def _segment_charts_fast(
|
|||||||
return _renumber(face_chart, device)
|
return _renumber(face_chart, device)
|
||||||
|
|
||||||
min_dist_to_seed = np.full(F, np.inf, dtype=np.float32)
|
min_dist_to_seed = np.full(F, np.inf, dtype=np.float32)
|
||||||
if adaptive_seeding:
|
for sf in seed_faces:
|
||||||
for sf in seed_faces:
|
d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1)
|
||||||
d = ((face_centroid - face_centroid[sf]) ** 2).sum(axis=-1)
|
min_dist_to_seed = np.minimum(min_dist_to_seed, d)
|
||||||
min_dist_to_seed = np.minimum(min_dist_to_seed, d)
|
|
||||||
|
|
||||||
if _HAVE_NUMBA:
|
if _HAVE_NUMBA:
|
||||||
# Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg.
|
# Multi-pass threshold schedule (low-cost first); tau cap 0.5 keeps cones ~30deg.
|
||||||
@ -375,8 +239,6 @@ def _segment_charts_fast(
|
|||||||
break
|
break
|
||||||
if (face_chart == -1).sum() == 0:
|
if (face_chart == -1).sum() == 0:
|
||||||
break
|
break
|
||||||
if not adaptive_seeding:
|
|
||||||
break
|
|
||||||
if chart_basis.shape[0] >= max_total_charts:
|
if chart_basis.shape[0] >= max_total_charts:
|
||||||
break
|
break
|
||||||
unassigned_mask = face_chart == -1
|
unassigned_mask = face_chart == -1
|
||||||
@ -462,24 +324,6 @@ def _segment_charts_fast(
|
|||||||
return _renumber(face_chart, device)
|
return _renumber(face_chart, device)
|
||||||
|
|
||||||
|
|
||||||
def segment_charts(
|
|
||||||
mesh: MeshData,
|
|
||||||
max_cost: float = DEFAULT_MAX_COST,
|
|
||||||
w_normal_deviation: float = DEFAULT_W_NORMAL_DEVIATION,
|
|
||||||
w_roundness: float = DEFAULT_W_ROUNDNESS,
|
|
||||||
w_straightness: float = DEFAULT_W_STRAIGHTNESS,
|
|
||||||
target_chart_count: int = 0,
|
|
||||||
) -> Tensor:
|
|
||||||
"""Segment mesh into charts. Returns face -> chart_id."""
|
|
||||||
return _segment_charts_fast(
|
|
||||||
mesh, max_cost=max_cost,
|
|
||||||
w_normal_deviation=w_normal_deviation,
|
|
||||||
w_roundness=w_roundness,
|
|
||||||
w_straightness=w_straightness,
|
|
||||||
target_chart_count=target_chart_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ----
|
# ---- Parallel edge-collapse (PEC) chart clustering (CUDA) ----
|
||||||
def _combine_normal_cones(
|
def _combine_normal_cones(
|
||||||
axis_a: Tensor, half_a: Tensor,
|
axis_a: Tensor, half_a: Tensor,
|
||||||
@ -558,10 +402,7 @@ def _build_chart_edges(
|
|||||||
|
|
||||||
def cluster_charts_pec(
|
def cluster_charts_pec(
|
||||||
mesh: MeshData,
|
mesh: MeshData,
|
||||||
target_chart_count: int = 0,
|
|
||||||
max_cost: float = 0.7,
|
max_cost: float = 0.7,
|
||||||
area_penalty_weight: float = 0.0,
|
|
||||||
roundness_weight: float = 0.0,
|
|
||||||
max_iters: int = 1024,
|
max_iters: int = 1024,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg)."""
|
"""Parallel edge-collapse clustering; returns face_chart [F]. max_cost is the per-merge cutoff (~0.7 rad ~ 40deg)."""
|
||||||
@ -570,7 +411,6 @@ def cluster_charts_pec(
|
|||||||
faces = mesh.faces.to(torch.long)
|
faces = mesh.faces.to(torch.long)
|
||||||
vertices = mesh.vertices.to(torch.float32)
|
vertices = mesh.vertices.to(torch.float32)
|
||||||
face_normal = mesh.face_normal.to(torch.float32)
|
face_normal = mesh.face_normal.to(torch.float32)
|
||||||
face_area = mesh.face_area.to(torch.float32)
|
|
||||||
face_face = mesh.face_face.to(torch.long)
|
face_face = mesh.face_face.to(torch.long)
|
||||||
|
|
||||||
face_edge_len = face_edge_lengths(vertices, faces)
|
face_edge_len = face_edge_lengths(vertices, faces)
|
||||||
@ -578,11 +418,9 @@ def cluster_charts_pec(
|
|||||||
chart_id = torch.arange(F, dtype=torch.long, device=device)
|
chart_id = torch.arange(F, dtype=torch.long, device=device)
|
||||||
chart_axis = face_normal.clone()
|
chart_axis = face_normal.clone()
|
||||||
chart_half = torch.zeros(F, dtype=torch.float32, device=device)
|
chart_half = torch.zeros(F, dtype=torch.float32, device=device)
|
||||||
chart_area = face_area.clone()
|
|
||||||
chart_perim = face_edge_len.sum(dim=1).clone()
|
|
||||||
|
|
||||||
for it in range(max_iters):
|
for it in range(max_iters):
|
||||||
edges, edge_len = _build_chart_edges(face_face, chart_id, face_edge_len)
|
edges, _ = _build_chart_edges(face_face, chart_id, face_edge_len)
|
||||||
if edges.shape[0] == 0:
|
if edges.shape[0] == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -594,13 +432,6 @@ def cluster_charts_pec(
|
|||||||
half_b = chart_half[b]
|
half_b = chart_half[b]
|
||||||
_, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b)
|
_, new_half, _ = _combine_normal_cones(axis_a, half_a, axis_b, half_b)
|
||||||
cost = new_half.clone()
|
cost = new_half.clone()
|
||||||
if area_penalty_weight > 0.0:
|
|
||||||
new_area = chart_area[a] + chart_area[b]
|
|
||||||
cost = cost + area_penalty_weight * new_area
|
|
||||||
if roundness_weight > 0.0:
|
|
||||||
new_area_r = chart_area[a] + chart_area[b]
|
|
||||||
new_perim_r = chart_perim[a] + chart_perim[b] - 2.0 * edge_len
|
|
||||||
cost = cost + roundness_weight * (new_perim_r * new_perim_r) / new_area_r.clamp_min(1e-12)
|
|
||||||
|
|
||||||
# Pack (cost, edge_id) so scatter_reduce amin picks the right edge.
|
# Pack (cost, edge_id) so scatter_reduce amin picks the right edge.
|
||||||
E = edges.shape[0]
|
E = edges.shape[0]
|
||||||
@ -612,11 +443,12 @@ def cluster_charts_pec(
|
|||||||
chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True)
|
chart_min.scatter_reduce_(0, a, key, reduce="amin", include_self=True)
|
||||||
chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True)
|
chart_min.scatter_reduce_(0, b, key, reduce="amin", include_self=True)
|
||||||
|
|
||||||
# Mutual-min collapse: each chart in at most one merge per iter.
|
# Mutual-min collapse: each chart in at most one merge per iter (winners are disjoint pairs).
|
||||||
is_a_min = chart_min[a] == key
|
is_a_min = chart_min[a] == key
|
||||||
is_b_min = chart_min[b] == key
|
is_b_min = chart_min[b] == key
|
||||||
|
mutual = is_a_min & is_b_min
|
||||||
within = cost <= max_cost
|
within = cost <= max_cost
|
||||||
winners = is_a_min & is_b_min & within
|
winners = mutual & within
|
||||||
|
|
||||||
n_merge = int(winners.sum().item())
|
n_merge = int(winners.sum().item())
|
||||||
if n_merge == 0:
|
if n_merge == 0:
|
||||||
@ -624,7 +456,6 @@ def cluster_charts_pec(
|
|||||||
|
|
||||||
win_a = a[winners]
|
win_a = a[winners]
|
||||||
win_b = b[winners]
|
win_b = b[winners]
|
||||||
win_el = edge_len[winners]
|
|
||||||
|
|
||||||
axis_a_w = chart_axis[win_a]
|
axis_a_w = chart_axis[win_a]
|
||||||
half_a_w = chart_half[win_a]
|
half_a_w = chart_half[win_a]
|
||||||
@ -635,8 +466,6 @@ def cluster_charts_pec(
|
|||||||
)
|
)
|
||||||
chart_axis[win_a] = new_axis
|
chart_axis[win_a] = new_axis
|
||||||
chart_half[win_a] = new_half_w
|
chart_half[win_a] = new_half_w
|
||||||
chart_area[win_a] = chart_area[win_a] + chart_area[win_b]
|
|
||||||
chart_perim[win_a] = chart_perim[win_a] + chart_perim[win_b] - 2.0 * win_el
|
|
||||||
|
|
||||||
remap = torch.arange(N, dtype=torch.long, device=device)
|
remap = torch.arange(N, dtype=torch.long, device=device)
|
||||||
remap[win_b] = win_a
|
remap[win_b] = win_a
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types, io
|
||||||
import copy
|
import copy
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -18,6 +18,9 @@ from tqdm import tqdm
|
|||||||
from scipy.sparse import csr_matrix
|
from scipy.sparse import csr_matrix
|
||||||
from scipy.sparse.csgraph import connected_components
|
from scipy.sparse.csgraph import connected_components
|
||||||
from scipy.spatial import cKDTree
|
from scipy.spatial import cKDTree
|
||||||
|
import scipy.ndimage as ndi
|
||||||
|
|
||||||
|
MeshCameras = io.Custom("MESH_CAMERAS") # carries the camera set from RenderMeshViews → BakeViewsToTexture
|
||||||
|
|
||||||
def get_mesh_batch_item(mesh, index):
|
def get_mesh_batch_item(mesh, index):
|
||||||
if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None:
|
if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None:
|
||||||
@ -460,7 +463,8 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors
|
|||||||
try:
|
try:
|
||||||
vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution)
|
vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear failed ({e}); falling back to CPU")
|
comfy.model_management.raise_non_oom(e) # only fall back on OOM; surface real errors
|
||||||
|
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear ran out of memory ({e}); falling back to CPU")
|
||||||
vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution)
|
vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution)
|
||||||
if not ok.all():
|
if not ok.all():
|
||||||
vals[~ok] = _nearest(valid_positions[~ok]) # no occupied neighbour
|
vals[~ok] = _nearest(valid_positions[~ok]) # no occupied neighbour
|
||||||
@ -657,25 +661,46 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64, return_face=False):
|
|||||||
return bestp
|
return bestp
|
||||||
|
|
||||||
|
|
||||||
def _back_project_positions(position_map, mask, ref_v, ref_f):
|
def _back_project_positions(position_map, mask, ref_v, ref_f, max_query_res=1024):
|
||||||
"""Snap covered texels onto the reference mesh's true surface (pure-torch BVH, no
|
"""Snap covered texels onto the reference mesh's true surface (pure-torch BVH, no
|
||||||
cumesh/scipy/trimesh) so the voxel field is sampled at full detail, not along flat
|
cumesh/scipy/trimesh) so the voxel field is sampled at full detail, not along flat
|
||||||
triangle chords. Returns a new position_map."""
|
triangle chords. Returns a new position_map.
|
||||||
valid = np.ascontiguousarray(position_map[mask].astype(np.float32))
|
"""
|
||||||
if valid.shape[0] == 0:
|
if not mask.any():
|
||||||
return position_map
|
return position_map
|
||||||
|
|
||||||
dev = comfy.model_management.get_torch_device()
|
dev = comfy.model_management.get_torch_device()
|
||||||
rv = ref_v.detach().to(dev).float()
|
rv = ref_v.detach().to(dev).float()
|
||||||
rf = ref_f.detach().to(dev).long()
|
rf = ref_f.detach().to(dev).long()
|
||||||
tri = rv[rf]
|
tri = rv[rf]
|
||||||
Q = torch.from_numpy(valid).to(dev)
|
|
||||||
|
|
||||||
bvh = _build_triangle_bvh(tri)
|
bvh = _build_triangle_bvh(tri)
|
||||||
bp = _closest_points_on_mesh_bvh(Q, tri, bvh)
|
|
||||||
|
|
||||||
|
def _closest(pts_np):
|
||||||
|
return _closest_points_on_mesh_bvh(
|
||||||
|
torch.from_numpy(np.ascontiguousarray(pts_np.astype(np.float32))).to(dev), tri, bvh
|
||||||
|
).detach().cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
H, W, _ = position_map.shape
|
||||||
|
stride = max(1, int(math.ceil(max(H, W) / float(max_query_res))))
|
||||||
|
if stride == 1 or not mask[::stride, ::stride].any():
|
||||||
|
out = position_map.copy()
|
||||||
|
out[mask] = _closest(position_map[mask]).astype(position_map.dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# Low-res correction, then bilinear upsample to full resolution.
|
||||||
|
pos_lo = position_map[::stride, ::stride]
|
||||||
|
mask_lo = mask[::stride, ::stride]
|
||||||
|
Hl, Wl = mask_lo.shape
|
||||||
|
corr_lo = np.zeros((Hl, Wl, 3), dtype=np.float32)
|
||||||
|
corr_lo[mask_lo] = _closest(pos_lo[mask_lo]) - pos_lo[mask_lo].astype(np.float32)
|
||||||
|
inds = ndi.distance_transform_edt(~mask_lo, return_distances=False, return_indices=True)
|
||||||
|
corr_lo = corr_lo[tuple(inds)] # extrapolate into gutter (nearest)
|
||||||
|
corr = torch.nn.functional.interpolate(
|
||||||
|
torch.from_numpy(np.ascontiguousarray(corr_lo)).permute(2, 0, 1)[None].to(dev),
|
||||||
|
size=(H, W), mode="bilinear", align_corners=False,
|
||||||
|
)[0].permute(1, 2, 0).cpu().numpy()
|
||||||
out = position_map.copy()
|
out = position_map.copy()
|
||||||
out[mask] = bp.detach().cpu().numpy().astype(position_map.dtype)
|
out[mask] = position_map[mask] + corr[mask]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -703,6 +728,7 @@ def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64):
|
|||||||
nmin, nmax = bvh['nmin'], bvh['nmax']
|
nmin, nmax = bvh['nmin'], bvh['nmax']
|
||||||
left, right, order = bvh['left'], bvh['right'], bvh['order']
|
left, right, order = bvh['left'], bvh['right'], bvh['order']
|
||||||
inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs)
|
inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs)
|
||||||
|
tmaxN = tmax if torch.is_tensor(tmax) else torch.full((N,), float(tmax), device=dev) # per-ray far bound
|
||||||
hit = torch.zeros(N, dtype=torch.bool, device=dev)
|
hit = torch.zeros(N, dtype=torch.bool, device=dev)
|
||||||
# int32 stack: node indices fit in 31 bits and this [N, max_stack] array dominates memory.
|
# int32 stack: node indices fit in 31 bits and this [N, max_stack] array dominates memory.
|
||||||
stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev)
|
stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev)
|
||||||
@ -710,24 +736,24 @@ def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64):
|
|||||||
stack[:, 0] = 0
|
stack[:, 0] = 0
|
||||||
active = torch.arange(N, device=dev)
|
active = torch.arange(N, device=dev)
|
||||||
|
|
||||||
def slab(node, o, i):
|
def slab(node, o, i, tmx):
|
||||||
t1 = (nmin[node] - o) * i
|
t1 = (nmin[node] - o) * i
|
||||||
t2 = (nmax[node] - o) * i
|
t2 = (nmax[node] - o) * i
|
||||||
tnear = torch.minimum(t1, t2).amax(-1)
|
tnear = torch.minimum(t1, t2).amax(-1)
|
||||||
tfar = torch.maximum(t1, t2).amin(-1)
|
tfar = torch.maximum(t1, t2).amin(-1)
|
||||||
return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmax) & (tfar >= tmin)
|
return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmx) & (tfar >= tmin)
|
||||||
|
|
||||||
while active.numel() > 0:
|
while active.numel() > 0:
|
||||||
a = active
|
a = active
|
||||||
node = stack[a, sp[a] - 1]
|
node = stack[a, sp[a] - 1]
|
||||||
sp[a] = sp[a] - 1
|
sp[a] = sp[a] - 1
|
||||||
within = slab(node, orig[a], inv[a])
|
within = slab(node, orig[a], inv[a], tmaxN[a])
|
||||||
isleaf = node >= LEAF
|
isleaf = node >= LEAF
|
||||||
lv = within & isleaf
|
lv = within & isleaf
|
||||||
if bool(lv.any()):
|
if bool(lv.any()):
|
||||||
ga = a[lv]
|
ga = a[lv]
|
||||||
tt = tri[order[node[lv] - LEAF]]
|
tt = tri[order[node[lv] - LEAF]]
|
||||||
h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmax)
|
h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmaxN[ga])
|
||||||
hit[ga[h]] = True
|
hit[ga[h]] = True
|
||||||
iv = within & ~isleaf
|
iv = within & ~isleaf
|
||||||
if bool(iv.any()):
|
if bool(iv.any()):
|
||||||
@ -838,7 +864,7 @@ def _bake_ambient_occlusion(high_v, high_f, low_v_np, low_f_np, low_uv_np, low_n
|
|||||||
# memory for no gain; floor keeps tiny GPUs from thrashing into too many chunks.
|
# memory for no gain; floor keeps tiny GPUs from thrashing into too many chunks.
|
||||||
try:
|
try:
|
||||||
free = torch.cuda.mem_get_info(dev)[0] if dev.type == "cuda" else (2 << 30)
|
free = torch.cuda.mem_get_info(dev)[0] if dev.type == "cuda" else (2 << 30)
|
||||||
except Exception:
|
except RuntimeError:
|
||||||
free = 2 << 30
|
free = 2 << 30
|
||||||
ray_chunk = int(min(1 << 22, max(1 << 20, (free * 0.25) / (num_samples * 4 + 200))))
|
ray_chunk = int(min(1 << 22, max(1 << 20, (free * 0.25) / (num_samples * 4 + 200))))
|
||||||
face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution)
|
face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution)
|
||||||
@ -1059,11 +1085,8 @@ def _jfa_fill_gpu(img01, mask):
|
|||||||
return filled.cpu().numpy()
|
return filled.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
def _seam_fill(img01, mask, inpaint_radius):
|
def _seam_fill(img01, mask):
|
||||||
"""Fill UV-gutter texels (so seams don't pull in black) via JFA. `inpaint_radius<=0`
|
"""Fill UV-gutter texels (so seams don't pull in black) via JFA nearest-coverage."""
|
||||||
disables; the radius value itself is ignored (JFA fills all uncovered by nearest)."""
|
|
||||||
if inpaint_radius <= 0:
|
|
||||||
return img01
|
|
||||||
return _jfa_fill_gpu(img01, mask)
|
return _jfa_fill_gpu(img01, mask)
|
||||||
|
|
||||||
|
|
||||||
@ -1092,7 +1115,7 @@ def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None):
|
|||||||
|
|
||||||
|
|
||||||
def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
||||||
resolution, texture_size, uvs, inpaint_radius=3,
|
resolution, texture_size, uvs,
|
||||||
normalize_uvs=True, reference=None, pbar=None):
|
normalize_uvs=True, reference=None, pbar=None):
|
||||||
"""Bake a baseColor (+ optional metallicRoughness) texture: rasterize in UV space,
|
"""Bake a baseColor (+ optional metallicRoughness) texture: rasterize in UV space,
|
||||||
sample each texel from the sparse voxel volume. `uvs` (N,2) is the existing layout,
|
sample each texel from the sparse voxel volume. `uvs` (N,2) is the existing layout,
|
||||||
@ -1109,7 +1132,6 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
|||||||
|
|
||||||
v_np = vertices.detach().cpu().numpy().astype(np.float32)
|
v_np = vertices.detach().cpu().numpy().astype(np.float32)
|
||||||
f_np = faces.detach().cpu().numpy().astype(np.uint32)
|
f_np = faces.detach().cpu().numpy().astype(np.uint32)
|
||||||
fcount = int(f_np.shape[0])
|
|
||||||
|
|
||||||
uv_np = uvs.detach().cpu().numpy().astype(np.float32)
|
uv_np = uvs.detach().cpu().numpy().astype(np.float32)
|
||||||
if uv_np.shape[0] != v_np.shape[0]:
|
if uv_np.shape[0] != v_np.shape[0]:
|
||||||
@ -1117,13 +1139,6 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
|||||||
f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 "
|
f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 "
|
||||||
f"with vertices ({v_np.shape[0]})."
|
f"with vertices ({v_np.shape[0]})."
|
||||||
)
|
)
|
||||||
uv_min = uv_np.min(axis=0)
|
|
||||||
uv_max = uv_np.max(axis=0)
|
|
||||||
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
|
|
||||||
f"{fcount} faces")
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
|
|
||||||
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
|
|
||||||
uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ")
|
uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ")
|
||||||
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
|
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
|
||||||
|
|
||||||
@ -1151,12 +1166,12 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
|||||||
roughness = attrs[..., 4:5] if C >= 5 else None
|
roughness = attrs[..., 4:5] if C >= 5 else None
|
||||||
# alpha (idx 5) ignored — meshes kept opaque (upstream OPAQUE alpha_mode).
|
# alpha (idx 5) ignored — meshes kept opaque (upstream OPAQUE alpha_mode).
|
||||||
|
|
||||||
base_color = _seam_fill(np.ascontiguousarray(base_color), mask, inpaint_radius)
|
base_color = _seam_fill(np.ascontiguousarray(base_color), mask)
|
||||||
mr_image = None
|
mr_image = None
|
||||||
if has_pbr:
|
if has_pbr:
|
||||||
# glTF metallicRoughness: R unused, G=roughness, B=metallic.
|
# glTF metallicRoughness: R unused, G=roughness, B=metallic.
|
||||||
mr = np.concatenate([np.zeros_like(roughness), roughness, metallic], axis=-1)
|
mr = np.concatenate([np.zeros_like(roughness), roughness, metallic], axis=-1)
|
||||||
mr_image = _seam_fill(np.ascontiguousarray(mr), mask, inpaint_radius)
|
mr_image = _seam_fill(np.ascontiguousarray(mr), mask)
|
||||||
|
|
||||||
device = vertices.device
|
device = vertices.device
|
||||||
out_v = torch.from_numpy(new_verts).to(device=device, dtype=torch.float32)
|
out_v = torch.from_numpy(new_verts).to(device=device, dtype=torch.float32)
|
||||||
@ -1195,8 +1210,8 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.Mesh.Input("mesh"),
|
IO.Mesh.Input("mesh"),
|
||||||
IO.Voxel.Input("voxel_colors"),
|
IO.Voxel.Input("voxel_colors"),
|
||||||
IO.Int.Input("texture_size", default=1024, min=64, max=8192,
|
IO.Int.Input("texture_size", default=2048, min=64, max=8192,
|
||||||
tooltip="Square texture resolution."),
|
tooltip="Square UV atlas resolution."),
|
||||||
IO.Mesh.Input("reference_mesh", optional=True,
|
IO.Mesh.Input("reference_mesh", optional=True,
|
||||||
tooltip=(
|
tooltip=(
|
||||||
"Optional dense pre-decimation mesh; back-projects each texel onto its "
|
"Optional dense pre-decimation mesh; back-projects each texel onto its "
|
||||||
@ -1211,13 +1226,11 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, mesh, voxel_colors, texture_size, reference_mesh=None):
|
def execute(cls, mesh, voxel_colors, texture_size, reference_mesh=None):
|
||||||
# Matches official to_glb; effectively on/off since the gutter fill ignores the value.
|
|
||||||
inpaint_radius = 3
|
|
||||||
voxels = voxel_colors
|
voxels = voxel_colors
|
||||||
coords = voxels.data
|
coords = voxels.data
|
||||||
colors = voxels.voxel_colors
|
colors = voxels.voxel_colors
|
||||||
resolution = voxels.resolution
|
resolution = voxels.resolution
|
||||||
mesh_uvs = getattr(mesh, "uvs", None)
|
mesh_uvs = mesh.uvs
|
||||||
if mesh_uvs is None:
|
if mesh_uvs is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"BakeTextureFromVoxel: input mesh has no UVs. This node bakes onto the "
|
"BakeTextureFromVoxel: input mesh has no UVs. This node bakes onto the "
|
||||||
@ -1249,7 +1262,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
|
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
|
||||||
v_i, f_i, item_coords, item_colors,
|
v_i, f_i, item_coords, item_colors,
|
||||||
resolution=resolution, texture_size=texture_size,
|
resolution=resolution, texture_size=texture_size,
|
||||||
uvs=ev_i, inpaint_radius=inpaint_radius,
|
uvs=ev_i,
|
||||||
reference=ref_i, pbar=pbar,
|
reference=ref_i, pbar=pbar,
|
||||||
)
|
)
|
||||||
out_tex.append(bt)
|
out_tex.append(bt)
|
||||||
@ -1275,7 +1288,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
|
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
|
||||||
v0, f0, coords, colors,
|
v0, f0, coords, colors,
|
||||||
resolution=resolution, texture_size=texture_size,
|
resolution=resolution, texture_size=texture_size,
|
||||||
uvs=ev0, inpaint_radius=inpaint_radius,
|
uvs=ev0,
|
||||||
reference=ref0, pbar=pbar,
|
reference=ref0, pbar=pbar,
|
||||||
)
|
)
|
||||||
base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0)
|
base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0)
|
||||||
@ -1317,9 +1330,9 @@ class MeshTextureToImage(IO.ComfyNode):
|
|||||||
t = t.unsqueeze(0)
|
t = t.unsqueeze(0)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
base = _as_image(getattr(mesh, "texture", None))
|
base = _as_image(mesh.texture)
|
||||||
mr = _as_image(getattr(mesh, "metallic_roughness", None))
|
mr = _as_image(mesh.metallic_roughness)
|
||||||
normal_map = _as_image(getattr(mesh, "normal_map", None))
|
normal_map = _as_image(mesh.normal_map)
|
||||||
|
|
||||||
if base is None:
|
if base is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1335,7 +1348,7 @@ class MeshTextureToImage(IO.ComfyNode):
|
|||||||
roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous()
|
roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous()
|
||||||
# R is real occlusion only if AO was baked; else it's the unused zero channel, which as
|
# R is real occlusion only if AO was baked; else it's the unused zero channel, which as
|
||||||
# "occlusion" would read fully-dark — so report white unless occlusion_in_mr is set.
|
# "occlusion" would read fully-dark — so report white unless occlusion_in_mr is set.
|
||||||
if getattr(mesh, "occlusion_in_mr", False):
|
if mesh.occlusion_in_mr:
|
||||||
occlusion = mr[..., 0:1].expand(-1, -1, -1, 3).contiguous()
|
occlusion = mr[..., 0:1].expand(-1, -1, -1, 3).contiguous()
|
||||||
else:
|
else:
|
||||||
occlusion = torch.ones_like(base)
|
occlusion = torch.ones_like(base)
|
||||||
@ -1365,7 +1378,7 @@ class ApplyTextureToMesh(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, mesh, base_color, metallic=None, roughness=None, occlusion=None, normal_map=None):
|
def execute(cls, mesh, base_color, metallic=None, roughness=None, occlusion=None, normal_map=None):
|
||||||
mesh_uvs = getattr(mesh, "uvs", None)
|
mesh_uvs = mesh.uvs
|
||||||
if mesh_uvs is None:
|
if mesh_uvs is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
|
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
|
||||||
@ -1391,7 +1404,7 @@ class ApplyTextureToMesh(IO.ComfyNode):
|
|||||||
# and export the smooth normals the TBN was built on — without a NORMAL attribute the
|
# and export the smooth normals the TBN was built on — without a NORMAL attribute the
|
||||||
# viewer shades flat and the tangent-space detail fights the faceting.
|
# viewer shades flat and the tangent-space detail fights the faceting.
|
||||||
dev = comfy.model_management.get_torch_device()
|
dev = comfy.model_management.get_torch_device()
|
||||||
low_n_attr = getattr(mesh, "normals", None)
|
low_n_attr = mesh.normals
|
||||||
B = int(mesh.vertices.shape[0])
|
B = int(mesh.vertices.shape[0])
|
||||||
Nmax = int(mesh.vertices.shape[1]) if mesh.vertices.ndim == 3 else int(mesh.vertices.shape[0])
|
Nmax = int(mesh.vertices.shape[1]) if mesh.vertices.ndim == 3 else int(mesh.vertices.shape[0])
|
||||||
tangents_padded = torch.zeros((B, Nmax, 4), dtype=torch.float32)
|
tangents_padded = torch.zeros((B, Nmax, 4), dtype=torch.float32)
|
||||||
@ -1467,7 +1480,7 @@ class BakeNormalMapFromMesh(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, low_poly, high_poly, resolution, cage_distance=0.05, ignore_backfaces=True):
|
def execute(cls, low_poly, high_poly, resolution, cage_distance=0.05, ignore_backfaces=True):
|
||||||
low_uvs = getattr(low_poly, "uvs", None)
|
low_uvs = low_poly.uvs
|
||||||
if low_uvs is None:
|
if low_uvs is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"BakeNormalMapFromMesh: low_poly has no UVs. Connect the UV-unwrapped "
|
"BakeNormalMapFromMesh: low_poly has no UVs. Connect the UV-unwrapped "
|
||||||
@ -1475,8 +1488,8 @@ class BakeNormalMapFromMesh(IO.ComfyNode):
|
|||||||
"onto existing UVs and never unwraps.")
|
"onto existing UVs and never unwraps.")
|
||||||
dev = comfy.model_management.get_torch_device()
|
dev = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
low_n_attr = getattr(low_poly, "normals", None)
|
low_n_attr = low_poly.normals
|
||||||
high_n_attr = getattr(high_poly, "normals", None)
|
high_n_attr = high_poly.normals
|
||||||
B = int(low_poly.vertices.shape[0])
|
B = int(low_poly.vertices.shape[0])
|
||||||
h_batch = int(high_poly.vertices.shape[0])
|
h_batch = int(high_poly.vertices.shape[0])
|
||||||
|
|
||||||
@ -1549,13 +1562,13 @@ class BakeAmbientOcclusion(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, low_poly, high_poly, resolution, samples, max_distance, strength, bias):
|
def execute(cls, low_poly, high_poly, resolution, samples, max_distance, strength, bias):
|
||||||
low_uvs = getattr(low_poly, "uvs", None)
|
low_uvs = low_poly.uvs
|
||||||
if low_uvs is None:
|
if low_uvs is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"BakeAmbientOcclusion: low_poly has no UVs. Connect the UV-unwrapped low-poly "
|
"BakeAmbientOcclusion: low_poly has no UVs. Connect the UV-unwrapped low-poly "
|
||||||
"(the same one used for the other bakes); this node never unwraps.")
|
"(the same one used for the other bakes); this node never unwraps.")
|
||||||
dev = comfy.model_management.get_torch_device()
|
dev = comfy.model_management.get_torch_device()
|
||||||
low_n_attr = getattr(low_poly, "normals", None)
|
low_n_attr = low_poly.normals
|
||||||
B = int(low_poly.vertices.shape[0])
|
B = int(low_poly.vertices.shape[0])
|
||||||
h_batch = int(high_poly.vertices.shape[0])
|
h_batch = int(high_poly.vertices.shape[0])
|
||||||
|
|
||||||
@ -1630,7 +1643,7 @@ class SetMeshMaterial(IO.ComfyNode):
|
|||||||
base_color_r, base_color_g, base_color_b, metallic_factor, roughness_factor,
|
base_color_r, base_color_g, base_color_b, metallic_factor, roughness_factor,
|
||||||
normal_scale, occlusion_strength, double_sided, emissive_texture=None):
|
normal_scale, occlusion_strength, double_sided, emissive_texture=None):
|
||||||
out_mesh = copy.copy(mesh)
|
out_mesh = copy.copy(mesh)
|
||||||
material = dict(getattr(mesh, "material", {}) or {}) # merge over any prior material
|
material = dict(mesh.material or {}) # merge over any prior material
|
||||||
material.update({
|
material.update({
|
||||||
"emissive_factor": [float(emissive_r), float(emissive_g), float(emissive_b)],
|
"emissive_factor": [float(emissive_r), float(emissive_g), float(emissive_b)],
|
||||||
"emissive_strength": float(emissive_strength),
|
"emissive_strength": float(emissive_strength),
|
||||||
@ -2201,7 +2214,8 @@ class DecimateMesh(IO.ComfyNode):
|
|||||||
if rc is not None:
|
if rc is not None:
|
||||||
c = rc.to(src_device)
|
c = rc.to(src_device)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"DecimateMesh: QEM simplify failed, passing mesh through unchanged: {e!r}")
|
comfy.model_management.raise_non_oom(e) # surface real errors; only OOM passes through
|
||||||
|
logging.warning(f"DecimateMesh: QEM simplify ran out of memory, passing mesh through unchanged: {e!r}")
|
||||||
counts["out"] += int(f.shape[0])
|
counts["out"] += int(f.shape[0])
|
||||||
return v, f, c
|
return v, f, c
|
||||||
|
|
||||||
@ -2318,7 +2332,8 @@ class RemeshMesh(IO.ComfyNode):
|
|||||||
f = rf.to(src_device)
|
f = rf.to(src_device)
|
||||||
c = rc.to(src_device) if rc is not None else None
|
c = rc.to(src_device) if rc is not None else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"RemeshMesh: remesh failed, passing mesh through unchanged: {e!r}")
|
comfy.model_management.raise_non_oom(e) # surface real errors; only OOM passes through
|
||||||
|
logging.warning(f"RemeshMesh: remesh ran out of memory, passing mesh through unchanged: {e!r}")
|
||||||
counts["out"] += int(f.shape[0])
|
counts["out"] += int(f.shape[0])
|
||||||
return v, f, c
|
return v, f, c
|
||||||
|
|
||||||
@ -2409,9 +2424,9 @@ def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance
|
|||||||
if segmenter == "pec":
|
if segmenter == "pec":
|
||||||
if mesh.faces.device.type != "cuda":
|
if mesh.faces.device.type != "cuda":
|
||||||
raise RuntimeError("segmenter='pec' requires a CUDA mesh; use 'adaptive' for CPU.")
|
raise RuntimeError("segmenter='pec' requires a CUDA mesh; use 'adaptive' for CPU.")
|
||||||
face_chart = _uv_seg.cluster_charts_pec(mesh, target_chart_count=0, max_cost=1.0)
|
face_chart = _uv_seg.cluster_charts_pec(mesh, max_cost=1.0)
|
||||||
elif segmenter == "adaptive":
|
elif segmenter == "adaptive":
|
||||||
face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0, target_chart_count=0)
|
face_chart = _uv_seg.segment_charts(mesh, max_cost=2.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown segmenter '{segmenter}'. valid: pec, adaptive")
|
raise ValueError(f"unknown segmenter '{segmenter}'. valid: pec, adaptive")
|
||||||
|
|
||||||
@ -2534,12 +2549,12 @@ class UnwrapMesh(IO.ComfyNode):
|
|||||||
if is_list or is_batched:
|
if is_list or is_batched:
|
||||||
vi, fi = mesh.vertices[i], mesh.faces[i]
|
vi, fi = mesh.vertices[i], mesh.faces[i]
|
||||||
ci = None
|
ci = None
|
||||||
vc = getattr(mesh, "vertex_colors", None)
|
vc = mesh.vertex_colors
|
||||||
if vc is not None:
|
if vc is not None:
|
||||||
ci = vc[i] if (isinstance(vc, list) or vc.ndim == 3) else vc
|
ci = vc[i] if (isinstance(vc, list) or vc.ndim == 3) else vc
|
||||||
else:
|
else:
|
||||||
vi, fi = mesh.vertices, mesh.faces
|
vi, fi = mesh.vertices, mesh.faces
|
||||||
ci = getattr(mesh, "vertex_colors", None)
|
ci = mesh.vertex_colors
|
||||||
|
|
||||||
src_device = vi.device
|
src_device = vi.device
|
||||||
vnp = vi.detach().cpu().numpy().astype(np.float32)
|
vnp = vi.detach().cpu().numpy().astype(np.float32)
|
||||||
@ -2561,7 +2576,7 @@ class UnwrapMesh(IO.ComfyNode):
|
|||||||
bar.update(1)
|
bar.update(1)
|
||||||
|
|
||||||
out_mesh = _pack_uv_meshes(out_v, out_f, out_uv, out_c if out_c else None)
|
out_mesh = _pack_uv_meshes(out_v, out_f, out_uv, out_c if out_c else None)
|
||||||
if getattr(mesh, "texture", None) is not None:
|
if mesh.texture is not None:
|
||||||
out_mesh.texture = mesh.texture
|
out_mesh.texture = mesh.texture
|
||||||
|
|
||||||
if cls.hidden.unique_id:
|
if cls.hidden.unique_id:
|
||||||
@ -2743,7 +2758,7 @@ class RenderUVAtlas(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, mesh, resolution):
|
def execute(cls, mesh, resolution):
|
||||||
uvs_t = getattr(mesh, "uvs", None)
|
uvs_t = mesh.uvs
|
||||||
if uvs_t is None:
|
if uvs_t is None:
|
||||||
raise RuntimeError("mesh has no UVs to render. Run UnwrapMesh first.")
|
raise RuntimeError("mesh has no UVs to render. Run UnwrapMesh first.")
|
||||||
uvs_np = uvs_t.detach().cpu().numpy()
|
uvs_np = uvs_t.detach().cpu().numpy()
|
||||||
@ -2847,8 +2862,8 @@ def merge_meshes(meshes):
|
|||||||
def _b0(t):
|
def _b0(t):
|
||||||
return t[0] if t.ndim == 3 else t
|
return t[0] if t.ndim == 3 else t
|
||||||
|
|
||||||
any_uvs = any(getattr(m, "uvs", None) is not None for m in meshes)
|
any_uvs = any(m.uvs is not None for m in meshes)
|
||||||
any_colors = any(getattr(m, "vertex_colors", None) is not None for m in meshes)
|
any_colors = any(m.vertex_colors is not None for m in meshes)
|
||||||
|
|
||||||
verts_list, faces_list, uvs_list, colors_list = [], [], [], []
|
verts_list, faces_list, uvs_list, colors_list = [], [], [], []
|
||||||
texture = None
|
texture = None
|
||||||
@ -2861,16 +2876,16 @@ def merge_meshes(meshes):
|
|||||||
faces_list.append(f + offset)
|
faces_list.append(f + offset)
|
||||||
offset += v.shape[0]
|
offset += v.shape[0]
|
||||||
if any_uvs:
|
if any_uvs:
|
||||||
mu = getattr(m, "uvs", None)
|
mu = m.uvs
|
||||||
uvs_list.append(_b0(mu).cpu() if mu is not None else v.new_zeros((v.shape[0], 2)))
|
uvs_list.append(_b0(mu).cpu() if mu is not None else v.new_zeros((v.shape[0], 2)))
|
||||||
if any_colors:
|
if any_colors:
|
||||||
mc = getattr(m, "vertex_colors", None)
|
mc = m.vertex_colors
|
||||||
if mc is not None:
|
if mc is not None:
|
||||||
c = _b0(mc).cpu()
|
c = _b0(mc).cpu()
|
||||||
else:
|
else:
|
||||||
c = v.new_ones((v.shape[0], 3))
|
c = v.new_ones((v.shape[0], 3))
|
||||||
colors_list.append(c)
|
colors_list.append(c)
|
||||||
mt = getattr(m, "texture", None)
|
mt = m.texture
|
||||||
if mt is not None:
|
if mt is not None:
|
||||||
if texture is None:
|
if texture is None:
|
||||||
texture = mt.cpu()
|
texture = mt.cpu()
|
||||||
|
|||||||
@ -105,16 +105,31 @@ def get_mesh_batch_item(mesh, index):
|
|||||||
return mesh.vertices[index], mesh.faces[index], colors, uvs, normals
|
return mesh.vertices[index], mesh.faces[index], colors, uvs, normals
|
||||||
|
|
||||||
|
|
||||||
def _smooth_vertex_normals(vertices_np, faces_np):
|
def _smooth_vertex_normals(vertices_np, faces_np, weld=True):
|
||||||
"""Area-weighted per-vertex normals (unit length), fully smooth — no vertex splitting.
|
"""Area-weighted per-vertex normals (unit length), fully smooth — no vertex splitting.
|
||||||
|
|
||||||
Un-normalized face normals (the raw cross product) have magnitude 2*area, so
|
Un-normalized face normals (the raw cross product) have magnitude 2*area, so
|
||||||
accumulating them onto their vertices yields an area-weighted average."""
|
accumulating them onto their vertices yields an area-weighted average. `weld` averages
|
||||||
|
across vertices that share a position — UV-seam duplicates created by unwrapping — so
|
||||||
|
both sides of a seam get one identical normal. Without it each side averages only its
|
||||||
|
own faces and a visible shading seam appears; welding matches the official, which
|
||||||
|
computes normals on the pre-split mesh and gathers them through the UV vmap."""
|
||||||
tris = vertices_np[faces_np] # (M, 3, 3)
|
tris = vertices_np[faces_np] # (M, 3, 3)
|
||||||
face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0])
|
face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0])
|
||||||
normals = np.zeros((vertices_np.shape[0], 3), dtype=np.float64)
|
if weld and vertices_np.shape[0]:
|
||||||
for k in range(3):
|
# Group coincident positions (quantized to ~1e-5 of the bbox) into one shared normal.
|
||||||
np.add.at(normals, faces_np[:, k], face_n)
|
lo = vertices_np.min(0)
|
||||||
|
inv_tol = 1.0 / (max(float((vertices_np.max(0) - lo).max()), 1e-9) * 1e-5)
|
||||||
|
q = np.round((vertices_np - lo) * inv_tol).astype(np.int64)
|
||||||
|
_, group = np.unique(q, axis=0, return_inverse=True)
|
||||||
|
acc = np.zeros((int(group.max()) + 1, 3), dtype=np.float64)
|
||||||
|
for k in range(3):
|
||||||
|
np.add.at(acc, group[faces_np[:, k]], face_n)
|
||||||
|
normals = acc[group] # welded normal back to each vertex
|
||||||
|
else:
|
||||||
|
normals = np.zeros((vertices_np.shape[0], 3), dtype=np.float64)
|
||||||
|
for k in range(3):
|
||||||
|
np.add.at(normals, faces_np[:, k], face_n)
|
||||||
lens = np.linalg.norm(normals, axis=1, keepdims=True)
|
lens = np.linalg.norm(normals, axis=1, keepdims=True)
|
||||||
normals /= np.where(lens > 1e-12, lens, 1.0)
|
normals /= np.where(lens > 1e-12, lens, 1.0)
|
||||||
return normals.astype(np.float32)
|
return normals.astype(np.float32)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user