From 1697da460b1df3e91a7eafd2ec881908fdf7251e Mon Sep 17 00:00:00 2001 From: kijai Date: Wed, 10 Jun 2026 10:30:41 +0300 Subject: [PATCH] PBR baking --- comfy_api/latest/_util/geometry_types.py | 5 +- comfy_extras/nodes_mesh_postprocess.py | 1195 +++++++++++++++++++++- comfy_extras/nodes_save_3d.py | 43 +- comfy_extras/nodes_trellis2.py | 181 +++- 4 files changed, 1398 insertions(+), 26 deletions(-) diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index b821fd620..93f697119 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -17,6 +17,7 @@ class MESH: uvs: torch.Tensor | None = None, vertex_colors: torch.Tensor | None = None, texture: torch.Tensor | None = None, + metallic_roughness: torch.Tensor | None = None, vertex_counts: torch.Tensor | None = None, face_counts: torch.Tensor | None = None): @@ -26,7 +27,9 @@ class MESH: self.faces = faces # faces: (B, M, 3) self.uvs = uvs # uvs: (B, N, 2) self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4) - self.texture = texture # texture: (B, H, W, 3) + self.texture = texture # texture (baseColor): (B, H, W, 3) + # glTF metallicRoughness texture: (B, H, W, 3), R unused, G=roughness, B=metallic + self.metallic_roughness = metallic_roughness # When vertices/faces are zero-padded to a common N/M across the batch (variable-size mesh batch), # these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed. self.vertex_counts = vertex_counts diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 9cc77a206..5e4955695 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -81,6 +81,10 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] + # Voxel field may carry the full PBR set (base_color, metallic, roughness, + # alpha); vertex colors only use base_color RGB. + if v_colors.shape[-1] > 3: + v_colors = v_colors[:, :3] # to [0, 1] srgb_colors = v_colors.clamp(0, 1)#(v_colors * 0.5 + 0.5).clamp(0, 1) @@ -158,6 +162,554 @@ class PaintMesh(IO.ComfyNode): out_mesh = paint_mesh_with_voxels(mesh, coords, colors, resolution=resolution) return IO.NodeOutput(out_mesh) + +# ============================================================================= +# Texture baking from sparse voxel volume. +# +# Pipeline: xatlas UV unwrap → OpenGL UV-space rasterize to position map → +# nearest-voxel color sample per texel → cv2.inpaint to fill UV seams → +# attach texture + UVs to the Mesh for SaveGLB to serialize. +# +# Uses comfy_extras.nodes_glsl.GLContext for OpenGL context (already handles +# GLFW / EGL / OSMesa backend selection); xatlas for UV parameterization. +# ============================================================================= + +_GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache" + + +def _gl_compile_program(gl, vert_src: str, frag_src: str): + """Compile and link a minimal vert+frag GL program. Caller owns the GLuint + and must glDeleteProgram when done.""" + def _check_shader(s, kind): + if not gl.glGetShaderiv(s, gl.GL_COMPILE_STATUS): + log = gl.glGetShaderInfoLog(s).decode(errors="replace") + gl.glDeleteShader(s) + raise RuntimeError(f"GL {kind} shader compile failed: {log}") + + vs = gl.glCreateShader(gl.GL_VERTEX_SHADER) + gl.glShaderSource(vs, vert_src) + gl.glCompileShader(vs) + _check_shader(vs, "vertex") + fs = gl.glCreateShader(gl.GL_FRAGMENT_SHADER) + gl.glShaderSource(fs, frag_src) + gl.glCompileShader(fs) + _check_shader(fs, "fragment") + prog = gl.glCreateProgram() + gl.glAttachShader(prog, vs) + gl.glAttachShader(prog, fs) + gl.glLinkProgram(prog) + gl.glDeleteShader(vs) + gl.glDeleteShader(fs) + if not gl.glGetProgramiv(prog, gl.GL_LINK_STATUS): + log = gl.glGetProgramInfoLog(prog).decode(errors="replace") + gl.glDeleteProgram(prog) + raise RuntimeError(f"GL program link failed: {log}") + return prog + + +# Position-passthrough shader. Vertex maps UV → clip space; fragment outputs the +# interpolated world-space vertex position (with alpha=1 marking valid texels). +_BAKE_VERT_SRC = """ +#version 330 core +layout (location = 0) in vec3 a_pos; +layout (location = 1) in vec2 a_uv; +out vec3 v_pos; +void main() { + v_pos = a_pos; + gl_Position = vec4(a_uv * 2.0 - 1.0, 0.0, 1.0); +} +""" + +_BAKE_FRAG_SRC = """ +#version 330 core +in vec3 v_pos; +out vec4 frag_color; +void main() { + frag_color = vec4(v_pos, 1.0); +} +""" + + +def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): + """Rasterize unwrapped mesh in UV space; return (position_map, mask). + position_map: (H, W, 3) float32 — interpolated 3D position per texel. + mask: (H, W) bool — valid (covered) texels. + + Uses comfy_extras.nodes_glsl.GLContext, which lazily picks GLFW/EGL/OSMesa.""" + from comfy_extras.nodes_glsl import GLContext, _import_opengl + + GLContext() # ensure backend is initialized + current + gl = _import_opengl() + + # PyOpenGL's high-level wrappers for the buffer/draw/readback functions + # store array refs in OpenGL.contextdata, which on EGL contexts triggers + # "Attempt to retrieve context when no valid context". Use the raw C + # entry points (OpenGL.raw.*) instead — they skip the bookkeeping. + import ctypes as _ctypes + from OpenGL.raw.GL.VERSION.GL_1_1 import ( + glReadPixels as _raw_glReadPixels, + glTexImage2D as _raw_glTexImage2D, + glDrawElements as _raw_glDrawElements, + ) + from OpenGL.raw.GL.VERSION.GL_1_5 import glBufferData as _raw_glBufferData + from OpenGL.raw.GL.VERSION.GL_2_0 import glVertexAttribPointer as _raw_glVertexAttribPointer + + H = W = int(texture_size) + fbo = color_tex = vbo = ibo = vao = prog = None + try: + # Interleaved [pos.x, pos.y, pos.z, uv.x, uv.y] per vertex (stride=20 bytes). + verts32 = np.ascontiguousarray(verts_np, dtype=np.float32) + uvs32 = np.ascontiguousarray(uvs_np, dtype=np.float32) + faces32 = np.ascontiguousarray(faces_np, dtype=np.uint32) + vbo_data = np.ascontiguousarray(np.concatenate([verts32, uvs32], axis=-1), dtype=np.float32) + + prog = _gl_compile_program(gl, _BAKE_VERT_SRC, _BAKE_FRAG_SRC) + + vao = gl.glGenVertexArrays(1) + gl.glBindVertexArray(vao) + + vbo = gl.glGenBuffers(1) + gl.glBindBuffer(gl.GL_ARRAY_BUFFER, vbo) + _raw_glBufferData(gl.GL_ARRAY_BUFFER, int(vbo_data.nbytes), + vbo_data.ctypes.data_as(_ctypes.c_void_p), gl.GL_STATIC_DRAW) + + ibo = gl.glGenBuffers(1) + gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, ibo) + _raw_glBufferData(gl.GL_ELEMENT_ARRAY_BUFFER, int(faces32.nbytes), + faces32.ctypes.data_as(_ctypes.c_void_p), gl.GL_STATIC_DRAW) + + gl.glEnableVertexAttribArray(0) + _raw_glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl.GL_FALSE, 20, _ctypes.c_void_p(0)) + gl.glEnableVertexAttribArray(1) + _raw_glVertexAttribPointer(1, 2, gl.GL_FLOAT, gl.GL_FALSE, 20, _ctypes.c_void_p(12)) + + fbo = gl.glGenFramebuffers(1) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) + color_tex = gl.glGenTextures(1) + gl.glBindTexture(gl.GL_TEXTURE_2D, color_tex) + _raw_glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, W, H, + 0, gl.GL_RGBA, gl.GL_FLOAT, None) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_NEAREST) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST) + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, + gl.GL_TEXTURE_2D, color_tex, 0) + status = gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) + if status != gl.GL_FRAMEBUFFER_COMPLETE: + raise RuntimeError(f"FBO incomplete (status=0x{status:x})") + + gl.glViewport(0, 0, W, H) + gl.glClearColor(0.0, 0.0, 0.0, 0.0) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glDisable(gl.GL_CULL_FACE) + gl.glDisable(gl.GL_DEPTH_TEST) + gl.glUseProgram(prog) + _raw_glDrawElements(gl.GL_TRIANGLES, int(faces32.size), gl.GL_UNSIGNED_INT, None) + gl.glFinish() + + # Pre-allocate readback buffer and pass it as a pointer so PyOpenGL + # doesn't try to allocate one through its array-handler machinery. + arr = np.empty((H, W, 4), dtype=np.float32) + _raw_glReadPixels(0, 0, W, H, gl.GL_RGBA, gl.GL_FLOAT, + arr.ctypes.data_as(_ctypes.c_void_p)) + # Do NOT flipud here. Our shader places UV(0,0) at FBO bottom-left + # (clip(-1,-1)), and glReadPixels returns bottom-row-first, so arr[0] + # already holds the UV v=0 data. glTF samples PNG with row 0 = upper-left + # = UV v=0, so storing arr as-is gives a consistent mapping. Flipping + # would invert V and make every sample come from the wrong row. + position_map = arr[..., :3] + mask = arr[..., 3] > 0.5 + return position_map, mask + finally: + if fbo is not None: + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) + gl.glDeleteFramebuffers(1, [fbo]) + if color_tex is not None: + gl.glDeleteTextures(1, [color_tex]) + if vbo is not None: + gl.glDeleteBuffers(1, [vbo]) + if ibo is not None: + gl.glDeleteBuffers(1, [ibo]) + if vao is not None: + gl.glBindVertexArray(0) + gl.glDeleteVertexArrays(1, [vao]) + if prog is not None: + gl.glUseProgram(0) + gl.glDeleteProgram(prog) + + +def _sample_voxel_attrs_per_texel(position_map, mask, voxel_coords, voxel_colors, resolution): + """For every masked texel, query the nearest voxel and return ALL its + attribute channels. Returns (H, W, C) float32 in [0, 1] where C is the + voxel feature width (3 for plain color, 6 for full PBR).""" + H, W, _ = position_map.shape + color_np = voxel_colors.detach().cpu().numpy().astype(np.float32) + C = color_np.shape[-1] + out = np.zeros((H, W, C), dtype=np.float32) + if not mask.any(): + return out + + origin = np.array([-0.5, -0.5, -0.5], dtype=np.float32) + voxel_size = 1.0 / float(resolution) + voxel_pos = voxel_coords.detach().cpu().numpy().astype(np.float32) * voxel_size + origin + + tree = scipy.spatial.cKDTree(voxel_pos) + valid_positions = position_map[mask] + _, nearest_idx = tree.query(valid_positions, k=1, workers=-1) + out[mask] = np.clip(color_np[nearest_idx], 0.0, 1.0) + return out + + +def bake_texture_from_voxel_fn(vertices, faces, voxel_coords, voxel_colors, + resolution, texture_size, inpaint_radius=3, + fast_unwrap=True, existing_uvs=None, + normalize_uvs=True): + """Bake a baseColor (+ optional metallicRoughness) texture for + `vertices/faces`, rasterizing in UV space and nearest-voxel-sampling each + texel from the provided sparse colored voxel volume. + + If `existing_uvs` (N, 2) is given, it is used directly and xatlas is + skipped — bakes onto the mesh's current UV layout without re-unwrapping. + Otherwise xatlas computes a fresh atlas (verts/faces may grow at seams). + + Returns (out_vertices, out_faces, out_uvs, out_texture, out_mr). + + `fast_unwrap=True` configures xatlas with permissive chart options so it + finishes in a reasonable time on large meshes — at the cost of less even + UV distribution. Set False to use xatlas defaults (slow on >100k faces).""" + import time + + v_np = vertices.detach().cpu().numpy().astype(np.float32) + f_np = faces.detach().cpu().numpy().astype(np.uint32) + fcount = int(f_np.shape[0]) + t0 = time.perf_counter() + + if existing_uvs is not None: + # Bake onto the mesh's current UVs — no xatlas, no seam-splitting. + uv_np = existing_uvs.detach().cpu().numpy().astype(np.float32) + if uv_np.shape[0] != v_np.shape[0]: + raise ValueError( + f"BakeTextureFromVoxel: existing UVs ({uv_np.shape[0]}) must be 1:1 " + f"with vertices ({v_np.shape[0]})." + ) + uv_min = uv_np.min(axis=0) + uv_max = uv_np.max(axis=0) + oob = int(((uv_np < 0.0) | (uv_np > 1.0)).any(axis=1).sum()) + logging.info(f"[BakeTextureFromVoxel] using existing UVs: {v_np.shape[0]} verts, " + f"{fcount} faces (xatlas skipped)") + logging.info(f"[BakeTextureFromVoxel] UV range: u[{uv_min[0]:.3f},{uv_max[0]:.3f}] " + f"v[{uv_min[1]:.3f},{uv_max[1]:.3f}] out-of-[0,1] verts: {oob}/{uv_np.shape[0]}") + out_of_unit = (uv_min.min() < -1e-4) or (uv_max.max() > 1.0001) + if normalize_uvs and out_of_unit: + # Uniform fit of the UV bbox into [0,1] (preserves chart aspect ratios). + # Handles packers that overflow the unit square slightly. NOT a UDIM + # de-tiler — a true multi-tile layout would get squashed; warn if the + # span is large enough to look like tiling. + extent = float((uv_max - uv_min).max()) + span = max(float(uv_max[0] - uv_min[0]), float(uv_max[1] - uv_min[1])) + if span > 1.5: + logging.warning( + f"[BakeTextureFromVoxel] UV span {span:.2f} looks like a tiled/UDIM " + f"layout; uniform-fitting it into [0,1] will overlap tiles. " + f"Re-unwrap instead (use_existing_uvs=False)." + ) + if extent > 0: + uv_np = ((uv_np - uv_min) / extent).astype(np.float32) + logging.info(f"[BakeTextureFromVoxel] normalized UVs into [0,1] " + f"(uniform scale 1/{extent:.4f})") + new_verts, new_faces, new_uvs = v_np, f_np, uv_np + else: + import xatlas + if fcount > 300_000: + logging.warning( + f"[BakeTextureFromVoxel] mesh has {fcount} faces — xatlas chart " + f"decomposition is CPU-bound and may take many minutes. Consider " + f"decimating to under ~200k faces before baking." + ) + logging.info(f"[BakeTextureFromVoxel] xatlas unwrap: {v_np.shape[0]} verts, {fcount} faces") + if fast_unwrap and hasattr(xatlas, "Atlas"): + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + logging.info(f"[BakeTextureFromVoxel] add_mesh: {time.perf_counter() - t0:.1f}s") + gen_kwargs = {} + applied = [] + # ChartOptions: looser growth → larger / fewer charts → faster. + if hasattr(xatlas, "ChartOptions"): + co = xatlas.ChartOptions() + for attr, val in ( + ("max_iterations", 1), + ("max_cost", 8.0), + ("normal_deviation_weight", 1.0), + ("roundness_weight", 0.0), + ("straightness_weight", 0.0), + ("normal_seam_weight", 1.0), + ("texture_seam_weight", 0.0), + ("use_input_mesh_uvs", False), + ): + if hasattr(co, attr): + setattr(co, attr, val) + applied.append(f"chart.{attr}") + gen_kwargs["chart_options"] = co + # PackOptions.bruteForce defaults to True — tries many rotations per + # chart and is the single biggest contributor to pack time on small + # meshes. Off it loses ~5-15% packing efficiency but runs ~5× faster. + if hasattr(xatlas, "PackOptions"): + po = xatlas.PackOptions() + for attr, val in ( + ("bruteForce", False), + ("brute_force", False), # snake_case alias on some builds + ("create_image", False), + ("createImage", False), + ("padding", 2), + ): + if hasattr(po, attr): + setattr(po, attr, val) + applied.append(f"pack.{attr}") + gen_kwargs["pack_options"] = po + logging.info(f"[BakeTextureFromVoxel] options applied: {applied}") + tgen = time.perf_counter() + try: + atlas.generate(**gen_kwargs) + except TypeError as e: + logging.warning(f"[BakeTextureFromVoxel] generate(**kwargs) rejected ({e}); retrying with defaults") + atlas.generate() + logging.info(f"[BakeTextureFromVoxel] generate: {time.perf_counter() - tgen:.1f}s") + tget = time.perf_counter() + vmapping, indices, uvs = atlas[0] + logging.info(f"[BakeTextureFromVoxel] retrieve: {time.perf_counter() - tget:.1f}s") + else: + vmapping, indices, uvs = xatlas.parametrize(v_np, f_np) + + logging.info(f"[BakeTextureFromVoxel] xatlas total {time.perf_counter() - t0:.1f}s " + f"({vmapping.shape[0]} verts after seams)") + new_verts = v_np[vmapping] + new_faces = indices.astype(np.uint32) + new_uvs = uvs.astype(np.float32) + + t1 = time.perf_counter() + position_map, mask = _bake_position_map(new_verts, new_faces, new_uvs, texture_size) + logging.info(f"[BakeTextureFromVoxel] GL rasterize {texture_size}² in {time.perf_counter() - t1:.1f}s " + f"({int(mask.sum())}/{mask.size} texels covered)") + + t2 = time.perf_counter() + attrs = _sample_voxel_attrs_per_texel( + position_map, mask, voxel_coords, voxel_colors, resolution, + ) + logging.info(f"[BakeTextureFromVoxel] voxel sample in {time.perf_counter() - t2:.1f}s " + f"({attrs.shape[-1]} channels)") + + # Split into PBR maps. Layout matches upstream pbr_attr_layout: + # 0:3 base_color, 3 metallic, 4 roughness, 5 alpha. + C = attrs.shape[-1] + base_color = attrs[..., 0:3] + has_pbr = C >= 5 + metallic = attrs[..., 3:4] if C >= 4 else None + roughness = attrs[..., 4:5] if C >= 5 else None + # alpha channel exists at index 5 but we keep meshes opaque (upstream uses + # alpha_mode=OPAQUE in the remesh path); plumb later if needed. + + def _inpaint(img01, n_ch): + if inpaint_radius <= 0: + return img01 + import cv2 + u8 = (img01 * 255.0).clip(0, 255).astype(np.uint8) + mask_inv = ((~mask).astype(np.uint8)) * 255 + u8 = cv2.inpaint(u8, mask_inv, int(inpaint_radius), cv2.INPAINT_TELEA) + if u8.ndim == 2: + u8 = u8[..., None] + return u8.astype(np.float32) / 255.0 + + t3 = time.perf_counter() + base_color = _inpaint(np.ascontiguousarray(base_color), 3) + mr_image = None + if has_pbr: + # glTF metallicRoughness: R unused, G=roughness, B=metallic. + mr = np.concatenate([np.zeros_like(roughness), roughness, metallic], axis=-1) + mr_image = _inpaint(np.ascontiguousarray(mr), 3) + if inpaint_radius > 0: + logging.info(f"[BakeTextureFromVoxel] inpaint in {time.perf_counter() - t3:.1f}s") + + device = vertices.device + out_v = torch.from_numpy(new_verts).to(device=device, dtype=torch.float32) + out_f = torch.from_numpy(new_faces.astype(np.int32)).to(device=device, dtype=torch.int32) + out_uvs = torch.from_numpy(new_uvs).to(device=device, dtype=torch.float32) + out_tex = torch.from_numpy(np.ascontiguousarray(base_color)).to(device=device, dtype=torch.float32) + out_mr = (torch.from_numpy(np.ascontiguousarray(mr_image)).to(device=device, dtype=torch.float32) + if mr_image is not None else None) + return out_v, out_f, out_uvs, out_tex, out_mr + + +class BakeTextureFromVoxel(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BakeTextureFromVoxel", + display_name="Bake Texture From Voxel", + category="latent/3d", + description=( + "Unwraps the mesh with xatlas, rasterizes it in UV space via OpenGL " + "(using ComfyUI's existing PyOpenGL backend), and bakes PBR textures " + "by nearest-voxel sampling of the input sparse voxel volume. Produces " + "a baseColor texture, plus a metallicRoughness texture when the voxel " + "field carries the full PBR set (6 channels). Returns a Mesh with `uvs`, " + "`texture`, and `metallic_roughness` attached — SaveGLB serializes them " + "as real baseColorTexture / metallicRoughnessTexture maps." + ), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Voxel.Input("voxel_colors"), + IO.Int.Input("texture_size", default=1024, min=64, max=8192, + tooltip="Square texture resolution. Larger = sharper but slower / bigger file."), + IO.Int.Input("inpaint_radius", default=3, min=0, max=32, + tooltip="OpenCV inpaint radius for filling UV seam gutters. 0 disables."), + IO.Boolean.Input("fast_unwrap", default=True, + tooltip=( + "Use looser xatlas chart options to finish unwrap " + "much faster on large meshes (cost: less even UV " + "distribution). Off uses xatlas defaults, which can " + "take many minutes on >100k-face meshes." + )), + IO.Boolean.Input("use_existing_uvs", default=False, + tooltip=( + "Bake onto the mesh's existing UV layout instead of " + "re-unwrapping with xatlas. Requires the input mesh to " + "already carry UVs (e.g. from TorchXatlasUVWrap or a " + "retopologized mesh). Much faster and preserves your " + "UV layout. Ignored if the mesh has no UVs." + )), + IO.Boolean.Input("normalize_uvs", default=True, + tooltip=( + "When using existing UVs that spill outside [0,1] " + "(common with packers that overflow the unit square), " + "uniformly rescale them to fit. Without this, out-of-range " + "regions are clipped and don't bake. Disable only if your " + "UVs are already exactly in [0,1]." + )), + ], + outputs=[IO.Mesh.Output("mesh")], + ) + + @classmethod + def execute(cls, mesh, voxel_colors, texture_size, inpaint_radius, fast_unwrap, use_existing_uvs, normalize_uvs): + voxels = voxel_colors + coords = voxels.data + colors = voxels.voxel_colors + resolution = voxels.resolution + mesh_uvs = getattr(mesh, "uvs", None) + if use_existing_uvs and mesh_uvs is None: + logging.warning("BakeTextureFromVoxel: use_existing_uvs=True but mesh has no UVs; " + "falling back to xatlas unwrap.") + + if coords.shape[-1] == 4: + # Sparse coords have a batch column; bake per-item. + batch_idx = coords[:, 0].long() + voxel_xyz = coords[:, 1:] + mesh_batch_size = int(mesh.vertices.shape[0]) + out_verts, out_faces, out_uvs, out_tex, out_mr = [], [], [], [], [] + pbar = comfy.utils.ProgressBar(mesh_batch_size) + for i in range(mesh_batch_size): + sel = batch_idx == i + item_coords = voxel_xyz[sel] + item_colors = colors[sel] + v_i, f_i, _ = get_mesh_batch_item(mesh, i) + if item_coords.shape[0] == 0 or f_i.numel() == 0: + logging.warning(f"BakeTextureFromVoxel: skipping batch {i} (empty voxel/mesh)") + pbar.update(1) + continue + ev_i = mesh_uvs[i, :v_i.shape[0]] if (use_existing_uvs and mesh_uvs is not None) else None + bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn( + v_i, f_i, item_coords, item_colors, + resolution=resolution, texture_size=texture_size, + inpaint_radius=inpaint_radius, fast_unwrap=fast_unwrap, + existing_uvs=ev_i, normalize_uvs=normalize_uvs, + ) + out_verts.append(bv); out_faces.append(bf); out_uvs.append(bu) + out_tex.append(bt); out_mr.append(bmr) + pbar.update(1) + if not out_verts: + return IO.NodeOutput(mesh) + # Local pack_variable_mesh_batch doesn't take uvs/texture; build the + # packed mesh ourselves so we can attach both. UVs are 1:1 with verts. + packed = pack_variable_mesh_batch(out_verts, out_faces) + max_v = packed.vertices.shape[1] + packed_uvs = out_uvs[0].new_zeros((len(out_uvs), max_v, 2)) + for i, u in enumerate(out_uvs): + packed_uvs[i, :u.shape[0]] = u + packed.uvs = packed_uvs + packed.texture = torch.stack(out_tex, dim=0) + if all(mr is not None for mr in out_mr): + packed.metallic_roughness = torch.stack(out_mr, dim=0) + return IO.NodeOutput(packed) + + # Single-item path. + v0 = mesh.vertices.squeeze(0) + f0 = mesh.faces.squeeze(0) + ev0 = mesh_uvs.squeeze(0) if (use_existing_uvs and mesh_uvs is not None) else None + bv, bf, bu, bt, bmr = bake_texture_from_voxel_fn( + v0, f0, coords, colors, + resolution=resolution, texture_size=texture_size, + inpaint_radius=inpaint_radius, fast_unwrap=fast_unwrap, + existing_uvs=ev0, normalize_uvs=normalize_uvs, + ) + out_mesh = Types.MESH( + vertices=bv.unsqueeze(0), faces=bf.unsqueeze(0), + uvs=bu.unsqueeze(0), texture=bt.unsqueeze(0), + ) + if bmr is not None: + out_mesh.metallic_roughness = bmr.unsqueeze(0) + return IO.NodeOutput(out_mesh) + + +class MeshTextureToImage(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshTextureToImage", + display_name="Mesh Texture to Image", + category="latent/3d", + description=( + "Extracts a mesh's baked textures as IMAGE outputs for preview/save. " + "base_color is the baseColor map; metallic_roughness is the packed " + "glTF MR map (R unused, G=roughness, B=metallic) — black if the mesh " + "has no PBR texture." + ), + inputs=[IO.Mesh.Input("mesh")], + outputs=[ + IO.Image.Output(display_name="base_color"), + IO.Image.Output(display_name="metallic_roughness"), + IO.Image.Output(display_name="metallic"), + IO.Image.Output(display_name="roughness"), + ], + ) + + @classmethod + def execute(cls, mesh): + def _as_image(tex): + # Mesh textures are (B, H, W, 3) float in [0, 1] — already IMAGE layout. + if tex is None: + return None + t = tex.float().clamp(0.0, 1.0).cpu() + if t.ndim == 3: + t = t.unsqueeze(0) + return t + + base = _as_image(getattr(mesh, "texture", None)) + mr = _as_image(getattr(mesh, "metallic_roughness", None)) + + if base is None: + raise ValueError( + "MeshTextureToImage: mesh has no baseColor texture. Run " + "BakeTextureFromVoxel first (PaintMesh only sets vertex colors, not a texture)." + ) + if mr is None: + mr = torch.zeros_like(base) + # Split the packed glTF MR map into single-channel grayscale previews: + # G=roughness, B=metallic. Replicated to 3 channels so they display + # as proper grayscale IMAGEs. + metallic = mr[..., 2:3].expand(-1, -1, -1, 3).contiguous() + roughness = mr[..., 1:2].expand(-1, -1, -1, 3).contiguous() + return IO.NodeOutput(base, mr, metallic, roughness) + + def paint_mesh_default_colors(mesh): out_mesh = copy.copy(mesh) vertex_count = mesh.vertices.shape[1] @@ -284,6 +836,457 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): return v, f + +def _fill_holes_v2_diagnostic(verts, faces, max_perimeter): + """Topology dump for debugging missed-hole cases. Logs edge count + distribution, cycle count, and per-cycle (size, perimeter).""" + device = verts.device + V = verts.shape[0] + F = faces.shape[0] + e_all = torch.cat([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) + e_sorted, _ = e_all.sort(dim=1) + packed = e_sorted[:, 0].long() * V + e_sorted[:, 1].long() + unique_packed, counts = torch.unique(packed, return_counts=True) + + n_boundary = int((counts == 1).sum().item()) + n_interior = int((counts == 2).sum().item()) + n_nonmanifold = int((counts >= 3).sum().item()) + nm_max = int(counts.max().item()) if counts.numel() > 0 else 0 + nm_share_breakdown = [] + if n_nonmanifold > 0: + # Show top-5 non-manifold counts + nm_counts = counts[counts >= 3] + unique_nm, cnt_nm = torch.unique(nm_counts, return_counts=True) + for c, n in zip(unique_nm.tolist(), cnt_nm.tolist()): + nm_share_breakdown.append(f"{n} edges×{c}faces") + + logging.info(f"[FillHolesV2 diag] V={V} F={F} | " + f"boundary(cnt==1)={n_boundary} interior(cnt==2)={n_interior} " + f"non-manifold(cnt>=3)={n_nonmanifold} (max={nm_max})") + if nm_share_breakdown: + logging.info(f"[FillHolesV2 diag] non-manifold breakdown: {', '.join(nm_share_breakdown[:5])}") + + if n_boundary == 0: + logging.info("[FillHolesV2 diag] no boundary edges → no cycles to fill") + return + + # Walk components same as production path (bidir-prop, by-vertex count). + boundary_packed = unique_packed[counts == 1] + is_b = torch.isin(packed, boundary_packed) + b_directed = e_all[is_b] + src = b_directed[:, 0].long() + tgt = b_directed[:, 1].long() + + labels = torch.arange(V, dtype=torch.long, device=device) + for _ in range(64): + edge_min = torch.minimum(labels[src], labels[tgt]) + new_labels = labels.clone() + new_labels.scatter_reduce_(0, src, edge_min, reduce="amin", include_self=True) + new_labels.scatter_reduce_(0, tgt, edge_min, reduce="amin", include_self=True) + new_labels = new_labels[new_labels] + if torch.equal(new_labels, labels): + break + labels = new_labels + + edge_component = labels[src] + unique_components, component_idx = torch.unique(edge_component, return_inverse=True) + L = unique_components.shape[0] + edge_len = (verts[src] - verts[tgt]).norm(dim=-1) + perim = torch.zeros(L, dtype=verts.dtype, device=device) + perim.scatter_add_(0, component_idx, edge_len) + edge_count = torch.bincount(component_idx, minlength=L) + + pair_keys = torch.unique(torch.cat([ + component_idx.long() * V + src, + component_idx.long() * V + tgt, + ])) + pair_c = pair_keys // V + vert_count = torch.bincount(pair_c, minlength=L) + + # Open chain = vert_count == edge_count + 1; closed cycle = vert_count == edge_count. + is_chain = (vert_count == edge_count + 1) + is_cycle = (vert_count == edge_count) & (vert_count > 0) + is_other = ~(is_chain | is_cycle) + + # Match production filter (cycles only, default fill_chains=False, default max_verts=16). + MAX_VERTS_DEFAULT = 16 + CENTROID_FAN_THRESHOLD = 8 + cycle_perim_ok = is_cycle & (perim < max_perimeter) + cycle_size_ok = is_cycle & (vert_count >= 3) & (vert_count <= MAX_VERTS_DEFAULT) + actually_kept = is_cycle & (vert_count >= 3) & (vert_count <= MAX_VERTS_DEFAULT) & (perim < max_perimeter) + + # Triangulation strategy split. + vfan = actually_kept & (vert_count <= CENTROID_FAN_THRESHOLD) + cfan = actually_kept & (vert_count > CENTROID_FAN_THRESHOLD) + vfan_tris = int((vert_count[vfan] - 2).sum().item()) # N-2 tris per N-vert cycle + cfan_tris = int(vert_count[cfan].sum().item()) # N tris per N-vert cycle + cfan_new_verts = int(cfan.sum().item()) # 1 centroid per centroid-fan component + + logging.info(f"[FillHolesV2 diag] components={L} " + f"cycles={int(is_cycle.sum().item())} chains={int(is_chain.sum().item())} " + f"non-simple={int(is_other.sum().item())}") + logging.info(f"[FillHolesV2 diag] (with default filter: cycles only, verts in [3,{MAX_VERTS_DEFAULT}], perim<{max_perimeter})") + logging.info(f"[FillHolesV2 diag] actually kept={int(actually_kept.sum().item())} " + f"cycle rejected by perim={int((is_cycle & ~cycle_perim_ok).sum().item())} " + f"cycle rejected by verts={int((is_cycle & ~cycle_size_ok).sum().item())}") + logging.info(f"[FillHolesV2 diag] vertex-fan: {int(vfan.sum().item())} cycles → {vfan_tris} tris (no new verts)") + logging.info(f"[FillHolesV2 diag] centroid-fan: {int(cfan.sum().item())} cycles → {cfan_tris} tris + {cfan_new_verts} new verts") + + # Cycle vert-count distribution + if is_cycle.any(): + from collections import Counter + cycle_sizes = vert_count[is_cycle].tolist() + sc = Counter(cycle_sizes) + # show buckets: 3, 4, 5, 6, 7-10, 11-20, 21-50, 51+ + buckets = {"3":0,"4":0,"5":0,"6":0,"7-10":0,"11-20":0,"21-50":0,"51+":0} + for s, n in sc.items(): + if s == 3: buckets["3"] += n + elif s == 4: buckets["4"] += n + elif s == 5: buckets["5"] += n + elif s == 6: buckets["6"] += n + elif s <= 10: buckets["7-10"] += n + elif s <= 20: buckets["11-20"] += n + elif s <= 50: buckets["21-50"] += n + else: buckets["51+"] += n + logging.info(f"[FillHolesV2 diag] cycle vert-count buckets: {buckets}") + + if is_cycle.any(): + cycle_perims = perim[is_cycle].cpu().tolist() + head = sorted(cycle_perims, reverse=True)[:10] + logging.info(f"[FillHolesV2 diag] top-10 cycle perimeters: " + f"{['%.4f' % p for p in head]}") + + +def _fill_holes_v2_gpu(verts, faces, max_perimeter, colors=None, fill_chains=False, max_verts=16): + # Bidirectional connected-component labeling on the undirected boundary + # subgraph. Fixes the original pointer-doubling bug where chains starting at + # the lowest-id vertex never propagated their label backward, producing + # spurious size-1/2 fragments (see qem_core._propagate_face_labels for + # the same pattern applied to face adjacency). + # + # By default we only close TRUE cycles (each boundary vert has degree 2 in + # the component). Chains tend to be either real surface boundaries or + # fragments of a cycle broken by non-manifold edges — fan-filling them with + # an arbitrary centroid produces overlapping/noisy geometry. Pass + # fill_chains=True to opt in to chain closure. + device = verts.device + V = verts.shape[0] + dtype = verts.dtype + + e_all = torch.cat([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) + e_sorted, _ = e_all.sort(dim=1) + packed = e_sorted[:, 0].long() * V + e_sorted[:, 1].long() + unique_packed, counts = torch.unique(packed, return_counts=True) + boundary_packed = unique_packed[counts == 1] + if boundary_packed.numel() == 0: + return verts, faces, colors, 0 + is_b = torch.isin(packed, boundary_packed) + + b_directed = e_all[is_b] + src = b_directed[:, 0].long() + tgt = b_directed[:, 1].long() + + # Undirected bidirectional min-prop with path compression. + labels = torch.arange(V, dtype=torch.long, device=device) + for _ in range(64): + edge_min = torch.minimum(labels[src], labels[tgt]) + new_labels = labels.clone() + new_labels.scatter_reduce_(0, src, edge_min, reduce="amin", include_self=True) + new_labels.scatter_reduce_(0, tgt, edge_min, reduce="amin", include_self=True) + new_labels = new_labels[new_labels] # path compression + if torch.equal(new_labels, labels): + break + labels = new_labels + + # Each boundary edge -> its component id. After bidir-prop, labels[src] == labels[tgt]. + edge_component = labels[src] + unique_components, component_idx = torch.unique(edge_component, return_inverse=True) + L = unique_components.shape[0] + edge_count = torch.bincount(component_idx, minlength=L) + edge_len = (verts[src] - verts[tgt]).norm(dim=-1) + perim = torch.zeros(L, dtype=dtype, device=device) + perim.scatter_add_(0, component_idx, edge_len) + + # Unique boundary-vertex set per component, to count verts and place centroids. + # Pack (component, vert) into one key; dedup via torch.unique. + pair_keys = torch.cat([ + component_idx.long() * V + src, + component_idx.long() * V + tgt, + ]) + pair_keys = torch.unique(pair_keys) + pair_v = pair_keys % V + pair_c = pair_keys // V + vert_count = torch.bincount(pair_c, minlength=L) + + centroids = torch.zeros((L, 3), dtype=dtype, device=device) + centroids.scatter_add_(0, pair_c[:, None].expand(-1, 3), verts[pair_v]) + centroids = centroids / vert_count.clamp_min(1).to(dtype).unsqueeze(-1) + + # Identify closed cycles: every boundary vert in the component has exactly + # degree 2 in the boundary subgraph. Equivalent: vert_count == edge_count. + is_cycle_component = (vert_count == edge_count) & (vert_count > 0) + + # Filter: keep cycles (always) and chains (only if fill_chains=True), under perim limit. + # Also cap vert_count: fan-from-centroid only triangulates correctly for small, + # near-planar cycles. Larger holes produce overlapping geometry because the + # centroid lands far from any surface. + size_ok = (vert_count >= 3) & (vert_count <= max_verts) & (perim < max_perimeter) + if fill_chains: + keep_component = size_ok + else: + keep_component = is_cycle_component & size_ok + if not keep_component.any(): + return verts, faces, colors, 0 + # Only centroid-fan components actually allocate a new vertex slot. + # We pre-compute their indices here so the triangulation step below has them ready. + use_centroid_per_comp_pre = keep_component & (vert_count > 8) # threshold mirrored below + centroid_long = use_centroid_per_comp_pre.long() + centroid_idx_per_comp = V + centroid_long.cumsum(0) - 1 + + # Triangulate kept components. Two strategies: + # + # Vertex-fan (small cycles): pick one boundary vert as apex, connect to all + # non-adjacent boundary edges. N verts -> N-2 triangles, no inserted vertex. + # Apex stays on the existing surface, so no off-surface centroid → no overlap. + # Right choice for 6-vert dual-grid pinches around interior verts. + # + # Centroid-fan (large cycles): insert a new vertex at the boundary centroid, + # fan from it. N triangles. Only safe if the cycle is close to planar. + # We fall back to centroid-fan above `centroid_fan_threshold` verts where + # vertex-fan would produce excessively skinny triangles. + CENTROID_FAN_THRESHOLD = 8 # tune: lower = more vertex-fan, higher = more centroid-fan + + # Edge kept mask + edge_kept = keep_component[component_idx] + edge_comp = component_idx[edge_kept] + kept_src = src[edge_kept] + kept_tgt = tgt[edge_kept] + + # Per-edge tag: which strategy does its component use? + use_centroid_per_comp = keep_component & (vert_count > CENTROID_FAN_THRESHOLD) + use_centroid_per_edge = use_centroid_per_comp[edge_comp] + + fan_pieces = [] + # ---- Centroid-fan branch (only for components > threshold) ---- + if use_centroid_per_edge.any(): + kept_centroid = centroid_idx_per_comp[edge_comp[use_centroid_per_edge]] + fan_pieces.append(torch.stack([ + kept_tgt[use_centroid_per_edge], + kept_src[use_centroid_per_edge], + kept_centroid, + ], dim=1).to(faces.dtype)) + + # ---- Vertex-fan branch (small cycles, no centroid inserted) ---- + use_vertex_fan_per_comp = keep_component & (vert_count <= CENTROID_FAN_THRESHOLD) + if use_vertex_fan_per_comp.any(): + # For each vertex-fan component, pick the smallest-id boundary vert as apex + # (deterministic & matches labels[*] = smallest after bidir-prop). + # Then emit edges in component as fan tris (apex, src, tgt) EXCEPT for + # the two edges incident to the apex (those would be degenerate). + apex_per_comp = labels[unique_components] # labels[u]==u after convergence + # Edges that DON'T touch their component's apex + vf_mask = use_vertex_fan_per_comp[edge_comp] + if vf_mask.any(): + vf_src = kept_src[vf_mask] + vf_tgt = kept_tgt[vf_mask] + vf_comp = edge_comp[vf_mask] + vf_apex = apex_per_comp[vf_comp] + # Skip edges that include the apex (apex==src or apex==tgt → degenerate tri). + non_apex = (vf_src != vf_apex) & (vf_tgt != vf_apex) + fan_pieces.append(torch.stack([ + vf_tgt[non_apex], vf_src[non_apex], vf_apex[non_apex], + ], dim=1).to(faces.dtype)) + + fan_faces = torch.cat(fan_pieces, dim=0) if fan_pieces else torch.empty((0, 3), dtype=faces.dtype, device=device) + + # Open chains: close them with a closing triangle ONLY for centroid-fan + # components (vertex-fan chains would need a different closing strategy). + # In practice fill_chains=False makes this a no-op since chains aren't kept. + if fill_chains: + vert_degree = torch.zeros(V, dtype=torch.long, device=device) + vert_degree.scatter_add_(0, src, torch.ones_like(src)) + vert_degree.scatter_add_(0, tgt, torch.ones_like(tgt)) + is_endpoint = (vert_degree[pair_v] == 1) & use_centroid_per_comp_pre[pair_c] + if is_endpoint.any(): + ep_v = pair_v[is_endpoint] + ep_c = pair_c[is_endpoint] + order = torch.argsort(ep_c, stable=True) + ep_v_sorted = ep_v[order] + ep_c_sorted = ep_c[order] + ep_count_per_c = torch.bincount(ep_c_sorted, minlength=L) + is_chain_comp = ep_count_per_c == 2 + ep_is_chain = is_chain_comp[ep_c_sorted] + if ep_is_chain.any(): + chain_ep_v = ep_v_sorted[ep_is_chain] + chain_ep_c = ep_c_sorted[ep_is_chain] + assert chain_ep_v.numel() % 2 == 0 + chain_ep_v = chain_ep_v.view(-1, 2) + chain_ep_c = chain_ep_c.view(-1, 2)[:, 0] + close_centroid = centroid_idx_per_comp[chain_ep_c] + close_faces = torch.stack( + [chain_ep_v[:, 0], chain_ep_v[:, 1], close_centroid], dim=1 + ).to(faces.dtype) + fan_faces = torch.cat([fan_faces, close_faces], dim=0) + + # Only centroid-fan components contribute a new vertex; vertex-fan reuses existing. + new_centroids_v = centroids[use_centroid_per_comp_pre] + out_v = torch.cat([verts, new_centroids_v], dim=0) + out_f = torch.cat([faces, fan_faces], dim=0) + + out_c = colors + if colors is not None: + c_sum = torch.zeros((L, colors.shape[1]), dtype=colors.dtype, device=device) + c_sum.scatter_add_( + 0, pair_c[:, None].expand(-1, colors.shape[1]), colors[pair_v]) + c_avg = c_sum / vert_count.clamp_min(1).to(colors.dtype).unsqueeze(-1) + out_c = torch.cat([colors, c_avg[use_centroid_per_comp_pre]], dim=0) + + return out_v, out_f, out_c, int(keep_component.sum().item()) + + +def weld_vertices_fn(vertices, faces, epsilon=None, epsilon_rel=1e-5, colors=None): + """Merge coincident vertices via L_inf grid quantization. + Ported from custom_nodes/qem_simplify/qem_core.py:_weld_vertices. + + `epsilon`: absolute L_inf distance; verts within this collapse together. + If None, `epsilon_rel * bbox_diag` is used. + Attributes (colors) are averaged across each cluster. + + Returns (new_verts, new_faces, new_colors, n_welded).""" + if vertices.ndim == 3: + v_out, f_out, c_out = [], [], [] if colors is not None else None + total = 0 + for i in range(vertices.shape[0]): + ci = colors[i] if colors is not None else None + v_i, f_i, c_i, n = weld_vertices_fn(vertices[i], faces[i], epsilon, epsilon_rel, ci) + v_out.append(v_i); f_out.append(f_i); total += n + if c_out is not None: + c_out.append(c_i) + max_v = max(v.shape[0] for v in v_out) + for i in range(len(v_out)): + pad_n = max_v - v_out[i].shape[0] + if pad_n > 0: + v_out[i] = torch.cat([v_out[i], + torch.zeros(pad_n, 3, device=v_out[i].device, dtype=v_out[i].dtype)], dim=0) + if c_out is not None: + c_out[i] = torch.cat([c_out[i], + torch.zeros(pad_n, c_out[i].shape[1], device=c_out[i].device, dtype=c_out[i].dtype)], dim=0) + c_stack = torch.stack(c_out) if c_out is not None else None + return torch.stack(v_out), torch.stack(f_out), c_stack, total + + if vertices.shape[0] == 0: + return vertices, faces, colors, 0 + device = vertices.device + if epsilon is None: + bbox = vertices.max(dim=0)[0] - vertices.min(dim=0)[0] + eps = torch.norm(bbox) * float(epsilon_rel) + eps = max(float(eps.item()), 1e-12) + else: + eps = float(epsilon) + if eps <= 0: + return vertices, faces, colors, 0 + + scale = 1.0 / eps + bbox_min = vertices.min(dim=0)[0] + q = ((vertices - bbox_min) * scale).round().to(torch.int64) + extent = ((vertices.max(dim=0)[0] - bbox_min) * scale).round().to(torch.int64) + 2 + key = (q[:, 0] * extent[1] + q[:, 1]) * extent[2] + q[:, 2] + unique_key, inv = torch.unique(key, return_inverse=True) + n_unique = unique_key.shape[0] + if n_unique == vertices.shape[0]: + return vertices, faces, colors, 0 + + counts = torch.zeros(n_unique, dtype=vertices.dtype, device=device) + counts.scatter_add_(0, inv, torch.ones(vertices.shape[0], dtype=vertices.dtype, device=device)) + counts_div = counts.unsqueeze(-1).clamp_min(1.0) + + new_verts = torch.zeros((n_unique, 3), dtype=vertices.dtype, device=device) + new_verts.scatter_add_(0, inv.unsqueeze(-1).expand_as(vertices), vertices) + new_verts = new_verts / counts_div + + new_colors = None + if colors is not None: + new_colors = torch.zeros((n_unique, colors.shape[1]), dtype=colors.dtype, device=device) + new_colors.scatter_add_(0, inv.unsqueeze(-1).expand_as(colors), colors) + new_colors = new_colors / counts_div.to(colors.dtype) + + new_faces = inv[faces.long()].to(faces.dtype) if faces.numel() > 0 else faces + return new_verts, new_faces, new_colors, int(vertices.shape[0] - n_unique) + + +def fill_holes_v2_fn(vertices, faces, max_perimeter=0.03, colors=None, weld_epsilon_rel=1e-5, fill_chains=False, max_verts=16, diagnostic=False): + """Batched wrapper for the v2 GPU hole-filler. CPU tensors get pulled + onto CUDA when available; otherwise fall back to the v1 (CPU walker) fn. + + Pre-welds vertices via `weld_vertices_fn(epsilon_rel=weld_epsilon_rel)` — + boundary detection requires shared edges, which requires welded verts. + Already-welded meshes early-exit cheaply. Set `weld_epsilon_rel=0` to skip.""" + if vertices.ndim == 3: + v_list, f_list, c_list = [], [], [] if colors is not None else None + pbar = comfy.utils.ProgressBar(vertices.shape[0]) + for i in range(vertices.shape[0]): + ci = colors[i] if colors is not None else None + v_i, f_i, c_i = fill_holes_v2_fn(vertices[i], faces[i], max_perimeter, ci, weld_epsilon_rel, fill_chains, max_verts, diagnostic) + v_list.append(v_i); f_list.append(f_i) + if c_list is not None: + c_list.append(c_i) + pbar.update(1) + max_v = max(v.shape[0] for v in v_list) + for i in range(len(v_list)): + pad_n = max_v - v_list[i].shape[0] + if pad_n > 0: + v_list[i] = torch.cat([v_list[i], + torch.zeros(pad_n, 3, device=v_list[i].device, dtype=v_list[i].dtype)], dim=0) + if c_list is not None: + c_list[i] = torch.cat([c_list[i], + torch.zeros(pad_n, c_list[i].shape[1], device=c_list[i].device, dtype=c_list[i].dtype)], dim=0) + c_out = torch.stack(c_list) if c_list is not None else None + return torch.stack(v_list), torch.stack(f_list), c_out + + if faces.numel() == 0: + return vertices, faces, colors + # Adaptive weld: a properly welded triangle surface has V/F ≈ 0.5 (closed) + # to ~1.0 (with boundaries). V/F > 1 means most faces still share no verts + # and hole-fill would emit one bogus closing tri per face. We double the + # weld epsilon until V/F < WELDED_THRESHOLD or we hit WELD_CAP. + if weld_epsilon_rel > 0: + eps = float(weld_epsilon_rel) + WELD_CAP = 1e-2 # ≈ 10 voxels at 1024-res — aggressive but bounded + WELDED_THRESHOLD = 1.0 # V/F below this is "welded enough" for hole-fill + total_welded = 0 + n_escalations = 0 + while True: + vertices, faces, colors, n = weld_vertices_fn( + vertices, faces, epsilon=None, epsilon_rel=eps, colors=colors, + ) + total_welded += n + ratio = vertices.shape[0] / max(faces.shape[0], 1) + if ratio < WELDED_THRESHOLD or eps >= WELD_CAP: + break + eps = min(eps * 2.0, WELD_CAP) + n_escalations += 1 + if total_welded > 0 or n_escalations > 0: + tag = f" (escalated weld epsilon_rel→{eps:.1e} after {n_escalations} step{'s' if n_escalations != 1 else ''})" if n_escalations > 0 else "" + logging.info(f"[FillHolesV2] pre-welded {total_welded} verts, V/F={ratio:.2f}{tag}") + if ratio >= WELDED_THRESHOLD: + logging.warning( + f"[FillHolesV2] even at weld epsilon_rel={WELD_CAP} the mesh stays " + f"unwelded (V/F={ratio:.2f}, want < {WELDED_THRESHOLD}). Source mesh has " + f"duplicate verts at distances >{WELD_CAP}× bbox; fix upstream " + f"(decimate node settings) or run WeldVertices manually with a larger epsilon." + ) + # Diag runs AFTER welding so its topology numbers match what the filler sees. + if diagnostic and vertices.device.type == "cuda" and faces.numel() > 0: + _fill_holes_v2_diagnostic(vertices, faces, max_perimeter) + if vertices.device.type == "cuda": + out_v, out_f, out_c, _ = _fill_holes_v2_gpu(vertices, faces, max_perimeter, colors, fill_chains, max_verts) + return out_v, out_f, out_c + # CPU fallback: re-use the v1 walker (no attribute prop, but topologically equivalent + # for manifold boundary; v2 GPU is the path that actually matters for pixal3d output). + out_v, out_f = fill_holes_fn(vertices, faces, max_perimeter=max_perimeter) + return out_v, out_f, colors + + def _cleanup_mesh(verts, faces, min_angle_deg=0.5, max_aspect=100.0): if faces.numel() == 0: return verts, faces @@ -1130,13 +2133,203 @@ class FillHoles(IO.ComfyNode): return v, f, c return _process_mesh_batch(mesh, _fn) +class FillHolesV2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="FillHolesV2", + display_name="Fill Holes (v2)", + category="latent/3d", + description=( + "GPU-vectorised hole-filling via directed-half-edge pointer-doubling. " + "Drop-in alternative to FillHoles for comparison: same max_perimeter " + "cutoff and fan-from-centroid triangulation, but no Python loop, " + "auto-correct winding from face direction, and centroid colors are " + "averaged from the loop instead of left zero." + ), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001, + tooltip="Maximum hole perimeter to fill. Set to 0 to disable."), + IO.Float.Input("weld_epsilon_rel", default=1e-5, min=0.0, step=1e-6, + tooltip=( + "Pre-weld tolerance as a fraction of the bbox diagonal. " + "Boundary detection needs welded verts; already-welded meshes " + "early-exit free. Set to 0 to skip pre-weld." + )), + IO.Int.Input("max_verts", default=16, min=3, max=1024, + tooltip=( + "Cap the boundary-vertex count per cycle. Fan-from-centroid " + "only triangulates correctly for small, near-planar holes — " + "larger cycles produce overlapping geometry because the centroid " + "lands far from any surface. Keep low (≤16) for clean fills." + )), + IO.Boolean.Input("fill_chains", default=False, + tooltip=( + "Also fill open boundary chains (not just closed cycles) " + "by closing them with a fan + closing triangle. " + "Often produces noisy/overlapping geometry on real meshes " + "because chains are usually genuine surface boundaries or " + "fragments of cycles broken by non-manifold edges. Leave OFF " + "to match cumesh/upstream behavior." + )), + IO.Boolean.Input("verbose", default=False, + tooltip="Log topology diagnostics (edge counts, cycles found, reject reasons) for debugging."), + ], + outputs=[IO.Mesh.Output("mesh")], + ) + + @classmethod + def execute(cls, mesh, max_perimeter, weld_epsilon_rel, max_verts, fill_chains, verbose): + def _fn(v, f, c): + if max_perimeter > 0: + v, f, c = fill_holes_v2_fn( + v, f, max_perimeter=max_perimeter, colors=c, + weld_epsilon_rel=weld_epsilon_rel, + fill_chains=fill_chains, + max_verts=max_verts, + diagnostic=verbose, + ) + return v, f, c + return _process_mesh_batch(mesh, _fn) + + +class WeldVertices(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WeldVertices", + display_name="Weld Vertices", + category="latent/3d", + description=( + "Merge coincident vertices via L_inf grid quantization. Use when a " + "mesh comes in unwelded (every face has its own 3 verts, no shared edges) " + "— pre-pass before FillHoles, DecimateMesh, or any topology-aware op. " + "Per-vertex colors are averaged across each merged cluster." + ), + inputs=[ + IO.Mesh.Input("mesh"), + IO.Float.Input("epsilon_rel", default=1e-5, min=0.0, step=1e-6, + tooltip="Weld tolerance as a fraction of the bbox diagonal. " + "1e-5 is enough for floating-point dedup; raise to " + "1e-3 for visibly-close-but-distinct verts."), + IO.Float.Input("epsilon_abs", default=0.0, min=0.0, step=1e-6, + tooltip="Absolute weld tolerance (overrides epsilon_rel when > 0)."), + ], + outputs=[IO.Mesh.Output("mesh")], + ) + + @classmethod + def execute(cls, mesh, epsilon_rel, epsilon_abs): + eps = epsilon_abs if epsilon_abs > 0 else None + def _fn(v, f, c): + v, f, c, n = weld_vertices_fn(v, f, epsilon=eps, epsilon_rel=epsilon_rel, colors=c) + if n > 0: + logging.info(f"[WeldVertices] merged {n} verts ({v.shape[0]} remain)") + return v, f, c + return _process_mesh_batch(mesh, _fn) + + +def merge_meshes(meshes): + """Concatenate a list of Types.MESH into a single (B=1) mesh. + + Vertices, faces (with cumulative index offset), uvs, and vertex_colors are + concatenated. If only some inputs carry uvs/vertex_colors, the missing sides + are padded — zeros for uvs, white (1.0) for vertex_colors — so the merged + primitive has a uniform attribute set. Texture is taken from the first input + that has one; later textures are dropped with a warning (single-primitive glb + can't carry multiple texture atlases without baking). + """ + if not meshes: + raise ValueError("merge_meshes: need at least one mesh") + + def _b0(t): + return t[0] if t.ndim == 3 else t + + any_uvs = any(getattr(m, "uvs", None) is not None for m in meshes) + any_colors = any(getattr(m, "vertex_colors", None) is not None for m in meshes) + + verts_list, faces_list, uvs_list, colors_list = [], [], [], [] + texture = None + offset = 0 + for m in meshes: + # Mesh tensors are normalized to CPU by our producer nodes; coerce defensively + # so MoGe-side meshes (which may land on CUDA) merge cleanly with our outputs. + v = _b0(m.vertices).cpu() + f = _b0(m.faces).cpu() + verts_list.append(v) + faces_list.append(f + offset) + offset += v.shape[0] + if any_uvs: + mu = getattr(m, "uvs", None) + uvs_list.append(_b0(mu).cpu() if mu is not None else v.new_zeros((v.shape[0], 2))) + if any_colors: + mc = getattr(m, "vertex_colors", None) + if mc is not None: + c = _b0(mc).cpu() + else: + c = v.new_ones((v.shape[0], 3)) + colors_list.append(c) + mt = getattr(m, "texture", None) + if mt is not None: + if texture is None: + texture = mt.cpu() + else: + logging.warning("MergeMeshes: dropping extra texture from input; only one texture is kept.") + + merged_verts = torch.cat(verts_list, dim=0).unsqueeze(0) + merged_faces = torch.cat(faces_list, dim=0).unsqueeze(0) + merged_uvs = torch.cat(uvs_list, dim=0).unsqueeze(0) if any_uvs else None + merged_colors = torch.cat(colors_list, dim=0).unsqueeze(0) if any_colors else None + + return Types.MESH( + vertices=merged_verts, + faces=merged_faces, + uvs=merged_uvs, + vertex_colors=merged_colors, + texture=texture, + ) + + +class MergeMeshes(IO.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = IO.Autogrow.TemplatePrefix( + IO.Mesh.Input("mesh"), prefix="mesh", min=2, max=50, + ) + return IO.Schema( + node_id="MergeMeshes", + display_name="Merge Meshes", + category="latent/3d", + description=( + "Concatenate N meshes into a single mesh by offsetting face indices " + "and stacking vertices, faces, uvs, and vertex colors. Useful for combining a " + "Pixal3D-reconstructed object (via Pixal3DAlignObject) with a MoGe scene " + "background (via MoGePointMapToMesh) into one GLB." + ), + inputs=[ + IO.Autogrow.Input("meshes", template=autogrow_template), + ], + outputs=[IO.Mesh.Output("mesh")], + ) + + @classmethod + def execute(cls, meshes: IO.Autogrow.Type) -> IO.NodeOutput: + return IO.NodeOutput(merge_meshes(list(meshes.values()))) + + class PostProcessMeshExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ FillHoles, + FillHolesV2, + WeldVertices, DecimateMesh, - PaintMesh + PaintMesh, + BakeTextureFromVoxel, + MeshTextureToImage, + MergeMeshes, ] diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index bc04ba928..8b5e1b8e0 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -80,7 +80,8 @@ def get_mesh_batch_item(mesh, index): def save_glb(vertices, faces, filepath, metadata=None, - uvs=None, vertex_colors=None, texture_image=None): + uvs=None, vertex_colors=None, texture_image=None, + metallic_roughness_image=None): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -92,6 +93,8 @@ def save_glb(vertices, faces, filepath, metadata=None, uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1] texture_image: PIL.Image - Optional baseColor texture, embedded as PNG + metallic_roughness_image: PIL.Image - Optional glTF metallicRoughness texture + (R unused, G=roughness, B=metallic), embedded as PNG """ # Convert tensors to numpy arrays @@ -126,12 +129,18 @@ def save_glb(vertices, faces, filepath, metadata=None, buf = BytesIO() texture_image.save(buf, format="PNG") texture_png_bytes = buf.getvalue() + mr_png_bytes = None + if metallic_roughness_image is not None: + buf = BytesIO() + metallic_roughness_image.save(buf, format="PNG") + mr_png_bytes = buf.getvalue() vertices_buffer = vertices_np.tobytes() indices_buffer = faces_np.tobytes() uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b"" colors_buffer = colors_np.tobytes() if colors_np is not None else b"" texture_buffer = texture_png_bytes if texture_png_bytes is not None else b"" + mr_buffer = mr_png_bytes if mr_png_bytes is not None else b"" def pad_to_4_bytes(buffer): padding_length = (4 - (len(buffer) % 4)) % 4 @@ -142,6 +151,7 @@ def save_glb(vertices, faces, filepath, metadata=None, uvs_buffer_padded = pad_to_4_bytes(uvs_buffer) colors_buffer_padded = pad_to_4_bytes(colors_buffer) texture_buffer_padded = pad_to_4_bytes(texture_buffer) + mr_buffer_padded = pad_to_4_bytes(mr_buffer) buffer_data = b"".join([ vertices_buffer_padded, @@ -149,6 +159,7 @@ def save_glb(vertices, faces, filepath, metadata=None, uvs_buffer_padded, colors_buffer_padded, texture_buffer_padded, + mr_buffer_padded, ]) vertices_byte_length = len(vertices_buffer) @@ -158,6 +169,7 @@ def save_glb(vertices, faces, filepath, metadata=None, uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded) colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded) texture_byte_offset = colors_byte_offset + len(colors_buffer_padded) + mr_byte_offset = texture_byte_offset + len(texture_buffer_padded) buffer_views = [ { @@ -251,8 +263,24 @@ def save_glb(vertices, faces, filepath, metadata=None, }) images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) - textures.append({"source": 0, "sampler": 0}) - pbr["baseColorTexture"] = {"index": 0, "texCoord": 0} + textures.append({"source": len(images) - 1, "sampler": 0}) + pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0} + + if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: + buffer_views.append({ + "buffer": 0, + "byteOffset": mr_byte_offset, + "byteLength": len(mr_buffer), + }) + images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) + if not samplers: + samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) + textures.append({"source": len(images) - 1, "sampler": 0}) + pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0} + # When a metallicRoughness texture is present, the factors scale it; use 1.0 + # so the texture values pass through unchanged (glTF convention). + pbr["metallicFactor"] = 1.0 + pbr["roughnessFactor"] = 1.0 materials.append({ "pbrMetallicRoughness": pbr, @@ -373,12 +401,20 @@ class SaveGLB(IO.ComfyNode): assert texture_np.ndim == 4 and texture_np.shape[-1] == 3, ( f"texture must be (B, H, W, 3) RGB, got shape {tuple(texture_np.shape)}" ) + mr_b = getattr(mesh, "metallic_roughness", None) + mr_np = None + if mr_b is not None: + mr_np = (mr_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8) + assert mr_np.ndim == 4 and mr_np.shape[-1] == 3, ( + f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}" + ) for i in range(mesh.vertices.shape[0]): vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i) if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0: logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}") continue tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None + mr_img = Image.fromarray(mr_np[i], mode="RGB") if mr_np is not None else None f = f"{filename}_{counter:05}_.glb" save_glb( vertices_i, faces_i, @@ -387,6 +423,7 @@ class SaveGLB(IO.ComfyNode): uvs=uvs_i, vertex_colors=v_colors, texture_image=tex_img, + metallic_roughness_image=mr_img, ) results.append({ "filename": f, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f57b1b018..04bc19f25 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -9,8 +9,10 @@ from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch import comfy.model_management import comfy.utils import folder_paths +from comfy.ldm.trellis2 import sampling_preview from PIL import Image import logging +import os import numpy as np import math import torch @@ -19,6 +21,89 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES") NAFModel = io.Custom("NAF_MODEL") +# Texture latent -> base-color calibration for the per-step preview +def _tex_rgb_factors_path(): + return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt") + + +def _pool_albedo_to_input(in_coords, out_coords, out_colors): + in_sp = in_coords[:, 1:4].long() + out_sp = out_coords[:, 1:4].long() + in_b = in_coords[:, 0].long() + out_b = out_coords[:, 0].long() + in_res = int(in_sp.max().item()) + 1 + out_res = int(out_sp.max().item()) + 1 + parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1) + R = in_res + in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2] + par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2] + order = torch.argsort(in_flat) + in_sorted = in_flat[order] + pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1) + matched = in_sorted[pos] == par_flat + in_idx = order[pos][matched] + cols = out_colors[matched].float() + N = in_coords.shape[0] + csum = cols.new_zeros((N, 3)) + ccount = cols.new_zeros((N, 1)) + csum.index_add_(0, in_idx, cols) + ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype)) + valid = ccount[:, 0] > 0 + albedo = torch.zeros_like(csum) + albedo[valid] = csum[valid] / ccount[valid] + return albedo, valid + + +def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords): + """Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish.""" + try: + dev = out_colors.device + in_latent = in_latent.to(dev) + in_coords = in_coords.to(dev) + out_coords = out_coords.to(dev) + albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors) + X = in_latent[valid].float().cpu() + Y = albedo[valid].float().cpu() + if X.shape[0] < 64: + return + Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1] + A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1] + B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3] + + path = _tex_rgb_factors_path() + if os.path.exists(path): + try: + prev = torch.load(path, map_location="cpu") + A_run = A_run + prev["A"] + B_run = B_run + prev["B"] + except Exception: + pass + os.makedirs(os.path.dirname(path), exist_ok=True) + torch.save({"A": A_run, "B": B_run}, path) + + eye = torch.eye(A_run.shape[0]) + WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3] + W, b = WB[:-1].contiguous(), WB[-1].contiguous() + sampling_preview.set_tex_rgb(W, b) + except Exception as e: + logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}") + + +def _load_tex_rgb_factors(): + try: + path = _tex_rgb_factors_path() + if os.path.exists(path): + d = torch.load(path, map_location="cpu") + eye = torch.eye(d["A"].shape[0]) + WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"]) + sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous()) + except Exception as e: + logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}") + + +_load_tex_rgb_factors() + + def prepare_trellis_vae_for_decode(vae, sample_shape): memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) if len(sample_shape) == 5: @@ -174,6 +259,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): resolution = int(coord_resolution) * 16 else: resolution = int(vae.first_stage_model.resolution.item()) + model_frame = samples.get("model_frame", "y_up") sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() coords = samples["coords"] @@ -205,8 +291,13 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): coords_list = [stage_tensor.coords for stage_tensor in stage_tensors] subs.append(SparseTensor.from_tensor_list(feats_list, coords_list)) - vert_list = [v.float() for v, f in mesh] - face_list = [f.int() for v, f in mesh] + # Rotate Z-up (Trellis2 training frame) vertices to glTF Y-up. Pixal3D outputs are already Y-up. + if model_frame == "z_up": + vert_list = [torch.stack([v[..., 0], v[..., 2], -v[..., 1]], dim=-1).float().cpu() + for v, _ in mesh] + else: + vert_list = [v.float().cpu() for v, _ in mesh] + face_list = [f.int().cpu() for _, f in mesh] if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) else: @@ -241,19 +332,32 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): prepare_trellis_vae_for_decode(vae, sample_tensor.shape) trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") + model_frame = samples.get("model_frame", "y_up") samples = samples["samples"] samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) samples = samples.to(device) + cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration + cal_in_coords = coords std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords.to(device)) samples = samples * std + mean voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides) - color_feats = voxel.feats[:, :3] + # Keep all decoded channels. The texture VAE emits 6: base_color (0:3), + # metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color + # consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full + # PBR set. Older 3-channel checkpoints pass through unchanged. + color_feats = voxel.feats voxel_coords = voxel.coords + # Calibrate the latent->base_color map for the per-step texture preview. + # Done here while input coords and voxel_coords share the model frame + # (before the z_up remap below) and on the real decoded albedo. + if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3: + _calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords) + if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3: spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords max_idx = int(spatial.max().item()) + 1 @@ -261,6 +365,24 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): else: tex_resolution = 1024 + # Remap Z-up voxel coords to Y-up: (x, y, z) -> (x, z, R-1-y), matching the + # R_x(-90°) applied to mesh vertices in VaeDecodeShapeTrellis. Keeps PaintMesh's + # NN lookup correctly aligned without it needing to know the source frame. + if model_frame == "z_up" and voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3: + R = tex_resolution + if voxel_coords.shape[-1] == 4: + batch_col = voxel_coords[:, :1] + spatial = voxel_coords[:, 1:] + spatial_yup = torch.stack( + [spatial[:, 0], spatial[:, 2], (R - 1) - spatial[:, 1]], dim=-1 + ) + voxel_coords = torch.cat([batch_col, spatial_yup], dim=-1) + else: + voxel_coords = torch.stack( + [voxel_coords[:, 0], voxel_coords[:, 2], (R - 1) - voxel_coords[:, 1]], + dim=-1, + ) + voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution) return IO.NodeOutput(voxel) @@ -425,7 +547,9 @@ class Trellis2UpsampleStage(IO.ComfyNode): positive_out = _conditioning_set_extras(positive, extras) negative_out = _conditioning_set_extras(negative, extras) out_latent = {"samples": latent, "coords": coords, "coord_counts": counts, - "coord_resolution": coord_resolution, "type": "trellis2"} + "coord_resolution": coord_resolution, "type": "trellis2", + "model_frame": shape_latent.get("model_frame", + "y_up" if proj_pack is not None else "z_up")} return IO.NodeOutput(positive_out, negative_out, out_latent) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) @@ -694,7 +818,8 @@ class Trellis2ShapeStage(IO.ComfyNode): positive_out = _conditioning_set_extras(positive, extras) negative_out = _conditioning_set_extras(negative, extras) out_latent = {"samples": latent, "coords": coords, "coord_counts": counts, - "coord_resolution": coord_resolution, "type": "trellis2"} + "coord_resolution": coord_resolution, "type": "trellis2", + "model_frame": "y_up" if proj_pack is not None else "z_up"} return IO.NodeOutput(positive_out, negative_out, out_latent) class Trellis2TextureStage(IO.ComfyNode): @@ -747,7 +872,9 @@ class Trellis2TextureStage(IO.ComfyNode): positive_out = _conditioning_set_extras(positive, extras) negative_out = _conditioning_set_extras(negative, extras) - out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts} + out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts, + "model_frame": shape_latent.get("model_frame", + "y_up" if proj_pack is not None else "z_up")} if coord_resolution is not None: out_latent["coord_resolution"] = coord_resolution return IO.NodeOutput(positive_out, negative_out, out_latent) @@ -1018,6 +1145,8 @@ def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle points = vertices_world.unsqueeze(0).float() T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float() cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x + T = T.to(points.device) + cam = cam.to(points.device) uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution) uv = uv_pix.squeeze(0) / image_resolution return uv, depth.squeeze(0), valid.squeeze(0) @@ -1108,13 +1237,25 @@ class Pixal3DAlignObject(IO.ComfyNode): scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H)) in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) & (scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H)) + # MoGe geometry and object_mask can land on CPU after passing between nodes; + # match the indexed tensor's device for sy/sx so the gather works on either. + moge_points = moge_points.to(scene_pixels.device) + moge_mask = moge_mask.to(scene_pixels.device) sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1) sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1) moge_per_vertex = moge_points[batch_index, sy, sx] + # MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF + # Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same + # frame as vertices_one (Pixal3D model frame = glTF Y-up). Mirrors the + # `verts * [1, -1, -1]` step in MoGePointMapToMesh. + moge_per_vertex = moge_per_vertex * torch.tensor( + [1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device + ) moge_mask_per_vertex = moge_mask[batch_index, sy, sx] keep = valid & in_scene & moge_mask_per_vertex if object_mask is not None: om = object_mask if object_mask.ndim == 2 else object_mask[batch_index] + om = om.to(sy.device) keep = keep & (om[sy, sx] > 0.5) finite = torch.isfinite(moge_per_vertex).all(dim=-1) @@ -1131,25 +1272,23 @@ class Pixal3DAlignObject(IO.ComfyNode): q_mean = Q.mean(dim=0, keepdim=True) P_c = P - p_mean Q_c = Q - q_mean - num = (P_c * Q_c).sum() - den = (P_c * P_c).sum().clamp(min=1e-8) - scale = float((num / den).item()) - if not (scale > 0): - # Negative scale would mirror the mesh; treat as a camera-convention mismatch. - logging.warning( - f"Pixal3DAlignObject: computed scale={scale:.4f} <= 0; " - "refusing to apply mirroring. Check camera convention alignment.") - scale = 1.0 - aligned = vertices_one - else: - t = q_mean - scale * p_mean - aligned = scale * vertices_one + t + # Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is + # noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper + # acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c) + # gets multiplied by cos(yaw) and shrinks the object. Using + # sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of + # rotation; translation still positions the mesh at MoGe's centroid. + p_var = (P_c * P_c).sum().clamp(min=1e-8) + q_var = (Q_c * Q_c).sum() + scale = float(torch.sqrt(q_var / p_var).item()) + t = q_mean - scale * p_mean + aligned = scale * vertices_one + t if vertices.ndim == 3: aligned = aligned.unsqueeze(0) - out_mesh = Types.MESH(vertices=aligned, faces=faces) + out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu()) else: - out_mesh = Types.MESH(vertices=aligned, faces=faces_one) + out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu()) return IO.NodeOutput(out_mesh, float(scale))