diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 2c1605759..b7f1f3f0b 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -17,9 +17,10 @@ from comfy_extras.mesh3d.uv_unwrap import parameterize as _uv_param from comfy_extras.mesh3d.uv_unwrap import pack as _uv_pack import warnings import logging -import scipy +from tqdm import tqdm from scipy.sparse import csr_matrix from scipy.sparse.csgraph import connected_components +from scipy.spatial import cKDTree def get_mesh_batch_item(mesh, index): if hasattr(mesh, "vertex_counts") and mesh.vertex_counts is not None: @@ -72,15 +73,12 @@ def pack_variable_mesh_batch(vertices, faces, colors=None): def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): - """ - Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. - """ + """Paint a mesh using nearest-neighbor colors from a sparse voxel field.""" device = comfy.model_management.vae_offload_device() origin = torch.tensor([-0.5, -0.5, -0.5], device=device) voxel_size = 1.0 / resolution - # map voxels voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) voxel_colors = voxel_colors.to(device) @@ -88,19 +86,15 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): voxel_pos_np = voxel_pos.numpy() verts_np = verts.numpy() - tree = scipy.spatial.cKDTree(voxel_pos_np) - - # nearest neighbour k=1 + tree = cKDTree(voxel_pos_np) _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] - # Voxel field may carry the full PBR set (base_color, metallic, roughness, - # alpha); vertex colors only use base_color RGB. + # Voxel field may carry full PBR; vertex colors use only base_color RGB. if v_colors.shape[-1] > 3: v_colors = v_colors[:, :3] - # to [0, 1] srgb_colors = v_colors.clamp(0, 1)#(v_colors * 0.5 + 0.5).clamp(0, 1) # to Linear RGB (required for GLTF) @@ -120,10 +114,7 @@ class PaintMesh(IO.ComfyNode): node_id="PaintMesh", display_name="Paint Mesh", category="latent/3d", - description=( - "Paints the mesh using colors from the input voxel field by matching each vertex " - "to the nearest voxel color." - ), + description="Paints each mesh vertex with its nearest voxel color from the input voxel field.", inputs=[ IO.Mesh.Input("mesh"), IO.Voxel.Input("voxel_colors") @@ -177,24 +168,15 @@ class PaintMesh(IO.ComfyNode): return IO.NodeOutput(out_mesh) -# ============================================================================= -# Texture baking from sparse voxel volume. -# -# Pipeline: take the mesh's existing UVs → OpenGL UV-space rasterize to position -# map → nearest-voxel color sample per texel → GPU Jump-Flood fill UV seams → -# attach texture + UVs to the Mesh for SaveGLB to serialize. Unwrapping is done -# upstream (Trellis2OfficialUnwrap / TorchXatlasUVWrap); this path never unwraps. -# -# Uses comfy_extras.nodes_glsl.GLContext for OpenGL context (already handles -# GLFW / EGL / OSMesa backend selection). -# ============================================================================= +# Texture baking from sparse voxel volume: existing UVs → OpenGL UV-space +# rasterize → per-texel voxel sample → JFA seam fill → attach to mesh for SaveGLB. +# Never unwraps (done upstream). GL context via nodes_glsl.GLContext. _GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache" def _gl_compile_program(gl, vert_src: str, frag_src: str): - """Compile and link a minimal vert+frag GL program. Caller owns the GLuint - and must glDeleteProgram when done.""" + """Compile+link a vert+frag GL program (caller glDeleteProgram).""" def _check_shader(s, kind): if not gl.glGetShaderiv(s, gl.GL_COMPILE_STATUS): log = gl.glGetShaderInfoLog(s).decode(errors="replace") @@ -222,8 +204,8 @@ def _gl_compile_program(gl, vert_src: str, frag_src: str): return prog -# Position-passthrough shader. Vertex maps UV → clip space; fragment outputs the -# interpolated world-space vertex position (with alpha=1 marking valid texels). +# Position-passthrough: vertex maps UV → clip space; fragment outputs interpolated +# world-space position (alpha=1 marks valid texels). _BAKE_VERT_SRC = """ #version 330 core layout (location = 0) in vec3 a_pos; @@ -246,20 +228,15 @@ void main() { def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): - """Rasterize unwrapped mesh in UV space; return (position_map, mask). - position_map: (H, W, 3) float32 — interpolated 3D position per texel. - mask: (H, W) bool — valid (covered) texels. - - Uses comfy_extras.nodes_glsl.GLContext, which lazily picks GLFW/EGL/OSMesa.""" + """Rasterize unwrapped mesh in UV space. Returns (position_map [H,W,3] float32, + mask [H,W] bool covered).""" from comfy_extras.nodes_glsl import GLContext, _import_opengl GLContext() # ensure backend is initialized + current gl = _import_opengl() - # PyOpenGL's high-level wrappers for the buffer/draw/readback functions - # store array refs in OpenGL.contextdata, which on EGL contexts triggers - # "Attempt to retrieve context when no valid context". Use the raw C - # entry points (OpenGL.raw.*) instead — they skip the bookkeeping. + # PyOpenGL high-level wrappers store array refs in OpenGL.contextdata, which on + # EGL contexts errors ("no valid context"); use raw C entry points instead. import ctypes as _ctypes from OpenGL.raw.GL.VERSION.GL_1_1 import ( glReadPixels as _raw_glReadPixels, @@ -272,7 +249,7 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): H = W = int(texture_size) fbo = color_tex = vbo = ibo = vao = prog = None try: - # Interleaved [pos.x, pos.y, pos.z, uv.x, uv.y] per vertex (stride=20 bytes). + # Interleaved [pos.xyz, uv.xy] per vertex (stride=20). verts32 = np.ascontiguousarray(verts_np, dtype=np.float32) uvs32 = np.ascontiguousarray(uvs_np, dtype=np.float32) faces32 = np.ascontiguousarray(faces_np, dtype=np.uint32) @@ -321,16 +298,12 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): _raw_glDrawElements(gl.GL_TRIANGLES, int(faces32.size), gl.GL_UNSIGNED_INT, None) gl.glFinish() - # Pre-allocate readback buffer and pass it as a pointer so PyOpenGL - # doesn't try to allocate one through its array-handler machinery. + # Pre-allocated readback buffer passed as a pointer (skips PyOpenGL alloc). arr = np.empty((H, W, 4), dtype=np.float32) _raw_glReadPixels(0, 0, W, H, gl.GL_RGBA, gl.GL_FLOAT, arr.ctypes.data_as(_ctypes.c_void_p)) - # Do NOT flipud here. Our shader places UV(0,0) at FBO bottom-left - # (clip(-1,-1)), and glReadPixels returns bottom-row-first, so arr[0] - # already holds the UV v=0 data. glTF samples PNG with row 0 = upper-left - # = UV v=0, so storing arr as-is gives a consistent mapping. Flipping - # would invert V and make every sample come from the wrong row. + # Do NOT flipud: shader puts UV(0,0) at FBO bottom-left and glReadPixels + # returns bottom-row-first, so arr[0] is UV v=0 — matches glTF PNG row 0. position_map = arr[..., :3] mask = arr[..., 3] > 0.5 return position_map, mask @@ -353,31 +326,21 @@ def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution): - """Normalized trilinear interpolation of a SPARSE voxel attribute field. - - The official o_voxel.to_glb trilinear-samples a *dense* attribute volume; here - the field is sparse (only surface voxels carry values), so a plain trilinear - would bleed zeros from empty cells. Instead we accumulate, per query, only the - occupied corners among the 8 surrounding voxels and renormalize by their - weights — i.e. trilinear over the occupied subset. Voxel centres sit at integer - coords c with world position c/resolution - 0.5. - - Returns (vals [K, C] float64, ok [K] bool). `ok` is False where none of the 8 - corners is occupied (caller falls back to nearest there).""" + """Normalized trilinear over a SPARSE voxel field (only occupied corners of the 8, + renormalized; matches official o_voxel.to_glb but without dense-volume zero-bleed). + Returns (vals [K,C] float64, ok [K] bool); ok=False where no corner is occupied.""" R = int(resolution) origin = -0.5 voxel_size = 1.0 / R - # Cell-CENTER convention: voxel coord c sits at world origin + (c+0.5)*voxel_size, - # matching the official flex_gemm grid_sample_3d (its trilinear weight centers - # integer coord c at query c+0.5). The `- 0.5` puts integer gc on voxel centres - # so the 8 trilinear corners bracket the query correctly. Omitting it samples - # half a voxel toward the corner — colour bleed at boundaries / thin features. - gc = (positions.astype(np.float64) - origin) / voxel_size - 0.5 # continuous voxel-index coords - base = np.floor(gc).astype(np.int64) # [K,3] lower corner - frac = gc - base # [K,3] in [0,1) + # Cell-CENTER convention: coord c sits at origin+(c+0.5)*voxel_size (matches + # official grid_sample_3d); the -0.5 puts integer gc on centres so the 8 corners + # bracket the query (omitting it bleeds colour at boundaries/thin features). + gc = (positions.astype(np.float64) - origin) / voxel_size - 0.5 + base = np.floor(gc).astype(np.int64) + frac = gc - base vc = voxel_coords_np.astype(np.int64) - occ_keys = (vc[:, 0] * R + vc[:, 1]) * R + vc[:, 2] # linear key per occupied voxel + occ_keys = (vc[:, 0] * R + vc[:, 1]) * R + vc[:, 2] order = np.argsort(occ_keys) occ_sorted = occ_keys[order] @@ -398,9 +361,9 @@ def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution): key = (cx * R + cy) * R + cz ins = np.clip(np.searchsorted(occ_sorted, key), 0, len(occ_sorted) - 1) matched = inb & (occ_sorted[ins] == key) - idx = order[ins] # original voxel index (garbage where !matched) + idx = order[ins] # garbage where !matched w = np.where(matched, wx * wy * wz, 0.0)[:, None] - acc += w * color_np[idx] # w=0 cancels the garbage rows + acc += w * color_np[idx] # w=0 cancels garbage rows wsum += w ok = wsum[:, 0] > 1e-8 vals = np.zeros((K, C), dtype=np.float64) @@ -409,12 +372,7 @@ def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution): def _trilinear_sample_sparse_gpu(positions, voxel_coords_np, color_np, resolution): - """GPU port of `_trilinear_sample_sparse` — same normalized-over-occupied-corners - trilinear, but the per-texel 8-corner accumulation runs on CUDA via sorted-key - `searchsorted` instead of NumPy float64. This is the bake hot path (millions of - covered texels × 8 corners), so the CPU version dominates runtime; the GPU port - is ~identical numerically and 10-50× faster. Returns (vals [K,C] float32, ok - [K] bool), matching the NumPy signature.""" + """GPU port of `_trilinear_sample_sparse`. Returns (vals [K,C] float32, ok [K] bool).""" dev = comfy.model_management.get_torch_device() R = int(resolution) origin = -0.5 @@ -424,8 +382,7 @@ def _trilinear_sample_sparse_gpu(positions, voxel_coords_np, color_np, resolutio col = torch.from_numpy(np.ascontiguousarray(color_np)).to(dev).float() K, C = P.shape[0], col.shape[1] M = VC.shape[0] - # Same cell-CENTER convention as the NumPy path (see its docstring): integer - # voxel coord c sits at (c+0.5)*voxel_size + origin, so subtract 0.5 to bracket. + # Cell-CENTER convention (see NumPy path): -0.5 to bracket the query. gc = (P - origin) / voxel_size - 0.5 base = torch.floor(gc).long() frac = gc - base.float() @@ -457,13 +414,8 @@ def _trilinear_sample_sparse_gpu(positions, voxel_coords_np, color_np, resolutio def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): - """GPU nearest-occupied-voxel lookup for surface points. Voxels sit on a - regular integer grid (coord c ↔ world c/R-0.5), so the nearest voxel to a - query is round((p+0.5)*R) plus a 3³ neighbour check — an O(1)-per-query grid - lookup (sorted-key binary search), ~10-30× faster than a cKDTree over millions - of voxels and ~identical. Returns (vals [K,C] float32, found [K] bool); `found` - is False for the rare query whose nearest occupied voxel is >1 cell away (the - caller falls back to a cKDTree on just those).""" + """GPU nearest-occupied-voxel lookup via sorted-key grid scan. Returns (vals [K,C] + float32, found [K] bool).""" dev = comfy.model_management.get_torch_device() R = int(resolution) P = torch.from_numpy(np.ascontiguousarray(positions)).to(dev).float() @@ -476,9 +428,7 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): def _search(idx, radius): """Nearest occupied voxel within ±radius cells, for query subset P[idx].""" Ps = P[idx] - # Cell-CENTER convention: voxel c is centred at (c+0.5)/R - 0.5 in world, - # so the coord nearest a point is round((p+0.5)*R - 0.5) (matches the - # official grid_sample_3d). The distance test below uses the same centre. + # Cell-CENTER convention: nearest coord = round((p+0.5)*R-0.5) (matches official). rc = ((Ps + 0.5) * R - 0.5).round().long() n = idx.shape[0] bd = torch.full((n,), 1e30, device=dev) @@ -501,16 +451,14 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): return bi, fnd def _brute_nearest(idx): - """Exact nearest occupied voxel for a (small) query subset by chunked GPU - brute force over all M voxels. Used only for the handful of stragglers the - grid scan misses (>4 cells from any voxel) — replaces a cKDTree build over - all M voxels, which costs seconds even for a few query points.""" + """Exact nearest occupied voxel for the few grid-scan stragglers, chunked GPU + brute force (avoids a seconds-long cKDTree build over all M voxels).""" Ps = P[idx] # [N,3] world N = Ps.shape[0] vox_pos = (VC.float() + 0.5) / R - 0.5 # [M,3] voxel centres best_d = torch.full((N,), 1e30, device=dev) best_j = torch.zeros(N, dtype=torch.long, device=dev) - # Bound the N×chunk distance matrix to ~64M elements regardless of N. + # Bound the N×chunk matrix to ~64M elements. chunk = max(1, (1 << 26) // max(1, N)) for s in range(0, M, chunk): vc = vox_pos[s:s + chunk] # [B,3] @@ -524,18 +472,16 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): all_idx = torch.arange(K, device=dev) best_i = torch.zeros(K, dtype=torch.long, device=dev) found = torch.zeros(K, dtype=torch.bool, device=dev) - # Pass 1: radius 1 (3³) over everything — catches ~all surface texels cheaply. + # Pass 1: radius 1 over everything; Pass 2: radius 4 on misses; Pass 3: brute force. bi1, fnd1 = _search(all_idx, 1) best_i[all_idx] = bi1 found[all_idx] = fnd1 - # Pass 2: wider radius (9³) on ONLY the radius-1 misses. miss = torch.nonzero(~found, as_tuple=True)[0] if miss.numel() > 0: bi2, fnd2 = _search(miss, 4) best_i[miss] = bi2 found[miss] = fnd2 - # Pass 3: exact GPU brute force for the few stragglers still unfound (>4 cells - # out). Always resolves them, so `found` is all-True on return — no cKDTree. + # Brute force always resolves, so `found` is all-True on return. miss2 = torch.nonzero(~found, as_tuple=True)[0] if miss2.numel() > 0: best_i[miss2] = _brute_nearest(miss2) @@ -545,13 +491,9 @@ def _nearest_voxel_sample_gpu(positions, voxel_coords_np, color_np, resolution): def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution): - """For every masked texel, sample the voxel field and return ALL its attribute - channels. Returns (H, W, C) float32 in [0, 1] where C is the voxel feature - width (3 for plain color, 6 for full PBR). - - Normalized trilinear over occupied voxels (matches the official o_voxel.to_glb - path), with nearest fallback for texels whose 8 surrounding voxels are all - empty.""" + """Sample all voxel attribute channels at every masked texel. Returns (H,W,C) + float32 in [0,1] (C = feature width: 3 color, 6 PBR). Normalized trilinear over + occupied voxels (matches official), nearest fallback where all 8 corners empty.""" H, W, _ = position_map.shape color_np = voxel_colors.detach().cpu().numpy().astype(np.float32) C = color_np.shape[-1] @@ -562,20 +504,17 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors origin = np.array([-0.5, -0.5, -0.5], dtype=np.float32) voxel_size = 1.0 / float(resolution) coords_np = voxel_coords.detach().cpu().numpy() - # Cell-CENTER convention (+0.5 voxel), matching the official grid_sample_3d and - # the _trilinear/_nearest paths above; this cKDTree only serves the rare - # >cell-radius nearest fallback but must use the same world mapping. + # Cell-CENTER convention (+0.5 voxel) — same world mapping as the GPU paths; this + # cKDTree only serves the rare non-CUDA nearest fallback. voxel_pos = (coords_np.astype(np.float32) + 0.5) * voxel_size + origin valid_positions = position_map[mask] def _nearest(query): - # Fully on-GPU nearest-occupied-voxel: grid scan + brute-force tail. Always - # resolves every query, so no cKDTree (its build over all voxels cost ~3s). + # On-GPU nearest-voxel (grid scan + brute tail); always resolves, no cKDTree. vals, found = _nearest_voxel_sample_gpu(query, coords_np, color_np, resolution) if not found.all(): - # Defensive: only reachable on a non-CUDA device where the GPU path is - # unavailable; fall back to a one-off cKDTree. - tree = scipy.spatial.cKDTree(voxel_pos) + # Only reachable on non-CUDA: fall back to a one-off cKDTree. + tree = cKDTree(voxel_pos) _, nearest_idx = tree.query(query[~found], k=1, workers=-1) vals[~found] = color_np[nearest_idx] return vals @@ -586,16 +525,14 @@ def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors logging.warning(f"[BakeTextureFromVoxel] GPU trilinear failed ({e}); falling back to CPU") vals, ok = _trilinear_sample_sparse(valid_positions, coords_np, color_np, resolution) if not ok.all(): - # Texels with no occupied neighbour fall back to nearest. - vals[~ok] = _nearest(valid_positions[~ok]) + vals[~ok] = _nearest(valid_positions[~ok]) # no occupied neighbour out[mask] = np.clip(vals, 0.0, 1.0).astype(np.float32) return out def _closest_point_on_triangles(p, a, b, c): - """Vectorized exact closest point on triangles (Ericson, Real-Time Collision - Detection §5.1.5). p/a/b/c are [..., 3]; returns [..., 3]. Handles all - vertex/edge/face Voronoi regions, applied highest-priority-last via where.""" + """Vectorized exact closest point on triangles (Ericson §5.1.5). p/a/b/c [...,3] → + [...,3]; all vertex/edge/face Voronoi regions, highest-priority-last via where.""" ab = b - a ac = c - a ap = p - a @@ -619,32 +556,28 @@ def _closest_point_on_triangles(p, a, b, c): v = vb * denom w = vc * denom res = a + ab * u(v) + ac * u(w) - # edge BC den_bc = (d4 - d3) + (d5 - d6) w_bc = (d4 - d3) / den_bc.clamp_min(1e-20) - res = torch.where(u((va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0)), - b + (c - b) * u(w_bc), res) - # edge AC + res = torch.where(u((va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0)), b + (c - b) * u(w_bc), res) # edge BC w_ac = d2 / (d2 - d6).clamp_min(1e-20) - res = torch.where(u((vb <= 0) & (d2 >= 0) & (d6 <= 0)), a + ac * u(w_ac), res) - # vertex C - res = torch.where(u((d6 >= 0) & (d5 <= d6)), c, res) - # edge AB + res = torch.where(u((vb <= 0) & (d2 >= 0) & (d6 <= 0)), a + ac * u(w_ac), res) # edge AC + res = torch.where(u((d6 >= 0) & (d5 <= d6)), c, res) # vertex C v_ab = d1 / (d1 - d3).clamp_min(1e-20) - res = torch.where(u((vc <= 0) & (d1 >= 0) & (d3 <= 0)), a + ab * u(v_ab), res) - # vertex B - res = torch.where(u((d3 >= 0) & (d4 <= d3)), b, res) - # vertex A - res = torch.where(u((d1 <= 0) & (d2 <= 0)), a, res) + res = torch.where(u((vc <= 0) & (d1 >= 0) & (d3 <= 0)), a + ab * u(v_ab), res) # edge AB + res = torch.where(u((d3 >= 0) & (d4 <= d3)), b, res) # vertex B + res = torch.where(u((d1 <= 0) & (d2 <= 0)), a, res) # vertex A return res def _msb_int64(x): """floor(log2(x)) elementwise for int64 x >= 1 (bit-search, no float).""" - r = torch.zeros_like(x); xx = x.clone() + r = torch.zeros_like(x) + xx = x.clone() for s in (32, 16, 8, 4, 2, 1): - sh = xx >> s; m = sh > 0 - r = torch.where(m, r + s, r); xx = torch.where(m, sh, xx) + sh = xx >> s + m = sh > 0 + r = torch.where(m, r + s, r) + xx = torch.where(m, sh, xx) return r @@ -660,36 +593,38 @@ def _morton_expand21(v): def _build_triangle_bvh(tri): - """Linear BVH (Karras 2012) over triangle AABBs, pure torch, NO external deps. - - 21-bit-per-axis Morton sort of triangle centroids -> parallel radix-tree - construction -> bottom-up node AABBs. Internal nodes are indexed 0..T-2, leaves - are encoded as LEAF+i (i in 0..T-1) where leaf i holds triangle `order[i]`. - Returns a dict with node AABBs (nmin,nmax over 2T entries), child links - (left,right), the leaf->triangle map `order`, LEAF offset and T. - - A real tree (not a uniform grid) is what makes the closest-point query prune - empty space and dense clusters, so it stays fast on huge, non-uniform references - where the grid's ring search blows up — i.e. the cuMesh BVH approach, in torch.""" - dev = tri.device; T = tri.shape[0] - amin = tri.amin(1); amax = tri.amax(1); cent = (amin + amax) * 0.5 - lo = cent.amin(0); hi = cent.amax(0); span = (hi - lo).clamp_min(1e-12) + """Linear BVH (Karras 2012) over triangle AABBs, pure torch, no external deps + (the cuMesh approach, in torch). Internal nodes 0..T-2; leaves encoded LEAF+i, + leaf i holds triangle order[i]. Returns dict(LEAF, left, right, nmin, nmax over + 2T entries, order, T).""" + dev = tri.device + T = tri.shape[0] + amin = tri.amin(1) + amax = tri.amax(1) + cent = (amin + amax) * 0.5 + lo = cent.amin(0) + hi = cent.amax(0) + span = (hi - lo).clamp_min(1e-12) q = (((cent - lo) / span).clamp(0, 1) * float((1 << 21) - 1)).long() morton = (_morton_expand21(q[:, 0]) << 2 | _morton_expand21(q[:, 1]) << 1 | _morton_expand21(q[:, 2])).long() - order = torch.argsort(morton); msort = morton[order] + order = torch.argsort(morton) + msort = morton[order] - # delta(i,j): length of the common prefix of the (morton, index) keys of leaves - # i and j (index breaks ties so duplicate Morton codes still split); -1 if OOB. + # delta(i,j): common-prefix length of (morton, index) keys of leaves i,j (index + # breaks ties so duplicate codes still split); -1 if OOB. def delta(i, j): - ok = (j >= 0) & (j < T); jj = j.clamp(0, T - 1) - x = msort[i] ^ msort[jj]; same = x == 0 + ok = (j >= 0) & (j < T) + jj = j.clamp(0, T - 1) + x = msort[i] ^ msort[jj] + same = x == 0 cp = torch.where(same, torch.full_like(x, 63), 62 - _msb_int64(x.clamp_min(1))) xi = i ^ jj cpi = torch.where(xi == 0, torch.full_like(x, 32), 31 - _msb_int64(xi.clamp_min(1))) return torch.where(ok, cp + torch.where(same, cpi, torch.zeros_like(cp)), torch.full_like(x, -1)) I = torch.arange(T - 1, device=dev) - dplus = delta(I, I + 1); dminus = delta(I, I - 1) + dplus = delta(I, I + 1) + dminus = delta(I, I - 1) direction = torch.where(dplus >= dminus, torch.ones_like(I), -torch.ones_like(I)) dmin = torch.minimum(dplus, dminus) # range length: exponential probe then binary search @@ -701,7 +636,8 @@ def _build_triangle_bvh(tri): lmax = torch.where(cond, lmax * 2, lmax) if int(lmax.max()) > 2 * T: break - l = torch.zeros_like(I); t = lmax.clone() + l = torch.zeros_like(I) + t = lmax.clone() while True: t = t // 2 if int(t.max()) == 0: @@ -709,10 +645,13 @@ def _build_triangle_bvh(tri): cond = delta(I, I + (l + t) * direction) > dmin l = torch.where(cond, l + t, l) j = I + l * direction - first = torch.minimum(I, j); last = torch.maximum(I, j) + first = torch.minimum(I, j) + last = torch.maximum(I, j) # split position: binary search on delta within [first, last] dnode = delta(first, last) - s = torch.zeros_like(I); div = torch.full_like(I, 2); rng = last - first + s = torch.zeros_like(I) + div = torch.full_like(I, 2) + rng = last - first while True: step = (rng + div - 1) // div cond = delta(first, (first + s + step).clamp(max=T - 1)) > dnode @@ -722,15 +661,18 @@ def _build_triangle_bvh(tri): s = torch.where(cond1, s + 1, s) break div = div * 2 - gamma = first + s; LEAF = T + gamma = first + s + LEAF = T left = torch.where(gamma == first, LEAF + gamma, gamma) right = torch.where(gamma + 1 == last, LEAF + gamma + 1, gamma + 1) - # node AABBs: leaves seeded, internal unioned bottom-up over a few passes (a - # balanced tree settles in ~log2(T) passes; the cap is a safety bound). - nmin = torch.empty((2 * T, 3), device=dev); nmax = torch.empty((2 * T, 3), device=dev) - nmin[LEAF:] = amin[order]; nmax[LEAF:] = amax[order] - setm = torch.zeros(2 * T, dtype=torch.bool, device=dev); setm[LEAF:] = True + # node AABBs: leaves seeded, internal unioned bottom-up (~log2(T) passes; cap is a backstop). + nmin = torch.empty((2 * T, 3), device=dev) + nmax = torch.empty((2 * T, 3), device=dev) + nmin[LEAF:] = amin[order] + nmax[LEAF:] = amax[order] + setm = torch.zeros(2 * T, dtype=torch.bool, device=dev) + setm[LEAF:] = True for _ in range(128): need = ~setm[:T - 1] if not bool(need.any()): @@ -746,45 +688,63 @@ def _build_triangle_bvh(tri): def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64): - """Exact closest surface point per query, via per-query stack traversal of the - triangle BVH (nearest-child-first for tight pruning), pure torch. Returns [N,3]. - - Each while-iteration advances all still-active queries by one node; the active - set shrinks fast, so even a few thousand iterations are cheap big GPU kernels. - `max_stack` bounds the per-query stack (= tree height); overflow is counted and - warned (a handful of texels could be slightly off) rather than silently wrong.""" - dev = Q.device; N = Q.shape[0] - LEAF = bvh['LEAF']; nmin = bvh['nmin']; nmax = bvh['nmax'] - left = bvh['left']; right = bvh['right']; order = bvh['order'] + """Exact closest surface point per query via per-query BVH stack traversal + (nearest-child-first), pure torch. Returns [N,3]. `max_stack` bounds the stack + (= tree height); overflow is counted+warned, not silently wrong.""" + dev = Q.device + N = Q.shape[0] + LEAF = bvh['LEAF'] + nmin = bvh['nmin'] + nmax = bvh['nmax'] + left = bvh['left'] + right = bvh['right'] + order = bvh['order'] stack = torch.full((N, max_stack), -1, dtype=torch.long, device=dev) - sp = torch.ones(N, dtype=torch.long, device=dev); stack[:, 0] = 0 - best = torch.full((N,), 1e30, device=dev); bestp = Q.clone() - active = torch.arange(N, device=dev); overflow = 0 + sp = torch.ones(N, dtype=torch.long, device=dev) + stack[:, 0] = 0 + best = torch.full((N,), 1e30, device=dev) + bestp = Q.clone() + active = torch.arange(N, device=dev) + overflow = 0 def aabb_d2(node, q): d = (nmin[node] - q).clamp_min(0) + (q - nmax[node]).clamp_min(0) return (d * d).sum(-1) while active.numel() > 0: - a = active; qa = Q[a] - node = stack[a, sp[a] - 1]; sp[a] = sp[a] - 1 + a = active + qa = Q[a] + node = stack[a, sp[a] - 1] + sp[a] = sp[a] - 1 within = aabb_d2(node, qa) < best[a] isleaf = node >= LEAF lv = within & isleaf if bool(lv.any()): - ga = a[lv]; tt = tri[order[node[lv] - LEAF]] + ga = a[lv] + tt = tri[order[node[lv] - LEAF]] cp = _closest_point_on_triangles(qa[lv], tt[:, 0], tt[:, 1], tt[:, 2]) d2 = ((cp - qa[lv]) ** 2).sum(-1) - upd = d2 < best[ga]; gu = ga[upd]; best[gu] = d2[upd]; bestp[gu] = cp[upd] + upd = d2 < best[ga] + gu = ga[upd] + best[gu] = d2[upd] + bestp[gu] = cp[upd] iv = within & ~isleaf if bool(iv.any()): - gi = a[iv]; qi = qa[iv]; lc = left[node[iv]]; rc = right[node[iv]] - dl = aabb_d2(lc, qi); dr = aabb_d2(rc, qi) - near = torch.where(dl <= dr, lc, rc); far = torch.where(dl <= dr, rc, lc) + gi = a[iv] + qi = qa[iv] + lc = left[node[iv]] + rc = right[node[iv]] + dl = aabb_d2(lc, qi) + dr = aabb_d2(rc, qi) + near = torch.where(dl <= dr, lc, rc) + far = torch.where(dl <= dr, rc, lc) s0 = sp[gi] - stack[gi, s0.clamp(max=max_stack - 1)] = far; sp[gi] = (s0 + 1).clamp(max=max_stack) - s1 = sp[gi]; overflow += int((s1 >= max_stack).sum()) - stack[gi, s1.clamp(max=max_stack - 1)] = near; sp[gi] = (s1 + 1).clamp(max=max_stack) + stack[gi, s0.clamp(max=max_stack - 1)] = far + sp[gi] = (s0 + 1).clamp(max=max_stack) + s1 = sp[gi] + overflow += int((s1 >= max_stack).sum()) + stack[gi, s1.clamp(max=max_stack - 1)] = near + sp[gi] = (s1 + 1).clamp(max=max_stack) active = a[sp[a] > 0] if overflow: logging.warning(f"[back-project] BVH stack overflow on {overflow} pushes " @@ -794,31 +754,21 @@ def _closest_points_on_mesh_bvh(Q, tri, bvh, max_stack=64): def _back_project_positions(position_map, mask, ref_v, ref_f): - """Snap each covered texel's interpolated position onto the reference mesh's true - surface, so the voxel field is sampled at full surface detail instead of along - flat triangle chords (the cause of faceted/pixelized bakes on coarse meshes). - Mirrors o_voxel.to_glb step 7c but with NO cumesh/scipy/trimesh dependency: a - pure-torch linear BVH (`_build_triangle_bvh`) + exact closest-point traversal, - the same approach as cuMesh's cuBVH. Returns a new position_map with the covered - texels replaced.""" + """Snap covered texels onto the reference mesh's true surface (pure-torch BVH, no + cumesh/scipy/trimesh) so the voxel field is sampled at full detail, not along flat + triangle chords. Returns a new position_map.""" valid = np.ascontiguousarray(position_map[mask].astype(np.float32)) if valid.shape[0] == 0: return position_map - import time as _time dev = comfy.model_management.get_torch_device() rv = ref_v.detach().to(dev).float() rf = ref_f.detach().to(dev).long() tri = rv[rf] Q = torch.from_numpy(valid).to(dev) - _t = _time.perf_counter() bvh = _build_triangle_bvh(tri) - _tb = _time.perf_counter() bp = _closest_points_on_mesh_bvh(Q, tri, bvh) - logging.info(f"[back-project] BVH build {_tb - _t:.1f}s + traverse " - f"{_time.perf_counter() - _tb:.1f}s ({rf.shape[0]} ref tris, " - f"{valid.shape[0]} texels)") out = position_map.copy() out[mask] = bp.detach().cpu().numpy().astype(position_map.dtype) @@ -826,10 +776,8 @@ def _back_project_positions(position_map, mask, ref_v, ref_f): def _jfa_fill_gpu(img01, mask): - """Fill every uncovered texel with its nearest covered texel's value via GPU - Jump Flooding (O(log n) passes) — a fast nearest-fill replacement for - cv2.inpaint on UV seam/gutter filling. img01 [H,W,C] float, mask [H,W] bool - (True = covered). Returns [H,W,C] float. ~6× faster than cv2 Telea per map.""" + """Fill uncovered texels with nearest covered value via GPU Jump Flooding + (O(log n) passes; replaces cv2.inpaint). img01 [H,W,C] float, mask [H,W] bool.""" if not mask.any(): return img01 dev = comfy.model_management.get_torch_device() @@ -862,25 +810,17 @@ def _jfa_fill_gpu(img01, mask): def _seam_fill(img01, mask, inpaint_radius): - """Fill the UV-gutter texels around covered charts so seam sampling doesn't - pull in black, via GPU Jump Flooding (nearest fill). `inpaint_radius<=0` - disables; otherwise the radius is ignored — JFA fills every uncovered texel - by nearest regardless.""" + """Fill UV-gutter texels (so seams don't pull in black) via JFA. `inpaint_radius<=0` + disables; the radius value itself is ignored (JFA fills all uncovered by nearest).""" if inpaint_radius <= 0: return img01 return _jfa_fill_gpu(img01, mask) def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None): - """Uniformly fit a UV layout's bbox into [0,1] when it spills outside the unit - square (preserves chart aspect ratios; handles packers that overflow slightly). - No-op when the UVs are already in [0,1] — the normal case for official/xatlas - unwraps. NOT a UDIM de-tiler; warns if the span looks tiled. - - Deterministic from the input UVs alone, so the texture bake and - ApplyTextureToMesh both call it to agree on the exact UVs the texture was baked - against (the bake no longer emits the mesh, so apply must re-derive them). - + """Uniformly fit a UV bbox into [0,1] when it spills outside (preserves aspect; + no-op if already in [0,1]; not a UDIM de-tiler). Shared deterministic helper — + bake and ApplyTextureToMesh both call it so UVs agree (keep both paths in sync). Returns float32 [N,2].""" uv_np = uv_np.astype(np.float32) uv_min = uv_np.min(axis=0) @@ -904,35 +844,16 @@ def _normalize_uvs_to_unit(uv_np, normalize=True, log_prefix=None): def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, resolution, texture_size, uvs, inpaint_radius=3, normalize_uvs=True, reference=None, pbar=None): - """Bake a baseColor (+ optional metallicRoughness) texture for - `vertices/faces`, rasterizing in UV space and nearest-voxel-sampling each - texel from the provided sparse colored voxel volume. - - `uvs` (N, 2) is the mesh's existing UV layout — baked onto directly (this - node never unwraps; connect a UV unwrap node upstream). It must be 1:1 with - `vertices`. - - Returns (out_vertices, out_faces, out_uvs, out_texture, out_mr). - - Progress: drives a local tqdm over its 5 stages (uvs → rasterize → - back-project → sample → finalize) and, if a comfy `pbar` (ProgressBar) is - passed, ticks it once per stage too — so callers should size it as 5 per - bake.""" - import time - - # 5-stage progress: tqdm (console) + optional comfy ProgressBar (UI). _tick is - # called exactly once at each stage boundary, including no-op stages (e.g. no - # back-projection), so the comfy pbar stays aligned at 5 ticks per bake. - try: - from tqdm import tqdm as _tqdm - _tq = _tqdm(total=5, desc="Bake texture", leave=False) - except Exception: - _tq = None + """Bake a baseColor (+ optional metallicRoughness) texture: rasterize in UV space, + sample each texel from the sparse voxel volume. `uvs` (N,2) is the existing layout, + 1:1 with `vertices` (never unwraps). Returns (v, f, uvs, texture, mr). Ticks `pbar` + once per stage; size it 5 per bake.""" + # _tick fires once per stage boundary, including no-op stages, so the 5-tick pbar stays aligned. + _tq = tqdm(total=5, desc="Bake texture", leave=False) def _tick(name): - if _tq is not None: - _tq.set_postfix_str(name) - _tq.update(1) + _tq.set_postfix_str(name) + _tq.update(1) if pbar is not None: pbar.update(1) @@ -940,7 +861,6 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, f_np = faces.detach().cpu().numpy().astype(np.uint32) fcount = int(f_np.shape[0]) - # Bake onto the mesh's current UVs — no unwrap, no seam-splitting. uv_np = uvs.detach().cpu().numpy().astype(np.float32) if uv_np.shape[0] != v_np.shape[0]: raise ValueError( @@ -959,48 +879,34 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, _tick("uvs") - t1 = time.perf_counter() position_map, mask = _bake_position_map(new_verts, new_faces, new_uvs, texture_size) - logging.info(f"[BakeTextureFromVoxel] GL rasterize {texture_size}² in {time.perf_counter() - t1:.1f}s " - f"({int(mask.sum())}/{mask.size} texels covered)") _tick("rasterize") if reference is not None: - # Back-project texel positions onto the original dense surface before - # sampling — the o_voxel.to_glb step that makes the bake smooth on coarse - # meshes (instead of sampling along flat triangle chords). - tb = time.perf_counter() + # Back-project onto the dense surface before sampling (smooth bake on coarse + # meshes, not along flat triangle chords). position_map = _back_project_positions(position_map, mask, reference[0], reference[1]) - logging.info(f"[BakeTextureFromVoxel] BVH back-project in {time.perf_counter() - tb:.1f}s") _tick("back-project") - t2 = time.perf_counter() attrs = _sample_voxel_attrs_per_texel( position_map, mask, voxel_coords, voxel_colors, resolution, ) - logging.info(f"[BakeTextureFromVoxel] voxel sample in {time.perf_counter() - t2:.1f}s " - f"({attrs.shape[-1]} channels)") _tick("sample") - # Split into PBR maps. Layout matches upstream pbr_attr_layout: - # 0:3 base_color, 3 metallic, 4 roughness, 5 alpha. + # PBR layout (upstream pbr_attr_layout): 0:3 base_color, 3 metallic, 4 roughness, 5 alpha. C = attrs.shape[-1] base_color = attrs[..., 0:3] has_pbr = C >= 5 metallic = attrs[..., 3:4] if C >= 4 else None roughness = attrs[..., 4:5] if C >= 5 else None - # alpha channel exists at index 5 but we keep meshes opaque (upstream uses - # alpha_mode=OPAQUE in the remesh path); plumb later if needed. + # alpha (idx 5) ignored — meshes kept opaque (upstream OPAQUE alpha_mode). - t3 = time.perf_counter() base_color = _seam_fill(np.ascontiguousarray(base_color), mask, inpaint_radius) mr_image = None if has_pbr: # glTF metallicRoughness: R unused, G=roughness, B=metallic. mr = np.concatenate([np.zeros_like(roughness), roughness, metallic], axis=-1) mr_image = _seam_fill(np.ascontiguousarray(mr), mask, inpaint_radius) - if inpaint_radius > 0: - logging.info(f"[BakeTextureFromVoxel] inpaint in {time.perf_counter() - t3:.1f}s") device = vertices.device out_v = torch.from_numpy(new_verts).to(device=device, dtype=torch.float32) @@ -1010,85 +916,13 @@ def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, out_mr = (torch.from_numpy(np.ascontiguousarray(mr_image)).to(device=device, dtype=torch.float32) if mr_image is not None else None) _tick("finalize") - if _tq is not None: - _tq.close() + _tq.close() return out_v, out_f, out_uvs, out_tex, out_mr -def _per_vertex_normals(verts_np, faces_np): - """Area-weighted per-vertex normals (unit length) for a triangle mesh.""" - v = verts_np.astype(np.float64) - f = faces_np.astype(np.int64) - # Un-normalized face normals are area-weighted (cross product magnitude = 2*area), - # so accumulating them onto vertices gives an area-weighted vertex normal. - fn = np.cross(v[f[:, 1]] - v[f[:, 0]], v[f[:, 2]] - v[f[:, 0]]) - vn = np.zeros_like(v) - for k in range(3): - np.add.at(vn, f[:, k], fn) - vn = vn / np.clip(np.linalg.norm(vn, axis=1, keepdims=True), 1e-12, None) - return vn.astype(np.float32) - - -def bake_texture_multiview_fn(vertices, faces, voxel_coords, voxel_colors, resolution, - texture_size, views, uvs, blend_temperature=0.25, - inpaint_radius=3, normalize_uvs=True): - """Bake a baseColor texture by projecting view photos onto the mesh. - - Reuses bake_texture_from_voxel_fn for the UV-space bake + the nearest-voxel - fallback colour, then overlays photo colour on every covered+visible texel: - each texel's world position/normal is projected into each view, occlusion is - resolved with a texel z-buffer, and the views are blended weighted by how - directly each camera faces the surface. Texels seen by no view keep the voxel - colour. The seam inpaint runs last, over the composited result. - - `views`: list of dicts {image[H,W,3] in [0,1], azimuth_deg, transform_matrix[4,4], - camera_angle_x (scalar tensor), image_resolution}. All Pixal3D views share the - one front camera and differ only by azimuth. - - Returns (verts, faces, uvs, tex, mr) — same shape contract as - bake_texture_from_voxel_fn, so the node attaches them identically.""" - from comfy.ldm.trellis2 import multiview_bake as mvbake - - # Voxel bake → unwrapped geometry + fallback colour (inpaint deferred to the end). - out_v, out_f, out_uvs, voxel_tex, voxel_mr = bake_texture_from_voxel_fn( - vertices, faces, voxel_coords, voxel_colors, resolution=resolution, - texture_size=texture_size, uvs=uvs, inpaint_radius=0, - normalize_uvs=normalize_uvs) - - v_np = out_v.detach().cpu().numpy().astype(np.float32) - f_np = out_f.detach().cpu().numpy().astype(np.uint32) - uv_np = out_uvs.detach().cpu().numpy().astype(np.float32) - - # Per-texel world position + normal (the GL baker outputs any per-vertex vec3). - position_map, mask = _bake_position_map(v_np, f_np, uv_np, texture_size) - normal_map, _ = _bake_position_map(_per_vertex_normals(v_np, f_np), f_np, uv_np, texture_size) - - device = out_v.device - base = voxel_tex.detach().cpu().numpy().copy() - if mask.any() and views: - pos = torch.from_numpy(np.ascontiguousarray(position_map[mask])).to(device) - nrm = torch.from_numpy(np.ascontiguousarray(normal_map[mask])).to(device) - fallback = torch.from_numpy(np.ascontiguousarray(base[mask])).to(device) - view_objs = [{ - "image": vw["image"].to(device), - "azimuth_deg": vw["azimuth_deg"], - "transform_matrix": vw["transform_matrix"].to(device), - "camera_angle_x": vw["camera_angle_x"].to(device), - "image_resolution": vw["image_resolution"], - } for vw in views] - rgb, _seen = mvbake.composite_views(pos, nrm, view_objs, fallback, blend_temperature) - base[mask] = rgb.detach().cpu().numpy() - - base = _seam_fill(np.ascontiguousarray(base), mask, inpaint_radius) - - out_tex = torch.from_numpy(np.ascontiguousarray(base)).to(device=device, dtype=torch.float32) - return out_v, out_f, out_uvs, out_tex, voxel_mr - - def _mr_channel(packed_mr, ch, ref): - """Pull one channel out of a packed glTF MR map (G=roughness at idx 1, B=metallic - at idx 2) as a 3-channel grayscale IMAGE [H,W,3] in [0,1]. Returns black sized - like `ref` when there's no MR map (non-PBR voxel field).""" + """Pull one channel (G=roughness idx 1, B=metallic idx 2) from a packed glTF MR map + as 3-channel grayscale [H,W,3] in [0,1]. Black sized like `ref` if no MR map.""" if packed_mr is None: return torch.zeros_like(ref.float().cpu()) m = packed_mr.float().clamp(0.0, 1.0).cpu() @@ -1103,27 +937,20 @@ class BakeTextureFromVoxel(IO.ComfyNode): display_name="Bake Texture From Voxel", category="latent/3d", description=( - "Bakes PBR textures onto the mesh's existing UV layout by rasterizing it " - "in UV space via OpenGL (ComfyUI's PyOpenGL backend) and trilinear-sampling " - "the input sparse voxel volume. Does NOT unwrap — connect a UV unwrap node " - "(e.g. Trellis2OfficialUnwrap or TorchXatlasUVWrap) upstream. Outputs the " - "baked maps as IMAGEs: base_color, plus metallic and roughness as separate " - "grayscale maps (both black when the voxel field has no PBR set). " - "Preview/save/post-process them, then feed them to ApplyTextureToMesh (with " - "the SAME mesh) to attach them for SaveGLB. UVs that spill outside [0,1] are " - "uniformly fit into the unit square." + "Bakes PBR textures onto the mesh's existing UV layout (trilinear-sample the " + "sparse voxel volume). Does NOT unwrap — connect a UV unwrap node upstream. Outputs " + "base_color + metallic/roughness grayscale IMAGEs (black if no PBR); feed them to " + "ApplyTextureToMesh (SAME mesh) for SaveGLB." ), inputs=[ IO.Mesh.Input("mesh"), IO.Voxel.Input("voxel_colors"), IO.Int.Input("texture_size", default=1024, min=64, max=8192, - tooltip="Square texture resolution. Larger = sharper but slower / bigger file."), + tooltip="Square texture resolution."), IO.Mesh.Input("reference_mesh", optional=True, tooltip=( - "Optional original (dense, pre-decimation) mesh. If connected, each " - "texel is back-projected onto its true surface before sampling — the " - "o_voxel.to_glb step that removes faceted/pixelized baking on coarse " - "meshes. Pure scipy+torch, no extra deps.")), + "Optional dense pre-decimation mesh; back-projects each texel onto its " + "true surface before sampling, removing faceted baking on coarse meshes.")), ], outputs=[ IO.Image.Output(display_name="base_color"), @@ -1134,8 +961,7 @@ class BakeTextureFromVoxel(IO.ComfyNode): @classmethod def execute(cls, mesh, voxel_colors, texture_size, reference_mesh=None): - # Seam-gutter inpaint radius is hardcoded to 3 (matches the official to_glb); - # it's an on/off-grade knob — Telea fills the whole gutter regardless of value. + # Matches official to_glb; effectively on/off since the gutter fill ignores the value. inpaint_radius = 3 voxels = voxel_colors coords = voxels.data @@ -1154,8 +980,7 @@ class BakeTextureFromVoxel(IO.ComfyNode): voxel_xyz = coords[:, 1:] mesh_batch_size = int(mesh.vertices.shape[0]) out_tex, out_mr = [], [] - # 5 stage ticks per item (see bake_texture_from_voxel_fn); skipped items - # tick all 5 so the bar stays aligned. + # 5 ticks per item; skipped items tick all 5 to stay aligned. pbar = comfy.utils.ProgressBar(mesh_batch_size * 5) for i in range(mesh_batch_size): sel = batch_idx == i @@ -1177,15 +1002,13 @@ class BakeTextureFromVoxel(IO.ComfyNode): uvs=ev_i, inpaint_radius=inpaint_radius, reference=ref_i, pbar=pbar, ) - out_tex.append(bt); out_mr.append(bmr) + out_tex.append(bt) + out_mr.append(bmr) if not out_tex: - # Every item skipped (degenerate) — emit one black map so the IMAGE - # outputs stay valid. + # All items skipped — emit one black map so IMAGE outputs stay valid. black = torch.zeros((1, texture_size, texture_size, 3)) return IO.NodeOutput(black, black, black) - # All maps are texture_size² — stack into [B,H,W,3] IMAGE batches. The - # packed MR (G=roughness, B=metallic) is split into separate grayscale - # maps; both black where the voxel field carried no PBR set. + # Stack [B,H,W,3]; split packed MR (G=roughness, B=metallic) into grayscale maps. base_img = torch.stack([t.float().clamp(0.0, 1.0).cpu() for t in out_tex], dim=0) metallic_img = torch.stack([_mr_channel(m, 2, out_tex[0]) for m in out_mr], dim=0) roughness_img = torch.stack([_mr_channel(m, 1, out_tex[0]) for m in out_mr], dim=0) @@ -1198,7 +1021,7 @@ class BakeTextureFromVoxel(IO.ComfyNode): ref0 = None if reference_mesh is not None: ref0 = (reference_mesh.vertices.squeeze(0), reference_mesh.faces.squeeze(0)) - pbar = comfy.utils.ProgressBar(5) # 5 stage ticks (see bake_texture_from_voxel_fn) + pbar = comfy.utils.ProgressBar(5) # 5 stage ticks _bv, _bf, _bu, bt, bmr = bake_texture_from_voxel_fn( v0, f0, coords, colors, resolution=resolution, texture_size=texture_size, @@ -1219,10 +1042,8 @@ class MeshTextureToImage(IO.ComfyNode): display_name="Mesh Texture to Image", category="latent/3d", description=( - "Extracts a mesh's baked textures as IMAGE outputs for preview/save. " - "base_color is the baseColor map; metallic_roughness is the packed " - "glTF MR map (R unused, G=roughness, B=metallic) — black if the mesh " - "has no PBR texture." + "Extracts a mesh's baked textures as IMAGE outputs: base_color and the packed " + "glTF MR map (G=roughness, B=metallic; black if no PBR texture)." ), inputs=[IO.Mesh.Input("mesh")], outputs=[ @@ -1236,7 +1057,7 @@ class MeshTextureToImage(IO.ComfyNode): @classmethod def execute(cls, mesh): def _as_image(tex): - # Mesh textures are (B, H, W, 3) float in [0, 1] — already IMAGE layout. + # Mesh textures are (B,H,W,3) float [0,1] — already IMAGE layout. if tex is None: return None t = tex.float().clamp(0.0, 1.0).cpu() @@ -1254,9 +1075,7 @@ class MeshTextureToImage(IO.ComfyNode): ) if mr is None: mr = torch.zeros_like(base) - # Split the packed glTF MR map into single-channel grayscale previews: - # G=roughness, B=metallic. Replicated to 3 channels so they display - # as proper grayscale IMAGEs. + # Split packed MR into grayscale previews (G=roughness, B=metallic), to 3ch. metallic = mr[..., 2:3].expand(-1, -1, -1, 3).contiguous() roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous() return IO.NodeOutput(base, mr, metallic, roughness) @@ -1270,15 +1089,11 @@ class ApplyTextureToMesh(IO.ComfyNode): display_name="Apply Texture to Mesh", category="latent/3d", description=( - "Attaches baked texture IMAGEs to a mesh's existing UV layout so SaveGLB " - "serializes them as baseColorTexture / metallicRoughnessTexture maps. Pairs " - "with BakeTextureFromVoxel: feed it the SAME mesh you baked from, plus that " - "node's base_color (and optionally metallic / roughness grayscale maps) — the " - "UVs must match the ones the texture was baked against, so don't re-unwrap in " - "between. metallic and roughness are repacked into the glTF MR map " - "(G=roughness, B=metallic); leave them unconnected for non-PBR meshes (a " - "missing metallic defaults to 0, a missing roughness to 1). Lets you preview / " - "upscale / edit the baked maps before applying them." + "Attaches baked texture IMAGEs to a mesh's existing UV layout for SaveGLB. " + "Pairs with BakeTextureFromVoxel: feed the SAME mesh and its base_color " + "(optionally metallic/roughness); don't re-unwrap in between. metallic/roughness " + "repack into the glTF MR map (G=roughness, B=metallic); missing metallic=0, " + "roughness=1." ), inputs=[ IO.Mesh.Input("mesh"), @@ -1298,9 +1113,7 @@ class ApplyTextureToMesh(IO.ComfyNode): "you fed to BakeTextureFromVoxel (this node attaches onto existing UVs and " "never unwraps).") - # Re-derive the exact UVs the bake rasterized against — it uniformly fits - # out-of-[0,1] layouts into the unit square, so apply the identical - # deterministic transform here (per batch item, over each item's real verts). + # Re-derive the exact UVs the bake used (shared _normalize_uvs_to_unit), per item. if mesh_uvs.ndim == 3: new_uvs = mesh_uvs.clone() for i in range(mesh_uvs.shape[0]): @@ -1316,9 +1129,7 @@ class ApplyTextureToMesh(IO.ComfyNode): out_mesh.uvs = new_uvs out_mesh.texture = base_color.float().clamp(0.0, 1.0).cpu() if metallic is not None or roughness is not None: - # Repack separate grayscale maps into glTF MR: R unused, G=roughness, - # B=metallic. Size defaults off whichever map is connected; a missing - # channel falls back to a sensible scalar (metal 0 / rough 1). + # Repack glTF MR (G=roughness, B=metallic); missing channel → scalar (metal 0/rough 1). prov = (metallic if metallic is not None else roughness).float().clamp(0.0, 1.0).cpu() B, H, W, _ = prov.shape rough_ch = (roughness.float().clamp(0.0, 1.0).cpu()[..., 0:1] @@ -1413,7 +1224,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): mesh_normal = face_normals.mean(dim=0) mesh_normal = mesh_normal / (torch.norm(mesh_normal) + 1e-8) - # === FIX: Fill ALL boundary loops below perimeter threshold === + # Fill all boundary loops below the perimeter threshold. new_verts = [] new_faces = [] v_idx = v.shape[0] @@ -1422,7 +1233,6 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): loop_t = torch.tensor(loop, device=device, dtype=torch.long) loop_v = v[loop_t] - # Perimeter check next_v = torch.roll(loop_v, -1, dims=0) diffs = loop_v - next_v perimeter = torch.norm(diffs, dim=1).sum().item() @@ -1456,138 +1266,9 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): return v, f -def _fill_holes_v2_diagnostic(verts, faces, max_perimeter): - """Topology dump for debugging missed-hole cases. Logs edge count - distribution, cycle count, and per-cycle (size, perimeter).""" - device = verts.device - V = verts.shape[0] - F = faces.shape[0] - e_all = torch.cat([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) - e_sorted, _ = e_all.sort(dim=1) - packed = e_sorted[:, 0].long() * V + e_sorted[:, 1].long() - unique_packed, counts = torch.unique(packed, return_counts=True) - - n_boundary = int((counts == 1).sum().item()) - n_interior = int((counts == 2).sum().item()) - n_nonmanifold = int((counts >= 3).sum().item()) - nm_max = int(counts.max().item()) if counts.numel() > 0 else 0 - nm_share_breakdown = [] - if n_nonmanifold > 0: - # Show top-5 non-manifold counts - nm_counts = counts[counts >= 3] - unique_nm, cnt_nm = torch.unique(nm_counts, return_counts=True) - for c, n in zip(unique_nm.tolist(), cnt_nm.tolist()): - nm_share_breakdown.append(f"{n} edges×{c}faces") - - logging.info(f"[FillHoles diag] V={V} F={F} | " - f"boundary(cnt==1)={n_boundary} interior(cnt==2)={n_interior} " - f"non-manifold(cnt>=3)={n_nonmanifold} (max={nm_max})") - if nm_share_breakdown: - logging.info(f"[FillHoles diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}") - - if n_boundary == 0: - logging.info("[FillHoles diag] no boundary edges → no cycles to fill") - return - - # Walk components same as production path (bidir-prop, by-vertex count). - boundary_packed = unique_packed[counts == 1] - is_b = torch.isin(packed, boundary_packed) - b_directed = e_all[is_b] - src = b_directed[:, 0].long() - tgt = b_directed[:, 1].long() - - labels = torch.arange(V, dtype=torch.long, device=device) - for _ in range(64): - edge_min = torch.minimum(labels[src], labels[tgt]) - new_labels = labels.clone() - new_labels.scatter_reduce_(0, src, edge_min, reduce="amin", include_self=True) - new_labels.scatter_reduce_(0, tgt, edge_min, reduce="amin", include_self=True) - new_labels = new_labels[new_labels] - if torch.equal(new_labels, labels): - break - labels = new_labels - - edge_component = labels[src] - unique_components, component_idx = torch.unique(edge_component, return_inverse=True) - L = unique_components.shape[0] - edge_len = (verts[src] - verts[tgt]).norm(dim=-1) - perim = torch.zeros(L, dtype=verts.dtype, device=device) - perim.scatter_add_(0, component_idx, edge_len) - edge_count = torch.bincount(component_idx, minlength=L) - - pair_keys = torch.unique(torch.cat([ - component_idx.long() * V + src, - component_idx.long() * V + tgt, - ])) - pair_c = pair_keys // V - vert_count = torch.bincount(pair_c, minlength=L) - - # Open chain = vert_count == edge_count + 1; closed cycle = vert_count == edge_count. - is_chain = (vert_count == edge_count + 1) - is_cycle = (vert_count == edge_count) & (vert_count > 0) - is_other = ~(is_chain | is_cycle) - - # Match production filter (cycles only, default fill_chains=False, default max_verts=16). - MAX_VERTS_DEFAULT = 16 - CENTROID_FAN_THRESHOLD = 8 - cycle_perim_ok = is_cycle & (perim < max_perimeter) - cycle_size_ok = is_cycle & (vert_count >= 3) & (vert_count <= MAX_VERTS_DEFAULT) - actually_kept = is_cycle & (vert_count >= 3) & (vert_count <= MAX_VERTS_DEFAULT) & (perim < max_perimeter) - - # Triangulation strategy split. - vfan = actually_kept & (vert_count <= CENTROID_FAN_THRESHOLD) - cfan = actually_kept & (vert_count > CENTROID_FAN_THRESHOLD) - vfan_tris = int((vert_count[vfan] - 2).sum().item()) # N-2 tris per N-vert cycle - cfan_tris = int(vert_count[cfan].sum().item()) # N tris per N-vert cycle - cfan_new_verts = int(cfan.sum().item()) # 1 centroid per centroid-fan component - - logging.info(f"[FillHoles diag] components={L} " - f"cycles={int(is_cycle.sum().item())} chains={int(is_chain.sum().item())} " - f"non-simple={int(is_other.sum().item())}") - logging.info(f"[FillHoles diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})") - logging.info(f"[FillHoles diag] actually kept={int(actually_kept.sum().item())} " - f"cycle rejected by perim={int((is_cycle & ~cycle_perim_ok).sum().item())} " - f"cycle rejected by verts={int((is_cycle & ~cycle_size_ok).sum().item())}") - logging.info(f"[FillHoles diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)") - logging.info(f"[FillHoles diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts") - - # Cycle vert-count distribution - if is_cycle.any(): - from collections import Counter - cycle_sizes = vert_count[is_cycle].tolist() - sc = Counter(cycle_sizes) - # show buckets: 3, 4, 5, 6, 7-10, 11-20, 21-50, 51+ - buckets = {"3":0,"4":0,"5":0,"6":0,"7-10":0,"11-20":0,"21-50":0,"51+":0} - for s, n in sc.items(): - if s == 3: buckets["3"] += n - elif s == 4: buckets["4"] += n - elif s == 5: buckets["5"] += n - elif s == 6: buckets["6"] += n - elif s <= 10: buckets["7-10"] += n - elif s <= 20: buckets["11-20"] += n - elif s <= 50: buckets["21-50"] += n - else: buckets["51+"] += n - logging.info(f"[FillHoles diag] cycle vert-count buckets: {buckets}") - - if is_cycle.any(): - cycle_perims = perim[is_cycle].cpu().tolist() - head = sorted(cycle_perims, reverse=True)[:10] - logging.info(f"[FillHoles diag] top-10 cycle perimeters: " - f"{['%.4f' % p for p in head]}") - - def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=False, max_verts=16): - # Bidirectional connected-component labeling on the undirected boundary - # subgraph. Fixes the original pointer-doubling bug where chains starting at - # the lowest-id vertex never propagated their label backward, producing - # spurious size-1/2 fragments (see qem_core._propagate_face_labels for - # the same pattern applied to face adjacency). - # - # By default we only close TRUE cycles (each boundary vert has degree 2 in - # the component). Chains tend to be either real surface boundaries or - # fragments of a cycle broken by non-manifold edges — fan-filling them with - # an arbitrary centroid produces overlapping/noisy geometry. Pass - # fill_chains=True to opt in to chain closure. + # Bidirectional (not pointer-doubling) CC labeling so low-id chains propagate + # backward. Cycles-only by default; fill_chains=True opts into noisy chain fills. device = verts.device V = verts.shape[0] dtype = verts.dtype @@ -1617,7 +1298,7 @@ def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=Fal break labels = new_labels - # Each boundary edge -> its component id. After bidir-prop, labels[src] == labels[tgt]. + # After bidir-prop, labels[src] == labels[tgt], so labels[src] is the edge's component. edge_component = labels[src] unique_components, component_idx = torch.unique(edge_component, return_inverse=True) L = unique_components.shape[0] @@ -1626,8 +1307,7 @@ def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=Fal perim = torch.zeros(L, dtype=dtype, device=device) perim.scatter_add_(0, component_idx, edge_len) - # Unique boundary-vertex set per component, to count verts and place centroids. - # Pack (component, vert) into one key; dedup via torch.unique. + # Unique boundary verts per component, via packed (comp,vert) keys. pair_keys = torch.cat([ component_idx.long() * V + src, component_idx.long() * V + tgt, @@ -1641,14 +1321,11 @@ def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=Fal centroids.scatter_add_(0, pair_c[:, None].expand(-1, 3), verts[pair_v]) centroids = centroids / vert_count.clamp_min(1).to(dtype).unsqueeze(-1) - # Identify closed cycles: every boundary vert in the component has exactly - # degree 2 in the boundary subgraph. Equivalent: vert_count == edge_count. + # Closed cycle ⇔ every boundary vert has degree 2 ⇔ vert_count == edge_count. is_cycle_component = (vert_count == edge_count) & (vert_count > 0) - # Filter: keep cycles (always) and chains (only if fill_chains=True), under perim limit. - # Also cap vert_count: fan-from-centroid only triangulates correctly for small, - # near-planar cycles. Larger holes produce overlapping geometry because the - # centroid lands far from any surface. + # Keep cycles (and chains if fill_chains) under perim/vert limits; centroid-fan + # only works for small near-planar holes (else centroid lands off-surface → overlap). size_ok = (vert_count >= 3) & (vert_count <= max_verts) & (perim < max_perimeter) if fill_chains: keep_component = size_ok @@ -1656,37 +1333,25 @@ def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=Fal keep_component = is_cycle_component & size_ok if not keep_component.any(): return verts, faces, colors, 0 - # Only centroid-fan components actually allocate a new vertex slot. - # We pre-compute their indices here so the triangulation step below has them ready. - use_centroid_per_comp_pre = keep_component & (vert_count > 8) # threshold mirrored below + # Only centroid-fan components allocate a new vertex (threshold mirrored below). + use_centroid_per_comp_pre = keep_component & (vert_count > 8) centroid_long = use_centroid_per_comp_pre.long() centroid_idx_per_comp = V + centroid_long.cumsum(0) - 1 - # Triangulate kept components. Two strategies: - # - # Vertex-fan (small cycles): pick one boundary vert as apex, connect to all - # non-adjacent boundary edges. N verts -> N-2 triangles, no inserted vertex. - # Apex stays on the existing surface, so no off-surface centroid → no overlap. - # Right choice for 6-vert dual-grid pinches around interior verts. - # - # Centroid-fan (large cycles): insert a new vertex at the boundary centroid, - # fan from it. N triangles. Only safe if the cycle is close to planar. - # We fall back to centroid-fan above `centroid_fan_threshold` verts where - # vertex-fan would produce excessively skinny triangles. - CENTROID_FAN_THRESHOLD = 8 # tune: lower = more vertex-fan, higher = more centroid-fan + # vertex-fan (small cycles): boundary vert as apex, on-surface. centroid-fan (large): + # insert centroid (near-planar only, but avoids skinny tris on big holes). + CENTROID_FAN_THRESHOLD = 8 - # Edge kept mask edge_kept = keep_component[component_idx] edge_comp = component_idx[edge_kept] kept_src = src[edge_kept] kept_tgt = tgt[edge_kept] - # Per-edge tag: which strategy does its component use? use_centroid_per_comp = keep_component & (vert_count > CENTROID_FAN_THRESHOLD) use_centroid_per_edge = use_centroid_per_comp[edge_comp] fan_pieces = [] - # ---- Centroid-fan branch (only for components > threshold) ---- + # Centroid-fan branch if use_centroid_per_edge.any(): kept_centroid = centroid_idx_per_comp[edge_comp[use_centroid_per_edge]] fan_pieces.append(torch.stack([ @@ -1695,22 +1360,17 @@ def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=Fal kept_centroid, ], dim=1).to(faces.dtype)) - # ---- Vertex-fan branch (small cycles, no centroid inserted) ---- + # Vertex-fan branch (small cycles) use_vertex_fan_per_comp = keep_component & (vert_count <= CENTROID_FAN_THRESHOLD) if use_vertex_fan_per_comp.any(): - # For each vertex-fan component, pick the smallest-id boundary vert as apex - # (deterministic & matches labels[*] = smallest after bidir-prop). - # Then emit edges in component as fan tris (apex, src, tgt) EXCEPT for - # the two edges incident to the apex (those would be degenerate). - apex_per_comp = labels[unique_components] # labels[u]==u after convergence - # Edges that DON'T touch their component's apex + # Apex = smallest-id boundary vert; fan (apex, src, tgt) skipping apex-incident edges. + apex_per_comp = labels[unique_components] vf_mask = use_vertex_fan_per_comp[edge_comp] if vf_mask.any(): vf_src = kept_src[vf_mask] vf_tgt = kept_tgt[vf_mask] vf_comp = edge_comp[vf_mask] vf_apex = apex_per_comp[vf_comp] - # Skip edges that include the apex (apex==src or apex==tgt → degenerate tri). non_apex = (vf_src != vf_apex) & (vf_tgt != vf_apex) fan_pieces.append(torch.stack([ vf_tgt[non_apex], vf_src[non_apex], vf_apex[non_apex], @@ -1718,9 +1378,7 @@ def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=Fal fan_faces = torch.cat(fan_pieces, dim=0) if fan_pieces else torch.empty((0, 3), dtype=faces.dtype, device=device) - # Open chains: close them with a closing triangle ONLY for centroid-fan - # components (vertex-fan chains would need a different closing strategy). - # In practice fill_chains=False makes this a no-op since chains aren't kept. + # Close open chains (centroid-fan only; no-op when fill_chains=False). if fill_chains: vert_degree = torch.zeros(V, dtype=torch.long, device=device) vert_degree.scatter_add_(0, src, torch.ones_like(src)) @@ -1764,21 +1422,17 @@ def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=Fal def weld_vertices_fn(vertices, faces, epsilon=None, epsilon_rel=1e-5, colors=None): - """Merge coincident vertices via L_inf grid quantization. - Ported from custom_nodes/qem_simplify/qem_core.py:_weld_vertices. - - `epsilon`: absolute L_inf distance; verts within this collapse together. - If None, `epsilon_rel * bbox_diag` is used. - Attributes (colors) are averaged across each cluster. - - Returns (new_verts, new_faces, new_colors, n_welded).""" + """Merge coincident vertices via L_inf grid quantization. `epsilon` absolute (None → + epsilon_rel*bbox_diag); colors averaged per cluster. Returns (v, f, colors, n_welded).""" if vertices.ndim == 3: v_out, f_out, c_out = [], [], [] if colors is not None else None total = 0 for i in range(vertices.shape[0]): ci = colors[i] if colors is not None else None v_i, f_i, c_i, n = weld_vertices_fn(vertices[i], faces[i], epsilon, epsilon_rel, ci) - v_out.append(v_i); f_out.append(f_i); total += n + v_out.append(v_i) + f_out.append(f_i) + total += n if c_out is not None: c_out.append(c_i) max_v = max(v.shape[0] for v in v_out) @@ -1833,20 +1487,17 @@ def weld_vertices_fn(vertices, faces, epsilon=None, epsilon_rel=1e-5, colors=Non return new_verts, new_faces, new_colors, int(vertices.shape[0] - n_unique) -def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsilon_rel=1e-5, fill_chains=False, max_verts=16, diagnostic=False): - """Batched wrapper for the v2 GPU hole-filler. CPU tensors get pulled - onto CUDA when available; otherwise fall back to the v1 (CPU walker) fn. - - Pre-welds vertices via `weld_vertices_fn(epsilon_rel=weld_epsilon_rel)` — - boundary detection requires shared edges, which requires welded verts. - Already-welded meshes early-exit cheaply. Set `weld_epsilon_rel=0` to skip.""" +def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsilon_rel=1e-5, fill_chains=False, max_verts=16): + """Batched v2 GPU hole-filler (v1 CPU walker fallback on non-CUDA). Pre-welds verts + first — boundary detection needs shared edges; `weld_epsilon_rel=0` skips it.""" if vertices.ndim == 3: v_list, f_list, c_list = [], [], [] if colors is not None else None pbar = comfy.utils.ProgressBar(vertices.shape[0]) for i in range(vertices.shape[0]): ci = colors[i] if colors is not None else None - v_i, f_i, c_i = fill_holes_v2_fn(vertices[i], faces[i], max_perimeter, ci, weld_epsilon_rel, fill_chains, max_verts, diagnostic) - v_list.append(v_i); f_list.append(f_i) + v_i, f_i, c_i = fill_holes_v2_fn(vertices[i], faces[i], max_perimeter, ci, weld_epsilon_rel, fill_chains, max_verts) + v_list.append(v_i) + f_list.append(f_i) if c_list is not None: c_list.append(c_i) pbar.update(1) @@ -1864,14 +1515,12 @@ def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsi if faces.numel() == 0: return vertices, faces, colors - # Adaptive weld: a properly welded triangle surface has V/F ≈ 0.5 (closed) - # to ~1.0 (with boundaries). V/F > 1 means most faces still share no verts - # and hole-fill would emit one bogus closing tri per face. We double the - # weld epsilon until V/F < WELDED_THRESHOLD or we hit WELD_CAP. + # Adaptive weld: welded surfaces have V/F ≈ 0.5-1.0; V/F > 1 means unwelded (hole-fill + # would emit a bogus tri per face). Double epsilon until V/F < WELDED_THRESHOLD or WELD_CAP. if weld_epsilon_rel > 0: eps = float(weld_epsilon_rel) - WELD_CAP = 1e-2 # ≈ 10 voxels at 1024-res — aggressive but bounded - WELDED_THRESHOLD = 1.0 # V/F below this is "welded enough" for hole-fill + WELD_CAP = 1e-2 # ≈ 10 voxels at 1024-res + WELDED_THRESHOLD = 1.0 # V/F below this is welded enough total_welded = 0 n_escalations = 0 while True: @@ -1894,39 +1543,16 @@ def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsi f"duplicate verts at distances >{WELD_CAP}× bbox; fix upstream " f"(decimate node settings) or run WeldVertices manually with a larger epsilon." ) - # Diag runs AFTER welding so its topology numbers match what the filler sees. - if diagnostic and vertices.device.type == "cuda" and faces.numel() > 0: - _fill_holes_v2_diagnostic(vertices, faces, max_perimeter) if vertices.device.type == "cuda": out_v, out_f, out_c, _ = _fill_holes_v2_gpu(vertices, faces, max_perimeter, colors, fill_chains, max_verts) return out_v, out_f, out_c - # CPU fallback: re-use the v1 walker (no attribute prop, but topologically equivalent - # for manifold boundary; v2 GPU is the path that actually matters for pixal3d output). + # CPU fallback: v1 walker (no attribute prop, but topologically equivalent for manifold boundary). out_v, out_f = fill_holes_fn(vertices, faces, max_perimeter=max_perimeter) return out_v, out_f, colors -def compute_vertex_normals(verts, faces): - """Computes area-weighted vertex normals.""" - # QUICK FIX: Ensure indices are int64 for scatter_add_ - faces_long = faces.to(torch.int64) - - i0, i1, i2 = faces_long[:, 0], faces_long[:, 1], faces_long[:, 2] - v0, v1, v2 = verts[i0], verts[i1], verts[i2] - - # calculate unnormalized face normals (magnitude is proportional to area) - face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) - - # accumulate face normals to vertices - vertex_normals = torch.zeros_like(verts) - vertex_normals.scatter_add_(0, i0.unsqueeze(-1).expand_as(face_normals), face_normals) - vertex_normals.scatter_add_(0, i1.unsqueeze(-1).expand_as(face_normals), face_normals) - vertex_normals.scatter_add_(0, i2.unsqueeze(-1).expand_as(face_normals), face_normals) - - return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6) - def _process_mesh_batch(mesh, per_item_fn): - """Handles list/batched/single mesh dispatching, color extraction, and stacking.""" + """Dispatch list/batched/single mesh, extract colors, stack results.""" mesh = copy.deepcopy(mesh) def process_single(v, f, c, bar): @@ -1978,189 +1604,8 @@ def _process_mesh_batch(mesh, per_item_fn): return IO.NodeOutput(mesh) -def fix_face_orientation(vertices, faces, reference_normals=None): - num_faces = faces.shape[0] - if num_faces == 0: - return faces - - device = faces.device - corrected = faces.clone() - max_vert = vertices.shape[0] - - # Manifold edge adjacency: pair faces that share an edge (run length 2 after - # canonicalizing + sorting edge hashes). - idx = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.int64, device=device) - edges = corrected[:, idx] # (num_faces, 3, 2) directed - edges_canon = torch.sort(edges, dim=2)[0].view(-1, 2) - edge_hash = edges_canon[:, 0] * max_vert + edges_canon[:, 1] - hash_sorted, sort_idx = torch.sort(edge_hash) - start = torch.cat([torch.ones(1, dtype=torch.bool, device=device), - hash_sorted[1:] != hash_sorted[:-1]]) - unique_starts = torch.nonzero(start, as_tuple=True)[0] - unique_ends = torch.cat([unique_starts[1:], - torch.tensor([hash_sorted.numel()], device=device)]) - manifold_starts = unique_starts[(unique_ends - unique_starts) == 2] - - if manifold_starts.numel() > 0: - f_a = sort_idx[manifold_starts] // 3 - f_b = sort_idx[manifold_starts + 1] // 3 - le_a = sort_idx[manifold_starts] % 3 - le_b = sort_idx[manifold_starts + 1] % 3 - opposite = (edges[f_a, le_a] == edges[f_b, le_b].flip(dims=[1])).all(dim=1) - - # Connected components via scipy (fast C), replacing a per-face Python BFS. - import scipy.sparse - import scipy.sparse.csgraph - fa_np = f_a.cpu().numpy(); fb_np = f_b.cpu().numpy() - graph = scipy.sparse.coo_matrix( - (np.ones(fa_np.shape[0] * 2, dtype=np.int8), - (np.concatenate([fa_np, fb_np]), np.concatenate([fb_np, fa_np]))), - shape=(num_faces, num_faces)) - num_components, comp = scipy.sparse.csgraph.connected_components(graph, directed=False) - component_id = torch.from_numpy(comp.astype(np.int64)).to(device) - - # Within-component consistent winding. A QEM output from a consistently wound - # source is already consistent (every shared edge is traversed oppositely) -> - # no flips needed, the common fast path. Otherwise propagate a parity flip - # across the dual graph by vectorized label relaxation (min-root carrying - # parity), instead of the old per-face CPU BFS. - if not bool(opposite.all()): - nf = ~opposite - src = torch.cat([f_a, f_b]); dst = torch.cat([f_b, f_a]); nfd = torch.cat([nf, nf]) - root = torch.arange(num_faces, device=device) - par = torch.zeros(num_faces, dtype=torch.bool, device=device) - for _ in range(num_faces + 8): # breaks at graph diameter; cap is a backstop - cand_root = root[src]; cand_par = par[src] ^ nfd - new_root = root.clone() - new_root.scatter_reduce_(0, dst, cand_root, reduce='amin', include_self=True) - changed = new_root < root - if not bool(changed.any()): - break - apply = changed[dst] & (cand_root == new_root[dst]) - par[dst[apply]] = cand_par[apply] - root = new_root - if bool(par.any()): - corrected[par] = corrected[par][:, [0, 2, 1]] - else: - component_id = torch.arange(num_faces, device=device) - num_components = num_faces - - v0 = vertices[corrected[:, 0]] - v1 = vertices[corrected[:, 1]] - v2 = vertices[corrected[:, 2]] - - face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) - face_normals = face_normals / (torch.norm(face_normals, dim=-1, keepdim=True) + 1e-8) - - if reference_normals is not None: - n0 = reference_normals[corrected[:, 0]] - n1 = reference_normals[corrected[:, 1]] - n2 = reference_normals[corrected[:, 2]] - ref_normals = (n0 + n1 + n2) / 3.0 - ref_normals = ref_normals / (torch.norm(ref_normals, dim=-1, keepdim=True) + 1e-8) - - votes = (face_normals * ref_normals).sum(dim=-1) - - outward_votes_comp = torch.zeros(num_components, dtype=torch.int64, device=device) - inward_votes_comp = torch.zeros(num_components, dtype=torch.int64, device=device) - - outward_votes_comp.scatter_add_(0, component_id, (votes > 0).to(torch.int64)) - inward_votes_comp.scatter_add_(0, component_id, (votes < 0).to(torch.int64)) - - n_faces_comp_int = torch.zeros(num_components, dtype=torch.int64, device=device) - n_faces_comp_int.scatter_add_(0, component_id, torch.ones(num_faces, dtype=torch.int64, device=device)) - - thresholds = torch.maximum(torch.ones_like(n_faces_comp_int), n_faces_comp_int // 10) - should_flip_comp = inward_votes_comp > outward_votes_comp + thresholds - else: - # Vectorized 3-Axis Extreme Majority Vote (Geometrically Infallible) - face_centroids = (v0 + v1 + v2) / 3.0 - - votes_by_axis = [] - for axis in range(3): - coords = face_centroids[:, axis] - - # Double stable sort acts as a vectorized lexsort on (coords, component_id) - sort_idx2 = torch.argsort(coords, stable=True) - sort_idx2 = sort_idx2[torch.argsort(component_id[sort_idx2], stable=True)] - - # Find group boundaries to get the extreme outer face along this axis per component - comp_id_sorted = component_id[sort_idx2] - group_ends = torch.nonzero(comp_id_sorted[1:] != comp_id_sorted[:-1], as_tuple=True)[0] - group_ends = torch.cat([group_ends, torch.tensor([len(comp_id_sorted) - 1], device=device)]) - - extreme_face_indices = sort_idx2[group_ends] - extreme_normals = face_normals[extreme_face_indices] - - # Normal's component along the respective axis should be positive - votes_by_axis.append(extreme_normals[:, axis] > 0) - - stacked_votes = torch.stack(votes_by_axis, dim=0) - should_flip_comp = stacked_votes.sum(dim=0) < 2 # False if at least 2 axes agree outward - - should_flip_face = should_flip_comp[component_id] - if should_flip_face.any(): - corrected[should_flip_face] = corrected[should_flip_face][:, [0, 2, 1]] - - return corrected - - -def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4): - is_batched = vertices.ndim == 3 - device = vertices.device - - if is_batched: - B = vertices.shape[0] - F = faces.shape[1] - - # 1. Advanced index broadcast to pull all faces in parallel without any Python loops - batch_idx = torch.arange(B, device=device).view(-1, 1, 1) - v_faces = vertices[batch_idx, faces] # shape (B, F, 3, 3) - - v0, v1, v2 = v_faces[:, :, 0], v_faces[:, :, 1], v_faces[:, :, 2] - - # 2. Compute face normals - fn = torch.cross(v1 - v0, v2 - v0, dim=-1) - fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8) - - # 3. Translate directly along the face normals in parallel - offset_verts = v_faces + fn.unsqueeze(2) * z_offset - out_v = offset_verts.reshape(B, -1, 3) - - # 4. Generate identical faces for all batches using constant expansion (O(1)) - f_single = torch.arange(F * 3, device=device).reshape(-1, 3) - out_f = f_single.unsqueeze(0).expand(B, -1, -1) - - if colors is not None: - c_faces = colors[batch_idx, faces] - out_c = c_faces.reshape(B, -1, colors.shape[-1]) - return out_v, out_f, out_c - return out_v, out_f - - # --- Unbatched (Single Mesh) --- - v_faces = vertices[faces] # shape (F, 3, 3) - v0, v1, v2 = v_faces[:, 0], v_faces[:, 1], v_faces[:, 2] - - # Compute face normals - fn = torch.cross(v1 - v0, v2 - v0, dim=-1) - fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8) - - # Offset each face's private vertices along its face normal - offset_verts = v_faces + fn.unsqueeze(1) * z_offset - offset_verts = offset_verts.reshape(-1, 3) - - # Generate sequential face indices for the unwelded vertices - f_unwelded = torch.arange(faces.shape[0] * 3, device=vertices.device).reshape(-1, 3) - - if colors is not None: - c_faces = colors[faces] - c_unwelded = c_faces.reshape(-1, colors.shape[-1]) - return offset_verts, f_unwelded, c_unwelded - - return offset_verts, f_unwelded, None - def _fmt_count(n) -> str: - """Compact human-readable integer for node status lines, e.g. 853, 12.3K, 1.23M.""" + """Compact integer for status lines, e.g. 853, 12.3K, 1.23M.""" n = int(n) if n >= 1_000_000: return f"{n / 1_000_000:.2f}".rstrip("0").rstrip(".") + "M" @@ -2179,24 +1624,18 @@ def _fmt_face_change(n_in, n_out) -> str: class DecimateMesh(IO.ComfyNode): @classmethod def define_schema(cls): - # placement_mode picks how the merged vertex is positioned, and which extra - # quality knobs are surfaced (DynamicCombo: the qem sub-widgets only appear - # when 'qem' is selected). + # qem sub-widgets show only when 'qem' is selected (DynamicCombo). placement_options = [ IO.DynamicCombo.Option(key="midpoint", inputs=[]), IO.DynamicCombo.Option(key="qem", inputs=[ IO.Float.Input("line_quadric_weight", default=0.0, min=0.0, max=100.0, step=0.1, - tooltip="Weight of the per-edge line quadric (squared distance to the edge " - "line). Biases collapses to preserve sharp ridges/valleys. 0 = off."), + tooltip="Per-edge line-quadric weight; preserves sharp ridges/valleys. 0 = off."), IO.Float.Input("feature_edge_quadric_weight", default=0.0, min=0.0, max=1000.0, step=1.0, - tooltip="Extra quadric weight on dihedral feature edges (creases). Higher = " - "more aggressively preserves hard edges. 0 = off."), + tooltip="Extra quadric weight on dihedral feature edges (creases). 0 = off."), IO.Float.Input("feature_edge_min_dihedral_deg", default=30.0, min=0.0, max=180.0, step=1.0, - tooltip="Minimum dihedral angle (degrees) for an edge to count as a feature " - "edge for feature_edge_quadric_weight."), + tooltip="Min dihedral angle (deg) to count an edge as a feature edge."), IO.Boolean.Input("clamp_v_to_edge", default=True, - tooltip="Project the QEM-optimal position onto the collapsed edge segment. " - "Prevents inward-cascade drift on curved surfaces."), + tooltip="Project the QEM-optimal position onto the collapsed edge segment."), ]), ] return IO.Schema( @@ -2205,19 +1644,17 @@ class DecimateMesh(IO.ComfyNode): category="latent/3d", description=( "Simplifies a mesh to a target face count using QEM, on the active compute " - "device. 'midpoint' placement uses the cumesh-faithful preset (best quality, " - "preserves thin features / hair). 'qem' places each merged vertex at the QEM " - "optimum and exposes line/feature-edge quadric controls. Output stays welded " - "so it smooth-shades." + "device. 'midpoint' is the cumesh-faithful preset (best quality, preserves thin " + "features / hair); 'qem' places verts at the QEM optimum with line/feature-edge " + "controls. Output stays welded." ), inputs=[ IO.Mesh.Input("mesh"), IO.Int.Input("target_face_count", default=200_000, min=0, max=50_000_000, - tooltip="Target maximum number of faces. Set to 0 to disable."), + tooltip="Target max faces. 0 disables."), IO.DynamicCombo.Input("placement_mode", options=placement_options, display_name="placement_mode", - tooltip="midpoint: cumesh-faithful preset (recommended). " - "qem: QEM-optimal placement with line/feature-edge controls."), + tooltip="midpoint: cumesh-faithful (recommended). qem: QEM-optimal placement."), ], outputs=[IO.Mesh.Output("mesh")], hidden=[IO.Hidden.unique_id], @@ -2227,7 +1664,7 @@ class DecimateMesh(IO.ComfyNode): def execute(cls, mesh, target_face_count, placement_mode): mode = placement_mode.get("placement_mode", "midpoint") if mode == "qem": - # QEM-optimum placement + ratio driver; everything else inherits the defaults. + # QEM-optimum placement; rest inherit defaults. cfg = QEMConfig( placement_mode="qem", line_quadric_weight=float(placement_mode.get("line_quadric_weight", 0.0)), @@ -2236,10 +1673,9 @@ class DecimateMesh(IO.ComfyNode): clamp_v_to_edge=bool(placement_mode.get("clamp_v_to_edge", True)), ) else: - cfg = QEMConfig() # midpoint placement + threshold driver (the defaults) + cfg = QEMConfig() # midpoint defaults - # ComfyUI passes meshes on CPU; the QEM is ~30x slower there. Run on the - # selected compute device and return on the mesh's original device. + # ComfyUI passes meshes on CPU (QEM much slower there); compute on device, return on original. compute_device = comfy.model_management.get_torch_device() counts = {"in": 0, "out": 0} @@ -2264,7 +1700,7 @@ class DecimateMesh(IO.ComfyNode): result = _process_mesh_batch(mesh, _fn) - # Send progress text to display the face reduction on the node + # Display the face reduction on the node if cls.hidden.unique_id: PromptServer.instance.send_progress_text( _fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id) @@ -2275,29 +1711,23 @@ class DecimateMesh(IO.ComfyNode): class RemeshMesh(IO.ComfyNode): @classmethod def define_schema(cls): - # sign_mode picks the scalar field, and exposes only the knobs relevant to it - # (DynamicCombo: udf sub-widgets show for 'udf', sdf sub-widgets for 'sdf'). + # sub-widgets show per sign_mode (DynamicCombo). sign_mode_options = [ IO.DynamicCombo.Option(key="udf", inputs=[ IO.Boolean.Input("qef", default=False, - tooltip="Experimental: place dual vertices via QEF (closest-triangle normals) " - "instead of edge-crossing centroid. QEF is sign-agnostic so it works " - "in UDF too — pulls the ±eps surface back onto the planes for sharper " - "edges. May misbehave near the UDF double shell; compare with it off."), + tooltip="Experimental: QEF dual-vertex placement for sharper edges; may " + "misbehave near the UDF double shell."), IO.Boolean.Input("drop_inverted_components", default=True, - tooltip="Drop closed components with inward normals (negative signed volume) — " - "the inner shell UDF produces on closed regions."), + tooltip="Drop inward-normal (negative-volume) closed components — the UDF inner shell."), IO.Boolean.Input("drop_enclosed_components", default=True, - tooltip="Drop components whose bbox is inside the largest's AND fail a raycast " - "point-in-mesh test. Disable if you have legitimate parts inside others."), + tooltip="Drop components inside the largest's bbox that fail a point-in-mesh " + "raycast. Disable for legitimate nested parts."), ]), IO.DynamicCombo.Option(key="sdf", inputs=[ IO.Boolean.Input("qef", default=True, - tooltip="Place dual vertices via QEF solve from closest-triangle normals " - "(recovers sharp features) vs edge-crossing centroid."), + tooltip="QEF dual-vertex placement (recovers sharp features) vs edge-crossing centroid."), IO.Boolean.Input("manifold", default=False, - tooltip="Manifold Dual Contouring: emit 1-4 dual verts per voxel for " - "multi-sheet (thin/touching) cases. Slower; guarantees manifold output."), + tooltip="Manifold Dual Contouring: 1-4 dual verts/voxel for multi-sheet cases. Slower."), ]), ] return IO.Schema( @@ -2305,42 +1735,34 @@ class RemeshMesh(IO.ComfyNode): display_name="Remesh Mesh (Narrow-Band DC)", category="latent/3d", description=( - "Re-extracts a uniformly tessellated mesh by sampling a distance field on a " - "narrow-band voxel grid and contouring it with Dual Contouring, on the active " - "compute device. Normalizes topology of messy / non-manifold / self-intersecting " - "input; run before DecimateMesh to hit an exact face count. Output stays welded." + "Re-extracts a uniformly tessellated mesh via a narrow-band distance field + Dual " + "Contouring, on the active compute device. Normalizes messy / non-manifold / " + "self-intersecting topology; run before DecimateMesh to hit an exact face count. " + "Output stays welded." ), inputs=[ IO.Mesh.Input("mesh"), IO.Int.Input("target_faces", default=0, min=0, max=50_000_000, tooltip="0 = use 'resolution'. >0 = auto-pick resolution to roughly hit this " - "face count (±30-50%); usually overshoot then DecimateMesh to exact."), + "count (±30-50%); overshoot then DecimateMesh to exact."), IO.Int.Input("resolution", default=256, min=32, max=1024, - tooltip="Voxel grid resolution (used when target_faces=0). Higher = more detail, " - "slower. 256 ~ 100k faces, 512 ~ 1M."), + tooltip="Voxel grid resolution (when target_faces=0). 256 ~ 100k faces, 512 ~ 1M."), IO.DynamicCombo.Input("sign_mode", options=sign_mode_options, display_name="sign_mode", - tooltip="udf: robust to messy/non-manifold input (double shell cleaned by " - "the inner-shell filters). sdf: clean single surface with optional " - "QEF sharp-feature recovery, but needs consistent winding."), + tooltip="udf: robust to messy/non-manifold input. sdf: clean single " + "surface with QEF sharp-feature recovery, but needs consistent winding."), IO.Float.Input("band", default=1.0, min=0.5, max=4.0, step=0.1, - tooltip="Narrow-band width in voxel units (which voxels are sampled). In UDF " - "mode also offsets the surface by this many voxels."), + tooltip="Narrow-band width in voxel units. In UDF mode also offsets the surface."), IO.Float.Input("project_back", default=0.0, min=0.0, max=1.0, step=0.05, - tooltip="Lerp output verts toward the closest point on the original surface " - "(0 = pure DC, 1 = snapped). Recovers voxelization-lost detail."), + tooltip="Lerp verts toward the original surface (0 = pure DC, 1 = snapped)."), IO.Boolean.Input("fix_poles", default=False, - tooltip="Collapse valence-3 vertex pairs (DC T-junction artifact). Cheap; " - "improves shading and downstream simplification."), + tooltip="Collapse valence-3 vertex pairs (DC T-junction artifact)."), IO.Int.Input("smooth_iters", default=0, min=0, max=20, - tooltip="Taubin λ|μ smoothing iterations (0 = off). Volume-preserving; cleans DC " - "stairstepping. 2-3 is enough; higher rounds off QEF sharp features."), + tooltip="Taubin smoothing iters (0 = off). 2-3 cleans DC stairstepping; higher rounds off QEF edges."), IO.Float.Input("drop_small_components", default=0.01, min=0.0, max=0.5, step=0.005, - tooltip="Drop components with fewer than this fraction of the largest component's " - "faces (inner-shell fragments, noise). 0 disables."), + tooltip="Drop components below this fraction of the largest's face count. 0 disables."), IO.Int.Input("precluster_max_verts", default=0, min=0, max=50_000_000, - tooltip="Safety fallback: if input has more verts than this (>0), cluster-decimate " - "it down first so the distance-field queries don't OOM on huge inputs. " - "0 = off; 1-2M is reasonable for very large meshes."), + tooltip="If input exceeds this (>0), cluster-decimate first so field queries don't " + "OOM. 0 = off; 1-2M for very large meshes."), ], outputs=[IO.Mesh.Output("mesh")], hidden=[IO.Hidden.unique_id], @@ -2351,14 +1773,13 @@ class RemeshMesh(IO.ComfyNode): project_back, fix_poles, smooth_iters, drop_small_components, precluster_max_verts): mode = sign_mode.get("sign_mode", "udf") - # mode-specific sub-widgets (absent ones fall back to defaults) + # mode-specific sub-widgets (absent → defaults) qef = bool(sign_mode.get("qef", True)) manifold = bool(sign_mode.get("manifold", False)) drop_inverted_components = bool(sign_mode.get("drop_inverted_components", True)) drop_enclosed_components = bool(sign_mode.get("drop_enclosed_components", True)) - # ComfyUI passes meshes on CPU; remesh is far faster on GPU. Run on the - # selected compute device and return on the mesh's original device. + # ComfyUI passes meshes on CPU (remesh far faster on GPU); compute on device, return on original. compute_device = comfy.model_management.get_torch_device() counts = {"in": 0, "out": 0} @@ -2370,13 +1791,12 @@ class RemeshMesh(IO.ComfyNode): ff = f.to(compute_device).to(torch.int64) cc = c.to(compute_device).float() if c is not None else None - # safety fallback: cluster-decimate very large inputs before the field queries + # cluster-decimate very large inputs before field queries if precluster_max_verts > 0 and vv.shape[0] > precluster_max_verts: vv, ff, cc = qem_cluster_decimate( vv, ff, target_verts=int(precluster_max_verts), colors=cc) - # Fixed [-0.5,0.5] cube domain (matches cumesh / TRELLIS2). scale ≈ 1.0 - # for any resolution, so this is consistent in target_faces auto mode too. + # Fixed [-0.5,0.5] cube domain (matches cumesh/TRELLIS2); scale ≈ 1.0 any resolution. rs_scale = (resolution + 3.0 * band) / resolution rs_center = torch.zeros(3, dtype=vv.dtype, device=compute_device) @@ -2402,7 +1822,7 @@ class RemeshMesh(IO.ComfyNode): result = _process_mesh_batch(mesh, _fn) - # Send progress text to display the face change on the node + # Display the face change on the node if cls.hidden.unique_id: PromptServer.instance.send_progress_text( _fmt_face_change(counts["in"], counts["out"]), cls.hidden.unique_id) @@ -2411,7 +1831,7 @@ class RemeshMesh(IO.ComfyNode): def _pack_uv_meshes(vs, fs, uvs, colors): - """Pack per-item (verts, faces, uvs[, colors]) into a MESH; stack if single, else pad with counts.""" + """Pack per-item (verts, faces, uvs[, colors]) into a MESH; stack if single, else pad.""" if len(vs) == 1: m = Types.MESH(vertices=vs[0].unsqueeze(0), faces=fs[0].unsqueeze(0), uvs=uvs[0].unsqueeze(0)) if colors is not None: @@ -2440,7 +1860,7 @@ def _pack_uv_meshes(vs, fs, uvs, colors): def _uv_weld_vertices(v, f, weld_distance): - """Merge coincident verts; returns (welded_v, welded_f, welded_to_orig) (last None if no welding).""" + """Merge coincident verts; returns (welded_v, welded_f, welded_to_orig); last None if no welding.""" v_np = v.cpu().numpy() f_np = f.cpu().numpy() if v_np.size == 0: @@ -2467,13 +1887,13 @@ def _uv_weld_vertices(v, f, weld_distance): def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance): - """UV-unwrap a single mesh; returns (vmapping, indices, uvs) — vmapping maps each output + """UV-unwrap a single mesh; returns (vmapping, indices, uvs); vmapping maps each output vertex to an input vertex (seam verts duplicated).""" v_in = positions.to(torch.float32) f_in = indices.to(torch.long).reshape(-1, 3) v_in, f_in, welded_to_orig = _uv_weld_vertices(v_in, f_in, weld_distance) - # drop degenerate faces (repeated index) — they corrupt edge adjacency + # drop degenerate faces (repeated index; corrupt edge adjacency) degen = ((f_in[:, 0] == f_in[:, 1]) | (f_in[:, 1] == f_in[:, 2]) | (f_in[:, 2] == f_in[:, 0])) if bool(degen.any()): f_in = f_in[~degen] @@ -2496,7 +1916,7 @@ def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance n_charts = int(face_chart.max().item()) + 1 if face_chart.numel() else 0 areas_cpu = _uv_mesh.chart_3d_areas(mesh.face_area, face_chart, n_charts).detach().cpu() - # per-chart loop runs on CPU/numpy to avoid per-chart GPU sync + # per-chart loop on CPU/numpy to avoid per-chart GPU sync face_chart_np = face_chart.cpu().numpy() faces_np = mesh.faces.cpu().numpy() vertices_np = mesh.vertices.cpu().numpy() @@ -2536,7 +1956,7 @@ def _uv_unwrap(positions, indices, segmenter, resolution, padding, weld_distance all_chart_uv_areas.append(uv_area_sum) all_chart_faces.append(lf) - # auto-tune texel density to land near `resolution` (assumes ~0.62 pack fill) + # auto-tune texel density toward `resolution` (~0.62 pack fill) total_3d_area = sum(all_chart_3d_areas) or 1.0 target_dim = float(resolution) if resolution > 0 else 1024.0 tex_per_unit = math.sqrt((target_dim * target_dim) * 0.62 / total_3d_area) @@ -2573,23 +1993,22 @@ class UnwrapMesh(IO.ComfyNode): display_name="Unwrap Mesh UVs", category="latent/3d", description=( - "Generates a UV atlas (pure-torch, no xatlas dependency): segments the surface into " - "charts, parameterizes each, and packs them into a [0,1] atlas. Verts on chart seams " - "are duplicated. Run after DecimateMesh/RemeshMesh, before texture baking." + "Generates a UV atlas (pure-torch, no xatlas): segments the surface into charts, " + "parameterizes each, packs into a [0,1] atlas. Seam verts duplicated. Run after " + "DecimateMesh/RemeshMesh, before texture baking." ), inputs=[ IO.Mesh.Input("mesh"), IO.Combo.Input("segmenter", options=["pec", "adaptive"], default="pec", - tooltip="pec: fast parallel-edge-collapse charting (CUDA; falls back to " - "adaptive on CPU). adaptive: CPU charting, slower."), + tooltip="pec: fast parallel-edge-collapse charting (CUDA; CPU falls back to " + "adaptive). adaptive: CPU, slower."), IO.Int.Input("resolution", default=1024, min=0, max=8192, step=256, - tooltip="Target atlas resolution used to auto-scale texel density (0 = fit-to-content)."), + tooltip="Target atlas resolution for texel-density auto-scale (0 = fit-to-content)."), IO.Int.Input("padding", default=1, min=0, max=16, - tooltip="Texel padding between charts in the packed atlas."), + tooltip="Texel padding between charts."), IO.Float.Input("weld_distance", default=0.0, min=0.0, max=1.0, step=0.0001, - tooltip="Merge radius for coincident verts as a fraction of mesh extent " - "(0 = auto, 1e-5). Raise to ~0.001 if you get per-triangle charts " - "(unwelded / triangle-soup input)."), + tooltip="Coincident-vert merge radius as a fraction of mesh extent (0 = auto). " + "Raise to ~0.001 if you get per-triangle charts (unwelded input)."), ], outputs=[IO.Mesh.Output("mesh")], hidden=[IO.Hidden.unique_id], @@ -2652,7 +2071,7 @@ class UnwrapMesh(IO.ComfyNode): def _uv_sorted_edge_keys(indices: np.ndarray): - """Undirected edge keys per face-edge, sorted; returns (sorted_keys, face_id, lo, hi, first_mask).""" + """Sorted undirected edge keys; returns (sorted_keys, face_id, lo, hi, first_mask).""" a = indices.ravel().astype(np.int64) b = np.roll(indices, -1, axis=1).ravel().astype(np.int64) lo = np.minimum(a, b) @@ -2668,7 +2087,7 @@ def _uv_sorted_edge_keys(indices: np.ndarray): def _uv_faces_to_chart_ids(indices: np.ndarray) -> np.ndarray: - """Chart = connected component of faces adjacent iff they share a (non-seam-duplicated) UV vertex.""" + """Chart = connected component of faces sharing a (non-seam-duplicated) UV vertex.""" F = indices.shape[0] if F == 0: return np.empty(0, dtype=np.int64) @@ -2709,15 +2128,14 @@ def _uv_palette(n: int) -> np.ndarray: def _uv_render_atlas(uvs_np, indices_np, resolution, device, bg=(0.13, 0.13, 0.13), edge=(0.0, 0.0, 0.0)): - """Tile-based torch rasterizer of the UV atlas (charts colored, borders outlined); returns (H,W,3).""" + """Tile-based torch rasterizer of the UV atlas (charts colored, borders outlined); (H,W,3).""" w = h = int(resolution) chart_ids_np = _uv_faces_to_chart_ids(indices_np) uvs = torch.from_numpy(uvs_np).to(device=device, dtype=torch.float32) indices = torch.from_numpy(indices_np).to(device=device, dtype=torch.long) chart_ids = torch.from_numpy(chart_ids_np).to(device=device, dtype=torch.long) - img = torch.zeros((h, w, 3), dtype=torch.float32, device=device) - img[..., 0] = bg[0]; img[..., 1] = bg[1]; img[..., 2] = bg[2] + img = torch.tensor(bg, dtype=torch.float32, device=device).expand(h, w, 3).contiguous() if indices.numel() == 0: return img @@ -2729,9 +2147,12 @@ def _uv_render_atlas(uvs_np, indices_np, resolution, device, uv_px[:, 1] = uv_px[:, 1].clamp(0.0, 1.0) * (h - 1) tri = uv_px[indices] - x0 = tri[:, 0, 0]; y0 = tri[:, 0, 1] - x1 = tri[:, 1, 0]; y1 = tri[:, 1, 1] - x2 = tri[:, 2, 0]; y2 = tri[:, 2, 1] + x0 = tri[:, 0, 0] + y0 = tri[:, 0, 1] + x1 = tri[:, 1, 0] + y1 = tri[:, 1, 1] + x2 = tri[:, 2, 0] + y2 = tri[:, 2, 1] denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2) nondegen = denom.abs() > 1e-20 @@ -2740,7 +2161,7 @@ def _uv_render_atlas(uvs_np, indices_np, resolution, device, ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, h - 1).long() ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, h - 1).long() - # full point-in-triangle over all (pixel, tri) pairs is O(H*W*F); tile and test only bbox-overlapping tris + # full point-in-tri over all pairs is O(H*W*F); tile and test only bbox-overlapping tris TILE = 64 eps = 1e-6 for ty in range(0, h, TILE): @@ -2755,9 +2176,12 @@ def _uv_render_atlas(uvs_np, indices_np, resolution, device, ys = torch.arange(ty, ty_end, dtype=torch.float32, device=device) + 0.5 xs = torch.arange(tx, tx_end, dtype=torch.float32, device=device) + 0.5 yy, xx = torch.meshgrid(ys, xs, indexing="ij") - sub_x0 = x0[idx][:, None, None]; sub_y0 = y0[idx][:, None, None] - sub_x1 = x1[idx][:, None, None]; sub_y1 = y1[idx][:, None, None] - sub_x2 = x2[idx][:, None, None]; sub_y2 = y2[idx][:, None, None] + sub_x0 = x0[idx][:, None, None] + sub_y0 = y0[idx][:, None, None] + sub_x1 = x1[idx][:, None, None] + sub_y1 = y1[idx][:, None, None] + sub_x2 = x2[idx][:, None, None] + sub_y2 = y2[idx][:, None, None] sub_den = denom[idx][:, None, None] bx = ((sub_y1 - sub_y2) * (xx - sub_x2) + (sub_x2 - sub_x1) * (yy - sub_y2)) / sub_den by = ((sub_y2 - sub_y0) * (xx - sub_x2) + (sub_x0 - sub_x2) * (yy - sub_y2)) / sub_den @@ -2772,7 +2196,7 @@ def _uv_render_atlas(uvs_np, indices_np, resolution, device, tile_img[hit_any] = tile_color[hit_any] img[ty:ty_end, tx:tx_end] = tile_img - # chart outlines: a chart border is an open boundary in UV space (seam verts duplicated) → edges with 1 incident face + # chart outlines: UV-space borders are open boundaries (edges with 1 incident face) _sk, _fid, lo, hi, first = _uv_sorted_edge_keys(indices_np) starts = np.nonzero(first)[0] counts = np.diff(np.append(starts, first.size)) @@ -2780,7 +2204,8 @@ def _uv_render_atlas(uvs_np, indices_np, resolution, device, uv_cpu = uv_px.cpu().numpy() px_xs, px_ys = [], [] for a, b in zip(lo[starts[boundary]], hi[starts[boundary]]): - p0 = uv_cpu[a]; p1 = uv_cpu[b] + p0 = uv_cpu[a] + p1 = uv_cpu[b] steps = int(max(abs(p1[0] - p0[0]), abs(p1[1] - p0[1])) + 1) if steps <= 1: continue @@ -2788,7 +2213,8 @@ def _uv_render_atlas(uvs_np, indices_np, resolution, device, xs = (p0[0] + (p1[0] - p0[0]) * ts).astype(np.int32) ys = (p0[1] + (p1[1] - p0[1]) * ts).astype(np.int32) valid = (xs >= 0) & (xs < w) & (ys >= 0) & (ys < h) - px_xs.append(xs[valid]); px_ys.append(ys[valid]) + px_xs.append(xs[valid]) + px_ys.append(ys[valid]) if px_xs: xs_all = torch.from_numpy(np.concatenate(px_xs)).to(device=device, dtype=torch.long) ys_all = torch.from_numpy(np.concatenate(px_ys)).to(device=device, dtype=torch.long) @@ -2804,8 +2230,8 @@ class RenderUVAtlas(IO.ComfyNode): node_id="RenderUVAtlas", display_name="Render UV Atlas", category="latent/3d", - description=("Renders a mesh's UV layout as an image — each chart a distinct color, " - "outlined where it borders other charts. Run UnwrapMesh first."), + description=("Renders a mesh's UV layout as an image (each chart a distinct color, " + "outlined at borders). Run UnwrapMesh first."), inputs=[ IO.Mesh.Input("mesh"), IO.Int.Input("resolution", default=1024, min=64, max=4096, step=64), @@ -2841,46 +2267,28 @@ class FillHoles(IO.ComfyNode): display_name="Fill Holes", category="latent/3d", description=( - "Fills holes in a mesh up to a maximum perimeter threshold, preserving " - "the existing geometry/UVs (only patch triangles are added). GPU-vectorised " - "via directed-half-edge pointer-doubling: no Python loop, auto-correct " - "winding from face direction, and centroid colors are averaged from the hole " - "loop. Falls back to a CPU walker on non-CUDA devices." + "Fills holes up to a max perimeter, preserving existing geometry/UVs (only patch " + "tris added). GPU-vectorised with auto-corrected winding and loop-averaged centroid " + "colors; CPU walker fallback on non-CUDA." ), inputs=[ IO.Mesh.Input("mesh"), IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001, - tooltip="Maximum hole perimeter to fill. Set to 0 to disable."), + tooltip="Max hole perimeter to fill. 0 disables."), IO.Float.Input("weld_epsilon_rel", default=1e-5, min=0.0, step=1e-6, - tooltip=( - "Pre-weld tolerance as a fraction of the bbox diagonal. " - "Boundary detection needs welded verts; already-welded meshes " - "early-exit free. Set to 0 to skip pre-weld." - )), + tooltip="Pre-weld tolerance (fraction of bbox diagonal); boundary detection " + "needs welded verts. 0 skips."), IO.Int.Input("max_verts", default=16, min=3, max=1024, - tooltip=( - "Cap the boundary-vertex count per cycle. Fan-from-centroid " - "only triangulates correctly for small, near-planar holes — " - "larger cycles produce overlapping geometry because the centroid " - "lands far from any surface. Keep low (≤16) for clean fills." - )), + tooltip="Cap boundary verts per cycle; centroid-fan only works for small " + "near-planar holes. Keep ≤16."), IO.Boolean.Input("fill_chains", default=False, - tooltip=( - "Also fill open boundary chains (not just closed cycles) " - "by closing them with a fan + closing triangle. " - "Often produces noisy/overlapping geometry on real meshes " - "because chains are usually genuine surface boundaries or " - "fragments of cycles broken by non-manifold edges. Leave OFF " - "to match cumesh/upstream behavior." - )), - IO.Boolean.Input("verbose", default=False, - tooltip="Log topology diagnostics (edge counts, cycles found, reject reasons) for debugging."), + tooltip="Also fill open chains (not just cycles). Noisy; OFF matches cumesh."), ], outputs=[IO.Mesh.Output("mesh")], ) @classmethod - def execute(cls, mesh, max_perimeter, weld_epsilon_rel, max_verts, fill_chains, verbose): + def execute(cls, mesh, max_perimeter, weld_epsilon_rel, max_verts, fill_chains): def _fn(v, f, c): if max_perimeter > 0: v, f, c = fill_holes_v2_fn( @@ -2888,7 +2296,6 @@ class FillHoles(IO.ComfyNode): weld_epsilon_rel=weld_epsilon_rel, fill_chains=fill_chains, max_verts=max_verts, - diagnostic=verbose, ) return v, f, c return _process_mesh_batch(mesh, _fn) @@ -2902,16 +2309,14 @@ class WeldVertices(IO.ComfyNode): display_name="Weld Vertices", category="latent/3d", description=( - "Merge coincident vertices via L_inf grid quantization. Use when a " - "mesh comes in unwelded (every face has its own 3 verts, no shared edges) " - "— pre-pass before FillHoles, DecimateMesh, or any topology-aware op. " - "Per-vertex colors are averaged across each merged cluster." + "Merge coincident vertices via L_inf grid quantization. Use when a mesh comes in " + "unwelded (per-face verts, no shared edges) — pre-pass before FillHoles, " + "DecimateMesh, or any topology-aware op. Colors averaged per cluster." ), inputs=[ IO.Mesh.Input("mesh"), IO.Float.Input("epsilon_rel", default=1e-5, min=0.0, step=1e-6, - tooltip="Weld tolerance as a fraction of the bbox diagonal. " - "1e-5 is enough for floating-point dedup; raise to " + tooltip="Weld tolerance (fraction of bbox diagonal). 1e-5 for float dedup; " "1e-3 for visibly-close-but-distinct verts."), IO.Float.Input("epsilon_abs", default=0.0, min=0.0, step=1e-6, tooltip="Absolute weld tolerance (overrides epsilon_rel when > 0)."), @@ -2931,15 +2336,9 @@ class WeldVertices(IO.ComfyNode): def merge_meshes(meshes): - """Concatenate a list of Types.MESH into a single (B=1) mesh. - - Vertices, faces (with cumulative index offset), uvs, and vertex_colors are - concatenated. If only some inputs carry uvs/vertex_colors, the missing sides - are padded — zeros for uvs, white (1.0) for vertex_colors — so the merged - primitive has a uniform attribute set. Texture is taken from the first input - that has one; later textures are dropped with a warning (single-primitive glb - can't carry multiple texture atlases without baking). - """ + """Concatenate Types.MESH list into one (B=1) mesh: cumulative face-index offset, + missing uvs/colors padded (zeros/white), texture from the first input that has one + (later dropped — single-primitive glb can't carry multiple atlases).""" if not meshes: raise ValueError("merge_meshes: need at least one mesh") @@ -2953,8 +2352,7 @@ def merge_meshes(meshes): texture = None offset = 0 for m in meshes: - # Mesh tensors are normalized to CPU by our producer nodes; coerce defensively - # so MoGe-side meshes (which may land on CUDA) merge cleanly with our outputs. + # Coerce to CPU so CUDA-side (MoGe) meshes merge cleanly with our outputs. v = _b0(m.vertices).cpu() f = _b0(m.faces).cpu() verts_list.append(v) @@ -3002,10 +2400,9 @@ class MergeMeshes(IO.ComfyNode): display_name="Merge Meshes", category="latent/3d", description=( - "Concatenate N meshes into a single mesh by offsetting face indices " - "and stacking vertices, faces, uvs, and vertex colors. Useful for combining a " - "Pixal3D-reconstructed object (via Pixal3DAlignObject) with a MoGe scene " - "background (via MoGePointMapToMesh) into one GLB." + "Concatenate N meshes into one by offsetting face indices and stacking verts, " + "faces, uvs, and colors. E.g. combine a Pixal3D object with a MoGe background " + "into one GLB." ), inputs=[ IO.Autogrow.Input("meshes", template=autogrow_template),