More cleanup

This commit is contained in:
kijai 2026-06-27 00:13:13 +03:00
parent 1f7acd9354
commit 41f5f4b2c0
2 changed files with 380 additions and 416 deletions

View File

@ -180,12 +180,13 @@ class PaintMesh(IO.ComfyNode):
# =============================================================================
# Texture baking from sparse voxel volume.
#
# Pipeline: xatlas UV unwrap → OpenGL UV-space rasterize to position map →
# nearest-voxel color sample per texel → cv2.inpaint to fill UV seams →
# attach texture + UVs to the Mesh for SaveGLB to serialize.
# Pipeline: take the mesh's existing UVs → OpenGL UV-space rasterize to position
# map → nearest-voxel color sample per texel → GPU Jump-Flood fill UV seams →
# attach texture + UVs to the Mesh for SaveGLB to serialize. Unwrapping is done
# upstream (Trellis2OfficialUnwrap / TorchXatlasUVWrap); this path never unwraps.
#
# Uses comfy_extras.nodes_glsl.GLContext for OpenGL context (already handles
# GLFW / EGL / OSMesa backend selection); xatlas for UV parameterization.
# GLFW / EGL / OSMesa backend selection).
# =============================================================================
_GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache"
@ -407,6 +408,54 @@ def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution):
return vals, ok
def _trilinear_sample_sparse_gpu(positions, voxel_coords_np, color_np, resolution):
"""GPU port of `_trilinear_sample_sparse` — same normalized-over-occupied-corners
trilinear, but the per-texel 8-corner accumulation runs on CUDA via sorted-key
`searchsorted` instead of NumPy float64. This is the bake hot path (millions of
covered texels × 8 corners), so the CPU version dominates runtime; the GPU port
is ~identical numerically and 10-50× faster. Returns (vals [K,C] float32, ok
[K] bool), matching the NumPy signature."""
dev = comfy.model_management.get_torch_device()
R = int(resolution)
origin = -0.5
voxel_size = 1.0 / R
P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float()
VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long()
col = torch.from_numpy(np.ascontiguousarray(color_np)).to(dev).float()
K, C = P.shape[0], col.shape[1]
M = VC.shape[0]
# Same cell-CENTER convention as the NumPy path (see its docstring): integer
# voxel coord c sits at (c+0.5)*voxel_size + origin, so subtract 0.5 to bracket.
gc = (P - origin) / voxel_size - 0.5
base = torch.floor(gc).long()
frac = gc - base.float()
key = (VC[:, 0] * R + VC[:, 1]) * R + VC[:, 2]
skey, order = key.sort()
acc = torch.zeros((K, C), device=dev)
wsum = torch.zeros((K, 1), device=dev)
for dx in (0, 1):
wx = frac[:, 0] if dx else 1.0 - frac[:, 0]
for dy in (0, 1):
wy = frac[:, 1] if dy else 1.0 - frac[:, 1]
for dz in (0, 1):
wz = frac[:, 2] if dz else 1.0 - frac[:, 2]
cx = base[:, 0] + dx
cy = base[:, 1] + dy
cz = base[:, 2] + dz
inb = (cx >= 0) & (cx < R) & (cy >= 0) & (cy < R) & (cz >= 0) & (cz < R)
qk = (cx * R + cy) * R + cz
ins = torch.searchsorted(skey, qk).clamp(max=M - 1)
matched = inb & (skey[ins] == qk)
idx = order[ins] # garbage where !matched
w = torch.where(matched, wx * wy * wz, torch.zeros_like(wx))[:, None]
acc += w * col[idx] # w=0 cancels garbage rows
wsum += w
ok = wsum[:, 0] > 1e-8
vals = torch.zeros((K, C), device=dev)
vals[ok] = acc[ok] / wsum[ok].clamp_min(1e-8)
return vals.cpu().numpy(), ok.cpu().numpy()
def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
"""GPU nearest-occupied-voxel lookup for surface points. Voxels sit on a
regular integer grid (coord c world c/R-0.5), so the nearest voxel to a
@ -415,7 +464,7 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
of voxels and ~identical. Returns (vals [K,C] float32, found [K] bool); `found`
is False for the rare query whose nearest occupied voxel is >1 cell away (the
caller falls back to a cKDTree on just those)."""
dev = "cuda" if torch.cuda.is_available() else "cpu"
dev = comfy.model_management.get_torch_device()
R = int(resolution)
P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float()
VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long()
@ -451,6 +500,27 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
fnd |= match
return bi, fnd
def _brute_nearest(idx):
"""Exact nearest occupied voxel for a (small) query subset by chunked GPU
brute force over all M voxels. Used only for the handful of stragglers the
grid scan misses (>4 cells from any voxel) replaces a cKDTree build over
all M voxels, which costs seconds even for a few query points."""
Ps = P[idx] # [N,3] world
N = Ps.shape[0]
vox_pos = (VC.float() + 0.5) / R - 0.5 # [M,3] voxel centres
best_d = torch.full((N,), 1e30, device=dev)
best_j = torch.zeros(N, dtype=torch.long, device=dev)
# Bound the N×chunk distance matrix to ~64M elements regardless of N.
chunk = max(1, (1 << 26) // max(1, N))
for s in range(0, M, chunk):
vc = vox_pos[s:s + chunk] # [B,3]
dd = (Ps[:, None, :] - vc[None, :, :]).pow(2).sum(-1) # [N,B]
md, mj = dd.min(1)
upd = md < best_d
best_d = torch.where(upd, md, best_d)
best_j = torch.where(upd, mj + s, best_j)
return best_j
all_idx = torch.arange(K, device=dev)
best_i = torch.zeros(K, dtype=torch.long, device=dev)
found = torch.zeros(K, dtype=torch.bool, device=dev)
@ -458,27 +528,30 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
bi1, fnd1 = _search(all_idx, 1)
best_i[all_idx] = bi1
found[all_idx] = fnd1
# Pass 2: wider radius on ONLY the few misses (avoids ever building a cKDTree
# over millions of voxels just for a handful of >1-cell-away points).
# Pass 2: wider radius (9³) on ONLY the radius-1 misses.
miss = torch.nonzero(~found, as_tuple=True)[0]
if miss.numel() > 0:
bi2, fnd2 = _search(miss, 4)
best_i[miss] = bi2
found[miss] = fnd2
# Pass 3: exact GPU brute force for the few stragglers still unfound (>4 cells
# out). Always resolves them, so `found` is all-True on return — no cKDTree.
miss2 = torch.nonzero(~found, as_tuple=True)[0]
if miss2.numel() > 0:
best_i[miss2] = _brute_nearest(miss2)
found[miss2] = True
vals = col[best_i]
return vals.cpu().numpy(), found.cpu().numpy()
def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution,
mode="trilinear"):
def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution):
"""For every masked texel, sample the voxel field and return ALL its attribute
channels. Returns (H, W, C) float32 in [0, 1] where C is the voxel feature
width (3 for plain color, 6 for full PBR).
mode="trilinear" normalized trilinear over occupied voxels (the default; matches
the official o_voxel.to_glb path), with nearest fallback for texels whose 8
surrounding voxels are all empty. This is the only mode the nodes expose now.
mode="nearest" nearest-voxel; kept as an internal/dev lever (blocky)."""
Normalized trilinear over occupied voxels (matches the official o_voxel.to_glb
path), with nearest fallback for texels whose 8 surrounding voxels are all
empty."""
H, W, _ = position_map.shape
color_np = voxel_colors.detach().cpu().numpy().astype(np.float32)
C = color_np.shape[-1]
@ -496,22 +569,26 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors
valid_positions = position_map[mask]
def _nearest(query):
# GPU grid lookup; cKDTree only for the rare >1-cell miss.
# Fully on-GPU nearest-occupied-voxel: grid scan + brute-force tail. Always
# resolves every query, so no cKDTree (its build over all voxels cost ~3s).
vals, found = _nearest_voxel_sample_gpu(query, coords_np, color_np, resolution)
if not found.all():
# Defensive: only reachable on a non-CUDA device where the GPU path is
# unavailable; fall back to a one-off cKDTree.
tree = scipy.spatial.cKDTree(voxel_pos)
_, nearest_idx = tree.query(query[~found], k=1, workers=-1)
vals[~found] = color_np[nearest_idx]
return vals
if mode == "trilinear":
try:
vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution)
except Exception as e:
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear failed ({e}); falling back to CPU")
vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution)
if not ok.all():
# Texels with no occupied neighbour fall back to nearest.
vals[~ok] = _nearest(valid_positions[~ok])
out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32)
else:
out[mask] = np.clip(_nearest(valid_positions), 0.0, 1.0)
if not ok.all():
# Texels with no occupied neighbour fall back to nearest.
vals[~ok] = _nearest(valid_positions[~ok])
out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32)
return out
@ -729,7 +806,7 @@ def _back_project_positions(position_map, mask, ref_v, ref_f):
return position_map
import time as _time
dev = "cuda" if torch.cuda.is_available() else "cpu"
dev = comfy.model_management.get_torch_device()
rv = ref_v.detach().to(dev).float()
rf = ref_f.detach().to(dev).long()
tri = rv[rf]
@ -755,7 +832,7 @@ def _jfa_fill_gpu(img01, mask):
(True = covered). Returns [H,W,C] float. ~6× faster than cv2 Telea per map."""
if not mask.any():
return img01
dev = "cuda"
dev = comfy.model_management.get_torch_device()
it = torch.from_numpy(np.ascontiguousarray(img01)).to(dev).float()
mm = torch.from_numpy(np.ascontiguousarray(mask)).to(dev)
H, W = mm.shape
@ -786,41 +863,58 @@ def _jfa_fill_gpu(img01, mask):
def _seam_fill(img01, mask, inpaint_radius):
"""Fill the UV-gutter texels around covered charts so seam sampling doesn't
pull in black. GPU Jump Flooding (nearest fill) when CUDA is available, else
cv2 Telea inpaint. `inpaint_radius<=0` disables; the radius only affects the
cv2 fallback (JFA fills every uncovered texel by nearest)."""
pull in black, via GPU Jump Flooding (nearest fill). `inpaint_radius<=0`
disables; otherwise the radius is ignored JFA fills every uncovered texel
by nearest regardless."""
if inpaint_radius <= 0:
return img01
if torch.cuda.is_available():
return _jfa_fill_gpu(img01, mask)
import cv2
u8 = (img01 * 255.0).clip(0, 255).astype(np.uint8)
u8 = cv2.inpaint(u8, ((~mask).astype(np.uint8)) * 255, int(inpaint_radius), cv2.INPAINT_TELEA)
if u8.ndim == 2:
u8 = u8[..., None]
return u8.astype(np.float32) / 255.0
return _jfa_fill_gpu(img01, mask)
def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None):
"""Uniformly fit a UV layout's bbox into [0,1] when it spills outside the unit
square (preserves chart aspect ratios; handles packers that overflow slightly).
No-op when the UVs are already in [0,1] the normal case for official/xatlas
unwraps. NOT a UDIM de-tiler; warns if the span looks tiled.
Deterministic from the input UVs alone, so the texture bake and
ApplyTextureToMesh both call it to agree on the exact UVs the texture was baked
against (the bake no longer emits the mesh, so apply must re-derive them).
Returns float32 [N,2]."""
uv_np = uv_np.astype(np.float32)
uv_min = uv_np.min(axis=0)
uv_max = uv_np.max(axis=0)
out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001)
if not (normalize and out_of_unit):
return uv_np
extent = float((uv_max - uv_min).max())
span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1]))
if span > 1.5 and log_prefix:
logging.warning(
f"{log_prefix} UV span {span:.2f} looks like a tiled/UDIM layout; "
f"uniform-fitting it into [0,1] will overlap tiles. Re-unwrap upstream instead.")
if extent > 0:
uv_np = ((uv_np - uv_min) / extent).astype(np.float32)
if log_prefix:
logging.info(f"{log_prefix} normalized UVs into [0,1] (uniform scale 1/{extent:.4f})")
return uv_np
def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
resolution, texture_size, inpaint_radius=3,
fast_unwrap=True, existing_uvs=None,
normalize_uvs=True, sample_mode="trilinear",
reference=None, pbar=None):
resolution, texture_size, uvs, inpaint_radius=3,
normalize_uvs=True, reference=None, pbar=None):
"""Bake a baseColor (+ optional metallicRoughness) texture for
`vertices/faces`, rasterizing in UV space and nearest-voxel-sampling each
texel from the provided sparse colored voxel volume.
If `existing_uvs` (N, 2) is given, it is used directly and xatlas is
skipped bakes onto the mesh's current UV layout without re-unwrapping.
Otherwise xatlas computes a fresh atlas (verts/faces may grow at seams).
`uvs` (N, 2) is the mesh's existing UV layout — baked onto directly (this
node never unwraps; connect a UV unwrap node upstream). It must be 1:1 with
`vertices`.
Returns (out_vertices, out_faces, out_uvs, out_texture, out_mr).
`fast_unwrap=True` configures xatlas with permissive chart options so it
finishes in a reasonable time on large meshes at the cost of less even
UV distribution. Set False to use xatlas defaults (slow on >100k faces).
Progress: drives a local tqdm over its 5 stages (unwrap rasterize
Progress: drives a local tqdm over its 5 stages (uvs rasterize
back-project sample finalize) and, if a comfy `pbar` (ProgressBar) is
passed, ticks it once per stage too so callers should size it as 5 per
bake."""
@ -845,111 +939,25 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
v_np = vertices.detach().cpu().numpy().astype(np.float32)
f_np = faces.detach().cpu().numpy().astype(np.uint32)
fcount = int(f_np.shape[0])
t0 = time.perf_counter()
if existing_uvs is not None:
# Bake onto the mesh's current UVs — no xatlas, no seam-splitting.
uv_np = existing_uvs.detach().cpu().numpy().astype(np.float32)
if uv_np.shape[0] != v_np.shape[0]:
raise ValueError(
f"BakeTextureFromVoxel: existing UVs ({uv_np.shape[0]}) must be 1:1 "
f"with vertices ({v_np.shape[0]})."
)
uv_min = uv_np.min(axis=0)
uv_max = uv_np.max(axis=0)
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
f"{fcount} faces (xatlas skipped)")
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001)
if normalize_uvs and out_of_unit:
# Uniform fit of the UV bbox into [0,1] (preserves chart aspect ratios).
# Handles packers that overflow the unit square slightly. NOT a UDIM
# de-tiler — a true multi-tile layout would get squashed; warn if the
# span is large enough to look like tiling.
extent = float((uv_max - uv_min).max())
span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1]))
if span > 1.5:
logging.warning(
f"[BakeTextureFromVoxel] UV span {span:.2f} looks like a tiled/UDIM "
f"layout; uniform-fitting it into [0,1] will overlap tiles. "
f"Re-unwrap instead (use_existing_uvs=False)."
)
if extent > 0:
uv_np = ((uv_np - uv_min) / extent).astype(np.float32)
logging.info(f"[BakeTextureFromVoxel] normalized UVs into [0,1] "
f"(uniform scale 1/{extent:.4f})")
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
else:
import xatlas
if fcount > 300_000:
logging.warning(
f"[BakeTextureFromVoxel] mesh has {fcount} faces — xatlas chart "
f"decomposition is CPU-bound and may take many minutes. Consider "
f"decimating to under ~200k faces before baking."
)
logging.info(f"[BakeTextureFromVoxel] xatlas unwrap: {v_np.shape[0]} verts, {fcount} faces")
if fast_unwrap and hasattr(xatlas, "Atlas"):
atlas = xatlas.Atlas()
atlas.add_mesh(v_np, f_np)
logging.info(f"[BakeTextureFromVoxel] add_mesh: {time.perf_counter() - t0:.1f}s")
gen_kwargs = {}
applied = []
# ChartOptions: looser growth → larger / fewer charts → faster.
if hasattr(xatlas, "ChartOptions"):
co = xatlas.ChartOptions()
for attr, val in (
("max_iterations", 1),
("max_cost", 8.0),
("normal_deviation_weight", 1.0),
("roundness_weight", 0.0),
("straightness_weight", 0.0),
("normal_seam_weight", 1.0),
("texture_seam_weight", 0.0),
("use_input_mesh_uvs", False),
):
if hasattr(co, attr):
setattr(co, attr, val)
applied.append(f"chart.{attr}")
gen_kwargs["chart_options"] = co
# PackOptions.bruteForce defaults to True — tries many rotations per
# chart and is the single biggest contributor to pack time on small
# meshes. Off it loses ~5-15% packing efficiency but runs ~5× faster.
if hasattr(xatlas, "PackOptions"):
po = xatlas.PackOptions()
for attr, val in (
("bruteForce", False),
("brute_force", False), # snake_case alias on some builds
("create_image", False),
("createImage", False),
("padding", 2),
):
if hasattr(po, attr):
setattr(po, attr, val)
applied.append(f"pack.{attr}")
gen_kwargs["pack_options"] = po
logging.info(f"[BakeTextureFromVoxel] options applied: {applied}")
tgen = time.perf_counter()
try:
atlas.generate(**gen_kwargs)
except TypeError as e:
logging.warning(f"[BakeTextureFromVoxel] generate(**kwargs) rejected ({e}); retrying with defaults")
atlas.generate()
logging.info(f"[BakeTextureFromVoxel] generate: {time.perf_counter() - tgen:.1f}s")
tget = time.perf_counter()
vmapping, indices, uvs = atlas[0]
logging.info(f"[BakeTextureFromVoxel] retrieve: {time.perf_counter() - tget:.1f}s")
else:
vmapping, indices, uvs = xatlas.parametrize(v_np, f_np)
# Bake onto the mesh's current UVs — no unwrap, no seam-splitting.
uv_np = uvs.detach().cpu().numpy().astype(np.float32)
if uv_np.shape[0] != v_np.shape[0]:
raise ValueError(
f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 "
f"with vertices ({v_np.shape[0]})."
)
uv_min = uv_np.min(axis=0)
uv_max = uv_np.max(axis=0)
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
f"{fcount} faces")
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ")
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
logging.info(f"[BakeTextureFromVoxel] xatlas total {time.perf_counter() - t0:.1f}s "
f"({vmapping.shape[0]} verts after seams)")
new_verts = v_np[vmapping]
new_faces = indices.astype(np.uint32)
new_uvs = uvs.astype(np.float32)
_tick("unwrap")
_tick("uvs")
t1 = time.perf_counter()
position_map, mask = _bake_position_map(new_verts, new_faces, new_uvs, texture_size)
@ -968,7 +976,7 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
t2 = time.perf_counter()
attrs = _sample_voxel_attrs_per_texel(
position_map, mask, voxel_coords, voxel_colors, resolution, mode=sample_mode,
position_map, mask, voxel_coords, voxel_colors, resolution,
)
logging.info(f"[BakeTextureFromVoxel] voxel sample in {time.perf_counter() - t2:.1f}s "
f"({attrs.shape[-1]} channels)")
@ -1022,12 +1030,11 @@ def _per_vertex_normals(verts_np, faces_np):
def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resolution,
texture_size, views, blend_temperature=0.25,
inpaint_radius=3, fast_unwrap=True, existing_uvs=None,
normalize_uvs=True, sample_mode="trilinear"):
texture_size, views, uvs, blend_temperature=0.25,
inpaint_radius=3, normalize_uvs=True):
"""Bake a baseColor texture by projecting view photos onto the mesh.
Reuses bake_texture_from_voxel_fn for the xatlas unwrap + the nearest-voxel
Reuses bake_texture_from_voxel_fn for the UV-space bake + the nearest-voxel
fallback colour, then overlays photo colour on every covered+visible texel:
each texel's world position/normal is projected into each view, occlusion is
resolved with a texel z-buffer, and the views are blended weighted by how
@ -1045,8 +1052,8 @@ def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resol
# Voxel bake → unwrapped geometry + fallback colour (inpaint deferred to the end).
out_v, out_f, out_uvs, voxel_tex, voxel_mr = bake_texture_from_voxel_fn(
vertices, faces, voxel_coords, voxel_colors, resolution=resolution,
texture_size=texture_size, inpaint_radius=0, fast_unwrap=fast_unwrap,
existing_uvs=existing_uvs, normalize_uvs=normalize_uvs, sample_mode=sample_mode)
texture_size=texture_size, uvs=uvs, inpaint_radius=0,
normalize_uvs=normalize_uvs)
v_np = out_v.detach().cpu().numpy().astype(np.float32)
f_np = out_f.detach().cpu().numpy().astype(np.uint32)
@ -1078,6 +1085,16 @@ def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resol
return out_v, out_f, out_uvs, out_tex, voxel_mr
def _mr_channel(packed_mr, ch, ref):
"""Pull one channel out of a packed glTF MR map (G=roughness at idx 1, B=metallic
at idx 2) as a 3-channel grayscale IMAGE [H,W,3] in [0,1]. Returns black sized
like `ref` when there's no MR map (non-PBR voxel field)."""
if packed_mr is None:
return torch.zeros_like(ref.float().cpu())
m = packed_mr.float().clamp(0.0, 1.0).cpu()
return m[..., ch:ch + 1].expand(-1, -1, 3).contiguous()
class BakeTextureFromVoxel(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -1089,12 +1106,12 @@ class BakeTextureFromVoxel(IO.ComfyNode):
"Bakes PBR textures onto the mesh's existing UV layout by rasterizing it "
"in UV space via OpenGL (ComfyUI's PyOpenGL backend) and trilinear-sampling "
"the input sparse voxel volume. Does NOT unwrap — connect a UV unwrap node "
"(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Produces a "
"baseColor texture, plus a metallicRoughness texture when the voxel field "
"carries the full PBR set (6 channels). Returns a Mesh with `uvs`, `texture`, "
"and `metallic_roughness` attached — SaveGLB serializes them as real "
"baseColorTexture / metallicRoughnessTexture maps. UVs that spill outside "
"[0,1] are uniformly fit back into the unit square."
"(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Outputs the "
"baked maps as IMAGEs: base_color, plus metallic and roughness as separate "
"grayscale maps (both black when the voxel field has no PBR set). "
"Preview/save/post-process them, then feed them to ApplyTextureToMesh (with "
"the SAME mesh) to attach them for SaveGLB. UVs that spill outside [0,1] are "
"uniformly fit into the unit square."
),
inputs=[
IO.Mesh.Input("mesh"),
@ -1108,7 +1125,11 @@ class BakeTextureFromVoxel(IO.ComfyNode):
"o_voxel.to_glb step that removes faceted/pixelized baking on coarse "
"meshes. Pure scipy+torch, no extra deps.")),
],
outputs=[IO.Mesh.Output("mesh")],
outputs=[
IO.Image.Output(display_name="base_color"),
IO.Image.Output(display_name="metallic"),
IO.Image.Output(display_name="roughness"),
],
)
@classmethod
@ -1132,7 +1153,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
batch_idx = coords[:, 0].long()
voxel_xyz = coords[:, 1:]
mesh_batch_size = int(mesh.vertices.shape[0])
out_verts, out_faces, out_uvs, out_tex, out_mr = [], [], [], [], []
out_tex, out_mr = [], []
# 5 stage ticks per item (see bake_texture_from_voxel_fn); skipped items
# tick all 5 so the bar stays aligned.
pbar = comfy.utils.ProgressBar(mesh_batch_size * 5)
@ -1150,28 +1171,25 @@ class BakeTextureFromVoxel(IO.ComfyNode):
if reference_mesh is not None:
rv_i, rf_i, _ = get_mesh_batch_item(reference_mesh, i)
ref_i = (rv_i, rf_i)
bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn(
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
v_i, f_i, item_coords, item_colors,
resolution=resolution, texture_size=texture_size,
inpaint_radius=inpaint_radius,
existing_uvs=ev_i, reference=ref_i, pbar=pbar,
uvs=ev_i, inpaint_radius=inpaint_radius,
reference=ref_i, pbar=pbar,
)
out_verts.append(bv); out_faces.append(bf); out_uvs.append(bu)
out_tex.append(bt); out_mr.append(bmr)
if not out_verts:
return IO.NodeOutput(mesh)
# Local pack_variable_mesh_batch doesn't take uvs/texture; build the
# packed mesh ourselves so we can attach both. UVs are 1:1 with verts.
packed = pack_variable_mesh_batch(out_verts, out_faces)
max_v = packed.vertices.shape[1]
packed_uvs = out_uvs[0].new_zeros((len(out_uvs), max_v, 2))
for i, u in enumerate(out_uvs):
packed_uvs[i, :u.shape[0]] = u
packed.uvs = packed_uvs
packed.texture = torch.stack(out_tex, dim=0)
if all(mr is not None for mr in out_mr):
packed.metallic_roughness = torch.stack(out_mr, dim=0)
return IO.NodeOutput(packed)
if not out_tex:
# Every item skipped (degenerate) — emit one black map so the IMAGE
# outputs stay valid.
black = torch.zeros((1, texture_size, texture_size, 3))
return IO.NodeOutput(black, black, black)
# All maps are texture_size² — stack into [B,H,W,3] IMAGE batches. The
# packed MR (G=roughness, B=metallic) is split into separate grayscale
# maps; both black where the voxel field carried no PBR set.
base_img = torch.stack([t.float().clamp(0.0, 1.0).cpu() for t in out_tex], dim=0)
metallic_img = torch.stack([_mr_channel(m, 2, out_tex[0]) for m in out_mr], dim=0)
roughness_img = torch.stack([_mr_channel(m, 1, out_tex[0]) for m in out_mr], dim=0)
return IO.NodeOutput(base_img, metallic_img, roughness_img)
# Single-item path.
v0 = mesh.vertices.squeeze(0)
@ -1181,19 +1199,16 @@ class BakeTextureFromVoxel(IO.ComfyNode):
if reference_mesh is not None:
ref0 = (reference_mesh.vertices.squeeze(0), reference_mesh.faces.squeeze(0))
pbar = comfy.utils.ProgressBar(5) # 5 stage ticks (see bake_texture_from_voxel_fn)
bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn(
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
v0, f0, coords, colors,
resolution=resolution, texture_size=texture_size,
inpaint_radius=inpaint_radius,
existing_uvs=ev0, reference=ref0, pbar=pbar,
uvs=ev0, inpaint_radius=inpaint_radius,
reference=ref0, pbar=pbar,
)
out_mesh = Types.MESH(
vertices=bv.unsqueeze(0), faces=bf.unsqueeze(0),
uvs=bu.unsqueeze(0), texture=bt.unsqueeze(0),
)
if bmr is not None:
out_mesh.metallic_roughness = bmr.unsqueeze(0)
return IO.NodeOutput(out_mesh)
base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0)
metallic_img = _mr_channel(bmr, 2, bt).unsqueeze(0)
roughness_img = _mr_channel(bmr, 1, bt).unsqueeze(0)
return IO.NodeOutput(base_img, metallic_img, roughness_img)
class MeshTextureToImage(IO.ComfyNode):
@ -1247,6 +1262,73 @@ class MeshTextureToImage(IO.ComfyNode):
return IO.NodeOutput(base, mr, metallic, roughness)
class ApplyTextureToMesh(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ApplyTextureToMesh",
display_name="Apply Texture to Mesh",
category="latent/3d",
description=(
"Attaches baked texture IMAGEs to a mesh's existing UV layout so SaveGLB "
"serializes them as baseColorTexture / metallicRoughnessTexture maps. Pairs "
"with BakeTextureFromVoxel: feed it the SAME mesh you baked from, plus that "
"node's base_color (and optionally metallic / roughness grayscale maps) — the "
"UVs must match the ones the texture was baked against, so don't re-unwrap in "
"between. metallic and roughness are repacked into the glTF MR map "
"(G=roughness, B=metallic); leave them unconnected for non-PBR meshes (a "
"missing metallic defaults to 0, a missing roughness to 1). Lets you preview / "
"upscale / edit the baked maps before applying them."
),
inputs=[
IO.Mesh.Input("mesh"),
IO.Image.Input("base_color"),
IO.Image.Input("metallic", optional=True),
IO.Image.Input("roughness", optional=True),
],
outputs=[IO.Mesh.Output("mesh")],
)
@classmethod
def execute(cls, mesh, base_color, metallic=None, roughness=None):
mesh_uvs = getattr(mesh, "uvs", None)
if mesh_uvs is None:
raise ValueError(
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
"you fed to BakeTextureFromVoxel (this node attaches onto existing UVs and "
"never unwraps).")
# Re-derive the exact UVs the bake rasterized against — it uniformly fits
# out-of-[0,1] layouts into the unit square, so apply the identical
# deterministic transform here (per batch item, over each item's real verts).
if mesh_uvs.ndim == 3:
new_uvs = mesh_uvs.clone()
for i in range(mesh_uvs.shape[0]):
v_i, _f_i, _ = get_mesh_batch_item(mesh, i)
n = v_i.shape[0]
norm = _normalize_uvs_to_unit(mesh_uvs[i, :n].detach().cpu().numpy())
new_uvs[i, :n] = torch.from_numpy(norm).to(new_uvs)
else:
norm = _normalize_uvs_to_unit(mesh_uvs.detach().cpu().numpy())
new_uvs = torch.from_numpy(norm).to(mesh_uvs)
out_mesh = copy.copy(mesh)
out_mesh.uvs = new_uvs
out_mesh.texture = base_color.float().clamp(0.0, 1.0).cpu()
if metallic is not None or roughness is not None:
# Repack separate grayscale maps into glTF MR: R unused, G=roughness,
# B=metallic. Size defaults off whichever map is connected; a missing
# channel falls back to a sensible scalar (metal 0 / rough 1).
prov = (metallic if metallic is not None else roughness).float().clamp(0.0, 1.0).cpu()
B, H, W, _ = prov.shape
rough_ch = (roughness.float().clamp(0.0, 1.0).cpu()[..., 0:1]
if roughness is not None else torch.ones((B, H, W, 1)))
metal_ch = (metallic.float().clamp(0.0, 1.0).cpu()[..., 0:1]
if metallic is not None else torch.zeros((B, H, W, 1)))
out_mesh.metallic_roughness = torch.cat([torch.zeros((B, H, W, 1)), rough_ch, metal_ch], dim=-1)
return IO.NodeOutput(out_mesh)
def paint_mesh_default_colors(mesh):
out_mesh = copy.copy(mesh)
vertex_count = mesh.vertices.shape[1]
@ -1397,14 +1479,14 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
for c, n in zip(unique_nm.tolist(), cnt_nm.tolist()):
nm_share_breakdown.append(f"{n} edges×{c}faces")
logging.info(f"[FillHolesV2 diag] V={V} F={F} | "
logging.info(f"[FillHoles diag] V={V} F={F} | "
f"boundary(cnt==1)={n_boundary} interior(cnt==2)={n_interior} "
f"non-manifold(cnt>=3)={n_nonmanifold} (max={nm_max})")
if nm_share_breakdown:
logging.info(f"[FillHolesV2 diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}")
logging.info(f"[FillHoles diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}")
if n_boundary == 0:
logging.info("[FillHolesV2 diag] no boundary edges → no cycles to fill")
logging.info("[FillHoles diag] no boundary edges → no cycles to fill")
return
# Walk components same as production path (bidir-prop, by-vertex count).
@ -1459,15 +1541,15 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
cfan_tris = int(vert_count[cfan].sum().item()) # N tris per N-vert cycle
cfan_new_verts = int(cfan.sum().item()) # 1 centroid per centroid-fan component
logging.info(f"[FillHolesV2 diag] components={L} "
logging.info(f"[FillHoles diag] components={L} "
f"cycles={int(is_cycle.sum().item())} chains={int(is_chain.sum().item())} "
f"non-simple={int(is_other.sum().item())}")
logging.info(f"[FillHolesV2 diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})")
logging.info(f"[FillHolesV2 diag] actually kept={int(actually_kept.sum().item())} "
logging.info(f"[FillHoles diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})")
logging.info(f"[FillHoles diag] actually kept={int(actually_kept.sum().item())} "
f"cycle rejected by perim={int((is_cycle & ~cycle_perim_ok).sum().item())} "
f"cycle rejected by verts={int((is_cycle & ~cycle_size_ok).sum().item())}")
logging.info(f"[FillHolesV2 diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)")
logging.info(f"[FillHolesV2 diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts")
logging.info(f"[FillHoles diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)")
logging.info(f"[FillHoles diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts")
# Cycle vert-count distribution
if is_cycle.any():
@ -1485,12 +1567,12 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
elif s <= 20: buckets["11-20"] += n
elif s <= 50: buckets["21-50"] += n
else: buckets["51+"] += n
logging.info(f"[FillHolesV2 diag] cycle vert-count buckets: {buckets}")
logging.info(f"[FillHoles diag] cycle vert-count buckets: {buckets}")
if is_cycle.any():
cycle_perims = perim[is_cycle].cpu().tolist()
head = sorted(cycle_perims, reverse=True)[:10]
logging.info(f"[FillHolesV2 diag] top-10 cycle perimeters: "
logging.info(f"[FillHoles diag] top-10 cycle perimeters: "
f"{['%.4f' % p for p in head]}")
@ -1804,10 +1886,10 @@ def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsi
n_escalations += 1
if total_welded > 0 or n_escalations > 0:
tag = f" (escalated weld epsilon_rel→{eps:.1e} after {n_escalations} step{'s' if n_escalations != 1 else ''})" if n_escalations > 0 else ""
logging.info(f"[FillHolesV2] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}")
logging.info(f"[FillHoles] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}")
if ratio >= WELDED_THRESHOLD:
logging.warning(
f"[FillHolesV2] even at weld epsilon_rel={WELD_CAP} the mesh stays "
f"[FillHoles] even at weld epsilon_rel={WELD_CAP} the mesh stays "
f"unwelded (V/F={ratio:.2f}, want < {WELDED_THRESHOLD}). Source mesh has "
f"duplicate verts at distances >{WELD_CAP}× bbox; fix upstream "
f"(decimate node settings) or run WeldVertices manually with a larger epsilon."
@ -2077,6 +2159,23 @@ def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4):
return offset_verts, f_unwelded, None
def _fmt_count(n) -> str:
"""Compact human-readable integer for node status lines, e.g. 853, 12.3K, 1.23M."""
n = int(n)
if n >= 1_000_000:
return f"{n / 1_000_000:.2f}".rstrip("0").rstrip(".") + "M"
if n >= 1_000:
return f"{n / 1_000:.1f}".rstrip("0").rstrip(".") + "K"
return str(n)
def _fmt_face_change(n_in, n_out) -> str:
"""'faces: 1.23M → 200K (-84%)' — the count delta for decimate/remesh status."""
n_in, n_out = int(n_in), int(n_out)
pct = f" ({(n_out - n_in) / n_in * 100:+.0f}%)" if n_in else ""
return f"faces: {_fmt_count(n_in)}{_fmt_count(n_out)}{pct}"
class DecimateMesh(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -2168,7 +2267,7 @@ class DecimateMesh(IO.ComfyNode):
# Send progress text to display the face reduction on the node
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(
f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id)
_fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id)
return result
@ -2306,7 +2405,7 @@ class RemeshMesh(IO.ComfyNode):
# Send progress text to display the face change on the node
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(
f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id)
_fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id)
return result
@ -2546,7 +2645,8 @@ class UnwrapMesh(IO.ComfyNode):
if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(
f"UV: {out_v[0].shape[0]}v / {out_f[0].shape[0]}f, atlas ~{resolution}px",
f"UV: {_fmt_count(out_v[0].shape[0])} verts / {_fmt_count(out_f[0].shape[0])} faces"
f" · atlas ~{resolution}px",
cls.hidden.unique_id)
return IO.NodeOutput(out_mesh)
@ -2740,36 +2840,12 @@ class FillHoles(IO.ComfyNode):
node_id="FillHoles",
display_name="Fill Holes",
category="latent/3d",
description="Fills holes in a mesh up to a maximum perimeter threshold.",
inputs=[
IO.Mesh.Input("mesh"),
IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001,
tooltip="Maximum hole perimeter to fill. Set to 0 to disable."),
],
outputs=[IO.Mesh.Output("mesh")],
)
@classmethod
def execute(cls, mesh, max_perimeter):
def _fn(v, f, c):
if max_perimeter > 0:
v, f = fill_holes_fn(v, f, max_perimeter=max_perimeter)
return v, f, c
return _process_mesh_batch(mesh, _fn)
class FillHolesV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="FillHolesV2",
display_name="Fill Holes (v2)",
category="latent/3d",
description=(
"GPU-vectorised hole-filling via directed-half-edge pointer-doubling. "
"Drop-in alternative to FillHoles for comparison: same max_perimeter "
"cutoff and fan-from-centroid triangulation, but no Python loop, "
"auto-correct winding from face direction, and centroid colors are "
"averaged from the loop instead of left zero."
"Fills holes in a mesh up to a maximum perimeter threshold, preserving "
"the existing geometry/UVs (only patch triangles are added). GPU-vectorised "
"via directed-half-edge pointer-doubling: no Python loop, auto-correct "
"winding from face direction, and centroid colors are averaged from the hole "
"loop. Falls back to a CPU walker on non-CUDA devices."
),
inputs=[
IO.Mesh.Input("mesh"),
@ -2947,7 +3023,6 @@ class PostProcessMeshExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
FillHoles,
FillHolesV2,
WeldVertices,
DecimateMesh,
RemeshMesh,
@ -2956,6 +3031,7 @@ class PostProcessMeshExtension(ComfyExtension):
PaintMesh,
BakeTextureFromVoxel,
MeshTextureToImage,
ApplyTextureToMesh,
MergeMeshes,
]

View File

@ -1,18 +1,15 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types, io
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
from comfy.ldm.trellis2.vae import SparseTensor
from comfy.ldm.trellis2.model import (
_build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats,
)
from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
import comfy.model_management
import comfy.utils
import folder_paths
from comfy.ldm.trellis2 import sampling_preview
from PIL import Image
import logging
import os
import numpy as np
import math
import torch
@ -21,89 +18,6 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
NAFModel = io.Custom("NAF_MODEL")
# Texture latent -> base-color calibration for the per-step preview
def _tex_rgb_factors_path():
return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt")
def _pool_albedo_to_input(in_coords, out_coords, out_colors):
in_sp = in_coords[:, 1:4].long()
out_sp = out_coords[:, 1:4].long()
in_b = in_coords[:, 0].long()
out_b = out_coords[:, 0].long()
in_res = int(in_sp.max().item()) + 1
out_res = int(out_sp.max().item()) + 1
parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1)
R = in_res
in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2]
par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2]
order = torch.argsort(in_flat)
in_sorted = in_flat[order]
pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1)
matched = in_sorted[pos] == par_flat
in_idx = order[pos][matched]
cols = out_colors[matched].float()
N = in_coords.shape[0]
csum = cols.new_zeros((N, 3))
ccount = cols.new_zeros((N, 1))
csum.index_add_(0, in_idx, cols)
ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype))
valid = ccount[:, 0] > 0
albedo = torch.zeros_like(csum)
albedo[valid] = csum[valid] / ccount[valid]
return albedo, valid
def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords):
"""Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish."""
try:
dev = out_colors.device
in_latent = in_latent.to(dev)
in_coords = in_coords.to(dev)
out_coords = out_coords.to(dev)
albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors)
X = in_latent[valid].float().cpu()
Y = albedo[valid].float().cpu()
if X.shape[0] < 64:
return
Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1]
A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1]
B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3]
path = _tex_rgb_factors_path()
if os.path.exists(path):
try:
prev = torch.load(path, map_location="cpu")
A_run = A_run + prev["A"]
B_run = B_run + prev["B"]
except Exception:
pass
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({"A": A_run, "B": B_run}, path)
eye = torch.eye(A_run.shape[0])
WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3]
W, b = WB[:-1].contiguous(), WB[-1].contiguous()
sampling_preview.set_tex_rgb(W, b)
except Exception as e:
logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}")
def _load_tex_rgb_factors():
try:
path = _tex_rgb_factors_path()
if os.path.exists(path):
d = torch.load(path, map_location="cpu")
eye = torch.eye(d["A"].shape[0])
WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"])
sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous())
except Exception as e:
logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}")
_load_tex_rgb_factors()
def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
if len(sample_shape) == 5:
@ -271,7 +185,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
if coord_counts is None:
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
samples = shape_norm(samples.to(device), coords.to(device))
mesh, subs = trellis_vae.decode_shape_slat(samples, resolution)
mesh, subs = trellis_vae.decode_shape_slat(samples.to(vae.vae_dtype), resolution)
else:
split_items = split_batched_sparse_latent(samples, coords, coord_counts)
mesh = []
@ -280,7 +194,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
coords_i = coords_i.to(device).clone()
coords_i[:, 0] = 0
sample_i = shape_norm(feats_i.to(device), coords_i)
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i, resolution)
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i.to(vae.vae_dtype), resolution)
mesh.append(mesh_i[0])
subs_per_sample.append(subs_i)
@ -338,14 +252,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
samples = samples["samples"]
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
samples = samples.to(device)
cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration
cal_in_coords = coords
std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords.to(device))
samples = samples * std + mean
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
voxel = trellis_vae.decode_tex_slat(samples.to(vae.vae_dtype), shape_subdivides)
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
@ -353,12 +265,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
color_feats = voxel.feats
voxel_coords = voxel.coords
# Calibrate the latent->base_color map for the per-step texture preview.
# Done here while input coords and voxel_coords share the model frame
# (before the z_up remap below) and on the real decoded albedo.
if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3:
_calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords)
if coord_resolution is not None:
tex_resolution = int(coord_resolution) * 16
elif voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
@ -416,7 +322,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
decoded_batches = []
for start in range(0, sample_tensor.shape[0], batch_number):
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
decoded_batches.append(shape_vae.decode_structure(sample_chunk) > 0)
decoded_batches.append(shape_vae.decode_structure(sample_chunk.to(vae.vae_dtype)) > 0)
decoded = torch.cat(decoded_batches, dim=0)
current_res = decoded.shape[2]
@ -491,7 +397,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
shape_latent["samples"], shape_latent["coords"], coord_counts,
)
slat = shape_norm(feats.to(device), coords_512.to(device))
sample_hr_coords = [shape_vae.upsample_shape(slat, upsample_times=4)]
sample_hr_coords = [shape_vae.upsample_shape(slat.to(vae.vae_dtype), upsample_times=4)]
else:
items = split_batched_sparse_latent(
shape_latent["samples"], shape_latent["coords"], coord_counts,
@ -501,7 +407,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
coords_i = coords_i.to(device).clone()
coords_i[:, 0] = 0
slat_i = shape_norm(feats_i.to(device), coords_i)
sample_hr_coords.append(shape_vae.upsample_shape(slat_i, upsample_times=4))
sample_hr_coords.append(shape_vae.upsample_shape(slat_i.to(vae.vae_dtype), upsample_times=4))
# Resolution search — cache the final iteration's quantized unique tensors
# so we don't recompute .unique() per sample after picking hr_resolution.
@ -977,8 +883,10 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
if pad_l or pad_t or pad_r or pad_b:
img = torch.nn.functional.pad(img, (pad_l, pad_r, pad_t, pad_b), value=0.0)
mask = torch.nn.functional.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=0.0)
crop_x1 += pad_l; crop_x2 += pad_l
crop_y1 += pad_t; crop_y2 += pad_t
crop_x1 += pad_l
crop_x2 += pad_l
crop_y1 += pad_t
crop_y2 += pad_t
cropped_img = img [..., crop_y1:crop_y2, crop_x1:crop_x2]
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
@ -1100,7 +1008,7 @@ class Pixal3DConditioning(IO.ComfyNode):
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32)
T = _build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
T = build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
proj_pack = {
"stages": {
@ -1119,15 +1027,6 @@ class Pixal3DConditioning(IO.ComfyNode):
}
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
# proj_feat_pack rides in the conditioning dict (same place embeds, ControlNet
# hints etc. live); the sampler auto-promotes it to a model.forward kwarg via
# Trellis2.extra_conds. The same pack object is shared between pos/neg —
# CONDConstant.can_concat sees them equal and concats to a single dict, then
# Trellis2.forward zeros proj for the uncond slots via cond_or_uncond.
# Pre-compute the SS-stage proj features (dense 16³ grid) once here — the
# shape/texture stages do their own computes in their respective stage nodes.
# proj_pack lives on intermediate (CPU); force the compute onto cuda so
# the bilinear-sampling step doesn't run on CPU.
ss_proj_feats = compute_stage_proj_feats(
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
device=torch_device,
@ -1278,12 +1177,6 @@ class Pixal3DAlignObject(IO.ComfyNode):
q_mean = Q.mean(dim=0, keepdim=True)
P_c = P - p_mean
Q_c = Q - q_mean
# Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is
# noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper
# acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c)
# gets multiplied by cos(yaw) and shrinks the object. Using
# sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of
# rotation; translation still positions the mesh at MoGe's centroid.
p_var = (P_c * P_c).sum().clamp(min=1e-8)
q_var = (Q_c * Q_c).sum()
scale = float(torch.sqrt(q_var / p_var).item())
@ -1326,74 +1219,69 @@ class LoadNAFModel(IO.ComfyNode):
return IO.NodeOutput(model)
class CFGGuidanceInterval(IO.ComfyNode):
"""Generic model patch: apply CFG only during [start_percent, end_percent] of
the sampling schedule. Outside that window, skip the uncond computation and
collapse to effective cfg=1 same idea as upstream Trellis2 / Pixal3D's
guidance_interval_mixin, but lives at the sampler level (via
sampler_calc_cond_batch_function) so it works for any model.
class GetMeshInfo(IO.ComfyNode):
"""Report vertex / face counts and attributes for a MESH, displayed on the
node (and as a string output). Counts are comma-formatted since meshes can
run into the millions of faces. Passes the mesh through unchanged."""
Percents use ComfyUI's standard convention: 0.0 = start of sampling
(max-noise step), 1.0 = end of sampling (clean step). Conversion to sigma
is done via model_sampling.percent_to_sigma so the window is portable
across schedules (flow / EDM / discrete) and shift settings.
Defaults are full-range (no bypass). Upstream Trellis2 / Pixal3D
pipeline.json sets guidance_interval=[0.6, 1.0] (upstream t-space) on the
SS and shape samplers CFG active only in the first 40% of sampling.
Wire (start_percent=0.0, end_percent=0.4) on the SS / shape KSamplers to
match. Texture defaults to cfg=1 so the node is moot there."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="CFGGuidanceInterval",
category="model_patches/sampling",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001,
tooltip="Fraction of sampling at which CFG turns ON (0 = beginning)."),
IO.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001,
tooltip="Fraction of sampling at which CFG turns OFF (1 = end)."),
node_id="GetMeshInfo",
display_name="Get Mesh Info",
category="latent/3d",
inputs=[IO.Mesh.Input("mesh")],
outputs=[
IO.Mesh.Output(display_name="mesh"),
IO.String.Output(display_name="info"),
],
outputs=[IO.Model.Output()],
)
@staticmethod
def _fmt(n: int) -> str:
# e.g. 1234567 -> "1,234,567 (1.23M)"; small numbers stay plain.
s = f"{n:,}"
if n >= 1_000_000:
s += f" ({n / 1_000_000:.2f}M)"
elif n >= 10_000:
s += f" ({n / 1_000:.1f}K)"
return s
@classmethod
def execute(cls, model, start_percent, end_percent):
import comfy.samplers
def execute(cls, mesh):
B = mesh.vertices.shape[0]
# Honour per-item counts when the batch is zero-padded; else use the row sizes.
if mesh.vertex_counts is not None:
v_counts = [int(x) for x in mesh.vertex_counts.tolist()]
f_counts = [int(x) for x in mesh.face_counts.tolist()]
else:
v_counts = [int(mesh.vertices.shape[1])] * B
f_counts = [int(mesh.faces.shape[1])] * B
model_sampling = model.get_model_object("model_sampling")
# percent_to_sigma is monotonically decreasing: percent=0 -> sigma_max,
# percent=1 -> sigma_min. So start_percent < end_percent in user space
# means sigma_start > sigma_end. "Inside the window" is sigma in
# [sigma_end, sigma_start].
sigma_start = float(model_sampling.percent_to_sigma(start_percent))
sigma_end = float(model_sampling.percent_to_sigma(end_percent))
attrs = []
for name in ("uvs", "vertex_colors", "normals", "texture", "metallic_roughness"):
t = getattr(mesh, name, None)
if t is not None:
if name in ("texture", "metallic_roughness"):
attrs.append(f"{name} {int(t.shape[-3])}×{int(t.shape[-2])}") # H×W
else:
attrs.append(name)
def calc_cond_batch_with_interval(args):
sigma_val = args["sigma"][0].item()
conds = args["conds"]
input_x = args["input"]
timestep = args["sigma"]
model_ref = args["model"]
model_opts = args["model_options"]
lines = []
if B > 1:
lines.append(f"Batch: {B} meshes")
lines.append(f"Vertices: {cls._fmt(sum(v_counts))} total")
lines.append(f"Faces: {cls._fmt(sum(f_counts))} total")
for i in range(B):
lines.append(f" [{i}] {v_counts[i]:>10,} verts · {f_counts[i]:>10,} faces")
else:
lines.append(f"Vertices: {cls._fmt(v_counts[0])}")
lines.append(f"Faces: {cls._fmt(f_counts[0])}")
lines.append(f"Attributes: {', '.join(attrs) if attrs else 'none'}")
# conds is typically [cond, uncond]; uncond may be None when ComfyUI's
# global cfg=1 optimization has already pruned it.
cond = conds[0]
uncond = conds[1] if len(conds) > 1 else None
inside = sigma_end <= sigma_val <= sigma_start
if uncond is None or inside:
return comfy.samplers.calc_cond_batch(model_ref, conds, input_x, timestep, model_opts)
# Outside the window: compute cond only, mirror it into the uncond slot
# so the downstream cfg_function collapses to `cond` (effective cfg=1).
out = comfy.samplers.calc_cond_batch(model_ref, [cond], input_x, timestep, model_opts)
return [out[0], out[0]]
m = model.clone()
m.model_options["sampler_calc_cond_batch_function"] = calc_cond_batch_with_interval
return IO.NodeOutput(m)
info = "\n".join(lines)
logging.info("[GetMeshInfo]\n%s", info)
return IO.NodeOutput(mesh, info, ui=UI.PreviewText(info))
class Trellis2Extension(ComfyExtension):
@ -1411,7 +1299,7 @@ class Trellis2Extension(ComfyExtension):
VaeDecodeShapeTrellis,
VaeDecodeStructureTrellis2,
Trellis2UpsampleStage,
CFGGuidanceInterval,
GetMeshInfo,
]