From 41f5f4b2c0dfdc2fd1d27e4cf390a81427cf34ec Mon Sep 17 00:00:00 2001 From: kijai Date: Sat, 27 Jun 2026 00:13:13 +0300 Subject: [PATCH] More cleanup --- comfy_extras/nodes_mesh_postprocess.py | 550 ++++++++++++++----------- comfy_extras/nodes_trellis2.py | 246 +++-------- 2 files changed, 380 insertions(+), 416 deletions(-) diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index a0da42139..2c1605759 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -180,12 +180,13 @@ class PaintMesh(IO.ComfyNode): # ============================================================================= # Texture baking from sparse voxel volume. # -# Pipeline: xatlas UV unwrap → OpenGL UV-space rasterize to position map → -# nearest-voxel color sample per texel → cv2.inpaint to fill UV seams → -# attach texture + UVs to the Mesh for SaveGLB to serialize. +# Pipeline: take the mesh's existing UVs → OpenGL UV-space rasterize to position +# map → nearest-voxel color sample per texel → GPU Jump-Flood fill UV seams → +# attach texture + UVs to the Mesh for SaveGLB to serialize. Unwrapping is done +# upstream (Trellis2OfficialUnwrap / TorchXatlasUVWrap); this path never unwraps. # # Uses comfy_extras.nodes_glsl.GLContext for OpenGL context (already handles -# GLFW / EGL / OSMesa backend selection); xatlas for UV parameterization. +# GLFW / EGL / OSMesa backend selection). # ============================================================================= _GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache" @@ -407,6 +408,54 @@ def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution): return vals, ok +def _trilinear_sample_sparse_gpu(positions, voxel_coords_np, color_np, resolution): + """GPU port of `_trilinear_sample_sparse` — same normalized-over-occupied-corners + trilinear, but the per-texel 8-corner accumulation runs on CUDA via sorted-key + `searchsorted` instead of NumPy float64. This is the bake hot path (millions of + covered texels × 8 corners), so the CPU version dominates runtime; the GPU port + is ~identical numerically and 10-50× faster. Returns (vals [K,C] float32, ok + [K] bool), matching the NumPy signature.""" + dev = comfy.model_management.get_torch_device() + R = int(resolution) + origin = -0.5 + voxel_size = 1.0 / R + P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float() + VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long() + col = torch.from_numpy(np.ascontiguousarray(color_np)).to(dev).float() + K, C = P.shape[0], col.shape[1] + M = VC.shape[0] + # Same cell-CENTER convention as the NumPy path (see its docstring): integer + # voxel coord c sits at (c+0.5)*voxel_size + origin, so subtract 0.5 to bracket. + gc = (P - origin) / voxel_size - 0.5 + base = torch.floor(gc).long() + frac = gc - base.float() + key = (VC[:, 0] * R + VC[:, 1]) * R + VC[:, 2] + skey, order = key.sort() + acc = torch.zeros((K, C), device=dev) + wsum = torch.zeros((K, 1), device=dev) + for dx in (0, 1): + wx = frac[:, 0] if dx else 1.0 - frac[:, 0] + for dy in (0, 1): + wy = frac[:, 1] if dy else 1.0 - frac[:, 1] + for dz in (0, 1): + wz = frac[:, 2] if dz else 1.0 - frac[:, 2] + cx = base[:, 0] + dx + cy = base[:, 1] + dy + cz = base[:, 2] + dz + inb = (cx >= 0) & (cx < R) & (cy >= 0) & (cy < R) & (cz >= 0) & (cz < R) + qk = (cx * R + cy) * R + cz + ins = torch.searchsorted(skey, qk).clamp(max=M - 1) + matched = inb & (skey[ins] == qk) + idx = order[ins] # garbage where !matched + w = torch.where(matched, wx * wy * wz, torch.zeros_like(wx))[:, None] + acc += w * col[idx] # w=0 cancels garbage rows + wsum += w + ok = wsum[:, 0] > 1e-8 + vals = torch.zeros((K, C), device=dev) + vals[ok] = acc[ok] / wsum[ok].clamp_min(1e-8) + return vals.cpu().numpy(), ok.cpu().numpy() + + def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): """GPU nearest-occupied-voxel lookup for surface points. Voxels sit on a regular integer grid (coord c ↔ world c/R-0.5), so the nearest voxel to a @@ -415,7 +464,7 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): of voxels and ~identical. Returns (vals [K,C] float32, found [K] bool); `found` is False for the rare query whose nearest occupied voxel is >1 cell away (the caller falls back to a cKDTree on just those).""" - dev = "cuda" if torch.cuda.is_available() else "cpu" + dev = comfy.model_management.get_torch_device() R = int(resolution) P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float() VC = torch.from_numpy(np.ascontiguousarray(voxel_coords_np)).to(dev).long() @@ -451,6 +500,27 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): fnd |= match return bi, fnd + def _brute_nearest(idx): + """Exact nearest occupied voxel for a (small) query subset by chunked GPU + brute force over all M voxels. Used only for the handful of stragglers the + grid scan misses (>4 cells from any voxel) — replaces a cKDTree build over + all M voxels, which costs seconds even for a few query points.""" + Ps = P[idx] # [N,3] world + N = Ps.shape[0] + vox_pos = (VC.float() + 0.5) / R - 0.5 # [M,3] voxel centres + best_d = torch.full((N,), 1e30, device=dev) + best_j = torch.zeros(N, dtype=torch.long, device=dev) + # Bound the N×chunk distance matrix to ~64M elements regardless of N. + chunk = max(1, (1 << 26) // max(1, N)) + for s in range(0, M, chunk): + vc = vox_pos[s:s + chunk] # [B,3] + dd = (Ps[:, None, :] - vc[None, :, :]).pow(2).sum(-1) # [N,B] + md, mj = dd.min(1) + upd = md < best_d + best_d = torch.where(upd, md, best_d) + best_j = torch.where(upd, mj + s, best_j) + return best_j + all_idx = torch.arange(K, device=dev) best_i = torch.zeros(K, dtype=torch.long, device=dev) found = torch.zeros(K, dtype=torch.bool, device=dev) @@ -458,27 +528,30 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): bi1, fnd1 = _search(all_idx, 1) best_i[all_idx] = bi1 found[all_idx] = fnd1 - # Pass 2: wider radius on ONLY the few misses (avoids ever building a cKDTree - # over millions of voxels just for a handful of >1-cell-away points). + # Pass 2: wider radius (9³) on ONLY the radius-1 misses. miss = torch.nonzero(~found, as_tuple=True)[0] if miss.numel() > 0: bi2, fnd2 = _search(miss, 4) best_i[miss] = bi2 found[miss] = fnd2 + # Pass 3: exact GPU brute force for the few stragglers still unfound (>4 cells + # out). Always resolves them, so `found` is all-True on return — no cKDTree. + miss2 = torch.nonzero(~found, as_tuple=True)[0] + if miss2.numel() > 0: + best_i[miss2] = _brute_nearest(miss2) + found[miss2] = True vals = col[best_i] return vals.cpu().numpy(), found.cpu().numpy() -def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution, - mode="trilinear"): +def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution): """For every masked texel, sample the voxel field and return ALL its attribute channels. Returns (H, W, C) float32 in [0, 1] where C is the voxel feature width (3 for plain color, 6 for full PBR). - mode="trilinear" — normalized trilinear over occupied voxels (the default; matches - the official o_voxel.to_glb path), with nearest fallback for texels whose 8 - surrounding voxels are all empty. This is the only mode the nodes expose now. - mode="nearest" — nearest-voxel; kept as an internal/dev lever (blocky).""" + Normalized trilinear over occupied voxels (matches the official o_voxel.to_glb + path), with nearest fallback for texels whose 8 surrounding voxels are all + empty.""" H, W, _ = position_map.shape color_np = voxel_colors.detach().cpu().numpy().astype(np.float32) C = color_np.shape[-1] @@ -496,22 +569,26 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors valid_positions = position_map[mask] def _nearest(query): - # GPU grid lookup; cKDTree only for the rare >1-cell miss. + # Fully on-GPU nearest-occupied-voxel: grid scan + brute-force tail. Always + # resolves every query, so no cKDTree (its build over all voxels cost ~3s). vals, found = _nearest_voxel_sample_gpu(query, coords_np, color_np, resolution) if not found.all(): + # Defensive: only reachable on a non-CUDA device where the GPU path is + # unavailable; fall back to a one-off cKDTree. tree = scipy.spatial.cKDTree(voxel_pos) _, nearest_idx = tree.query(query[~found], k=1, workers=-1) vals[~found] = color_np[nearest_idx] return vals - if mode == "trilinear": + try: + vals, ok = _trilinear_sample_sparse_gpu(valid_positions, coords_np, color_np, resolution) + except Exception as e: + logging.warning(f"[BakeTextureFromVoxel] GPU trilinear failed ({e}); falling back to CPU") vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution) - if not ok.all(): - # Texels with no occupied neighbour fall back to nearest. - vals[~ok] = _nearest(valid_positions[~ok]) - out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32) - else: - out[mask] = np.clip(_nearest(valid_positions), 0.0, 1.0) + if not ok.all(): + # Texels with no occupied neighbour fall back to nearest. + vals[~ok] = _nearest(valid_positions[~ok]) + out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32) return out @@ -729,7 +806,7 @@ def _back_project_positions(position_map, mask, ref_v, ref_f): return position_map import time as _time - dev = "cuda" if torch.cuda.is_available() else "cpu" + dev = comfy.model_management.get_torch_device() rv = ref_v.detach().to(dev).float() rf = ref_f.detach().to(dev).long() tri = rv[rf] @@ -755,7 +832,7 @@ def _jfa_fill_gpu(img01, mask): (True = covered). Returns [H,W,C] float. ~6× faster than cv2 Telea per map.""" if not mask.any(): return img01 - dev = "cuda" + dev = comfy.model_management.get_torch_device() it = torch.from_numpy(np.ascontiguousarray(img01)).to(dev).float() mm = torch.from_numpy(np.ascontiguousarray(mask)).to(dev) H, W = mm.shape @@ -786,41 +863,58 @@ def _jfa_fill_gpu(img01, mask): def _seam_fill(img01, mask, inpaint_radius): """Fill the UV-gutter texels around covered charts so seam sampling doesn't - pull in black. GPU Jump Flooding (nearest fill) when CUDA is available, else - cv2 Telea inpaint. `inpaint_radius<=0` disables; the radius only affects the - cv2 fallback (JFA fills every uncovered texel by nearest).""" + pull in black, via GPU Jump Flooding (nearest fill). `inpaint_radius<=0` + disables; otherwise the radius is ignored — JFA fills every uncovered texel + by nearest regardless.""" if inpaint_radius <= 0: return img01 - if torch.cuda.is_available(): - return _jfa_fill_gpu(img01, mask) - import cv2 - u8 = (img01 * 255.0).clip(0, 255).astype(np.uint8) - u8 = cv2.inpaint(u8, ((~mask).astype(np.uint8)) * 255, int(inpaint_radius), cv2.INPAINT_TELEA) - if u8.ndim == 2: - u8 = u8[..., None] - return u8.astype(np.float32) / 255.0 + return _jfa_fill_gpu(img01, mask) + + +def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None): + """Uniformly fit a UV layout's bbox into [0,1] when it spills outside the unit + square (preserves chart aspect ratios; handles packers that overflow slightly). + No-op when the UVs are already in [0,1] — the normal case for official/xatlas + unwraps. NOT a UDIM de-tiler; warns if the span looks tiled. + + Deterministic from the input UVs alone, so the texture bake and + ApplyTextureToMesh both call it to agree on the exact UVs the texture was baked + against (the bake no longer emits the mesh, so apply must re-derive them). + + Returns float32 [N,2].""" + uv_np = uv_np.astype(np.float32) + uv_min = uv_np.min(axis=0) + uv_max = uv_np.max(axis=0) + out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001) + if not (normalize and out_of_unit): + return uv_np + extent = float((uv_max - uv_min).max()) + span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1])) + if span > 1.5 and log_prefix: + logging.warning( + f"{log_prefix} UV span {span:.2f} looks like a tiled/UDIM layout; " + f"uniform-fitting it into [0,1] will overlap tiles. Re-unwrap upstream instead.") + if extent > 0: + uv_np = ((uv_np - uv_min) / extent).astype(np.float32) + if log_prefix: + logging.info(f"{log_prefix} normalized UVs into [0,1] (uniform scale 1/{extent:.4f})") + return uv_np def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, - resolution, texture_size, inpaint_radius=3, - fast_unwrap=True, existing_uvs=None, - normalize_uvs=True, sample_mode="trilinear", - reference=None, pbar=None): + resolution, texture_size, uvs, inpaint_radius=3, + normalize_uvs=True, reference=None, pbar=None): """Bake a baseColor (+ optional metallicRoughness) texture for `vertices/faces`, rasterizing in UV space and nearest-voxel-sampling each texel from the provided sparse colored voxel volume. - If `existing_uvs` (N, 2) is given, it is used directly and xatlas is - skipped — bakes onto the mesh's current UV layout without re-unwrapping. - Otherwise xatlas computes a fresh atlas (verts/faces may grow at seams). + `uvs` (N, 2) is the mesh's existing UV layout — baked onto directly (this + node never unwraps; connect a UV unwrap node upstream). It must be 1:1 with + `vertices`. Returns (out_vertices, out_faces, out_uvs, out_texture, out_mr). - `fast_unwrap=True` configures xatlas with permissive chart options so it - finishes in a reasonable time on large meshes — at the cost of less even - UV distribution. Set False to use xatlas defaults (slow on >100k faces). - - Progress: drives a local tqdm over its 5 stages (unwrap → rasterize → + Progress: drives a local tqdm over its 5 stages (uvs → rasterize → back-project → sample → finalize) and, if a comfy `pbar` (ProgressBar) is passed, ticks it once per stage too — so callers should size it as 5 per bake.""" @@ -845,111 +939,25 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, v_np = vertices.detach().cpu().numpy().astype(np.float32) f_np = faces.detach().cpu().numpy().astype(np.uint32) fcount = int(f_np.shape[0]) - t0 = time.perf_counter() - if existing_uvs is not None: - # Bake onto the mesh's current UVs — no xatlas, no seam-splitting. - uv_np = existing_uvs.detach().cpu().numpy().astype(np.float32) - if uv_np.shape[0] != v_np.shape[0]: - raise ValueError( - f"BakeTextureFromVoxel: existing UVs ({uv_np.shape[0]}) must be 1:1 " - f"with vertices ({v_np.shape[0]})." - ) - uv_min = uv_np.min(axis=0) - uv_max = uv_np.max(axis=0) - oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum()) - logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, " - f"{fcount} faces (xatlas skipped)") - logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] " - f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}") - out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001) - if normalize_uvs and out_of_unit: - # Uniform fit of the UV bbox into [0,1] (preserves chart aspect ratios). - # Handles packers that overflow the unit square slightly. NOT a UDIM - # de-tiler — a true multi-tile layout would get squashed; warn if the - # span is large enough to look like tiling. - extent = float((uv_max - uv_min).max()) - span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1])) - if span > 1.5: - logging.warning( - f"[BakeTextureFromVoxel] UV span {span:.2f} looks like a tiled/UDIM " - f"layout; uniform-fitting it into [0,1] will overlap tiles. " - f"Re-unwrap instead (use_existing_uvs=False)." - ) - if extent > 0: - uv_np = ((uv_np - uv_min) / extent).astype(np.float32) - logging.info(f"[BakeTextureFromVoxel] normalized UVs into [0,1] " - f"(uniform scale 1/{extent:.4f})") - new_verts, new_faces, new_uvs = v_np, f_np, uv_np - else: - import xatlas - if fcount > 300_000: - logging.warning( - f"[BakeTextureFromVoxel] mesh has {fcount} faces — xatlas chart " - f"decomposition is CPU-bound and may take many minutes. Consider " - f"decimating to under ~200k faces before baking." - ) - logging.info(f"[BakeTextureFromVoxel] xatlas unwrap: {v_np.shape[0]} verts, {fcount} faces") - if fast_unwrap and hasattr(xatlas, "Atlas"): - atlas = xatlas.Atlas() - atlas.add_mesh(v_np, f_np) - logging.info(f"[BakeTextureFromVoxel] add_mesh: {time.perf_counter() - t0:.1f}s") - gen_kwargs = {} - applied = [] - # ChartOptions: looser growth → larger / fewer charts → faster. - if hasattr(xatlas, "ChartOptions"): - co = xatlas.ChartOptions() - for attr, val in ( - ("max_iterations", 1), - ("max_cost", 8.0), - ("normal_deviation_weight", 1.0), - ("roundness_weight", 0.0), - ("straightness_weight", 0.0), - ("normal_seam_weight", 1.0), - ("texture_seam_weight", 0.0), - ("use_input_mesh_uvs", False), - ): - if hasattr(co, attr): - setattr(co, attr, val) - applied.append(f"chart.{attr}") - gen_kwargs["chart_options"] = co - # PackOptions.bruteForce defaults to True — tries many rotations per - # chart and is the single biggest contributor to pack time on small - # meshes. Off it loses ~5-15% packing efficiency but runs ~5× faster. - if hasattr(xatlas, "PackOptions"): - po = xatlas.PackOptions() - for attr, val in ( - ("bruteForce", False), - ("brute_force", False), # snake_case alias on some builds - ("create_image", False), - ("createImage", False), - ("padding", 2), - ): - if hasattr(po, attr): - setattr(po, attr, val) - applied.append(f"pack.{attr}") - gen_kwargs["pack_options"] = po - logging.info(f"[BakeTextureFromVoxel] options applied: {applied}") - tgen = time.perf_counter() - try: - atlas.generate(**gen_kwargs) - except TypeError as e: - logging.warning(f"[BakeTextureFromVoxel] generate(**kwargs) rejected ({e}); retrying with defaults") - atlas.generate() - logging.info(f"[BakeTextureFromVoxel] generate: {time.perf_counter() - tgen:.1f}s") - tget = time.perf_counter() - vmapping, indices, uvs = atlas[0] - logging.info(f"[BakeTextureFromVoxel] retrieve: {time.perf_counter() - tget:.1f}s") - else: - vmapping, indices, uvs = xatlas.parametrize(v_np, f_np) + # Bake onto the mesh's current UVs — no unwrap, no seam-splitting. + uv_np = uvs.detach().cpu().numpy().astype(np.float32) + if uv_np.shape[0] != v_np.shape[0]: + raise ValueError( + f"BakeTextureFromVoxel: UVs ({uv_np.shape[0]}) must be 1:1 " + f"with vertices ({v_np.shape[0]})." + ) + uv_min = uv_np.min(axis=0) + uv_max = uv_np.max(axis=0) + oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum()) + logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, " + f"{fcount} faces") + logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] " + f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}") + uv_np = _normalize_uvs_to_unit(uv_np, normalize_uvs, log_prefix="[BakeTextureFromVoxel] ") + new_verts, new_faces, new_uvs = v_np, f_np, uv_np - logging.info(f"[BakeTextureFromVoxel] xatlas total {time.perf_counter() - t0:.1f}s " - f"({vmapping.shape[0]} verts after seams)") - new_verts = v_np[vmapping] - new_faces = indices.astype(np.uint32) - new_uvs = uvs.astype(np.float32) - - _tick("unwrap") + _tick("uvs") t1 = time.perf_counter() position_map, mask = _bake_position_map(new_verts, new_faces, new_uvs, texture_size) @@ -968,7 +976,7 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, t2 = time.perf_counter() attrs = _sample_voxel_attrs_per_texel( - position_map, mask, voxel_coords, voxel_colors, resolution, mode=sample_mode, + position_map, mask, voxel_coords, voxel_colors, resolution, ) logging.info(f"[BakeTextureFromVoxel] voxel sample in {time.perf_counter() - t2:.1f}s " f"({attrs.shape[-1]} channels)") @@ -1022,12 +1030,11 @@ def _per_vertex_normals(verts_np, faces_np): def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resolution, - texture_size, views, blend_temperature=0.25, - inpaint_radius=3, fast_unwrap=True, existing_uvs=None, - normalize_uvs=True, sample_mode="trilinear"): + texture_size, views, uvs, blend_temperature=0.25, + inpaint_radius=3, normalize_uvs=True): """Bake a baseColor texture by projecting view photos onto the mesh. - Reuses bake_texture_from_voxel_fn for the xatlas unwrap + the nearest-voxel + Reuses bake_texture_from_voxel_fn for the UV-space bake + the nearest-voxel fallback colour, then overlays photo colour on every covered+visible texel: each texel's world position/normal is projected into each view, occlusion is resolved with a texel z-buffer, and the views are blended weighted by how @@ -1045,8 +1052,8 @@ def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resol # Voxel bake → unwrapped geometry + fallback colour (inpaint deferred to the end). out_v, out_f, out_uvs, voxel_tex, voxel_mr = bake_texture_from_voxel_fn( vertices, faces, voxel_coords, voxel_colors, resolution=resolution, - texture_size=texture_size, inpaint_radius=0, fast_unwrap=fast_unwrap, - existing_uvs=existing_uvs, normalize_uvs=normalize_uvs, sample_mode=sample_mode) + texture_size=texture_size, uvs=uvs, inpaint_radius=0, + normalize_uvs=normalize_uvs) v_np = out_v.detach().cpu().numpy().astype(np.float32) f_np = out_f.detach().cpu().numpy().astype(np.uint32) @@ -1078,6 +1085,16 @@ def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resol return out_v, out_f, out_uvs, out_tex, voxel_mr +def _mr_channel(packed_mr, ch, ref): + """Pull one channel out of a packed glTF MR map (G=roughness at idx 1, B=metallic + at idx 2) as a 3-channel grayscale IMAGE [H,W,3] in [0,1]. Returns black sized + like `ref` when there's no MR map (non-PBR voxel field).""" + if packed_mr is None: + return torch.zeros_like(ref.float().cpu()) + m = packed_mr.float().clamp(0.0, 1.0).cpu() + return m[..., ch:ch + 1].expand(-1, -1, 3).contiguous() + + class BakeTextureFromVoxel(IO.ComfyNode): @classmethod def define_schema(cls): @@ -1089,12 +1106,12 @@ class BakeTextureFromVoxel(IO.ComfyNode): "Bakes PBR textures onto the mesh's existing UV layout by rasterizing it " "in UV space via OpenGL (ComfyUI's PyOpenGL backend) and trilinear-sampling " "the input sparse voxel volume. Does NOT unwrap — connect a UV unwrap node " - "(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Produces a " - "baseColor texture, plus a metallicRoughness texture when the voxel field " - "carries the full PBR set (6 channels). Returns a Mesh with `uvs`, `texture`, " - "and `metallic_roughness` attached — SaveGLB serializes them as real " - "baseColorTexture / metallicRoughnessTexture maps. UVs that spill outside " - "[0,1] are uniformly fit back into the unit square." + "(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Outputs the " + "baked maps as IMAGEs: base_color, plus metallic and roughness as separate " + "grayscale maps (both black when the voxel field has no PBR set). " + "Preview/save/post-process them, then feed them to ApplyTextureToMesh (with " + "the SAME mesh) to attach them for SaveGLB. UVs that spill outside [0,1] are " + "uniformly fit into the unit square." ), inputs=[ IO.Mesh.Input("mesh"), @@ -1108,7 +1125,11 @@ class BakeTextureFromVoxel(IO.ComfyNode): "o_voxel.to_glb step that removes faceted/pixelized baking on coarse " "meshes. Pure scipy+torch, no extra deps.")), ], - outputs=[IO.Mesh.Output("mesh")], + outputs=[ + IO.Image.Output(display_name="base_color"), + IO.Image.Output(display_name="metallic"), + IO.Image.Output(display_name="roughness"), + ], ) @classmethod @@ -1132,7 +1153,7 @@ class BakeTextureFromVoxel(IO.ComfyNode): batch_idx = coords[:, 0].long() voxel_xyz = coords[:, 1:] mesh_batch_size = int(mesh.vertices.shape[0]) - out_verts, out_faces, out_uvs, out_tex, out_mr = [], [], [], [], [] + out_tex, out_mr = [], [] # 5 stage ticks per item (see bake_texture_from_voxel_fn); skipped items # tick all 5 so the bar stays aligned. pbar = comfy.utils.ProgressBar(mesh_batch_size * 5) @@ -1150,28 +1171,25 @@ class BakeTextureFromVoxel(IO.ComfyNode): if reference_mesh is not None: rv_i, rf_i, _ = get_mesh_batch_item(reference_mesh, i) ref_i = (rv_i, rf_i) - bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn( + _bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn( v_i, f_i, item_coords, item_colors, resolution=resolution, texture_size=texture_size, - inpaint_radius=inpaint_radius, - existing_uvs=ev_i, reference=ref_i, pbar=pbar, + uvs=ev_i, inpaint_radius=inpaint_radius, + reference=ref_i, pbar=pbar, ) - out_verts.append(bv); out_faces.append(bf); out_uvs.append(bu) out_tex.append(bt); out_mr.append(bmr) - if not out_verts: - return IO.NodeOutput(mesh) - # Local pack_variable_mesh_batch doesn't take uvs/texture; build the - # packed mesh ourselves so we can attach both. UVs are 1:1 with verts. - packed = pack_variable_mesh_batch(out_verts, out_faces) - max_v = packed.vertices.shape[1] - packed_uvs = out_uvs[0].new_zeros((len(out_uvs), max_v, 2)) - for i, u in enumerate(out_uvs): - packed_uvs[i, :u.shape[0]] = u - packed.uvs = packed_uvs - packed.texture = torch.stack(out_tex, dim=0) - if all(mr is not None for mr in out_mr): - packed.metallic_roughness = torch.stack(out_mr, dim=0) - return IO.NodeOutput(packed) + if not out_tex: + # Every item skipped (degenerate) — emit one black map so the IMAGE + # outputs stay valid. + black = torch.zeros((1, texture_size, texture_size, 3)) + return IO.NodeOutput(black, black, black) + # All maps are texture_size² — stack into [B,H,W,3] IMAGE batches. The + # packed MR (G=roughness, B=metallic) is split into separate grayscale + # maps; both black where the voxel field carried no PBR set. + base_img = torch.stack([t.float().clamp(0.0, 1.0).cpu() for t in out_tex], dim=0) + metallic_img = torch.stack([_mr_channel(m, 2, out_tex[0]) for m in out_mr], dim=0) + roughness_img = torch.stack([_mr_channel(m, 1, out_tex[0]) for m in out_mr], dim=0) + return IO.NodeOutput(base_img, metallic_img, roughness_img) # Single-item path. v0 = mesh.vertices.squeeze(0) @@ -1181,19 +1199,16 @@ class BakeTextureFromVoxel(IO.ComfyNode): if reference_mesh is not None: ref0 = (reference_mesh.vertices.squeeze(0), reference_mesh.faces.squeeze(0)) pbar = comfy.utils.ProgressBar(5) # 5 stage ticks (see bake_texture_from_voxel_fn) - bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn( + _bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn( v0, f0, coords, colors, resolution=resolution, texture_size=texture_size, - inpaint_radius=inpaint_radius, - existing_uvs=ev0, reference=ref0, pbar=pbar, + uvs=ev0, inpaint_radius=inpaint_radius, + reference=ref0, pbar=pbar, ) - out_mesh = Types.MESH( - vertices=bv.unsqueeze(0), faces=bf.unsqueeze(0), - uvs=bu.unsqueeze(0), texture=bt.unsqueeze(0), - ) - if bmr is not None: - out_mesh.metallic_roughness = bmr.unsqueeze(0) - return IO.NodeOutput(out_mesh) + base_img = bt.float().clamp(0.0, 1.0).cpu().unsqueeze(0) + metallic_img = _mr_channel(bmr, 2, bt).unsqueeze(0) + roughness_img = _mr_channel(bmr, 1, bt).unsqueeze(0) + return IO.NodeOutput(base_img, metallic_img, roughness_img) class MeshTextureToImage(IO.ComfyNode): @@ -1247,6 +1262,73 @@ class MeshTextureToImage(IO.ComfyNode): return IO.NodeOutput(base, mr, metallic, roughness) +class ApplyTextureToMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ApplyTextureToMesh", + display_name="Apply Texture to Mesh", + category="latent/3d", + description=( + "Attaches baked texture IMAGEs to a mesh's existing UV layout so SaveGLB " + "serializes them as baseColorTexture / metallicRoughnessTexture maps. Pairs " + "with BakeTextureFromVoxel: feed it the SAME mesh you baked from, plus that " + "node's base_color (and optionally metallic / roughness grayscale maps) — the " + "UVs must match the ones the texture was baked against, so don't re-unwrap in " + "between. metallic and roughness are repacked into the glTF MR map " + "(G=roughness, B=metallic); leave them unconnected for non-PBR meshes (a " + "missing metallic defaults to 0, a missing roughness to 1). Lets you preview / " + "upscale / edit the baked maps before applying them." + ), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Image.Input("base_color"), + IO.Image.Input("metallic", optional=True), + IO.Image.Input("roughness", optional=True), + ], + outputs=[IO.Mesh.Output("mesh")], + ) + + @classmethod + def execute(cls, mesh, base_color, metallic=None, roughness=None): + mesh_uvs = getattr(mesh, "uvs", None) + if mesh_uvs is None: + raise ValueError( + "ApplyTextureToMesh: mesh has no UVs. Connect the same UV-unwrapped mesh " + "you fed to BakeTextureFromVoxel (this node attaches onto existing UVs and " + "never unwraps).") + + # Re-derive the exact UVs the bake rasterized against — it uniformly fits + # out-of-[0,1] layouts into the unit square, so apply the identical + # deterministic transform here (per batch item, over each item's real verts). + if mesh_uvs.ndim == 3: + new_uvs = mesh_uvs.clone() + for i in range(mesh_uvs.shape[0]): + v_i, _f_i, _ = get_mesh_batch_item(mesh, i) + n = v_i.shape[0] + norm = _normalize_uvs_to_unit(mesh_uvs[i, :n].detach().cpu().numpy()) + new_uvs[i, :n] = torch.from_numpy(norm).to(new_uvs) + else: + norm = _normalize_uvs_to_unit(mesh_uvs.detach().cpu().numpy()) + new_uvs = torch.from_numpy(norm).to(mesh_uvs) + + out_mesh = copy.copy(mesh) + out_mesh.uvs = new_uvs + out_mesh.texture = base_color.float().clamp(0.0, 1.0).cpu() + if metallic is not None or roughness is not None: + # Repack separate grayscale maps into glTF MR: R unused, G=roughness, + # B=metallic. Size defaults off whichever map is connected; a missing + # channel falls back to a sensible scalar (metal 0 / rough 1). + prov = (metallic if metallic is not None else roughness).float().clamp(0.0, 1.0).cpu() + B, H, W, _ = prov.shape + rough_ch = (roughness.float().clamp(0.0, 1.0).cpu()[..., 0:1] + if roughness is not None else torch.ones((B, H, W, 1))) + metal_ch = (metallic.float().clamp(0.0, 1.0).cpu()[..., 0:1] + if metallic is not None else torch.zeros((B, H, W, 1))) + out_mesh.metallic_roughness = torch.cat([torch.zeros((B, H, W, 1)), rough_ch, metal_ch], dim=-1) + return IO.NodeOutput(out_mesh) + + def paint_mesh_default_colors(mesh): out_mesh = copy.copy(mesh) vertex_count = mesh.vertices.shape[1] @@ -1397,14 +1479,14 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter): for c, n in zip(unique_nm.tolist(), cnt_nm.tolist()): nm_share_breakdown.append(f"{n} edges×{c}faces") - logging.info(f"[FillHolesV2 diag] V={V} F={F} | " + logging.info(f"[FillHoles diag] V={V} F={F} | " f"boundary(cnt==1)={n_boundary} interior(cnt==2)={n_interior} " f"non-manifold(cnt>=3)={n_nonmanifold} (max={nm_max})") if nm_share_breakdown: - logging.info(f"[FillHolesV2 diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}") + logging.info(f"[FillHoles diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}") if n_boundary == 0: - logging.info("[FillHolesV2 diag] no boundary edges → no cycles to fill") + logging.info("[FillHoles diag] no boundary edges → no cycles to fill") return # Walk components same as production path (bidir-prop, by-vertex count). @@ -1459,15 +1541,15 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter): cfan_tris = int(vert_count[cfan].sum().item()) # N tris per N-vert cycle cfan_new_verts = int(cfan.sum().item()) # 1 centroid per centroid-fan component - logging.info(f"[FillHolesV2 diag] components={L} " + logging.info(f"[FillHoles diag] components={L} " f"cycles={int(is_cycle.sum().item())} chains={int(is_chain.sum().item())} " f"non-simple={int(is_other.sum().item())}") - logging.info(f"[FillHolesV2 diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})") - logging.info(f"[FillHolesV2 diag] actually kept={int(actually_kept.sum().item())} " + logging.info(f"[FillHoles diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})") + logging.info(f"[FillHoles diag] actually kept={int(actually_kept.sum().item())} " f"cycle rejected by perim={int((is_cycle & ~cycle_perim_ok).sum().item())} " f"cycle rejected by verts={int((is_cycle & ~cycle_size_ok).sum().item())}") - logging.info(f"[FillHolesV2 diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)") - logging.info(f"[FillHolesV2 diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts") + logging.info(f"[FillHoles diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)") + logging.info(f"[FillHoles diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts") # Cycle vert-count distribution if is_cycle.any(): @@ -1485,12 +1567,12 @@ def _fill_holes_v2_diagnostic(verts, faces, max_perimeter): elif s <= 20: buckets["11-20"] += n elif s <= 50: buckets["21-50"] += n else: buckets["51+"] += n - logging.info(f"[FillHolesV2 diag] cycle vert-count buckets: {buckets}") + logging.info(f"[FillHoles diag] cycle vert-count buckets: {buckets}") if is_cycle.any(): cycle_perims = perim[is_cycle].cpu().tolist() head = sorted(cycle_perims, reverse=True)[:10] - logging.info(f"[FillHolesV2 diag] top-10 cycle perimeters: " + logging.info(f"[FillHoles diag] top-10 cycle perimeters: " f"{['%.4f' % p for p in head]}") @@ -1804,10 +1886,10 @@ def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsi n_escalations += 1 if total_welded > 0 or n_escalations > 0: tag = f" (escalated weld epsilon_rel→{eps:.1e} after {n_escalations} step{'s' if n_escalations != 1 else ''})" if n_escalations > 0 else "" - logging.info(f"[FillHolesV2] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}") + logging.info(f"[FillHoles] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}") if ratio >= WELDED_THRESHOLD: logging.warning( - f"[FillHolesV2] even at weld epsilon_rel={WELD_CAP} the mesh stays " + f"[FillHoles] even at weld epsilon_rel={WELD_CAP} the mesh stays " f"unwelded (V/F={ratio:.2f}, want < {WELDED_THRESHOLD}). Source mesh has " f"duplicate verts at distances >{WELD_CAP}× bbox; fix upstream " f"(decimate node settings) or run WeldVertices manually with a larger epsilon." @@ -2077,6 +2159,23 @@ def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4): return offset_verts, f_unwelded, None +def _fmt_count(n) -> str: + """Compact human-readable integer for node status lines, e.g. 853, 12.3K, 1.23M.""" + n = int(n) + if n >= 1_000_000: + return f"{n / 1_000_000:.2f}".rstrip("0").rstrip(".") + "M" + if n >= 1_000: + return f"{n / 1_000:.1f}".rstrip("0").rstrip(".") + "K" + return str(n) + + +def _fmt_face_change(n_in, n_out) -> str: + """'faces: 1.23M → 200K (-84%)' — the count delta for decimate/remesh status.""" + n_in, n_out = int(n_in), int(n_out) + pct = f" ({(n_out - n_in) / n_in * 100:+.0f}%)" if n_in else "" + return f"faces: {_fmt_count(n_in)} → {_fmt_count(n_out)}{pct}" + + class DecimateMesh(IO.ComfyNode): @classmethod def define_schema(cls): @@ -2168,7 +2267,7 @@ class DecimateMesh(IO.ComfyNode): # Send progress text to display the face reduction on the node if cls.hidden.unique_id: PromptServer.instance.send_progress_text( - f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id) + _fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id) return result @@ -2306,7 +2405,7 @@ class RemeshMesh(IO.ComfyNode): # Send progress text to display the face change on the node if cls.hidden.unique_id: PromptServer.instance.send_progress_text( - f"faces: {counts['in']} -> {counts['out']}", cls.hidden.unique_id) + _fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id) return result @@ -2546,7 +2645,8 @@ class UnwrapMesh(IO.ComfyNode): if cls.hidden.unique_id: PromptServer.instance.send_progress_text( - f"UV: {out_v[0].shape[0]}v / {out_f[0].shape[0]}f, atlas ~{resolution}px", + f"UV: {_fmt_count(out_v[0].shape[0])} verts / {_fmt_count(out_f[0].shape[0])} faces" + f" · atlas ~{resolution}px", cls.hidden.unique_id) return IO.NodeOutput(out_mesh) @@ -2740,36 +2840,12 @@ class FillHoles(IO.ComfyNode): node_id="FillHoles", display_name="Fill Holes", category="latent/3d", - description="Fills holes in a mesh up to a maximum perimeter threshold.", - inputs=[ - IO.Mesh.Input("mesh"), - IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001, - tooltip="Maximum hole perimeter to fill. Set to 0 to disable."), - ], - outputs=[IO.Mesh.Output("mesh")], - ) - - @classmethod - def execute(cls, mesh, max_perimeter): - def _fn(v, f, c): - if max_perimeter > 0: - v, f = fill_holes_fn(v, f, max_perimeter=max_perimeter) - return v, f, c - return _process_mesh_batch(mesh, _fn) - -class FillHolesV2(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="FillHolesV2", - display_name="Fill Holes (v2)", - category="latent/3d", description=( - "GPU-vectorised hole-filling via directed-half-edge pointer-doubling. " - "Drop-in alternative to FillHoles for comparison: same max_perimeter " - "cutoff and fan-from-centroid triangulation, but no Python loop, " - "auto-correct winding from face direction, and centroid colors are " - "averaged from the loop instead of left zero." + "Fills holes in a mesh up to a maximum perimeter threshold, preserving " + "the existing geometry/UVs (only patch triangles are added). GPU-vectorised " + "via directed-half-edge pointer-doubling: no Python loop, auto-correct " + "winding from face direction, and centroid colors are averaged from the hole " + "loop. Falls back to a CPU walker on non-CUDA devices." ), inputs=[ IO.Mesh.Input("mesh"), @@ -2947,7 +3023,6 @@ class PostProcessMeshExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ FillHoles, - FillHolesV2, WeldVertices, DecimateMesh, RemeshMesh, @@ -2956,6 +3031,7 @@ class PostProcessMeshExtension(ComfyExtension): PaintMesh, BakeTextureFromVoxel, MeshTextureToImage, + ApplyTextureToMesh, MergeMeshes, ] diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index ceace5b7d..b1bf6807c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,18 +1,15 @@ from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO, Types, io +from comfy_api.latest import ComfyExtension, IO, Types, UI, io from comfy.ldm.trellis2.vae import SparseTensor -from comfy.ldm.trellis2.model import ( - _build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats, -) +from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict + from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch import comfy.model_management import comfy.utils import folder_paths -from comfy.ldm.trellis2 import sampling_preview from PIL import Image import logging -import os import numpy as np import math import torch @@ -21,89 +18,6 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES") NAFModel = io.Custom("NAF_MODEL") -# Texture latent -> base-color calibration for the per-step preview -def _tex_rgb_factors_path(): - return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt") - - -def _pool_albedo_to_input(in_coords, out_coords, out_colors): - in_sp = in_coords[:, 1:4].long() - out_sp = out_coords[:, 1:4].long() - in_b = in_coords[:, 0].long() - out_b = out_coords[:, 0].long() - in_res = int(in_sp.max().item()) + 1 - out_res = int(out_sp.max().item()) + 1 - parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1) - R = in_res - in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2] - par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2] - order = torch.argsort(in_flat) - in_sorted = in_flat[order] - pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1) - matched = in_sorted[pos] == par_flat - in_idx = order[pos][matched] - cols = out_colors[matched].float() - N = in_coords.shape[0] - csum = cols.new_zeros((N, 3)) - ccount = cols.new_zeros((N, 1)) - csum.index_add_(0, in_idx, cols) - ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype)) - valid = ccount[:, 0] > 0 - albedo = torch.zeros_like(csum) - albedo[valid] = csum[valid] / ccount[valid] - return albedo, valid - - -def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords): - """Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish.""" - try: - dev = out_colors.device - in_latent = in_latent.to(dev) - in_coords = in_coords.to(dev) - out_coords = out_coords.to(dev) - albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors) - X = in_latent[valid].float().cpu() - Y = albedo[valid].float().cpu() - if X.shape[0] < 64: - return - Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1] - A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1] - B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3] - - path = _tex_rgb_factors_path() - if os.path.exists(path): - try: - prev = torch.load(path, map_location="cpu") - A_run = A_run + prev["A"] - B_run = B_run + prev["B"] - except Exception: - pass - os.makedirs(os.path.dirname(path), exist_ok=True) - torch.save({"A": A_run, "B": B_run}, path) - - eye = torch.eye(A_run.shape[0]) - WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3] - W, b = WB[:-1].contiguous(), WB[-1].contiguous() - sampling_preview.set_tex_rgb(W, b) - except Exception as e: - logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}") - - -def _load_tex_rgb_factors(): - try: - path = _tex_rgb_factors_path() - if os.path.exists(path): - d = torch.load(path, map_location="cpu") - eye = torch.eye(d["A"].shape[0]) - WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"]) - sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous()) - except Exception as e: - logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}") - - -_load_tex_rgb_factors() - - def prepare_trellis_vae_for_decode(vae, sample_shape): memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) if len(sample_shape) == 5: @@ -271,7 +185,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): if coord_counts is None: samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) samples = shape_norm(samples.to(device), coords.to(device)) - mesh, subs = trellis_vae.decode_shape_slat(samples, resolution) + mesh, subs = trellis_vae.decode_shape_slat(samples.to(vae.vae_dtype), resolution) else: split_items = split_batched_sparse_latent(samples, coords, coord_counts) mesh = [] @@ -280,7 +194,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): coords_i = coords_i.to(device).clone() coords_i[:, 0] = 0 sample_i = shape_norm(feats_i.to(device), coords_i) - mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i, resolution) + mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i.to(vae.vae_dtype), resolution) mesh.append(mesh_i[0]) subs_per_sample.append(subs_i) @@ -338,14 +252,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = samples["samples"] samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) samples = samples.to(device) - cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration - cal_in_coords = coords std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords.to(device)) samples = samples * std + mean - voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides) + voxel = trellis_vae.decode_tex_slat(samples.to(vae.vae_dtype), shape_subdivides) # Keep all decoded channels. The texture VAE emits 6: base_color (0:3), # metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color # consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full @@ -353,12 +265,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): color_feats = voxel.feats voxel_coords = voxel.coords - # Calibrate the latent->base_color map for the per-step texture preview. - # Done here while input coords and voxel_coords share the model frame - # (before the z_up remap below) and on the real decoded albedo. - if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3: - _calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords) - if coord_resolution is not None: tex_resolution = int(coord_resolution) * 16 elif voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3: @@ -416,7 +322,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): decoded_batches = [] for start in range(0, sample_tensor.shape[0], batch_number): sample_chunk = sample_tensor[start:start + batch_number].to(load_device) - decoded_batches.append(shape_vae.decode_structure(sample_chunk) > 0) + decoded_batches.append(shape_vae.decode_structure(sample_chunk.to(vae.vae_dtype)) > 0) decoded = torch.cat(decoded_batches, dim=0) current_res = decoded.shape[2] @@ -491,7 +397,7 @@ class Trellis2UpsampleStage(IO.ComfyNode): shape_latent["samples"], shape_latent["coords"], coord_counts, ) slat = shape_norm(feats.to(device), coords_512.to(device)) - sample_hr_coords = [shape_vae.upsample_shape(slat, upsample_times=4)] + sample_hr_coords = [shape_vae.upsample_shape(slat.to(vae.vae_dtype), upsample_times=4)] else: items = split_batched_sparse_latent( shape_latent["samples"], shape_latent["coords"], coord_counts, @@ -501,7 +407,7 @@ class Trellis2UpsampleStage(IO.ComfyNode): coords_i = coords_i.to(device).clone() coords_i[:, 0] = 0 slat_i = shape_norm(feats_i.to(device), coords_i) - sample_hr_coords.append(shape_vae.upsample_shape(slat_i, upsample_times=4)) + sample_hr_coords.append(shape_vae.upsample_shape(slat_i.to(vae.vae_dtype), upsample_times=4)) # Resolution search — cache the final iteration's quantized unique tensors # so we don't recompute .unique() per sample after picking hr_resolution. @@ -977,8 +883,10 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): if pad_l or pad_t or pad_r or pad_b: img = torch.nn.functional.pad(img, (pad_l, pad_r, pad_t, pad_b), value=0.0) mask = torch.nn.functional.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=0.0) - crop_x1 += pad_l; crop_x2 += pad_l - crop_y1 += pad_t; crop_y2 += pad_t + crop_x1 += pad_l + crop_x2 += pad_l + crop_y1 += pad_t + crop_y2 += pad_t cropped_img = img [..., crop_y1:crop_y2, crop_x1:crop_x2] cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2] @@ -1100,7 +1008,7 @@ class Pixal3DConditioning(IO.ComfyNode): cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32) dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32) scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32) - T = _build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32) + T = build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32) proj_pack = { "stages": { @@ -1119,15 +1027,6 @@ class Pixal3DConditioning(IO.ComfyNode): } # global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024. - # proj_feat_pack rides in the conditioning dict (same place embeds, ControlNet - # hints etc. live); the sampler auto-promotes it to a model.forward kwarg via - # Trellis2.extra_conds. The same pack object is shared between pos/neg — - # CONDConstant.can_concat sees them equal and concats to a single dict, then - # Trellis2.forward zeros proj for the uncond slots via cond_or_uncond. - # Pre-compute the SS-stage proj features (dense 16³ grid) once here — the - # shape/texture stages do their own computes in their respective stage nodes. - # proj_pack lives on intermediate (CPU); force the compute onto cuda so - # the bilinear-sampling step doesn't run on CPU. ss_proj_feats = compute_stage_proj_feats( proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size, device=torch_device, @@ -1278,12 +1177,6 @@ class Pixal3DAlignObject(IO.ComfyNode): q_mean = Q.mean(dim=0, keepdim=True) P_c = P - p_mean Q_c = Q - q_mean - # Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is - # noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper - # acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c) - # gets multiplied by cos(yaw) and shrinks the object. Using - # sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of - # rotation; translation still positions the mesh at MoGe's centroid. p_var = (P_c * P_c).sum().clamp(min=1e-8) q_var = (Q_c * Q_c).sum() scale = float(torch.sqrt(q_var / p_var).item()) @@ -1326,74 +1219,69 @@ class LoadNAFModel(IO.ComfyNode): return IO.NodeOutput(model) -class CFGGuidanceInterval(IO.ComfyNode): - """Generic model patch: apply CFG only during [start_percent, end_percent] of - the sampling schedule. Outside that window, skip the uncond computation and - collapse to effective cfg=1 — same idea as upstream Trellis2 / Pixal3D's - guidance_interval_mixin, but lives at the sampler level (via - sampler_calc_cond_batch_function) so it works for any model. +class GetMeshInfo(IO.ComfyNode): + """Report vertex / face counts and attributes for a MESH, displayed on the + node (and as a string output). Counts are comma-formatted since meshes can + run into the millions of faces. Passes the mesh through unchanged.""" - Percents use ComfyUI's standard convention: 0.0 = start of sampling - (max-noise step), 1.0 = end of sampling (clean step). Conversion to sigma - is done via model_sampling.percent_to_sigma so the window is portable - across schedules (flow / EDM / discrete) and shift settings. - - Defaults are full-range (no bypass). Upstream Trellis2 / Pixal3D - pipeline.json sets guidance_interval=[0.6, 1.0] (upstream t-space) on the - SS and shape samplers — CFG active only in the first 40% of sampling. - Wire (start_percent=0.0, end_percent=0.4) on the SS / shape KSamplers to - match. Texture defaults to cfg=1 so the node is moot there.""" @classmethod def define_schema(cls): return IO.Schema( - node_id="CFGGuidanceInterval", - category="model_patches/sampling", - inputs=[ - IO.Model.Input("model"), - IO.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, - tooltip="Fraction of sampling at which CFG turns ON (0 = beginning)."), - IO.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, - tooltip="Fraction of sampling at which CFG turns OFF (1 = end)."), + node_id="GetMeshInfo", + display_name="Get Mesh Info", + category="latent/3d", + inputs=[IO.Mesh.Input("mesh")], + outputs=[ + IO.Mesh.Output(display_name="mesh"), + IO.String.Output(display_name="info"), ], - outputs=[IO.Model.Output()], ) + @staticmethod + def _fmt(n: int) -> str: + # e.g. 1234567 -> "1,234,567 (1.23M)"; small numbers stay plain. + s = f"{n:,}" + if n >= 1_000_000: + s += f" ({n / 1_000_000:.2f}M)" + elif n >= 10_000: + s += f" ({n / 1_000:.1f}K)" + return s + @classmethod - def execute(cls, model, start_percent, end_percent): - import comfy.samplers + def execute(cls, mesh): + B = mesh.vertices.shape[0] + # Honour per-item counts when the batch is zero-padded; else use the row sizes. + if mesh.vertex_counts is not None: + v_counts = [int(x) for x in mesh.vertex_counts.tolist()] + f_counts = [int(x) for x in mesh.face_counts.tolist()] + else: + v_counts = [int(mesh.vertices.shape[1])] * B + f_counts = [int(mesh.faces.shape[1])] * B - model_sampling = model.get_model_object("model_sampling") - # percent_to_sigma is monotonically decreasing: percent=0 -> sigma_max, - # percent=1 -> sigma_min. So start_percent < end_percent in user space - # means sigma_start > sigma_end. "Inside the window" is sigma in - # [sigma_end, sigma_start]. - sigma_start = float(model_sampling.percent_to_sigma(start_percent)) - sigma_end = float(model_sampling.percent_to_sigma(end_percent)) + attrs = [] + for name in ("uvs", "vertex_colors", "normals", "texture", "metallic_roughness"): + t = getattr(mesh, name, None) + if t is not None: + if name in ("texture", "metallic_roughness"): + attrs.append(f"{name} {int(t.shape[-3])}×{int(t.shape[-2])}") # H×W + else: + attrs.append(name) - def calc_cond_batch_with_interval(args): - sigma_val = args["sigma"][0].item() - conds = args["conds"] - input_x = args["input"] - timestep = args["sigma"] - model_ref = args["model"] - model_opts = args["model_options"] + lines = [] + if B > 1: + lines.append(f"Batch: {B} meshes") + lines.append(f"Vertices: {cls._fmt(sum(v_counts))} total") + lines.append(f"Faces: {cls._fmt(sum(f_counts))} total") + for i in range(B): + lines.append(f" [{i}] {v_counts[i]:>10,} verts · {f_counts[i]:>10,} faces") + else: + lines.append(f"Vertices: {cls._fmt(v_counts[0])}") + lines.append(f"Faces: {cls._fmt(f_counts[0])}") + lines.append(f"Attributes: {', '.join(attrs) if attrs else 'none'}") - # conds is typically [cond, uncond]; uncond may be None when ComfyUI's - # global cfg=1 optimization has already pruned it. - cond = conds[0] - uncond = conds[1] if len(conds) > 1 else None - inside = sigma_end <= sigma_val <= sigma_start - - if uncond is None or inside: - return comfy.samplers.calc_cond_batch(model_ref, conds, input_x, timestep, model_opts) - # Outside the window: compute cond only, mirror it into the uncond slot - # so the downstream cfg_function collapses to `cond` (effective cfg=1). - out = comfy.samplers.calc_cond_batch(model_ref, [cond], input_x, timestep, model_opts) - return [out[0], out[0]] - - m = model.clone() - m.model_options["sampler_calc_cond_batch_function"] = calc_cond_batch_with_interval - return IO.NodeOutput(m) + info = "\n".join(lines) + logging.info("[GetMeshInfo]\n%s", info) + return IO.NodeOutput(mesh, info, ui=UI.PreviewText(info)) class Trellis2Extension(ComfyExtension): @@ -1411,7 +1299,7 @@ class Trellis2Extension(ComfyExtension): VaeDecodeShapeTrellis, VaeDecodeStructureTrellis2, Trellis2UpsampleStage, - CFGGuidanceInterval, + GetMeshInfo, ]