This commit is contained in:
pythongosssss 2026-01-30 07:20:47 +02:00 committed by GitHub
commit 20dafde12e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 756 additions and 0 deletions

752
comfy_extras/nodes_glsl.py Normal file
View File

@ -0,0 +1,752 @@
import os
import sys
import re
import logging
import ctypes.util
import importlib.util
from typing import TypedDict
import numpy as np
import torch
import nodes
from comfy_api.latest import ComfyExtension, io, ui
from typing_extensions import override
from utils.install_util import get_missing_requirements_message
logger = logging.getLogger(__name__)
def _check_opengl_availability():
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
missing = []
# Check Python packages (using find_spec to avoid importing)
if importlib.util.find_spec("glfw") is None:
missing.append("glfw")
if importlib.util.find_spec("OpenGL") is None:
missing.append("PyOpenGL")
if missing:
raise RuntimeError(
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
)
# On Linux without display, check if headless backends are available
if sys.platform.startswith("linux"):
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
if not has_display:
# Check for EGL or OSMesa libraries
has_egl = ctypes.util.find_library("EGL")
has_osmesa = ctypes.util.find_library("OSMesa")
# Error disabled for CI as it fails this check
# if not has_egl and not has_osmesa:
# raise RuntimeError(
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
# "See error below for installation instructions."
# )
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
# Run early check at import time
_check_opengl_availability()
# OpenGL modules - initialized lazily when context is created
gl = None
glfw = None
EGL = None
def _import_opengl():
"""Import OpenGL module. Called after context is created."""
global gl
if gl is None:
import OpenGL.GL as _gl
gl = _gl
return gl
class SizeModeInput(TypedDict):
size_mode: str
width: int
height: int
MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# Vertex shader using gl_VertexID trick - no VBO needed.
# Draws a single triangle that covers the entire screen:
#
# (-1,3)
# /|
# / | <- visible area is the unit square from (-1,-1) to (1,1)
# / | parts outside get clipped away
# (-1,-1)---(3,-1)
#
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
VERTEX_SHADER = """#version 330 core
out vec2 v_texCoord;
void main() {
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
gl_Position = vec4(verts[gl_VertexID], 0, 1);
}
"""
DEFAULT_FRAGMENT_SHADER = """#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
void main() {
fragColor0 = texture(u_image0, v_texCoord);
}
"""
def _convert_es_to_desktop(source: str) -> str:
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
# Remove any existing #version directive
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
# Remove precision qualifiers (not needed in desktop GLSL)
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
# Prepend desktop GLSL version
return "#version 330 core\n" + source
def _detect_output_count(source: str) -> int:
"""Detect how many fragColor outputs are used in the shader.
Returns the count of outputs needed (1 to MAX_OUTPUTS).
"""
matches = re.findall(r"fragColor(\d+)", source)
if not matches:
return 1 # Default to 1 output if none found
max_index = max(int(m) for m in matches)
return min(max_index + 1, MAX_OUTPUTS)
def _init_glfw():
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
import glfw as _glfw
if not _glfw.init():
raise RuntimeError("glfw.init() failed")
try:
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
if not window:
raise RuntimeError("glfw.create_window() failed")
_glfw.make_context_current(window)
return window, _glfw
except Exception:
_glfw.terminate()
raise
def _init_egl():
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
from OpenGL import EGL as _EGL
from OpenGL.EGL import (
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
eglTerminate, eglDestroyContext, eglDestroySurface,
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
)
display = None
context = None
surface = None
try:
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
if display == _EGL.EGL_NO_DISPLAY:
raise RuntimeError("eglGetDisplay() failed")
major, minor = _EGL.EGLint(), _EGL.EGLint()
if not eglInitialize(display, major, minor):
display = None # Not initialized, don't terminate
raise RuntimeError("eglInitialize() failed")
config_attribs = [
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
EGL_DEPTH_SIZE, 0, EGL_NONE
]
configs = (_EGL.EGLConfig * 1)()
num_configs = _EGL.EGLint()
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
raise RuntimeError("eglChooseConfig() failed")
config = configs[0]
if not eglBindAPI(EGL_OPENGL_API):
raise RuntimeError("eglBindAPI() failed")
context_attribs = [
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
EGL_NONE
]
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
if context == EGL_NO_CONTEXT:
raise RuntimeError("eglCreateContext() failed")
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
if surface == _EGL.EGL_NO_SURFACE:
raise RuntimeError("eglCreatePbufferSurface() failed")
if not eglMakeCurrent(display, surface, surface, context):
raise RuntimeError("eglMakeCurrent() failed")
return display, context, surface, _EGL
except Exception:
# Clean up any resources on failure
if surface is not None:
eglDestroySurface(display, surface)
if context is not None:
eglDestroyContext(display, context)
if display is not None:
eglTerminate(display)
raise
def _init_osmesa():
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
import ctypes
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
from OpenGL import GL as _gl
from OpenGL.osmesa import (
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
OSMESA_RGBA,
)
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
if not ctx:
raise RuntimeError("OSMesaCreateContextExt() failed")
width, height = 64, 64
buffer = (ctypes.c_ubyte * (width * height * 4))()
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
OSMesaDestroyContext(ctx)
raise RuntimeError("OSMesaMakeCurrent() failed")
return ctx, buffer
class GLContext:
"""Manages OpenGL context and resources for shader execution.
Tries backends in order: GLFW (desktop) EGL (headless GPU) OSMesa (software).
"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if GLContext._initialized:
return
GLContext._initialized = True
global glfw, EGL
import time
start = time.perf_counter()
self._backend = None
self._window = None
self._egl_display = None
self._egl_context = None
self._egl_surface = None
self._osmesa_ctx = None
self._osmesa_buffer = None
# Try backends in order: GLFW → EGL → OSMesa
errors = []
try:
self._window, glfw = _init_glfw()
self._backend = "glfw"
except Exception as e:
errors.append(("GLFW", e))
if self._backend is None:
try:
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
self._backend = "egl"
except Exception as e:
errors.append(("EGL", e))
if self._backend is None:
try:
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
self._backend = "osmesa"
except Exception as e:
errors.append(("OSMesa", e))
if self._backend is None:
if sys.platform == "win32":
platform_help = (
"Windows: Ensure GPU drivers are installed and display is available.\n"
" CPU-only/headless mode is not supported on Windows."
)
elif sys.platform == "darwin":
platform_help = (
"macOS: Ensure display is available. For headless, try virtual display."
)
else:
platform_help = (
"Linux: Install one of these backends:\n"
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
" Headless (CPU): sudo apt install libosmesa6"
)
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
raise RuntimeError(
f"Failed to create OpenGL context.\n\n"
f"Backend errors:\n{error_details}\n\n"
f"{platform_help}\n\n"
"Python packages: pip install PyOpenGL PyOpenGL-accelerate glfw"
)
# Now import OpenGL.GL (after context is current)
_import_opengl()
# Create VAO (required for core profile, but OSMesa may use compat profile)
self._vao = None
try:
vao = gl.glGenVertexArrays(1)
gl.glBindVertexArray(vao)
self._vao = vao # Only store after successful bind
except Exception:
# OSMesa with older Mesa may not support VAOs
# Clean up if we created but couldn't bind
if vao:
try:
gl.glDeleteVertexArrays(1, [vao])
except Exception:
pass
elapsed = (time.perf_counter() - start) * 1000
# Log device info
renderer = gl.glGetString(gl.GL_RENDERER)
vendor = gl.glGetString(gl.GL_VENDOR)
version = gl.glGetString(gl.GL_VERSION)
renderer = renderer.decode() if renderer else "Unknown"
vendor = vendor.decode() if vendor else "Unknown"
version = version.decode() if version else "Unknown"
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
def make_current(self):
if self._backend == "glfw":
glfw.make_context_current(self._window)
elif self._backend == "egl":
from OpenGL.EGL import eglMakeCurrent
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
elif self._backend == "osmesa":
from OpenGL.osmesa import OSMesaMakeCurrent
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
if self._vao is not None:
gl.glBindVertexArray(self._vao)
def _compile_shader(source: str, shader_type: int) -> int:
"""Compile a shader and return its ID."""
shader = gl.glCreateShader(shader_type)
gl.glShaderSource(shader, source)
gl.glCompileShader(shader)
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
error = gl.glGetShaderInfoLog(shader).decode()
gl.glDeleteShader(shader)
raise RuntimeError(f"Shader compilation failed:\n{error}")
return shader
def _create_program(vertex_source: str, fragment_source: str) -> int:
"""Create and link a shader program."""
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
try:
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
except RuntimeError:
gl.glDeleteShader(vertex_shader)
raise
program = gl.glCreateProgram()
gl.glAttachShader(program, vertex_shader)
gl.glAttachShader(program, fragment_shader)
gl.glLinkProgram(program)
gl.glDeleteShader(vertex_shader)
gl.glDeleteShader(fragment_shader)
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
error = gl.glGetProgramInfoLog(program).decode()
gl.glDeleteProgram(program)
raise RuntimeError(f"Program linking failed:\n{error}")
return program
def _render_shader_batch(
fragment_code: str,
width: int,
height: int,
image_batches: list[list[np.ndarray]],
floats: list[float],
ints: list[int],
) -> list[list[np.ndarray]]:
"""
Render a fragment shader for multiple batches efficiently.
Compiles shader once, reuses framebuffer/textures across batches.
Args:
fragment_code: User's fragment shader code
width: Output width
height: Output height
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
floats: List of float uniforms
ints: List of int uniforms
Returns:
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
"""
if not image_batches:
return []
ctx = GLContext()
ctx.make_current()
# Convert from GLSL ES to desktop GLSL 330
fragment_source = _convert_es_to_desktop(fragment_code)
# Detect how many outputs the shader actually uses
num_outputs = _detect_output_count(fragment_code)
# Track resources for cleanup
program = None
fbo = None
output_textures = []
input_textures = []
num_inputs = len(image_batches[0])
try:
# Compile shaders (once for all batches)
try:
program = _create_program(VERTEX_SHADER, fragment_source)
except RuntimeError:
logger.error(f"Fragment shader:\n{fragment_source}")
raise
gl.glUseProgram(program)
# Create framebuffer with only the needed color attachments
fbo = gl.glGenFramebuffers(1)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
draw_buffers = []
for i in range(num_outputs):
tex = gl.glGenTextures(1)
output_textures.append(tex)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
gl.glDrawBuffers(num_outputs, draw_buffers)
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
raise RuntimeError("Framebuffer is not complete")
# Create input textures (reused for all batches)
for i in range(num_inputs):
tex = gl.glGenTextures(1)
input_textures.append(tex)
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
loc = gl.glGetUniformLocation(program, f"u_image{i}")
if loc >= 0:
gl.glUniform1i(loc, i)
# Set static uniforms (once for all batches)
loc = gl.glGetUniformLocation(program, "u_resolution")
if loc >= 0:
gl.glUniform2f(loc, float(width), float(height))
for i, v in enumerate(floats):
loc = gl.glGetUniformLocation(program, f"u_float{i}")
if loc >= 0:
gl.glUniform1f(loc, v)
for i, v in enumerate(ints):
loc = gl.glGetUniformLocation(program, f"u_int{i}")
if loc >= 0:
gl.glUniform1i(loc, v)
gl.glViewport(0, 0, width, height)
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
# Process each batch
all_batch_outputs = []
for images in image_batches:
# Update input textures with this batch's images
for i, img in enumerate(images):
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
# Flip vertically for GL coordinates, ensure RGBA
h, w, c = img.shape
if c == 3:
img_upload = np.empty((h, w, 4), dtype=np.float32)
img_upload[:, :, :3] = img[::-1, :, :]
img_upload[:, :, 3] = 1.0
else:
img_upload = np.ascontiguousarray(img[::-1, :, :])
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
# Render
gl.glClearColor(0, 0, 0, 0)
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
# Read back outputs for this batch
# (glGetTexImage is synchronous, implicitly waits for rendering)
batch_outputs = []
for tex in output_textures:
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
batch_outputs.append(np.ascontiguousarray(img[::-1, :, :]))
# Pad with black images for unused outputs
black_img = np.zeros((height, width, 4), dtype=np.float32)
for _ in range(num_outputs, MAX_OUTPUTS):
batch_outputs.append(black_img)
all_batch_outputs.append(batch_outputs)
return all_batch_outputs
finally:
# Unbind before deleting
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
if input_textures:
gl.glDeleteTextures(len(input_textures), input_textures)
if output_textures:
gl.glDeleteTextures(len(output_textures), output_textures)
if fbo is not None:
gl.glDeleteFramebuffers(1, [fbo])
if program is not None:
gl.glDeleteProgram(program)
class GLSLShader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
image_template = io.Autogrow.TemplatePrefix(
io.Image.Input("image"),
prefix="image",
min=1,
max=MAX_IMAGES,
)
float_template = io.Autogrow.TemplatePrefix(
io.Float.Input("float", default=0.0),
prefix="u_float",
min=0,
max=MAX_UNIFORMS,
)
int_template = io.Autogrow.TemplatePrefix(
io.Int.Input("int", default=0),
prefix="u_int",
min=0,
max=MAX_UNIFORMS,
)
return io.Schema(
node_id="GLSLShader",
display_name="GLSL Shader",
category="image/shader",
description=(
f"Apply GLSL fragment shaders to images. "
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
),
inputs=[
io.String.Input(
"fragment_shader",
default=DEFAULT_FRAGMENT_SHADER,
multiline=True,
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
),
io.DynamicCombo.Input(
"size_mode",
options=[
io.DynamicCombo.Option("from_input", []),
io.DynamicCombo.Option(
"custom",
[
io.Int.Input(
"width",
default=512,
min=1,
max=nodes.MAX_RESOLUTION,
),
io.Int.Input(
"height",
default=512,
min=1,
max=nodes.MAX_RESOLUTION,
),
],
),
],
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
),
io.Autogrow.Input("images", template=image_template),
io.Autogrow.Input("floats", template=float_template),
io.Autogrow.Input("ints", template=int_template),
],
outputs=[
io.Image.Output(display_name="IMAGE0"),
io.Image.Output(display_name="IMAGE1"),
io.Image.Output(display_name="IMAGE2"),
io.Image.Output(display_name="IMAGE3"),
],
)
@classmethod
def execute(
cls,
fragment_shader: str,
size_mode: SizeModeInput,
images: io.Autogrow.Type,
floats: io.Autogrow.Type = None,
ints: io.Autogrow.Type = None,
**kwargs,
) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None]
float_list = (
[v if v is not None else 0.0 for v in floats.values()] if floats else []
)
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
if not image_list:
raise ValueError("At least one input image is required")
# Determine output dimensions
if size_mode["size_mode"] == "custom":
out_width = size_mode["width"]
out_height = size_mode["height"]
else:
out_height, out_width = image_list[0].shape[1:3]
batch_size = image_list[0].shape[0]
# Prepare batches
image_batches = []
for batch_idx in range(batch_size):
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
image_batches.append(batch_images)
all_batch_outputs = _render_shader_batch(
fragment_shader,
out_width,
out_height,
image_batches,
float_list,
int_list,
)
# Collect outputs into tensors
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
for batch_outputs in all_batch_outputs:
for i, out_img in enumerate(batch_outputs):
all_outputs[i].append(torch.from_numpy(out_img))
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
return io.NodeOutput(
*output_tensors,
ui=cls._build_ui_output(image_list, output_tensors[0]),
)
@classmethod
def _build_ui_output(
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
) -> dict[str, list]:
"""Build UI output with input and output images for client-side shader execution."""
combined_inputs = torch.cat(image_list, dim=0)
input_images_ui = ui.ImageSaveHelper.save_images(
combined_inputs,
filename_prefix="GLSLShader_input",
folder_type=io.FolderType.temp,
cls=None,
compress_level=1,
)
output_images_ui = ui.ImageSaveHelper.save_images(
output_batch,
filename_prefix="GLSLShader_output",
folder_type=io.FolderType.temp,
cls=None,
compress_level=1,
)
return {"input_images": input_images_ui, "images": output_images_ui}
class GLSLExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [GLSLShader]
async def comfy_entrypoint() -> GLSLExtension:
return GLSLExtension()

View File

@ -2432,6 +2432,7 @@ async def init_builtin_extra_nodes():
"nodes_wanmove.py",
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_glsl.py",
"nodes_lora_debug.py"
]

View File

@ -29,3 +29,6 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL
PyOpenGL-accelerate
glfw