mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Normal and AO baking
This commit is contained in:
parent
ab58d1b79f
commit
42ac23f6f6
@ -7,7 +7,7 @@ import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from server import PromptServer
|
||||
from comfy_extras.mesh3d.postprocess.qem_decimate import QEMConfig, qem_decimate_simplify, qem_cluster_decimate
|
||||
from comfy_extras.mesh3d.postprocess.qem_decimate import QEMConfig, qem_decimate_simplify, qem_cluster_decimate, _compute_vertex_normals
|
||||
from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc, _point_tri_closest
|
||||
from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh
|
||||
from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg
|
||||
@ -166,23 +166,23 @@ class PaintMesh(IO.ComfyNode):
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
|
||||
def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
|
||||
"""Rasterize the mesh in UV space and barycentric-interpolate the per-vertex vec3
|
||||
(world position, or any vec3 attr e.g. normals) at each covered texel. Pure torch,
|
||||
tiled point-in-triangle — no GL/EGL, runs anywhere torch does. Returns (attr_map
|
||||
[H,W,3] float32, mask [H,W] bool). """
|
||||
def _rasterize_uv_barycentric(faces_np, uvs_np, texture_size):
|
||||
"""Rasterize the mesh in UV space (tiled point-in-triangle, pure torch). Returns per-texel
|
||||
face index [H,W], barycentric coords [H,W,3] and coverage mask [H,W], on the torch device.
|
||||
Interpolate any per-vertex attribute from these with _interp_vertex_attr."""
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
H = W = int(texture_size)
|
||||
face_idx = torch.zeros((H, W), dtype=torch.long, device=dev)
|
||||
bary = torch.zeros((H, W, 3), device=dev)
|
||||
cov = torch.zeros((H, W), dtype=torch.bool, device=dev)
|
||||
if faces_np.shape[0] == 0:
|
||||
return np.zeros((H, W, 3), dtype=np.float32), np.zeros((H, W), dtype=bool)
|
||||
return face_idx, bary, cov
|
||||
|
||||
verts = torch.from_numpy(np.ascontiguousarray(verts_np, dtype=np.float32)).to(dev)
|
||||
uvs = torch.from_numpy(np.ascontiguousarray(uvs_np, dtype=np.float32)).to(dev)
|
||||
faces = torch.from_numpy(np.ascontiguousarray(faces_np).astype(np.int64)).to(dev)
|
||||
|
||||
# GL convention: window coord = uv * resolution, coverage tested at texel centre.
|
||||
tri_uv = (uvs * float(W))[faces] # [F,3,2]
|
||||
tri_attr = verts[faces] # [F,3,3]
|
||||
x0, y0 = tri_uv[:, 0, 0], tri_uv[:, 0, 1]
|
||||
x1, y1 = tri_uv[:, 1, 0], tri_uv[:, 1, 1]
|
||||
x2, y2 = tri_uv[:, 2, 0], tri_uv[:, 2, 1]
|
||||
@ -194,9 +194,6 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
|
||||
ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, H - 1).long()
|
||||
ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, H - 1).long()
|
||||
|
||||
pos_out = torch.zeros((H, W, 3), device=dev)
|
||||
cov = torch.zeros((H, W), dtype=torch.bool, device=dev)
|
||||
|
||||
# Tile so point-in-triangle only runs over the triangles whose bbox hits the tile.
|
||||
TILE = 64
|
||||
eps = 1e-6
|
||||
@ -224,15 +221,40 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
|
||||
continue
|
||||
hit = inside.any(dim=0) # [th,tw]
|
||||
sel = inside.int().argmax(dim=0) # [th,tw] first covering local tri
|
||||
b0s = b0.gather(0, sel[None]).squeeze(0) # [th,tw] bary of selected tri
|
||||
b1s = b1.gather(0, sel[None]).squeeze(0)
|
||||
b2s = b2.gather(0, sel[None]).squeeze(0)
|
||||
p = tri_attr[idx[sel]] # [th,tw,3,3]
|
||||
attr = b0s[..., None] * p[..., 0, :] + b1s[..., None] * p[..., 1, :] + b2s[..., None] * p[..., 2, :]
|
||||
pos_out[ty:ty_end, tx:tx_end][hit] = attr[hit] # slice is a view → writes through
|
||||
bsel = torch.stack([b0.gather(0, sel[None]).squeeze(0),
|
||||
b1.gather(0, sel[None]).squeeze(0),
|
||||
b2.gather(0, sel[None]).squeeze(0)], dim=-1) # [th,tw,3]
|
||||
face_idx[ty:ty_end, tx:tx_end][hit] = idx[sel][hit] # slice is a view → writes through
|
||||
bary[ty:ty_end, tx:tx_end][hit] = bsel[hit]
|
||||
cov[ty:ty_end, tx:tx_end] |= hit
|
||||
|
||||
return pos_out.cpu().numpy(), cov.cpu().numpy()
|
||||
return face_idx, bary, cov
|
||||
|
||||
|
||||
def _interp_vertex_attr(attr_v, faces, face_idx, bary, mask):
|
||||
"""Interpolate a per-vertex attribute [N,C] into a [H,W,C] map via a rasterized
|
||||
(face_idx, bary, mask). Uncovered texels stay zero."""
|
||||
H, W = mask.shape
|
||||
out = torch.zeros((H, W, attr_v.shape[1]), device=attr_v.device, dtype=attr_v.dtype)
|
||||
if mask.any():
|
||||
vtri = attr_v[faces[face_idx[mask]]] # [K,3,C]
|
||||
out[mask] = (bary[mask][:, :, None] * vtri).sum(1)
|
||||
return out
|
||||
|
||||
|
||||
def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
|
||||
"""Barycentric-interpolate a per-vertex vec3 (world position, or any vec3 e.g. normals)
|
||||
at each covered texel. Returns (attr_map [H,W,3] float32, mask [H,W] bool)."""
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
H = W = int(texture_size)
|
||||
if faces_np.shape[0] == 0:
|
||||
return np.zeros((H, W, 3), dtype=np.float32), np.zeros((H, W), dtype=bool)
|
||||
|
||||
face_idx, bary, mask = _rasterize_uv_barycentric(faces_np, uvs_np, texture_size)
|
||||
verts = torch.from_numpy(np.ascontiguousarray(verts_np, dtype=np.float32)).to(dev)
|
||||
faces = torch.from_numpy(np.ascontiguousarray(faces_np).astype(np.int64)).to(dev)
|
||||
attr = _interp_vertex_attr(verts, faces, face_idx, bary, mask)
|
||||
return attr.cpu().numpy(), mask.cpu().numpy()
|
||||
|
||||
|
||||
def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution):
|
||||
@ -565,9 +587,10 @@ def _build_triangle_bvh(tri):
|
||||
return dict(LEAF=LEAF, left=left, right=right, nmin=nmin, nmax=nmax, order=order, T=T)
|
||||
|
||||
|
||||
def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
|
||||
def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64, return_face=False):
|
||||
"""Exact closest surface point per query via per-query BVH stack traversal
|
||||
(nearest-child-first), pure torch. Returns [N,3]. `max_stack` bounds the stack
|
||||
(nearest-child-first), pure torch. Returns [N,3], or (points [N,3], face_idx [N])
|
||||
when return_face=True (face_idx indexes `tri`). `max_stack` bounds the stack
|
||||
(= tree height); overflow is counted+warned, not silently wrong."""
|
||||
dev = Q.device
|
||||
N = Q.shape[0]
|
||||
@ -582,6 +605,7 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
|
||||
stack[:, 0] = 0
|
||||
best = torch.full((N,), 1e30, device=dev)
|
||||
bestp = Q.clone()
|
||||
bestf = torch.full((N,), -1, dtype=torch.long, device=dev)
|
||||
active = torch.arange(N, device=dev)
|
||||
overflow = 0
|
||||
|
||||
@ -599,12 +623,14 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
|
||||
lv = within & isleaf
|
||||
if bool(lv.any()):
|
||||
ga = a[lv]
|
||||
tt = tri[order[node[lv] - LEAF]]
|
||||
fidx = order[node[lv] - LEAF] # triangle index of each leaf
|
||||
tt = tri[fidx]
|
||||
cp, d2 = _point_tri_closest(qa[lv], tt)
|
||||
upd = d2 < best[ga]
|
||||
gu = ga[upd]
|
||||
best[gu] = d2[upd]
|
||||
bestp[gu] = cp[upd]
|
||||
bestf[gu] = fidx[upd]
|
||||
iv = within & ~isleaf
|
||||
if bool(iv.any()):
|
||||
gi = a[iv]
|
||||
@ -627,6 +653,8 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64):
|
||||
logging.warning(f"[back-project] BVH stack overflow on {overflow} pushes "
|
||||
f"(max_stack={max_stack}); a few texels may be slightly off — "
|
||||
f"raise max_stack if this is large.")
|
||||
if return_face:
|
||||
return bestp, bestf
|
||||
return bestp
|
||||
|
||||
|
||||
@ -652,6 +680,352 @@ def _back_project_positions(position_map, mask, ref_v, ref_f):
|
||||
return out
|
||||
|
||||
|
||||
def _ray_tri_hit(o, d, tri, tmin, tmax):
|
||||
"""Möller-Trumbore any-hit per (ray, triangle) pair, double-sided. Returns bool [N]."""
|
||||
a, b, c = tri[:, 0], tri[:, 1], tri[:, 2]
|
||||
e1, e2 = b - a, c - a
|
||||
p = torch.cross(d, e2, dim=-1)
|
||||
det = (e1 * p).sum(-1)
|
||||
inv = 1.0 / torch.where(det.abs() < 1e-20, torch.full_like(det, 1e-20), det)
|
||||
tvec = o - a
|
||||
u = (tvec * p).sum(-1) * inv
|
||||
q = torch.cross(tvec, e1, dim=-1)
|
||||
v = (d * q).sum(-1) * inv
|
||||
t = (e2 * q).sum(-1) * inv
|
||||
return (det.abs() > 1e-20) & (u >= 0) & (v >= 0) & (u + v <= 1) & (t > tmin) & (t < tmax)
|
||||
|
||||
|
||||
def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64):
|
||||
"""Any-hit ray test over the BVH (slab cull + Möller-Trumbore), pure torch. Returns bool
|
||||
[N]: True if the ray hits any triangle in (tmin, tmax). Rays early-out once they hit."""
|
||||
dev = orig.device
|
||||
N = orig.shape[0]
|
||||
LEAF = bvh['LEAF']
|
||||
nmin, nmax = bvh['nmin'], bvh['nmax']
|
||||
left, right, order = bvh['left'], bvh['right'], bvh['order']
|
||||
inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs)
|
||||
hit = torch.zeros(N, dtype=torch.bool, device=dev)
|
||||
# int32 stack: node indices fit in 31 bits and this [N, max_stack] array dominates memory.
|
||||
stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev)
|
||||
sp = torch.ones(N, dtype=torch.long, device=dev)
|
||||
stack[:, 0] = 0
|
||||
active = torch.arange(N, device=dev)
|
||||
|
||||
def slab(node, o, i):
|
||||
t1 = (nmin[node] - o) * i
|
||||
t2 = (nmax[node] - o) * i
|
||||
tnear = torch.minimum(t1, t2).amax(-1)
|
||||
tfar = torch.maximum(t1, t2).amin(-1)
|
||||
return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmax) & (tfar >= tmin)
|
||||
|
||||
while active.numel() > 0:
|
||||
a = active
|
||||
node = stack[a, sp[a] - 1]
|
||||
sp[a] = sp[a] - 1
|
||||
within = slab(node, orig[a], inv[a])
|
||||
isleaf = node >= LEAF
|
||||
lv = within & isleaf
|
||||
if bool(lv.any()):
|
||||
ga = a[lv]
|
||||
tt = tri[order[node[lv] - LEAF]]
|
||||
h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmax)
|
||||
hit[ga[h]] = True
|
||||
iv = within & ~isleaf
|
||||
if bool(iv.any()):
|
||||
gi = a[iv]
|
||||
s0 = sp[gi]
|
||||
stack[gi, s0.clamp(max=max_stack - 1)] = left[node[iv]].to(torch.int32)
|
||||
sp[gi] = (s0 + 1).clamp(max=max_stack)
|
||||
s1 = sp[gi]
|
||||
stack[gi, s1.clamp(max=max_stack - 1)] = right[node[iv]].to(torch.int32)
|
||||
sp[gi] = (s1 + 1).clamp(max=max_stack)
|
||||
active = a[(sp[a] > 0) & ~hit[a]] # drop finished + already-hit rays
|
||||
return hit
|
||||
|
||||
|
||||
def _ray_tri_intersect(o, d, tri, tmin, tmax, cull_backface=False):
|
||||
"""Möller-Trumbore per (ray, triangle) pair. Returns (hit [N], t [N]) where t is the ray
|
||||
parameter and hit means the meeting is in (tmin, tmax). With cull_backface, drops faces whose
|
||||
outward (winding) normal points along the ray — i.e. only keep surfaces facing the origin."""
|
||||
a, b, c = tri[:, 0], tri[:, 1], tri[:, 2]
|
||||
e1, e2 = b - a, c - a
|
||||
p = torch.cross(d, e2, dim=-1)
|
||||
det = (e1 * p).sum(-1)
|
||||
inv = 1.0 / torch.where(det.abs() < 1e-20, torch.full_like(det, 1e-20), det)
|
||||
tvec = o - a
|
||||
u = (tvec * p).sum(-1) * inv
|
||||
q = torch.cross(tvec, e1, dim=-1)
|
||||
v = (d * q).sum(-1) * inv
|
||||
t = (e2 * q).sum(-1) * inv
|
||||
hit = (det.abs() > 1e-20) & (u >= 0) & (v >= 0) & (u + v <= 1) & (t > tmin) & (t < tmax)
|
||||
if cull_backface:
|
||||
hit = hit & ((torch.cross(e1, e2, dim=-1) * d).sum(-1) < 0) # keep only front-facing
|
||||
return hit, t
|
||||
|
||||
|
||||
def _closest_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64, cull_backface=False):
|
||||
"""Nearest-hit ray cast over the BVH, pure torch. Returns (t [N], face [N] long, -1 on
|
||||
miss; hit [N] bool) — the closest intersection in (tmin, tmax), pruning nodes past best_t."""
|
||||
dev = orig.device
|
||||
N = orig.shape[0]
|
||||
LEAF = bvh['LEAF']
|
||||
nmin, nmax = bvh['nmin'], bvh['nmax']
|
||||
left, right, order = bvh['left'], bvh['right'], bvh['order']
|
||||
inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs)
|
||||
best_t = torch.full((N,), float(tmax), device=dev)
|
||||
best_f = torch.full((N,), -1, dtype=torch.long, device=dev)
|
||||
stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev)
|
||||
sp = torch.ones(N, dtype=torch.long, device=dev)
|
||||
stack[:, 0] = 0
|
||||
active = torch.arange(N, device=dev)
|
||||
|
||||
while active.numel() > 0:
|
||||
a = active
|
||||
node = stack[a, sp[a] - 1]
|
||||
sp[a] = sp[a] - 1
|
||||
t1 = (nmin[node] - orig[a]) * inv[a]
|
||||
t2 = (nmax[node] - orig[a]) * inv[a]
|
||||
tnear = torch.minimum(t1, t2).amax(-1)
|
||||
tfar = torch.maximum(t1, t2).amin(-1)
|
||||
within = (tfar >= tnear.clamp_min(tmin)) & (tfar >= tmin) & (tnear < best_t[a]) # prune past best
|
||||
isleaf = node >= LEAF
|
||||
lv = within & isleaf
|
||||
if bool(lv.any()):
|
||||
ga = a[lv]
|
||||
fidx = order[node[lv] - LEAF]
|
||||
h, t = _ray_tri_intersect(orig[ga], dirs[ga], tri[fidx], tmin, tmax, cull_backface)
|
||||
upd = h & (t < best_t[ga])
|
||||
gu = ga[upd]
|
||||
best_t[gu] = t[upd]
|
||||
best_f[gu] = fidx[upd]
|
||||
iv = within & ~isleaf
|
||||
if bool(iv.any()):
|
||||
gi = a[iv]
|
||||
s0 = sp[gi]
|
||||
stack[gi, s0.clamp(max=max_stack - 1)] = left[node[iv]].to(torch.int32)
|
||||
sp[gi] = (s0 + 1).clamp(max=max_stack)
|
||||
s1 = sp[gi]
|
||||
stack[gi, s1.clamp(max=max_stack - 1)] = right[node[iv]].to(torch.int32)
|
||||
sp[gi] = (s1 + 1).clamp(max=max_stack)
|
||||
active = a[sp[a] > 0]
|
||||
return best_t, best_f, best_f >= 0
|
||||
|
||||
|
||||
def _onb(n):
|
||||
"""Branchless orthonormal basis (t, b) around unit normals n [N,3]."""
|
||||
up = torch.where(n[..., 2:3].abs() < 0.999,
|
||||
torch.tensor([0.0, 0.0, 1.0], device=n.device).expand_as(n),
|
||||
torch.tensor([1.0, 0.0, 0.0], device=n.device).expand_as(n))
|
||||
t = torch.nn.functional.normalize(torch.cross(up, n, dim=-1), dim=-1, eps=1e-6)
|
||||
return t, torch.cross(n, t, dim=-1)
|
||||
|
||||
|
||||
def _bake_ambient_occlusion(high_v, high_f, low_v_np, low_f_np, low_uv_np, low_n, resolution,
|
||||
num_samples=64, max_distance=0.5, strength=1.0, bias=0.01,
|
||||
ray_chunk=None, pbar=None):
|
||||
"""Bake high-poly ambient occlusion into the low-poly's UV layout: per texel, cosine-weight
|
||||
a hemisphere of rays around the normal and cast them at the high-poly. AO = 1 - hit-fraction
|
||||
(cosine weighting makes the hit-fraction the estimator). Returns ao_img [H,W,3] in [0,1].
|
||||
|
||||
ray_chunk caps rays cast at once (the per-chunk BVH stack is its dominant transient VRAM);
|
||||
None auto-sizes it to a slice of free VRAM — big chunks (fast) on large GPUs, small (safe)
|
||||
on small ones."""
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
H = W = int(resolution)
|
||||
S = int(num_samples)
|
||||
if ray_chunk is None:
|
||||
# ~376 B/ray (int32 stack max_stack*4 + a few [N,3] ray buffers); spend a quarter of free
|
||||
# VRAM. Speed saturates around 4M rays/chunk, so cap there (≈2 GB peak) rather than grow
|
||||
# memory for no gain; floor keeps tiny GPUs from thrashing into too many chunks.
|
||||
try:
|
||||
free = torch.cuda.mem_get_info(dev)[0] if dev.type == "cuda" else (2 << 30)
|
||||
except Exception:
|
||||
free = 2 << 30
|
||||
ray_chunk = int(min(1 << 22, max(1 << 20, (free * 0.25) / (num_samples * 4 + 200))))
|
||||
face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution)
|
||||
if not mask.any():
|
||||
return np.ones((H, W, 3), dtype=np.float32)
|
||||
lf = torch.from_numpy(np.ascontiguousarray(low_f_np).astype(np.int64)).to(dev)
|
||||
lv = torch.from_numpy(np.ascontiguousarray(low_v_np, dtype=np.float32)).to(dev)
|
||||
low_n = low_n.to(dev).float()
|
||||
m = mask
|
||||
vtri = lf[face_idx[m]] # [K,3] vertex ids
|
||||
bsel = bary_uv[m] # [K,3]
|
||||
P = (bsel[:, :, None] * lv[vtri]).sum(1) # [K,3]
|
||||
Nl = torch.nn.functional.normalize((bsel[:, :, None] * low_n[vtri]).sum(1), dim=-1, eps=1e-6)
|
||||
|
||||
hv = high_v.to(dev).float()
|
||||
hf = high_f.to(dev).long()
|
||||
tri = hv[hf]
|
||||
bvh = _build_triangle_bvh(tri)
|
||||
diag = float((hv.amax(0) - hv.amin(0)).norm().clamp_min(1e-6))
|
||||
biasw = max(1e-5, float(bias) * diag)
|
||||
tmax = float(max_distance) * diag
|
||||
|
||||
# Back-project onto the high surface, then lift along the normal: the low-poly chord can sit
|
||||
# below the high surface, and casting from below floods false self-occlusion (dark blotches).
|
||||
bp = _closest_points_on_mesh_bvh(P, tri, bvh)
|
||||
origins = bp + Nl * biasw
|
||||
|
||||
K = P.shape[0]
|
||||
T, B = _onb(Nl)
|
||||
occ = torch.zeros(K, device=dev)
|
||||
tex_per_chunk = max(1, int(ray_chunk) // max(1, S))
|
||||
for s in range(0, K, tex_per_chunk):
|
||||
e = min(s + tex_per_chunk, K)
|
||||
kk = e - s
|
||||
o, n, t, b = origins[s:e], Nl[s:e], T[s:e], B[s:e]
|
||||
r1 = torch.rand(kk, S, device=dev)
|
||||
r2 = torch.rand(kk, S, device=dev)
|
||||
sr = r1.sqrt()
|
||||
lz = r1.mul_(-1.0).add_(1.0).clamp_min_(0.0).sqrt_() # sqrt(1-r1) (r1 dead after sr)
|
||||
ang = r2.mul_(2.0 * math.pi) # in place (r2 dead)
|
||||
lx = ang.cos().mul_(sr)
|
||||
ly = ang.sin().mul_(sr)
|
||||
d = t[:, None, :] * lx[..., None] # cosine-weighted hemisphere,
|
||||
d.addcmul_(b[:, None, :], ly[..., None]) # fused d += b*ly
|
||||
d.addcmul_(n[:, None, :], lz[..., None]) # fused d += n*lz (no extra temps)
|
||||
d = torch.nn.functional.normalize(d.reshape(-1, 3), dim=-1, eps=1e-6)
|
||||
oo = o[:, None, :].expand(-1, S, -1).reshape(-1, 3)
|
||||
hit = _any_hit_rays_bvh(oo, d, tri, bvh, tmin=biasw, tmax=tmax)
|
||||
occ[s:e] = hit.reshape(kk, S).sum(1, dtype=torch.float32).div_(float(S)) # mean without a float copy
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
ao = occ.mul_(-float(strength)).add_(1.0).clamp_(0.0, 1.0) # 1 - occ*strength, in place (occ is dead)
|
||||
out = torch.ones((H, W), device=dev)
|
||||
out[m] = ao
|
||||
out3 = np.repeat(out.cpu().numpy()[..., None], 3, axis=2)
|
||||
return _jfa_fill_gpu(np.ascontiguousarray(out3, dtype=np.float32), mask.cpu().numpy())
|
||||
|
||||
|
||||
def _compute_vertex_tangents(verts, faces, uvs, normals):
|
||||
"""Per-vertex tangents (Lengyel) orthonormalized against `normals`. Returns [N,4]:
|
||||
unit tangent xyz + handedness w (the bitangent is w * cross(N, T)). Pure torch."""
|
||||
N = verts.shape[0]
|
||||
i0, i1, i2 = faces[:, 0].long(), faces[:, 1].long(), faces[:, 2].long()
|
||||
e1, e2 = verts[i1] - verts[i0], verts[i2] - verts[i0]
|
||||
d1, d2 = uvs[i1] - uvs[i0], uvs[i2] - uvs[i0]
|
||||
denom = d1[:, 0] * d2[:, 1] - d2[:, 0] * d1[:, 1]
|
||||
r = 1.0 / torch.where(denom.abs() < 1e-20, torch.full_like(denom, 1e-20), denom)
|
||||
tan = (d2[:, 1:2] * e1 - d1[:, 1:2] * e2) * r[:, None] # [F,3]
|
||||
bit = (d1[:, 0:1] * e2 - d2[:, 0:1] * e1) * r[:, None]
|
||||
tacc = torch.zeros((N, 3), device=verts.device, dtype=verts.dtype)
|
||||
bacc = torch.zeros((N, 3), device=verts.device, dtype=verts.dtype)
|
||||
for idx in (i0, i1, i2):
|
||||
tacc.scatter_add_(0, idx[:, None].expand(-1, 3), tan)
|
||||
bacc.scatter_add_(0, idx[:, None].expand(-1, 3), bit)
|
||||
n = torch.nn.functional.normalize(normals, dim=-1, eps=1e-6)
|
||||
# Gram-Schmidt: drop the normal component, then renormalize.
|
||||
t = torch.nn.functional.normalize(tacc - n * (n * tacc).sum(-1, keepdim=True), dim=-1, eps=1e-6)
|
||||
w = torch.sign((torch.cross(n, t, dim=-1) * bacc).sum(-1))
|
||||
w = torch.where(w == 0, torch.ones_like(w), w) # degenerate → right-handed
|
||||
return torch.cat([t, w[:, None]], dim=-1)
|
||||
|
||||
|
||||
def _vertex_tangents_for_item(lv, lf, uv, low_n_attr_i, dev):
|
||||
"""Per-item shading normals + tangents. Shared by the bake (BakeNormalMapFromMesh) and the
|
||||
export attach (ApplyTextureToMesh) so their basis can't diverge. `low_n_attr_i` is the
|
||||
mesh's per-item normals or None (then computed). Returns (low_n [N,3], tangents [N,4])."""
|
||||
low_n = low_n_attr_i.to(dev).float() if low_n_attr_i is not None else _compute_vertex_normals(lv, lf)
|
||||
tangents = _compute_vertex_tangents(lv, lf, uv.to(dev).float(), low_n)
|
||||
return low_n, tangents
|
||||
|
||||
|
||||
def _barycentric(p, tri):
|
||||
"""Barycentric coords [N,3] of points p [N,3] wrt triangles tri [N,3,3] (per-pair)."""
|
||||
a, b, c = tri[:, 0], tri[:, 1], tri[:, 2]
|
||||
v0, v1, v2 = b - a, c - a, p - a
|
||||
d00 = (v0 * v0).sum(-1)
|
||||
d01 = (v0 * v1).sum(-1)
|
||||
d11 = (v1 * v1).sum(-1)
|
||||
d20 = (v2 * v0).sum(-1)
|
||||
d21 = (v2 * v1).sum(-1)
|
||||
denom = d00 * d11 - d01 * d01
|
||||
denom = torch.where(denom.abs() < 1e-20, torch.full_like(denom, 1e-20), denom)
|
||||
v = (d11 * d20 - d01 * d21) / denom
|
||||
w = (d00 * d21 - d01 * d20) / denom
|
||||
return torch.stack([1.0 - v - w, v, w], dim=-1)
|
||||
|
||||
|
||||
def _bake_normal_map(high_v, high_f, high_n, low_v_np, low_f_np, low_uv_np, low_n, tangents,
|
||||
resolution, cage_distance=0.05, ignore_backfaces=True):
|
||||
"""Tangent-space normal map (glTF/OpenGL +Y) of the high-poly baked into the low-poly's UV
|
||||
layout. Per texel a cage ray (along the normal, over cage_distance * bbox-diagonal) finds the
|
||||
matching high-poly surface, whose normal is projected into the texel's TBN frame.
|
||||
ignore_backfaces skips surfaces facing away (crevices/enclosures); misses fall back to
|
||||
closest-point. Returns [H,W,3] in [0,1]."""
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
H = W = int(resolution)
|
||||
flat = np.array([0.5, 0.5, 1.0], dtype=np.float32)
|
||||
|
||||
# One rasterization, then interpolate position/normal/tangent/handedness by indexing it.
|
||||
face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution)
|
||||
if not mask.any():
|
||||
return np.tile(flat, (H, W, 1))
|
||||
lf = torch.from_numpy(np.ascontiguousarray(low_f_np).astype(np.int64)).to(dev)
|
||||
lv = torch.from_numpy(np.ascontiguousarray(low_v_np, dtype=np.float32)).to(dev)
|
||||
low_n = low_n.to(dev).float()
|
||||
tangents = tangents.to(dev).float()
|
||||
m = mask
|
||||
fsel = face_idx[m] # [K] source face per texel
|
||||
bsel = bary_uv[m] # [K,3]
|
||||
vtri = lf[fsel] # [K,3] vertex ids
|
||||
|
||||
def _interp(attr): # attr [N,C] -> [K,C]
|
||||
return (bsel[:, :, None] * attr[vtri]).sum(1)
|
||||
|
||||
P = _interp(lv) # [K,3] world pos
|
||||
Nl = torch.nn.functional.normalize(_interp(low_n), dim=-1, eps=1e-6)
|
||||
Tl = _interp(tangents[:, :3])
|
||||
Wl = _interp(tangents[:, 3:4])[:, 0]
|
||||
|
||||
hv = high_v.to(dev).float()
|
||||
hf = high_f.to(dev).long()
|
||||
tri = hv[hf]
|
||||
bvh = _build_triangle_bvh(tri)
|
||||
|
||||
# Cage ray-cast: from cage outward, march back along -normal and take the nearest (outermost)
|
||||
# hit. Closest-point is the fallback where the ray misses.
|
||||
diag = float((hv.amax(0) - hv.amin(0)).norm().clamp_min(1e-6))
|
||||
cage = max(1e-6, float(cage_distance) * diag)
|
||||
origin = P + Nl * cage
|
||||
t_hit, f_hit, ray_hit = _closest_hit_rays_bvh(origin, -Nl, tri, bvh, tmin=0.0, tmax=2.0 * cage,
|
||||
cull_backface=bool(ignore_backfaces))
|
||||
bface = f_hit.clamp_min(0)
|
||||
hitpoint = origin - t_hit[:, None] * Nl
|
||||
# Closest-point fallback only for texels the ray missed (usually few) — running it over every
|
||||
# texel wastes a full BVH traversal on the ones already resolved by the ray.
|
||||
miss = ~ray_hit
|
||||
if bool(miss.any()):
|
||||
bp_m, bf_m = _closest_points_on_mesh_bvh(P[miss], tri, bvh, return_face=True)
|
||||
bface = bface.clone()
|
||||
hitpoint = hitpoint.clone()
|
||||
bface[miss] = bf_m.clamp_min(0)
|
||||
hitpoint[miss] = bp_m
|
||||
|
||||
htri = tri[bface] # [K,3,3]
|
||||
bary = _barycentric(hitpoint, htri)
|
||||
hn_tri = high_n.to(dev).float()[hf[bface]] # [K,3,3] vertex normals
|
||||
Nh = torch.nn.functional.normalize((bary[:, :, None] * hn_tri).sum(1), dim=-1, eps=1e-6)
|
||||
|
||||
# Per-texel TBN (Gram-Schmidt tangent against the interpolated normal).
|
||||
T = torch.nn.functional.normalize(Tl - Nl * (Nl * Tl).sum(-1, keepdim=True), dim=-1, eps=1e-6)
|
||||
Bn = Wl[:, None] * torch.cross(Nl, T, dim=-1)
|
||||
nz = (Nh * Nl).sum(-1) # reused as z-channel and the back-face test
|
||||
ts = torch.stack([(Nh * T).sum(-1), (Nh * Bn).sum(-1), nz], dim=-1)
|
||||
ts = torch.nn.functional.normalize(ts, dim=-1, eps=1e-6)
|
||||
# Safety net: if the matched high normal faces away from the texel (a back surface the fallback
|
||||
# grabbed in a deep crevice), use the flat base normal rather than a wrong one.
|
||||
ts[nz < 0.0] = torch.tensor([0.0, 0.0, 1.0], device=dev)
|
||||
enc = ts.mul_(0.5).add_(0.5).clamp_(0.0, 1.0) # encode in place (ts is dead)
|
||||
|
||||
out = torch.from_numpy(np.tile(flat, (H, W, 1))).to(dev)
|
||||
out[m] = enc
|
||||
# Dilate into the UV gutter so bilinear/mip sampling at chart edges doesn't bleed flat blue.
|
||||
return _jfa_fill_gpu(out.cpu().numpy(), mask.cpu().numpy())
|
||||
|
||||
|
||||
def _jfa_fill_gpu(img01, mask):
|
||||
"""Fill uncovered texels with nearest covered value via GPU Jump Flooding
|
||||
(O(log n) passes; replaces cv2.inpaint). img01 [H,W,C] float, mask [H,W] bool."""
|
||||
@ -919,15 +1293,17 @@ class MeshTextureToImage(IO.ComfyNode):
|
||||
display_name="Mesh Texture to Image",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Extracts a mesh's baked textures as IMAGE outputs: base_color and the packed "
|
||||
"glTF MR map (G=roughness, B=metallic; black if no PBR texture)."
|
||||
"Extracts a mesh's baked textures as individual IMAGEs: base_color, metallic, "
|
||||
"roughness, occlusion and normal_map. Channels with nothing baked come back "
|
||||
"neutral (occlusion white, normal flat)."
|
||||
),
|
||||
inputs=[IO.Mesh.Input("mesh")],
|
||||
outputs=[
|
||||
IO.Image.Output(display_name="base_color"),
|
||||
IO.Image.Output(display_name="metallic_roughness"),
|
||||
IO.Image.Output(display_name="metallic"),
|
||||
IO.Image.Output(display_name="roughness"),
|
||||
IO.Image.Output(display_name="occlusion"),
|
||||
IO.Image.Output(display_name="normal_map"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -944,6 +1320,7 @@ class MeshTextureToImage(IO.ComfyNode):
|
||||
|
||||
base = _as_image(getattr(mesh, "texture", None))
|
||||
mr = _as_image(getattr(mesh, "metallic_roughness", None))
|
||||
normal_map = _as_image(getattr(mesh, "normal_map", None))
|
||||
|
||||
if base is None:
|
||||
raise ValueError(
|
||||
@ -952,10 +1329,18 @@ class MeshTextureToImage(IO.ComfyNode):
|
||||
)
|
||||
if mr is None:
|
||||
mr = torch.zeros_like(base)
|
||||
# Split packed MR into grayscale previews (G=roughness, B=metallic), to 3ch.
|
||||
if normal_map is None:
|
||||
normal_map = torch.ones_like(base) * torch.tensor([0.5, 0.5, 1.0]) # neutral flat normal
|
||||
# Unpack the ORM map (R=occlusion, G=roughness, B=metallic) to 3-channel grayscale.
|
||||
metallic = mr[..., 2:3].expand(-1, -1, -1, 3).contiguous()
|
||||
roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous()
|
||||
return IO.NodeOutput(base, mr, metallic, roughness)
|
||||
# R is real occlusion only if AO was baked; else it's the unused zero channel, which as
|
||||
# "occlusion" would read fully-dark — so report white unless occlusion_in_mr is set.
|
||||
if getattr(mesh, "occlusion_in_mr", False):
|
||||
occlusion = mr[..., 0:1].expand(-1, -1, -1, 3).contiguous()
|
||||
else:
|
||||
occlusion = torch.ones_like(base)
|
||||
return IO.NodeOutput(base, metallic, roughness, occlusion, normal_map)
|
||||
|
||||
|
||||
class ApplyTextureToMesh(IO.ComfyNode):
|
||||
@ -966,29 +1351,26 @@ class ApplyTextureToMesh(IO.ComfyNode):
|
||||
display_name="Apply Texture to Mesh",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Attaches baked texture IMAGEs to a mesh's existing UV layout for SaveGLB. "
|
||||
"Pairs with BakeTextureFromVoxel: feed the SAME mesh and its base_color "
|
||||
"(optionally metallic/roughness); don't re-unwrap in between. metallic/roughness "
|
||||
"repack into the glTF MR map (G=roughness, B=metallic); missing metallic=0, "
|
||||
"roughness=1."
|
||||
"Attaches baked texture IMAGEs to a mesh's UV layout for SaveGLB. Feed the SAME mesh you baked"
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Image.Input("base_color"),
|
||||
IO.Image.Input("metallic", optional=True),
|
||||
IO.Image.Input("roughness", optional=True),
|
||||
IO.Image.Input("occlusion", optional=True),
|
||||
IO.Image.Input("normal_map", optional=True),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, base_color, metallic=None, roughness=None):
|
||||
def execute(cls, mesh, base_color, metallic=None, roughness=None, occlusion=None, normal_map=None):
|
||||
mesh_uvs = getattr(mesh, "uvs", None)
|
||||
if mesh_uvs is None:
|
||||
raise ValueError(
|
||||
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
|
||||
"you fed to BakeTextureFromVoxel (this node attaches onto existing UVs and "
|
||||
"never unwraps).")
|
||||
"you fed to BakeTextureFromVoxel.")
|
||||
|
||||
# Re-derive the exact UVs the bake used (shared _normalize_uvs_to_unit), per item.
|
||||
if mesh_uvs.ndim == 3:
|
||||
@ -1005,15 +1387,264 @@ class ApplyTextureToMesh(IO.ComfyNode):
|
||||
out_mesh = copy.copy(mesh)
|
||||
out_mesh.uvs = new_uvs
|
||||
out_mesh.texture = base_color.float().clamp(0.0, 1.0).cpu()
|
||||
if metallic is not None or roughness is not None:
|
||||
# Repack glTF MR (G=roughness, B=metallic); missing channel → scalar (metal 0/rough 1).
|
||||
prov = (metallic if metallic is not None else roughness).float().clamp(0.0, 1.0).cpu()
|
||||
B, H, W, _ = prov.shape
|
||||
rough_ch = (roughness.float().clamp(0.0, 1.0).cpu()[..., 0:1]
|
||||
if roughness is not None else torch.ones((B, H, W, 1)))
|
||||
metal_ch = (metallic.float().clamp(0.0, 1.0).cpu()[..., 0:1]
|
||||
if metallic is not None else torch.zeros((B, H, W, 1)))
|
||||
out_mesh.metallic_roughness = torch.cat([torch.zeros((B, H, W, 1)), rough_ch, metal_ch], dim=-1)
|
||||
if normal_map is not None:
|
||||
# Recompute tangents (shared helper, same normalized UVs → same basis as the bake)
|
||||
# and export the smooth normals the TBN was built on — without a NORMAL attribute the
|
||||
# viewer shades flat and the tangent-space detail fights the faceting.
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
low_n_attr = getattr(mesh, "normals", None)
|
||||
B = int(mesh.vertices.shape[0])
|
||||
Nmax = int(mesh.vertices.shape[1]) if mesh.vertices.ndim == 3 else int(mesh.vertices.shape[0])
|
||||
tangents_padded = torch.zeros((B, Nmax, 4), dtype=torch.float32)
|
||||
normals_padded = torch.zeros((B, Nmax, 3), dtype=torch.float32)
|
||||
for i in range(B):
|
||||
v_i, f_i, _ = get_mesh_batch_item(mesh, i)
|
||||
n = int(v_i.shape[0])
|
||||
if f_i.numel() == 0:
|
||||
continue
|
||||
lv, lf = v_i.to(dev).float(), f_i.to(dev).long()
|
||||
uv_i = new_uvs[i, :n] if new_uvs.ndim == 3 else new_uvs[:n]
|
||||
n_attr_i = low_n_attr[i, :n] if low_n_attr is not None else None
|
||||
low_n, tangents = _vertex_tangents_for_item(lv, lf, uv_i, n_attr_i, dev)
|
||||
tangents_padded[i, :n] = tangents.cpu()
|
||||
normals_padded[i, :n] = low_n.cpu()
|
||||
out_mesh.normal_map = normal_map.float().clamp(0.0, 1.0).cpu()
|
||||
out_mesh.tangents = tangents_padded
|
||||
out_mesh.normals = normals_padded
|
||||
if metallic is not None or roughness is not None or occlusion is not None:
|
||||
# Pack glTF ORM (R=occlusion, G=roughness, B=metallic); missing → 1/1/0. Maps may
|
||||
# arrive at different resolutions, so resize each channel to a common H×W first.
|
||||
provided = [x for x in (metallic, roughness, occlusion) if x is not None]
|
||||
B = int(provided[0].shape[0])
|
||||
H = max(int(x.shape[1]) for x in provided)
|
||||
W = max(int(x.shape[2]) for x in provided)
|
||||
|
||||
def _chan(img, default):
|
||||
if img is None:
|
||||
return torch.full((B, H, W, 1), float(default))
|
||||
t = img.float().clamp(0.0, 1.0).cpu()[..., 0:1]
|
||||
if int(t.shape[1]) != H or int(t.shape[2]) != W:
|
||||
t = torch.nn.functional.interpolate(t.permute(0, 3, 1, 2), size=(H, W),
|
||||
mode="bilinear", align_corners=False).permute(0, 2, 3, 1)
|
||||
return t
|
||||
|
||||
out_mesh.metallic_roughness = torch.cat(
|
||||
[_chan(occlusion, 1.0), _chan(roughness, 1.0), _chan(metallic, 0.0)], dim=-1)
|
||||
if occlusion is not None:
|
||||
# Tells SaveGLB to also point occlusionTexture at the MR image (R = AO).
|
||||
out_mesh.occlusion_in_mr = True
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
|
||||
class BakeNormalMapFromMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BakeNormalMapFromMesh",
|
||||
display_name="Bake Normal Map from Mesh",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Bakes a tangent-space normal map (glTF/OpenGL +Y) from a high-poly mesh into a "
|
||||
"low-poly's UV layout, capturing detail lost to decimation. Feed the UV-unwrapped "
|
||||
"low_poly and the same-frame high_poly it was decimated from. Outputs an IMAGE for "
|
||||
"ApplyTextureToMesh's normal_map input."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("low_poly"),
|
||||
IO.Mesh.Input("high_poly"),
|
||||
IO.Int.Input("resolution", default=1024, min=64, max=8192, step=64,
|
||||
display_name="resolution"),
|
||||
IO.Float.Input("cage_distance", default=0.05, min=0.001, max=0.5, step=0.001,
|
||||
tooltip="Surface search band, as a fraction of the bbox diagonal. "
|
||||
"Raise for wrong/missing patches under heavy decimation; "
|
||||
"lower if it grabs across gaps."),
|
||||
IO.Boolean.Input("ignore_backfaces", default=True,
|
||||
tooltip="Skip high-poly surfaces facing away from the texel, so "
|
||||
"crevices/enclosed spaces don't grab the opposite wall. "
|
||||
"Disable only if the high-poly winding is inconsistent."),
|
||||
],
|
||||
outputs=[IO.Image.Output(display_name="normal_map")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, low_poly, high_poly, resolution, cage_distance=0.05, ignore_backfaces=True):
|
||||
low_uvs = getattr(low_poly, "uvs", None)
|
||||
if low_uvs is None:
|
||||
raise ValueError(
|
||||
"BakeNormalMapFromMesh: low_poly has no UVs. Connect the UV-unwrapped "
|
||||
"low-poly (the same one you fed to BakeTextureFromVoxel); this node bakes "
|
||||
"onto existing UVs and never unwraps.")
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
|
||||
low_n_attr = getattr(low_poly, "normals", None)
|
||||
high_n_attr = getattr(high_poly, "normals", None)
|
||||
B = int(low_poly.vertices.shape[0])
|
||||
h_batch = int(high_poly.vertices.shape[0])
|
||||
|
||||
imgs = []
|
||||
for i in range(B):
|
||||
v_i, f_i, _ = get_mesh_batch_item(low_poly, i)
|
||||
n = int(v_i.shape[0])
|
||||
if f_i.numel() == 0:
|
||||
logging.warning(f"BakeNormalMapFromMesh: skipping batch {i} (empty mesh)")
|
||||
imgs.append(torch.full((int(resolution), int(resolution), 3), 0.5))
|
||||
continue
|
||||
|
||||
uv_i = low_uvs[i, :n] if low_uvs.ndim == 3 else low_uvs[:n]
|
||||
uv_np = _normalize_uvs_to_unit(uv_i.detach().cpu().numpy(), log_prefix="[BakeNormalMapFromMesh] ")
|
||||
|
||||
lv = v_i.to(dev).float()
|
||||
lf = f_i.to(dev).long()
|
||||
# Tangents build the per-texel TBN; ApplyTextureToMesh recomputes the same basis on export.
|
||||
n_attr_i = low_n_attr[i, :n] if low_n_attr is not None else None
|
||||
low_n, tangents = _vertex_tangents_for_item(lv, lf, torch.from_numpy(uv_np).to(dev), n_attr_i, dev)
|
||||
|
||||
hv_i, hf_i, _ = get_mesh_batch_item(high_poly, i if h_batch > 1 else 0)
|
||||
hv = hv_i.to(dev).float()
|
||||
hf = hf_i.to(dev).long()
|
||||
high_n = (high_n_attr[i, :hv.shape[0]].to(dev).float() if high_n_attr is not None
|
||||
else _compute_vertex_normals(hv, hf))
|
||||
|
||||
img = _bake_normal_map(
|
||||
hv, hf, high_n,
|
||||
lv.detach().cpu().numpy(), lf.detach().cpu().numpy().astype(np.uint32), uv_np,
|
||||
low_n, tangents, resolution, cage_distance=float(cage_distance),
|
||||
ignore_backfaces=bool(ignore_backfaces),
|
||||
)
|
||||
imgs.append(torch.from_numpy(np.ascontiguousarray(img)).float())
|
||||
|
||||
normal_img = torch.stack([t.clamp(0.0, 1.0) for t in imgs], dim=0)
|
||||
return IO.NodeOutput(normal_img)
|
||||
|
||||
|
||||
class BakeAmbientOcclusion(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BakeAmbientOcclusion",
|
||||
display_name="Bake Ambient Occlusion",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Bakes an ambient-occlusion map from a high-poly mesh into a low-poly's UV "
|
||||
"layout (white = open, dark = crevices). Feed the UV-unwrapped low_poly and the "
|
||||
"high_poly it was decimated from. Outputs a grayscale IMAGE for "
|
||||
"ApplyTextureToMesh's occlusion input (packed into the ORM map / occlusionTexture)."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("low_poly"),
|
||||
IO.Mesh.Input("high_poly"),
|
||||
IO.Int.Input("resolution", default=1024, min=64, max=8192, step=64),
|
||||
IO.Int.Input("samples", default=64, min=4, max=1024, step=4,
|
||||
tooltip="Rays per texel. More = smoother, slower. Raise if grainy."),
|
||||
IO.Float.Input("max_distance", default=0.5, min=0.01, max=2.0, step=0.01,
|
||||
tooltip="Ray length, as a fraction of the bbox diagonal. "
|
||||
"Smaller = tighter, more local occlusion."),
|
||||
IO.Float.Input("strength", default=1.0, min=0.0, max=2.0, step=0.05,
|
||||
tooltip="Scales the occlusion. >1 darkens, <1 lightens."),
|
||||
IO.Float.Input("bias", default=0.01, min=0.0001, max=0.2, step=0.0005,
|
||||
tooltip="Ray origin lift off the surface, as a fraction of the bbox "
|
||||
"diagonal. Raise if even surfaces show dark blotches/holes."),
|
||||
],
|
||||
outputs=[IO.Image.Output(display_name="occlusion")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, low_poly, high_poly, resolution, samples, max_distance, strength, bias):
|
||||
low_uvs = getattr(low_poly, "uvs", None)
|
||||
if low_uvs is None:
|
||||
raise ValueError(
|
||||
"BakeAmbientOcclusion: low_poly has no UVs. Connect the UV-unwrapped low-poly "
|
||||
"(the same one used for the other bakes); this node never unwraps.")
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
low_n_attr = getattr(low_poly, "normals", None)
|
||||
B = int(low_poly.vertices.shape[0])
|
||||
h_batch = int(high_poly.vertices.shape[0])
|
||||
|
||||
pbar = comfy.utils.ProgressBar(max(1, B)) # one tick per batch item
|
||||
imgs = []
|
||||
for i in range(B):
|
||||
v_i, f_i, _ = get_mesh_batch_item(low_poly, i)
|
||||
n = int(v_i.shape[0])
|
||||
if f_i.numel() == 0:
|
||||
logging.warning(f"BakeAmbientOcclusion: skipping batch {i} (empty mesh)")
|
||||
imgs.append(torch.ones((int(resolution), int(resolution), 3)))
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
uv_i = low_uvs[i, :n] if low_uvs.ndim == 3 else low_uvs[:n]
|
||||
uv_np = _normalize_uvs_to_unit(uv_i.detach().cpu().numpy(), log_prefix="[BakeAmbientOcclusion] ")
|
||||
lv = v_i.to(dev).float()
|
||||
lf = f_i.to(dev).long()
|
||||
low_n = (low_n_attr[i, :n].to(dev).float() if low_n_attr is not None
|
||||
else _compute_vertex_normals(lv, lf))
|
||||
|
||||
hv_i, hf_i, _ = get_mesh_batch_item(high_poly, i if h_batch > 1 else 0)
|
||||
img = _bake_ambient_occlusion(
|
||||
hv_i.to(dev).float(), hf_i.to(dev).long(),
|
||||
lv.detach().cpu().numpy(), lf.detach().cpu().numpy().astype(np.uint32), uv_np,
|
||||
low_n, resolution, num_samples=int(samples),
|
||||
max_distance=float(max_distance), strength=float(strength), bias=float(bias),
|
||||
)
|
||||
imgs.append(torch.from_numpy(np.ascontiguousarray(img)).float())
|
||||
pbar.update(1)
|
||||
|
||||
ao_img = torch.stack([t.clamp(0.0, 1.0) for t in imgs], dim=0)
|
||||
return IO.NodeOutput(ao_img)
|
||||
|
||||
|
||||
class SetMeshMaterial(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SetMeshMaterial",
|
||||
display_name="Set Mesh Material",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Sets glTF material properties SaveGLB can't derive from textures: emissive "
|
||||
"(color + strength + optional texture), baseColor tint, metallic/roughness "
|
||||
"factors, normal scale, occlusion strength, double-sided. Place before SaveGLB."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Float.Input("emissive_r", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Float.Input("emissive_g", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Float.Input("emissive_b", default=0.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Float.Input("emissive_strength", default=1.0, min=0.0, max=100.0, step=0.1,
|
||||
tooltip=">1 for HDR glow (KHR_materials_emissive_strength)."),
|
||||
IO.Image.Input("emissive_texture", optional=True),
|
||||
IO.Float.Input("base_color_r", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Float.Input("base_color_g", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Float.Input("base_color_b", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Float.Input("metallic_factor", default=-1.0, min=-1.0, max=1.0, step=0.01,
|
||||
tooltip="-1 = leave auto; 0..1 overrides."),
|
||||
IO.Float.Input("roughness_factor", default=-1.0, min=-1.0, max=1.0, step=0.01,
|
||||
tooltip="-1 = leave auto; 0..1 overrides."),
|
||||
IO.Float.Input("normal_scale", default=1.0, min=0.0, max=10.0, step=0.05),
|
||||
IO.Float.Input("occlusion_strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Boolean.Input("double_sided", default=True),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, emissive_r, emissive_g, emissive_b, emissive_strength,
|
||||
base_color_r, base_color_g, base_color_b, metallic_factor, roughness_factor,
|
||||
normal_scale, occlusion_strength, double_sided, emissive_texture=None):
|
||||
out_mesh = copy.copy(mesh)
|
||||
material = dict(getattr(mesh, "material", {}) or {}) # merge over any prior material
|
||||
material.update({
|
||||
"emissive_factor": [float(emissive_r), float(emissive_g), float(emissive_b)],
|
||||
"emissive_strength": float(emissive_strength),
|
||||
"base_color_factor": [float(base_color_r), float(base_color_g), float(base_color_b), 1.0],
|
||||
"metallic_factor": float(metallic_factor), # <0 => leave auto
|
||||
"roughness_factor": float(roughness_factor),
|
||||
"normal_scale": float(normal_scale),
|
||||
"occlusion_strength": float(occlusion_strength),
|
||||
"double_sided": bool(double_sided),
|
||||
})
|
||||
out_mesh.material = material
|
||||
if emissive_texture is not None:
|
||||
out_mesh.emissive = emissive_texture.float().clamp(0.0, 1.0).cpu()
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
|
||||
@ -2278,8 +2909,7 @@ class MergeMeshes(IO.ComfyNode):
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Concatenate N meshes into one by offsetting face indices and stacking verts, "
|
||||
"faces, uvs, and colors. E.g. combine a Pixal3D object with a MoGe background "
|
||||
"into one GLB."
|
||||
"faces, uvs, and colors."
|
||||
),
|
||||
inputs=[
|
||||
IO.Autogrow.Input("meshes", template=autogrow_template),
|
||||
@ -2306,6 +2936,9 @@ class PostProcessMeshExtension(ComfyExtension):
|
||||
BakeTextureFromVoxel,
|
||||
MeshTextureToImage,
|
||||
ApplyTextureToMesh,
|
||||
BakeNormalMapFromMesh,
|
||||
BakeAmbientOcclusion,
|
||||
SetMeshMaterial,
|
||||
MergeMeshes,
|
||||
]
|
||||
|
||||
|
||||
@ -20,11 +20,11 @@ from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False,
|
||||
normals=None, metallic_roughness=None):
|
||||
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
|
||||
# stashing per-item lengths as runtime attrs so consumers can recover the real slice.
|
||||
# colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts.
|
||||
# texture is (B, H, W, 3) — passed through unchanged
|
||||
normals=None, metallic_roughness=None, tangents=None, normal_map=None,
|
||||
occlusion_in_mr=False, material=None, emissive=None):
|
||||
# Pack per-item tensors into padded batches, stashing per-item lengths as runtime attrs.
|
||||
# colors/uvs/normals/tangents are 1:1 with vertices (padded to max_vertices); texture/
|
||||
# metallic_roughness/normal_map are (B,H,W,*) image stacks passed through unchanged.
|
||||
batch_size = len(vertices)
|
||||
max_vertices = max(v.shape[0] for v in vertices)
|
||||
max_faces = max(f.shape[0] for f in faces)
|
||||
@ -65,11 +65,31 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non
|
||||
)
|
||||
packed_normals[i, :nrm.shape[0]] = nrm
|
||||
|
||||
return Types.MESH(packed_vertices, packed_faces,
|
||||
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
|
||||
metallic_roughness=metallic_roughness,
|
||||
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
|
||||
normals=packed_normals)
|
||||
packed_tangents = None
|
||||
if tangents is not None:
|
||||
packed_tangents = tangents[0].new_zeros((batch_size, max_vertices, tangents[0].shape[1]))
|
||||
for i, tn in enumerate(tangents):
|
||||
assert tn.shape[0] == vertices[i].shape[0], (
|
||||
f"tangents[{i}] has {tn.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)"
|
||||
)
|
||||
packed_tangents[i, :tn.shape[0]] = tn
|
||||
|
||||
out = Types.MESH(packed_vertices, packed_faces,
|
||||
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
|
||||
metallic_roughness=metallic_roughness,
|
||||
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
|
||||
normals=packed_normals)
|
||||
if packed_tangents is not None:
|
||||
out.tangents = packed_tangents
|
||||
if normal_map is not None:
|
||||
out.normal_map = normal_map
|
||||
if occlusion_in_mr:
|
||||
out.occlusion_in_mr = True
|
||||
if material is not None:
|
||||
out.material = material
|
||||
if emissive is not None:
|
||||
out.emissive = emissive
|
||||
return out
|
||||
|
||||
|
||||
def get_mesh_batch_item(mesh, index):
|
||||
@ -180,7 +200,8 @@ def _compute_vertex_normals(vertices_np, faces_np, crease_angle=None):
|
||||
def save_glb(vertices, faces, filepath, metadata=None,
|
||||
uvs=None, vertex_colors=None, texture_image=None,
|
||||
metallic_roughness_image=None, unlit=False,
|
||||
normals=None):
|
||||
normals=None, normal_map_image=None, tangents=None, occlusion_in_mr=False,
|
||||
material=None, emissive_image=None):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
@ -197,6 +218,16 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
normals: torch.Tensor of shape (N, 3) - Optional per-vertex normals, written as the
|
||||
glTF NORMAL attribute. When omitted, NO normals are written and viewers fall back
|
||||
to flat (per-face) shading — use the MeshSmoothNormals node to generate them.
|
||||
normal_map_image: PIL.Image - Optional tangent-space normal map (glTF/OpenGL +Y),
|
||||
written as the material normalTexture. Needs TEXCOORD_0.
|
||||
tangents: torch.Tensor of shape (N, 4) - Optional per-vertex tangents (xyz + handedness w),
|
||||
written as the glTF TANGENT attribute. Without it viewers derive tangents in-shader.
|
||||
occlusion_in_mr: bool - When True, R of metallic_roughness_image holds AO (ORM packing) and
|
||||
occlusionTexture is pointed at that same image.
|
||||
material: dict - Optional scalar overrides from SetMeshMaterial (base_color_factor,
|
||||
metallic/roughness_factor with <0 = auto, emissive_factor/strength, normal_scale,
|
||||
occlusion_strength, double_sided).
|
||||
emissive_image: PIL.Image - Optional emissive (glow) texture, written as emissiveTexture.
|
||||
"""
|
||||
|
||||
# Convert tensors to numpy arrays
|
||||
@ -231,6 +262,11 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
raise ValueError(
|
||||
f"save_glb: normals has {normals_np.shape[0]} entries but vertex count is {n_verts}"
|
||||
)
|
||||
tangents_np = tangents.cpu().numpy().astype(np.float32) if tangents is not None else None
|
||||
if tangents_np is not None and tangents_np.shape != (n_verts, 4):
|
||||
raise ValueError(
|
||||
f"save_glb: tangents must be (N, 4) with N={n_verts}, got {tuple(tangents_np.shape)}"
|
||||
)
|
||||
faces_np = faces_signed.astype(np.uint32)
|
||||
texture_png_bytes = None
|
||||
if texture_image is not None:
|
||||
@ -242,46 +278,60 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
buf = BytesIO()
|
||||
metallic_roughness_image.save(buf, format="PNG")
|
||||
mr_png_bytes = buf.getvalue()
|
||||
nm_png_bytes = None
|
||||
if normal_map_image is not None:
|
||||
buf = BytesIO()
|
||||
normal_map_image.save(buf, format="PNG")
|
||||
nm_png_bytes = buf.getvalue()
|
||||
em_png_bytes = None
|
||||
if emissive_image is not None:
|
||||
buf = BytesIO()
|
||||
emissive_image.save(buf, format="PNG")
|
||||
em_png_bytes = buf.getvalue()
|
||||
|
||||
vertices_buffer = vertices_np.tobytes()
|
||||
indices_buffer = faces_np.tobytes()
|
||||
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
|
||||
colors_buffer = colors_np.tobytes() if colors_np is not None else b""
|
||||
normals_buffer = normals_np.tobytes() if normals_np is not None else b""
|
||||
tangents_buffer = tangents_np.tobytes() if tangents_np is not None else b""
|
||||
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
|
||||
mr_buffer = mr_png_bytes if mr_png_bytes is not None else b""
|
||||
nm_buffer = nm_png_bytes if nm_png_bytes is not None else b""
|
||||
em_buffer = em_png_bytes if em_png_bytes is not None else b""
|
||||
|
||||
def pad_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
return buffer + b'\x00' * padding_length
|
||||
|
||||
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
|
||||
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
|
||||
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
|
||||
colors_buffer_padded = pad_to_4_bytes(colors_buffer)
|
||||
normals_buffer_padded = pad_to_4_bytes(normals_buffer)
|
||||
texture_buffer_padded = pad_to_4_bytes(texture_buffer)
|
||||
mr_buffer_padded = pad_to_4_bytes(mr_buffer)
|
||||
|
||||
buffer_data = b"".join([
|
||||
vertices_buffer_padded,
|
||||
indices_buffer_padded,
|
||||
uvs_buffer_padded,
|
||||
colors_buffer_padded,
|
||||
normals_buffer_padded,
|
||||
texture_buffer_padded,
|
||||
mr_buffer_padded,
|
||||
])
|
||||
# Blob order in one place; offsets accumulated in a pass so adding a buffer is one entry.
|
||||
_blobs = [
|
||||
("vertices", vertices_buffer), ("indices", indices_buffer), ("uvs", uvs_buffer),
|
||||
("colors", colors_buffer), ("normals", normals_buffer), ("tangents", tangents_buffer),
|
||||
("texture", texture_buffer), ("mr", mr_buffer), ("nm", nm_buffer), ("em", em_buffer),
|
||||
]
|
||||
byte_offset = {}
|
||||
acc = 0
|
||||
parts = []
|
||||
for name, b in _blobs:
|
||||
padded = pad_to_4_bytes(b)
|
||||
byte_offset[name] = acc
|
||||
acc += len(padded)
|
||||
parts.append(padded)
|
||||
buffer_data = b"".join(parts)
|
||||
|
||||
vertices_byte_length = len(vertices_buffer)
|
||||
vertices_byte_offset = 0
|
||||
indices_byte_length = len(indices_buffer)
|
||||
indices_byte_offset = len(vertices_buffer_padded)
|
||||
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
|
||||
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
|
||||
normals_byte_offset = colors_byte_offset + len(colors_buffer_padded)
|
||||
texture_byte_offset = normals_byte_offset + len(normals_buffer_padded)
|
||||
mr_byte_offset = texture_byte_offset + len(texture_buffer_padded)
|
||||
vertices_byte_offset = byte_offset["vertices"]
|
||||
indices_byte_offset = byte_offset["indices"]
|
||||
uvs_byte_offset = byte_offset["uvs"]
|
||||
colors_byte_offset = byte_offset["colors"]
|
||||
normals_byte_offset = byte_offset["normals"]
|
||||
tangents_byte_offset = byte_offset["tangents"]
|
||||
texture_byte_offset = byte_offset["texture"]
|
||||
mr_byte_offset = byte_offset["mr"]
|
||||
nm_byte_offset = byte_offset["nm"]
|
||||
em_byte_offset = byte_offset["em"]
|
||||
|
||||
buffer_views = [
|
||||
{
|
||||
@ -368,6 +418,23 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
})
|
||||
primitive_attributes["NORMAL"] = accessor_idx
|
||||
|
||||
if tangents_np is not None and len(tangents_np) > 0:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
"byteOffset": tangents_byte_offset,
|
||||
"byteLength": len(tangents_buffer),
|
||||
"target": 34962
|
||||
})
|
||||
accessor_idx = len(accessors)
|
||||
accessors.append({
|
||||
"bufferView": len(buffer_views) - 1,
|
||||
"byteOffset": 0,
|
||||
"componentType": 5126, # FLOAT
|
||||
"count": len(tangents_np),
|
||||
"type": "VEC4", # xyz tangent + w handedness (glTF TANGENT)
|
||||
})
|
||||
primitive_attributes["TANGENT"] = accessor_idx
|
||||
|
||||
primitive = {
|
||||
"attributes": primitive_attributes,
|
||||
"indices": 1,
|
||||
@ -379,9 +446,24 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
samplers = []
|
||||
materials = []
|
||||
extensions_used = []
|
||||
|
||||
def add_image_texture(png_byte_offset, png_byte_length):
|
||||
"""Append an embedded PNG image + a texture referencing it; return the texture index."""
|
||||
buffer_views.append({"buffer": 0, "byteOffset": png_byte_offset, "byteLength": png_byte_length})
|
||||
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||
if not samplers:
|
||||
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||
textures.append({"source": len(images) - 1, "sampler": 0})
|
||||
return len(textures) - 1
|
||||
|
||||
has_uv = "TEXCOORD_0" in primitive_attributes
|
||||
if unlit and texture_png_bytes is None:
|
||||
# Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a
|
||||
# gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours.
|
||||
if nm_png_bytes is not None or em_png_bytes is not None or occlusion_in_mr or material is not None:
|
||||
logging.warning(
|
||||
"save_glb: unlit material ignores normal/occlusion/emissive maps and SetMeshMaterial "
|
||||
"overrides — those are PBR-lit features. Disable unlit to export them.")
|
||||
materials.append({
|
||||
"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0},
|
||||
"extensions": {"KHR_materials_unlit": {}},
|
||||
@ -395,37 +477,57 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
"roughnessFactor": 0.5,
|
||||
"baseColorFactor": [0.22, 0.22, 0.22, 1.0],
|
||||
}
|
||||
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
"byteOffset": texture_byte_offset,
|
||||
"byteLength": len(texture_buffer),
|
||||
})
|
||||
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||
textures.append({"source": len(images) - 1, "sampler": 0})
|
||||
pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0}
|
||||
if texture_png_bytes is not None and has_uv:
|
||||
pbr["baseColorTexture"] = {"index": add_image_texture(texture_byte_offset, len(texture_buffer)), "texCoord": 0}
|
||||
|
||||
if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
"byteOffset": mr_byte_offset,
|
||||
"byteLength": len(mr_buffer),
|
||||
})
|
||||
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||
if not samplers:
|
||||
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||
textures.append({"source": len(images) - 1, "sampler": 0})
|
||||
pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0}
|
||||
if mr_png_bytes is not None and has_uv:
|
||||
mr_texture_index = add_image_texture(mr_byte_offset, len(mr_buffer))
|
||||
pbr["metallicRoughnessTexture"] = {"index": mr_texture_index, "texCoord": 0}
|
||||
# When a metallicRoughness texture is present, the factors scale it; use 1.0
|
||||
# so the texture values pass through unchanged (glTF convention).
|
||||
pbr["metallicFactor"] = 1.0
|
||||
pbr["roughnessFactor"] = 1.0
|
||||
|
||||
materials.append({
|
||||
mat = material if isinstance(material, dict) else {}
|
||||
# Scalar overrides from SetMeshMaterial (factor < 0 means "leave auto").
|
||||
if mat.get("base_color_factor") is not None:
|
||||
pbr["baseColorFactor"] = [float(x) for x in mat["base_color_factor"]]
|
||||
if mat.get("metallic_factor", -1.0) >= 0.0:
|
||||
pbr["metallicFactor"] = float(mat["metallic_factor"])
|
||||
if mat.get("roughness_factor", -1.0) >= 0.0:
|
||||
pbr["roughnessFactor"] = float(mat["roughness_factor"])
|
||||
|
||||
material = {
|
||||
"pbrMetallicRoughness": pbr,
|
||||
"doubleSided": True,
|
||||
})
|
||||
"doubleSided": bool(mat.get("double_sided", True)),
|
||||
}
|
||||
if occlusion_in_mr and mr_png_bytes is not None and has_uv:
|
||||
# ORM packing: occlusionTexture reuses the MR image (glTF reads its R channel).
|
||||
material["occlusionTexture"] = {"index": mr_texture_index, "texCoord": 0,
|
||||
"strength": float(mat.get("occlusion_strength", 1.0))}
|
||||
if nm_png_bytes is not None and has_uv:
|
||||
material["normalTexture"] = {"index": add_image_texture(nm_byte_offset, len(nm_buffer)),
|
||||
"texCoord": 0, "scale": float(mat.get("normal_scale", 1.0))}
|
||||
|
||||
emissive_factor = [float(x) for x in mat.get("emissive_factor", [0.0, 0.0, 0.0])]
|
||||
emissive_strength = float(mat.get("emissive_strength", 1.0))
|
||||
has_em_tex = em_png_bytes is not None and has_uv
|
||||
if any(c > 0.0 for c in emissive_factor) or has_em_tex:
|
||||
# glTF multiplies emissiveFactor × texture, so a texture with no color would go black;
|
||||
# default the factor to white in that case.
|
||||
if has_em_tex and not any(c > 0.0 for c in emissive_factor):
|
||||
emissive_factor = [1.0, 1.0, 1.0]
|
||||
material["emissiveFactor"] = [min(1.0, c) for c in emissive_factor]
|
||||
if has_em_tex:
|
||||
material["emissiveTexture"] = {"index": add_image_texture(em_byte_offset, len(em_buffer)),
|
||||
"texCoord": 0}
|
||||
if emissive_strength != 1.0:
|
||||
material.setdefault("extensions", {})["KHR_materials_emissive_strength"] = {
|
||||
"emissiveStrength": emissive_strength}
|
||||
if "KHR_materials_emissive_strength" not in extensions_used:
|
||||
extensions_used.append("KHR_materials_emissive_strength")
|
||||
|
||||
materials.append(material)
|
||||
primitive["material"] = 0
|
||||
|
||||
gltf = {
|
||||
@ -556,6 +658,22 @@ class SaveGLB(IO.ComfyNode):
|
||||
assert mr_np.ndim == 4 and mr_np.shape[-1] == 3, (
|
||||
f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}"
|
||||
)
|
||||
nm_b = getattr(mesh, "normal_map", None)
|
||||
nm_np = None
|
||||
if nm_b is not None:
|
||||
nm_np = (nm_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert nm_np.ndim == 4 and nm_np.shape[-1] == 3, (
|
||||
f"normal_map must be (B, H, W, 3), got shape {tuple(nm_np.shape)}"
|
||||
)
|
||||
em_b = getattr(mesh, "emissive", None)
|
||||
em_np = None
|
||||
if em_b is not None:
|
||||
em_np = (em_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert em_np.ndim == 4 and em_np.shape[-1] == 3, (
|
||||
f"emissive must be (B, H, W, 3), got shape {tuple(em_np.shape)}"
|
||||
)
|
||||
tangents_b = getattr(mesh, "tangents", None)
|
||||
material = getattr(mesh, "material", None)
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
vertices_i, faces_i, v_colors, uvs_i, normals_i = get_mesh_batch_item(mesh, i)
|
||||
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
|
||||
@ -563,6 +681,9 @@ class SaveGLB(IO.ComfyNode):
|
||||
continue
|
||||
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
|
||||
mr_img = Image.fromarray(mr_np[i], mode="RGB") if mr_np is not None else None
|
||||
nm_img = Image.fromarray(nm_np[i], mode="RGB") if nm_np is not None else None
|
||||
em_img = Image.fromarray(em_np[i], mode="RGB") if em_np is not None else None
|
||||
tangents_i = tangents_b[i, :vertices_i.shape[0]] if tangents_b is not None else None
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(
|
||||
vertices_i, faces_i,
|
||||
@ -574,6 +695,11 @@ class SaveGLB(IO.ComfyNode):
|
||||
metallic_roughness_image=mr_img,
|
||||
unlit=getattr(mesh, "unlit", False),
|
||||
normals=normals_i,
|
||||
normal_map_image=nm_img,
|
||||
tangents=tangents_i,
|
||||
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
|
||||
material=material,
|
||||
emissive_image=em_img,
|
||||
)
|
||||
results.append({
|
||||
"filename": f,
|
||||
@ -723,9 +849,11 @@ class MeshSmoothNormals(IO.ComfyNode):
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
# Crease split changes per-item vertex counts -> rebuild as a variable-size batch.
|
||||
tangents_b = getattr(mesh, "tangents", None)
|
||||
v_list, f_list, n_list = [], [], []
|
||||
c_list = [] if mesh.vertex_colors is not None else None
|
||||
u_list = [] if mesh.uvs is not None else None
|
||||
t_list = [] if tangents_b is not None else None
|
||||
for i in range(batch_size):
|
||||
v_i, f_i, c_i, u_i, _ = get_mesh_batch_item(mesh, i)
|
||||
if v_i.shape[0] == 0 or f_i.shape[0] == 0:
|
||||
@ -742,12 +870,19 @@ class MeshSmoothNormals(IO.ComfyNode):
|
||||
c_list.append(c_i[remap_t.to(c_i.device)])
|
||||
if u_list is not None:
|
||||
u_list.append(u_i[remap_t.to(u_i.device)])
|
||||
if t_list is not None:
|
||||
# Remap (not recompute) so TANGENT keeps the baked basis; split verts copy theirs.
|
||||
t_i = tangents_b[i, :v_i.shape[0]]
|
||||
t_list.append(t_i[remap_t.to(t_i.device)])
|
||||
if not v_list:
|
||||
return IO.NodeOutput(mesh)
|
||||
out = pack_variable_mesh_batch(
|
||||
v_list, f_list, colors=c_list, uvs=u_list,
|
||||
texture=mesh.texture, unlit=getattr(mesh, "unlit", False),
|
||||
normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None))
|
||||
normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None),
|
||||
tangents=t_list, normal_map=getattr(mesh, "normal_map", None),
|
||||
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
|
||||
material=getattr(mesh, "material", None), emissive=getattr(mesh, "emissive", None))
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_point
|
||||
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
|
||||
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
from server import PromptServer
|
||||
import comfy.latent_formats
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
@ -414,38 +415,44 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
"y_up" if proj_pack is not None else "z_up")}
|
||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||
|
||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||
def _dinov3_encode(model, image_bchw, image_size, want_patches=False):
|
||||
"""Run DINOv3 once at the requested resolution.
|
||||
|
||||
def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
||||
image_bchw: [B, 3, H, W] float in [0, 1] (any source resolution; resized here).
|
||||
Returns the full sequence tensor (Trellis2 path) or a dict with the global
|
||||
tokens split out + a 2D patch grid (Pixal3D path) when `want_patches=True`.
|
||||
"""
|
||||
model_internal = model.model
|
||||
device = comfy.model_management.get_torch_device()
|
||||
img_t = comfy.utils.common_upscale(image_bchw, image_size, image_size, "lanczos", "disabled").to(device)
|
||||
mean = torch.tensor(model.image_mean or [0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
|
||||
std = torch.tensor(model.image_std or [0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
|
||||
img_t = (img_t - mean) / std
|
||||
model_internal.image_size = image_size
|
||||
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
|
||||
if not want_patches:
|
||||
return tokens
|
||||
h_p = w_p = image_size // 16
|
||||
n_reg = tokens.shape[1] - 1 - h_p * w_p
|
||||
return {"tokens": tokens[:, :1 + n_reg], "patches_2d": _dinov3_patches_to_2d(tokens, image_size)}
|
||||
|
||||
|
||||
def run_conditioning(model, cropped_pil_img, include_1024=True):
|
||||
device = comfy.model_management.intermediate_device()
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
|
||||
def prepare_tensor(pil_img, size):
|
||||
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
|
||||
img_np = np.array(resized_pil).astype(np.float32) / 255.0
|
||||
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
||||
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||
|
||||
model_internal.image_size = 512
|
||||
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
||||
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
|
||||
|
||||
cond_1024 = None
|
||||
if include_1024:
|
||||
model_internal.image_size = 1024
|
||||
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
||||
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
|
||||
img_np = np.array(cropped_pil_img).astype(np.float32) / 255.0
|
||||
image_bchw = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).contiguous()
|
||||
|
||||
cond_512 = _dinov3_encode(model, image_bchw, 512)
|
||||
conditioning = {
|
||||
'cond_512': cond_512.to(device),
|
||||
'neg_cond': torch.zeros_like(cond_512).to(device),
|
||||
"cond_512": cond_512.to(device),
|
||||
"neg_cond": torch.zeros_like(cond_512).to(device),
|
||||
}
|
||||
if cond_1024 is not None:
|
||||
conditioning['cond_1024'] = cond_1024.to(device)
|
||||
|
||||
if include_1024:
|
||||
cond_1024 = _dinov3_encode(model, image_bchw, 1024)
|
||||
conditioning["cond_1024"] = cond_1024.to(device)
|
||||
return conditioning
|
||||
|
||||
class Trellis2Conditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -780,21 +787,6 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
|
||||
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
|
||||
|
||||
|
||||
def _run_dinov3_with_patches(model, composite, image_size):
|
||||
model_internal = model.model
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
img_t = comfy.utils.common_upscale(composite, image_size, image_size, "lanczos", "disabled")
|
||||
img_t = img_t.to(torch_device)
|
||||
img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||
model_internal.image_size = image_size
|
||||
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
|
||||
patches = _dinov3_patches_to_2d(tokens, image_size)
|
||||
h_p = w_p = image_size // 16
|
||||
n_reg = tokens.shape[1] - 1 - h_p * w_p
|
||||
global_tokens = tokens[:, :1 + n_reg]
|
||||
return {"tokens": global_tokens, "patches_2d": patches}
|
||||
|
||||
|
||||
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float()
|
||||
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float()
|
||||
@ -802,6 +794,12 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
|
||||
mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
|
||||
|
||||
# Detect & correct an inverted mask
|
||||
m2d = mask[0, 0]
|
||||
border = torch.cat([m2d[0, :], m2d[-1, :], m2d[:, 0], m2d[:, -1]])
|
||||
if float(border.mean()) > 0.5:
|
||||
mask = 1.0 - mask
|
||||
|
||||
H, W = img.shape[-2:]
|
||||
if max(H, W) > max_image_size:
|
||||
scale = max_image_size / max(H, W)
|
||||
@ -923,8 +921,8 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
scene_size_list.append(scene_size)
|
||||
composite_list.append(composite)
|
||||
|
||||
cond_512 = _run_dinov3_with_patches(clip_vision_model, composite, 512)
|
||||
cond_1024 = _run_dinov3_with_patches(clip_vision_model, composite, 1024)
|
||||
cond_512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=True)
|
||||
cond_1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=True)
|
||||
cond_512_list.append(cond_512["tokens"].to(device))
|
||||
cond_1024_list.append(cond_1024["tokens"].to(device))
|
||||
patches_512_list.append(cond_512["patches_2d"].to(device))
|
||||
@ -1104,8 +1102,7 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
||||
moge_per_vertex = moge_points[batch_index, sy, sx]
|
||||
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
|
||||
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
|
||||
# frame as vertices_one (Pixal3D model frame = glTF Y-up). Mirrors the
|
||||
# `verts * [1, -1, -1]` step in MoGePointMapToMesh.
|
||||
# frame as vertices_one (Pixal3D model frame = glTF Y-up).
|
||||
moge_per_vertex = moge_per_vertex * torch.tensor(
|
||||
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
|
||||
)
|
||||
@ -1188,6 +1185,7 @@ class GetMeshInfo(IO.ComfyNode):
|
||||
IO.Mesh.Output(display_name="mesh"),
|
||||
IO.String.Output(display_name="info"),
|
||||
],
|
||||
hidden=[IO.Hidden.unique_id],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1212,10 +1210,10 @@ class GetMeshInfo(IO.ComfyNode):
|
||||
f_counts = [int(mesh.faces.shape[1])] * B
|
||||
|
||||
attrs = []
|
||||
for name in ("uvs", "vertex_colors", "normals", "texture", "metallic_roughness"):
|
||||
for name in ("uvs", "vertex_colors", "normals", "tangents", "texture", "metallic_roughness", "normal_map"):
|
||||
t = getattr(mesh, name, None)
|
||||
if t is not None:
|
||||
if name in ("texture", "metallic_roughness"):
|
||||
if name in ("texture", "metallic_roughness", "normal_map"):
|
||||
attrs.append(f"{name} {int(t.shape[-3])}×{int(t.shape[-2])}") # H×W
|
||||
else:
|
||||
attrs.append(name)
|
||||
@ -1234,6 +1232,9 @@ class GetMeshInfo(IO.ComfyNode):
|
||||
|
||||
info = "\n".join(lines)
|
||||
logging.info("[GetMeshInfo]\n%s", info)
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(info, cls.hidden.unique_id)
|
||||
return IO.NodeOutput(mesh, info, ui=UI.PreviewText(info))
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user