mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 16:50:17 +08:00
convert to using PyOpenGL and glfw
This commit is contained in:
parent
aaea976f36
commit
a4317314d2
4
.github/workflows/test-build.yml
vendored
4
.github/workflows/test-build.yml
vendored
@ -25,10 +25,6 @@ jobs:
|
|||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install system dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y libx11-dev
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
|||||||
@ -1,18 +1,59 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
os.environ.setdefault("PYOPENGL_PLATFORM", "osmesa")
|
||||||
|
elif not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"):
|
||||||
|
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from typing import TypedDict
|
||||||
from typing import TypedDict, Generator
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.latest import ComfyExtension, io, ui
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
from comfy.cli_args import args
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from utils.install_util import get_missing_requirements_message
|
from utils.install_util import get_missing_requirements_message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import glfw
|
||||||
|
import OpenGL.GL as gl
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
||||||
|
"Install with: pip install PyOpenGL PyOpenGL-accelerate glfw"
|
||||||
|
) from e
|
||||||
|
except AttributeError as e:
|
||||||
|
# This happens when PyOpenGL can't load the requested platform (e.g., OSMesa not installed)
|
||||||
|
platform = os.environ.get("PYOPENGL_PLATFORM", "default")
|
||||||
|
if platform == "osmesa":
|
||||||
|
raise RuntimeError(
|
||||||
|
"OSMesa (software rendering) requested but not installed.\n"
|
||||||
|
"OSMesa is required for --cpu mode.\n\n"
|
||||||
|
"Install OSMesa:\n"
|
||||||
|
" e.g. Ubuntu/Debian: sudo apt install libosmesa6-dev\n"
|
||||||
|
"Or disable CPU mode to use hardware rendering."
|
||||||
|
) from e
|
||||||
|
elif platform == "egl":
|
||||||
|
raise RuntimeError(
|
||||||
|
"EGL (headless rendering) requested but not available.\n"
|
||||||
|
"EGL is used for headless GPU rendering without a display.\n\n"
|
||||||
|
"Install EGL:\n"
|
||||||
|
" e.g. Ubuntu/Debian: sudo apt install libegl1-mesa-dev libgles2-mesa-dev\n"
|
||||||
|
"Or set DISPLAY/WAYLAND_DISPLAY environment variable if you have a display."
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"OpenGL initialization failed (platform: {platform}).\n"
|
||||||
|
"Ensure OpenGL drivers are installed and working."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class SizeModeInput(TypedDict):
|
class SizeModeInput(TypedDict):
|
||||||
size_mode: str
|
size_mode: str
|
||||||
@ -24,15 +65,25 @@ MAX_IMAGES = 5 # u_image0-4
|
|||||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||||
|
# Draws a single triangle that covers the entire screen:
|
||||||
|
#
|
||||||
|
# (-1,3)
|
||||||
|
# /|
|
||||||
|
# / | <- visible area is the unit square from (-1,-1) to (1,1)
|
||||||
|
# / | parts outside get clipped away
|
||||||
|
# (-1,-1)---(3,-1)
|
||||||
|
#
|
||||||
|
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
||||||
|
VERTEX_SHADER = """#version 330 core
|
||||||
|
out vec2 v_texCoord;
|
||||||
|
void main() {
|
||||||
|
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
||||||
|
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
|
||||||
|
gl_Position = vec4(verts[gl_VertexID], 0, 1);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
|
||||||
import moderngl
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
|
|
||||||
|
|
||||||
# Default NOOP fragment shader that passes through the input image unchanged
|
|
||||||
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
|
|
||||||
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
||||||
precision highp float;
|
precision highp float;
|
||||||
|
|
||||||
@ -48,252 +99,267 @@ void main() {
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Simple vertex shader for full-screen quad
|
def _convert_es_to_desktop(source: str) -> str:
|
||||||
VERTEX_SHADER = """#version 330
|
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
||||||
|
# Remove any existing #version directive
|
||||||
in vec2 in_position;
|
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
||||||
in vec2 in_texcoord;
|
# Remove precision qualifiers (not needed in desktop GLSL)
|
||||||
|
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
||||||
out vec2 v_texCoord;
|
# Prepend desktop GLSL version
|
||||||
|
return "#version 330 core\n" + source
|
||||||
void main() {
|
|
||||||
gl_Position = vec4(in_position, 0.0, 1.0);
|
|
||||||
v_texCoord = in_texcoord;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_es_to_desktop_glsl(source: str) -> str:
|
class GLContext:
|
||||||
"""Convert GLSL ES 3.00 shader to desktop GLSL 3.30 for ModernGL compatibility."""
|
"""Manages OpenGL context and resources for shader execution."""
|
||||||
return re.sub(r'#version\s+300\s+es', '#version 330', source)
|
|
||||||
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if GLContext._initialized:
|
||||||
|
return
|
||||||
|
GLContext._initialized = True
|
||||||
|
|
||||||
|
import time
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
if not glfw.init():
|
||||||
|
raise RuntimeError("Failed to initialize GLFW")
|
||||||
|
|
||||||
|
glfw.window_hint(glfw.VISIBLE, glfw.FALSE)
|
||||||
|
glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3)
|
||||||
|
glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3)
|
||||||
|
glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE)
|
||||||
|
|
||||||
|
self._window = glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
||||||
|
if not self._window:
|
||||||
|
glfw.terminate()
|
||||||
|
raise RuntimeError("Failed to create GLFW window")
|
||||||
|
|
||||||
|
glfw.make_context_current(self._window)
|
||||||
|
|
||||||
|
# Create VAO (required for core profile even if we don't use vertex attributes)
|
||||||
|
self._vao = gl.glGenVertexArrays(1)
|
||||||
|
gl.glBindVertexArray(self._vao)
|
||||||
|
|
||||||
|
elapsed = (time.perf_counter() - start) * 1000
|
||||||
|
|
||||||
|
# Log device info
|
||||||
|
renderer = gl.glGetString(gl.GL_RENDERER)
|
||||||
|
vendor = gl.glGetString(gl.GL_VENDOR)
|
||||||
|
version = gl.glGetString(gl.GL_VERSION)
|
||||||
|
renderer = renderer.decode() if renderer else "Unknown"
|
||||||
|
vendor = vendor.decode() if vendor else "Unknown"
|
||||||
|
version = version.decode() if version else "Unknown"
|
||||||
|
|
||||||
|
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - {renderer} ({vendor}), GL {version}")
|
||||||
|
|
||||||
|
def make_current(self):
|
||||||
|
glfw.make_context_current(self._window)
|
||||||
|
gl.glBindVertexArray(self._vao)
|
||||||
|
|
||||||
|
|
||||||
def _create_software_gl_context() -> moderngl.Context:
|
def _compile_shader(source: str, shader_type: int) -> int:
|
||||||
original_env = os.environ.get("LIBGL_ALWAYS_SOFTWARE")
|
"""Compile a shader and return its ID."""
|
||||||
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
shader = gl.glCreateShader(shader_type)
|
||||||
|
gl.glShaderSource(shader, source)
|
||||||
|
gl.glCompileShader(shader)
|
||||||
|
|
||||||
|
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
||||||
|
error = gl.glGetShaderInfoLog(shader).decode()
|
||||||
|
gl.glDeleteShader(shader)
|
||||||
|
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
||||||
|
|
||||||
|
return shader
|
||||||
|
|
||||||
|
|
||||||
|
def _create_program(vertex_source: str, fragment_source: str) -> int:
|
||||||
|
"""Create and link a shader program."""
|
||||||
|
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
|
||||||
try:
|
try:
|
||||||
ctx = moderngl.create_standalone_context(require=330)
|
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
|
||||||
logger.info(f"Created software-rendered OpenGL context: {ctx.info['GL_RENDERER']}")
|
except RuntimeError:
|
||||||
return ctx
|
gl.glDeleteShader(vertex_shader)
|
||||||
finally:
|
raise
|
||||||
if original_env is None:
|
|
||||||
os.environ.pop("LIBGL_ALWAYS_SOFTWARE", None)
|
program = gl.glCreateProgram()
|
||||||
else:
|
gl.glAttachShader(program, vertex_shader)
|
||||||
os.environ["LIBGL_ALWAYS_SOFTWARE"] = original_env
|
gl.glAttachShader(program, fragment_shader)
|
||||||
|
gl.glLinkProgram(program)
|
||||||
|
|
||||||
|
gl.glDeleteShader(vertex_shader)
|
||||||
|
gl.glDeleteShader(fragment_shader)
|
||||||
|
|
||||||
|
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
||||||
|
error = gl.glGetProgramInfoLog(program).decode()
|
||||||
|
gl.glDeleteProgram(program)
|
||||||
|
raise RuntimeError(f"Program linking failed:\n{error}")
|
||||||
|
|
||||||
|
return program
|
||||||
|
|
||||||
|
|
||||||
def _create_gl_context(force_software: bool = False) -> moderngl.Context:
|
def _render_shader_batch(
|
||||||
if force_software:
|
fragment_code: str,
|
||||||
try:
|
|
||||||
return _create_software_gl_context()
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Failed to create software-rendered OpenGL context.\n"
|
|
||||||
"Ensure Mesa/llvmpipe is installed for software rendering support."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Try hardware rendering first, fall back to software
|
|
||||||
try:
|
|
||||||
ctx = moderngl.create_standalone_context(require=330)
|
|
||||||
logger.info(f"Created OpenGL context: {ctx.info['GL_RENDERER']}")
|
|
||||||
return ctx
|
|
||||||
except Exception as hw_error:
|
|
||||||
logger.warning(f"Hardware OpenGL context creation failed: {hw_error}")
|
|
||||||
logger.info("Attempting software rendering fallback...")
|
|
||||||
try:
|
|
||||||
return _create_software_gl_context()
|
|
||||||
except Exception as sw_error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to create OpenGL context.\n"
|
|
||||||
f"Hardware error: {hw_error}\n\n"
|
|
||||||
f"Possible solutions:\n"
|
|
||||||
f"1. Install GPU drivers with OpenGL 3.3+ support\n"
|
|
||||||
f"2. Install Mesa for software rendering (Linux: apt install libgl1-mesa-dri)\n"
|
|
||||||
f"3. On headless servers, ensure virtual framebuffer (Xvfb) or EGL is available"
|
|
||||||
) from sw_error
|
|
||||||
|
|
||||||
|
|
||||||
def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Texture:
|
|
||||||
height, width = image.shape[:2]
|
|
||||||
channels = image.shape[2] if len(image.shape) > 2 else 1
|
|
||||||
|
|
||||||
components = min(channels, 4)
|
|
||||||
|
|
||||||
image_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# Flip vertically for OpenGL coordinate system (origin at bottom-left)
|
|
||||||
image_uint8 = np.ascontiguousarray(np.flipud(image_uint8))
|
|
||||||
|
|
||||||
texture = ctx.texture((width, height), components, image_uint8.tobytes())
|
|
||||||
texture.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
|
||||||
texture.repeat_x = False
|
|
||||||
texture.repeat_y = False
|
|
||||||
|
|
||||||
return texture
|
|
||||||
|
|
||||||
|
|
||||||
def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
|
|
||||||
width, height = fbo.size
|
|
||||||
|
|
||||||
data = fbo.read(components=channels, attachment=attachment)
|
|
||||||
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
|
|
||||||
|
|
||||||
image = np.ascontiguousarray(np.flipud(image))
|
|
||||||
|
|
||||||
return image.astype(np.float32) / 255.0
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_shader(ctx: moderngl.Context, fragment_source: str) -> moderngl.Program:
|
|
||||||
# Convert user's GLSL ES 3.00 fragment shader to desktop GLSL 3.30 for ModernGL
|
|
||||||
fragment_source = _convert_es_to_desktop_glsl(fragment_source)
|
|
||||||
|
|
||||||
try:
|
|
||||||
program = ctx.program(
|
|
||||||
vertex_shader=VERTEX_SHADER,
|
|
||||||
fragment_shader=fragment_source,
|
|
||||||
)
|
|
||||||
return program
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Fragment shader compilation failed.\n\n"
|
|
||||||
"Make sure your shader:\n"
|
|
||||||
"1. Uses #version 300 es (WebGL 2.0 compatible)\n"
|
|
||||||
"2. Has valid GLSL ES 3.00 syntax\n"
|
|
||||||
"3. Includes 'precision highp float;' after version\n"
|
|
||||||
"4. Uses 'out vec4 fragColor' instead of gl_FragColor\n"
|
|
||||||
"5. Declares uniforms correctly (e.g., uniform sampler2D u_image0;)"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
def _render_shader(
|
|
||||||
ctx: moderngl.Context,
|
|
||||||
program: moderngl.Program,
|
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
textures: list[moderngl.Texture],
|
image_batches: list[list[np.ndarray]],
|
||||||
uniforms: dict[str, int | float],
|
floats: list[float],
|
||||||
) -> list[np.ndarray]:
|
ints: list[int],
|
||||||
# Create output textures
|
) -> list[list[np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Render a fragment shader for multiple batches efficiently.
|
||||||
|
|
||||||
|
Compiles shader once, reuses framebuffer/textures across batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragment_code: User's fragment shader code
|
||||||
|
width: Output width
|
||||||
|
height: Output height
|
||||||
|
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||||
|
floats: List of float uniforms
|
||||||
|
ints: List of int uniforms
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||||
|
"""
|
||||||
|
if not image_batches:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ctx = GLContext()
|
||||||
|
ctx.make_current()
|
||||||
|
|
||||||
|
# Convert from GLSL ES to desktop GLSL 330
|
||||||
|
fragment_source = _convert_es_to_desktop(fragment_code)
|
||||||
|
|
||||||
|
# Track resources for cleanup
|
||||||
|
program = None
|
||||||
|
fbo = None
|
||||||
output_textures = []
|
output_textures = []
|
||||||
for _ in range(MAX_OUTPUTS):
|
input_textures = []
|
||||||
tex = ctx.texture((width, height), 4)
|
|
||||||
tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
|
||||||
output_textures.append(tex)
|
|
||||||
|
|
||||||
fbo = ctx.framebuffer(color_attachments=output_textures)
|
num_inputs = len(image_batches[0])
|
||||||
|
|
||||||
# Full-screen quad vertices (position + texcoord)
|
|
||||||
vertices = np.array([
|
|
||||||
# Position (x, y), Texcoord (u, v)
|
|
||||||
-1.0, -1.0, 0.0, 0.0,
|
|
||||||
1.0, -1.0, 1.0, 0.0,
|
|
||||||
-1.0, 1.0, 0.0, 1.0,
|
|
||||||
1.0, 1.0, 1.0, 1.0,
|
|
||||||
], dtype='f4')
|
|
||||||
|
|
||||||
vbo = ctx.buffer(vertices.tobytes())
|
|
||||||
vao = ctx.vertex_array(
|
|
||||||
program,
|
|
||||||
[(vbo, '2f 2f', 'in_position', 'in_texcoord')],
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Bind textures
|
# Compile shaders (once for all batches)
|
||||||
for i, texture in enumerate(textures):
|
try:
|
||||||
texture.use(i)
|
program = _create_program(VERTEX_SHADER, fragment_source)
|
||||||
uniform_name = f'u_image{i}'
|
except RuntimeError:
|
||||||
if uniform_name in program:
|
logger.error(f"Fragment shader:\n{fragment_source}")
|
||||||
program[uniform_name].value = i
|
raise
|
||||||
|
|
||||||
# Set uniforms
|
gl.glUseProgram(program)
|
||||||
if 'u_resolution' in program:
|
|
||||||
program['u_resolution'].value = (float(width), float(height))
|
|
||||||
|
|
||||||
for name, value in uniforms.items():
|
# Create framebuffer with multiple color attachments (reused for all batches)
|
||||||
if name in program:
|
fbo = gl.glGenFramebuffers(1)
|
||||||
program[name].value = value
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
||||||
|
|
||||||
# Render
|
draw_buffers = []
|
||||||
fbo.use()
|
|
||||||
fbo.clear(0.0, 0.0, 0.0, 1.0)
|
|
||||||
vao.render(moderngl.TRIANGLE_STRIP)
|
|
||||||
|
|
||||||
# Read results from all attachments
|
|
||||||
results = []
|
|
||||||
for i in range(MAX_OUTPUTS):
|
for i in range(MAX_OUTPUTS):
|
||||||
results.append(_texture_to_image(fbo, attachment=i, channels=4))
|
tex = gl.glGenTextures(1)
|
||||||
return results
|
output_textures.append(tex)
|
||||||
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||||
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
|
||||||
|
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
|
||||||
|
|
||||||
|
gl.glDrawBuffers(MAX_OUTPUTS, draw_buffers)
|
||||||
|
|
||||||
|
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
|
||||||
|
raise RuntimeError("Framebuffer is not complete")
|
||||||
|
|
||||||
|
# Create input textures (reused for all batches)
|
||||||
|
for i in range(num_inputs):
|
||||||
|
tex = gl.glGenTextures(1)
|
||||||
|
input_textures.append(tex)
|
||||||
|
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
||||||
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||||
|
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_image{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1i(loc, i)
|
||||||
|
|
||||||
|
# Set static uniforms (once for all batches)
|
||||||
|
loc = gl.glGetUniformLocation(program, "u_resolution")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform2f(loc, float(width), float(height))
|
||||||
|
|
||||||
|
for i, v in enumerate(floats):
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_float{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1f(loc, v)
|
||||||
|
|
||||||
|
for i, v in enumerate(ints):
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_int{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1i(loc, v)
|
||||||
|
|
||||||
|
gl.glViewport(0, 0, width, height)
|
||||||
|
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
|
||||||
|
|
||||||
|
# Process each batch
|
||||||
|
all_batch_outputs = []
|
||||||
|
for images in image_batches:
|
||||||
|
# Update input textures with this batch's images
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
||||||
|
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
|
||||||
|
|
||||||
|
# Flip vertically for GL coordinates, ensure RGBA
|
||||||
|
img_flipped = np.ascontiguousarray(img[::-1, :, :])
|
||||||
|
if img_flipped.shape[2] == 3:
|
||||||
|
img_flipped = np.ascontiguousarray(np.concatenate(
|
||||||
|
[img_flipped, np.ones((*img_flipped.shape[:2], 1), dtype=np.float32)],
|
||||||
|
axis=2,
|
||||||
|
))
|
||||||
|
|
||||||
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, img_flipped.shape[1], img_flipped.shape[0], 0, gl.GL_RGBA, gl.GL_FLOAT, img_flipped)
|
||||||
|
|
||||||
|
# Render
|
||||||
|
gl.glClearColor(0, 0, 0, 0)
|
||||||
|
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
||||||
|
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
||||||
|
|
||||||
|
# Read back outputs for this batch
|
||||||
|
batch_outputs = []
|
||||||
|
for tex in output_textures:
|
||||||
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||||
|
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
||||||
|
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
||||||
|
batch_outputs.append(np.ascontiguousarray(img[::-1, :, :]))
|
||||||
|
|
||||||
|
all_batch_outputs.append(batch_outputs)
|
||||||
|
|
||||||
|
return all_batch_outputs
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
vao.release()
|
# Unbind before deleting
|
||||||
vbo.release()
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
||||||
for tex in output_textures:
|
gl.glUseProgram(0)
|
||||||
tex.release()
|
|
||||||
fbo.release()
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_textures(
|
|
||||||
ctx: moderngl.Context,
|
|
||||||
image_list: list[torch.Tensor],
|
|
||||||
batch_idx: int,
|
|
||||||
) -> list[moderngl.Texture]:
|
|
||||||
textures = []
|
|
||||||
for img_tensor in image_list[:MAX_IMAGES]:
|
|
||||||
img_idx = min(batch_idx, img_tensor.shape[0] - 1)
|
|
||||||
img_np = img_tensor[img_idx].cpu().numpy()
|
|
||||||
textures.append(_image_to_texture(ctx, img_np))
|
|
||||||
return textures
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_uniforms(int_list: list[int], float_list: list[float]) -> dict[str, int | float]:
|
|
||||||
uniforms: dict[str, int | float] = {}
|
|
||||||
for i, val in enumerate(int_list[:MAX_UNIFORMS]):
|
|
||||||
uniforms[f'u_int{i}'] = int(val)
|
|
||||||
for i, val in enumerate(float_list[:MAX_UNIFORMS]):
|
|
||||||
uniforms[f'u_float{i}'] = float(val)
|
|
||||||
return uniforms
|
|
||||||
|
|
||||||
|
|
||||||
def _release_textures(textures: list[moderngl.Texture]) -> None:
|
|
||||||
for texture in textures:
|
|
||||||
texture.release()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _gl_context(force_software: bool = False) -> Generator[moderngl.Context, None, None]:
|
|
||||||
ctx = _create_gl_context(force_software)
|
|
||||||
try:
|
|
||||||
yield ctx
|
|
||||||
finally:
|
|
||||||
ctx.release()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _shader_program(ctx: moderngl.Context, fragment_source: str) -> Generator[moderngl.Program, None, None]:
|
|
||||||
program = _compile_shader(ctx, fragment_source)
|
|
||||||
try:
|
|
||||||
yield program
|
|
||||||
finally:
|
|
||||||
program.release()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _textures_context(
|
|
||||||
ctx: moderngl.Context,
|
|
||||||
image_list: list[torch.Tensor],
|
|
||||||
batch_idx: int,
|
|
||||||
) -> Generator[list[moderngl.Texture], None, None]:
|
|
||||||
textures = _prepare_textures(ctx, image_list, batch_idx)
|
|
||||||
try:
|
|
||||||
yield textures
|
|
||||||
finally:
|
|
||||||
_release_textures(textures)
|
|
||||||
|
|
||||||
|
if input_textures:
|
||||||
|
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||||
|
if output_textures:
|
||||||
|
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||||
|
if fbo is not None:
|
||||||
|
gl.glDeleteFramebuffers(1, [fbo])
|
||||||
|
if program is not None:
|
||||||
|
gl.glDeleteProgram(program)
|
||||||
|
|
||||||
class GLSLShader(io.ComfyNode):
|
class GLSLShader(io.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
# Create autogrow templates
|
|
||||||
image_template = io.Autogrow.TemplatePrefix(
|
image_template = io.Autogrow.TemplatePrefix(
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
prefix="image",
|
prefix="image",
|
||||||
@ -335,15 +401,22 @@ class GLSLShader(io.ComfyNode):
|
|||||||
io.DynamicCombo.Input(
|
io.DynamicCombo.Input(
|
||||||
"size_mode",
|
"size_mode",
|
||||||
options=[
|
options=[
|
||||||
io.DynamicCombo.Option(
|
io.DynamicCombo.Option("from_input", []),
|
||||||
"from_input",
|
|
||||||
[], # No extra inputs - uses first input image dimensions
|
|
||||||
),
|
|
||||||
io.DynamicCombo.Option(
|
io.DynamicCombo.Option(
|
||||||
"custom",
|
"custom",
|
||||||
[
|
[
|
||||||
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
io.Int.Input(
|
||||||
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
"width",
|
||||||
|
default=512,
|
||||||
|
min=1,
|
||||||
|
max=nodes.MAX_RESOLUTION,
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"height",
|
||||||
|
default=512,
|
||||||
|
min=1,
|
||||||
|
max=nodes.MAX_RESOLUTION,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -372,7 +445,9 @@ class GLSLShader(io.ComfyNode):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
image_list = [v for v in images.values() if v is not None]
|
image_list = [v for v in images.values() if v is not None]
|
||||||
float_list = [v if v is not None else 0.0 for v in floats.values()] if floats else []
|
float_list = (
|
||||||
|
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||||
|
)
|
||||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||||
|
|
||||||
if not image_list:
|
if not image_list:
|
||||||
@ -380,34 +455,44 @@ class GLSLShader(io.ComfyNode):
|
|||||||
|
|
||||||
# Determine output dimensions
|
# Determine output dimensions
|
||||||
if size_mode["size_mode"] == "custom":
|
if size_mode["size_mode"] == "custom":
|
||||||
out_width, out_height = size_mode["width"], size_mode["height"]
|
out_width = size_mode["width"]
|
||||||
|
out_height = size_mode["height"]
|
||||||
else:
|
else:
|
||||||
out_height, out_width = image_list[0].shape[1], image_list[0].shape[2]
|
out_height, out_width = image_list[0].shape[1:3]
|
||||||
|
|
||||||
batch_size = image_list[0].shape[0]
|
batch_size = image_list[0].shape[0]
|
||||||
uniforms = _prepare_uniforms(int_list, float_list)
|
|
||||||
|
|
||||||
with _gl_context(force_software=args.cpu) as ctx:
|
# Prepare batches
|
||||||
with _shader_program(ctx, fragment_shader) as program:
|
image_batches = []
|
||||||
# Collect outputs for each render target across all batches
|
for batch_idx in range(batch_size):
|
||||||
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)]
|
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
|
||||||
|
image_batches.append(batch_images)
|
||||||
|
|
||||||
for b in range(batch_size):
|
all_batch_outputs = _render_shader_batch(
|
||||||
with _textures_context(ctx, image_list, b) as textures:
|
fragment_shader,
|
||||||
results = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
|
out_width,
|
||||||
for i, result in enumerate(results):
|
out_height,
|
||||||
all_outputs[i].append(torch.from_numpy(result))
|
image_batches,
|
||||||
|
float_list,
|
||||||
|
int_list,
|
||||||
|
)
|
||||||
|
|
||||||
# Stack batches for each output
|
# Collect outputs into tensors
|
||||||
output_values = []
|
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
|
||||||
for i in range(MAX_OUTPUTS):
|
for batch_outputs in all_batch_outputs:
|
||||||
output_batch = torch.stack(all_outputs[i], dim=0)
|
for i, out_img in enumerate(batch_outputs):
|
||||||
output_values.append(output_batch)
|
all_outputs[i].append(torch.from_numpy(out_img))
|
||||||
|
|
||||||
return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0]))
|
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
|
||||||
|
return io.NodeOutput(
|
||||||
|
*output_tensors,
|
||||||
|
ui=cls._build_ui_output(image_list, output_tensors[0]),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]:
|
def _build_ui_output(
|
||||||
|
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
||||||
|
) -> dict[str, list]:
|
||||||
"""Build UI output with input and output images for client-side shader execution."""
|
"""Build UI output with input and output images for client-side shader execution."""
|
||||||
combined_inputs = torch.cat(image_list, dim=0)
|
combined_inputs = torch.cat(image_list, dim=0)
|
||||||
input_images_ui = ui.ImageSaveHelper.save_images(
|
input_images_ui = ui.ImageSaveHelper.save_images(
|
||||||
|
|||||||
@ -29,4 +29,6 @@ kornia>=0.7.1
|
|||||||
spandrel
|
spandrel
|
||||||
pydantic~=2.0
|
pydantic~=2.0
|
||||||
pydantic-settings~=2.0
|
pydantic-settings~=2.0
|
||||||
moderngl
|
PyOpenGL
|
||||||
|
PyOpenGL-accelerate
|
||||||
|
glfw
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user