mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-05 22:21:31 +08:00
More cleanup
This commit is contained in:
parent
1f7acd9354
commit
41f5f4b2c0
@ -180,12 +180,13 @@ class PaintMesh(IO.ComfyNode):
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Texture baking from sparse voxel volume.
|
# Texture baking from sparse voxel volume.
|
||||||
#
|
#
|
||||||
# Pipeline: xatlas UV unwrap → OpenGL UV-space rasterize to position map →
|
# Pipeline: take the mesh's existing UVs → OpenGL UV-space rasterize to position
|
||||||
# nearest-voxel color sample per texel → cv2.inpaint to fill UV seams →
|
# map → nearest-voxel color sample per texel → GPU Jump-Flood fill UV seams →
|
||||||
# attach texture + UVs to the Mesh for SaveGLB to serialize.
|
# attach texture + UVs to the Mesh for SaveGLB to serialize. Unwrapping is done
|
||||||
|
# upstream (Trellis2OfficialUnwrap / TorchXatlasUVWrap); this path never unwraps.
|
||||||
#
|
#
|
||||||
# Uses comfy_extras.nodes_glsl.GLContext for OpenGL context (already handles
|
# Uses comfy_extras.nodes_glsl.GLContext for OpenGL context (already handles
|
||||||
# GLFW / EGL / OSMesa backend selection); xatlas for UV parameterization.
|
# GLFW / EGL / OSMesa backend selection).
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
_GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache"
|
_GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache"
|
||||||
@ -407,6 +408,54 @@ def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution):
|
|||||||
return vals, ok
|
return vals, ok
|
||||||
|
|
||||||
|
|
||||||
|
def _trilinear_sample_sparse_gpu(positions, voxel_coords_np, color_np, resolution):
|
||||||
|
"""GPU port of `_trilinear_sample_sparse` — same normalized-over-occupied-corners
|
||||||
|
trilinear, but the per-texel 8-corner accumulation runs on CUDA via sorted-key
|
||||||
|
`searchsorted` instead of NumPy float64. This is the bake hot path (millions of
|
||||||
|
covered texels × 8 corners), so the CPU version dominates runtime; the GPU port
|
||||||
|
is ~identical numerically and 10-50× faster. Returns (vals [K,C] float32, ok
|
||||||
|
[K] bool), matching the NumPy signature."""
|
||||||
|
dev = comfy.model_management.get_torch_device()
|
||||||
|
R = int(resolution)
|
||||||
|
origin = -0.5
|
||||||
|
voxel_size = 1.0 / R
|
||||||
|
P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float()
|
||||||
|
VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long()
|
||||||
|
col = torch.from_numpy(np.ascontiguousarray(color_np)).to(dev).float()
|
||||||
|
K, C = P.shape[0], col.shape[1]
|
||||||
|
M = VC.shape[0]
|
||||||
|
# Same cell-CENTER convention as the NumPy path (see its docstring): integer
|
||||||
|
# voxel coord c sits at (c+0.5)*voxel_size + origin, so subtract 0.5 to bracket.
|
||||||
|
gc = (P - origin) / voxel_size - 0.5
|
||||||
|
base = torch.floor(gc).long()
|
||||||
|
frac = gc - base.float()
|
||||||
|
key = (VC[:, 0] * R + VC[:, 1]) * R + VC[:, 2]
|
||||||
|
skey, order = key.sort()
|
||||||
|
acc = torch.zeros((K, C), device=dev)
|
||||||
|
wsum = torch.zeros((K, 1), device=dev)
|
||||||
|
for dx in (0, 1):
|
||||||
|
wx = frac[:, 0] if dx else 1.0 - frac[:, 0]
|
||||||
|
for dy in (0, 1):
|
||||||
|
wy = frac[:, 1] if dy else 1.0 - frac[:, 1]
|
||||||
|
for dz in (0, 1):
|
||||||
|
wz = frac[:, 2] if dz else 1.0 - frac[:, 2]
|
||||||
|
cx = base[:, 0] + dx
|
||||||
|
cy = base[:, 1] + dy
|
||||||
|
cz = base[:, 2] + dz
|
||||||
|
inb = (cx >= 0) & (cx < R) & (cy >= 0) & (cy < R) & (cz >= 0) & (cz < R)
|
||||||
|
qk = (cx * R + cy) * R + cz
|
||||||
|
ins = torch.searchsorted(skey, qk).clamp(max=M - 1)
|
||||||
|
matched = inb & (skey[ins] == qk)
|
||||||
|
idx = order[ins] # garbage where !matched
|
||||||
|
w = torch.where(matched, wx * wy * wz, torch.zeros_like(wx))[:, None]
|
||||||
|
acc += w * col[idx] # w=0 cancels garbage rows
|
||||||
|
wsum += w
|
||||||
|
ok = wsum[:, 0] > 1e-8
|
||||||
|
vals = torch.zeros((K, C), device=dev)
|
||||||
|
vals[ok] = acc[ok] / wsum[ok].clamp_min(1e-8)
|
||||||
|
return vals.cpu().numpy(), ok.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
||||||
"""GPU nearest-occupied-voxel lookup for surface points. Voxels sit on a
|
"""GPU nearest-occupied-voxel lookup for surface points. Voxels sit on a
|
||||||
regular integer grid (coord c ↔ world c/R-0.5), so the nearest voxel to a
|
regular integer grid (coord c ↔ world c/R-0.5), so the nearest voxel to a
|
||||||
@ -415,7 +464,7 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
|||||||
of voxels and ~identical. Returns (vals [K,C] float32, found [K] bool); `found`
|
of voxels and ~identical. Returns (vals [K,C] float32, found [K] bool); `found`
|
||||||
is False for the rare query whose nearest occupied voxel is >1 cell away (the
|
is False for the rare query whose nearest occupied voxel is >1 cell away (the
|
||||||
caller falls back to a cKDTree on just those)."""
|
caller falls back to a cKDTree on just those)."""
|
||||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
dev = comfy.model_management.get_torch_device()
|
||||||
R = int(resolution)
|
R = int(resolution)
|
||||||
P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float()
|
P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float()
|
||||||
VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long()
|
VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long()
|
||||||
@ -451,6 +500,27 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
|||||||
fnd |= match
|
fnd |= match
|
||||||
return bi, fnd
|
return bi, fnd
|
||||||
|
|
||||||
|
def _brute_nearest(idx):
|
||||||
|
"""Exact nearest occupied voxel for a (small) query subset by chunked GPU
|
||||||
|
brute force over all M voxels. Used only for the handful of stragglers the
|
||||||
|
grid scan misses (>4 cells from any voxel) — replaces a cKDTree build over
|
||||||
|
all M voxels, which costs seconds even for a few query points."""
|
||||||
|
Ps = P[idx] # [N,3] world
|
||||||
|
N = Ps.shape[0]
|
||||||
|
vox_pos = (VC.float() + 0.5) / R - 0.5 # [M,3] voxel centres
|
||||||
|
best_d = torch.full((N,), 1e30, device=dev)
|
||||||
|
best_j = torch.zeros(N, dtype=torch.long, device=dev)
|
||||||
|
# Bound the N×chunk distance matrix to ~64M elements regardless of N.
|
||||||
|
chunk = max(1, (1 << 26) // max(1, N))
|
||||||
|
for s in range(0, M, chunk):
|
||||||
|
vc = vox_pos[s:s + chunk] # [B,3]
|
||||||
|
dd = (Ps[:, None, :] - vc[None, :, :]).pow(2).sum(-1) # [N,B]
|
||||||
|
md, mj = dd.min(1)
|
||||||
|
upd = md < best_d
|
||||||
|
best_d = torch.where(upd, md, best_d)
|
||||||
|
best_j = torch.where(upd, mj + s, best_j)
|
||||||
|
return best_j
|
||||||
|
|
||||||
all_idx = torch.arange(K, device=dev)
|
all_idx = torch.arange(K, device=dev)
|
||||||
best_i = torch.zeros(K, dtype=torch.long, device=dev)
|
best_i = torch.zeros(K, dtype=torch.long, device=dev)
|
||||||
found = torch.zeros(K, dtype=torch.bool, device=dev)
|
found = torch.zeros(K, dtype=torch.bool, device=dev)
|
||||||
@ -458,27 +528,30 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution):
|
|||||||
bi1, fnd1 = _search(all_idx, 1)
|
bi1, fnd1 = _search(all_idx, 1)
|
||||||
best_i[all_idx] = bi1
|
best_i[all_idx] = bi1
|
||||||
found[all_idx] = fnd1
|
found[all_idx] = fnd1
|
||||||
# Pass 2: wider radius on ONLY the few misses (avoids ever building a cKDTree
|
# Pass 2: wider radius (9³) on ONLY the radius-1 misses.
|
||||||
# over millions of voxels just for a handful of >1-cell-away points).
|
|
||||||
miss = torch.nonzero(~found, as_tuple=True)[0]
|
miss = torch.nonzero(~found, as_tuple=True)[0]
|
||||||
if miss.numel() > 0:
|
if miss.numel() > 0:
|
||||||
bi2, fnd2 = _search(miss, 4)
|
bi2, fnd2 = _search(miss, 4)
|
||||||
best_i[miss] = bi2
|
best_i[miss] = bi2
|
||||||
found[miss] = fnd2
|
found[miss] = fnd2
|
||||||
|
# Pass 3: exact GPU brute force for the few stragglers still unfound (>4 cells
|
||||||
|
# out). Always resolves them, so `found` is all-True on return — no cKDTree.
|
||||||
|
miss2 = torch.nonzero(~found, as_tuple=True)[0]
|
||||||
|
if miss2.numel() > 0:
|
||||||
|
best_i[miss2] = _brute_nearest(miss2)
|
||||||
|
found[miss2] = True
|
||||||
vals = col[best_i]
|
vals = col[best_i]
|
||||||
return vals.cpu().numpy(), found.cpu().numpy()
|
return vals.cpu().numpy(), found.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution,
|
def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution):
|
||||||
mode="trilinear"):
|
|
||||||
"""For every masked texel, sample the voxel field and return ALL its attribute
|
"""For every masked texel, sample the voxel field and return ALL its attribute
|
||||||
channels. Returns (H, W, C) float32 in [0, 1] where C is the voxel feature
|
channels. Returns (H, W, C) float32 in [0, 1] where C is the voxel feature
|
||||||
width (3 for plain color, 6 for full PBR).
|
width (3 for plain color, 6 for full PBR).
|
||||||
|
|
||||||
mode="trilinear" — normalized trilinear over occupied voxels (the default; matches
|
Normalized trilinear over occupied voxels (matches the official o_voxel.to_glb
|
||||||
the official o_voxel.to_glb path), with nearest fallback for texels whose 8
|
path), with nearest fallback for texels whose 8 surrounding voxels are all
|
||||||
surrounding voxels are all empty. This is the only mode the nodes expose now.
|
empty."""
|
||||||
mode="nearest" — nearest-voxel; kept as an internal/dev lever (blocky)."""
|
|
||||||
H, W, _ = position_map.shape
|
H, W, _ = position_map.shape
|
||||||
color_np = voxel_colors.detach().cpu().numpy().astype(np.float32)
|
color_np = voxel_colors.detach().cpu().numpy().astype(np.float32)
|
||||||
C = color_np.shape[-1]
|
C = color_np.shape[-1]
|
||||||
@ -496,22 +569,26 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors
|
|||||||
valid_positions = position_map[mask]
|
valid_positions = position_map[mask]
|
||||||
|
|
||||||
def _nearest(query):
|
def _nearest(query):
|
||||||
# GPU grid lookup; cKDTree only for the rare >1-cell miss.
|
# Fully on-GPU nearest-occupied-voxel: grid scan + brute-force tail. Always
|
||||||
|
# resolves every query, so no cKDTree (its build over all voxels cost ~3s).
|
||||||
vals, found = _nearest_voxel_sample_gpu(query, coords_np, color_np, resolution)
|
vals, found = _nearest_voxel_sample_gpu(query, coords_np, color_np, resolution)
|
||||||
if not found.all():
|
if not found.all():
|
||||||
|
# Defensive: only reachable on a non-CUDA device where the GPU path is
|
||||||
|
# unavailable; fall back to a one-off cKDTree.
|
||||||
tree = scipy.spatial.cKDTree(voxel_pos)
|
tree = scipy.spatial.cKDTree(voxel_pos)
|
||||||
_, nearest_idx = tree.query(query[~found], k=1, workers=-1)
|
_, nearest_idx = tree.query(query[~found], k=1, workers=-1)
|
||||||
vals[~found] = color_np[nearest_idx]
|
vals[~found] = color_np[nearest_idx]
|
||||||
return vals
|
return vals
|
||||||
|
|
||||||
if mode == "trilinear":
|
try:
|
||||||
|
vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"[BakeTextureFromVoxel] GPU trilinear failed ({e}); falling back to CPU")
|
||||||
vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution)
|
vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution)
|
||||||
if not ok.all():
|
if not ok.all():
|
||||||
# Texels with no occupied neighbour fall back to nearest.
|
# Texels with no occupied neighbour fall back to nearest.
|
||||||
vals[~ok] = _nearest(valid_positions[~ok])
|
vals[~ok] = _nearest(valid_positions[~ok])
|
||||||
out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32)
|
out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32)
|
||||||
else:
|
|
||||||
out[mask] = np.clip(_nearest(valid_positions), 0.0, 1.0)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -729,7 +806,7 @@ def _back_project_positions(position_map, mask, ref_v, ref_f):
|
|||||||
return position_map
|
return position_map
|
||||||
|
|
||||||
import time as _time
|
import time as _time
|
||||||
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
dev = comfy.model_management.get_torch_device()
|
||||||
rv = ref_v.detach().to(dev).float()
|
rv = ref_v.detach().to(dev).float()
|
||||||
rf = ref_f.detach().to(dev).long()
|
rf = ref_f.detach().to(dev).long()
|
||||||
tri = rv[rf]
|
tri = rv[rf]
|
||||||
@ -755,7 +832,7 @@ def _jfa_fill_gpu(img01, mask):
|
|||||||
(True = covered). Returns [H,W,C] float. ~6× faster than cv2 Telea per map."""
|
(True = covered). Returns [H,W,C] float. ~6× faster than cv2 Telea per map."""
|
||||||
if not mask.any():
|
if not mask.any():
|
||||||
return img01
|
return img01
|
||||||
dev = "cuda"
|
dev = comfy.model_management.get_torch_device()
|
||||||
it = torch.from_numpy(np.ascontiguousarray(img01)).to(dev).float()
|
it = torch.from_numpy(np.ascontiguousarray(img01)).to(dev).float()
|
||||||
mm = torch.from_numpy(np.ascontiguousarray(mask)).to(dev)
|
mm = torch.from_numpy(np.ascontiguousarray(mask)).to(dev)
|
||||||
H, W = mm.shape
|
H, W = mm.shape
|
||||||
@ -786,41 +863,58 @@ def _jfa_fill_gpu(img01, mask):
|
|||||||
|
|
||||||
def _seam_fill(img01, mask, inpaint_radius):
|
def _seam_fill(img01, mask, inpaint_radius):
|
||||||
"""Fill the UV-gutter texels around covered charts so seam sampling doesn't
|
"""Fill the UV-gutter texels around covered charts so seam sampling doesn't
|
||||||
pull in black. GPU Jump Flooding (nearest fill) when CUDA is available, else
|
pull in black, via GPU Jump Flooding (nearest fill). `inpaint_radius<=0`
|
||||||
cv2 Telea inpaint. `inpaint_radius<=0` disables; the radius only affects the
|
disables; otherwise the radius is ignored — JFA fills every uncovered texel
|
||||||
cv2 fallback (JFA fills every uncovered texel by nearest)."""
|
by nearest regardless."""
|
||||||
if inpaint_radius <= 0:
|
if inpaint_radius <= 0:
|
||||||
return img01
|
return img01
|
||||||
if torch.cuda.is_available():
|
return _jfa_fill_gpu(img01, mask)
|
||||||
return _jfa_fill_gpu(img01, mask)
|
|
||||||
import cv2
|
|
||||||
u8 = (img01 * 255.0).clip(0, 255).astype(np.uint8)
|
def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None):
|
||||||
u8 = cv2.inpaint(u8, ((~mask).astype(np.uint8)) * 255, int(inpaint_radius), cv2.INPAINT_TELEA)
|
"""Uniformly fit a UV layout's bbox into [0,1] when it spills outside the unit
|
||||||
if u8.ndim == 2:
|
square (preserves chart aspect ratios; handles packers that overflow slightly).
|
||||||
u8 = u8[..., None]
|
No-op when the UVs are already in [0,1] — the normal case for official/xatlas
|
||||||
return u8.astype(np.float32) / 255.0
|
unwraps. NOT a UDIM de-tiler; warns if the span looks tiled.
|
||||||
|
|
||||||
|
Deterministic from the input UVs alone, so the texture bake and
|
||||||
|
ApplyTextureToMesh both call it to agree on the exact UVs the texture was baked
|
||||||
|
against (the bake no longer emits the mesh, so apply must re-derive them).
|
||||||
|
|
||||||
|
Returns float32 [N,2]."""
|
||||||
|
uv_np = uv_np.astype(np.float32)
|
||||||
|
uv_min = uv_np.min(axis=0)
|
||||||
|
uv_max = uv_np.max(axis=0)
|
||||||
|
out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001)
|
||||||
|
if not (normalize and out_of_unit):
|
||||||
|
return uv_np
|
||||||
|
extent = float((uv_max - uv_min).max())
|
||||||
|
span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1]))
|
||||||
|
if span > 1.5 and log_prefix:
|
||||||
|
logging.warning(
|
||||||
|
f"{log_prefix} UV span {span:.2f} looks like a tiled/UDIM layout; "
|
||||||
|
f"uniform-fitting it into [0,1] will overlap tiles. Re-unwrap upstream instead.")
|
||||||
|
if extent > 0:
|
||||||
|
uv_np = ((uv_np - uv_min) / extent).astype(np.float32)
|
||||||
|
if log_prefix:
|
||||||
|
logging.info(f"{log_prefix} normalized UVs into [0,1] (uniform scale 1/{extent:.4f})")
|
||||||
|
return uv_np
|
||||||
|
|
||||||
|
|
||||||
def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
||||||
resolution, texture_size, inpaint_radius=3,
|
resolution, texture_size, uvs, inpaint_radius=3,
|
||||||
fast_unwrap=True, existing_uvs=None,
|
normalize_uvs=True, reference=None, pbar=None):
|
||||||
normalize_uvs=True, sample_mode="trilinear",
|
|
||||||
reference=None, pbar=None):
|
|
||||||
"""Bake a baseColor (+ optional metallicRoughness) texture for
|
"""Bake a baseColor (+ optional metallicRoughness) texture for
|
||||||
`vertices/faces`, rasterizing in UV space and nearest-voxel-sampling each
|
`vertices/faces`, rasterizing in UV space and nearest-voxel-sampling each
|
||||||
texel from the provided sparse colored voxel volume.
|
texel from the provided sparse colored voxel volume.
|
||||||
|
|
||||||
If `existing_uvs` (N, 2) is given, it is used directly and xatlas is
|
`uvs` (N, 2) is the mesh's existing UV layout — baked onto directly (this
|
||||||
skipped — bakes onto the mesh's current UV layout without re-unwrapping.
|
node never unwraps; connect a UV unwrap node upstream). It must be 1:1 with
|
||||||
Otherwise xatlas computes a fresh atlas (verts/faces may grow at seams).
|
`vertices`.
|
||||||
|
|
||||||
Returns (out_vertices, out_faces, out_uvs, out_texture, out_mr).
|
Returns (out_vertices, out_faces, out_uvs, out_texture, out_mr).
|
||||||
|
|
||||||
`fast_unwrap=True` configures xatlas with permissive chart options so it
|
Progress: drives a local tqdm over its 5 stages (uvs → rasterize →
|
||||||
finishes in a reasonable time on large meshes — at the cost of less even
|
|
||||||
UV distribution. Set False to use xatlas defaults (slow on >100k faces).
|
|
||||||
|
|
||||||
Progress: drives a local tqdm over its 5 stages (unwrap → rasterize →
|
|
||||||
back-project → sample → finalize) and, if a comfy `pbar` (ProgressBar) is
|
back-project → sample → finalize) and, if a comfy `pbar` (ProgressBar) is
|
||||||
passed, ticks it once per stage too — so callers should size it as 5 per
|
passed, ticks it once per stage too — so callers should size it as 5 per
|
||||||
bake."""
|
bake."""
|
||||||
@ -845,111 +939,25 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
|||||||
v_np = vertices.detach().cpu().numpy().astype(np.float32)
|
v_np = vertices.detach().cpu().numpy().astype(np.float32)
|
||||||
f_np = faces.detach().cpu().numpy().astype(np.uint32)
|
f_np = faces.detach().cpu().numpy().astype(np.uint32)
|
||||||
fcount = int(f_np.shape[0])
|
fcount = int(f_np.shape[0])
|
||||||
t0 = time.perf_counter()
|
|
||||||
|
|
||||||
if existing_uvs is not None:
|
# Bake onto the mesh's current UVs — no unwrap, no seam-splitting.
|
||||||
# Bake onto the mesh's current UVs — no xatlas, no seam-splitting.
|
uv_np = uvs.detach().cpu().numpy().astype(np.float32)
|
||||||
uv_np = existing_uvs.detach().cpu().numpy().astype(np.float32)
|
if uv_np.shape[0] != v_np.shape[0]:
|
||||||
if uv_np.shape[0] != v_np.shape[0]:
|
raise ValueError(
|
||||||
raise ValueError(
|
f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 "
|
||||||
f"BakeTextureFromVoxel: existing UVs ({uv_np.shape[0]}) must be 1:1 "
|
f"with vertices ({v_np.shape[0]})."
|
||||||
f"with vertices ({v_np.shape[0]})."
|
)
|
||||||
)
|
uv_min = uv_np.min(axis=0)
|
||||||
uv_min = uv_np.min(axis=0)
|
uv_max = uv_np.max(axis=0)
|
||||||
uv_max = uv_np.max(axis=0)
|
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
|
||||||
oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum())
|
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
|
||||||
logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, "
|
f"{fcount} faces")
|
||||||
f"{fcount} faces (xatlas skipped)")
|
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
|
||||||
logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] "
|
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
|
||||||
f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}")
|
uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ")
|
||||||
out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001)
|
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
|
||||||
if normalize_uvs and out_of_unit:
|
|
||||||
# Uniform fit of the UV bbox into [0,1] (preserves chart aspect ratios).
|
|
||||||
# Handles packers that overflow the unit square slightly. NOT a UDIM
|
|
||||||
# de-tiler — a true multi-tile layout would get squashed; warn if the
|
|
||||||
# span is large enough to look like tiling.
|
|
||||||
extent = float((uv_max - uv_min).max())
|
|
||||||
span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1]))
|
|
||||||
if span > 1.5:
|
|
||||||
logging.warning(
|
|
||||||
f"[BakeTextureFromVoxel] UV span {span:.2f} looks like a tiled/UDIM "
|
|
||||||
f"layout; uniform-fitting it into [0,1] will overlap tiles. "
|
|
||||||
f"Re-unwrap instead (use_existing_uvs=False)."
|
|
||||||
)
|
|
||||||
if extent > 0:
|
|
||||||
uv_np = ((uv_np - uv_min) / extent).astype(np.float32)
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] normalized UVs into [0,1] "
|
|
||||||
f"(uniform scale 1/{extent:.4f})")
|
|
||||||
new_verts, new_faces, new_uvs = v_np, f_np, uv_np
|
|
||||||
else:
|
|
||||||
import xatlas
|
|
||||||
if fcount > 300_000:
|
|
||||||
logging.warning(
|
|
||||||
f"[BakeTextureFromVoxel] mesh has {fcount} faces — xatlas chart "
|
|
||||||
f"decomposition is CPU-bound and may take many minutes. Consider "
|
|
||||||
f"decimating to under ~200k faces before baking."
|
|
||||||
)
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] xatlas unwrap: {v_np.shape[0]} verts, {fcount} faces")
|
|
||||||
if fast_unwrap and hasattr(xatlas, "Atlas"):
|
|
||||||
atlas = xatlas.Atlas()
|
|
||||||
atlas.add_mesh(v_np, f_np)
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] add_mesh: {time.perf_counter() - t0:.1f}s")
|
|
||||||
gen_kwargs = {}
|
|
||||||
applied = []
|
|
||||||
# ChartOptions: looser growth → larger / fewer charts → faster.
|
|
||||||
if hasattr(xatlas, "ChartOptions"):
|
|
||||||
co = xatlas.ChartOptions()
|
|
||||||
for attr, val in (
|
|
||||||
("max_iterations", 1),
|
|
||||||
("max_cost", 8.0),
|
|
||||||
("normal_deviation_weight", 1.0),
|
|
||||||
("roundness_weight", 0.0),
|
|
||||||
("straightness_weight", 0.0),
|
|
||||||
("normal_seam_weight", 1.0),
|
|
||||||
("texture_seam_weight", 0.0),
|
|
||||||
("use_input_mesh_uvs", False),
|
|
||||||
):
|
|
||||||
if hasattr(co, attr):
|
|
||||||
setattr(co, attr, val)
|
|
||||||
applied.append(f"chart.{attr}")
|
|
||||||
gen_kwargs["chart_options"] = co
|
|
||||||
# PackOptions.bruteForce defaults to True — tries many rotations per
|
|
||||||
# chart and is the single biggest contributor to pack time on small
|
|
||||||
# meshes. Off it loses ~5-15% packing efficiency but runs ~5× faster.
|
|
||||||
if hasattr(xatlas, "PackOptions"):
|
|
||||||
po = xatlas.PackOptions()
|
|
||||||
for attr, val in (
|
|
||||||
("bruteForce", False),
|
|
||||||
("brute_force", False), # snake_case alias on some builds
|
|
||||||
("create_image", False),
|
|
||||||
("createImage", False),
|
|
||||||
("padding", 2),
|
|
||||||
):
|
|
||||||
if hasattr(po, attr):
|
|
||||||
setattr(po, attr, val)
|
|
||||||
applied.append(f"pack.{attr}")
|
|
||||||
gen_kwargs["pack_options"] = po
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] options applied: {applied}")
|
|
||||||
tgen = time.perf_counter()
|
|
||||||
try:
|
|
||||||
atlas.generate(**gen_kwargs)
|
|
||||||
except TypeError as e:
|
|
||||||
logging.warning(f"[BakeTextureFromVoxel] generate(**kwargs) rejected ({e}); retrying with defaults")
|
|
||||||
atlas.generate()
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] generate: {time.perf_counter() - tgen:.1f}s")
|
|
||||||
tget = time.perf_counter()
|
|
||||||
vmapping, indices, uvs = atlas[0]
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] retrieve: {time.perf_counter() - tget:.1f}s")
|
|
||||||
else:
|
|
||||||
vmapping, indices, uvs = xatlas.parametrize(v_np, f_np)
|
|
||||||
|
|
||||||
logging.info(f"[BakeTextureFromVoxel] xatlas total {time.perf_counter() - t0:.1f}s "
|
_tick("uvs")
|
||||||
f"({vmapping.shape[0]} verts after seams)")
|
|
||||||
new_verts = v_np[vmapping]
|
|
||||||
new_faces = indices.astype(np.uint32)
|
|
||||||
new_uvs = uvs.astype(np.float32)
|
|
||||||
|
|
||||||
_tick("unwrap")
|
|
||||||
|
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
position_map, mask = _bake_position_map(new_verts, new_faces, new_uvs, texture_size)
|
position_map, mask = _bake_position_map(new_verts, new_faces, new_uvs, texture_size)
|
||||||
@ -968,7 +976,7 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors,
|
|||||||
|
|
||||||
t2 = time.perf_counter()
|
t2 = time.perf_counter()
|
||||||
attrs = _sample_voxel_attrs_per_texel(
|
attrs = _sample_voxel_attrs_per_texel(
|
||||||
position_map, mask, voxel_coords, voxel_colors, resolution, mode=sample_mode,
|
position_map, mask, voxel_coords, voxel_colors, resolution,
|
||||||
)
|
)
|
||||||
logging.info(f"[BakeTextureFromVoxel] voxel sample in {time.perf_counter() - t2:.1f}s "
|
logging.info(f"[BakeTextureFromVoxel] voxel sample in {time.perf_counter() - t2:.1f}s "
|
||||||
f"({attrs.shape[-1]} channels)")
|
f"({attrs.shape[-1]} channels)")
|
||||||
@ -1022,12 +1030,11 @@ def _per_vertex_normals(verts_np, faces_np):
|
|||||||
|
|
||||||
|
|
||||||
def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resolution,
|
def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resolution,
|
||||||
texture_size, views, blend_temperature=0.25,
|
texture_size, views, uvs, blend_temperature=0.25,
|
||||||
inpaint_radius=3, fast_unwrap=True, existing_uvs=None,
|
inpaint_radius=3, normalize_uvs=True):
|
||||||
normalize_uvs=True, sample_mode="trilinear"):
|
|
||||||
"""Bake a baseColor texture by projecting view photos onto the mesh.
|
"""Bake a baseColor texture by projecting view photos onto the mesh.
|
||||||
|
|
||||||
Reuses bake_texture_from_voxel_fn for the xatlas unwrap + the nearest-voxel
|
Reuses bake_texture_from_voxel_fn for the UV-space bake + the nearest-voxel
|
||||||
fallback colour, then overlays photo colour on every covered+visible texel:
|
fallback colour, then overlays photo colour on every covered+visible texel:
|
||||||
each texel's world position/normal is projected into each view, occlusion is
|
each texel's world position/normal is projected into each view, occlusion is
|
||||||
resolved with a texel z-buffer, and the views are blended weighted by how
|
resolved with a texel z-buffer, and the views are blended weighted by how
|
||||||
@ -1045,8 +1052,8 @@ def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resol
|
|||||||
# Voxel bake → unwrapped geometry + fallback colour (inpaint deferred to the end).
|
# Voxel bake → unwrapped geometry + fallback colour (inpaint deferred to the end).
|
||||||
out_v, out_f, out_uvs, voxel_tex, voxel_mr = bake_texture_from_voxel_fn(
|
out_v, out_f, out_uvs, voxel_tex, voxel_mr = bake_texture_from_voxel_fn(
|
||||||
vertices, faces, voxel_coords, voxel_colors, resolution=resolution,
|
vertices, faces, voxel_coords, voxel_colors, resolution=resolution,
|
||||||
texture_size=texture_size, inpaint_radius=0, fast_unwrap=fast_unwrap,
|
texture_size=texture_size, uvs=uvs, inpaint_radius=0,
|
||||||
existing_uvs=existing_uvs, normalize_uvs=normalize_uvs, sample_mode=sample_mode)
|
normalize_uvs=normalize_uvs)
|
||||||
|
|
||||||
v_np = out_v.detach().cpu().numpy().astype(np.float32)
|
v_np = out_v.detach().cpu().numpy().astype(np.float32)
|
||||||
f_np = out_f.detach().cpu().numpy().astype(np.uint32)
|
f_np = out_f.detach().cpu().numpy().astype(np.uint32)
|
||||||
@ -1078,6 +1085,16 @@ def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resol
|
|||||||
return out_v, out_f, out_uvs, out_tex, voxel_mr
|
return out_v, out_f, out_uvs, out_tex, voxel_mr
|
||||||
|
|
||||||
|
|
||||||
|
def _mr_channel(packed_mr, ch, ref):
|
||||||
|
"""Pull one channel out of a packed glTF MR map (G=roughness at idx 1, B=metallic
|
||||||
|
at idx 2) as a 3-channel grayscale IMAGE [H,W,3] in [0,1]. Returns black sized
|
||||||
|
like `ref` when there's no MR map (non-PBR voxel field)."""
|
||||||
|
if packed_mr is None:
|
||||||
|
return torch.zeros_like(ref.float().cpu())
|
||||||
|
m = packed_mr.float().clamp(0.0, 1.0).cpu()
|
||||||
|
return m[..., ch:ch + 1].expand(-1, -1, 3).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class BakeTextureFromVoxel(IO.ComfyNode):
|
class BakeTextureFromVoxel(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -1089,12 +1106,12 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
"Bakes PBR textures onto the mesh's existing UV layout by rasterizing it "
|
"Bakes PBR textures onto the mesh's existing UV layout by rasterizing it "
|
||||||
"in UV space via OpenGL (ComfyUI's PyOpenGL backend) and trilinear-sampling "
|
"in UV space via OpenGL (ComfyUI's PyOpenGL backend) and trilinear-sampling "
|
||||||
"the input sparse voxel volume. Does NOT unwrap — connect a UV unwrap node "
|
"the input sparse voxel volume. Does NOT unwrap — connect a UV unwrap node "
|
||||||
"(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Produces a "
|
"(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Outputs the "
|
||||||
"baseColor texture, plus a metallicRoughness texture when the voxel field "
|
"baked maps as IMAGEs: base_color, plus metallic and roughness as separate "
|
||||||
"carries the full PBR set (6 channels). Returns a Mesh with `uvs`, `texture`, "
|
"grayscale maps (both black when the voxel field has no PBR set). "
|
||||||
"and `metallic_roughness` attached — SaveGLB serializes them as real "
|
"Preview/save/post-process them, then feed them to ApplyTextureToMesh (with "
|
||||||
"baseColorTexture / metallicRoughnessTexture maps. UVs that spill outside "
|
"the SAME mesh) to attach them for SaveGLB. UVs that spill outside [0,1] are "
|
||||||
"[0,1] are uniformly fit back into the unit square."
|
"uniformly fit into the unit square."
|
||||||
),
|
),
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mesh.Input("mesh"),
|
IO.Mesh.Input("mesh"),
|
||||||
@ -1108,7 +1125,11 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
"o_voxel.to_glb step that removes faceted/pixelized baking on coarse "
|
"o_voxel.to_glb step that removes faceted/pixelized baking on coarse "
|
||||||
"meshes. Pure scipy+torch, no extra deps.")),
|
"meshes. Pure scipy+torch, no extra deps.")),
|
||||||
],
|
],
|
||||||
outputs=[IO.Mesh.Output("mesh")],
|
outputs=[
|
||||||
|
IO.Image.Output(display_name="base_color"),
|
||||||
|
IO.Image.Output(display_name="metallic"),
|
||||||
|
IO.Image.Output(display_name="roughness"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1132,7 +1153,7 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
batch_idx = coords[:, 0].long()
|
batch_idx = coords[:, 0].long()
|
||||||
voxel_xyz = coords[:, 1:]
|
voxel_xyz = coords[:, 1:]
|
||||||
mesh_batch_size = int(mesh.vertices.shape[0])
|
mesh_batch_size = int(mesh.vertices.shape[0])
|
||||||
out_verts, out_faces, out_uvs, out_tex, out_mr = [], [], [], [], []
|
out_tex, out_mr = [], []
|
||||||
# 5 stage ticks per item (see bake_texture_from_voxel_fn); skipped items
|
# 5 stage ticks per item (see bake_texture_from_voxel_fn); skipped items
|
||||||
# tick all 5 so the bar stays aligned.
|
# tick all 5 so the bar stays aligned.
|
||||||
pbar = comfy.utils.ProgressBar(mesh_batch_size * 5)
|
pbar = comfy.utils.ProgressBar(mesh_batch_size * 5)
|
||||||
@ -1150,28 +1171,25 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
if reference_mesh is not None:
|
if reference_mesh is not None:
|
||||||
rv_i, rf_i, _ = get_mesh_batch_item(reference_mesh, i)
|
rv_i, rf_i, _ = get_mesh_batch_item(reference_mesh, i)
|
||||||
ref_i = (rv_i, rf_i)
|
ref_i = (rv_i, rf_i)
|
||||||
bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn(
|
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
|
||||||
v_i, f_i, item_coords, item_colors,
|
v_i, f_i, item_coords, item_colors,
|
||||||
resolution=resolution, texture_size=texture_size,
|
resolution=resolution, texture_size=texture_size,
|
||||||
inpaint_radius=inpaint_radius,
|
uvs=ev_i, inpaint_radius=inpaint_radius,
|
||||||
existing_uvs=ev_i, reference=ref_i, pbar=pbar,
|
reference=ref_i, pbar=pbar,
|
||||||
)
|
)
|
||||||
out_verts.append(bv); out_faces.append(bf); out_uvs.append(bu)
|
|
||||||
out_tex.append(bt); out_mr.append(bmr)
|
out_tex.append(bt); out_mr.append(bmr)
|
||||||
if not out_verts:
|
if not out_tex:
|
||||||
return IO.NodeOutput(mesh)
|
# Every item skipped (degenerate) — emit one black map so the IMAGE
|
||||||
# Local pack_variable_mesh_batch doesn't take uvs/texture; build the
|
# outputs stay valid.
|
||||||
# packed mesh ourselves so we can attach both. UVs are 1:1 with verts.
|
black = torch.zeros((1, texture_size, texture_size, 3))
|
||||||
packed = pack_variable_mesh_batch(out_verts, out_faces)
|
return IO.NodeOutput(black, black, black)
|
||||||
max_v = packed.vertices.shape[1]
|
# All maps are texture_size² — stack into [B,H,W,3] IMAGE batches. The
|
||||||
packed_uvs = out_uvs[0].new_zeros((len(out_uvs), max_v, 2))
|
# packed MR (G=roughness, B=metallic) is split into separate grayscale
|
||||||
for i, u in enumerate(out_uvs):
|
# maps; both black where the voxel field carried no PBR set.
|
||||||
packed_uvs[i, :u.shape[0]] = u
|
base_img = torch.stack([t.float().clamp(0.0, 1.0).cpu() for t in out_tex], dim=0)
|
||||||
packed.uvs = packed_uvs
|
metallic_img = torch.stack([_mr_channel(m, 2, out_tex[0]) for m in out_mr], dim=0)
|
||||||
packed.texture = torch.stack(out_tex, dim=0)
|
roughness_img = torch.stack([_mr_channel(m, 1, out_tex[0]) for m in out_mr], dim=0)
|
||||||
if all(mr is not None for mr in out_mr):
|
return IO.NodeOutput(base_img, metallic_img, roughness_img)
|
||||||
packed.metallic_roughness = torch.stack(out_mr, dim=0)
|
|
||||||
return IO.NodeOutput(packed)
|
|
||||||
|
|
||||||
# Single-item path.
|
# Single-item path.
|
||||||
v0 = mesh.vertices.squeeze(0)
|
v0 = mesh.vertices.squeeze(0)
|
||||||
@ -1181,19 +1199,16 @@ class BakeTextureFromVoxel(IO.ComfyNode):
|
|||||||
if reference_mesh is not None:
|
if reference_mesh is not None:
|
||||||
ref0 = (reference_mesh.vertices.squeeze(0), reference_mesh.faces.squeeze(0))
|
ref0 = (reference_mesh.vertices.squeeze(0), reference_mesh.faces.squeeze(0))
|
||||||
pbar = comfy.utils.ProgressBar(5) # 5 stage ticks (see bake_texture_from_voxel_fn)
|
pbar = comfy.utils.ProgressBar(5) # 5 stage ticks (see bake_texture_from_voxel_fn)
|
||||||
bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn(
|
_bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn(
|
||||||
v0, f0, coords, colors,
|
v0, f0, coords, colors,
|
||||||
resolution=resolution, texture_size=texture_size,
|
resolution=resolution, texture_size=texture_size,
|
||||||
inpaint_radius=inpaint_radius,
|
uvs=ev0, inpaint_radius=inpaint_radius,
|
||||||
existing_uvs=ev0, reference=ref0, pbar=pbar,
|
reference=ref0, pbar=pbar,
|
||||||
)
|
)
|
||||||
out_mesh = Types.MESH(
|
base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0)
|
||||||
vertices=bv.unsqueeze(0), faces=bf.unsqueeze(0),
|
metallic_img = _mr_channel(bmr, 2, bt).unsqueeze(0)
|
||||||
uvs=bu.unsqueeze(0), texture=bt.unsqueeze(0),
|
roughness_img = _mr_channel(bmr, 1, bt).unsqueeze(0)
|
||||||
)
|
return IO.NodeOutput(base_img, metallic_img, roughness_img)
|
||||||
if bmr is not None:
|
|
||||||
out_mesh.metallic_roughness = bmr.unsqueeze(0)
|
|
||||||
return IO.NodeOutput(out_mesh)
|
|
||||||
|
|
||||||
|
|
||||||
class MeshTextureToImage(IO.ComfyNode):
|
class MeshTextureToImage(IO.ComfyNode):
|
||||||
@ -1247,6 +1262,73 @@ class MeshTextureToImage(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(base, mr, metallic, roughness)
|
return IO.NodeOutput(base, mr, metallic, roughness)
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyTextureToMesh(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ApplyTextureToMesh",
|
||||||
|
display_name="Apply Texture to Mesh",
|
||||||
|
category="latent/3d",
|
||||||
|
description=(
|
||||||
|
"Attaches baked texture IMAGEs to a mesh's existing UV layout so SaveGLB "
|
||||||
|
"serializes them as baseColorTexture / metallicRoughnessTexture maps. Pairs "
|
||||||
|
"with BakeTextureFromVoxel: feed it the SAME mesh you baked from, plus that "
|
||||||
|
"node's base_color (and optionally metallic / roughness grayscale maps) — the "
|
||||||
|
"UVs must match the ones the texture was baked against, so don't re-unwrap in "
|
||||||
|
"between. metallic and roughness are repacked into the glTF MR map "
|
||||||
|
"(G=roughness, B=metallic); leave them unconnected for non-PBR meshes (a "
|
||||||
|
"missing metallic defaults to 0, a missing roughness to 1). Lets you preview / "
|
||||||
|
"upscale / edit the baked maps before applying them."
|
||||||
|
),
|
||||||
|
inputs=[
|
||||||
|
IO.Mesh.Input("mesh"),
|
||||||
|
IO.Image.Input("base_color"),
|
||||||
|
IO.Image.Input("metallic", optional=True),
|
||||||
|
IO.Image.Input("roughness", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mesh.Output("mesh")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, mesh, base_color, metallic=None, roughness=None):
|
||||||
|
mesh_uvs = getattr(mesh, "uvs", None)
|
||||||
|
if mesh_uvs is None:
|
||||||
|
raise ValueError(
|
||||||
|
"ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh "
|
||||||
|
"you fed to BakeTextureFromVoxel (this node attaches onto existing UVs and "
|
||||||
|
"never unwraps).")
|
||||||
|
|
||||||
|
# Re-derive the exact UVs the bake rasterized against — it uniformly fits
|
||||||
|
# out-of-[0,1] layouts into the unit square, so apply the identical
|
||||||
|
# deterministic transform here (per batch item, over each item's real verts).
|
||||||
|
if mesh_uvs.ndim == 3:
|
||||||
|
new_uvs = mesh_uvs.clone()
|
||||||
|
for i in range(mesh_uvs.shape[0]):
|
||||||
|
v_i, _f_i, _ = get_mesh_batch_item(mesh, i)
|
||||||
|
n = v_i.shape[0]
|
||||||
|
norm = _normalize_uvs_to_unit(mesh_uvs[i, :n].detach().cpu().numpy())
|
||||||
|
new_uvs[i, :n] = torch.from_numpy(norm).to(new_uvs)
|
||||||
|
else:
|
||||||
|
norm = _normalize_uvs_to_unit(mesh_uvs.detach().cpu().numpy())
|
||||||
|
new_uvs = torch.from_numpy(norm).to(mesh_uvs)
|
||||||
|
|
||||||
|
out_mesh = copy.copy(mesh)
|
||||||
|
out_mesh.uvs = new_uvs
|
||||||
|
out_mesh.texture = base_color.float().clamp(0.0, 1.0).cpu()
|
||||||
|
if metallic is not None or roughness is not None:
|
||||||
|
# Repack separate grayscale maps into glTF MR: R unused, G=roughness,
|
||||||
|
# B=metallic. Size defaults off whichever map is connected; a missing
|
||||||
|
# channel falls back to a sensible scalar (metal 0 / rough 1).
|
||||||
|
prov = (metallic if metallic is not None else roughness).float().clamp(0.0, 1.0).cpu()
|
||||||
|
B, H, W, _ = prov.shape
|
||||||
|
rough_ch = (roughness.float().clamp(0.0, 1.0).cpu()[..., 0:1]
|
||||||
|
if roughness is not None else torch.ones((B, H, W, 1)))
|
||||||
|
metal_ch = (metallic.float().clamp(0.0, 1.0).cpu()[..., 0:1]
|
||||||
|
if metallic is not None else torch.zeros((B, H, W, 1)))
|
||||||
|
out_mesh.metallic_roughness = torch.cat([torch.zeros((B, H, W, 1)), rough_ch, metal_ch], dim=-1)
|
||||||
|
return IO.NodeOutput(out_mesh)
|
||||||
|
|
||||||
|
|
||||||
def paint_mesh_default_colors(mesh):
|
def paint_mesh_default_colors(mesh):
|
||||||
out_mesh = copy.copy(mesh)
|
out_mesh = copy.copy(mesh)
|
||||||
vertex_count = mesh.vertices.shape[1]
|
vertex_count = mesh.vertices.shape[1]
|
||||||
@ -1397,14 +1479,14 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
|
|||||||
for c, n in zip(unique_nm.tolist(), cnt_nm.tolist()):
|
for c, n in zip(unique_nm.tolist(), cnt_nm.tolist()):
|
||||||
nm_share_breakdown.append(f"{n} edges×{c}faces")
|
nm_share_breakdown.append(f"{n} edges×{c}faces")
|
||||||
|
|
||||||
logging.info(f"[FillHolesV2 diag] V={V} F={F} | "
|
logging.info(f"[FillHoles diag] V={V} F={F} | "
|
||||||
f"boundary(cnt==1)={n_boundary} interior(cnt==2)={n_interior} "
|
f"boundary(cnt==1)={n_boundary} interior(cnt==2)={n_interior} "
|
||||||
f"non-manifold(cnt>=3)={n_nonmanifold} (max={nm_max})")
|
f"non-manifold(cnt>=3)={n_nonmanifold} (max={nm_max})")
|
||||||
if nm_share_breakdown:
|
if nm_share_breakdown:
|
||||||
logging.info(f"[FillHolesV2 diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}")
|
logging.info(f"[FillHoles diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}")
|
||||||
|
|
||||||
if n_boundary == 0:
|
if n_boundary == 0:
|
||||||
logging.info("[FillHolesV2 diag] no boundary edges → no cycles to fill")
|
logging.info("[FillHoles diag] no boundary edges → no cycles to fill")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Walk components same as production path (bidir-prop, by-vertex count).
|
# Walk components same as production path (bidir-prop, by-vertex count).
|
||||||
@ -1459,15 +1541,15 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
|
|||||||
cfan_tris = int(vert_count[cfan].sum().item()) # N tris per N-vert cycle
|
cfan_tris = int(vert_count[cfan].sum().item()) # N tris per N-vert cycle
|
||||||
cfan_new_verts = int(cfan.sum().item()) # 1 centroid per centroid-fan component
|
cfan_new_verts = int(cfan.sum().item()) # 1 centroid per centroid-fan component
|
||||||
|
|
||||||
logging.info(f"[FillHolesV2 diag] components={L} "
|
logging.info(f"[FillHoles diag] components={L} "
|
||||||
f"cycles={int(is_cycle.sum().item())} chains={int(is_chain.sum().item())} "
|
f"cycles={int(is_cycle.sum().item())} chains={int(is_chain.sum().item())} "
|
||||||
f"non-simple={int(is_other.sum().item())}")
|
f"non-simple={int(is_other.sum().item())}")
|
||||||
logging.info(f"[FillHolesV2 diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})")
|
logging.info(f"[FillHoles diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})")
|
||||||
logging.info(f"[FillHolesV2 diag] actually kept={int(actually_kept.sum().item())} "
|
logging.info(f"[FillHoles diag] actually kept={int(actually_kept.sum().item())} "
|
||||||
f"cycle rejected by perim={int((is_cycle & ~cycle_perim_ok).sum().item())} "
|
f"cycle rejected by perim={int((is_cycle & ~cycle_perim_ok).sum().item())} "
|
||||||
f"cycle rejected by verts={int((is_cycle & ~cycle_size_ok).sum().item())}")
|
f"cycle rejected by verts={int((is_cycle & ~cycle_size_ok).sum().item())}")
|
||||||
logging.info(f"[FillHolesV2 diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)")
|
logging.info(f"[FillHoles diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)")
|
||||||
logging.info(f"[FillHolesV2 diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts")
|
logging.info(f"[FillHoles diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts")
|
||||||
|
|
||||||
# Cycle vert-count distribution
|
# Cycle vert-count distribution
|
||||||
if is_cycle.any():
|
if is_cycle.any():
|
||||||
@ -1485,12 +1567,12 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter):
|
|||||||
elif s <= 20: buckets["11-20"] += n
|
elif s <= 20: buckets["11-20"] += n
|
||||||
elif s <= 50: buckets["21-50"] += n
|
elif s <= 50: buckets["21-50"] += n
|
||||||
else: buckets["51+"] += n
|
else: buckets["51+"] += n
|
||||||
logging.info(f"[FillHolesV2 diag] cycle vert-count buckets: {buckets}")
|
logging.info(f"[FillHoles diag] cycle vert-count buckets: {buckets}")
|
||||||
|
|
||||||
if is_cycle.any():
|
if is_cycle.any():
|
||||||
cycle_perims = perim[is_cycle].cpu().tolist()
|
cycle_perims = perim[is_cycle].cpu().tolist()
|
||||||
head = sorted(cycle_perims, reverse=True)[:10]
|
head = sorted(cycle_perims, reverse=True)[:10]
|
||||||
logging.info(f"[FillHolesV2 diag] top-10 cycle perimeters: "
|
logging.info(f"[FillHoles diag] top-10 cycle perimeters: "
|
||||||
f"{['%.4f' % p for p in head]}")
|
f"{['%.4f' % p for p in head]}")
|
||||||
|
|
||||||
|
|
||||||
@ -1804,10 +1886,10 @@ def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsi
|
|||||||
n_escalations += 1
|
n_escalations += 1
|
||||||
if total_welded > 0 or n_escalations > 0:
|
if total_welded > 0 or n_escalations > 0:
|
||||||
tag = f" (escalated weld epsilon_rel→{eps:.1e} after {n_escalations} step{'s' if n_escalations != 1 else ''})" if n_escalations > 0 else ""
|
tag = f" (escalated weld epsilon_rel→{eps:.1e} after {n_escalations} step{'s' if n_escalations != 1 else ''})" if n_escalations > 0 else ""
|
||||||
logging.info(f"[FillHolesV2] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}")
|
logging.info(f"[FillHoles] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}")
|
||||||
if ratio >= WELDED_THRESHOLD:
|
if ratio >= WELDED_THRESHOLD:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"[FillHolesV2] even at weld epsilon_rel={WELD_CAP} the mesh stays "
|
f"[FillHoles] even at weld epsilon_rel={WELD_CAP} the mesh stays "
|
||||||
f"unwelded (V/F={ratio:.2f}, want < {WELDED_THRESHOLD}). Source mesh has "
|
f"unwelded (V/F={ratio:.2f}, want < {WELDED_THRESHOLD}). Source mesh has "
|
||||||
f"duplicate verts at distances >{WELD_CAP}× bbox; fix upstream "
|
f"duplicate verts at distances >{WELD_CAP}× bbox; fix upstream "
|
||||||
f"(decimate node settings) or run WeldVertices manually with a larger epsilon."
|
f"(decimate node settings) or run WeldVertices manually with a larger epsilon."
|
||||||
@ -2077,6 +2159,23 @@ def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4):
|
|||||||
|
|
||||||
return offset_verts, f_unwelded, None
|
return offset_verts, f_unwelded, None
|
||||||
|
|
||||||
|
def _fmt_count(n) -> str:
|
||||||
|
"""Compact human-readable integer for node status lines, e.g. 853, 12.3K, 1.23M."""
|
||||||
|
n = int(n)
|
||||||
|
if n >= 1_000_000:
|
||||||
|
return f"{n / 1_000_000:.2f}".rstrip("0").rstrip(".") + "M"
|
||||||
|
if n >= 1_000:
|
||||||
|
return f"{n / 1_000:.1f}".rstrip("0").rstrip(".") + "K"
|
||||||
|
return str(n)
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_face_change(n_in, n_out) -> str:
|
||||||
|
"""'faces: 1.23M → 200K (-84%)' — the count delta for decimate/remesh status."""
|
||||||
|
n_in, n_out = int(n_in), int(n_out)
|
||||||
|
pct = f" ({(n_out - n_in) / n_in * 100:+.0f}%)" if n_in else ""
|
||||||
|
return f"faces: {_fmt_count(n_in)} → {_fmt_count(n_out)}{pct}"
|
||||||
|
|
||||||
|
|
||||||
class DecimateMesh(IO.ComfyNode):
|
class DecimateMesh(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -2168,7 +2267,7 @@ class DecimateMesh(IO.ComfyNode):
|
|||||||
# Send progress text to display the face reduction on the node
|
# Send progress text to display the face reduction on the node
|
||||||
if cls.hidden.unique_id:
|
if cls.hidden.unique_id:
|
||||||
PromptServer.instance.send_progress_text(
|
PromptServer.instance.send_progress_text(
|
||||||
f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id)
|
_fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -2306,7 +2405,7 @@ class RemeshMesh(IO.ComfyNode):
|
|||||||
# Send progress text to display the face change on the node
|
# Send progress text to display the face change on the node
|
||||||
if cls.hidden.unique_id:
|
if cls.hidden.unique_id:
|
||||||
PromptServer.instance.send_progress_text(
|
PromptServer.instance.send_progress_text(
|
||||||
f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id)
|
_fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -2546,7 +2645,8 @@ class UnwrapMesh(IO.ComfyNode):
|
|||||||
|
|
||||||
if cls.hidden.unique_id:
|
if cls.hidden.unique_id:
|
||||||
PromptServer.instance.send_progress_text(
|
PromptServer.instance.send_progress_text(
|
||||||
f"UV: {out_v[0].shape[0]}v / {out_f[0].shape[0]}f, atlas ~{resolution}px",
|
f"UV: {_fmt_count(out_v[0].shape[0])} verts / {_fmt_count(out_f[0].shape[0])} faces"
|
||||||
|
f" · atlas ~{resolution}px",
|
||||||
cls.hidden.unique_id)
|
cls.hidden.unique_id)
|
||||||
return IO.NodeOutput(out_mesh)
|
return IO.NodeOutput(out_mesh)
|
||||||
|
|
||||||
@ -2740,36 +2840,12 @@ class FillHoles(IO.ComfyNode):
|
|||||||
node_id="FillHoles",
|
node_id="FillHoles",
|
||||||
display_name="Fill Holes",
|
display_name="Fill Holes",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
description="Fills holes in a mesh up to a maximum perimeter threshold.",
|
|
||||||
inputs=[
|
|
||||||
IO.Mesh.Input("mesh"),
|
|
||||||
IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001,
|
|
||||||
tooltip="Maximum hole perimeter to fill. Set to 0 to disable."),
|
|
||||||
],
|
|
||||||
outputs=[IO.Mesh.Output("mesh")],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, mesh, max_perimeter):
|
|
||||||
def _fn(v, f, c):
|
|
||||||
if max_perimeter > 0:
|
|
||||||
v, f = fill_holes_fn(v, f, max_perimeter=max_perimeter)
|
|
||||||
return v, f, c
|
|
||||||
return _process_mesh_batch(mesh, _fn)
|
|
||||||
|
|
||||||
class FillHolesV2(IO.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="FillHolesV2",
|
|
||||||
display_name="Fill Holes (v2)",
|
|
||||||
category="latent/3d",
|
|
||||||
description=(
|
description=(
|
||||||
"GPU-vectorised hole-filling via directed-half-edge pointer-doubling. "
|
"Fills holes in a mesh up to a maximum perimeter threshold, preserving "
|
||||||
"Drop-in alternative to FillHoles for comparison: same max_perimeter "
|
"the existing geometry/UVs (only patch triangles are added). GPU-vectorised "
|
||||||
"cutoff and fan-from-centroid triangulation, but no Python loop, "
|
"via directed-half-edge pointer-doubling: no Python loop, auto-correct "
|
||||||
"auto-correct winding from face direction, and centroid colors are "
|
"winding from face direction, and centroid colors are averaged from the hole "
|
||||||
"averaged from the loop instead of left zero."
|
"loop. Falls back to a CPU walker on non-CUDA devices."
|
||||||
),
|
),
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mesh.Input("mesh"),
|
IO.Mesh.Input("mesh"),
|
||||||
@ -2947,7 +3023,6 @@ class PostProcessMeshExtension(ComfyExtension):
|
|||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
FillHoles,
|
FillHoles,
|
||||||
FillHolesV2,
|
|
||||||
WeldVertices,
|
WeldVertices,
|
||||||
DecimateMesh,
|
DecimateMesh,
|
||||||
RemeshMesh,
|
RemeshMesh,
|
||||||
@ -2956,6 +3031,7 @@ class PostProcessMeshExtension(ComfyExtension):
|
|||||||
PaintMesh,
|
PaintMesh,
|
||||||
BakeTextureFromVoxel,
|
BakeTextureFromVoxel,
|
||||||
MeshTextureToImage,
|
MeshTextureToImage,
|
||||||
|
ApplyTextureToMesh,
|
||||||
MergeMeshes,
|
MergeMeshes,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,18 +1,15 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types, io
|
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor
|
from comfy.ldm.trellis2.vae import SparseTensor
|
||||||
from comfy.ldm.trellis2.model import (
|
from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats
|
||||||
_build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats,
|
|
||||||
)
|
|
||||||
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
|
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
|
||||||
|
|
||||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.ldm.trellis2 import sampling_preview
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
@ -21,89 +18,6 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
|||||||
NAFModel = io.Custom("NAF_MODEL")
|
NAFModel = io.Custom("NAF_MODEL")
|
||||||
|
|
||||||
|
|
||||||
# Texture latent -> base-color calibration for the per-step preview
|
|
||||||
def _tex_rgb_factors_path():
|
|
||||||
return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt")
|
|
||||||
|
|
||||||
|
|
||||||
def _pool_albedo_to_input(in_coords, out_coords, out_colors):
|
|
||||||
in_sp = in_coords[:, 1:4].long()
|
|
||||||
out_sp = out_coords[:, 1:4].long()
|
|
||||||
in_b = in_coords[:, 0].long()
|
|
||||||
out_b = out_coords[:, 0].long()
|
|
||||||
in_res = int(in_sp.max().item()) + 1
|
|
||||||
out_res = int(out_sp.max().item()) + 1
|
|
||||||
parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1)
|
|
||||||
R = in_res
|
|
||||||
in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2]
|
|
||||||
par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2]
|
|
||||||
order = torch.argsort(in_flat)
|
|
||||||
in_sorted = in_flat[order]
|
|
||||||
pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1)
|
|
||||||
matched = in_sorted[pos] == par_flat
|
|
||||||
in_idx = order[pos][matched]
|
|
||||||
cols = out_colors[matched].float()
|
|
||||||
N = in_coords.shape[0]
|
|
||||||
csum = cols.new_zeros((N, 3))
|
|
||||||
ccount = cols.new_zeros((N, 1))
|
|
||||||
csum.index_add_(0, in_idx, cols)
|
|
||||||
ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype))
|
|
||||||
valid = ccount[:, 0] > 0
|
|
||||||
albedo = torch.zeros_like(csum)
|
|
||||||
albedo[valid] = csum[valid] / ccount[valid]
|
|
||||||
return albedo, valid
|
|
||||||
|
|
||||||
|
|
||||||
def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords):
|
|
||||||
"""Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish."""
|
|
||||||
try:
|
|
||||||
dev = out_colors.device
|
|
||||||
in_latent = in_latent.to(dev)
|
|
||||||
in_coords = in_coords.to(dev)
|
|
||||||
out_coords = out_coords.to(dev)
|
|
||||||
albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors)
|
|
||||||
X = in_latent[valid].float().cpu()
|
|
||||||
Y = albedo[valid].float().cpu()
|
|
||||||
if X.shape[0] < 64:
|
|
||||||
return
|
|
||||||
Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1]
|
|
||||||
A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1]
|
|
||||||
B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3]
|
|
||||||
|
|
||||||
path = _tex_rgb_factors_path()
|
|
||||||
if os.path.exists(path):
|
|
||||||
try:
|
|
||||||
prev = torch.load(path, map_location="cpu")
|
|
||||||
A_run = A_run + prev["A"]
|
|
||||||
B_run = B_run + prev["B"]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
||||||
torch.save({"A": A_run, "B": B_run}, path)
|
|
||||||
|
|
||||||
eye = torch.eye(A_run.shape[0])
|
|
||||||
WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3]
|
|
||||||
W, b = WB[:-1].contiguous(), WB[-1].contiguous()
|
|
||||||
sampling_preview.set_tex_rgb(W, b)
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def _load_tex_rgb_factors():
|
|
||||||
try:
|
|
||||||
path = _tex_rgb_factors_path()
|
|
||||||
if os.path.exists(path):
|
|
||||||
d = torch.load(path, map_location="cpu")
|
|
||||||
eye = torch.eye(d["A"].shape[0])
|
|
||||||
WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"])
|
|
||||||
sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous())
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
_load_tex_rgb_factors()
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||||
if len(sample_shape) == 5:
|
if len(sample_shape) == 5:
|
||||||
@ -271,7 +185,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
if coord_counts is None:
|
if coord_counts is None:
|
||||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||||
samples = shape_norm(samples.to(device), coords.to(device))
|
samples = shape_norm(samples.to(device), coords.to(device))
|
||||||
mesh, subs = trellis_vae.decode_shape_slat(samples, resolution)
|
mesh, subs = trellis_vae.decode_shape_slat(samples.to(vae.vae_dtype), resolution)
|
||||||
else:
|
else:
|
||||||
split_items = split_batched_sparse_latent(samples, coords, coord_counts)
|
split_items = split_batched_sparse_latent(samples, coords, coord_counts)
|
||||||
mesh = []
|
mesh = []
|
||||||
@ -280,7 +194,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
coords_i = coords_i.to(device).clone()
|
coords_i = coords_i.to(device).clone()
|
||||||
coords_i[:, 0] = 0
|
coords_i[:, 0] = 0
|
||||||
sample_i = shape_norm(feats_i.to(device), coords_i)
|
sample_i = shape_norm(feats_i.to(device), coords_i)
|
||||||
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i, resolution)
|
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i.to(vae.vae_dtype), resolution)
|
||||||
mesh.append(mesh_i[0])
|
mesh.append(mesh_i[0])
|
||||||
subs_per_sample.append(subs_i)
|
subs_per_sample.append(subs_i)
|
||||||
|
|
||||||
@ -338,14 +252,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||||
samples = samples.to(device)
|
samples = samples.to(device)
|
||||||
cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration
|
|
||||||
cal_in_coords = coords
|
|
||||||
std = tex_slat_normalization["std"].to(samples)
|
std = tex_slat_normalization["std"].to(samples)
|
||||||
mean = tex_slat_normalization["mean"].to(samples)
|
mean = tex_slat_normalization["mean"].to(samples)
|
||||||
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
|
voxel = trellis_vae.decode_tex_slat(samples.to(vae.vae_dtype), shape_subdivides)
|
||||||
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
|
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
|
||||||
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
|
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
|
||||||
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
|
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
|
||||||
@ -353,12 +265,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
color_feats = voxel.feats
|
color_feats = voxel.feats
|
||||||
voxel_coords = voxel.coords
|
voxel_coords = voxel.coords
|
||||||
|
|
||||||
# Calibrate the latent->base_color map for the per-step texture preview.
|
|
||||||
# Done here while input coords and voxel_coords share the model frame
|
|
||||||
# (before the z_up remap below) and on the real decoded albedo.
|
|
||||||
if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3:
|
|
||||||
_calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords)
|
|
||||||
|
|
||||||
if coord_resolution is not None:
|
if coord_resolution is not None:
|
||||||
tex_resolution = int(coord_resolution) * 16
|
tex_resolution = int(coord_resolution) * 16
|
||||||
elif voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
elif voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
||||||
@ -416,7 +322,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
decoded_batches = []
|
decoded_batches = []
|
||||||
for start in range(0, sample_tensor.shape[0], batch_number):
|
for start in range(0, sample_tensor.shape[0], batch_number):
|
||||||
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
||||||
decoded_batches.append(shape_vae.decode_structure(sample_chunk) > 0)
|
decoded_batches.append(shape_vae.decode_structure(sample_chunk.to(vae.vae_dtype)) > 0)
|
||||||
decoded = torch.cat(decoded_batches, dim=0)
|
decoded = torch.cat(decoded_batches, dim=0)
|
||||||
current_res = decoded.shape[2]
|
current_res = decoded.shape[2]
|
||||||
|
|
||||||
@ -491,7 +397,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
|||||||
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
||||||
)
|
)
|
||||||
slat = shape_norm(feats.to(device), coords_512.to(device))
|
slat = shape_norm(feats.to(device), coords_512.to(device))
|
||||||
sample_hr_coords = [shape_vae.upsample_shape(slat, upsample_times=4)]
|
sample_hr_coords = [shape_vae.upsample_shape(slat.to(vae.vae_dtype), upsample_times=4)]
|
||||||
else:
|
else:
|
||||||
items = split_batched_sparse_latent(
|
items = split_batched_sparse_latent(
|
||||||
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
||||||
@ -501,7 +407,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
|||||||
coords_i = coords_i.to(device).clone()
|
coords_i = coords_i.to(device).clone()
|
||||||
coords_i[:, 0] = 0
|
coords_i[:, 0] = 0
|
||||||
slat_i = shape_norm(feats_i.to(device), coords_i)
|
slat_i = shape_norm(feats_i.to(device), coords_i)
|
||||||
sample_hr_coords.append(shape_vae.upsample_shape(slat_i, upsample_times=4))
|
sample_hr_coords.append(shape_vae.upsample_shape(slat_i.to(vae.vae_dtype), upsample_times=4))
|
||||||
|
|
||||||
# Resolution search — cache the final iteration's quantized unique tensors
|
# Resolution search — cache the final iteration's quantized unique tensors
|
||||||
# so we don't recompute .unique() per sample after picking hr_resolution.
|
# so we don't recompute .unique() per sample after picking hr_resolution.
|
||||||
@ -977,8 +883,10 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
|||||||
if pad_l or pad_t or pad_r or pad_b:
|
if pad_l or pad_t or pad_r or pad_b:
|
||||||
img = torch.nn.functional.pad(img, (pad_l, pad_r, pad_t, pad_b), value=0.0)
|
img = torch.nn.functional.pad(img, (pad_l, pad_r, pad_t, pad_b), value=0.0)
|
||||||
mask = torch.nn.functional.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=0.0)
|
mask = torch.nn.functional.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=0.0)
|
||||||
crop_x1 += pad_l; crop_x2 += pad_l
|
crop_x1 += pad_l
|
||||||
crop_y1 += pad_t; crop_y2 += pad_t
|
crop_x2 += pad_l
|
||||||
|
crop_y1 += pad_t
|
||||||
|
crop_y2 += pad_t
|
||||||
cropped_img = img [..., crop_y1:crop_y2, crop_x1:crop_x2]
|
cropped_img = img [..., crop_y1:crop_y2, crop_x1:crop_x2]
|
||||||
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
||||||
|
|
||||||
@ -1100,7 +1008,7 @@ class Pixal3DConditioning(IO.ComfyNode):
|
|||||||
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
|
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
|
||||||
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
|
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
|
||||||
scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32)
|
scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32)
|
||||||
T = _build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
|
T = build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
proj_pack = {
|
proj_pack = {
|
||||||
"stages": {
|
"stages": {
|
||||||
@ -1119,15 +1027,6 @@ class Pixal3DConditioning(IO.ComfyNode):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
|
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
|
||||||
# proj_feat_pack rides in the conditioning dict (same place embeds, ControlNet
|
|
||||||
# hints etc. live); the sampler auto-promotes it to a model.forward kwarg via
|
|
||||||
# Trellis2.extra_conds. The same pack object is shared between pos/neg —
|
|
||||||
# CONDConstant.can_concat sees them equal and concats to a single dict, then
|
|
||||||
# Trellis2.forward zeros proj for the uncond slots via cond_or_uncond.
|
|
||||||
# Pre-compute the SS-stage proj features (dense 16³ grid) once here — the
|
|
||||||
# shape/texture stages do their own computes in their respective stage nodes.
|
|
||||||
# proj_pack lives on intermediate (CPU); force the compute onto cuda so
|
|
||||||
# the bilinear-sampling step doesn't run on CPU.
|
|
||||||
ss_proj_feats = compute_stage_proj_feats(
|
ss_proj_feats = compute_stage_proj_feats(
|
||||||
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
|
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
@ -1278,12 +1177,6 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
|||||||
q_mean = Q.mean(dim=0, keepdim=True)
|
q_mean = Q.mean(dim=0, keepdim=True)
|
||||||
P_c = P - p_mean
|
P_c = P - p_mean
|
||||||
Q_c = Q - q_mean
|
Q_c = Q - q_mean
|
||||||
# Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is
|
|
||||||
# noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper
|
|
||||||
# acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c)
|
|
||||||
# gets multiplied by cos(yaw) and shrinks the object. Using
|
|
||||||
# sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of
|
|
||||||
# rotation; translation still positions the mesh at MoGe's centroid.
|
|
||||||
p_var = (P_c * P_c).sum().clamp(min=1e-8)
|
p_var = (P_c * P_c).sum().clamp(min=1e-8)
|
||||||
q_var = (Q_c * Q_c).sum()
|
q_var = (Q_c * Q_c).sum()
|
||||||
scale = float(torch.sqrt(q_var / p_var).item())
|
scale = float(torch.sqrt(q_var / p_var).item())
|
||||||
@ -1326,74 +1219,69 @@ class LoadNAFModel(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
class CFGGuidanceInterval(IO.ComfyNode):
|
class GetMeshInfo(IO.ComfyNode):
|
||||||
"""Generic model patch: apply CFG only during [start_percent, end_percent] of
|
"""Report vertex / face counts and attributes for a MESH, displayed on the
|
||||||
the sampling schedule. Outside that window, skip the uncond computation and
|
node (and as a string output). Counts are comma-formatted since meshes can
|
||||||
collapse to effective cfg=1 — same idea as upstream Trellis2 / Pixal3D's
|
run into the millions of faces. Passes the mesh through unchanged."""
|
||||||
guidance_interval_mixin, but lives at the sampler level (via
|
|
||||||
sampler_calc_cond_batch_function) so it works for any model.
|
|
||||||
|
|
||||||
Percents use ComfyUI's standard convention: 0.0 = start of sampling
|
|
||||||
(max-noise step), 1.0 = end of sampling (clean step). Conversion to sigma
|
|
||||||
is done via model_sampling.percent_to_sigma so the window is portable
|
|
||||||
across schedules (flow / EDM / discrete) and shift settings.
|
|
||||||
|
|
||||||
Defaults are full-range (no bypass). Upstream Trellis2 / Pixal3D
|
|
||||||
pipeline.json sets guidance_interval=[0.6, 1.0] (upstream t-space) on the
|
|
||||||
SS and shape samplers — CFG active only in the first 40% of sampling.
|
|
||||||
Wire (start_percent=0.0, end_percent=0.4) on the SS / shape KSamplers to
|
|
||||||
match. Texture defaults to cfg=1 so the node is moot there."""
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="CFGGuidanceInterval",
|
node_id="GetMeshInfo",
|
||||||
category="model_patches/sampling",
|
display_name="Get Mesh Info",
|
||||||
inputs=[
|
category="latent/3d",
|
||||||
IO.Model.Input("model"),
|
inputs=[IO.Mesh.Input("mesh")],
|
||||||
IO.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001,
|
outputs=[
|
||||||
tooltip="Fraction of sampling at which CFG turns ON (0 = beginning)."),
|
IO.Mesh.Output(display_name="mesh"),
|
||||||
IO.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001,
|
IO.String.Output(display_name="info"),
|
||||||
tooltip="Fraction of sampling at which CFG turns OFF (1 = end)."),
|
|
||||||
],
|
],
|
||||||
outputs=[IO.Model.Output()],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fmt(n: int) -> str:
|
||||||
|
# e.g. 1234567 -> "1,234,567 (1.23M)"; small numbers stay plain.
|
||||||
|
s = f"{n:,}"
|
||||||
|
if n >= 1_000_000:
|
||||||
|
s += f" ({n / 1_000_000:.2f}M)"
|
||||||
|
elif n >= 10_000:
|
||||||
|
s += f" ({n / 1_000:.1f}K)"
|
||||||
|
return s
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, start_percent, end_percent):
|
def execute(cls, mesh):
|
||||||
import comfy.samplers
|
B = mesh.vertices.shape[0]
|
||||||
|
# Honour per-item counts when the batch is zero-padded; else use the row sizes.
|
||||||
|
if mesh.vertex_counts is not None:
|
||||||
|
v_counts = [int(x) for x in mesh.vertex_counts.tolist()]
|
||||||
|
f_counts = [int(x) for x in mesh.face_counts.tolist()]
|
||||||
|
else:
|
||||||
|
v_counts = [int(mesh.vertices.shape[1])] * B
|
||||||
|
f_counts = [int(mesh.faces.shape[1])] * B
|
||||||
|
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
attrs = []
|
||||||
# percent_to_sigma is monotonically decreasing: percent=0 -> sigma_max,
|
for name in ("uvs", "vertex_colors", "normals", "texture", "metallic_roughness"):
|
||||||
# percent=1 -> sigma_min. So start_percent < end_percent in user space
|
t = getattr(mesh, name, None)
|
||||||
# means sigma_start > sigma_end. "Inside the window" is sigma in
|
if t is not None:
|
||||||
# [sigma_end, sigma_start].
|
if name in ("texture", "metallic_roughness"):
|
||||||
sigma_start = float(model_sampling.percent_to_sigma(start_percent))
|
attrs.append(f"{name} {int(t.shape[-3])}×{int(t.shape[-2])}") # H×W
|
||||||
sigma_end = float(model_sampling.percent_to_sigma(end_percent))
|
else:
|
||||||
|
attrs.append(name)
|
||||||
|
|
||||||
def calc_cond_batch_with_interval(args):
|
lines = []
|
||||||
sigma_val = args["sigma"][0].item()
|
if B > 1:
|
||||||
conds = args["conds"]
|
lines.append(f"Batch: {B} meshes")
|
||||||
input_x = args["input"]
|
lines.append(f"Vertices: {cls._fmt(sum(v_counts))} total")
|
||||||
timestep = args["sigma"]
|
lines.append(f"Faces: {cls._fmt(sum(f_counts))} total")
|
||||||
model_ref = args["model"]
|
for i in range(B):
|
||||||
model_opts = args["model_options"]
|
lines.append(f" [{i}] {v_counts[i]:>10,} verts · {f_counts[i]:>10,} faces")
|
||||||
|
else:
|
||||||
|
lines.append(f"Vertices: {cls._fmt(v_counts[0])}")
|
||||||
|
lines.append(f"Faces: {cls._fmt(f_counts[0])}")
|
||||||
|
lines.append(f"Attributes: {', '.join(attrs) if attrs else 'none'}")
|
||||||
|
|
||||||
# conds is typically [cond, uncond]; uncond may be None when ComfyUI's
|
info = "\n".join(lines)
|
||||||
# global cfg=1 optimization has already pruned it.
|
logging.info("[GetMeshInfo]\n%s", info)
|
||||||
cond = conds[0]
|
return IO.NodeOutput(mesh, info, ui=UI.PreviewText(info))
|
||||||
uncond = conds[1] if len(conds) > 1 else None
|
|
||||||
inside = sigma_end <= sigma_val <= sigma_start
|
|
||||||
|
|
||||||
if uncond is None or inside:
|
|
||||||
return comfy.samplers.calc_cond_batch(model_ref, conds, input_x, timestep, model_opts)
|
|
||||||
# Outside the window: compute cond only, mirror it into the uncond slot
|
|
||||||
# so the downstream cfg_function collapses to `cond` (effective cfg=1).
|
|
||||||
out = comfy.samplers.calc_cond_batch(model_ref, [cond], input_x, timestep, model_opts)
|
|
||||||
return [out[0], out[0]]
|
|
||||||
|
|
||||||
m = model.clone()
|
|
||||||
m.model_options["sampler_calc_cond_batch_function"] = calc_cond_batch_with_interval
|
|
||||||
return IO.NodeOutput(m)
|
|
||||||
|
|
||||||
|
|
||||||
class Trellis2Extension(ComfyExtension):
|
class Trellis2Extension(ComfyExtension):
|
||||||
@ -1411,7 +1299,7 @@ class Trellis2Extension(ComfyExtension):
|
|||||||
VaeDecodeShapeTrellis,
|
VaeDecodeShapeTrellis,
|
||||||
VaeDecodeStructureTrellis2,
|
VaeDecodeStructureTrellis2,
|
||||||
Trellis2UpsampleStage,
|
Trellis2UpsampleStage,
|
||||||
CFGGuidanceInterval,
|
GetMeshInfo,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user