Normal and AO baking

This commit is contained in:
kijai 2026-06-30 01:18:33 +03:00
parent ab58d1b79f
commit 42ac23f6f6
3 changed files with 920 additions and 151 deletions

View File

@ -7,7 +7,7 @@ import copy
import comfy.utils
import comfy.model_management
from server import PromptServer
from comfy_extras.mesh3d.postprocess.qem_decimate import QEMConfig, qem_decimate_simplify, qem_cluster_decimate
from comfy_extras.mesh3d.postprocess.qem_decimate import QEMConfig, qem_decimate_simplify, qem_cluster_decimate, _compute_vertex_normals
from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc, _point_tri_closest
from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh
from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg
@ -166,23 +166,23 @@ class PaintMesh(IO.ComfyNode):
return IO.NodeOutput(out_mesh)
def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
"""Rasterize the mesh in UV space and barycentric-interpolate the per-vertex vec3
(world position, or any vec3 attr e.g. normals) at each covered texel. Pure torch,
tiled point-in-triangle no GL/EGL, runs anywhere torch does. Returns (attr_map
[H,W,3] float32, mask [H,W] bool). """
def _rasterize_uv_barycentric(faces_np, uvs_np, texture_size):
"""Rasterize the mesh in UV space (tiled point-in-triangle, pure torch). Returns per-texel
face index [H,W], barycentric coords [H,W,3] and coverage mask [H,W], on the torch device.
Interpolate any per-vertex attribute from these with _interp_vertex_attr."""
dev = comfy.model_management.get_torch_device()
H = W = int(texture_size)
face_idx = torch.zeros((H, W), dtype=torch.long, device=dev)
bary = torch.zeros((H, W, 3), device=dev)
cov = torch.zeros((H, W), dtype=torch.bool, device=dev)
if faces_np.shape[0] == 0:
return np.zeros((H, W, 3), dtype=np.float32), np.zeros((H, W), dtype=bool)
return face_idx, bary, cov
verts = torch.from_numpy(np.ascontiguousarray(verts_np, dtype=np.float32)).to(dev)
uvs = torch.from_numpy(np.ascontiguousarray(uvs_np, dtype=np.float32)).to(dev)
faces = torch.from_numpy(np.ascontiguousarray(faces_np).astype(np.int64)).to(dev)
# GL convention: window coord = uv * resolution, coverage tested at texel centre.
tri_uv = (uvs * float(W))[faces] # [F,3,2]
tri_attr = verts[faces] # [F,3,3]
x0, y0 = tri_uv[:, 0, 0], tri_uv[:, 0, 1]
x1, y1 = tri_uv[:, 1, 0], tri_uv[:, 1, 1]
x2, y2 = tri_uv[:, 2, 0], tri_uv[:, 2, 1]
@ -194,9 +194,6 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, H - 1).long()
ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, H - 1).long()
pos_out = torch.zeros((H, W, 3), device=dev)
cov = torch.zeros((H, W), dtype=torch.bool, device=dev)
# Tile so point-in-triangle only runs over the triangles whose bbox hits the tile.
TILE = 64
eps = 1e-6
@ -224,15 +221,40 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
continue
hit = inside.any(dim=0) # [th,tw]
sel = inside.int().argmax(dim=0) # [th,tw] first covering local tri
b0s = b0.gather(0, sel[None]).squeeze(0) # [th,tw] bary of selected tri
b1s = b1.gather(0, sel[None]).squeeze(0)
b2s = b2.gather(0, sel[None]).squeeze(0)
p = tri_attr[idx[sel]] # [th,tw,3,3]
attr = b0s[..., None] * p[..., 0, :] + b1s[..., None] * p[..., 1, :] + b2s[..., None] * p[..., 2, :]
pos_out[ty:ty_end, tx:tx_end][hit] = attr[hit] # slice is a view → writes through
bsel = torch.stack([b0.gather(0, sel[None]).squeeze(0),
b1.gather(0, sel[None]).squeeze(0),
b2.gather(0, sel[None]).squeeze(0)], dim=-1) # [th,tw,3]
face_idx[ty:ty_end, tx:tx_end][hit] = idx[sel][hit] # slice is a view → writes through
bary[ty:ty_end, tx:tx_end][hit] = bsel[hit]
cov[ty:ty_end, tx:tx_end] |= hit
return pos_out.cpu().numpy(), cov.cpu().numpy()
return face_idx, bary, cov
def _interp_vertex_attr(attr_v, faces, face_idx, bary, mask):
"""Interpolate a per-vertex attribute [N,C] into a [H,W,C] map via a rasterized
(face_idx, bary, mask). Uncovered texels stay zero."""
H, W = mask.shape
out = torch.zeros((H, W, attr_v.shape[1]), device=attr_v.device, dtype=attr_v.dtype)
if mask.any():
vtri = attr_v[faces[face_idx[mask]]] # [K,3,C]
out[mask] = (bary[mask][:, :, None] * vtri).sum(1)
return out
def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
"""Barycentric-interpolate a per-vertex vec3 (world position, or any vec3 e.g. normals)
at each covered texel. Returns (attr_map [H,W,3] float32, mask [H,W] bool)."""
dev = comfy.model_management.get_torch_device()
H = W = int(texture_size)
if faces_np.shape[0] == 0:
return np.zeros((H, W, 3), dtype=np.float32), np.zeros((H, W), dtype=bool)
face_idx, bary, mask = _rasterize_uv_barycentric(faces_np, uvs_np, texture_size)
verts = torch.from_numpy(np.ascontiguousarray(verts_np, dtype=np.float32)).to(dev)
faces = torch.from_numpy(np.ascontiguousarray(faces_np).astype(np.int64)).to(dev)
attr = _interp_vertex_attr(verts, faces, face_idx, bary, mask)
return attr.cpu().numpy(), mask.cpu().numpy()
def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution):
@ -565,9 +587,10 @@ def _build_triangle_bvh(tri):
return dict(LEAF=LEAF, left=left, right=right, nmin=nmin, nmax=nmax, order=order, T=T)
def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64, return_face=False):
"""Exact closest surface point per query via per-query BVH stack traversal
(nearest-child-first), pure torch. Returns [N,3]. `max_stack` bounds the stack
(nearest-child-first), pure torch. Returns [N,3], or (points [N,3], face_idx [N])
when return_face=True (face_idx indexes `tri`). `max_stack` bounds the stack
(= tree height); overflow is counted+warned, not silently wrong."""
dev = Q.device
N = Q.shape[0]
@ -582,6 +605,7 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
stack[:, 0] = 0
best = torch.full((N,), 1e30, device=dev)
bestp = Q.clone()
bestf = torch.full((N,), -1, dtype=torch.long, device=dev)
active = torch.arange(N, device=dev)
overflow = 0
@ -599,12 +623,14 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
lv = within & isleaf
if bool(lv.any()):
ga = a[lv]
tt = tri[order[node[lv] - LEAF]]
fidx = order[node[lv] - LEAF] # triangle index of each leaf
tt = tri[fidx]
cp, d2 = _point_tri_closest(qa[lv], tt)
upd = d2 < best[ga]
gu = ga[upd]
best[gu] = d2[upd]
bestp[gu] = cp[upd]
bestf[gu] = fidx[upd]
iv = within & ~isleaf
if bool(iv.any()):
gi = a[iv]
@ -627,6 +653,8 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
logging.warning(f"[back-project] BVH stack overflow on {overflow} pushes "
f"(max_stack={max_stack}); a few texels may be slightly off — "
f"raise max_stack if this is large.")
if return_face:
return bestp, bestf
return bestp
@ -652,6 +680,352 @@ def _back_project_positions(position_map, mask, ref_v, ref_f):
return out
def _ray_tri_hit(o, d, tri, tmin, tmax):
"""Möller-Trumbore any-hit per (ray, triangle) pair, double-sided. Returns bool [N]."""
a, b, c = tri[:, 0], tri[:, 1], tri[:, 2]
e1, e2 = b - a, c - a
p = torch.cross(d, e2, dim=-1)
det = (e1 * p).sum(-1)
inv = 1.0 / torch.where(det.abs() < 1e-20, torch.full_like(det, 1e-20), det)
tvec = o - a
u = (tvec * p).sum(-1) * inv
q = torch.cross(tvec, e1, dim=-1)
v = (d * q).sum(-1) * inv
t = (e2 * q).sum(-1) * inv
return (det.abs() > 1e-20) & (u >= 0) & (v >= 0) & (u + v <= 1) & (t > tmin) & (t < tmax)
def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64):
"""Any-hit ray test over the BVH (slab cull + Möller-Trumbore), pure torch. Returns bool
[N]: True if the ray hits any triangle in (tmin, tmax). Rays early-out once they hit."""
dev = orig.device
N = orig.shape[0]
LEAF = bvh['LEAF']
nmin, nmax = bvh['nmin'], bvh['nmax']
left, right, order = bvh['left'], bvh['right'], bvh['order']
inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs)
hit = torch.zeros(N, dtype=torch.bool, device=dev)
# int32 stack: node indices fit in 31 bits and this [N, max_stack] array dominates memory.
stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev)
sp = torch.ones(N, dtype=torch.long, device=dev)
stack[:, 0] = 0
active = torch.arange(N, device=dev)
def slab(node, o, i):
t1 = (nmin[node] - o) * i
t2 = (nmax[node] - o) * i
tnear = torch.minimum(t1, t2).amax(-1)
tfar = torch.maximum(t1, t2).amin(-1)
return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmax) & (tfar >= tmin)
while active.numel() > 0:
a = active
node = stack[a, sp[a] - 1]
sp[a] = sp[a] - 1
within = slab(node, orig[a], inv[a])
isleaf = node >= LEAF
lv = within & isleaf
if bool(lv.any()):
ga = a[lv]
tt = tri[order[node[lv] - LEAF]]
h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmax)
hit[ga[h]] = True
iv = within & ~isleaf
if bool(iv.any()):
gi = a[iv]
s0 = sp[gi]
stack[gi, s0.clamp(max=max_stack - 1)] = left[node[iv]].to(torch.int32)
sp[gi] = (s0 + 1).clamp(max=max_stack)
s1 = sp[gi]
stack[gi, s1.clamp(max=max_stack - 1)] = right[node[iv]].to(torch.int32)
sp[gi] = (s1 + 1).clamp(max=max_stack)
active = a[(sp[a] > 0) & ~hit[a]] # drop finished + already-hit rays
return hit
def _ray_tri_intersect(o, d, tri, tmin, tmax, cull_backface=False):
"""Möller-Trumbore per (ray, triangle) pair. Returns (hit [N], t [N]) where t is the ray
parameter and hit means the meeting is in (tmin, tmax). With cull_backface, drops faces whose
outward (winding) normal points along the ray i.e. only keep surfaces facing the origin."""
a, b, c = tri[:, 0], tri[:, 1], tri[:, 2]
e1, e2 = b - a, c - a
p = torch.cross(d, e2, dim=-1)
det = (e1 * p).sum(-1)
inv = 1.0 / torch.where(det.abs() < 1e-20, torch.full_like(det, 1e-20), det)
tvec = o - a
u = (tvec * p).sum(-1) * inv
q = torch.cross(tvec, e1, dim=-1)
v = (d * q).sum(-1) * inv
t = (e2 * q).sum(-1) * inv
hit = (det.abs() > 1e-20) & (u >= 0) & (v >= 0) & (u + v <= 1) & (t > tmin) & (t < tmax)
if cull_backface:
hit = hit & ((torch.cross(e1, e2, dim=-1) * d).sum(-1) < 0) # keep only front-facing
return hit, t
def _closest_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64, cull_backface=False):
"""Nearest-hit ray cast over the BVH, pure torch. Returns (t [N], face [N] long, -1 on
miss; hit [N] bool) the closest intersection in (tmin, tmax), pruning nodes past best_t."""
dev = orig.device
N = orig.shape[0]
LEAF = bvh['LEAF']
nmin, nmax = bvh['nmin'], bvh['nmax']
left, right, order = bvh['left'], bvh['right'], bvh['order']
inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs)
best_t = torch.full((N,), float(tmax), device=dev)
best_f = torch.full((N,), -1, dtype=torch.long, device=dev)
stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev)
sp = torch.ones(N, dtype=torch.long, device=dev)
stack[:, 0] = 0
active = torch.arange(N, device=dev)
while active.numel() > 0:
a = active
node = stack[a, sp[a] - 1]
sp[a] = sp[a] - 1
t1 = (nmin[node] - orig[a]) * inv[a]
t2 = (nmax[node] - orig[a]) * inv[a]
tnear = torch.minimum(t1, t2).amax(-1)
tfar = torch.maximum(t1, t2).amin(-1)
within = (tfar >= tnear.clamp_min(tmin)) & (tfar >= tmin) & (tnear < best_t[a]) # prune past best
isleaf = node >= LEAF
lv = within & isleaf
if bool(lv.any()):
ga = a[lv]
fidx = order[node[lv] - LEAF]
h, t = _ray_tri_intersect(orig[ga], dirs[ga], tri[fidx], tmin, tmax, cull_backface)
upd = h & (t < best_t[ga])
gu = ga[upd]
best_t[gu] = t[upd]
best_f[gu] = fidx[upd]
iv = within & ~isleaf
if bool(iv.any()):
gi = a[iv]
s0 = sp[gi]
stack[gi, s0.clamp(max=max_stack - 1)] = left[node[iv]].to(torch.int32)
sp[gi] = (s0 + 1).clamp(max=max_stack)
s1 = sp[gi]
stack[gi, s1.clamp(max=max_stack - 1)] = right[node[iv]].to(torch.int32)
sp[gi] = (s1 + 1).clamp(max=max_stack)
active = a[sp[a] > 0]
return best_t, best_f, best_f >= 0
def _onb(n):
"""Branchless orthonormal basis (t, b) around unit normals n [N,3]."""
up = torch.where(n[..., 2:3].abs() < 0.999,
torch.tensor([0.0, 0.0, 1.0], device=n.device).expand_as(n),
torch.tensor([1.0, 0.0, 0.0], device=n.device).expand_as(n))
t = torch.nn.functional.normalize(torch.cross(up, n, dim=-1), dim=-1, eps=1e-6)
return t, torch.cross(n, t, dim=-1)
def _bake_ambient_occlusion(high_v, high_f, low_v_np, low_f_np, low_uv_np, low_n, resolution,
num_samples=64, max_distance=0.5, strength=1.0, bias=0.01,
ray_chunk=None, pbar=None):
"""Bake high-poly ambient occlusion into the low-poly's UV layout: per texel, cosine-weight
a hemisphere of rays around the normal and cast them at the high-poly. AO = 1 - hit-fraction
(cosine weighting makes the hit-fraction the estimator). Returns ao_img [H,W,3] in [0,1].
ray_chunk caps rays cast at once (the per-chunk BVH stack is its dominant transient VRAM);
None auto-sizes it to a slice of free VRAM big chunks (fast) on large GPUs, small (safe)
on small ones."""
dev = comfy.model_management.get_torch_device()
H = W = int(resolution)
S = int(num_samples)
if ray_chunk is None:
# ~376 B/ray (int32 stack max_stack*4 + a few [N,3] ray buffers); spend a quarter of free
# VRAM. Speed saturates around 4M rays/chunk, so cap there (≈2 GB peak) rather than grow
# memory for no gain; floor keeps tiny GPUs from thrashing into too many chunks.
try:
free = torch.cuda.mem_get_info(dev)[0] if dev.type == "cuda" else (2 << 30)
except Exception:
free = 2 << 30
ray_chunk = int(min(1 << 22, max(1 << 20, (free * 0.25) / (num_samples * 4 + 200))))
face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution)
if not mask.any():
return np.ones((H, W, 3), dtype=np.float32)
lf = torch.from_numpy(np.ascontiguousarray(low_f_np).astype(np.int64)).to(dev)
lv = torch.from_numpy(np.ascontiguousarray(low_v_np, dtype=np.float32)).to(dev)
low_n = low_n.to(dev).float()
m = mask
vtri = lf[face_idx[m]] # [K,3] vertex ids
bsel = bary_uv[m] # [K,3]
P = (bsel[:, :, None] * lv[vtri]).sum(1) # [K,3]
Nl = torch.nn.functional.normalize((bsel[:, :, None] * low_n[vtri]).sum(1), dim=-1, eps=1e-6)
hv = high_v.to(dev).float()
hf = high_f.to(dev).long()
tri = hv[hf]
bvh = _build_triangle_bvh(tri)
diag = float((hv.amax(0) - hv.amin(0)).norm().clamp_min(1e-6))
biasw = max(1e-5, float(bias) * diag)
tmax = float(max_distance) * diag
# Back-project onto the high surface, then lift along the normal: the low-poly chord can sit
# below the high surface, and casting from below floods false self-occlusion (dark blotches).
bp = _closest_points_on_mesh_bvh(P, tri, bvh)
origins = bp + Nl * biasw
K = P.shape[0]
T, B = _onb(Nl)
occ = torch.zeros(K, device=dev)
tex_per_chunk = max(1, int(ray_chunk) // max(1, S))
for s in range(0, K, tex_per_chunk):
e = min(s + tex_per_chunk, K)
kk = e - s
o, n, t, b = origins[s:e], Nl[s:e], T[s:e], B[s:e]
r1 = torch.rand(kk, S, device=dev)
r2 = torch.rand(kk, S, device=dev)
sr = r1.sqrt()
lz = r1.mul_(-1.0).add_(1.0).clamp_min_(0.0).sqrt_() # sqrt(1-r1) (r1 dead after sr)
ang = r2.mul_(2.0 * math.pi) # in place (r2 dead)
lx = ang.cos().mul_(sr)
ly = ang.sin().mul_(sr)
d = t[:, None, :] * lx[..., None] # cosine-weighted hemisphere,
d.addcmul_(b[:, None, :], ly[..., None]) # fused d += b*ly
d.addcmul_(n[:, None, :], lz[..., None]) # fused d += n*lz (no extra temps)
d = torch.nn.functional.normalize(d.reshape(-1, 3), dim=-1, eps=1e-6)
oo = o[:, None, :].expand(-1, S, -1).reshape(-1, 3)
hit = _any_hit_rays_bvh(oo, d, tri, bvh, tmin=biasw, tmax=tmax)
occ[s:e] = hit.reshape(kk, S).sum(1, dtype=torch.float32).div_(float(S)) # mean without a float copy
if pbar is not None:
pbar.update(1)
ao = occ.mul_(-float(strength)).add_(1.0).clamp_(0.0, 1.0) # 1 - occ*strength, in place (occ is dead)
out = torch.ones((H, W), device=dev)
out[m] = ao
out3 = np.repeat(out.cpu().numpy()[..., None], 3, axis=2)
return _jfa_fill_gpu(np.ascontiguousarray(out3, dtype=np.float32), mask.cpu().numpy())
def _compute_vertex_tangents(verts, faces, uvs, normals):
"""Per-vertex tangents (Lengyel) orthonormalized against `normals`. Returns [N,4]:
unit tangent xyz + handedness w (the bitangent is w * cross(N, T)). Pure torch."""
N = verts.shape[0]
i0, i1, i2 = faces[:, 0].long(), faces[:, 1].long(), faces[:, 2].long()
e1, e2 = verts[i1] - verts[i0], verts[i2] - verts[i0]
d1, d2 = uvs[i1] - uvs[i0], uvs[i2] - uvs[i0]
denom = d1[:, 0] * d2[:, 1] - d2[:, 0] * d1[:, 1]
r = 1.0 / torch.where(denom.abs() < 1e-20, torch.full_like(denom, 1e-20), denom)
tan = (d2[:, 1:2] * e1 - d1[:, 1:2] * e2) * r[:, None] # [F,3]
bit = (d1[:, 0:1] * e2 - d2[:, 0:1] * e1) * r[:, None]
tacc = torch.zeros((N, 3), device=verts.device, dtype=verts.dtype)
bacc = torch.zeros((N, 3), device=verts.device, dtype=verts.dtype)
for idx in (i0, i1, i2):
tacc.scatter_add_(0, idx[:, None].expand(-1, 3), tan)
bacc.scatter_add_(0, idx[:, None].expand(-1, 3), bit)
n = torch.nn.functional.normalize(normals, dim=-1, eps=1e-6)
# Gram-Schmidt: drop the normal component, then renormalize.
t = torch.nn.functional.normalize(tacc - n * (n * tacc).sum(-1, keepdim=True), dim=-1, eps=1e-6)
w = torch.sign((torch.cross(n, t, dim=-1) * bacc).sum(-1))
w = torch.where(w == 0, torch.ones_like(w), w) # degenerate → right-handed
return torch.cat([t, w[:, None]], dim=-1)
def _vertex_tangents_for_item(lv, lf, uv, low_n_attr_i, dev):
"""Per-item shading normals + tangents. Shared by the bake (BakeNormalMapFromMesh) and the
export attach (ApplyTextureToMesh) so their basis can't diverge. `low_n_attr_i` is the
mesh's per-item normals or None (then computed). Returns (low_n [N,3], tangents [N,4])."""
low_n = low_n_attr_i.to(dev).float() if low_n_attr_i is not None else _compute_vertex_normals(lv, lf)
tangents = _compute_vertex_tangents(lv, lf, uv.to(dev).float(), low_n)
return low_n, tangents
def _barycentric(p, tri):
"""Barycentric coords [N,3] of points p [N,3] wrt triangles tri [N,3,3] (per-pair)."""
a, b, c = tri[:, 0], tri[:, 1], tri[:, 2]
v0, v1, v2 = b - a, c - a, p - a
d00 = (v0 * v0).sum(-1)
d01 = (v0 * v1).sum(-1)
d11 = (v1 * v1).sum(-1)
d20 = (v2 * v0).sum(-1)
d21 = (v2 * v1).sum(-1)
denom = d00 * d11 - d01 * d01
denom = torch.where(denom.abs() < 1e-20, torch.full_like(denom, 1e-20), denom)
v = (d11 * d20 - d01 * d21) / denom
w = (d00 * d21 - d01 * d20) / denom
return torch.stack([1.0 - v - w, v, w], dim=-1)
def _bake_normal_map(high_v, high_f, high_n, low_v_np, low_f_np, low_uv_np, low_n, tangents,
resolution, cage_distance=0.05, ignore_backfaces=True):
"""Tangent-space normal map (glTF/OpenGL +Y) of the high-poly baked into the low-poly's UV
layout. Per texel a cage ray (along the normal, over cage_distance * bbox-diagonal) finds the
matching high-poly surface, whose normal is projected into the texel's TBN frame.
ignore_backfaces skips surfaces facing away (crevices/enclosures); misses fall back to
closest-point. Returns [H,W,3] in [0,1]."""
dev = comfy.model_management.get_torch_device()
H = W = int(resolution)
flat = np.array([0.5, 0.5, 1.0], dtype=np.float32)
# One rasterization, then interpolate position/normal/tangent/handedness by indexing it.
face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution)
if not mask.any():
return np.tile(flat, (H, W, 1))
lf = torch.from_numpy(np.ascontiguousarray(low_f_np).astype(np.int64)).to(dev)
lv = torch.from_numpy(np.ascontiguousarray(low_v_np, dtype=np.float32)).to(dev)
low_n = low_n.to(dev).float()
tangents = tangents.to(dev).float()
m = mask
fsel = face_idx[m] # [K] source face per texel
bsel = bary_uv[m] # [K,3]
vtri = lf[fsel] # [K,3] vertex ids
def _interp(attr): # attr [N,C] -> [K,C]
return (bsel[:, :, None] * attr[vtri]).sum(1)
P = _interp(lv) # [K,3] world pos
Nl = torch.nn.functional.normalize(_interp(low_n), dim=-1, eps=1e-6)
Tl = _interp(tangents[:, :3])
Wl = _interp(tangents[:, 3:4])[:, 0]
hv = high_v.to(dev).float()
hf = high_f.to(dev).long()
tri = hv[hf]
bvh = _build_triangle_bvh(tri)
# Cage ray-cast: from cage outward, march back along -normal and take the nearest (outermost)
# hit. Closest-point is the fallback where the ray misses.
diag = float((hv.amax(0) - hv.amin(0)).norm().clamp_min(1e-6))
cage = max(1e-6, float(cage_distance) * diag)
origin = P + Nl * cage
t_hit, f_hit, ray_hit = _closest_hit_rays_bvh(origin, -Nl, tri, bvh, tmin=0.0, tmax=2.0 * cage,
cull_backface=bool(ignore_backfaces))
bface = f_hit.clamp_min(0)
hitpoint = origin - t_hit[:, None] * Nl
# Closest-point fallback only for texels the ray missed (usually few) — running it over every
# texel wastes a full BVH traversal on the ones already resolved by the ray.
miss = ~ray_hit
if bool(miss.any()):
bp_m, bf_m = _closest_points_on_mesh_bvh(P[miss], tri, bvh, return_face=True)
bface = bface.clone()
hitpoint = hitpoint.clone()
bface[miss] = bf_m.clamp_min(0)
hitpoint[miss] = bp_m
htri = tri[bface] # [K,3,3]
bary = _barycentric(hitpoint, htri)
hn_tri = high_n.to(dev).float()[hf[bface]] # [K,3,3] vertex normals
Nh = torch.nn.functional.normalize((bary[:, :, None] * hn_tri).sum(1), dim=-1, eps=1e-6)
# Per-texel TBN (Gram-Schmidt tangent against the interpolated normal).
T = torch.nn.functional.normalize(Tl - Nl * (Nl * Tl).sum(-1, keepdim=True), dim=-1, eps=1e-6)
Bn = Wl[:, None] * torch.cross(Nl, T, dim=-1)
nz = (Nh * Nl).sum(-1) # reused as z-channel and the back-face test
ts = torch.stack([(Nh * T).sum(-1), (Nh * Bn).sum(-1), nz], dim=-1)
ts = torch.nn.functional.normalize(ts, dim=-1, eps=1e-6)
# Safety net: if the matched high normal faces away from the texel (a back surface the fallback
# grabbed in a deep crevice), use the flat base normal rather than a wrong one.
ts[nz < 0.0] = torch.tensor([0.0, 0.0, 1.0], device=dev)
enc = ts.mul_(0.5).add_(0.5).clamp_(0.0, 1.0) # encode in place (ts is dead)
out = torch.from_numpy(np.tile(flat, (H, W, 1))).to(dev)
out[m] = enc
# Dilate into the UV gutter so bilinear/mip sampling at chart edges doesn't bleed flat blue.
return _jfa_fill_gpu(out.cpu().numpy(), mask.cpu().numpy())
def _jfa_fill_gpu(img01, mask):
"""Fill uncovered texels with nearest covered value via GPU Jump Flooding
(O(log n) passes; replaces cv2.inpaint). img01 [H,W,C] float, mask [H,W] bool."""
@ -919,15 +1293,17 @@ class MeshTextureToImage(IO.ComfyNode):
display_name="Mesh Texture to Image",
category="latent/3d",
description=(
"Extracts a mesh's baked textures as IMAGE outputs: base_color and the packed "
"glTF MR map (G=roughness, B=metallic; black if no PBR texture)."
"Extracts a mesh's baked textures as individual IMAGEs: base_color, metallic, "
"roughness, occlusion and normal_map. Channels with nothing baked come back "
"neutral (occlusion white, normal flat)."
),
inputs=[IO.Mesh.Input("mesh")],
outputs=[
IO.Image.Output(display_name="base_color"),
IO.Image.Output(display_name="metallic_roughness"),
IO.Image.Output(display_name="metallic"),
IO.Image.Output(display_name="roughness"),
IO.Image.Output(display_name="occlusion"),
IO.Image.Output(display_name="normal_map"),
],
)
@ -944,6 +1320,7 @@ class MeshTextureToImage(IO.ComfyNode):
base = _as_image(getattr(mesh, "texture", None))
mr = _as_image(getattr(mesh, "metallic_roughness", None))
normal_map = _as_image(getattr(mesh, "normal_map", None))
if base is None:
raise ValueError(
@ -952,10 +1329,18 @@ class MeshTextureToImage(IO.ComfyNode):
)
if mr is None:
mr = torch.zeros_like(base)
# Split packed MR into grayscale previews (G=roughness, B=metallic), to 3ch.
if normal_map is None:
normal_map = torch.ones_like(base) * torch.tensor([0.5, 0.5, 1.0]) # neutral flat normal
# Unpack the ORM map (R=occlusion, G=roughness, B=metallic) to 3-channel grayscale.
metallic = mr[..., 2:3].expand(-1, -1, -1, 3).contiguous()
roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous()
return IO.NodeOutput(base, mr, metallic, roughness)
# R is real occlusion only if AO was baked; else it's the unused zero channel, which as
# "occlusion" would read fully-dark — so report white unless occlusion_in_mr is set.
if getattr(mesh, "occlusion_in_mr", False):
occlusion = mr[..., 0:1].expand(-1, -1, -1, 3).contiguous()
else:
occlusion = torch.ones_like(base)
return IO.NodeOutput(base, metallic, roughness, occlusion, normal_map)
class ApplyTextureToMesh(IO.ComfyNode):
@ -966,29 +1351,26 @@ class ApplyTextureToMesh(IO.ComfyNode):
display_name="Apply Texture to Mesh",
category="latent/3d",
description=(
"Attaches baked texture IMAGEs to a mesh's existing UV layout for SaveGLB. "
"Pairs with BakeTextureFromVoxel: feed the SAME mesh and its base_color "
"(optionally metallic/roughness); don't re-unwrap in between. metallic/roughness "
"repack into the glTF MR map (G=roughness, B=metallic); missing metallic=0, "
"roughness=1."
"Attaches baked texture IMAGEs to a mesh's UV layout for SaveGLB. Feed the SAME mesh you baked"
),
inputs=[
IO.Mesh.Input("mesh"),
IO.Image.Input("base_color"),
IO.Image.Input("metallic", optional=True),
IO.Image.Input("roughness", optional=True),
IO.Image.Input("occlusion", optional=True),
IO.Image.Input("normal_map", optional=True),
],
outputs=[IO.Mesh.Output("mesh")],
)
@classmethod
def execute(cls, mesh, base_color, metallic=None, roughness=None):
def execute(cls, mesh, base_color, metallic=None, roughness=None, occlusion=None, normal_map=None):
mesh_uvs = getattr(mesh, "uvs", None)
if mesh_uvs is None:
raise ValueError(
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
"you fed to BakeTextureFromVoxel (this node attaches onto existing UVs and "
"never unwraps).")
"you fed to BakeTextureFromVoxel.")
# Re-derive the exact UVs the bake used (shared _normalize_uvs_to_unit), per item.
if mesh_uvs.ndim == 3:
@ -1005,15 +1387,264 @@ class ApplyTextureToMesh(IO.ComfyNode):
out_mesh = copy.copy(mesh)
out_mesh.uvs = new_uvs
out_mesh.texture = base_color.float().clamp(0.0, 1.0).cpu()
if metallic is not None or roughness is not None:
# Repack glTF MR (G=roughness, B=metallic); missing channel → scalar (metal 0/rough 1).
prov = (metallic if metallic is not None else roughness).float().clamp(0.0, 1.0).cpu()
B, H, W, _ = prov.shape
rough_ch = (roughness.float().clamp(0.0, 1.0).cpu()[..., 0:1]
if roughness is not None else torch.ones((B, H, W, 1)))
metal_ch = (metallic.float().clamp(0.0, 1.0).cpu()[..., 0:1]
if metallic is not None else torch.zeros((B, H, W, 1)))
out_mesh.metallic_roughness = torch.cat([torch.zeros((B, H, W, 1)), rough_ch, metal_ch], dim=-1)
if normal_map is not None:
# Recompute tangents (shared helper, same normalized UVs → same basis as the bake)
# and export the smooth normals the TBN was built on — without a NORMAL attribute the
# viewer shades flat and the tangent-space detail fights the faceting.
dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(mesh, "normals", None)
B = int(mesh.vertices.shape[0])
Nmax = int(mesh.vertices.shape[1]) if mesh.vertices.ndim == 3 else int(mesh.vertices.shape[0])
tangents_padded = torch.zeros((B, Nmax, 4), dtype=torch.float32)
normals_padded = torch.zeros((B, Nmax, 3), dtype=torch.float32)
for i in range(B):
v_i, f_i, _ = get_mesh_batch_item(mesh, i)
n = int(v_i.shape[0])
if f_i.numel() == 0:
continue
lv, lf = v_i.to(dev).float(), f_i.to(dev).long()
uv_i = new_uvs[i, :n] if new_uvs.ndim == 3 else new_uvs[:n]
n_attr_i = low_n_attr[i, :n] if low_n_attr is not None else None
low_n, tangents = _vertex_tangents_for_item(lv, lf, uv_i, n_attr_i, dev)
tangents_padded[i, :n] = tangents.cpu()
normals_padded[i, :n] = low_n.cpu()
out_mesh.normal_map = normal_map.float().clamp(0.0, 1.0).cpu()
out_mesh.tangents = tangents_padded
out_mesh.normals = normals_padded
if metallic is not None or roughness is not None or occlusion is not None:
# Pack glTF ORM (R=occlusion, G=roughness, B=metallic); missing → 1/1/0. Maps may
# arrive at different resolutions, so resize each channel to a common H×W first.
provided = [x for x in (metallic, roughness, occlusion) if x is not None]
B = int(provided[0].shape[0])
H = max(int(x.shape[1]) for x in provided)
W = max(int(x.shape[2]) for x in provided)
def _chan(img, default):
if img is None:
return torch.full((B, H, W, 1), float(default))
t = img.float().clamp(0.0, 1.0).cpu()[..., 0:1]
if int(t.shape[1]) != H or int(t.shape[2]) != W:
t = torch.nn.functional.interpolate(t.permute(0, 3, 1, 2), size=(H, W),
mode="bilinear", align_corners=False).permute(0, 2, 3, 1)
return t
out_mesh.metallic_roughness = torch.cat(
[_chan(occlusion, 1.0), _chan(roughness, 1.0), _chan(metallic, 0.0)], dim=-1)
if occlusion is not None:
# Tells SaveGLB to also point occlusionTexture at the MR image (R = AO).
out_mesh.occlusion_in_mr = True
return IO.NodeOutput(out_mesh)
class BakeNormalMapFromMesh(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BakeNormalMapFromMesh",
display_name="Bake Normal Map from Mesh",
category="latent/3d",
description=(
"Bakes a tangent-space normal map (glTF/OpenGL +Y) from a high-poly mesh into a "
"low-poly's UV layout, capturing detail lost to decimation. Feed the UV-unwrapped "
"low_poly and the same-frame high_poly it was decimated from. Outputs an IMAGE for "
"ApplyTextureToMesh's normal_map input."
),
inputs=[
IO.Mesh.Input("low_poly"),
IO.Mesh.Input("high_poly"),
IO.Int.Input("resolution", default=1024, min=64, max=8192, step=64,
display_name="resolution"),
IO.Float.Input("cage_distance", default=0.05, min=0.001, max=0.5, step=0.001,
tooltip="Surface search band, as a fraction of the bbox diagonal. "
"Raise for wrong/missing patches under heavy decimation; "
"lower if it grabs across gaps."),
IO.Boolean.Input("ignore_backfaces", default=True,
tooltip="Skip high-poly surfaces facing away from the texel, so "
"crevices/enclosed spaces don't grab the opposite wall. "
"Disable only if the high-poly winding is inconsistent."),
],
outputs=[IO.Image.Output(display_name="normal_map")],
)
@classmethod
def execute(cls, low_poly, high_poly, resolution, cage_distance=0.05, ignore_backfaces=True):
low_uvs = getattr(low_poly, "uvs", None)
if low_uvs is None:
raise ValueError(
"BakeNormalMapFromMesh: low_poly has no UVs. Connect the UV-unwrapped "
"low-poly (the same one you fed to BakeTextureFromVoxel); this node bakes "
"onto existing UVs and never unwraps.")
dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(low_poly, "normals", None)
high_n_attr = getattr(high_poly, "normals", None)
B = int(low_poly.vertices.shape[0])
h_batch = int(high_poly.vertices.shape[0])
imgs = []
for i in range(B):
v_i, f_i, _ = get_mesh_batch_item(low_poly, i)
n = int(v_i.shape[0])
if f_i.numel() == 0:
logging.warning(f"BakeNormalMapFromMesh: skipping batch {i} (empty mesh)")
imgs.append(torch.full((int(resolution), int(resolution), 3), 0.5))
continue
uv_i = low_uvs[i, :n] if low_uvs.ndim == 3 else low_uvs[:n]
uv_np = _normalize_uvs_to_unit(uv_i.detach().cpu().numpy(), log_prefix="[BakeNormalMapFromMesh] ")
lv = v_i.to(dev).float()
lf = f_i.to(dev).long()
# Tangents build the per-texel TBN; ApplyTextureToMesh recomputes the same basis on export.
n_attr_i = low_n_attr[i, :n] if low_n_attr is not None else None
low_n, tangents = _vertex_tangents_for_item(lv, lf, torch.from_numpy(uv_np).to(dev), n_attr_i, dev)
hv_i, hf_i, _ = get_mesh_batch_item(high_poly, i if h_batch > 1 else 0)
hv = hv_i.to(dev).float()
hf = hf_i.to(dev).long()
high_n = (high_n_attr[i, :hv.shape[0]].to(dev).float() if high_n_attr is not None
else _compute_vertex_normals(hv, hf))
img = _bake_normal_map(
hv, hf, high_n,
lv.detach().cpu().numpy(), lf.detach().cpu().numpy().astype(np.uint32), uv_np,
low_n, tangents, resolution, cage_distance=float(cage_distance),
ignore_backfaces=bool(ignore_backfaces),
)
imgs.append(torch.from_numpy(np.ascontiguousarray(img)).float())
normal_img = torch.stack([t.clamp(0.0, 1.0) for t in imgs], dim=0)
return IO.NodeOutput(normal_img)
class BakeAmbientOcclusion(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BakeAmbientOcclusion",
display_name="Bake Ambient Occlusion",
category="latent/3d",
description=(
"Bakes an ambient-occlusion map from a high-poly mesh into a low-poly's UV "
"layout (white = open, dark = crevices). Feed the UV-unwrapped low_poly and the "
"high_poly it was decimated from. Outputs a grayscale IMAGE for "
"ApplyTextureToMesh's occlusion input (packed into the ORM map / occlusionTexture)."
),
inputs=[
IO.Mesh.Input("low_poly"),
IO.Mesh.Input("high_poly"),
IO.Int.Input("resolution", default=1024, min=64, max=8192, step=64),
IO.Int.Input("samples", default=64, min=4, max=1024, step=4,
tooltip="Rays per texel. More = smoother, slower. Raise if grainy."),
IO.Float.Input("max_distance", default=0.5, min=0.01, max=2.0, step=0.01,
tooltip="Ray length, as a fraction of the bbox diagonal. "
"Smaller = tighter, more local occlusion."),
IO.Float.Input("strength", default=1.0, min=0.0, max=2.0, step=0.05,
tooltip="Scales the occlusion. >1 darkens, <1 lightens."),
IO.Float.Input("bias", default=0.01, min=0.0001, max=0.2, step=0.0005,
tooltip="Ray origin lift off the surface, as a fraction of the bbox "
"diagonal. Raise if even surfaces show dark blotches/holes."),
],
outputs=[IO.Image.Output(display_name="occlusion")],
)
@classmethod
def execute(cls, low_poly, high_poly, resolution, samples, max_distance, strength, bias):
low_uvs = getattr(low_poly, "uvs", None)
if low_uvs is None:
raise ValueError(
"BakeAmbientOcclusion: low_poly has no UVs. Connect the UV-unwrapped low-poly "
"(the same one used for the other bakes); this node never unwraps.")
dev = comfy.model_management.get_torch_device()
low_n_attr = getattr(low_poly, "normals", None)
B = int(low_poly.vertices.shape[0])
h_batch = int(high_poly.vertices.shape[0])
pbar = comfy.utils.ProgressBar(max(1, B)) # one tick per batch item
imgs = []
for i in range(B):
v_i, f_i, _ = get_mesh_batch_item(low_poly, i)
n = int(v_i.shape[0])
if f_i.numel() == 0:
logging.warning(f"BakeAmbientOcclusion: skipping batch {i} (empty mesh)")
imgs.append(torch.ones((int(resolution), int(resolution), 3)))
pbar.update(1)
continue
uv_i = low_uvs[i, :n] if low_uvs.ndim == 3 else low_uvs[:n]
uv_np = _normalize_uvs_to_unit(uv_i.detach().cpu().numpy(), log_prefix="[BakeAmbientOcclusion] ")
lv = v_i.to(dev).float()
lf = f_i.to(dev).long()
low_n = (low_n_attr[i, :n].to(dev).float() if low_n_attr is not None
else _compute_vertex_normals(lv, lf))
hv_i, hf_i, _ = get_mesh_batch_item(high_poly, i if h_batch > 1 else 0)
img = _bake_ambient_occlusion(
hv_i.to(dev).float(), hf_i.to(dev).long(),
lv.detach().cpu().numpy(), lf.detach().cpu().numpy().astype(np.uint32), uv_np,
low_n, resolution, num_samples=int(samples),
max_distance=float(max_distance), strength=float(strength), bias=float(bias),
)
imgs.append(torch.from_numpy(np.ascontiguousarray(img)).float())
pbar.update(1)
ao_img = torch.stack([t.clamp(0.0, 1.0) for t in imgs], dim=0)
return IO.NodeOutput(ao_img)
class SetMeshMaterial(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SetMeshMaterial",
display_name="Set Mesh Material",
category="latent/3d",
description=(
"Sets glTF material properties SaveGLB can't derive from textures: emissive "
"(color + strength + optional texture), baseColor tint, metallic/roughness "
"factors, normal scale, occlusion strength, double-sided. Place before SaveGLB."
),
inputs=[
IO.Mesh.Input("mesh"),
IO.Float.Input("emissive_r", default=0.0, min=0.0, max=1.0, step=0.01),
IO.Float.Input("emissive_g", default=0.0, min=0.0, max=1.0, step=0.01),
IO.Float.Input("emissive_b", default=0.0, min=0.0, max=1.0, step=0.01),
IO.Float.Input("emissive_strength", default=1.0, min=0.0, max=100.0, step=0.1,
tooltip=">1 for HDR glow (KHR_materials_emissive_strength)."),
IO.Image.Input("emissive_texture", optional=True),
IO.Float.Input("base_color_r", default=1.0, min=0.0, max=1.0, step=0.01),
IO.Float.Input("base_color_g", default=1.0, min=0.0, max=1.0, step=0.01),
IO.Float.Input("base_color_b", default=1.0, min=0.0, max=1.0, step=0.01),
IO.Float.Input("metallic_factor", default=-1.0, min=-1.0, max=1.0, step=0.01,
tooltip="-1 = leave auto; 0..1 overrides."),
IO.Float.Input("roughness_factor", default=-1.0, min=-1.0, max=1.0, step=0.01,
tooltip="-1 = leave auto; 0..1 overrides."),
IO.Float.Input("normal_scale", default=1.0, min=0.0, max=10.0, step=0.05),
IO.Float.Input("occlusion_strength", default=1.0, min=0.0, max=1.0, step=0.01),
IO.Boolean.Input("double_sided", default=True),
],
outputs=[IO.Mesh.Output("mesh")],
)
@classmethod
def execute(cls, mesh, emissive_r, emissive_g, emissive_b, emissive_strength,
base_color_r, base_color_g, base_color_b, metallic_factor, roughness_factor,
normal_scale, occlusion_strength, double_sided, emissive_texture=None):
out_mesh = copy.copy(mesh)
material = dict(getattr(mesh, "material", {}) or {}) # merge over any prior material
material.update({
"emissive_factor": [float(emissive_r), float(emissive_g), float(emissive_b)],
"emissive_strength": float(emissive_strength),
"base_color_factor": [float(base_color_r), float(base_color_g), float(base_color_b), 1.0],
"metallic_factor": float(metallic_factor), # <0 => leave auto
"roughness_factor": float(roughness_factor),
"normal_scale": float(normal_scale),
"occlusion_strength": float(occlusion_strength),
"double_sided": bool(double_sided),
})
out_mesh.material = material
if emissive_texture is not None:
out_mesh.emissive = emissive_texture.float().clamp(0.0, 1.0).cpu()
return IO.NodeOutput(out_mesh)
@ -2278,8 +2909,7 @@ class MergeMeshes(IO.ComfyNode):
category="latent/3d",
description=(
"Concatenate N meshes into one by offsetting face indices and stacking verts, "
"faces, uvs, and colors. E.g. combine a Pixal3D object with a MoGe background "
"into one GLB."
"faces, uvs, and colors."
),
inputs=[
IO.Autogrow.Input("meshes", template=autogrow_template),
@ -2306,6 +2936,9 @@ class PostProcessMeshExtension(ComfyExtension):
BakeTextureFromVoxel,
MeshTextureToImage,
ApplyTextureToMesh,
BakeNormalMapFromMesh,
BakeAmbientOcclusion,
SetMeshMaterial,
MergeMeshes,
]

View File

@ -20,11 +20,11 @@ from comfy_api.latest import ComfyExtension, IO, Types
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False,
normals=None, metallic_roughness=None):
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
# stashing per-item lengths as runtime attrs so consumers can recover the real slice.
# colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts.
# texture is (B, H, W, 3) — passed through unchanged
normals=None, metallic_roughness=None, tangents=None, normal_map=None,
occlusion_in_mr=False, material=None, emissive=None):
# Pack per-item tensors into padded batches, stashing per-item lengths as runtime attrs.
# colors/uvs/normals/tangents are 1:1 with vertices (padded to max_vertices); texture/
# metallic_roughness/normal_map are (B,H,W,*) image stacks passed through unchanged.
batch_size = len(vertices)
max_vertices = max(v.shape[0] for v in vertices)
max_faces = max(f.shape[0] for f in faces)
@ -65,11 +65,31 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non
)
packed_normals[i, :nrm.shape[0]] = nrm
return Types.MESH(packed_vertices, packed_faces,
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
metallic_roughness=metallic_roughness,
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
normals=packed_normals)
packed_tangents = None
if tangents is not None:
packed_tangents = tangents[0].new_zeros((batch_size, max_vertices, tangents[0].shape[1]))
for i, tn in enumerate(tangents):
assert tn.shape[0] == vertices[i].shape[0], (
f"tangents[{i}] has {tn.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)"
)
packed_tangents[i, :tn.shape[0]] = tn
out = Types.MESH(packed_vertices, packed_faces,
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
metallic_roughness=metallic_roughness,
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
normals=packed_normals)
if packed_tangents is not None:
out.tangents = packed_tangents
if normal_map is not None:
out.normal_map = normal_map
if occlusion_in_mr:
out.occlusion_in_mr = True
if material is not None:
out.material = material
if emissive is not None:
out.emissive = emissive
return out
def get_mesh_batch_item(mesh, index):
@ -180,7 +200,8 @@ def _compute_vertex_normals(vertices_np, faces_np, crease_angle=None):
def save_glb(vertices, faces, filepath, metadata=None,
uvs=None, vertex_colors=None, texture_image=None,
metallic_roughness_image=None, unlit=False,
normals=None):
normals=None, normal_map_image=None, tangents=None, occlusion_in_mr=False,
material=None, emissive_image=None):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
@ -197,6 +218,16 @@ def save_glb(vertices, faces, filepath, metadata=None,
normals: torch.Tensor of shape (N, 3) - Optional per-vertex normals, written as the
glTF NORMAL attribute. When omitted, NO normals are written and viewers fall back
to flat (per-face) shading use the MeshSmoothNormals node to generate them.
normal_map_image: PIL.Image - Optional tangent-space normal map (glTF/OpenGL +Y),
written as the material normalTexture. Needs TEXCOORD_0.
tangents: torch.Tensor of shape (N, 4) - Optional per-vertex tangents (xyz + handedness w),
written as the glTF TANGENT attribute. Without it viewers derive tangents in-shader.
occlusion_in_mr: bool - When True, R of metallic_roughness_image holds AO (ORM packing) and
occlusionTexture is pointed at that same image.
material: dict - Optional scalar overrides from SetMeshMaterial (base_color_factor,
metallic/roughness_factor with <0 = auto, emissive_factor/strength, normal_scale,
occlusion_strength, double_sided).
emissive_image: PIL.Image - Optional emissive (glow) texture, written as emissiveTexture.
"""
# Convert tensors to numpy arrays
@ -231,6 +262,11 @@ def save_glb(vertices, faces, filepath, metadata=None,
raise ValueError(
f"save_glb: normals has {normals_np.shape[0]} entries but vertex count is {n_verts}"
)
tangents_np = tangents.cpu().numpy().astype(np.float32) if tangents is not None else None
if tangents_np is not None and tangents_np.shape != (n_verts, 4):
raise ValueError(
f"save_glb: tangents must be (N, 4) with N={n_verts}, got {tuple(tangents_np.shape)}"
)
faces_np = faces_signed.astype(np.uint32)
texture_png_bytes = None
if texture_image is not None:
@ -242,46 +278,60 @@ def save_glb(vertices, faces, filepath, metadata=None,
buf = BytesIO()
metallic_roughness_image.save(buf, format="PNG")
mr_png_bytes = buf.getvalue()
nm_png_bytes = None
if normal_map_image is not None:
buf = BytesIO()
normal_map_image.save(buf, format="PNG")
nm_png_bytes = buf.getvalue()
em_png_bytes = None
if emissive_image is not None:
buf = BytesIO()
emissive_image.save(buf, format="PNG")
em_png_bytes = buf.getvalue()
vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes()
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
colors_buffer = colors_np.tobytes() if colors_np is not None else b""
normals_buffer = normals_np.tobytes() if normals_np is not None else b""
tangents_buffer = tangents_np.tobytes() if tangents_np is not None else b""
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
mr_buffer = mr_png_bytes if mr_png_bytes is not None else b""
nm_buffer = nm_png_bytes if nm_png_bytes is not None else b""
em_buffer = em_png_bytes if em_png_bytes is not None else b""
def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b'\x00' * padding_length
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
colors_buffer_padded = pad_to_4_bytes(colors_buffer)
normals_buffer_padded = pad_to_4_bytes(normals_buffer)
texture_buffer_padded = pad_to_4_bytes(texture_buffer)
mr_buffer_padded = pad_to_4_bytes(mr_buffer)
buffer_data = b"".join([
vertices_buffer_padded,
indices_buffer_padded,
uvs_buffer_padded,
colors_buffer_padded,
normals_buffer_padded,
texture_buffer_padded,
mr_buffer_padded,
])
# Blob order in one place; offsets accumulated in a pass so adding a buffer is one entry.
_blobs = [
("vertices", vertices_buffer), ("indices", indices_buffer), ("uvs", uvs_buffer),
("colors", colors_buffer), ("normals", normals_buffer), ("tangents", tangents_buffer),
("texture", texture_buffer), ("mr", mr_buffer), ("nm", nm_buffer), ("em", em_buffer),
]
byte_offset = {}
acc = 0
parts = []
for name, b in _blobs:
padded = pad_to_4_bytes(b)
byte_offset[name] = acc
acc += len(padded)
parts.append(padded)
buffer_data = b"".join(parts)
vertices_byte_length = len(vertices_buffer)
vertices_byte_offset = 0
indices_byte_length = len(indices_buffer)
indices_byte_offset = len(vertices_buffer_padded)
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
normals_byte_offset = colors_byte_offset + len(colors_buffer_padded)
texture_byte_offset = normals_byte_offset + len(normals_buffer_padded)
mr_byte_offset = texture_byte_offset + len(texture_buffer_padded)
vertices_byte_offset = byte_offset["vertices"]
indices_byte_offset = byte_offset["indices"]
uvs_byte_offset = byte_offset["uvs"]
colors_byte_offset = byte_offset["colors"]
normals_byte_offset = byte_offset["normals"]
tangents_byte_offset = byte_offset["tangents"]
texture_byte_offset = byte_offset["texture"]
mr_byte_offset = byte_offset["mr"]
nm_byte_offset = byte_offset["nm"]
em_byte_offset = byte_offset["em"]
buffer_views = [
{
@ -368,6 +418,23 @@ def save_glb(vertices, faces, filepath, metadata=None,
})
primitive_attributes["NORMAL"] = accessor_idx
if tangents_np is not None and len(tangents_np) > 0:
buffer_views.append({
"buffer": 0,
"byteOffset": tangents_byte_offset,
"byteLength": len(tangents_buffer),
"target": 34962
})
accessor_idx = len(accessors)
accessors.append({
"bufferView": len(buffer_views) - 1,
"byteOffset": 0,
"componentType": 5126, # FLOAT
"count": len(tangents_np),
"type": "VEC4", # xyz tangent + w handedness (glTF TANGENT)
})
primitive_attributes["TANGENT"] = accessor_idx
primitive = {
"attributes": primitive_attributes,
"indices": 1,
@ -379,9 +446,24 @@ def save_glb(vertices, faces, filepath, metadata=None,
samplers = []
materials = []
extensions_used = []
def add_image_texture(png_byte_offset, png_byte_length):
"""Append an embedded PNG image + a texture referencing it; return the texture index."""
buffer_views.append({"buffer": 0, "byteOffset": png_byte_offset, "byteLength": png_byte_length})
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
if not samplers:
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": len(images) - 1, "sampler": 0})
return len(textures) - 1
has_uv = "TEXCOORD_0" in primitive_attributes
if unlit and texture_png_bytes is None:
# Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a
# gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours.
if nm_png_bytes is not None or em_png_bytes is not None or occlusion_in_mr or material is not None:
logging.warning(
"save_glb: unlit material ignores normal/occlusion/emissive maps and SetMeshMaterial "
"overrides — those are PBR-lit features. Disable unlit to export them.")
materials.append({
"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0},
"extensions": {"KHR_materials_unlit": {}},
@ -395,37 +477,57 @@ def save_glb(vertices, faces, filepath, metadata=None,
"roughnessFactor": 0.5,
"baseColorFactor": [0.22, 0.22, 0.22, 1.0],
}
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
buffer_views.append({
"buffer": 0,
"byteOffset": texture_byte_offset,
"byteLength": len(texture_buffer),
})
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": len(images) - 1, "sampler": 0})
pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0}
if texture_png_bytes is not None and has_uv:
pbr["baseColorTexture"] = {"index": add_image_texture(texture_byte_offset, len(texture_buffer)), "texCoord": 0}
if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
buffer_views.append({
"buffer": 0,
"byteOffset": mr_byte_offset,
"byteLength": len(mr_buffer),
})
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
if not samplers:
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": len(images) - 1, "sampler": 0})
pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0}
if mr_png_bytes is not None and has_uv:
mr_texture_index = add_image_texture(mr_byte_offset, len(mr_buffer))
pbr["metallicRoughnessTexture"] = {"index": mr_texture_index, "texCoord": 0}
# When a metallicRoughness texture is present, the factors scale it; use 1.0
# so the texture values pass through unchanged (glTF convention).
pbr["metallicFactor"] = 1.0
pbr["roughnessFactor"] = 1.0
materials.append({
mat = material if isinstance(material, dict) else {}
# Scalar overrides from SetMeshMaterial (factor < 0 means "leave auto").
if mat.get("base_color_factor") is not None:
pbr["baseColorFactor"] = [float(x) for x in mat["base_color_factor"]]
if mat.get("metallic_factor", -1.0) >= 0.0:
pbr["metallicFactor"] = float(mat["metallic_factor"])
if mat.get("roughness_factor", -1.0) >= 0.0:
pbr["roughnessFactor"] = float(mat["roughness_factor"])
material = {
"pbrMetallicRoughness": pbr,
"doubleSided": True,
})
"doubleSided": bool(mat.get("double_sided", True)),
}
if occlusion_in_mr and mr_png_bytes is not None and has_uv:
# ORM packing: occlusionTexture reuses the MR image (glTF reads its R channel).
material["occlusionTexture"] = {"index": mr_texture_index, "texCoord": 0,
"strength": float(mat.get("occlusion_strength", 1.0))}
if nm_png_bytes is not None and has_uv:
material["normalTexture"] = {"index": add_image_texture(nm_byte_offset, len(nm_buffer)),
"texCoord": 0, "scale": float(mat.get("normal_scale", 1.0))}
emissive_factor = [float(x) for x in mat.get("emissive_factor", [0.0, 0.0, 0.0])]
emissive_strength = float(mat.get("emissive_strength", 1.0))
has_em_tex = em_png_bytes is not None and has_uv
if any(c > 0.0 for c in emissive_factor) or has_em_tex:
# glTF multiplies emissiveFactor × texture, so a texture with no color would go black;
# default the factor to white in that case.
if has_em_tex and not any(c > 0.0 for c in emissive_factor):
emissive_factor = [1.0, 1.0, 1.0]
material["emissiveFactor"] = [min(1.0, c) for c in emissive_factor]
if has_em_tex:
material["emissiveTexture"] = {"index": add_image_texture(em_byte_offset, len(em_buffer)),
"texCoord": 0}
if emissive_strength != 1.0:
material.setdefault("extensions", {})["KHR_materials_emissive_strength"] = {
"emissiveStrength": emissive_strength}
if "KHR_materials_emissive_strength" not in extensions_used:
extensions_used.append("KHR_materials_emissive_strength")
materials.append(material)
primitive["material"] = 0
gltf = {
@ -556,6 +658,22 @@ class SaveGLB(IO.ComfyNode):
assert mr_np.ndim == 4 and mr_np.shape[-1] == 3, (
f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}"
)
nm_b = getattr(mesh, "normal_map", None)
nm_np = None
if nm_b is not None:
nm_np = (nm_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
assert nm_np.ndim == 4 and nm_np.shape[-1] == 3, (
f"normal_map must be (B, H, W, 3), got shape {tuple(nm_np.shape)}"
)
em_b = getattr(mesh, "emissive", None)
em_np = None
if em_b is not None:
em_np = (em_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
assert em_np.ndim == 4 and em_np.shape[-1] == 3, (
f"emissive must be (B, H, W, 3), got shape {tuple(em_np.shape)}"
)
tangents_b = getattr(mesh, "tangents", None)
material = getattr(mesh, "material", None)
for i in range(mesh.vertices.shape[0]):
vertices_i, faces_i, v_colors, uvs_i, normals_i = get_mesh_batch_item(mesh, i)
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
@ -563,6 +681,9 @@ class SaveGLB(IO.ComfyNode):
continue
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
mr_img = Image.fromarray(mr_np[i], mode="RGB") if mr_np is not None else None
nm_img = Image.fromarray(nm_np[i], mode="RGB") if nm_np is not None else None
em_img = Image.fromarray(em_np[i], mode="RGB") if em_np is not None else None
tangents_i = tangents_b[i, :vertices_i.shape[0]] if tangents_b is not None else None
f = f"{filename}_{counter:05}_.glb"
save_glb(
vertices_i, faces_i,
@ -574,6 +695,11 @@ class SaveGLB(IO.ComfyNode):
metallic_roughness_image=mr_img,
unlit=getattr(mesh, "unlit", False),
normals=normals_i,
normal_map_image=nm_img,
tangents=tangents_i,
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
material=material,
emissive_image=em_img,
)
results.append({
"filename": f,
@ -723,9 +849,11 @@ class MeshSmoothNormals(IO.ComfyNode):
return IO.NodeOutput(out)
# Crease split changes per-item vertex counts -> rebuild as a variable-size batch.
tangents_b = getattr(mesh, "tangents", None)
v_list, f_list, n_list = [], [], []
c_list = [] if mesh.vertex_colors is not None else None
u_list = [] if mesh.uvs is not None else None
t_list = [] if tangents_b is not None else None
for i in range(batch_size):
v_i, f_i, c_i, u_i, _ = get_mesh_batch_item(mesh, i)
if v_i.shape[0] == 0 or f_i.shape[0] == 0:
@ -742,12 +870,19 @@ class MeshSmoothNormals(IO.ComfyNode):
c_list.append(c_i[remap_t.to(c_i.device)])
if u_list is not None:
u_list.append(u_i[remap_t.to(u_i.device)])
if t_list is not None:
# Remap (not recompute) so TANGENT keeps the baked basis; split verts copy theirs.
t_i = tangents_b[i, :v_i.shape[0]]
t_list.append(t_i[remap_t.to(t_i.device)])
if not v_list:
return IO.NodeOutput(mesh)
out = pack_variable_mesh_batch(
v_list, f_list, colors=c_list, uvs=u_list,
texture=mesh.texture, unlit=getattr(mesh, "unlit", False),
normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None))
normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None),
tangents=t_list, normal_map=getattr(mesh, "normal_map", None),
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
material=getattr(mesh, "material", None), emissive=getattr(mesh, "emissive", None))
return IO.NodeOutput(out)

View File

@ -5,6 +5,7 @@ from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_point
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
from server import PromptServer
import comfy.latent_formats
import comfy.model_management
import comfy.utils
@ -414,38 +415,44 @@ class Trellis2UpsampleStage(IO.ComfyNode):
"y_up" if proj_pack is not None else "z_up")}
return IO.NodeOutput(positive_out, negative_out, out_latent)
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def _dinov3_encode(model, image_bchw, image_size, want_patches=False):
"""Run DINOv3 once at the requested resolution.
def run_conditioning(model, cropped_img_tensor, include_1024=True):
image_bchw: [B, 3, H, W] float in [0, 1] (any source resolution; resized here).
Returns the full sequence tensor (Trellis2 path) or a dict with the global
tokens split out + a 2D patch grid (Pixal3D path) when `want_patches=True`.
"""
model_internal = model.model
device = comfy.model_management.get_torch_device()
img_t = comfy.utils.common_upscale(image_bchw, image_size, image_size, "lanczos", "disabled").to(device)
mean = torch.tensor(model.image_mean or [0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
std = torch.tensor(model.image_std or [0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
img_t = (img_t - mean) / std
model_internal.image_size = image_size
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
if not want_patches:
return tokens
h_p = w_p = image_size // 16
n_reg = tokens.shape[1] - 1 - h_p * w_p
return {"tokens": tokens[:, :1 + n_reg], "patches_2d": _dinov3_patches_to_2d(tokens, image_size)}
def run_conditioning(model, cropped_pil_img, include_1024=True):
device = comfy.model_management.intermediate_device()
torch_device = comfy.model_management.get_torch_device()
def prepare_tensor(pil_img, size):
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
img_np = np.array(resized_pil).astype(np.float32) / 255.0
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
model_internal.image_size = 512
input_512 = prepare_tensor(cropped_img_tensor, 512)
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
cond_1024 = None
if include_1024:
model_internal.image_size = 1024
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
img_np = np.array(cropped_pil_img).astype(np.float32) / 255.0
image_bchw = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).contiguous()
cond_512 = _dinov3_encode(model, image_bchw, 512)
conditioning = {
'cond_512': cond_512.to(device),
'neg_cond': torch.zeros_like(cond_512).to(device),
"cond_512": cond_512.to(device),
"neg_cond": torch.zeros_like(cond_512).to(device),
}
if cond_1024 is not None:
conditioning['cond_1024'] = cond_1024.to(device)
if include_1024:
cond_1024 = _dinov3_encode(model, image_bchw, 1024)
conditioning["cond_1024"] = cond_1024.to(device)
return conditioning
class Trellis2Conditioning(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -780,21 +787,6 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
def _run_dinov3_with_patches(model, composite, image_size):
model_internal = model.model
torch_device = comfy.model_management.get_torch_device()
img_t = comfy.utils.common_upscale(composite, image_size, image_size, "lanczos", "disabled")
img_t = img_t.to(torch_device)
img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
model_internal.image_size = image_size
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
patches = _dinov3_patches_to_2d(tokens, image_size)
h_p = w_p = image_size // 16
n_reg = tokens.shape[1] - 1 - h_p * w_p
global_tokens = tokens[:, :1 + n_reg]
return {"tokens": global_tokens, "patches_2d": patches}
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float()
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float()
@ -802,6 +794,12 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
# Detect & correct an inverted mask
m2d = mask[0, 0]
border = torch.cat([m2d[0, :], m2d[-1, :], m2d[:, 0], m2d[:, -1]])
if float(border.mean()) > 0.5:
mask = 1.0 - mask
H, W = img.shape[-2:]
if max(H, W) > max_image_size:
scale = max_image_size / max(H, W)
@ -923,8 +921,8 @@ class Pixal3DConditioning(IO.ComfyNode):
scene_size_list.append(scene_size)
composite_list.append(composite)
cond_512 = _run_dinov3_with_patches(clip_vision_model, composite, 512)
cond_1024 = _run_dinov3_with_patches(clip_vision_model, composite, 1024)
cond_512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=True)
cond_1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=True)
cond_512_list.append(cond_512["tokens"].to(device))
cond_1024_list.append(cond_1024["tokens"].to(device))
patches_512_list.append(cond_512["patches_2d"].to(device))
@ -1104,8 +1102,7 @@ class Pixal3DAlignObject(IO.ComfyNode):
moge_per_vertex = moge_points[batch_index, sy, sx]
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
# frame as vertices_one (Pixal3D model frame = glTF Y-up). Mirrors the
# `verts * [1, -1, -1]` step in MoGePointMapToMesh.
# frame as vertices_one (Pixal3D model frame = glTF Y-up).
moge_per_vertex = moge_per_vertex * torch.tensor(
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
)
@ -1188,6 +1185,7 @@ class GetMeshInfo(IO.ComfyNode):
IO.Mesh.Output(display_name="mesh"),
IO.String.Output(display_name="info"),
],
hidden=[IO.Hidden.unique_id],
)
@staticmethod
@ -1212,10 +1210,10 @@ class GetMeshInfo(IO.ComfyNode):
f_counts = [int(mesh.faces.shape[1])] * B
attrs = []
for name in ("uvs", "vertex_colors", "normals", "texture", "metallic_roughness"):
for name in ("uvs", "vertex_colors", "normals", "tangents", "texture", "metallic_roughness", "normal_map"):
t = getattr(mesh, name, None)
if t is not None:
if name in ("texture", "metallic_roughness"):
if name in ("texture", "metallic_roughness", "normal_map"):
attrs.append(f"{name} {int(t.shape[-3])}×{int(t.shape[-2])}") # H×W
else:
attrs.append(name)
@ -1234,6 +1232,9 @@ class GetMeshInfo(IO.ComfyNode):
info = "\n".join(lines)
logging.info("[GetMeshInfo]\n%s", info)
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(info, cls.hidden.unique_id)
return IO.NodeOutput(mesh, info, ui=UI.PreviewText(info))