Bake without opengl

This commit is contained in:
kijai 2026-06-28 02:07:49 +03:00
parent 022503dbc9
commit ab58d1b79f
2 changed files with 62 additions and 152 deletions

View File

@ -1612,7 +1612,7 @@ def qem_simplify(
)
def simplify(
def qem_decimate_simplify(
vertices: torch.Tensor,
faces: torch.Tensor,
target: int,
@ -1640,7 +1640,7 @@ def simplify(
return qem_simplify(vertices, faces, target, colors, normals, max_edge_length, config)
def cluster_decimate(
def qem_cluster_decimate(
vertices: torch.Tensor, faces: torch.Tensor,
target_verts: int = 1_000_000,
colors: Optional[torch.Tensor] = None,

View File

@ -7,9 +7,7 @@ import copy
import comfy.utils
import comfy.model_management
from server import PromptServer
from comfy_extras.mesh3d.postprocess.qem_decimate import (
simplify as qem_decimate_simplify, QEMConfig, cluster_decimate as qem_cluster_decimate,
)
from comfy_extras.mesh3d.postprocess.qem_decimate import QEMConfig, qem_decimate_simplify, qem_cluster_decimate
from comfy_extras.mesh3d.postprocess.remesh import remesh_narrow_band_dc, _point_tri_closest
from comfy_extras.mesh3d.uv_unwrap import mesh as _uv_mesh
from comfy_extras.mesh3d.uv_unwrap import segment as _uv_seg
@ -168,161 +166,73 @@ class PaintMesh(IO.ComfyNode):
return IO.NodeOutput(out_mesh)
# Texture baking from sparse voxel volume: existing UVs → OpenGL UV-space
# rasterize → per-texel voxel sample → JFA seam fill → attach to mesh for SaveGLB.
# Never unwraps (done upstream). GL context via nodes_glsl.GLContext.
_GL_COMPILE_PROGRAM_CACHE_KEY = "_bake_texture_program_cache"
def _gl_compile_program(gl, vert_src: str, frag_src: str):
"""Compile+link a vert+frag GL program (caller glDeleteProgram)."""
def _check_shader(s, kind):
if not gl.glGetShaderiv(s, gl.GL_COMPILE_STATUS):
log = gl.glGetShaderInfoLog(s).decode(errors="replace")
gl.glDeleteShader(s)
raise RuntimeError(f"GL {kind} shader compile failed: {log}")
vs = gl.glCreateShader(gl.GL_VERTEX_SHADER)
gl.glShaderSource(vs, vert_src)
gl.glCompileShader(vs)
_check_shader(vs, "vertex")
fs = gl.glCreateShader(gl.GL_FRAGMENT_SHADER)
gl.glShaderSource(fs, frag_src)
gl.glCompileShader(fs)
_check_shader(fs, "fragment")
prog = gl.glCreateProgram()
gl.glAttachShader(prog, vs)
gl.glAttachShader(prog, fs)
gl.glLinkProgram(prog)
gl.glDeleteShader(vs)
gl.glDeleteShader(fs)
if not gl.glGetProgramiv(prog, gl.GL_LINK_STATUS):
log = gl.glGetProgramInfoLog(prog).decode(errors="replace")
gl.glDeleteProgram(prog)
raise RuntimeError(f"GL program link failed: {log}")
return prog
# Position-passthrough: vertex maps UV → clip space; fragment outputs interpolated
# world-space position (alpha=1 marks valid texels).
_BAKE_VERT_SRC = """
#version 330 core
layout (location = 0) in vec3 a_pos;
layout (location = 1) in vec2 a_uv;
out vec3 v_pos;
void main() {
v_pos = a_pos;
gl_Position = vec4(a_uv * 2.0 - 1.0, 0.0, 1.0);
}
"""
_BAKE_FRAG_SRC = """
#version 330 core
in vec3 v_pos;
out vec4 frag_color;
void main() {
frag_color = vec4(v_pos, 1.0);
}
"""
def _bake_position_map(verts_np, faces_np, uvs_np, texture_size):
"""Rasterize unwrapped mesh in UV space. Returns (position_map [H,W,3] float32,
mask [H,W] bool covered)."""
from comfy_extras.nodes_glsl import GLContext, _import_opengl
GLContext() # ensure backend is initialized + current
gl = _import_opengl()
# PyOpenGL high-level wrappers store array refs in OpenGL.contextdata, which on
# EGL contexts errors ("no valid context"); use raw C entry points instead.
import ctypes as _ctypes
from OpenGL.raw.GL.VERSION.GL_1_1 import (
glReadPixels as _raw_glReadPixels,
glTexImage2D as _raw_glTexImage2D,
glDrawElements as _raw_glDrawElements,
)
from OpenGL.raw.GL.VERSION.GL_1_5 import glBufferData as _raw_glBufferData
from OpenGL.raw.GL.VERSION.GL_2_0 import glVertexAttribPointer as _raw_glVertexAttribPointer
"""Rasterize the mesh in UV space and barycentric-interpolate the per-vertex vec3
(world position, or any vec3 attr e.g. normals) at each covered texel. Pure torch,
tiled point-in-triangle no GL/EGL, runs anywhere torch does. Returns (attr_map
[H,W,3] float32, mask [H,W] bool). """
dev = comfy.model_management.get_torch_device()
H = W = int(texture_size)
fbo = color_tex = vbo = ibo = vao = prog = None
try:
# Interleaved [pos.xyz, uv.xy] per vertex (stride=20).
verts32 = np.ascontiguousarray(verts_np, dtype=np.float32)
uvs32 = np.ascontiguousarray(uvs_np, dtype=np.float32)
faces32 = np.ascontiguousarray(faces_np, dtype=np.uint32)
vbo_data = np.ascontiguousarray(np.concatenate([verts32, uvs32], axis=-1), dtype=np.float32)
if faces_np.shape[0] == 0:
return np.zeros((H, W, 3), dtype=np.float32), np.zeros((H, W), dtype=bool)
prog = _gl_compile_program(gl, _BAKE_VERT_SRC, _BAKE_FRAG_SRC)
verts = torch.from_numpy(np.ascontiguousarray(verts_np, dtype=np.float32)).to(dev)
uvs = torch.from_numpy(np.ascontiguousarray(uvs_np, dtype=np.float32)).to(dev)
faces = torch.from_numpy(np.ascontiguousarray(faces_np).astype(np.int64)).to(dev)
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
# GL convention: window coord = uv * resolution, coverage tested at texel centre.
tri_uv = (uvs * float(W))[faces] # [F,3,2]
tri_attr = verts[faces] # [F,3,3]
x0, y0 = tri_uv[:, 0, 0], tri_uv[:, 0, 1]
x1, y1 = tri_uv[:, 1, 0], tri_uv[:, 1, 1]
x2, y2 = tri_uv[:, 2, 0], tri_uv[:, 2, 1]
denom = (y1 - y2) * (x0 - x2) + (x2 - x1) * (y0 - y2)
nondegen = denom.abs() > 1e-20
vbo = gl.glGenBuffers(1)
gl.glBindBuffer(gl.GL_ARRAY_BUFFER, vbo)
_raw_glBufferData(gl.GL_ARRAY_BUFFER, int(vbo_data.nbytes),
vbo_data.ctypes.data_as(_ctypes.c_void_p), gl.GL_STATIC_DRAW)
xmin = torch.minimum(torch.minimum(x0, x1), x2).floor().clamp_(0, W - 1).long()
xmax = torch.maximum(torch.maximum(x0, x1), x2).ceil().clamp_(0, W - 1).long()
ymin = torch.minimum(torch.minimum(y0, y1), y2).floor().clamp_(0, H - 1).long()
ymax = torch.maximum(torch.maximum(y0, y1), y2).ceil().clamp_(0, H - 1).long()
ibo = gl.glGenBuffers(1)
gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, ibo)
_raw_glBufferData(gl.GL_ELEMENT_ARRAY_BUFFER, int(faces32.nbytes),
faces32.ctypes.data_as(_ctypes.c_void_p), gl.GL_STATIC_DRAW)
pos_out = torch.zeros((H, W, 3), device=dev)
cov = torch.zeros((H, W), dtype=torch.bool, device=dev)
gl.glEnableVertexAttribArray(0)
_raw_glVertexAttribPointer(0, 3, gl.GL_FLOAT, gl.GL_FALSE, 20, _ctypes.c_void_p(0))
gl.glEnableVertexAttribArray(1)
_raw_glVertexAttribPointer(1, 2, gl.GL_FLOAT, gl.GL_FALSE, 20, _ctypes.c_void_p(12))
# Tile so point-in-triangle only runs over the triangles whose bbox hits the tile.
TILE = 64
eps = 1e-6
for ty in range(0, H, TILE):
ty_end = min(ty + TILE, H)
for tx in range(0, W, TILE):
tx_end = min(tx + TILE, W)
tri_mask = (nondegen & (xmin < tx_end) & (xmax >= tx)
& (ymin < ty_end) & (ymax >= ty))
if not tri_mask.any():
continue
idx = torch.nonzero(tri_mask, as_tuple=True)[0]
ys = torch.arange(ty, ty_end, dtype=torch.float32, device=dev) + 0.5
xs = torch.arange(tx, tx_end, dtype=torch.float32, device=dev) + 0.5
yy, xx = torch.meshgrid(ys, xs, indexing="ij") # [th,tw]
sx0, sy0 = x0[idx][:, None, None], y0[idx][:, None, None]
sx1, sy1 = x1[idx][:, None, None], y1[idx][:, None, None]
sx2, sy2 = x2[idx][:, None, None], y2[idx][:, None, None]
sden = denom[idx][:, None, None]
b0 = ((sy1 - sy2) * (xx - sx2) + (sx2 - sx1) * (yy - sy2)) / sden
b1 = ((sy2 - sy0) * (xx - sx2) + (sx0 - sx2) * (yy - sy2)) / sden
b2 = 1.0 - b0 - b1
inside = (b0 >= -eps) & (b1 >= -eps) & (b2 >= -eps) # [K,th,tw]
if not inside.any():
continue
hit = inside.any(dim=0) # [th,tw]
sel = inside.int().argmax(dim=0) # [th,tw] first covering local tri
b0s = b0.gather(0, sel[None]).squeeze(0) # [th,tw] bary of selected tri
b1s = b1.gather(0, sel[None]).squeeze(0)
b2s = b2.gather(0, sel[None]).squeeze(0)
p = tri_attr[idx[sel]] # [th,tw,3,3]
attr = b0s[..., None] * p[..., 0, :] + b1s[..., None] * p[..., 1, :] + b2s[..., None] * p[..., 2, :]
pos_out[ty:ty_end, tx:tx_end][hit] = attr[hit] # slice is a view → writes through
cov[ty:ty_end, tx:tx_end] |= hit
fbo = gl.glGenFramebuffers(1)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
color_tex = gl.glGenTextures(1)
gl.glBindTexture(gl.GL_TEXTURE_2D, color_tex)
_raw_glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, W, H,
0, gl.GL_RGBA, gl.GL_FLOAT, None)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_NEAREST)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0,
gl.GL_TEXTURE_2D, color_tex, 0)
status = gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER)
if status != gl.GL_FRAMEBUFFER_COMPLETE:
raise RuntimeError(f"FBO incomplete (status=0x{status:x})")
gl.glViewport(0, 0, W, H)
gl.glClearColor(0.0, 0.0, 0.0, 0.0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
gl.glDisable(gl.GL_CULL_FACE)
gl.glDisable(gl.GL_DEPTH_TEST)
gl.glUseProgram(prog)
_raw_glDrawElements(gl.GL_TRIANGLES, int(faces32.size), gl.GL_UNSIGNED_INT, None)
gl.glFinish()
# Pre-allocated readback buffer passed as a pointer (skips PyOpenGL alloc).
arr = np.empty((H, W, 4), dtype=np.float32)
_raw_glReadPixels(0, 0, W, H, gl.GL_RGBA, gl.GL_FLOAT,
arr.ctypes.data_as(_ctypes.c_void_p))
# Do NOT flipud: shader puts UV(0,0) at FBO bottom-left and glReadPixels
# returns bottom-row-first, so arr[0] is UV v=0 — matches glTF PNG row 0.
position_map = arr[..., :3]
mask = arr[..., 3] > 0.5
return position_map, mask
finally:
if fbo is not None:
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glDeleteFramebuffers(1, [fbo])
if color_tex is not None:
gl.glDeleteTextures(1, [color_tex])
if vbo is not None:
gl.glDeleteBuffers(1, [vbo])
if ibo is not None:
gl.glDeleteBuffers(1, [ibo])
if vao is not None:
gl.glBindVertexArray(0)
gl.glDeleteVertexArrays(1, [vao])
if prog is not None:
gl.glUseProgram(0)
gl.glDeleteProgram(prog)
return pos_out.cpu().numpy(), cov.cpu().numpy()
def _trilinear_sample_sparse(positions, voxel_coords_np, color_np, resolution):