convert to using PyOpenGL and glfw

This commit is contained in:
pythongosssss 2026-01-28 20:48:20 -08:00
parent aaea976f36
commit a4317314d2
3 changed files with 347 additions and 264 deletions

View File

@ -25,10 +25,6 @@ jobs:
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y libx11-dev
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip

View File

@ -1,18 +1,59 @@
import os import os
from comfy.cli_args import args
if args.cpu:
os.environ.setdefault("PYOPENGL_PLATFORM", "osmesa")
elif not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"):
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
import re import re
import logging import logging
from contextlib import contextmanager from typing import TypedDict
from typing import TypedDict, Generator
import numpy as np import numpy as np
import torch import torch
import nodes import nodes
from comfy_api.latest import ComfyExtension, io, ui from comfy_api.latest import ComfyExtension, io, ui
from comfy.cli_args import args
from typing_extensions import override from typing_extensions import override
from utils.install_util import get_missing_requirements_message from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
try:
import glfw
import OpenGL.GL as gl
except ImportError as e:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
"Install with: pip install PyOpenGL PyOpenGL-accelerate glfw"
) from e
except AttributeError as e:
# This happens when PyOpenGL can't load the requested platform (e.g., OSMesa not installed)
platform = os.environ.get("PYOPENGL_PLATFORM", "default")
if platform == "osmesa":
raise RuntimeError(
"OSMesa (software rendering) requested but not installed.\n"
"OSMesa is required for --cpu mode.\n\n"
"Install OSMesa:\n"
" e.g. Ubuntu/Debian: sudo apt install libosmesa6-dev\n"
"Or disable CPU mode to use hardware rendering."
) from e
elif platform == "egl":
raise RuntimeError(
"EGL (headless rendering) requested but not available.\n"
"EGL is used for headless GPU rendering without a display.\n\n"
"Install EGL:\n"
" e.g. Ubuntu/Debian: sudo apt install libegl1-mesa-dev libgles2-mesa-dev\n"
"Or set DISPLAY/WAYLAND_DISPLAY environment variable if you have a display."
) from e
else:
raise RuntimeError(
f"OpenGL initialization failed (platform: {platform}).\n"
"Ensure OpenGL drivers are installed and working."
) from e
class SizeModeInput(TypedDict): class SizeModeInput(TypedDict):
size_mode: str size_mode: str
@ -24,15 +65,25 @@ MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4 MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_OUTPUTS = 4 # fragColor0-3 (MRT) MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
logger = logging.getLogger(__name__) # Vertex shader using gl_VertexID trick - no VBO needed.
# Draws a single triangle that covers the entire screen:
#
# (-1,3)
# /|
# / | <- visible area is the unit square from (-1,-1) to (1,1)
# / | parts outside get clipped away
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
gl_Position = vec4(verts[gl_VertexID], 0, 1);
}
"""
try:
import moderngl
except ImportError as e:
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
# Default NOOP fragment shader that passes through the input image unchanged
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
DEFAULT_FRAGMENT_SHADER = """#version 300 es DEFAULT_FRAGMENT_SHADER = """#version 300 es
precision highp float; precision highp float;
@ -48,252 +99,267 @@ void main() {
""" """
# Simple vertex shader for full-screen quad def _convert_es_to_desktop(source: str) -> str:
VERTEX_SHADER = """#version 330 """Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
in vec2 in_position; source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
in vec2 in_texcoord; # Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
out vec2 v_texCoord; # Prepend desktop GLSL version
return "#version 330 core\n" + source
void main() {
gl_Position = vec4(in_position, 0.0, 1.0);
v_texCoord = in_texcoord;
}
"""
def _convert_es_to_desktop_glsl(source: str) -> str: class GLContext:
"""Convert GLSL ES 3.00 shader to desktop GLSL 3.30 for ModernGL compatibility.""" """Manages OpenGL context and resources for shader execution."""
return re.sub(r'#version\s+300\s+es', '#version 330', source)
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if GLContext._initialized:
return
GLContext._initialized = True
import time
start = time.perf_counter()
if not glfw.init():
raise RuntimeError("Failed to initialize GLFW")
glfw.window_hint(glfw.VISIBLE, glfw.FALSE)
glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3)
glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3)
glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE)
self._window = glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not self._window:
glfw.terminate()
raise RuntimeError("Failed to create GLFW window")
glfw.make_context_current(self._window)
# Create VAO (required for core profile even if we don't use vertex attributes)
self._vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(self._vao)
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - {renderer} ({vendor}), GL {version}")
def make_current(self):
glfw.make_context_current(self._window)
gl.glBindVertexArray(self._vao)
def _create_software_gl_context() -> moderngl.Context: def _compile_shader(source: str, shader_type: int) -> int:
original_env = os.environ.get("LIBGL_ALWAYS_SOFTWARE") """Compile a shader and return its ID."""
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1" shader = gl.glCreateShader(shader_type)
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
return shader
def _create_program(vertex_source: str, fragment_source: str) -> int:
"""Create and link a shader program."""
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
try: try:
ctx = moderngl.create_standalone_context(require=330) fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
logger.info(f"Created software-rendered OpenGL context: {ctx.info['GL_RENDERER']}") except RuntimeError:
return ctx gl.glDeleteShader(vertex_shader)
finally: raise
if original_env is None:
os.environ.pop("LIBGL_ALWAYS_SOFTWARE", None) program = gl.glCreateProgram()
else: gl.glAttachShader(program, vertex_shader)
os.environ["LIBGL_ALWAYS_SOFTWARE"] = original_env gl.glAttachShader(program, fragment_shader)
gl.glLinkProgram(program)
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
return program
def _create_gl_context(force_software: bool = False) -> moderngl.Context: def _render_shader_batch(
if force_software: fragment_code: str,
try:
return _create_software_gl_context()
except Exception as e:
raise RuntimeError(
"Failed to create software-rendered OpenGL context.\n"
"Ensure Mesa/llvmpipe is installed for software rendering support."
) from e
# Try hardware rendering first, fall back to software
try:
ctx = moderngl.create_standalone_context(require=330)
logger.info(f"Created OpenGL context: {ctx.info['GL_RENDERER']}")
return ctx
except Exception as hw_error:
logger.warning(f"Hardware OpenGL context creation failed: {hw_error}")
logger.info("Attempting software rendering fallback...")
try:
return _create_software_gl_context()
except Exception as sw_error:
raise RuntimeError(
f"Failed to create OpenGL context.\n"
f"Hardware error: {hw_error}\n\n"
f"Possible solutions:\n"
f"1. Install GPU drivers with OpenGL 3.3+ support\n"
f"2. Install Mesa for software rendering (Linux: apt install libgl1-mesa-dri)\n"
f"3. On headless servers, ensure virtual framebuffer (Xvfb) or EGL is available"
) from sw_error
def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Texture:
height, width = image.shape[:2]
channels = image.shape[2] if len(image.shape) > 2 else 1
components = min(channels, 4)
image_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
# Flip vertically for OpenGL coordinate system (origin at bottom-left)
image_uint8 = np.ascontiguousarray(np.flipud(image_uint8))
texture = ctx.texture((width, height), components, image_uint8.tobytes())
texture.filter = (moderngl.LINEAR, moderngl.LINEAR)
texture.repeat_x = False
texture.repeat_y = False
return texture
def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
width, height = fbo.size
data = fbo.read(components=channels, attachment=attachment)
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
image = np.ascontiguousarray(np.flipud(image))
return image.astype(np.float32) / 255.0
def _compile_shader(ctx: moderngl.Context, fragment_source: str) -> moderngl.Program:
# Convert user's GLSL ES 3.00 fragment shader to desktop GLSL 3.30 for ModernGL
fragment_source = _convert_es_to_desktop_glsl(fragment_source)
try:
program = ctx.program(
vertex_shader=VERTEX_SHADER,
fragment_shader=fragment_source,
)
return program
except Exception as e:
raise RuntimeError(
"Fragment shader compilation failed.\n\n"
"Make sure your shader:\n"
"1. Uses #version 300 es (WebGL 2.0 compatible)\n"
"2. Has valid GLSL ES 3.00 syntax\n"
"3. Includes 'precision highp float;' after version\n"
"4. Uses 'out vec4 fragColor' instead of gl_FragColor\n"
"5. Declares uniforms correctly (e.g., uniform sampler2D u_image0;)"
) from e
def _render_shader(
ctx: moderngl.Context,
program: moderngl.Program,
width: int, width: int,
height: int, height: int,
textures: list[moderngl.Texture], image_batches: list[list[np.ndarray]],
uniforms: dict[str, int | float], floats: list[float],
) -> list[np.ndarray]: ints: list[int],
# Create output textures ) -> list[list[np.ndarray]]:
"""
Render a fragment shader for multiple batches efficiently.
Compiles shader once, reuses framebuffer/textures across batches.
Args:
fragment_code: User's fragment shader code
width: Output width
height: Output height
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
floats: List of float uniforms
ints: List of int uniforms
Returns:
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
"""
if not image_batches:
return []
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Track resources for cleanup
program = None
fbo = None
output_textures = [] output_textures = []
for _ in range(MAX_OUTPUTS): input_textures = []
tex = ctx.texture((width, height), 4)
tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
output_textures.append(tex)
fbo = ctx.framebuffer(color_attachments=output_textures) num_inputs = len(image_batches[0])
# Full-screen quad vertices (position + texcoord)
vertices = np.array([
# Position (x, y), Texcoord (u, v)
-1.0, -1.0, 0.0, 0.0,
1.0, -1.0, 1.0, 0.0,
-1.0, 1.0, 0.0, 1.0,
1.0, 1.0, 1.0, 1.0,
], dtype='f4')
vbo = ctx.buffer(vertices.tobytes())
vao = ctx.vertex_array(
program,
[(vbo, '2f 2f', 'in_position', 'in_texcoord')],
)
try: try:
# Bind textures # Compile shaders (once for all batches)
for i, texture in enumerate(textures): try:
texture.use(i) program = _create_program(VERTEX_SHADER, fragment_source)
uniform_name = f'u_image{i}' except RuntimeError:
if uniform_name in program: logger.error(f"Fragment shader:\n{fragment_source}")
program[uniform_name].value = i raise
# Set uniforms gl.glUseProgram(program)
if 'u_resolution' in program:
program['u_resolution'].value = (float(width), float(height))
for name, value in uniforms.items(): # Create framebuffer with multiple color attachments (reused for all batches)
if name in program: fbo = gl.glGenFramebuffers(1)
program[name].value = value gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
# Render draw_buffers = []
fbo.use()
fbo.clear(0.0, 0.0, 0.0, 1.0)
vao.render(moderngl.TRIANGLE_STRIP)
# Read results from all attachments
results = []
for i in range(MAX_OUTPUTS): for i in range(MAX_OUTPUTS):
results.append(_texture_to_image(fbo, attachment=i, channels=4)) tex = gl.glGenTextures(1)
return results output_textures.append(tex)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
gl.glDrawBuffers(MAX_OUTPUTS, draw_buffers)
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
raise RuntimeError("Framebuffer is not complete")
# Create input textures (reused for all batches)
for i in range(num_inputs):
tex = gl.glGenTextures(1)
input_textures.append(tex)
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
loc = gl.glGetUniformLocation(program, f"u_image{i}")
if loc >= 0:
gl.glUniform1i(loc, i)
# Set static uniforms (once for all batches)
loc = gl.glGetUniformLocation(program, "u_resolution")
if loc >= 0:
gl.glUniform2f(loc, float(width), float(height))
for i, v in enumerate(floats):
loc = gl.glGetUniformLocation(program, f"u_float{i}")
if loc >= 0:
gl.glUniform1f(loc, v)
for i, v in enumerate(ints):
loc = gl.glGetUniformLocation(program, f"u_int{i}")
if loc >= 0:
gl.glUniform1i(loc, v)
gl.glViewport(0, 0, width, height)
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
# Process each batch
all_batch_outputs = []
for images in image_batches:
# Update input textures with this batch's images
for i, img in enumerate(images):
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
# Flip vertically for GL coordinates, ensure RGBA
img_flipped = np.ascontiguousarray(img[::-1, :, :])
if img_flipped.shape[2] == 3:
img_flipped = np.ascontiguousarray(np.concatenate(
[img_flipped, np.ones((*img_flipped.shape[:2], 1), dtype=np.float32)],
axis=2,
))
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, img_flipped.shape[1], img_flipped.shape[0], 0, gl.GL_RGBA, gl.GL_FLOAT, img_flipped)
# Render
gl.glClearColor(0, 0, 0, 0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(np.ascontiguousarray(img[::-1, :, :]))
all_batch_outputs.append(batch_outputs)
return all_batch_outputs
finally: finally:
vao.release() # Unbind before deleting
vbo.release() gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
for tex in output_textures: gl.glUseProgram(0)
tex.release()
fbo.release()
def _prepare_textures(
ctx: moderngl.Context,
image_list: list[torch.Tensor],
batch_idx: int,
) -> list[moderngl.Texture]:
textures = []
for img_tensor in image_list[:MAX_IMAGES]:
img_idx = min(batch_idx, img_tensor.shape[0] - 1)
img_np = img_tensor[img_idx].cpu().numpy()
textures.append(_image_to_texture(ctx, img_np))
return textures
def _prepare_uniforms(int_list: list[int], float_list: list[float]) -> dict[str, int | float]:
uniforms: dict[str, int | float] = {}
for i, val in enumerate(int_list[:MAX_UNIFORMS]):
uniforms[f'u_int{i}'] = int(val)
for i, val in enumerate(float_list[:MAX_UNIFORMS]):
uniforms[f'u_float{i}'] = float(val)
return uniforms
def _release_textures(textures: list[moderngl.Texture]) -> None:
for texture in textures:
texture.release()
@contextmanager
def _gl_context(force_software: bool = False) -> Generator[moderngl.Context, None, None]:
ctx = _create_gl_context(force_software)
try:
yield ctx
finally:
ctx.release()
@contextmanager
def _shader_program(ctx: moderngl.Context, fragment_source: str) -> Generator[moderngl.Program, None, None]:
program = _compile_shader(ctx, fragment_source)
try:
yield program
finally:
program.release()
@contextmanager
def _textures_context(
ctx: moderngl.Context,
image_list: list[torch.Tensor],
batch_idx: int,
) -> Generator[list[moderngl.Texture], None, None]:
textures = _prepare_textures(ctx, image_list, batch_idx)
try:
yield textures
finally:
_release_textures(textures)
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
if program is not None:
gl.glDeleteProgram(program)
class GLSLShader(io.ComfyNode): class GLSLShader(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
# Create autogrow templates
image_template = io.Autogrow.TemplatePrefix( image_template = io.Autogrow.TemplatePrefix(
io.Image.Input("image"), io.Image.Input("image"),
prefix="image", prefix="image",
@ -335,15 +401,22 @@ class GLSLShader(io.ComfyNode):
io.DynamicCombo.Input( io.DynamicCombo.Input(
"size_mode", "size_mode",
options=[ options=[
io.DynamicCombo.Option( io.DynamicCombo.Option("from_input", []),
"from_input",
[], # No extra inputs - uses first input image dimensions
),
io.DynamicCombo.Option( io.DynamicCombo.Option(
"custom", "custom",
[ [
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION), io.Int.Input(
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION), "width",
default=512,
min=1,
max=nodes.MAX_RESOLUTION,
),
io.Int.Input(
"height",
default=512,
min=1,
max=nodes.MAX_RESOLUTION,
),
], ],
), ),
], ],
@ -372,7 +445,9 @@ class GLSLShader(io.ComfyNode):
**kwargs, **kwargs,
) -> io.NodeOutput: ) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None] image_list = [v for v in images.values() if v is not None]
float_list = [v if v is not None else 0.0 for v in floats.values()] if floats else [] float_list = (
[v if v is not None else 0.0 for v in floats.values()] if floats else []
)
int_list = [v if v is not None else 0 for v in ints.values()] if ints else [] int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
if not image_list: if not image_list:
@ -380,34 +455,44 @@ class GLSLShader(io.ComfyNode):
# Determine output dimensions # Determine output dimensions
if size_mode["size_mode"] == "custom": if size_mode["size_mode"] == "custom":
out_width, out_height = size_mode["width"], size_mode["height"] out_width = size_mode["width"]
out_height = size_mode["height"]
else: else:
out_height, out_width = image_list[0].shape[1], image_list[0].shape[2] out_height, out_width = image_list[0].shape[1:3]
batch_size = image_list[0].shape[0] batch_size = image_list[0].shape[0]
uniforms = _prepare_uniforms(int_list, float_list)
with _gl_context(force_software=args.cpu) as ctx: # Prepare batches
with _shader_program(ctx, fragment_shader) as program: image_batches = []
# Collect outputs for each render target across all batches for batch_idx in range(batch_size):
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)] batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
image_batches.append(batch_images)
for b in range(batch_size): all_batch_outputs = _render_shader_batch(
with _textures_context(ctx, image_list, b) as textures: fragment_shader,
results = _render_shader(ctx, program, out_width, out_height, textures, uniforms) out_width,
for i, result in enumerate(results): out_height,
all_outputs[i].append(torch.from_numpy(result)) image_batches,
float_list,
int_list,
)
# Stack batches for each output # Collect outputs into tensors
output_values = [] all_outputs = [[] for _ in range(MAX_OUTPUTS)]
for i in range(MAX_OUTPUTS): for batch_outputs in all_batch_outputs:
output_batch = torch.stack(all_outputs[i], dim=0) for i, out_img in enumerate(batch_outputs):
output_values.append(output_batch) all_outputs[i].append(torch.from_numpy(out_img))
return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0])) output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
return io.NodeOutput(
*output_tensors,
ui=cls._build_ui_output(image_list, output_tensors[0]),
)
@classmethod @classmethod
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]: def _build_ui_output(
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
) -> dict[str, list]:
"""Build UI output with input and output images for client-side shader execution.""" """Build UI output with input and output images for client-side shader execution."""
combined_inputs = torch.cat(image_list, dim=0) combined_inputs = torch.cat(image_list, dim=0)
input_images_ui = ui.ImageSaveHelper.save_images( input_images_ui = ui.ImageSaveHelper.save_images(

View File

@ -29,4 +29,6 @@ kornia>=0.7.1
spandrel spandrel
pydantic~=2.0 pydantic~=2.0
pydantic-settings~=2.0 pydantic-settings~=2.0
moderngl PyOpenGL
PyOpenGL-accelerate
glfw