mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
More cleanup
This commit is contained in:
parent
1f7acd9354
commit
41f5f4b2c0
@ -180,12 +180,13 @@ class PaintMesh(IO.ComfyNode):
|
||||
# =============================================================================
|
||||
# Texture baking from sparse voxel volume.
|
||||
#
|
||||
# Pipeline: xatlas UV unwrap → OpenGL UV-space rasterize to position map →
|
||||
# nearest-voxel color sample per texel → cv2.inpaint to fill UV seams →
|
||||
# attach texture + UVs to the Mesh for SaveGLB to serialize.
|
||||
# Pipeline: take the mesh's existing UVs → OpenGL UV-space rasterize to position
|
||||
# map → nearest-voxel color sample per texel → GPU Jump-Flood fill UV seams →
|
||||
# attach texture + UVs to the Mesh for SaveGLB to serialize. Unwrapping is done
|
||||
# upstream (Trellis2OfficialUnwrap / TorchXatlasUVWrap); this path never unwraps.
|
||||
#
|
||||
# Uses comfy_extras.nodes_glsl.GLContext for OpenGL context (already handles
|
||||
# GLFW / EGL / OSMesa backend selection); xatlas for UV parameterization.
|
||||
# GLFW / EGL / OSMesa backend selection).
|
||||
# =============================================================================
|
||||
|
||||
_GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache"
|
||||
@ -407,6 +408,54 @@ def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution):
|
||||
return vals, ok
|
||||
|
||||
|
||||
def _trilinear_sample_sparse_gpu(positions, voxel_coords_np, color_np, resolution):
|
||||
"""GPU port of `_trilinear_sample_sparse` — same normalized-over-occupied-corners
|
||||
trilinear, but the per-texel 8-corner accumulation runs on CUDA via sorted-key
|
||||
`searchsorted` instead of NumPy float64. This is the bake hot path (millions of
|
||||
covered texels × 8 corners), so the CPU version dominates runtime; the GPU port
|
||||
is ~identical numerically and 10-50× faster. Returns (vals [K,C] float32, ok
|
||||
[K] bool), matching the NumPy signature."""
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
R = int(resolution)
|
||||
origin = -0.5
|
||||
voxel_size = 1.0 / R
|
||||
P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float()
|
||||
VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long()
|
||||
col = torch.from_numpy(np.ascontiguousarray(color_np)).to(dev).float()
|
||||
K, C = P.shape[0], col.shape[1]
|
||||
M = VC.shape[0]
|
||||
# Same cell-CENTER convention as the NumPy path (see its docstring): integer
|
||||
# voxel coord c sits at (c+0.5)*voxel_size + origin, so subtract 0.5 to bracket.
|
||||
gc = (P - origin) / voxel_size - 0.5
|
||||
base = torch.floor(gc).long()
|
||||
frac = gc - base.float()
|
||||
key = (VC[:, 0] * R + VC[:, 1]) * R + VC[:, 2]
|
||||
skey, order = key.sort()
|
||||
acc = torch.zeros((K, C), device=dev)
|
||||
wsum = torch.zeros((K, 1), device=dev)
|
||||
for dx in (0, 1):
|
||||
wx = frac[:, 0] if dx else 1.0 - frac[:, 0]
|
||||
for dy in (0, 1):
|
||||
wy = frac[:, 1] if dy else 1.0 - frac[:, 1]
|
||||
for dz in (0, 1):
|
||||
wz = frac[:, 2] if dz else 1.0 - frac[:, 2]
|
||||
cx = base[:, 0] + dx
|
||||
cy = base[:, 1] + dy
|
||||
cz = base[:, 2] + dz
|
||||
inb = (cx >= 0) & (cx < R) & (cy >= 0) & (cy < R) & (cz >= 0) & (cz < R)
|
||||
qk = (cx * R + cy) * R + cz
|
||||
ins = torch.searchsorted(skey, qk).clamp(max=M - 1)
|
||||
matched = inb & (skey[ins] == qk)
|
||||
idx = order[ins] # garbage where !matched
|
||||
w = torch.where(matched, wx * wy * wz, torch.zeros_like(wx))[:, None]
|
||||
acc += w * col[idx] # w=0 cancels garbage rows
|
||||
wsum += w
|
||||
ok = wsum[:, 0] > 1e-8
|
||||
vals = torch.zeros((K, C), device=dev)
|
||||
vals[ok] = acc[ok] / wsum[ok].clamp_min(1e-8)
|
||||
return vals.cpu().numpy(), ok.cpu().numpy()
|
||||
|
||||
|
||||
def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
||||
"""GPU nearest-occupied-voxel lookup for surface points. Voxels sit on a
|
||||
regular integer grid (coord c ↔ world c/R-0.5), so the nearest voxel to a
|
||||
@ -415,7 +464,7 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
||||
of voxels and ~identical. Returns (vals [K,C] float32, found [K] bool); `found`
|
||||
is False for the rare query whose nearest occupied voxel is >1 cell away (the
|
||||
caller falls back to a cKDTree on just those)."""
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
R = int(resolution)
|
||||
P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float()
|
||||
VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long()
|
||||
@ -451,6 +500,27 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
||||
fnd |= match
|
||||
return bi, fnd
|
||||
|
||||
def _brute_nearest(idx):
|
||||
"""Exact nearest occupied voxel for a (small) query subset by chunked GPU
|
||||
brute force over all M voxels. Used only for the handful of stragglers the
|
||||
grid scan misses (>4 cells from any voxel) — replaces a cKDTree build over
|
||||
all M voxels, which costs seconds even for a few query points."""
|
||||
Ps = P[idx] # [N,3] world
|
||||
N = Ps.shape[0]
|
||||
vox_pos = (VC.float() + 0.5) / R - 0.5 # [M,3] voxel centres
|
||||
best_d = torch.full((N,), 1e30, device=dev)
|
||||
best_j = torch.zeros(N, dtype=torch.long, device=dev)
|
||||
# Bound the N×chunk distance matrix to ~64M elements regardless of N.
|
||||
chunk = max(1, (1 << 26) // max(1, N))
|
||||
for s in range(0, M, chunk):
|
||||
vc = vox_pos[s:s + chunk] # [B,3]
|
||||
dd = (Ps[:, None, :] - vc[None, :, :]).pow(2).sum(-1) # [N,B]
|
||||
md, mj = dd.min(1)
|
||||
upd = md < best_d
|
||||
best_d = torch.where(upd, md, best_d)
|
||||
best_j = torch.where(upd, mj + s, best_j)
|
||||
return best_j
|
||||
|
||||
all_idx = torch.arange(K, device=dev)
|
||||
best_i = torch.zeros(K, dtype=torch.long, device=dev)
|
||||
found = torch.zeros(K, dtype=torch.bool, device=dev)
|
||||
@ -458,27 +528,30 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
||||
bi1, fnd1 = _search(all_idx, 1)
|
||||
best_i[all_idx] = bi1
|
||||
found[all_idx] = fnd1
|
||||
# Pass 2: wider radius on ONLY the few misses (avoids ever building a cKDTree
|
||||
# over millions of voxels just for a handful of >1-cell-away points).
|
||||
# Pass 2: wider radius (9³) on ONLY the radius-1 misses.
|
||||
miss = torch.nonzero(~found, as_tuple=True)[0]
|
||||
if miss.numel() > 0:
|
||||
bi2, fnd2 = _search(miss, 4)
|
||||
best_i[miss] = bi2
|
||||
found[miss] = fnd2
|
||||
# Pass 3: exact GPU brute force for the few stragglers still unfound (>4 cells
|
||||
# out). Always resolves them, so `found` is all-True on return — no cKDTree.
|
||||
miss2 = torch.nonzero(~found, as_tuple=True)[0]
|
||||
if miss2.numel() > 0:
|
||||
best_i[miss2] = _brute_nearest(miss2)
|
||||
found[miss2] = True
|
||||
vals = col[best_i]
|
||||
return vals.cpu().numpy(), found.cpu().numpy()
|
||||
|
||||
|
||||
def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution,
|
||||
mode="trilinear"):
|
||||
def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution):
|
||||
"""For every masked texel, sample the voxel field and return ALL its attribute
|
||||
channels. Returns (H, W, C) float32 in [0, 1] where C is the voxel feature
|
||||
width (3 for plain color, 6 for full PBR).
|
||||
|
||||
mode="trilinear" — normalized trilinear over occupied voxels (the default; matches
|
||||
the official o_voxel.to_glb path), with nearest fallback for texels whose 8
|
||||
surrounding voxels are all empty. This is the only mode the nodes expose now.
|
||||
mode="nearest" — nearest-voxel; kept as an internal/dev lever (blocky)."""
|
||||
Normalized trilinear over occupied voxels (matches the official o_voxel.to_glb
|
||||
path), with nearest fallback for texels whose 8 surrounding voxels are all
|
||||
empty."""
|
||||
H, W, _ = position_map.shape
|
||||
color_np = voxel_colors.detach().cpu().numpy().astype(np.float32)
|
||||
C = color_np.shape[-1]
|
||||
@ -496,22 +569,26 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors
|
||||
valid_positions = position_map[mask]
|
||||
|
||||
def _nearest(query):
|
||||
# GPU grid lookup; cKDTree only for the rare >1-cell miss.
|
||||
# Fully on-GPU nearest-occupied-voxel: grid scan + brute-force tail. Always
|
||||
# resolves every query, so no cKDTree (its build over all voxels cost ~3s).
|
||||
vals, found = _nearest_voxel_sample_gpu(query, coords_np, color_np, resolution)
|
||||
if not found.all():
|
||||
# Defensive: only reachable on a non-CUDA device where the GPU path is
|
||||
# unavailable; fall back to a one-off cKDTree.
|
||||
tree = scipy.spatial.cKDTree(voxel_pos)
|
||||
_, nearest_idx = tree.query(query[~found], k=1, workers=-1)
|
||||
vals[~found] = color_np[nearest_idx]
|
||||
return vals
|
||||
|
||||
if mode == "trilinear":
|
||||
try:
|
||||
vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution)
|
||||
except Exception as e:
|
||||
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear failed ({e}); falling back to CPU")
|
||||
vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution)
|
||||
if not ok.all():
|
||||
# Texels with no occupied neighbour fall back to nearest.
|
||||
vals[~ok] = _nearest(valid_positions[~ok])
|
||||
out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32)
|
||||
else:
|
||||
out[mask] = np.clip(_nearest(valid_positions), 0.0, 1.0)
|
||||
if not ok.all():
|
||||
# Texels with no occupied neighbour fall back to nearest.
|
||||
vals[~ok] = _nearest(valid_positions[~ok])
|
||||
out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32)
|
||||
return out
|
||||
|
||||
|
||||
@ -729,7 +806,7 @@ def _back_project_positions(position_map, mask, ref_v, ref_f):
|
||||
return position_map
|
||||
|
||||
import time as _time
|
||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
rv = ref_v.detach().to(dev).float()
|
||||
rf = ref_f.detach().to(dev).long()
|
||||
tri = rv[rf]
|
||||
@ -755,7 +832,7 @@ def _jfa_fill_gpu(img01, mask):
|
||||
(True = covered). Returns [H,W,C] float. ~6× faster than cv2 Telea per map."""
|
||||
if not mask.any():
|
||||
return img01
|
||||
dev = "cuda"
|
||||
dev = comfy.model_management.get_torch_device()
|
||||
it = torch.from_numpy(np.ascontiguousarray(img01)).to(dev).float()
|
||||
mm = torch.from_numpy(np.ascontiguousarray(mask)).to(dev)
|
||||
H, W = mm.shape
|
||||
@ -786,41 +863,58 @@ def _jfa_fill_gpu(img01, mask):
|
||||
|
||||
def _seam_fill(img01, mask, inpaint_radius):
|
||||
"""Fill the UV-gutter texels around covered charts so seam sampling doesn't
|
||||
pull in black. GPU Jump Flooding (nearest fill) when CUDA is available, else
|
||||
cv2 Telea inpaint. `inpaint_radius<=0` disables; the radius only affects the
|
||||
cv2 fallback (JFA fills every uncovered texel by nearest)."""
|
||||
pull in black, via GPU Jump Flooding (nearest fill). `inpaint_radius<=0`
|
||||
disables; otherwise the radius is ignored — JFA fills every uncovered texel
|
||||
by nearest regardless."""
|
||||
if inpaint_radius <= 0:
|
||||
return img01
|
||||
if torch.cuda.is_available():
|
||||
return _jfa_fill_gpu(img01, mask)
|
||||
import cv2
|
||||
u8 = (img01 * 255.0).clip(0, 255).astype(np.uint8)
|
||||
u8 = cv2.inpaint(u8, ((~mask).astype(np.uint8)) * 255, int(inpaint_radius), cv2.INPAINT_TELEA)
|
||||
if u8.ndim == 2:
|
||||
u8 = u8[..., None]
|
||||
return u8.astype(np.float32) / 255.0
|
||||
return _jfa_fill_gpu(img01, mask)
|
||||
|
||||
|
||||
def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None):
|
||||
"""Uniformly fit a UV layout's bbox into [0,1] when it spills outside the unit
|
||||
square (preserves chart aspect ratios; handles packers that overflow slightly).
|
||||
No-op when the UVs are already in [0,1] — the normal case for official/xatlas
|
||||
unwraps. NOT a UDIM de-tiler; warns if the span looks tiled.
|
||||
|
||||
Deterministic from the input UVs alone, so the texture bake and
|
||||
ApplyTextureToMesh both call it to agree on the exact UVs the texture was baked
|
||||
against (the bake no longer emits the mesh, so apply must re-derive them).
|
||||
|
||||
Returns float32 [N,2]."""
|
||||
uv_np = uv_np.astype(np.float32)
|
||||
uv_min = uv_np.min(axis=0)
|
||||
uv_max = uv_np.max(axis=0)
|
||||
out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001)
|
||||
if not (normalize and out_of_unit):
|
||||
return uv_np
|
||||
extent = float((uv_max - uv_min).max())
|
||||
span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1]))
|
||||
if span > 1.5 and log_prefix:
|
||||
logging.warning(
|
||||
f"{log_prefix} UV span {span:.2f} looks like a tiled/UDIM layout; "
|
||||
f"uniform-fitting it into [0,1] will overlap tiles. Re-unwrap upstream instead.")
|
||||
if extent > 0:
|
||||
uv_np = ((uv_np - uv_min) / extent).astype(np.float32)
|
||||
if log_prefix:
|
||||
logging.info(f"{log_prefix} normalized UVs into [0,1] (uniform scale 1/{extent:.4f})")
|
||||
return uv_np
|
||||
|
||||
|
||||
def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
||||
resolution, texture_size, inpaint_radius=3,
|
||||
fast_unwrap=True, existing_uvs=None,
|
||||
normalize_uvs=True, sample_mode="trilinear",
|
||||
reference=None, pbar=None):
|
||||
resolution, texture_size, uvs, inpaint_radius=3,
|
||||
normalize_uvs=True, reference=None, pbar=None):
|
||||
"""Bake a baseColor (+ optional metallicRoughness) texture for
|
||||
`vertices/faces`, rasterizing in UV space and nearest-voxel-sampling each
|
||||
texel from the provided sparse colored voxel volume.
|
||||
|
||||
If `existing_uvs` (N, 2) is given, it is used directly and xatlas is
|
||||
skipped — bakes onto the mesh's current UV layout without re-unwrapping.
|
||||
Otherwise xatlas computes a fresh atlas (verts/faces may grow at seams).
|
||||
`uvs` (N, 2) is the mesh's existing UV layout — baked onto directly (this
|
||||
node never unwraps; connect a UV unwrap node upstream). It must be 1:1 with
|
||||
`vertices`.
|
||||
|
||||
Returns (out_vertices, out_faces, out_uvs, out_texture, out_mr).
|
||||
|
||||
`fast_unwrap=True` configures xatlas with permissive chart options so it
|
||||
finishes in a reasonable time on large meshes — at the cost of less even
|
||||
UV distribution. Set False to use xatlas defaults (slow on >100k faces).
|
||||
|
||||
Progress: drives a local tqdm over its 5 stages (unwrap → rasterize →
|
||||
Progress: drives a local tqdm over its 5 stages (uvs → rasterize →
|
||||
back-project → sample → finalize) and, if a comfy `pbar` (ProgressBar) is
|
||||
passed, ticks it once per stage too — so callers should size it as 5 per
|
||||
bake."""
|
||||
@ -845,111 +939,25 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
||||
v_np = vertices.detach().cpu().numpy().astype(np.float32)
|
||||
f_np = faces.detach().cpu().numpy().astype(np.uint32)
|
||||
fcount = int(f_np.shape[0])
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if existing_uvs is not None:
|
||||
# Bake onto the mesh's current UVs — no xatlas, no seam-splitting.
|
||||
uv_np = existing_uvs.detach().cpu().numpy().astype(np.float32)
|
||||
if uv_np.shape[0] != v_np.shape[0]:
|
||||
raise ValueError(
|
||||
f"BakeTextureFromVoxel: existing UVs ({uv_np.shape[0]}) must be 1:1 "
|
||||
f"with vertices ({v_np.shape[0]})."
|
||||
)
|
||||
uv_min = uv_np.min(axis=0)
|
||||
uv_max = uv_np.max(axis=0)
|
||||
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
|
||||
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
|
||||
f"{fcount} faces (xatlas skipped)")
|
||||
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
|
||||
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
|
||||
out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001)
|
||||
if normalize_uvs and out_of_unit:
|
||||
# Uniform fit of the UV bbox into [0,1] (preserves chart aspect ratios).
|
||||
# Handles packers that overflow the unit square slightly. NOT a UDIM
|
||||
# de-tiler — a true multi-tile layout would get squashed; warn if the
|
||||
# span is large enough to look like tiling.
|
||||
extent = float((uv_max - uv_min).max())
|
||||
span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1]))
|
||||
if span > 1.5:
|
||||
logging.warning(
|
||||
f"[BakeTextureFromVoxel] UV span {span:.2f} looks like a tiled/UDIM "
|
||||
f"layout; uniform-fitting it into [0,1] will overlap tiles. "
|
||||
f"Re-unwrap instead (use_existing_uvs=False)."
|
||||
)
|
||||
if extent > 0:
|
||||
uv_np = ((uv_np - uv_min) / extent).astype(np.float32)
|
||||
logging.info(f"[BakeTextureFromVoxel] normalized UVs into [0,1] "
|
||||
f"(uniform scale 1/{extent:.4f})")
|
||||
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
|
||||
else:
|
||||
import xatlas
|
||||
if fcount > 300_000:
|
||||
logging.warning(
|
||||
f"[BakeTextureFromVoxel] mesh has {fcount} faces — xatlas chart "
|
||||
f"decomposition is CPU-bound and may take many minutes. Consider "
|
||||
f"decimating to under ~200k faces before baking."
|
||||
)
|
||||
logging.info(f"[BakeTextureFromVoxel] xatlas unwrap: {v_np.shape[0]} verts, {fcount} faces")
|
||||
if fast_unwrap and hasattr(xatlas, "Atlas"):
|
||||
atlas = xatlas.Atlas()
|
||||
atlas.add_mesh(v_np, f_np)
|
||||
logging.info(f"[BakeTextureFromVoxel] add_mesh: {time.perf_counter() - t0:.1f}s")
|
||||
gen_kwargs = {}
|
||||
applied = []
|
||||
# ChartOptions: looser growth → larger / fewer charts → faster.
|
||||
if hasattr(xatlas, "ChartOptions"):
|
||||
co = xatlas.ChartOptions()
|
||||
for attr, val in (
|
||||
("max_iterations", 1),
|
||||
("max_cost", 8.0),
|
||||
("normal_deviation_weight", 1.0),
|
||||
("roundness_weight", 0.0),
|
||||
("straightness_weight", 0.0),
|
||||
("normal_seam_weight", 1.0),
|
||||
("texture_seam_weight", 0.0),
|
||||
("use_input_mesh_uvs", False),
|
||||
):
|
||||
if hasattr(co, attr):
|
||||
setattr(co, attr, val)
|
||||
applied.append(f"chart.{attr}")
|
||||
gen_kwargs["chart_options"] = co
|
||||
# PackOptions.bruteForce defaults to True — tries many rotations per
|
||||
# chart and is the single biggest contributor to pack time on small
|
||||
# meshes. Off it loses ~5-15% packing efficiency but runs ~5× faster.
|
||||
if hasattr(xatlas, "PackOptions"):
|
||||
po = xatlas.PackOptions()
|
||||
for attr, val in (
|
||||
("bruteForce", False),
|
||||
("brute_force", False), # snake_case alias on some builds
|
||||
("create_image", False),
|
||||
("createImage", False),
|
||||
("padding", 2),
|
||||
):
|
||||
if hasattr(po, attr):
|
||||
setattr(po, attr, val)
|
||||
applied.append(f"pack.{attr}")
|
||||
gen_kwargs["pack_options"] = po
|
||||
logging.info(f"[BakeTextureFromVoxel] options applied: {applied}")
|
||||
tgen = time.perf_counter()
|
||||
try:
|
||||
atlas.generate(**gen_kwargs)
|
||||
except TypeError as e:
|
||||
logging.warning(f"[BakeTextureFromVoxel] generate(**kwargs) rejected ({e}); retrying with defaults")
|
||||
atlas.generate()
|
||||
logging.info(f"[BakeTextureFromVoxel] generate: {time.perf_counter() - tgen:.1f}s")
|
||||
tget = time.perf_counter()
|
||||
vmapping, indices, uvs = atlas[0]
|
||||
logging.info(f"[BakeTextureFromVoxel] retrieve: {time.perf_counter() - tget:.1f}s")
|
||||
else:
|
||||
vmapping, indices, uvs = xatlas.parametrize(v_np, f_np)
|
||||
# Bake onto the mesh's current UVs — no unwrap, no seam-splitting.
|
||||
uv_np = uvs.detach().cpu().numpy().astype(np.float32)
|
||||
if uv_np.shape[0] != v_np.shape[0]:
|
||||
raise ValueError(
|
||||
f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 "
|
||||
f"with vertices ({v_np.shape[0]})."
|
||||
)
|
||||
uv_min = uv_np.min(axis=0)
|
||||
uv_max = uv_np.max(axis=0)
|
||||
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
|
||||
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
|
||||
f"{fcount} faces")
|
||||
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
|
||||
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
|
||||
uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ")
|
||||
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
|
||||
|
||||
logging.info(f"[BakeTextureFromVoxel] xatlas total {time.perf_counter() - t0:.1f}s "
|
||||
f"({vmapping.shape[0]} verts after seams)")
|
||||
new_verts = v_np[vmapping]
|
||||
new_faces = indices.astype(np.uint32)
|
||||
new_uvs = uvs.astype(np.float32)
|
||||
|
||||
_tick("unwrap")
|
||||
_tick("uvs")
|
||||
|
||||
t1 = time.perf_counter()
|
||||
position_map, mask = _bake_position_map(new_verts, new_faces, new_uvs, texture_size)
|
||||
@ -968,7 +976,7 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
||||
|
||||
t2 = time.perf_counter()
|
||||
attrs = _sample_voxel_attrs_per_texel(
|
||||
position_map, mask, voxel_coords, voxel_colors, resolution, mode=sample_mode,
|
||||
position_map, mask, voxel_coords, voxel_colors, resolution,
|
||||
)
|
||||
logging.info(f"[BakeTextureFromVoxel] voxel sample in {time.perf_counter() - t2:.1f}s "
|
||||
f"({attrs.shape[-1]} channels)")
|
||||
@ -1022,12 +1030,11 @@ def _per_vertex_normals(verts_np, faces_np):
|
||||
|
||||
|
||||
def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resolution,
|
||||
texture_size, views, blend_temperature=0.25,
|
||||
inpaint_radius=3, fast_unwrap=True, existing_uvs=None,
|
||||
normalize_uvs=True, sample_mode="trilinear"):
|
||||
texture_size, views, uvs, blend_temperature=0.25,
|
||||
inpaint_radius=3, normalize_uvs=True):
|
||||
"""Bake a baseColor texture by projecting view photos onto the mesh.
|
||||
|
||||
Reuses bake_texture_from_voxel_fn for the xatlas unwrap + the nearest-voxel
|
||||
Reuses bake_texture_from_voxel_fn for the UV-space bake + the nearest-voxel
|
||||
fallback colour, then overlays photo colour on every covered+visible texel:
|
||||
each texel's world position/normal is projected into each view, occlusion is
|
||||
resolved with a texel z-buffer, and the views are blended weighted by how
|
||||
@ -1045,8 +1052,8 @@ def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resol
|
||||
# Voxel bake → unwrapped geometry + fallback colour (inpaint deferred to the end).
|
||||
out_v, out_f, out_uvs, voxel_tex, voxel_mr = bake_texture_from_voxel_fn(
|
||||
vertices, faces, voxel_coords, voxel_colors, resolution=resolution,
|
||||
texture_size=texture_size, inpaint_radius=0, fast_unwrap=fast_unwrap,
|
||||
existing_uvs=existing_uvs, normalize_uvs=normalize_uvs, sample_mode=sample_mode)
|
||||
texture_size=texture_size, uvs=uvs, inpaint_radius=0,
|
||||
normalize_uvs=normalize_uvs)
|
||||
|
||||
v_np = out_v.detach().cpu().numpy().astype(np.float32)
|
||||
f_np = out_f.detach().cpu().numpy().astype(np.uint32)
|
||||
@ -1078,6 +1085,16 @@ def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resol
|
||||
return out_v, out_f, out_uvs, out_tex, voxel_mr
|
||||
|
||||
|
||||
def _mr_channel(packed_mr, ch, ref):
|
||||
"""Pull one channel out of a packed glTF MR map (G=roughness at idx 1, B=metallic
|
||||
at idx 2) as a 3-channel grayscale IMAGE [H,W,3] in [0,1]. Returns black sized
|
||||
like `ref` when there's no MR map (non-PBR voxel field)."""
|
||||
if packed_mr is None:
|
||||
return torch.zeros_like(ref.float().cpu())
|
||||
m = packed_mr.float().clamp(0.0, 1.0).cpu()
|
||||
return m[..., ch:ch + 1].expand(-1, -1, 3).contiguous()
|
||||
|
||||
|
||||
class BakeTextureFromVoxel(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -1089,12 +1106,12 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
||||
"Bakes PBR textures onto the mesh's existing UV layout by rasterizing it "
|
||||
"in UV space via OpenGL (ComfyUI's PyOpenGL backend) and trilinear-sampling "
|
||||
"the input sparse voxel volume. Does NOT unwrap — connect a UV unwrap node "
|
||||
"(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Produces a "
|
||||
"baseColor texture, plus a metallicRoughness texture when the voxel field "
|
||||
"carries the full PBR set (6 channels). Returns a Mesh with `uvs`, `texture`, "
|
||||
"and `metallic_roughness` attached — SaveGLB serializes them as real "
|
||||
"baseColorTexture / metallicRoughnessTexture maps. UVs that spill outside "
|
||||
"[0,1] are uniformly fit back into the unit square."
|
||||
"(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Outputs the "
|
||||
"baked maps as IMAGEs: base_color, plus metallic and roughness as separate "
|
||||
"grayscale maps (both black when the voxel field has no PBR set). "
|
||||
"Preview/save/post-process them, then feed them to ApplyTextureToMesh (with "
|
||||
"the SAME mesh) to attach them for SaveGLB. UVs that spill outside [0,1] are "
|
||||
"uniformly fit into the unit square."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
@ -1108,7 +1125,11 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
||||
"o_voxel.to_glb step that removes faceted/pixelized baking on coarse "
|
||||
"meshes. Pure scipy+torch, no extra deps.")),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
outputs=[
|
||||
IO.Image.Output(display_name="base_color"),
|
||||
IO.Image.Output(display_name="metallic"),
|
||||
IO.Image.Output(display_name="roughness"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1132,7 +1153,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
||||
batch_idx = coords[:, 0].long()
|
||||
voxel_xyz = coords[:, 1:]
|
||||
mesh_batch_size = int(mesh.vertices.shape[0])
|
||||
out_verts, out_faces, out_uvs, out_tex, out_mr = [], [], [], [], []
|
||||
out_tex, out_mr = [], []
|
||||
# 5 stage ticks per item (see bake_texture_from_voxel_fn); skipped items
|
||||
# tick all 5 so the bar stays aligned.
|
||||
pbar = comfy.utils.ProgressBar(mesh_batch_size * 5)
|
||||
@ -1150,28 +1171,25 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
||||
if reference_mesh is not None:
|
||||
rv_i, rf_i, _ = get_mesh_batch_item(reference_mesh, i)
|
||||
ref_i = (rv_i, rf_i)
|
||||
bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn(
|
||||
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
|
||||
v_i, f_i, item_coords, item_colors,
|
||||
resolution=resolution, texture_size=texture_size,
|
||||
inpaint_radius=inpaint_radius,
|
||||
existing_uvs=ev_i, reference=ref_i, pbar=pbar,
|
||||
uvs=ev_i, inpaint_radius=inpaint_radius,
|
||||
reference=ref_i, pbar=pbar,
|
||||
)
|
||||
out_verts.append(bv); out_faces.append(bf); out_uvs.append(bu)
|
||||
out_tex.append(bt); out_mr.append(bmr)
|
||||
if not out_verts:
|
||||
return IO.NodeOutput(mesh)
|
||||
# Local pack_variable_mesh_batch doesn't take uvs/texture; build the
|
||||
# packed mesh ourselves so we can attach both. UVs are 1:1 with verts.
|
||||
packed = pack_variable_mesh_batch(out_verts, out_faces)
|
||||
max_v = packed.vertices.shape[1]
|
||||
packed_uvs = out_uvs[0].new_zeros((len(out_uvs), max_v, 2))
|
||||
for i, u in enumerate(out_uvs):
|
||||
packed_uvs[i, :u.shape[0]] = u
|
||||
packed.uvs = packed_uvs
|
||||
packed.texture = torch.stack(out_tex, dim=0)
|
||||
if all(mr is not None for mr in out_mr):
|
||||
packed.metallic_roughness = torch.stack(out_mr, dim=0)
|
||||
return IO.NodeOutput(packed)
|
||||
if not out_tex:
|
||||
# Every item skipped (degenerate) — emit one black map so the IMAGE
|
||||
# outputs stay valid.
|
||||
black = torch.zeros((1, texture_size, texture_size, 3))
|
||||
return IO.NodeOutput(black, black, black)
|
||||
# All maps are texture_size² — stack into [B,H,W,3] IMAGE batches. The
|
||||
# packed MR (G=roughness, B=metallic) is split into separate grayscale
|
||||
# maps; both black where the voxel field carried no PBR set.
|
||||
base_img = torch.stack([t.float().clamp(0.0, 1.0).cpu() for t in out_tex], dim=0)
|
||||
metallic_img = torch.stack([_mr_channel(m, 2, out_tex[0]) for m in out_mr], dim=0)
|
||||
roughness_img = torch.stack([_mr_channel(m, 1, out_tex[0]) for m in out_mr], dim=0)
|
||||
return IO.NodeOutput(base_img, metallic_img, roughness_img)
|
||||
|
||||
# Single-item path.
|
||||
v0 = mesh.vertices.squeeze(0)
|
||||
@ -1181,19 +1199,16 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
||||
if reference_mesh is not None:
|
||||
ref0 = (reference_mesh.vertices.squeeze(0), reference_mesh.faces.squeeze(0))
|
||||
pbar = comfy.utils.ProgressBar(5) # 5 stage ticks (see bake_texture_from_voxel_fn)
|
||||
bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn(
|
||||
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
|
||||
v0, f0, coords, colors,
|
||||
resolution=resolution, texture_size=texture_size,
|
||||
inpaint_radius=inpaint_radius,
|
||||
existing_uvs=ev0, reference=ref0, pbar=pbar,
|
||||
uvs=ev0, inpaint_radius=inpaint_radius,
|
||||
reference=ref0, pbar=pbar,
|
||||
)
|
||||
out_mesh = Types.MESH(
|
||||
vertices=bv.unsqueeze(0), faces=bf.unsqueeze(0),
|
||||
uvs=bu.unsqueeze(0), texture=bt.unsqueeze(0),
|
||||
)
|
||||
if bmr is not None:
|
||||
out_mesh.metallic_roughness = bmr.unsqueeze(0)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0)
|
||||
metallic_img = _mr_channel(bmr, 2, bt).unsqueeze(0)
|
||||
roughness_img = _mr_channel(bmr, 1, bt).unsqueeze(0)
|
||||
return IO.NodeOutput(base_img, metallic_img, roughness_img)
|
||||
|
||||
|
||||
class MeshTextureToImage(IO.ComfyNode):
|
||||
@ -1247,6 +1262,73 @@ class MeshTextureToImage(IO.ComfyNode):
|
||||
return IO.NodeOutput(base, mr, metallic, roughness)
|
||||
|
||||
|
||||
class ApplyTextureToMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ApplyTextureToMesh",
|
||||
display_name="Apply Texture to Mesh",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Attaches baked texture IMAGEs to a mesh's existing UV layout so SaveGLB "
|
||||
"serializes them as baseColorTexture / metallicRoughnessTexture maps. Pairs "
|
||||
"with BakeTextureFromVoxel: feed it the SAME mesh you baked from, plus that "
|
||||
"node's base_color (and optionally metallic / roughness grayscale maps) — the "
|
||||
"UVs must match the ones the texture was baked against, so don't re-unwrap in "
|
||||
"between. metallic and roughness are repacked into the glTF MR map "
|
||||
"(G=roughness, B=metallic); leave them unconnected for non-PBR meshes (a "
|
||||
"missing metallic defaults to 0, a missing roughness to 1). Lets you preview / "
|
||||
"upscale / edit the baked maps before applying them."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Image.Input("base_color"),
|
||||
IO.Image.Input("metallic", optional=True),
|
||||
IO.Image.Input("roughness", optional=True),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, base_color, metallic=None, roughness=None):
|
||||
mesh_uvs = getattr(mesh, "uvs", None)
|
||||
if mesh_uvs is None:
|
||||
raise ValueError(
|
||||
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
|
||||
"you fed to BakeTextureFromVoxel (this node attaches onto existing UVs and "
|
||||
"never unwraps).")
|
||||
|
||||
# Re-derive the exact UVs the bake rasterized against — it uniformly fits
|
||||
# out-of-[0,1] layouts into the unit square, so apply the identical
|
||||
# deterministic transform here (per batch item, over each item's real verts).
|
||||
if mesh_uvs.ndim == 3:
|
||||
new_uvs = mesh_uvs.clone()
|
||||
for i in range(mesh_uvs.shape[0]):
|
||||
v_i, _f_i, _ = get_mesh_batch_item(mesh, i)
|
||||
n = v_i.shape[0]
|
||||
norm = _normalize_uvs_to_unit(mesh_uvs[i, :n].detach().cpu().numpy())
|
||||
new_uvs[i, :n] = torch.from_numpy(norm).to(new_uvs)
|
||||
else:
|
||||
norm = _normalize_uvs_to_unit(mesh_uvs.detach().cpu().numpy())
|
||||
new_uvs = torch.from_numpy(norm).to(mesh_uvs)
|
||||
|
||||
out_mesh = copy.copy(mesh)
|
||||
out_mesh.uvs = new_uvs
|
||||
out_mesh.texture = base_color.float().clamp(0.0, 1.0).cpu()
|
||||
if metallic is not None or roughness is not None:
|
||||
# Repack separate grayscale maps into glTF MR: R unused, G=roughness,
|
||||
# B=metallic. Size defaults off whichever map is connected; a missing
|
||||
# channel falls back to a sensible scalar (metal 0 / rough 1).
|
||||
prov = (metallic if metallic is not None else roughness).float().clamp(0.0, 1.0).cpu()
|
||||
B, H, W, _ = prov.shape
|
||||
rough_ch = (roughness.float().clamp(0.0, 1.0).cpu()[..., 0:1]
|
||||
if roughness is not None else torch.ones((B, H, W, 1)))
|
||||
metal_ch = (metallic.float().clamp(0.0, 1.0).cpu()[..., 0:1]
|
||||
if metallic is not None else torch.zeros((B, H, W, 1)))
|
||||
out_mesh.metallic_roughness = torch.cat([torch.zeros((B, H, W, 1)), rough_ch, metal_ch], dim=-1)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
|
||||
def paint_mesh_default_colors(mesh):
|
||||
out_mesh = copy.copy(mesh)
|
||||
vertex_count = mesh.vertices.shape[1]
|
||||
@ -1397,14 +1479,14 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
|
||||
for c, n in zip(unique_nm.tolist(), cnt_nm.tolist()):
|
||||
nm_share_breakdown.append(f"{n} edges×{c}faces")
|
||||
|
||||
logging.info(f"[FillHolesV2 diag] V={V} F={F} | "
|
||||
logging.info(f"[FillHoles diag] V={V} F={F} | "
|
||||
f"boundary(cnt==1)={n_boundary} interior(cnt==2)={n_interior} "
|
||||
f"non-manifold(cnt>=3)={n_nonmanifold} (max={nm_max})")
|
||||
if nm_share_breakdown:
|
||||
logging.info(f"[FillHolesV2 diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}")
|
||||
logging.info(f"[FillHoles diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}")
|
||||
|
||||
if n_boundary == 0:
|
||||
logging.info("[FillHolesV2 diag] no boundary edges → no cycles to fill")
|
||||
logging.info("[FillHoles diag] no boundary edges → no cycles to fill")
|
||||
return
|
||||
|
||||
# Walk components same as production path (bidir-prop, by-vertex count).
|
||||
@ -1459,15 +1541,15 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
|
||||
cfan_tris = int(vert_count[cfan].sum().item()) # N tris per N-vert cycle
|
||||
cfan_new_verts = int(cfan.sum().item()) # 1 centroid per centroid-fan component
|
||||
|
||||
logging.info(f"[FillHolesV2 diag] components={L} "
|
||||
logging.info(f"[FillHoles diag] components={L} "
|
||||
f"cycles={int(is_cycle.sum().item())} chains={int(is_chain.sum().item())} "
|
||||
f"non-simple={int(is_other.sum().item())}")
|
||||
logging.info(f"[FillHolesV2 diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})")
|
||||
logging.info(f"[FillHolesV2 diag] actually kept={int(actually_kept.sum().item())} "
|
||||
logging.info(f"[FillHoles diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})")
|
||||
logging.info(f"[FillHoles diag] actually kept={int(actually_kept.sum().item())} "
|
||||
f"cycle rejected by perim={int((is_cycle & ~cycle_perim_ok).sum().item())} "
|
||||
f"cycle rejected by verts={int((is_cycle & ~cycle_size_ok).sum().item())}")
|
||||
logging.info(f"[FillHolesV2 diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)")
|
||||
logging.info(f"[FillHolesV2 diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts")
|
||||
logging.info(f"[FillHoles diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)")
|
||||
logging.info(f"[FillHoles diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts")
|
||||
|
||||
# Cycle vert-count distribution
|
||||
if is_cycle.any():
|
||||
@ -1485,12 +1567,12 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
|
||||
elif s <= 20: buckets["11-20"] += n
|
||||
elif s <= 50: buckets["21-50"] += n
|
||||
else: buckets["51+"] += n
|
||||
logging.info(f"[FillHolesV2 diag] cycle vert-count buckets: {buckets}")
|
||||
logging.info(f"[FillHoles diag] cycle vert-count buckets: {buckets}")
|
||||
|
||||
if is_cycle.any():
|
||||
cycle_perims = perim[is_cycle].cpu().tolist()
|
||||
head = sorted(cycle_perims, reverse=True)[:10]
|
||||
logging.info(f"[FillHolesV2 diag] top-10 cycle perimeters: "
|
||||
logging.info(f"[FillHoles diag] top-10 cycle perimeters: "
|
||||
f"{['%.4f' % p for p in head]}")
|
||||
|
||||
|
||||
@ -1804,10 +1886,10 @@ def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsi
|
||||
n_escalations += 1
|
||||
if total_welded > 0 or n_escalations > 0:
|
||||
tag = f" (escalated weld epsilon_rel→{eps:.1e} after {n_escalations} step{'s' if n_escalations != 1 else ''})" if n_escalations > 0 else ""
|
||||
logging.info(f"[FillHolesV2] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}")
|
||||
logging.info(f"[FillHoles] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}")
|
||||
if ratio >= WELDED_THRESHOLD:
|
||||
logging.warning(
|
||||
f"[FillHolesV2] even at weld epsilon_rel={WELD_CAP} the mesh stays "
|
||||
f"[FillHoles] even at weld epsilon_rel={WELD_CAP} the mesh stays "
|
||||
f"unwelded (V/F={ratio:.2f}, want < {WELDED_THRESHOLD}). Source mesh has "
|
||||
f"duplicate verts at distances >{WELD_CAP}× bbox; fix upstream "
|
||||
f"(decimate node settings) or run WeldVertices manually with a larger epsilon."
|
||||
@ -2077,6 +2159,23 @@ def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4):
|
||||
|
||||
return offset_verts, f_unwelded, None
|
||||
|
||||
def _fmt_count(n) -> str:
|
||||
"""Compact human-readable integer for node status lines, e.g. 853, 12.3K, 1.23M."""
|
||||
n = int(n)
|
||||
if n >= 1_000_000:
|
||||
return f"{n / 1_000_000:.2f}".rstrip("0").rstrip(".") + "M"
|
||||
if n >= 1_000:
|
||||
return f"{n / 1_000:.1f}".rstrip("0").rstrip(".") + "K"
|
||||
return str(n)
|
||||
|
||||
|
||||
def _fmt_face_change(n_in, n_out) -> str:
|
||||
"""'faces: 1.23M → 200K (-84%)' — the count delta for decimate/remesh status."""
|
||||
n_in, n_out = int(n_in), int(n_out)
|
||||
pct = f" ({(n_out - n_in) / n_in * 100:+.0f}%)" if n_in else ""
|
||||
return f"faces: {_fmt_count(n_in)} → {_fmt_count(n_out)}{pct}"
|
||||
|
||||
|
||||
class DecimateMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -2168,7 +2267,7 @@ class DecimateMesh(IO.ComfyNode):
|
||||
# Send progress text to display the face reduction on the node
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id)
|
||||
_fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id)
|
||||
|
||||
return result
|
||||
|
||||
@ -2306,7 +2405,7 @@ class RemeshMesh(IO.ComfyNode):
|
||||
# Send progress text to display the face change on the node
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id)
|
||||
_fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id)
|
||||
|
||||
return result
|
||||
|
||||
@ -2546,7 +2645,8 @@ class UnwrapMesh(IO.ComfyNode):
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"UV: {out_v[0].shape[0]}v / {out_f[0].shape[0]}f, atlas ~{resolution}px",
|
||||
f"UV: {_fmt_count(out_v[0].shape[0])} verts / {_fmt_count(out_f[0].shape[0])} faces"
|
||||
f" · atlas ~{resolution}px",
|
||||
cls.hidden.unique_id)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
@ -2740,36 +2840,12 @@ class FillHoles(IO.ComfyNode):
|
||||
node_id="FillHoles",
|
||||
display_name="Fill Holes",
|
||||
category="latent/3d",
|
||||
description="Fills holes in a mesh up to a maximum perimeter threshold.",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001,
|
||||
tooltip="Maximum hole perimeter to fill. Set to 0 to disable."),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, max_perimeter):
|
||||
def _fn(v, f, c):
|
||||
if max_perimeter > 0:
|
||||
v, f = fill_holes_fn(v, f, max_perimeter=max_perimeter)
|
||||
return v, f, c
|
||||
return _process_mesh_batch(mesh, _fn)
|
||||
|
||||
class FillHolesV2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FillHolesV2",
|
||||
display_name="Fill Holes (v2)",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"GPU-vectorised hole-filling via directed-half-edge pointer-doubling. "
|
||||
"Drop-in alternative to FillHoles for comparison: same max_perimeter "
|
||||
"cutoff and fan-from-centroid triangulation, but no Python loop, "
|
||||
"auto-correct winding from face direction, and centroid colors are "
|
||||
"averaged from the loop instead of left zero."
|
||||
"Fills holes in a mesh up to a maximum perimeter threshold, preserving "
|
||||
"the existing geometry/UVs (only patch triangles are added). GPU-vectorised "
|
||||
"via directed-half-edge pointer-doubling: no Python loop, auto-correct "
|
||||
"winding from face direction, and centroid colors are averaged from the hole "
|
||||
"loop. Falls back to a CPU walker on non-CUDA devices."
|
||||
),
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
@ -2947,7 +3023,6 @@ class PostProcessMeshExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
FillHoles,
|
||||
FillHolesV2,
|
||||
WeldVertices,
|
||||
DecimateMesh,
|
||||
RemeshMesh,
|
||||
@ -2956,6 +3031,7 @@ class PostProcessMeshExtension(ComfyExtension):
|
||||
PaintMesh,
|
||||
BakeTextureFromVoxel,
|
||||
MeshTextureToImage,
|
||||
ApplyTextureToMesh,
|
||||
MergeMeshes,
|
||||
]
|
||||
|
||||
|
||||
@ -1,18 +1,15 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, io
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy.ldm.trellis2.model import (
|
||||
_build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats,
|
||||
)
|
||||
from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats
|
||||
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
|
||||
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy.ldm.trellis2 import sampling_preview
|
||||
from PIL import Image
|
||||
import logging
|
||||
import os
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
@ -21,89 +18,6 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
NAFModel = io.Custom("NAF_MODEL")
|
||||
|
||||
|
||||
# Texture latent -> base-color calibration for the per-step preview
|
||||
def _tex_rgb_factors_path():
|
||||
return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt")
|
||||
|
||||
|
||||
def _pool_albedo_to_input(in_coords, out_coords, out_colors):
|
||||
in_sp = in_coords[:, 1:4].long()
|
||||
out_sp = out_coords[:, 1:4].long()
|
||||
in_b = in_coords[:, 0].long()
|
||||
out_b = out_coords[:, 0].long()
|
||||
in_res = int(in_sp.max().item()) + 1
|
||||
out_res = int(out_sp.max().item()) + 1
|
||||
parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1)
|
||||
R = in_res
|
||||
in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2]
|
||||
par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2]
|
||||
order = torch.argsort(in_flat)
|
||||
in_sorted = in_flat[order]
|
||||
pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1)
|
||||
matched = in_sorted[pos] == par_flat
|
||||
in_idx = order[pos][matched]
|
||||
cols = out_colors[matched].float()
|
||||
N = in_coords.shape[0]
|
||||
csum = cols.new_zeros((N, 3))
|
||||
ccount = cols.new_zeros((N, 1))
|
||||
csum.index_add_(0, in_idx, cols)
|
||||
ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype))
|
||||
valid = ccount[:, 0] > 0
|
||||
albedo = torch.zeros_like(csum)
|
||||
albedo[valid] = csum[valid] / ccount[valid]
|
||||
return albedo, valid
|
||||
|
||||
|
||||
def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords):
|
||||
"""Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish."""
|
||||
try:
|
||||
dev = out_colors.device
|
||||
in_latent = in_latent.to(dev)
|
||||
in_coords = in_coords.to(dev)
|
||||
out_coords = out_coords.to(dev)
|
||||
albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors)
|
||||
X = in_latent[valid].float().cpu()
|
||||
Y = albedo[valid].float().cpu()
|
||||
if X.shape[0] < 64:
|
||||
return
|
||||
Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1]
|
||||
A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1]
|
||||
B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3]
|
||||
|
||||
path = _tex_rgb_factors_path()
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
prev = torch.load(path, map_location="cpu")
|
||||
A_run = A_run + prev["A"]
|
||||
B_run = B_run + prev["B"]
|
||||
except Exception:
|
||||
pass
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({"A": A_run, "B": B_run}, path)
|
||||
|
||||
eye = torch.eye(A_run.shape[0])
|
||||
WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3]
|
||||
W, b = WB[:-1].contiguous(), WB[-1].contiguous()
|
||||
sampling_preview.set_tex_rgb(W, b)
|
||||
except Exception as e:
|
||||
logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}")
|
||||
|
||||
|
||||
def _load_tex_rgb_factors():
|
||||
try:
|
||||
path = _tex_rgb_factors_path()
|
||||
if os.path.exists(path):
|
||||
d = torch.load(path, map_location="cpu")
|
||||
eye = torch.eye(d["A"].shape[0])
|
||||
WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"])
|
||||
sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous())
|
||||
except Exception as e:
|
||||
logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}")
|
||||
|
||||
|
||||
_load_tex_rgb_factors()
|
||||
|
||||
|
||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||
if len(sample_shape) == 5:
|
||||
@ -271,7 +185,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
if coord_counts is None:
|
||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
samples = shape_norm(samples.to(device), coords.to(device))
|
||||
mesh, subs = trellis_vae.decode_shape_slat(samples, resolution)
|
||||
mesh, subs = trellis_vae.decode_shape_slat(samples.to(vae.vae_dtype), resolution)
|
||||
else:
|
||||
split_items = split_batched_sparse_latent(samples, coords, coord_counts)
|
||||
mesh = []
|
||||
@ -280,7 +194,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
coords_i = coords_i.to(device).clone()
|
||||
coords_i[:, 0] = 0
|
||||
sample_i = shape_norm(feats_i.to(device), coords_i)
|
||||
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i, resolution)
|
||||
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i.to(vae.vae_dtype), resolution)
|
||||
mesh.append(mesh_i[0])
|
||||
subs_per_sample.append(subs_i)
|
||||
|
||||
@ -338,14 +252,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
samples = samples["samples"]
|
||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
samples = samples.to(device)
|
||||
cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration
|
||||
cal_in_coords = coords
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
||||
samples = samples * std + mean
|
||||
|
||||
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
|
||||
voxel = trellis_vae.decode_tex_slat(samples.to(vae.vae_dtype), shape_subdivides)
|
||||
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
|
||||
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
|
||||
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
|
||||
@ -353,12 +265,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
color_feats = voxel.feats
|
||||
voxel_coords = voxel.coords
|
||||
|
||||
# Calibrate the latent->base_color map for the per-step texture preview.
|
||||
# Done here while input coords and voxel_coords share the model frame
|
||||
# (before the z_up remap below) and on the real decoded albedo.
|
||||
if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3:
|
||||
_calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords)
|
||||
|
||||
if coord_resolution is not None:
|
||||
tex_resolution = int(coord_resolution) * 16
|
||||
elif voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
||||
@ -416,7 +322,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
decoded_batches = []
|
||||
for start in range(0, sample_tensor.shape[0], batch_number):
|
||||
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
||||
decoded_batches.append(shape_vae.decode_structure(sample_chunk) > 0)
|
||||
decoded_batches.append(shape_vae.decode_structure(sample_chunk.to(vae.vae_dtype)) > 0)
|
||||
decoded = torch.cat(decoded_batches, dim=0)
|
||||
current_res = decoded.shape[2]
|
||||
|
||||
@ -491,7 +397,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
||||
)
|
||||
slat = shape_norm(feats.to(device), coords_512.to(device))
|
||||
sample_hr_coords = [shape_vae.upsample_shape(slat, upsample_times=4)]
|
||||
sample_hr_coords = [shape_vae.upsample_shape(slat.to(vae.vae_dtype), upsample_times=4)]
|
||||
else:
|
||||
items = split_batched_sparse_latent(
|
||||
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
||||
@ -501,7 +407,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
coords_i = coords_i.to(device).clone()
|
||||
coords_i[:, 0] = 0
|
||||
slat_i = shape_norm(feats_i.to(device), coords_i)
|
||||
sample_hr_coords.append(shape_vae.upsample_shape(slat_i, upsample_times=4))
|
||||
sample_hr_coords.append(shape_vae.upsample_shape(slat_i.to(vae.vae_dtype), upsample_times=4))
|
||||
|
||||
# Resolution search — cache the final iteration's quantized unique tensors
|
||||
# so we don't recompute .unique() per sample after picking hr_resolution.
|
||||
@ -977,8 +883,10 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
if pad_l or pad_t or pad_r or pad_b:
|
||||
img = torch.nn.functional.pad(img, (pad_l, pad_r, pad_t, pad_b), value=0.0)
|
||||
mask = torch.nn.functional.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=0.0)
|
||||
crop_x1 += pad_l; crop_x2 += pad_l
|
||||
crop_y1 += pad_t; crop_y2 += pad_t
|
||||
crop_x1 += pad_l
|
||||
crop_x2 += pad_l
|
||||
crop_y1 += pad_t
|
||||
crop_y2 += pad_t
|
||||
cropped_img = img [..., crop_y1:crop_y2, crop_x1:crop_x2]
|
||||
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
||||
|
||||
@ -1100,7 +1008,7 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
|
||||
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
|
||||
scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32)
|
||||
T = _build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
|
||||
T = build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
|
||||
|
||||
proj_pack = {
|
||||
"stages": {
|
||||
@ -1119,15 +1027,6 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
}
|
||||
|
||||
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
|
||||
# proj_feat_pack rides in the conditioning dict (same place embeds, ControlNet
|
||||
# hints etc. live); the sampler auto-promotes it to a model.forward kwarg via
|
||||
# Trellis2.extra_conds. The same pack object is shared between pos/neg —
|
||||
# CONDConstant.can_concat sees them equal and concats to a single dict, then
|
||||
# Trellis2.forward zeros proj for the uncond slots via cond_or_uncond.
|
||||
# Pre-compute the SS-stage proj features (dense 16³ grid) once here — the
|
||||
# shape/texture stages do their own computes in their respective stage nodes.
|
||||
# proj_pack lives on intermediate (CPU); force the compute onto cuda so
|
||||
# the bilinear-sampling step doesn't run on CPU.
|
||||
ss_proj_feats = compute_stage_proj_feats(
|
||||
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
|
||||
device=torch_device,
|
||||
@ -1278,12 +1177,6 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
||||
q_mean = Q.mean(dim=0, keepdim=True)
|
||||
P_c = P - p_mean
|
||||
Q_c = Q - q_mean
|
||||
# Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is
|
||||
# noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper
|
||||
# acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c)
|
||||
# gets multiplied by cos(yaw) and shrinks the object. Using
|
||||
# sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of
|
||||
# rotation; translation still positions the mesh at MoGe's centroid.
|
||||
p_var = (P_c * P_c).sum().clamp(min=1e-8)
|
||||
q_var = (Q_c * Q_c).sum()
|
||||
scale = float(torch.sqrt(q_var / p_var).item())
|
||||
@ -1326,74 +1219,69 @@ class LoadNAFModel(IO.ComfyNode):
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
|
||||
class CFGGuidanceInterval(IO.ComfyNode):
|
||||
"""Generic model patch: apply CFG only during [start_percent, end_percent] of
|
||||
the sampling schedule. Outside that window, skip the uncond computation and
|
||||
collapse to effective cfg=1 — same idea as upstream Trellis2 / Pixal3D's
|
||||
guidance_interval_mixin, but lives at the sampler level (via
|
||||
sampler_calc_cond_batch_function) so it works for any model.
|
||||
class GetMeshInfo(IO.ComfyNode):
|
||||
"""Report vertex / face counts and attributes for a MESH, displayed on the
|
||||
node (and as a string output). Counts are comma-formatted since meshes can
|
||||
run into the millions of faces. Passes the mesh through unchanged."""
|
||||
|
||||
Percents use ComfyUI's standard convention: 0.0 = start of sampling
|
||||
(max-noise step), 1.0 = end of sampling (clean step). Conversion to sigma
|
||||
is done via model_sampling.percent_to_sigma so the window is portable
|
||||
across schedules (flow / EDM / discrete) and shift settings.
|
||||
|
||||
Defaults are full-range (no bypass). Upstream Trellis2 / Pixal3D
|
||||
pipeline.json sets guidance_interval=[0.6, 1.0] (upstream t-space) on the
|
||||
SS and shape samplers — CFG active only in the first 40% of sampling.
|
||||
Wire (start_percent=0.0, end_percent=0.4) on the SS / shape KSamplers to
|
||||
match. Texture defaults to cfg=1 so the node is moot there."""
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CFGGuidanceInterval",
|
||||
category="model_patches/sampling",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001,
|
||||
tooltip="Fraction of sampling at which CFG turns ON (0 = beginning)."),
|
||||
IO.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001,
|
||||
tooltip="Fraction of sampling at which CFG turns OFF (1 = end)."),
|
||||
node_id="GetMeshInfo",
|
||||
display_name="Get Mesh Info",
|
||||
category="latent/3d",
|
||||
inputs=[IO.Mesh.Input("mesh")],
|
||||
outputs=[
|
||||
IO.Mesh.Output(display_name="mesh"),
|
||||
IO.String.Output(display_name="info"),
|
||||
],
|
||||
outputs=[IO.Model.Output()],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fmt(n: int) -> str:
|
||||
# e.g. 1234567 -> "1,234,567 (1.23M)"; small numbers stay plain.
|
||||
s = f"{n:,}"
|
||||
if n >= 1_000_000:
|
||||
s += f" ({n / 1_000_000:.2f}M)"
|
||||
elif n >= 10_000:
|
||||
s += f" ({n / 1_000:.1f}K)"
|
||||
return s
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, start_percent, end_percent):
|
||||
import comfy.samplers
|
||||
def execute(cls, mesh):
|
||||
B = mesh.vertices.shape[0]
|
||||
# Honour per-item counts when the batch is zero-padded; else use the row sizes.
|
||||
if mesh.vertex_counts is not None:
|
||||
v_counts = [int(x) for x in mesh.vertex_counts.tolist()]
|
||||
f_counts = [int(x) for x in mesh.face_counts.tolist()]
|
||||
else:
|
||||
v_counts = [int(mesh.vertices.shape[1])] * B
|
||||
f_counts = [int(mesh.faces.shape[1])] * B
|
||||
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
# percent_to_sigma is monotonically decreasing: percent=0 -> sigma_max,
|
||||
# percent=1 -> sigma_min. So start_percent < end_percent in user space
|
||||
# means sigma_start > sigma_end. "Inside the window" is sigma in
|
||||
# [sigma_end, sigma_start].
|
||||
sigma_start = float(model_sampling.percent_to_sigma(start_percent))
|
||||
sigma_end = float(model_sampling.percent_to_sigma(end_percent))
|
||||
attrs = []
|
||||
for name in ("uvs", "vertex_colors", "normals", "texture", "metallic_roughness"):
|
||||
t = getattr(mesh, name, None)
|
||||
if t is not None:
|
||||
if name in ("texture", "metallic_roughness"):
|
||||
attrs.append(f"{name} {int(t.shape[-3])}×{int(t.shape[-2])}") # H×W
|
||||
else:
|
||||
attrs.append(name)
|
||||
|
||||
def calc_cond_batch_with_interval(args):
|
||||
sigma_val = args["sigma"][0].item()
|
||||
conds = args["conds"]
|
||||
input_x = args["input"]
|
||||
timestep = args["sigma"]
|
||||
model_ref = args["model"]
|
||||
model_opts = args["model_options"]
|
||||
lines = []
|
||||
if B > 1:
|
||||
lines.append(f"Batch: {B} meshes")
|
||||
lines.append(f"Vertices: {cls._fmt(sum(v_counts))} total")
|
||||
lines.append(f"Faces: {cls._fmt(sum(f_counts))} total")
|
||||
for i in range(B):
|
||||
lines.append(f" [{i}] {v_counts[i]:>10,} verts · {f_counts[i]:>10,} faces")
|
||||
else:
|
||||
lines.append(f"Vertices: {cls._fmt(v_counts[0])}")
|
||||
lines.append(f"Faces: {cls._fmt(f_counts[0])}")
|
||||
lines.append(f"Attributes: {', '.join(attrs) if attrs else 'none'}")
|
||||
|
||||
# conds is typically [cond, uncond]; uncond may be None when ComfyUI's
|
||||
# global cfg=1 optimization has already pruned it.
|
||||
cond = conds[0]
|
||||
uncond = conds[1] if len(conds) > 1 else None
|
||||
inside = sigma_end <= sigma_val <= sigma_start
|
||||
|
||||
if uncond is None or inside:
|
||||
return comfy.samplers.calc_cond_batch(model_ref, conds, input_x, timestep, model_opts)
|
||||
# Outside the window: compute cond only, mirror it into the uncond slot
|
||||
# so the downstream cfg_function collapses to `cond` (effective cfg=1).
|
||||
out = comfy.samplers.calc_cond_batch(model_ref, [cond], input_x, timestep, model_opts)
|
||||
return [out[0], out[0]]
|
||||
|
||||
m = model.clone()
|
||||
m.model_options["sampler_calc_cond_batch_function"] = calc_cond_batch_with_interval
|
||||
return IO.NodeOutput(m)
|
||||
info = "\n".join(lines)
|
||||
logging.info("[GetMeshInfo]\n%s", info)
|
||||
return IO.NodeOutput(mesh, info, ui=UI.PreviewText(info))
|
||||
|
||||
|
||||
class Trellis2Extension(ComfyExtension):
|
||||
@ -1411,7 +1299,7 @@ class Trellis2Extension(ComfyExtension):
|
||||
VaeDecodeShapeTrellis,
|
||||
VaeDecodeStructureTrellis2,
|
||||
Trellis2UpsampleStage,
|
||||
CFGGuidanceInterval,
|
||||
GetMeshInfo,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user