From 42ac23f6f63b68c863559f84ecf9bffab4d9acb7 Mon Sep 17 00:00:00 2001 From: kijai Date: Tue, 30 Jun 2026 01:18:33 +0300 Subject: [PATCH] Normal and AO baking --- comfy_extras/nodes_mesh_postprocess.py | 725 +++++++++++++++++++++++-- comfy_extras/nodes_save_3d.py | 255 +++++++-- comfy_extras/nodes_trellis2.py | 91 ++-- 3 files changed, 920 insertions(+), 151 deletions(-) diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 090e55399..68cb3a603 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -7,7 +7,7 @@ import copy import comfy.utils import comfy.model_management from server import PromptServer -from comfy_extras.mesh3d.postprocess.qem_decimate import QEMConfig, qem_decimate_simplify, qem_cluster_decimate +from comfy_extras.mesh3d.postprocess.qem_decimate import QEMConfig, qem_decimate_simplify, qem_cluster_decimate, _compute_vertex_normals from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc, _point_tri_closest from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg @@ -166,23 +166,23 @@ class PaintMesh(IO.ComfyNode): return IO.NodeOutput(out_mesh) -def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): - """Rasterize the mesh in UV space and barycentric-interpolate the per-vertex vec3 - (world position, or any vec3 attr e.g. normals) at each covered texel. Pure torch, - tiled point-in-triangle — no GL/EGL, runs anywhere torch does. Returns (attr_map - [H,W,3] float32, mask [H,W] bool). """ +def _rasterize_uv_barycentric(faces_np, uvs_np, texture_size): + """Rasterize the mesh in UV space (tiled point-in-triangle, pure torch). Returns per-texel + face index [H,W], barycentric coords [H,W,3] and coverage mask [H,W], on the torch device. + Interpolate any per-vertex attribute from these with _interp_vertex_attr.""" dev = comfy.model_management.get_torch_device() H = W = int(texture_size) + face_idx = torch.zeros((H, W), dtype=torch.long, device=dev) + bary = torch.zeros((H, W, 3), device=dev) + cov = torch.zeros((H, W), dtype=torch.bool, device=dev) if faces_np.shape[0] == 0: - return np.zeros((H, W, 3), dtype=np.float32), np.zeros((H, W), dtype=bool) + return face_idx, bary, cov - verts = torch.from_numpy(np.ascontiguousarray(verts_np, dtype=np.float32)).to(dev) uvs = torch.from_numpy(np.ascontiguousarray(uvs_np, dtype=np.float32)).to(dev) faces = torch.from_numpy(np.ascontiguousarray(faces_np).astype(np.int64)).to(dev) # GL convention: window coord = uv * resolution, coverage tested at texel centre. tri_uv = (uvs * float(W))[faces] # [F,3,2] - tri_attr = verts[faces] # [F,3,3] x0, y0 = tri_uv[:, 0, 0], tri_uv[:, 0, 1] x1, y1 = tri_uv[:, 1, 0], tri_uv[:, 1, 1] x2, y2 = tri_uv[:, 2, 0], tri_uv[:, 2, 1] @@ -194,9 +194,6 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, H - 1).long() ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, H - 1).long() - pos_out = torch.zeros((H, W, 3), device=dev) - cov = torch.zeros((H, W), dtype=torch.bool, device=dev) - # Tile so point-in-triangle only runs over the triangles whose bbox hits the tile. TILE = 64 eps = 1e-6 @@ -224,15 +221,40 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): continue hit = inside.any(dim=0) # [th,tw] sel = inside.int().argmax(dim=0) # [th,tw] first covering local tri - b0s = b0.gather(0, sel[None]).squeeze(0) # [th,tw] bary of selected tri - b1s = b1.gather(0, sel[None]).squeeze(0) - b2s = b2.gather(0, sel[None]).squeeze(0) - p = tri_attr[idx[sel]] # [th,tw,3,3] - attr = b0s[..., None] * p[..., 0, :] + b1s[..., None] * p[..., 1, :] + b2s[..., None] * p[..., 2, :] - pos_out[ty:ty_end, tx:tx_end][hit] = attr[hit] # slice is a view → writes through + bsel = torch.stack([b0.gather(0, sel[None]).squeeze(0), + b1.gather(0, sel[None]).squeeze(0), + b2.gather(0, sel[None]).squeeze(0)], dim=-1) # [th,tw,3] + face_idx[ty:ty_end, tx:tx_end][hit] = idx[sel][hit] # slice is a view → writes through + bary[ty:ty_end, tx:tx_end][hit] = bsel[hit] cov[ty:ty_end, tx:tx_end] |= hit - return pos_out.cpu().numpy(), cov.cpu().numpy() + return face_idx, bary, cov + + +def _interp_vertex_attr(attr_v, faces, face_idx, bary, mask): + """Interpolate a per-vertex attribute [N,C] into a [H,W,C] map via a rasterized + (face_idx, bary, mask). Uncovered texels stay zero.""" + H, W = mask.shape + out = torch.zeros((H, W, attr_v.shape[1]), device=attr_v.device, dtype=attr_v.dtype) + if mask.any(): + vtri = attr_v[faces[face_idx[mask]]] # [K,3,C] + out[mask] = (bary[mask][:, :, None] * vtri).sum(1) + return out + + +def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): + """Barycentric-interpolate a per-vertex vec3 (world position, or any vec3 e.g. normals) + at each covered texel. Returns (attr_map [H,W,3] float32, mask [H,W] bool).""" + dev = comfy.model_management.get_torch_device() + H = W = int(texture_size) + if faces_np.shape[0] == 0: + return np.zeros((H, W, 3), dtype=np.float32), np.zeros((H, W), dtype=bool) + + face_idx, bary, mask = _rasterize_uv_barycentric(faces_np, uvs_np, texture_size) + verts = torch.from_numpy(np.ascontiguousarray(verts_np, dtype=np.float32)).to(dev) + faces = torch.from_numpy(np.ascontiguousarray(faces_np).astype(np.int64)).to(dev) + attr = _interp_vertex_attr(verts, faces, face_idx, bary, mask) + return attr.cpu().numpy(), mask.cpu().numpy() def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution): @@ -565,9 +587,10 @@ def _build_triangle_bvh(tri): return dict(LEAF=LEAF, left=left, right=right, nmin=nmin, nmax=nmax, order=order, T=T) -def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64): +def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64, return_face=False): """Exact closest surface point per query via per-query BVH stack traversal - (nearest-child-first), pure torch. Returns [N,3]. `max_stack` bounds the stack + (nearest-child-first), pure torch. Returns [N,3], or (points [N,3], face_idx [N]) + when return_face=True (face_idx indexes `tri`). `max_stack` bounds the stack (= tree height); overflow is counted+warned, not silently wrong.""" dev = Q.device N = Q.shape[0] @@ -582,6 +605,7 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64): stack[:, 0] = 0 best = torch.full((N,), 1e30, device=dev) bestp = Q.clone() + bestf = torch.full((N,), -1, dtype=torch.long, device=dev) active = torch.arange(N, device=dev) overflow = 0 @@ -599,12 +623,14 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64): lv = within & isleaf if bool(lv.any()): ga = a[lv] - tt = tri[order[node[lv] - LEAF]] + fidx = order[node[lv] - LEAF] # triangle index of each leaf + tt = tri[fidx] cp, d2 = _point_tri_closest(qa[lv], tt) upd = d2 < best[ga] gu = ga[upd] best[gu] = d2[upd] bestp[gu] = cp[upd] + bestf[gu] = fidx[upd] iv = within & ~isleaf if bool(iv.any()): gi = a[iv] @@ -627,6 +653,8 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64): logging.warning(f"[back-project] BVH stack overflow on {overflow} pushes " f"(max_stack={max_stack}); a few texels may be slightly off — " f"raise max_stack if this is large.") + if return_face: + return bestp, bestf return bestp @@ -652,6 +680,352 @@ def _back_project_positions(position_map, mask, ref_v, ref_f): return out +def _ray_tri_hit(o, d, tri, tmin, tmax): + """Möller-Trumbore any-hit per (ray, triangle) pair, double-sided. Returns bool [N].""" + a, b, c = tri[:, 0], tri[:, 1], tri[:, 2] + e1, e2 = b - a, c - a + p = torch.cross(d, e2, dim=-1) + det = (e1 * p).sum(-1) + inv = 1.0 / torch.where(det.abs() < 1e-20, torch.full_like(det, 1e-20), det) + tvec = o - a + u = (tvec * p).sum(-1) * inv + q = torch.cross(tvec, e1, dim=-1) + v = (d * q).sum(-1) * inv + t = (e2 * q).sum(-1) * inv + return (det.abs() > 1e-20) & (u >= 0) & (v >= 0) & (u + v <= 1) & (t > tmin) & (t < tmax) + + +def _any_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64): + """Any-hit ray test over the BVH (slab cull + Möller-Trumbore), pure torch. Returns bool + [N]: True if the ray hits any triangle in (tmin, tmax). Rays early-out once they hit.""" + dev = orig.device + N = orig.shape[0] + LEAF = bvh['LEAF'] + nmin, nmax = bvh['nmin'], bvh['nmax'] + left, right, order = bvh['left'], bvh['right'], bvh['order'] + inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs) + hit = torch.zeros(N, dtype=torch.bool, device=dev) + # int32 stack: node indices fit in 31 bits and this [N, max_stack] array dominates memory. + stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev) + sp = torch.ones(N, dtype=torch.long, device=dev) + stack[:, 0] = 0 + active = torch.arange(N, device=dev) + + def slab(node, o, i): + t1 = (nmin[node] - o) * i + t2 = (nmax[node] - o) * i + tnear = torch.minimum(t1, t2).amax(-1) + tfar = torch.maximum(t1, t2).amin(-1) + return (tfar >= tnear.clamp_min(tmin)) & (tnear <= tmax) & (tfar >= tmin) + + while active.numel() > 0: + a = active + node = stack[a, sp[a] - 1] + sp[a] = sp[a] - 1 + within = slab(node, orig[a], inv[a]) + isleaf = node >= LEAF + lv = within & isleaf + if bool(lv.any()): + ga = a[lv] + tt = tri[order[node[lv] - LEAF]] + h = _ray_tri_hit(orig[ga], dirs[ga], tt, tmin, tmax) + hit[ga[h]] = True + iv = within & ~isleaf + if bool(iv.any()): + gi = a[iv] + s0 = sp[gi] + stack[gi, s0.clamp(max=max_stack - 1)] = left[node[iv]].to(torch.int32) + sp[gi] = (s0 + 1).clamp(max=max_stack) + s1 = sp[gi] + stack[gi, s1.clamp(max=max_stack - 1)] = right[node[iv]].to(torch.int32) + sp[gi] = (s1 + 1).clamp(max=max_stack) + active = a[(sp[a] > 0) & ~hit[a]] # drop finished + already-hit rays + return hit + + +def _ray_tri_intersect(o, d, tri, tmin, tmax, cull_backface=False): + """Möller-Trumbore per (ray, triangle) pair. Returns (hit [N], t [N]) where t is the ray + parameter and hit means the meeting is in (tmin, tmax). With cull_backface, drops faces whose + outward (winding) normal points along the ray — i.e. only keep surfaces facing the origin.""" + a, b, c = tri[:, 0], tri[:, 1], tri[:, 2] + e1, e2 = b - a, c - a + p = torch.cross(d, e2, dim=-1) + det = (e1 * p).sum(-1) + inv = 1.0 / torch.where(det.abs() < 1e-20, torch.full_like(det, 1e-20), det) + tvec = o - a + u = (tvec * p).sum(-1) * inv + q = torch.cross(tvec, e1, dim=-1) + v = (d * q).sum(-1) * inv + t = (e2 * q).sum(-1) * inv + hit = (det.abs() > 1e-20) & (u >= 0) & (v >= 0) & (u + v <= 1) & (t > tmin) & (t < tmax) + if cull_backface: + hit = hit & ((torch.cross(e1, e2, dim=-1) * d).sum(-1) < 0) # keep only front-facing + return hit, t + + +def _closest_hit_rays_bvh(orig, dirs, tri, bvh, tmin=0.0, tmax=1e30, max_stack=64, cull_backface=False): + """Nearest-hit ray cast over the BVH, pure torch. Returns (t [N], face [N] long, -1 on + miss; hit [N] bool) — the closest intersection in (tmin, tmax), pruning nodes past best_t.""" + dev = orig.device + N = orig.shape[0] + LEAF = bvh['LEAF'] + nmin, nmax = bvh['nmin'], bvh['nmax'] + left, right, order = bvh['left'], bvh['right'], bvh['order'] + inv = 1.0 / torch.where(dirs.abs() < 1e-20, torch.full_like(dirs, 1e-20), dirs) + best_t = torch.full((N,), float(tmax), device=dev) + best_f = torch.full((N,), -1, dtype=torch.long, device=dev) + stack = torch.full((N, max_stack), -1, dtype=torch.int32, device=dev) + sp = torch.ones(N, dtype=torch.long, device=dev) + stack[:, 0] = 0 + active = torch.arange(N, device=dev) + + while active.numel() > 0: + a = active + node = stack[a, sp[a] - 1] + sp[a] = sp[a] - 1 + t1 = (nmin[node] - orig[a]) * inv[a] + t2 = (nmax[node] - orig[a]) * inv[a] + tnear = torch.minimum(t1, t2).amax(-1) + tfar = torch.maximum(t1, t2).amin(-1) + within = (tfar >= tnear.clamp_min(tmin)) & (tfar >= tmin) & (tnear < best_t[a]) # prune past best + isleaf = node >= LEAF + lv = within & isleaf + if bool(lv.any()): + ga = a[lv] + fidx = order[node[lv] - LEAF] + h, t = _ray_tri_intersect(orig[ga], dirs[ga], tri[fidx], tmin, tmax, cull_backface) + upd = h & (t < best_t[ga]) + gu = ga[upd] + best_t[gu] = t[upd] + best_f[gu] = fidx[upd] + iv = within & ~isleaf + if bool(iv.any()): + gi = a[iv] + s0 = sp[gi] + stack[gi, s0.clamp(max=max_stack - 1)] = left[node[iv]].to(torch.int32) + sp[gi] = (s0 + 1).clamp(max=max_stack) + s1 = sp[gi] + stack[gi, s1.clamp(max=max_stack - 1)] = right[node[iv]].to(torch.int32) + sp[gi] = (s1 + 1).clamp(max=max_stack) + active = a[sp[a] > 0] + return best_t, best_f, best_f >= 0 + + +def _onb(n): + """Branchless orthonormal basis (t, b) around unit normals n [N,3].""" + up = torch.where(n[..., 2:3].abs() < 0.999, + torch.tensor([0.0, 0.0, 1.0], device=n.device).expand_as(n), + torch.tensor([1.0, 0.0, 0.0], device=n.device).expand_as(n)) + t = torch.nn.functional.normalize(torch.cross(up, n, dim=-1), dim=-1, eps=1e-6) + return t, torch.cross(n, t, dim=-1) + + +def _bake_ambient_occlusion(high_v, high_f, low_v_np, low_f_np, low_uv_np, low_n, resolution, + num_samples=64, max_distance=0.5, strength=1.0, bias=0.01, + ray_chunk=None, pbar=None): + """Bake high-poly ambient occlusion into the low-poly's UV layout: per texel, cosine-weight + a hemisphere of rays around the normal and cast them at the high-poly. AO = 1 - hit-fraction + (cosine weighting makes the hit-fraction the estimator). Returns ao_img [H,W,3] in [0,1]. + + ray_chunk caps rays cast at once (the per-chunk BVH stack is its dominant transient VRAM); + None auto-sizes it to a slice of free VRAM — big chunks (fast) on large GPUs, small (safe) + on small ones.""" + dev = comfy.model_management.get_torch_device() + H = W = int(resolution) + S = int(num_samples) + if ray_chunk is None: + # ~376 B/ray (int32 stack max_stack*4 + a few [N,3] ray buffers); spend a quarter of free + # VRAM. Speed saturates around 4M rays/chunk, so cap there (≈2 GB peak) rather than grow + # memory for no gain; floor keeps tiny GPUs from thrashing into too many chunks. + try: + free = torch.cuda.mem_get_info(dev)[0] if dev.type == "cuda" else (2 << 30) + except Exception: + free = 2 << 30 + ray_chunk = int(min(1 << 22, max(1 << 20, (free * 0.25) / (num_samples * 4 + 200)))) + face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution) + if not mask.any(): + return np.ones((H, W, 3), dtype=np.float32) + lf = torch.from_numpy(np.ascontiguousarray(low_f_np).astype(np.int64)).to(dev) + lv = torch.from_numpy(np.ascontiguousarray(low_v_np, dtype=np.float32)).to(dev) + low_n = low_n.to(dev).float() + m = mask + vtri = lf[face_idx[m]] # [K,3] vertex ids + bsel = bary_uv[m] # [K,3] + P = (bsel[:, :, None] * lv[vtri]).sum(1) # [K,3] + Nl = torch.nn.functional.normalize((bsel[:, :, None] * low_n[vtri]).sum(1), dim=-1, eps=1e-6) + + hv = high_v.to(dev).float() + hf = high_f.to(dev).long() + tri = hv[hf] + bvh = _build_triangle_bvh(tri) + diag = float((hv.amax(0) - hv.amin(0)).norm().clamp_min(1e-6)) + biasw = max(1e-5, float(bias) * diag) + tmax = float(max_distance) * diag + + # Back-project onto the high surface, then lift along the normal: the low-poly chord can sit + # below the high surface, and casting from below floods false self-occlusion (dark blotches). + bp = _closest_points_on_mesh_bvh(P, tri, bvh) + origins = bp + Nl * biasw + + K = P.shape[0] + T, B = _onb(Nl) + occ = torch.zeros(K, device=dev) + tex_per_chunk = max(1, int(ray_chunk) // max(1, S)) + for s in range(0, K, tex_per_chunk): + e = min(s + tex_per_chunk, K) + kk = e - s + o, n, t, b = origins[s:e], Nl[s:e], T[s:e], B[s:e] + r1 = torch.rand(kk, S, device=dev) + r2 = torch.rand(kk, S, device=dev) + sr = r1.sqrt() + lz = r1.mul_(-1.0).add_(1.0).clamp_min_(0.0).sqrt_() # sqrt(1-r1) (r1 dead after sr) + ang = r2.mul_(2.0 * math.pi) # in place (r2 dead) + lx = ang.cos().mul_(sr) + ly = ang.sin().mul_(sr) + d = t[:, None, :] * lx[..., None] # cosine-weighted hemisphere, + d.addcmul_(b[:, None, :], ly[..., None]) # fused d += b*ly + d.addcmul_(n[:, None, :], lz[..., None]) # fused d += n*lz (no extra temps) + d = torch.nn.functional.normalize(d.reshape(-1, 3), dim=-1, eps=1e-6) + oo = o[:, None, :].expand(-1, S, -1).reshape(-1, 3) + hit = _any_hit_rays_bvh(oo, d, tri, bvh, tmin=biasw, tmax=tmax) + occ[s:e] = hit.reshape(kk, S).sum(1, dtype=torch.float32).div_(float(S)) # mean without a float copy + if pbar is not None: + pbar.update(1) + + ao = occ.mul_(-float(strength)).add_(1.0).clamp_(0.0, 1.0) # 1 - occ*strength, in place (occ is dead) + out = torch.ones((H, W), device=dev) + out[m] = ao + out3 = np.repeat(out.cpu().numpy()[..., None], 3, axis=2) + return _jfa_fill_gpu(np.ascontiguousarray(out3, dtype=np.float32), mask.cpu().numpy()) + + +def _compute_vertex_tangents(verts, faces, uvs, normals): + """Per-vertex tangents (Lengyel) orthonormalized against `normals`. Returns [N,4]: + unit tangent xyz + handedness w (the bitangent is w * cross(N, T)). Pure torch.""" + N = verts.shape[0] + i0, i1, i2 = faces[:, 0].long(), faces[:, 1].long(), faces[:, 2].long() + e1, e2 = verts[i1] - verts[i0], verts[i2] - verts[i0] + d1, d2 = uvs[i1] - uvs[i0], uvs[i2] - uvs[i0] + denom = d1[:, 0] * d2[:, 1] - d2[:, 0] * d1[:, 1] + r = 1.0 / torch.where(denom.abs() < 1e-20, torch.full_like(denom, 1e-20), denom) + tan = (d2[:, 1:2] * e1 - d1[:, 1:2] * e2) * r[:, None] # [F,3] + bit = (d1[:, 0:1] * e2 - d2[:, 0:1] * e1) * r[:, None] + tacc = torch.zeros((N, 3), device=verts.device, dtype=verts.dtype) + bacc = torch.zeros((N, 3), device=verts.device, dtype=verts.dtype) + for idx in (i0, i1, i2): + tacc.scatter_add_(0, idx[:, None].expand(-1, 3), tan) + bacc.scatter_add_(0, idx[:, None].expand(-1, 3), bit) + n = torch.nn.functional.normalize(normals, dim=-1, eps=1e-6) + # Gram-Schmidt: drop the normal component, then renormalize. + t = torch.nn.functional.normalize(tacc - n * (n * tacc).sum(-1, keepdim=True), dim=-1, eps=1e-6) + w = torch.sign((torch.cross(n, t, dim=-1) * bacc).sum(-1)) + w = torch.where(w == 0, torch.ones_like(w), w) # degenerate → right-handed + return torch.cat([t, w[:, None]], dim=-1) + + +def _vertex_tangents_for_item(lv, lf, uv, low_n_attr_i, dev): + """Per-item shading normals + tangents. Shared by the bake (BakeNormalMapFromMesh) and the + export attach (ApplyTextureToMesh) so their basis can't diverge. `low_n_attr_i` is the + mesh's per-item normals or None (then computed). Returns (low_n [N,3], tangents [N,4]).""" + low_n = low_n_attr_i.to(dev).float() if low_n_attr_i is not None else _compute_vertex_normals(lv, lf) + tangents = _compute_vertex_tangents(lv, lf, uv.to(dev).float(), low_n) + return low_n, tangents + + +def _barycentric(p, tri): + """Barycentric coords [N,3] of points p [N,3] wrt triangles tri [N,3,3] (per-pair).""" + a, b, c = tri[:, 0], tri[:, 1], tri[:, 2] + v0, v1, v2 = b - a, c - a, p - a + d00 = (v0 * v0).sum(-1) + d01 = (v0 * v1).sum(-1) + d11 = (v1 * v1).sum(-1) + d20 = (v2 * v0).sum(-1) + d21 = (v2 * v1).sum(-1) + denom = d00 * d11 - d01 * d01 + denom = torch.where(denom.abs() < 1e-20, torch.full_like(denom, 1e-20), denom) + v = (d11 * d20 - d01 * d21) / denom + w = (d00 * d21 - d01 * d20) / denom + return torch.stack([1.0 - v - w, v, w], dim=-1) + + +def _bake_normal_map(high_v, high_f, high_n, low_v_np, low_f_np, low_uv_np, low_n, tangents, + resolution, cage_distance=0.05, ignore_backfaces=True): + """Tangent-space normal map (glTF/OpenGL +Y) of the high-poly baked into the low-poly's UV + layout. Per texel a cage ray (along the normal, over cage_distance * bbox-diagonal) finds the + matching high-poly surface, whose normal is projected into the texel's TBN frame. + ignore_backfaces skips surfaces facing away (crevices/enclosures); misses fall back to + closest-point. Returns [H,W,3] in [0,1].""" + dev = comfy.model_management.get_torch_device() + H = W = int(resolution) + flat = np.array([0.5, 0.5, 1.0], dtype=np.float32) + + # One rasterization, then interpolate position/normal/tangent/handedness by indexing it. + face_idx, bary_uv, mask = _rasterize_uv_barycentric(low_f_np, low_uv_np, resolution) + if not mask.any(): + return np.tile(flat, (H, W, 1)) + lf = torch.from_numpy(np.ascontiguousarray(low_f_np).astype(np.int64)).to(dev) + lv = torch.from_numpy(np.ascontiguousarray(low_v_np, dtype=np.float32)).to(dev) + low_n = low_n.to(dev).float() + tangents = tangents.to(dev).float() + m = mask + fsel = face_idx[m] # [K] source face per texel + bsel = bary_uv[m] # [K,3] + vtri = lf[fsel] # [K,3] vertex ids + + def _interp(attr): # attr [N,C] -> [K,C] + return (bsel[:, :, None] * attr[vtri]).sum(1) + + P = _interp(lv) # [K,3] world pos + Nl = torch.nn.functional.normalize(_interp(low_n), dim=-1, eps=1e-6) + Tl = _interp(tangents[:, :3]) + Wl = _interp(tangents[:, 3:4])[:, 0] + + hv = high_v.to(dev).float() + hf = high_f.to(dev).long() + tri = hv[hf] + bvh = _build_triangle_bvh(tri) + + # Cage ray-cast: from cage outward, march back along -normal and take the nearest (outermost) + # hit. Closest-point is the fallback where the ray misses. + diag = float((hv.amax(0) - hv.amin(0)).norm().clamp_min(1e-6)) + cage = max(1e-6, float(cage_distance) * diag) + origin = P + Nl * cage + t_hit, f_hit, ray_hit = _closest_hit_rays_bvh(origin, -Nl, tri, bvh, tmin=0.0, tmax=2.0 * cage, + cull_backface=bool(ignore_backfaces)) + bface = f_hit.clamp_min(0) + hitpoint = origin - t_hit[:, None] * Nl + # Closest-point fallback only for texels the ray missed (usually few) — running it over every + # texel wastes a full BVH traversal on the ones already resolved by the ray. + miss = ~ray_hit + if bool(miss.any()): + bp_m, bf_m = _closest_points_on_mesh_bvh(P[miss], tri, bvh, return_face=True) + bface = bface.clone() + hitpoint = hitpoint.clone() + bface[miss] = bf_m.clamp_min(0) + hitpoint[miss] = bp_m + + htri = tri[bface] # [K,3,3] + bary = _barycentric(hitpoint, htri) + hn_tri = high_n.to(dev).float()[hf[bface]] # [K,3,3] vertex normals + Nh = torch.nn.functional.normalize((bary[:, :, None] * hn_tri).sum(1), dim=-1, eps=1e-6) + + # Per-texel TBN (Gram-Schmidt tangent against the interpolated normal). + T = torch.nn.functional.normalize(Tl - Nl * (Nl * Tl).sum(-1, keepdim=True), dim=-1, eps=1e-6) + Bn = Wl[:, None] * torch.cross(Nl, T, dim=-1) + nz = (Nh * Nl).sum(-1) # reused as z-channel and the back-face test + ts = torch.stack([(Nh * T).sum(-1), (Nh * Bn).sum(-1), nz], dim=-1) + ts = torch.nn.functional.normalize(ts, dim=-1, eps=1e-6) + # Safety net: if the matched high normal faces away from the texel (a back surface the fallback + # grabbed in a deep crevice), use the flat base normal rather than a wrong one. + ts[nz < 0.0] = torch.tensor([0.0, 0.0, 1.0], device=dev) + enc = ts.mul_(0.5).add_(0.5).clamp_(0.0, 1.0) # encode in place (ts is dead) + + out = torch.from_numpy(np.tile(flat, (H, W, 1))).to(dev) + out[m] = enc + # Dilate into the UV gutter so bilinear/mip sampling at chart edges doesn't bleed flat blue. + return _jfa_fill_gpu(out.cpu().numpy(), mask.cpu().numpy()) + + def _jfa_fill_gpu(img01, mask): """Fill uncovered texels with nearest covered value via GPU Jump Flooding (O(log n) passes; replaces cv2.inpaint). img01 [H,W,C] float, mask [H,W] bool.""" @@ -919,15 +1293,17 @@ class MeshTextureToImage(IO.ComfyNode): display_name="Mesh Texture to Image", category="latent/3d", description=( - "Extracts a mesh's baked textures as IMAGE outputs: base_color and the packed " - "glTF MR map (G=roughness, B=metallic; black if no PBR texture)." + "Extracts a mesh's baked textures as individual IMAGEs: base_color, metallic, " + "roughness, occlusion and normal_map. Channels with nothing baked come back " + "neutral (occlusion white, normal flat)." ), inputs=[IO.Mesh.Input("mesh")], outputs=[ IO.Image.Output(display_name="base_color"), - IO.Image.Output(display_name="metallic_roughness"), IO.Image.Output(display_name="metallic"), IO.Image.Output(display_name="roughness"), + IO.Image.Output(display_name="occlusion"), + IO.Image.Output(display_name="normal_map"), ], ) @@ -944,6 +1320,7 @@ class MeshTextureToImage(IO.ComfyNode): base = _as_image(getattr(mesh, "texture", None)) mr = _as_image(getattr(mesh, "metallic_roughness", None)) + normal_map = _as_image(getattr(mesh, "normal_map", None)) if base is None: raise ValueError( @@ -952,10 +1329,18 @@ class MeshTextureToImage(IO.ComfyNode): ) if mr is None: mr = torch.zeros_like(base) - # Split packed MR into grayscale previews (G=roughness, B=metallic), to 3ch. + if normal_map is None: + normal_map = torch.ones_like(base) * torch.tensor([0.5, 0.5, 1.0]) # neutral flat normal + # Unpack the ORM map (R=occlusion, G=roughness, B=metallic) to 3-channel grayscale. metallic = mr[..., 2:3].expand(-1, -1, -1, 3).contiguous() roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous() - return IO.NodeOutput(base, mr, metallic, roughness) + # R is real occlusion only if AO was baked; else it's the unused zero channel, which as + # "occlusion" would read fully-dark — so report white unless occlusion_in_mr is set. + if getattr(mesh, "occlusion_in_mr", False): + occlusion = mr[..., 0:1].expand(-1, -1, -1, 3).contiguous() + else: + occlusion = torch.ones_like(base) + return IO.NodeOutput(base, metallic, roughness, occlusion, normal_map) class ApplyTextureToMesh(IO.ComfyNode): @@ -966,29 +1351,26 @@ class ApplyTextureToMesh(IO.ComfyNode): display_name="Apply Texture to Mesh", category="latent/3d", description=( - "Attaches baked texture IMAGEs to a mesh's existing UV layout for SaveGLB. " - "Pairs with BakeTextureFromVoxel: feed the SAME mesh and its base_color " - "(optionally metallic/roughness); don't re-unwrap in between. metallic/roughness " - "repack into the glTF MR map (G=roughness, B=metallic); missing metallic=0, " - "roughness=1." + "Attaches baked texture IMAGEs to a mesh's UV layout for SaveGLB. Feed the SAME mesh you baked" ), inputs=[ IO.Mesh.Input("mesh"), IO.Image.Input("base_color"), IO.Image.Input("metallic", optional=True), IO.Image.Input("roughness", optional=True), + IO.Image.Input("occlusion", optional=True), + IO.Image.Input("normal_map", optional=True), ], outputs=[IO.Mesh.Output("mesh")], ) @classmethod - def execute(cls, mesh, base_color, metallic=None, roughness=None): + def execute(cls, mesh, base_color, metallic=None, roughness=None, occlusion=None, normal_map=None): mesh_uvs = getattr(mesh, "uvs", None) if mesh_uvs is None: raise ValueError( "ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh " - "you fed to BakeTextureFromVoxel (this node attaches onto existing UVs and " - "never unwraps).") + "you fed to BakeTextureFromVoxel.") # Re-derive the exact UVs the bake used (shared _normalize_uvs_to_unit), per item. if mesh_uvs.ndim == 3: @@ -1005,15 +1387,264 @@ class ApplyTextureToMesh(IO.ComfyNode): out_mesh = copy.copy(mesh) out_mesh.uvs = new_uvs out_mesh.texture = base_color.float().clamp(0.0, 1.0).cpu() - if metallic is not None or roughness is not None: - # Repack glTF MR (G=roughness, B=metallic); missing channel → scalar (metal 0/rough 1). - prov = (metallic if metallic is not None else roughness).float().clamp(0.0, 1.0).cpu() - B, H, W, _ = prov.shape - rough_ch = (roughness.float().clamp(0.0, 1.0).cpu()[..., 0:1] - if roughness is not None else torch.ones((B, H, W, 1))) - metal_ch = (metallic.float().clamp(0.0, 1.0).cpu()[..., 0:1] - if metallic is not None else torch.zeros((B, H, W, 1))) - out_mesh.metallic_roughness = torch.cat([torch.zeros((B, H, W, 1)), rough_ch, metal_ch], dim=-1) + if normal_map is not None: + # Recompute tangents (shared helper, same normalized UVs → same basis as the bake) + # and export the smooth normals the TBN was built on — without a NORMAL attribute the + # viewer shades flat and the tangent-space detail fights the faceting. + dev = comfy.model_management.get_torch_device() + low_n_attr = getattr(mesh, "normals", None) + B = int(mesh.vertices.shape[0]) + Nmax = int(mesh.vertices.shape[1]) if mesh.vertices.ndim == 3 else int(mesh.vertices.shape[0]) + tangents_padded = torch.zeros((B, Nmax, 4), dtype=torch.float32) + normals_padded = torch.zeros((B, Nmax, 3), dtype=torch.float32) + for i in range(B): + v_i, f_i, _ = get_mesh_batch_item(mesh, i) + n = int(v_i.shape[0]) + if f_i.numel() == 0: + continue + lv, lf = v_i.to(dev).float(), f_i.to(dev).long() + uv_i = new_uvs[i, :n] if new_uvs.ndim == 3 else new_uvs[:n] + n_attr_i = low_n_attr[i, :n] if low_n_attr is not None else None + low_n, tangents = _vertex_tangents_for_item(lv, lf, uv_i, n_attr_i, dev) + tangents_padded[i, :n] = tangents.cpu() + normals_padded[i, :n] = low_n.cpu() + out_mesh.normal_map = normal_map.float().clamp(0.0, 1.0).cpu() + out_mesh.tangents = tangents_padded + out_mesh.normals = normals_padded + if metallic is not None or roughness is not None or occlusion is not None: + # Pack glTF ORM (R=occlusion, G=roughness, B=metallic); missing → 1/1/0. Maps may + # arrive at different resolutions, so resize each channel to a common H×W first. + provided = [x for x in (metallic, roughness, occlusion) if x is not None] + B = int(provided[0].shape[0]) + H = max(int(x.shape[1]) for x in provided) + W = max(int(x.shape[2]) for x in provided) + + def _chan(img, default): + if img is None: + return torch.full((B, H, W, 1), float(default)) + t = img.float().clamp(0.0, 1.0).cpu()[..., 0:1] + if int(t.shape[1]) != H or int(t.shape[2]) != W: + t = torch.nn.functional.interpolate(t.permute(0, 3, 1, 2), size=(H, W), + mode="bilinear", align_corners=False).permute(0, 2, 3, 1) + return t + + out_mesh.metallic_roughness = torch.cat( + [_chan(occlusion, 1.0), _chan(roughness, 1.0), _chan(metallic, 0.0)], dim=-1) + if occlusion is not None: + # Tells SaveGLB to also point occlusionTexture at the MR image (R = AO). + out_mesh.occlusion_in_mr = True + return IO.NodeOutput(out_mesh) + + +class BakeNormalMapFromMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BakeNormalMapFromMesh", + display_name="Bake Normal Map from Mesh", + category="latent/3d", + description=( + "Bakes a tangent-space normal map (glTF/OpenGL +Y) from a high-poly mesh into a " + "low-poly's UV layout, capturing detail lost to decimation. Feed the UV-unwrapped " + "low_poly and the same-frame high_poly it was decimated from. Outputs an IMAGE for " + "ApplyTextureToMesh's normal_map input." + ), + inputs=[ + IO.Mesh.Input("low_poly"), + IO.Mesh.Input("high_poly"), + IO.Int.Input("resolution", default=1024, min=64, max=8192, step=64, + display_name="resolution"), + IO.Float.Input("cage_distance", default=0.05, min=0.001, max=0.5, step=0.001, + tooltip="Surface search band, as a fraction of the bbox diagonal. " + "Raise for wrong/missing patches under heavy decimation; " + "lower if it grabs across gaps."), + IO.Boolean.Input("ignore_backfaces", default=True, + tooltip="Skip high-poly surfaces facing away from the texel, so " + "crevices/enclosed spaces don't grab the opposite wall. " + "Disable only if the high-poly winding is inconsistent."), + ], + outputs=[IO.Image.Output(display_name="normal_map")], + ) + + @classmethod + def execute(cls, low_poly, high_poly, resolution, cage_distance=0.05, ignore_backfaces=True): + low_uvs = getattr(low_poly, "uvs", None) + if low_uvs is None: + raise ValueError( + "BakeNormalMapFromMesh: low_poly has no UVs. Connect the UV-unwrapped " + "low-poly (the same one you fed to BakeTextureFromVoxel); this node bakes " + "onto existing UVs and never unwraps.") + dev = comfy.model_management.get_torch_device() + + low_n_attr = getattr(low_poly, "normals", None) + high_n_attr = getattr(high_poly, "normals", None) + B = int(low_poly.vertices.shape[0]) + h_batch = int(high_poly.vertices.shape[0]) + + imgs = [] + for i in range(B): + v_i, f_i, _ = get_mesh_batch_item(low_poly, i) + n = int(v_i.shape[0]) + if f_i.numel() == 0: + logging.warning(f"BakeNormalMapFromMesh: skipping batch {i} (empty mesh)") + imgs.append(torch.full((int(resolution), int(resolution), 3), 0.5)) + continue + + uv_i = low_uvs[i, :n] if low_uvs.ndim == 3 else low_uvs[:n] + uv_np = _normalize_uvs_to_unit(uv_i.detach().cpu().numpy(), log_prefix="[BakeNormalMapFromMesh] ") + + lv = v_i.to(dev).float() + lf = f_i.to(dev).long() + # Tangents build the per-texel TBN; ApplyTextureToMesh recomputes the same basis on export. + n_attr_i = low_n_attr[i, :n] if low_n_attr is not None else None + low_n, tangents = _vertex_tangents_for_item(lv, lf, torch.from_numpy(uv_np).to(dev), n_attr_i, dev) + + hv_i, hf_i, _ = get_mesh_batch_item(high_poly, i if h_batch > 1 else 0) + hv = hv_i.to(dev).float() + hf = hf_i.to(dev).long() + high_n = (high_n_attr[i, :hv.shape[0]].to(dev).float() if high_n_attr is not None + else _compute_vertex_normals(hv, hf)) + + img = _bake_normal_map( + hv, hf, high_n, + lv.detach().cpu().numpy(), lf.detach().cpu().numpy().astype(np.uint32), uv_np, + low_n, tangents, resolution, cage_distance=float(cage_distance), + ignore_backfaces=bool(ignore_backfaces), + ) + imgs.append(torch.from_numpy(np.ascontiguousarray(img)).float()) + + normal_img = torch.stack([t.clamp(0.0, 1.0) for t in imgs], dim=0) + return IO.NodeOutput(normal_img) + + +class BakeAmbientOcclusion(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BakeAmbientOcclusion", + display_name="Bake Ambient Occlusion", + category="latent/3d", + description=( + "Bakes an ambient-occlusion map from a high-poly mesh into a low-poly's UV " + "layout (white = open, dark = crevices). Feed the UV-unwrapped low_poly and the " + "high_poly it was decimated from. Outputs a grayscale IMAGE for " + "ApplyTextureToMesh's occlusion input (packed into the ORM map / occlusionTexture)." + ), + inputs=[ + IO.Mesh.Input("low_poly"), + IO.Mesh.Input("high_poly"), + IO.Int.Input("resolution", default=1024, min=64, max=8192, step=64), + IO.Int.Input("samples", default=64, min=4, max=1024, step=4, + tooltip="Rays per texel. More = smoother, slower. Raise if grainy."), + IO.Float.Input("max_distance", default=0.5, min=0.01, max=2.0, step=0.01, + tooltip="Ray length, as a fraction of the bbox diagonal. " + "Smaller = tighter, more local occlusion."), + IO.Float.Input("strength", default=1.0, min=0.0, max=2.0, step=0.05, + tooltip="Scales the occlusion. >1 darkens, <1 lightens."), + IO.Float.Input("bias", default=0.01, min=0.0001, max=0.2, step=0.0005, + tooltip="Ray origin lift off the surface, as a fraction of the bbox " + "diagonal. Raise if even surfaces show dark blotches/holes."), + ], + outputs=[IO.Image.Output(display_name="occlusion")], + ) + + @classmethod + def execute(cls, low_poly, high_poly, resolution, samples, max_distance, strength, bias): + low_uvs = getattr(low_poly, "uvs", None) + if low_uvs is None: + raise ValueError( + "BakeAmbientOcclusion: low_poly has no UVs. Connect the UV-unwrapped low-poly " + "(the same one used for the other bakes); this node never unwraps.") + dev = comfy.model_management.get_torch_device() + low_n_attr = getattr(low_poly, "normals", None) + B = int(low_poly.vertices.shape[0]) + h_batch = int(high_poly.vertices.shape[0]) + + pbar = comfy.utils.ProgressBar(max(1, B)) # one tick per batch item + imgs = [] + for i in range(B): + v_i, f_i, _ = get_mesh_batch_item(low_poly, i) + n = int(v_i.shape[0]) + if f_i.numel() == 0: + logging.warning(f"BakeAmbientOcclusion: skipping batch {i} (empty mesh)") + imgs.append(torch.ones((int(resolution), int(resolution), 3))) + pbar.update(1) + continue + + uv_i = low_uvs[i, :n] if low_uvs.ndim == 3 else low_uvs[:n] + uv_np = _normalize_uvs_to_unit(uv_i.detach().cpu().numpy(), log_prefix="[BakeAmbientOcclusion] ") + lv = v_i.to(dev).float() + lf = f_i.to(dev).long() + low_n = (low_n_attr[i, :n].to(dev).float() if low_n_attr is not None + else _compute_vertex_normals(lv, lf)) + + hv_i, hf_i, _ = get_mesh_batch_item(high_poly, i if h_batch > 1 else 0) + img = _bake_ambient_occlusion( + hv_i.to(dev).float(), hf_i.to(dev).long(), + lv.detach().cpu().numpy(), lf.detach().cpu().numpy().astype(np.uint32), uv_np, + low_n, resolution, num_samples=int(samples), + max_distance=float(max_distance), strength=float(strength), bias=float(bias), + ) + imgs.append(torch.from_numpy(np.ascontiguousarray(img)).float()) + pbar.update(1) + + ao_img = torch.stack([t.clamp(0.0, 1.0) for t in imgs], dim=0) + return IO.NodeOutput(ao_img) + + +class SetMeshMaterial(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SetMeshMaterial", + display_name="Set Mesh Material", + category="latent/3d", + description=( + "Sets glTF material properties SaveGLB can't derive from textures: emissive " + "(color + strength + optional texture), baseColor tint, metallic/roughness " + "factors, normal scale, occlusion strength, double-sided. Place before SaveGLB." + ), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Float.Input("emissive_r", default=0.0, min=0.0, max=1.0, step=0.01), + IO.Float.Input("emissive_g", default=0.0, min=0.0, max=1.0, step=0.01), + IO.Float.Input("emissive_b", default=0.0, min=0.0, max=1.0, step=0.01), + IO.Float.Input("emissive_strength", default=1.0, min=0.0, max=100.0, step=0.1, + tooltip=">1 for HDR glow (KHR_materials_emissive_strength)."), + IO.Image.Input("emissive_texture", optional=True), + IO.Float.Input("base_color_r", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Float.Input("base_color_g", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Float.Input("base_color_b", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Float.Input("metallic_factor", default=-1.0, min=-1.0, max=1.0, step=0.01, + tooltip="-1 = leave auto; 0..1 overrides."), + IO.Float.Input("roughness_factor", default=-1.0, min=-1.0, max=1.0, step=0.01, + tooltip="-1 = leave auto; 0..1 overrides."), + IO.Float.Input("normal_scale", default=1.0, min=0.0, max=10.0, step=0.05), + IO.Float.Input("occlusion_strength", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Boolean.Input("double_sided", default=True), + ], + outputs=[IO.Mesh.Output("mesh")], + ) + + @classmethod + def execute(cls, mesh, emissive_r, emissive_g, emissive_b, emissive_strength, + base_color_r, base_color_g, base_color_b, metallic_factor, roughness_factor, + normal_scale, occlusion_strength, double_sided, emissive_texture=None): + out_mesh = copy.copy(mesh) + material = dict(getattr(mesh, "material", {}) or {}) # merge over any prior material + material.update({ + "emissive_factor": [float(emissive_r), float(emissive_g), float(emissive_b)], + "emissive_strength": float(emissive_strength), + "base_color_factor": [float(base_color_r), float(base_color_g), float(base_color_b), 1.0], + "metallic_factor": float(metallic_factor), # <0 => leave auto + "roughness_factor": float(roughness_factor), + "normal_scale": float(normal_scale), + "occlusion_strength": float(occlusion_strength), + "double_sided": bool(double_sided), + }) + out_mesh.material = material + if emissive_texture is not None: + out_mesh.emissive = emissive_texture.float().clamp(0.0, 1.0).cpu() return IO.NodeOutput(out_mesh) @@ -2278,8 +2909,7 @@ class MergeMeshes(IO.ComfyNode): category="latent/3d", description=( "Concatenate N meshes into one by offsetting face indices and stacking verts, " - "faces, uvs, and colors. E.g. combine a Pixal3D object with a MoGe background " - "into one GLB." + "faces, uvs, and colors." ), inputs=[ IO.Autogrow.Input("meshes", template=autogrow_template), @@ -2306,6 +2936,9 @@ class PostProcessMeshExtension(ComfyExtension): BakeTextureFromVoxel, MeshTextureToImage, ApplyTextureToMesh, + BakeNormalMapFromMesh, + BakeAmbientOcclusion, + SetMeshMaterial, MergeMeshes, ] diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index ef471eeee..d7afd1713 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -20,11 +20,11 @@ from comfy_api.latest import ComfyExtension, IO, Types def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False, - normals=None, metallic_roughness=None): - # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, - # stashing per-item lengths as runtime attrs so consumers can recover the real slice. - # colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. - # texture is (B, H, W, 3) — passed through unchanged + normals=None, metallic_roughness=None, tangents=None, normal_map=None, + occlusion_in_mr=False, material=None, emissive=None): + # Pack per-item tensors into padded batches, stashing per-item lengths as runtime attrs. + # colors/uvs/normals/tangents are 1:1 with vertices (padded to max_vertices); texture/ + # metallic_roughness/normal_map are (B,H,W,*) image stacks passed through unchanged. batch_size = len(vertices) max_vertices = max(v.shape[0] for v in vertices) max_faces = max(f.shape[0] for f in faces) @@ -65,11 +65,31 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non ) packed_normals[i, :nrm.shape[0]] = nrm - return Types.MESH(packed_vertices, packed_faces, - uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, - metallic_roughness=metallic_roughness, - vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit, - normals=packed_normals) + packed_tangents = None + if tangents is not None: + packed_tangents = tangents[0].new_zeros((batch_size, max_vertices, tangents[0].shape[1])) + for i, tn in enumerate(tangents): + assert tn.shape[0] == vertices[i].shape[0], ( + f"tangents[{i}] has {tn.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)" + ) + packed_tangents[i, :tn.shape[0]] = tn + + out = Types.MESH(packed_vertices, packed_faces, + uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, + metallic_roughness=metallic_roughness, + vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit, + normals=packed_normals) + if packed_tangents is not None: + out.tangents = packed_tangents + if normal_map is not None: + out.normal_map = normal_map + if occlusion_in_mr: + out.occlusion_in_mr = True + if material is not None: + out.material = material + if emissive is not None: + out.emissive = emissive + return out def get_mesh_batch_item(mesh, index): @@ -180,7 +200,8 @@ def _compute_vertex_normals(vertices_np, faces_np, crease_angle=None): def save_glb(vertices, faces, filepath, metadata=None, uvs=None, vertex_colors=None, texture_image=None, metallic_roughness_image=None, unlit=False, - normals=None): + normals=None, normal_map_image=None, tangents=None, occlusion_in_mr=False, + material=None, emissive_image=None): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -197,6 +218,16 @@ def save_glb(vertices, faces, filepath, metadata=None, normals: torch.Tensor of shape (N, 3) - Optional per-vertex normals, written as the glTF NORMAL attribute. When omitted, NO normals are written and viewers fall back to flat (per-face) shading — use the MeshSmoothNormals node to generate them. + normal_map_image: PIL.Image - Optional tangent-space normal map (glTF/OpenGL +Y), + written as the material normalTexture. Needs TEXCOORD_0. + tangents: torch.Tensor of shape (N, 4) - Optional per-vertex tangents (xyz + handedness w), + written as the glTF TANGENT attribute. Without it viewers derive tangents in-shader. + occlusion_in_mr: bool - When True, R of metallic_roughness_image holds AO (ORM packing) and + occlusionTexture is pointed at that same image. + material: dict - Optional scalar overrides from SetMeshMaterial (base_color_factor, + metallic/roughness_factor with <0 = auto, emissive_factor/strength, normal_scale, + occlusion_strength, double_sided). + emissive_image: PIL.Image - Optional emissive (glow) texture, written as emissiveTexture. """ # Convert tensors to numpy arrays @@ -231,6 +262,11 @@ def save_glb(vertices, faces, filepath, metadata=None, raise ValueError( f"save_glb: normals has {normals_np.shape[0]} entries but vertex count is {n_verts}" ) + tangents_np = tangents.cpu().numpy().astype(np.float32) if tangents is not None else None + if tangents_np is not None and tangents_np.shape != (n_verts, 4): + raise ValueError( + f"save_glb: tangents must be (N, 4) with N={n_verts}, got {tuple(tangents_np.shape)}" + ) faces_np = faces_signed.astype(np.uint32) texture_png_bytes = None if texture_image is not None: @@ -242,46 +278,60 @@ def save_glb(vertices, faces, filepath, metadata=None, buf = BytesIO() metallic_roughness_image.save(buf, format="PNG") mr_png_bytes = buf.getvalue() + nm_png_bytes = None + if normal_map_image is not None: + buf = BytesIO() + normal_map_image.save(buf, format="PNG") + nm_png_bytes = buf.getvalue() + em_png_bytes = None + if emissive_image is not None: + buf = BytesIO() + emissive_image.save(buf, format="PNG") + em_png_bytes = buf.getvalue() vertices_buffer = vertices_np.tobytes() indices_buffer = faces_np.tobytes() uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b"" colors_buffer = colors_np.tobytes() if colors_np is not None else b"" normals_buffer = normals_np.tobytes() if normals_np is not None else b"" + tangents_buffer = tangents_np.tobytes() if tangents_np is not None else b"" texture_buffer = texture_png_bytes if texture_png_bytes is not None else b"" mr_buffer = mr_png_bytes if mr_png_bytes is not None else b"" + nm_buffer = nm_png_bytes if nm_png_bytes is not None else b"" + em_buffer = em_png_bytes if em_png_bytes is not None else b"" def pad_to_4_bytes(buffer): padding_length = (4 - (len(buffer) % 4)) % 4 return buffer + b'\x00' * padding_length - vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) - indices_buffer_padded = pad_to_4_bytes(indices_buffer) - uvs_buffer_padded = pad_to_4_bytes(uvs_buffer) - colors_buffer_padded = pad_to_4_bytes(colors_buffer) - normals_buffer_padded = pad_to_4_bytes(normals_buffer) - texture_buffer_padded = pad_to_4_bytes(texture_buffer) - mr_buffer_padded = pad_to_4_bytes(mr_buffer) - - buffer_data = b"".join([ - vertices_buffer_padded, - indices_buffer_padded, - uvs_buffer_padded, - colors_buffer_padded, - normals_buffer_padded, - texture_buffer_padded, - mr_buffer_padded, - ]) + # Blob order in one place; offsets accumulated in a pass so adding a buffer is one entry. + _blobs = [ + ("vertices", vertices_buffer), ("indices", indices_buffer), ("uvs", uvs_buffer), + ("colors", colors_buffer), ("normals", normals_buffer), ("tangents", tangents_buffer), + ("texture", texture_buffer), ("mr", mr_buffer), ("nm", nm_buffer), ("em", em_buffer), + ] + byte_offset = {} + acc = 0 + parts = [] + for name, b in _blobs: + padded = pad_to_4_bytes(b) + byte_offset[name] = acc + acc += len(padded) + parts.append(padded) + buffer_data = b"".join(parts) vertices_byte_length = len(vertices_buffer) - vertices_byte_offset = 0 indices_byte_length = len(indices_buffer) - indices_byte_offset = len(vertices_buffer_padded) - uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded) - colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded) - normals_byte_offset = colors_byte_offset + len(colors_buffer_padded) - texture_byte_offset = normals_byte_offset + len(normals_buffer_padded) - mr_byte_offset = texture_byte_offset + len(texture_buffer_padded) + vertices_byte_offset = byte_offset["vertices"] + indices_byte_offset = byte_offset["indices"] + uvs_byte_offset = byte_offset["uvs"] + colors_byte_offset = byte_offset["colors"] + normals_byte_offset = byte_offset["normals"] + tangents_byte_offset = byte_offset["tangents"] + texture_byte_offset = byte_offset["texture"] + mr_byte_offset = byte_offset["mr"] + nm_byte_offset = byte_offset["nm"] + em_byte_offset = byte_offset["em"] buffer_views = [ { @@ -368,6 +418,23 @@ def save_glb(vertices, faces, filepath, metadata=None, }) primitive_attributes["NORMAL"] = accessor_idx + if tangents_np is not None and len(tangents_np) > 0: + buffer_views.append({ + "buffer": 0, + "byteOffset": tangents_byte_offset, + "byteLength": len(tangents_buffer), + "target": 34962 + }) + accessor_idx = len(accessors) + accessors.append({ + "bufferView": len(buffer_views) - 1, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(tangents_np), + "type": "VEC4", # xyz tangent + w handedness (glTF TANGENT) + }) + primitive_attributes["TANGENT"] = accessor_idx + primitive = { "attributes": primitive_attributes, "indices": 1, @@ -379,9 +446,24 @@ def save_glb(vertices, faces, filepath, metadata=None, samplers = [] materials = [] extensions_used = [] + + def add_image_texture(png_byte_offset, png_byte_length): + """Append an embedded PNG image + a texture referencing it; return the texture index.""" + buffer_views.append({"buffer": 0, "byteOffset": png_byte_offset, "byteLength": png_byte_length}) + images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) + if not samplers: + samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) + textures.append({"source": len(images) - 1, "sampler": 0}) + return len(textures) - 1 + + has_uv = "TEXCOORD_0" in primitive_attributes if unlit and texture_png_bytes is None: # Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a # gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours. + if nm_png_bytes is not None or em_png_bytes is not None or occlusion_in_mr or material is not None: + logging.warning( + "save_glb: unlit material ignores normal/occlusion/emissive maps and SetMeshMaterial " + "overrides — those are PBR-lit features. Disable unlit to export them.") materials.append({ "pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0}, "extensions": {"KHR_materials_unlit": {}}, @@ -395,37 +477,57 @@ def save_glb(vertices, faces, filepath, metadata=None, "roughnessFactor": 0.5, "baseColorFactor": [0.22, 0.22, 0.22, 1.0], } - if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: - buffer_views.append({ - "buffer": 0, - "byteOffset": texture_byte_offset, - "byteLength": len(texture_buffer), - }) - images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) - samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) - textures.append({"source": len(images) - 1, "sampler": 0}) - pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0} + if texture_png_bytes is not None and has_uv: + pbr["baseColorTexture"] = {"index": add_image_texture(texture_byte_offset, len(texture_buffer)), "texCoord": 0} - if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: - buffer_views.append({ - "buffer": 0, - "byteOffset": mr_byte_offset, - "byteLength": len(mr_buffer), - }) - images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) - if not samplers: - samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) - textures.append({"source": len(images) - 1, "sampler": 0}) - pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0} + if mr_png_bytes is not None and has_uv: + mr_texture_index = add_image_texture(mr_byte_offset, len(mr_buffer)) + pbr["metallicRoughnessTexture"] = {"index": mr_texture_index, "texCoord": 0} # When a metallicRoughness texture is present, the factors scale it; use 1.0 # so the texture values pass through unchanged (glTF convention). pbr["metallicFactor"] = 1.0 pbr["roughnessFactor"] = 1.0 - materials.append({ + mat = material if isinstance(material, dict) else {} + # Scalar overrides from SetMeshMaterial (factor < 0 means "leave auto"). + if mat.get("base_color_factor") is not None: + pbr["baseColorFactor"] = [float(x) for x in mat["base_color_factor"]] + if mat.get("metallic_factor", -1.0) >= 0.0: + pbr["metallicFactor"] = float(mat["metallic_factor"]) + if mat.get("roughness_factor", -1.0) >= 0.0: + pbr["roughnessFactor"] = float(mat["roughness_factor"]) + + material = { "pbrMetallicRoughness": pbr, - "doubleSided": True, - }) + "doubleSided": bool(mat.get("double_sided", True)), + } + if occlusion_in_mr and mr_png_bytes is not None and has_uv: + # ORM packing: occlusionTexture reuses the MR image (glTF reads its R channel). + material["occlusionTexture"] = {"index": mr_texture_index, "texCoord": 0, + "strength": float(mat.get("occlusion_strength", 1.0))} + if nm_png_bytes is not None and has_uv: + material["normalTexture"] = {"index": add_image_texture(nm_byte_offset, len(nm_buffer)), + "texCoord": 0, "scale": float(mat.get("normal_scale", 1.0))} + + emissive_factor = [float(x) for x in mat.get("emissive_factor", [0.0, 0.0, 0.0])] + emissive_strength = float(mat.get("emissive_strength", 1.0)) + has_em_tex = em_png_bytes is not None and has_uv + if any(c > 0.0 for c in emissive_factor) or has_em_tex: + # glTF multiplies emissiveFactor × texture, so a texture with no color would go black; + # default the factor to white in that case. + if has_em_tex and not any(c > 0.0 for c in emissive_factor): + emissive_factor = [1.0, 1.0, 1.0] + material["emissiveFactor"] = [min(1.0, c) for c in emissive_factor] + if has_em_tex: + material["emissiveTexture"] = {"index": add_image_texture(em_byte_offset, len(em_buffer)), + "texCoord": 0} + if emissive_strength != 1.0: + material.setdefault("extensions", {})["KHR_materials_emissive_strength"] = { + "emissiveStrength": emissive_strength} + if "KHR_materials_emissive_strength" not in extensions_used: + extensions_used.append("KHR_materials_emissive_strength") + + materials.append(material) primitive["material"] = 0 gltf = { @@ -556,6 +658,22 @@ class SaveGLB(IO.ComfyNode): assert mr_np.ndim == 4 and mr_np.shape[-1] == 3, ( f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}" ) + nm_b = getattr(mesh, "normal_map", None) + nm_np = None + if nm_b is not None: + nm_np = (nm_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8) + assert nm_np.ndim == 4 and nm_np.shape[-1] == 3, ( + f"normal_map must be (B, H, W, 3), got shape {tuple(nm_np.shape)}" + ) + em_b = getattr(mesh, "emissive", None) + em_np = None + if em_b is not None: + em_np = (em_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8) + assert em_np.ndim == 4 and em_np.shape[-1] == 3, ( + f"emissive must be (B, H, W, 3), got shape {tuple(em_np.shape)}" + ) + tangents_b = getattr(mesh, "tangents", None) + material = getattr(mesh, "material", None) for i in range(mesh.vertices.shape[0]): vertices_i, faces_i, v_colors, uvs_i, normals_i = get_mesh_batch_item(mesh, i) if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0: @@ -563,6 +681,9 @@ class SaveGLB(IO.ComfyNode): continue tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None mr_img = Image.fromarray(mr_np[i], mode="RGB") if mr_np is not None else None + nm_img = Image.fromarray(nm_np[i], mode="RGB") if nm_np is not None else None + em_img = Image.fromarray(em_np[i], mode="RGB") if em_np is not None else None + tangents_i = tangents_b[i, :vertices_i.shape[0]] if tangents_b is not None else None f = f"{filename}_{counter:05}_.glb" save_glb( vertices_i, faces_i, @@ -574,6 +695,11 @@ class SaveGLB(IO.ComfyNode): metallic_roughness_image=mr_img, unlit=getattr(mesh, "unlit", False), normals=normals_i, + normal_map_image=nm_img, + tangents=tangents_i, + occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False), + material=material, + emissive_image=em_img, ) results.append({ "filename": f, @@ -723,9 +849,11 @@ class MeshSmoothNormals(IO.ComfyNode): return IO.NodeOutput(out) # Crease split changes per-item vertex counts -> rebuild as a variable-size batch. + tangents_b = getattr(mesh, "tangents", None) v_list, f_list, n_list = [], [], [] c_list = [] if mesh.vertex_colors is not None else None u_list = [] if mesh.uvs is not None else None + t_list = [] if tangents_b is not None else None for i in range(batch_size): v_i, f_i, c_i, u_i, _ = get_mesh_batch_item(mesh, i) if v_i.shape[0] == 0 or f_i.shape[0] == 0: @@ -742,12 +870,19 @@ class MeshSmoothNormals(IO.ComfyNode): c_list.append(c_i[remap_t.to(c_i.device)]) if u_list is not None: u_list.append(u_i[remap_t.to(u_i.device)]) + if t_list is not None: + # Remap (not recompute) so TANGENT keeps the baked basis; split verts copy theirs. + t_i = tangents_b[i, :v_i.shape[0]] + t_list.append(t_i[remap_t.to(t_i.device)]) if not v_list: return IO.NodeOutput(mesh) out = pack_variable_mesh_batch( v_list, f_list, colors=c_list, uvs=u_list, texture=mesh.texture, unlit=getattr(mesh, "unlit", False), - normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None)) + normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None), + tangents=t_list, normal_map=getattr(mesh, "normal_map", None), + occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False), + material=getattr(mesh, "material", None), emissive=getattr(mesh, "emissive", None)) return IO.NodeOutput(out) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index bb948831f..ad775fc80 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -5,6 +5,7 @@ from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_point from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch +from server import PromptServer import comfy.latent_formats import comfy.model_management import comfy.utils @@ -414,38 +415,44 @@ class Trellis2UpsampleStage(IO.ComfyNode): "y_up" if proj_pack is not None else "z_up")} return IO.NodeOutput(positive_out, negative_out, out_latent) -dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) -dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) +def _dinov3_encode(model, image_bchw, image_size, want_patches=False): + """Run DINOv3 once at the requested resolution. -def run_conditioning(model, cropped_img_tensor, include_1024=True): + image_bchw: [B, 3, H, W] float in [0, 1] (any source resolution; resized here). + Returns the full sequence tensor (Trellis2 path) or a dict with the global + tokens split out + a 2D patch grid (Pixal3D path) when `want_patches=True`. + """ model_internal = model.model + device = comfy.model_management.get_torch_device() + img_t = comfy.utils.common_upscale(image_bchw, image_size, image_size, "lanczos", "disabled").to(device) + mean = torch.tensor(model.image_mean or [0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) + std = torch.tensor(model.image_std or [0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) + img_t = (img_t - mean) / std + model_internal.image_size = image_size + tokens = model_internal(img_t, skip_norm_elementwise=True)[0] + if not want_patches: + return tokens + h_p = w_p = image_size // 16 + n_reg = tokens.shape[1] - 1 - h_p * w_p + return {"tokens": tokens[:, :1 + n_reg], "patches_2d": _dinov3_patches_to_2d(tokens, image_size)} + + +def run_conditioning(model, cropped_pil_img, include_1024=True): device = comfy.model_management.intermediate_device() - torch_device = comfy.model_management.get_torch_device() - def prepare_tensor(pil_img, size): - resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) - img_np = np.array(resized_pil).astype(np.float32) / 255.0 - img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) - return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) - - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img_tensor, 512) - cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] - - cond_1024 = None - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img_tensor, 1024) - cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] + img_np = np.array(cropped_pil_img).astype(np.float32) / 255.0 + image_bchw = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).contiguous() + cond_512 = _dinov3_encode(model, image_bchw, 512) conditioning = { - 'cond_512': cond_512.to(device), - 'neg_cond': torch.zeros_like(cond_512).to(device), + "cond_512": cond_512.to(device), + "neg_cond": torch.zeros_like(cond_512).to(device), } - if cond_1024 is not None: - conditioning['cond_1024'] = cond_1024.to(device) - + if include_1024: + cond_1024 = _dinov3_encode(model, image_bchw, 1024) + conditioning["cond_1024"] = cond_1024.to(device) return conditioning + class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -780,21 +787,6 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16): return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous() -def _run_dinov3_with_patches(model, composite, image_size): - model_internal = model.model - torch_device = comfy.model_management.get_torch_device() - img_t = comfy.utils.common_upscale(composite, image_size, image_size, "lanczos", "disabled") - img_t = img_t.to(torch_device) - img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) - model_internal.image_size = image_size - tokens = model_internal(img_t, skip_norm_elementwise=True)[0] - patches = _dinov3_patches_to_2d(tokens, image_size) - h_p = w_p = image_size // 16 - n_reg = tokens.shape[1] - 1 - h_p * w_p - global_tokens = tokens[:, :1 + n_reg] - return {"tokens": global_tokens, "patches_2d": patches} - - def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float() mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float() @@ -802,6 +794,12 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0 mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0 + # Detect & correct an inverted mask + m2d = mask[0, 0] + border = torch.cat([m2d[0, :], m2d[-1, :], m2d[:, 0], m2d[:, -1]]) + if float(border.mean()) > 0.5: + mask = 1.0 - mask + H, W = img.shape[-2:] if max(H, W) > max_image_size: scale = max_image_size / max(H, W) @@ -923,8 +921,8 @@ class Pixal3DConditioning(IO.ComfyNode): scene_size_list.append(scene_size) composite_list.append(composite) - cond_512 = _run_dinov3_with_patches(clip_vision_model, composite, 512) - cond_1024 = _run_dinov3_with_patches(clip_vision_model, composite, 1024) + cond_512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=True) + cond_1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=True) cond_512_list.append(cond_512["tokens"].to(device)) cond_1024_list.append(cond_1024["tokens"].to(device)) patches_512_list.append(cond_512["patches_2d"].to(device)) @@ -1104,8 +1102,7 @@ class Pixal3DAlignObject(IO.ComfyNode): moge_per_vertex = moge_points[batch_index, sy, sx] # MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF # Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same - # frame as vertices_one (Pixal3D model frame = glTF Y-up). Mirrors the - # `verts * [1, -1, -1]` step in MoGePointMapToMesh. + # frame as vertices_one (Pixal3D model frame = glTF Y-up). moge_per_vertex = moge_per_vertex * torch.tensor( [1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device ) @@ -1188,6 +1185,7 @@ class GetMeshInfo(IO.ComfyNode): IO.Mesh.Output(display_name="mesh"), IO.String.Output(display_name="info"), ], + hidden=[IO.Hidden.unique_id], ) @staticmethod @@ -1212,10 +1210,10 @@ class GetMeshInfo(IO.ComfyNode): f_counts = [int(mesh.faces.shape[1])] * B attrs = [] - for name in ("uvs", "vertex_colors", "normals", "texture", "metallic_roughness"): + for name in ("uvs", "vertex_colors", "normals", "tangents", "texture", "metallic_roughness", "normal_map"): t = getattr(mesh, name, None) if t is not None: - if name in ("texture", "metallic_roughness"): + if name in ("texture", "metallic_roughness", "normal_map"): attrs.append(f"{name} {int(t.shape[-3])}×{int(t.shape[-2])}") # H×W else: attrs.append(name) @@ -1234,6 +1232,9 @@ class GetMeshInfo(IO.ComfyNode): info = "\n".join(lines) logging.info("[GetMeshInfo]\n%s", info) + + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(info, cls.hidden.unique_id) return IO.NodeOutput(mesh, info, ui=UI.PreviewText(info))