From ab58d1b79f4e2b84627c120e5e99f4d0bfa60777 Mon Sep 17 00:00:00 2001 From: kijai Date: Sun, 28 Jun 2026 02:07:49 +0300 Subject: [PATCH] Bake without opengl --- .../mesh3d/postprocess/qem_decimate.py | 4 +- comfy_extras/nodes_mesh_postprocess.py | 210 +++++------------- 2 files changed, 62 insertions(+), 152 deletions(-) diff --git a/comfy_extras/mesh3d/postprocess/qem_decimate.py b/comfy_extras/mesh3d/postprocess/qem_decimate.py index 202f9dd3d..d382a0661 100644 --- a/comfy_extras/mesh3d/postprocess/qem_decimate.py +++ b/comfy_extras/mesh3d/postprocess/qem_decimate.py @@ -1612,7 +1612,7 @@ def qem_simplify( ) -def simplify( +def qem_decimate_simplify( vertices: torch.Tensor, faces: torch.Tensor, target: int, @@ -1640,7 +1640,7 @@ def simplify( return qem_simplify(vertices, faces, target, colors, normals, max_edge_length, config) -def cluster_decimate( +def qem_cluster_decimate( vertices: torch.Tensor, faces: torch.Tensor, target_verts: int = 1_000_000, colors: Optional[torch.Tensor] = None, diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 2c7f63736..090e55399 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -7,9 +7,7 @@ import copy import comfy.utils import comfy.model_management from server import PromptServer -from comfy_extras.mesh3d.postprocess.qem_decimate import ( - simplify as qem_decimate_simplify, QEMConfig, cluster_decimate as qem_cluster_decimate, -) +from comfy_extras.mesh3d.postprocess.qem_decimate import QEMConfig, qem_decimate_simplify, qem_cluster_decimate from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc, _point_tri_closest from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg @@ -168,161 +166,73 @@ class PaintMesh(IO.ComfyNode): return IO.NodeOutput(out_mesh) -# Texture baking from sparse voxel volume: existing UVs → OpenGL UV-space -# rasterize → per-texel voxel sample → JFA seam fill → attach to mesh for SaveGLB. -# Never unwraps (done upstream). GL context via nodes_glsl.GLContext. - -_GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache" - - -def _gl_compile_program(gl, vert_src: str, frag_src: str): - """Compile+link a vert+frag GL program (caller glDeleteProgram).""" - def _check_shader(s, kind): - if not gl.glGetShaderiv(s, gl.GL_COMPILE_STATUS): - log = gl.glGetShaderInfoLog(s).decode(errors="replace") - gl.glDeleteShader(s) - raise RuntimeError(f"GL {kind} shader compile failed: {log}") - - vs = gl.glCreateShader(gl.GL_VERTEX_SHADER) - gl.glShaderSource(vs, vert_src) - gl.glCompileShader(vs) - _check_shader(vs, "vertex") - fs = gl.glCreateShader(gl.GL_FRAGMENT_SHADER) - gl.glShaderSource(fs, frag_src) - gl.glCompileShader(fs) - _check_shader(fs, "fragment") - prog = gl.glCreateProgram() - gl.glAttachShader(prog, vs) - gl.glAttachShader(prog, fs) - gl.glLinkProgram(prog) - gl.glDeleteShader(vs) - gl.glDeleteShader(fs) - if not gl.glGetProgramiv(prog, gl.GL_LINK_STATUS): - log = gl.glGetProgramInfoLog(prog).decode(errors="replace") - gl.glDeleteProgram(prog) - raise RuntimeError(f"GL program link failed: {log}") - return prog - - -# Position-passthrough: vertex maps UV → clip space; fragment outputs interpolated -# world-space position (alpha=1 marks valid texels). -_BAKE_VERT_SRC = """ -#version 330 core -layout (location = 0) in vec3 a_pos; -layout (location = 1) in vec2 a_uv; -out vec3 v_pos; -void main() { - v_pos = a_pos; - gl_Position = vec4(a_uv * 2.0 - 1.0, 0.0, 1.0); -} -""" - -_BAKE_FRAG_SRC = """ -#version 330 core -in vec3 v_pos; -out vec4 frag_color; -void main() { - frag_color = vec4(v_pos, 1.0); -} -""" - - def _bake_position_map(verts_np, faces_np, uvs_np, texture_size): - """Rasterize unwrapped mesh in UV space. Returns (position_map [H,W,3] float32, - mask [H,W] bool covered).""" - from comfy_extras.nodes_glsl import GLContext, _import_opengl - - GLContext() # ensure backend is initialized + current - gl = _import_opengl() - - # PyOpenGL high-level wrappers store array refs in OpenGL.contextdata, which on - # EGL contexts errors ("no valid context"); use raw C entry points instead. - import ctypes as _ctypes - from OpenGL.raw.GL.VERSION.GL_1_1 import ( - glReadPixels as _raw_glReadPixels, - glTexImage2D as _raw_glTexImage2D, - glDrawElements as _raw_glDrawElements, - ) - from OpenGL.raw.GL.VERSION.GL_1_5 import glBufferData as _raw_glBufferData - from OpenGL.raw.GL.VERSION.GL_2_0 import glVertexAttribPointer as _raw_glVertexAttribPointer - + """Rasterize the mesh in UV space and barycentric-interpolate the per-vertex vec3 + (world position, or any vec3 attr e.g. normals) at each covered texel. Pure torch, + tiled point-in-triangle — no GL/EGL, runs anywhere torch does. Returns (attr_map + [H,W,3] float32, mask [H,W] bool). """ + dev = comfy.model_management.get_torch_device() H = W = int(texture_size) - fbo = color_tex = vbo = ibo = vao = prog = None - try: - # Interleaved [pos.xyz, uv.xy] per vertex (stride=20). - verts32 = np.ascontiguousarray(verts_np, dtype=np.float32) - uvs32 = np.ascontiguousarray(uvs_np, dtype=np.float32) - faces32 = np.ascontiguousarray(faces_np, dtype=np.uint32) - vbo_data = np.ascontiguousarray(np.concatenate([verts32, uvs32], axis=-1), dtype=np.float32) + if faces_np.shape[0] == 0: + return np.zeros((H, W, 3), dtype=np.float32), np.zeros((H, W), dtype=bool) - prog = _gl_compile_program(gl, _BAKE_VERT_SRC, _BAKE_FRAG_SRC) + verts = torch.from_numpy(np.ascontiguousarray(verts_np, dtype=np.float32)).to(dev) + uvs = torch.from_numpy(np.ascontiguousarray(uvs_np, dtype=np.float32)).to(dev) + faces = torch.from_numpy(np.ascontiguousarray(faces_np).astype(np.int64)).to(dev) - vao = gl.glGenVertexArrays(1) - gl.glBindVertexArray(vao) + # GL convention: window coord = uv * resolution, coverage tested at texel centre. + tri_uv = (uvs * float(W))[faces] # [F,3,2] + tri_attr = verts[faces] # [F,3,3] + x0, y0 = tri_uv[:, 0, 0], tri_uv[:, 0, 1] + x1, y1 = tri_uv[:, 1, 0], tri_uv[:, 1, 1] + x2, y2 = tri_uv[:, 2, 0], tri_uv[:, 2, 1] + denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2) + nondegen = denom.abs() > 1e-20 - vbo = gl.glGenBuffers(1) - gl.glBindBuffer(gl.GL_ARRAY_BUFFER, vbo) - _raw_glBufferData(gl.GL_ARRAY_BUFFER, int(vbo_data.nbytes), - vbo_data.ctypes.data_as(_ctypes.c_void_p), gl.GL_STATIC_DRAW) + xmin = torch.minimum(torch.minimum(x0, x1), x2).floor().clamp_(0, W - 1).long() + xmax = torch.maximum(torch.maximum(x0, x1), x2).ceil().clamp_(0, W - 1).long() + ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, H - 1).long() + ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, H - 1).long() - ibo = gl.glGenBuffers(1) - gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, ibo) - _raw_glBufferData(gl.GL_ELEMENT_ARRAY_BUFFER, int(faces32.nbytes), - faces32.ctypes.data_as(_ctypes.c_void_p), gl.GL_STATIC_DRAW) + pos_out = torch.zeros((H, W, 3), device=dev) + cov = torch.zeros((H, W), dtype=torch.bool, device=dev) - gl.glEnableVertexAttribArray(0) - _raw_glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl.GL_FALSE, 20, _ctypes.c_void_p(0)) - gl.glEnableVertexAttribArray(1) - _raw_glVertexAttribPointer(1, 2, gl.GL_FLOAT, gl.GL_FALSE, 20, _ctypes.c_void_p(12)) + # Tile so point-in-triangle only runs over the triangles whose bbox hits the tile. + TILE = 64 + eps = 1e-6 + for ty in range(0, H, TILE): + ty_end = min(ty + TILE, H) + for tx in range(0, W, TILE): + tx_end = min(tx + TILE, W) + tri_mask = (nondegen & (xmin < tx_end) & (xmax >= tx) + & (ymin < ty_end) & (ymax >= ty)) + if not tri_mask.any(): + continue + idx = torch.nonzero(tri_mask, as_tuple=True)[0] + ys = torch.arange(ty, ty_end, dtype=torch.float32, device=dev) + 0.5 + xs = torch.arange(tx, tx_end, dtype=torch.float32, device=dev) + 0.5 + yy, xx = torch.meshgrid(ys, xs, indexing="ij") # [th,tw] + sx0, sy0 = x0[idx][:, None, None], y0[idx][:, None, None] + sx1, sy1 = x1[idx][:, None, None], y1[idx][:, None, None] + sx2, sy2 = x2[idx][:, None, None], y2[idx][:, None, None] + sden = denom[idx][:, None, None] + b0 = ((sy1 - sy2) * (xx - sx2) + (sx2 - sx1) * (yy - sy2)) / sden + b1 = ((sy2 - sy0) * (xx - sx2) + (sx0 - sx2) * (yy - sy2)) / sden + b2 = 1.0 - b0 - b1 + inside = (b0 >= -eps) & (b1 >= -eps) & (b2 >= -eps) # [K,th,tw] + if not inside.any(): + continue + hit = inside.any(dim=0) # [th,tw] + sel = inside.int().argmax(dim=0) # [th,tw] first covering local tri + b0s = b0.gather(0, sel[None]).squeeze(0) # [th,tw] bary of selected tri + b1s = b1.gather(0, sel[None]).squeeze(0) + b2s = b2.gather(0, sel[None]).squeeze(0) + p = tri_attr[idx[sel]] # [th,tw,3,3] + attr = b0s[..., None] * p[..., 0, :] + b1s[..., None] * p[..., 1, :] + b2s[..., None] * p[..., 2, :] + pos_out[ty:ty_end, tx:tx_end][hit] = attr[hit] # slice is a view → writes through + cov[ty:ty_end, tx:tx_end] |= hit - fbo = gl.glGenFramebuffers(1) - gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) - color_tex = gl.glGenTextures(1) - gl.glBindTexture(gl.GL_TEXTURE_2D, color_tex) - _raw_glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, W, H, - 0, gl.GL_RGBA, gl.GL_FLOAT, None) - gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_NEAREST) - gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST) - gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, - gl.GL_TEXTURE_2D, color_tex, 0) - status = gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) - if status != gl.GL_FRAMEBUFFER_COMPLETE: - raise RuntimeError(f"FBO incomplete (status=0x{status:x})") - - gl.glViewport(0, 0, W, H) - gl.glClearColor(0.0, 0.0, 0.0, 0.0) - gl.glClear(gl.GL_COLOR_BUFFER_BIT) - gl.glDisable(gl.GL_CULL_FACE) - gl.glDisable(gl.GL_DEPTH_TEST) - gl.glUseProgram(prog) - _raw_glDrawElements(gl.GL_TRIANGLES, int(faces32.size), gl.GL_UNSIGNED_INT, None) - gl.glFinish() - - # Pre-allocated readback buffer passed as a pointer (skips PyOpenGL alloc). - arr = np.empty((H, W, 4), dtype=np.float32) - _raw_glReadPixels(0, 0, W, H, gl.GL_RGBA, gl.GL_FLOAT, - arr.ctypes.data_as(_ctypes.c_void_p)) - # Do NOT flipud: shader puts UV(0,0) at FBO bottom-left and glReadPixels - # returns bottom-row-first, so arr[0] is UV v=0 — matches glTF PNG row 0. - position_map = arr[..., :3] - mask = arr[..., 3] > 0.5 - return position_map, mask - finally: - if fbo is not None: - gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) - gl.glDeleteFramebuffers(1, [fbo]) - if color_tex is not None: - gl.glDeleteTextures(1, [color_tex]) - if vbo is not None: - gl.glDeleteBuffers(1, [vbo]) - if ibo is not None: - gl.glDeleteBuffers(1, [ibo]) - if vao is not None: - gl.glBindVertexArray(0) - gl.glDeleteVertexArrays(1, [vao]) - if prog is not None: - gl.glUseProgram(0) - gl.glDeleteProgram(prog) + return pos_out.cpu().numpy(), cov.cpu().numpy() def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution):