mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 16:50:17 +08:00
732 lines
25 KiB
Python
732 lines
25 KiB
Python
import os
|
|
import sys
|
|
import re
|
|
import logging
|
|
import ctypes.util
|
|
import importlib.util
|
|
from typing import TypedDict
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import nodes
|
|
from comfy_api.latest import ComfyExtension, io, ui
|
|
from typing_extensions import override
|
|
from utils.install_util import get_missing_requirements_message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _check_opengl_availability():
|
|
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
|
missing = []
|
|
|
|
# Check Python packages (using find_spec to avoid importing)
|
|
if importlib.util.find_spec("glfw") is None:
|
|
missing.append("glfw")
|
|
|
|
if importlib.util.find_spec("OpenGL") is None:
|
|
missing.append("PyOpenGL")
|
|
|
|
if missing:
|
|
raise RuntimeError(
|
|
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
|
)
|
|
|
|
# On Linux without display, check if headless backends are available
|
|
if sys.platform.startswith("linux"):
|
|
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
|
if not has_display:
|
|
# Check for EGL or OSMesa libraries
|
|
has_egl = ctypes.util.find_library("EGL")
|
|
has_osmesa = ctypes.util.find_library("OSMesa")
|
|
|
|
if not has_egl and not has_osmesa:
|
|
raise RuntimeError(
|
|
"GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
|
"See error below for installation instructions."
|
|
)
|
|
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
|
|
|
|
|
# Run early check at import time
|
|
_check_opengl_availability()
|
|
|
|
# OpenGL modules - initialized lazily when context is created
|
|
gl = None
|
|
glfw = None
|
|
EGL = None
|
|
|
|
|
|
def _import_opengl():
|
|
"""Import OpenGL module. Called after context is created."""
|
|
global gl
|
|
if gl is None:
|
|
import OpenGL.GL as _gl
|
|
gl = _gl
|
|
return gl
|
|
|
|
|
|
class SizeModeInput(TypedDict):
|
|
size_mode: str
|
|
width: int
|
|
height: int
|
|
|
|
|
|
MAX_IMAGES = 5 # u_image0-4
|
|
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
|
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
|
|
|
# Vertex shader using gl_VertexID trick - no VBO needed.
|
|
# Draws a single triangle that covers the entire screen:
|
|
#
|
|
# (-1,3)
|
|
# /|
|
|
# / | <- visible area is the unit square from (-1,-1) to (1,1)
|
|
# / | parts outside get clipped away
|
|
# (-1,-1)---(3,-1)
|
|
#
|
|
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
|
VERTEX_SHADER = """#version 330 core
|
|
out vec2 v_texCoord;
|
|
void main() {
|
|
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
|
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
|
|
gl_Position = vec4(verts[gl_VertexID], 0, 1);
|
|
}
|
|
"""
|
|
|
|
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
|
precision highp float;
|
|
|
|
uniform sampler2D u_image0;
|
|
uniform vec2 u_resolution;
|
|
|
|
in vec2 v_texCoord;
|
|
layout(location = 0) out vec4 fragColor0;
|
|
|
|
void main() {
|
|
fragColor0 = texture(u_image0, v_texCoord);
|
|
}
|
|
"""
|
|
|
|
|
|
def _convert_es_to_desktop(source: str) -> str:
|
|
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
|
# Remove any existing #version directive
|
|
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
|
# Remove precision qualifiers (not needed in desktop GLSL)
|
|
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
|
# Prepend desktop GLSL version
|
|
return "#version 330 core\n" + source
|
|
|
|
|
|
def _init_glfw():
|
|
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
|
import glfw as _glfw
|
|
|
|
if not _glfw.init():
|
|
raise RuntimeError("glfw.init() failed")
|
|
|
|
try:
|
|
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
|
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
|
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
|
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
|
|
|
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
|
if not window:
|
|
raise RuntimeError("glfw.create_window() failed")
|
|
|
|
_glfw.make_context_current(window)
|
|
return window, _glfw
|
|
except Exception:
|
|
_glfw.terminate()
|
|
raise
|
|
|
|
|
|
def _init_egl():
|
|
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
|
from OpenGL import EGL as _EGL
|
|
from OpenGL.EGL import (
|
|
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
|
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
|
eglTerminate, eglDestroyContext, eglDestroySurface,
|
|
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
|
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
|
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
|
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
|
)
|
|
|
|
display = None
|
|
context = None
|
|
surface = None
|
|
|
|
try:
|
|
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
|
if display == _EGL.EGL_NO_DISPLAY:
|
|
raise RuntimeError("eglGetDisplay() failed")
|
|
|
|
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
|
if not eglInitialize(display, major, minor):
|
|
display = None # Not initialized, don't terminate
|
|
raise RuntimeError("eglInitialize() failed")
|
|
|
|
config_attribs = [
|
|
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
|
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
|
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
|
EGL_DEPTH_SIZE, 0, EGL_NONE
|
|
]
|
|
configs = (_EGL.EGLConfig * 1)()
|
|
num_configs = _EGL.EGLint()
|
|
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
|
raise RuntimeError("eglChooseConfig() failed")
|
|
config = configs[0]
|
|
|
|
if not eglBindAPI(EGL_OPENGL_API):
|
|
raise RuntimeError("eglBindAPI() failed")
|
|
|
|
context_attribs = [
|
|
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
|
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
|
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
|
EGL_NONE
|
|
]
|
|
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
|
if context == EGL_NO_CONTEXT:
|
|
raise RuntimeError("eglCreateContext() failed")
|
|
|
|
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
|
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
|
if surface == _EGL.EGL_NO_SURFACE:
|
|
raise RuntimeError("eglCreatePbufferSurface() failed")
|
|
|
|
if not eglMakeCurrent(display, surface, surface, context):
|
|
raise RuntimeError("eglMakeCurrent() failed")
|
|
|
|
return display, context, surface, _EGL
|
|
|
|
except Exception:
|
|
# Clean up any resources on failure
|
|
if surface is not None:
|
|
eglDestroySurface(display, surface)
|
|
if context is not None:
|
|
eglDestroyContext(display, context)
|
|
if display is not None:
|
|
eglTerminate(display)
|
|
raise
|
|
|
|
|
|
def _init_osmesa():
|
|
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
|
import ctypes
|
|
|
|
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
|
|
|
from OpenGL import GL as _gl
|
|
from OpenGL.osmesa import (
|
|
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
|
OSMESA_RGBA,
|
|
)
|
|
|
|
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
|
if not ctx:
|
|
raise RuntimeError("OSMesaCreateContextExt() failed")
|
|
|
|
width, height = 64, 64
|
|
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
|
|
|
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
|
OSMesaDestroyContext(ctx)
|
|
raise RuntimeError("OSMesaMakeCurrent() failed")
|
|
|
|
return ctx, buffer
|
|
|
|
|
|
class GLContext:
|
|
"""Manages OpenGL context and resources for shader execution.
|
|
|
|
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
|
"""
|
|
|
|
_instance = None
|
|
_initialized = False
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if GLContext._initialized:
|
|
return
|
|
GLContext._initialized = True
|
|
|
|
global glfw, EGL
|
|
|
|
import time
|
|
start = time.perf_counter()
|
|
|
|
self._backend = None
|
|
self._window = None
|
|
self._egl_display = None
|
|
self._egl_context = None
|
|
self._egl_surface = None
|
|
self._osmesa_ctx = None
|
|
self._osmesa_buffer = None
|
|
|
|
# Try backends in order: GLFW → EGL → OSMesa
|
|
errors = []
|
|
|
|
try:
|
|
self._window, glfw = _init_glfw()
|
|
self._backend = "glfw"
|
|
except Exception as e:
|
|
errors.append(("GLFW", e))
|
|
|
|
if self._backend is None:
|
|
try:
|
|
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
|
self._backend = "egl"
|
|
except Exception as e:
|
|
errors.append(("EGL", e))
|
|
|
|
if self._backend is None:
|
|
try:
|
|
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
|
self._backend = "osmesa"
|
|
except Exception as e:
|
|
errors.append(("OSMesa", e))
|
|
|
|
if self._backend is None:
|
|
if sys.platform == "win32":
|
|
platform_help = (
|
|
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
|
" CPU-only/headless mode is not supported on Windows."
|
|
)
|
|
elif sys.platform == "darwin":
|
|
platform_help = (
|
|
"macOS: Ensure display is available. For headless, try virtual display."
|
|
)
|
|
else:
|
|
platform_help = (
|
|
"Linux: Install one of these backends:\n"
|
|
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
|
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
|
" Headless (CPU): sudo apt install libosmesa6"
|
|
)
|
|
|
|
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
|
raise RuntimeError(
|
|
f"Failed to create OpenGL context.\n\n"
|
|
f"Backend errors:\n{error_details}\n\n"
|
|
f"{platform_help}\n\n"
|
|
"Python packages: pip install PyOpenGL PyOpenGL-accelerate glfw"
|
|
)
|
|
|
|
# Now import OpenGL.GL (after context is current)
|
|
_import_opengl()
|
|
|
|
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
|
self._vao = None
|
|
try:
|
|
vao = gl.glGenVertexArrays(1)
|
|
gl.glBindVertexArray(vao)
|
|
self._vao = vao # Only store after successful bind
|
|
except Exception:
|
|
# OSMesa with older Mesa may not support VAOs
|
|
# Clean up if we created but couldn't bind
|
|
if vao:
|
|
try:
|
|
gl.glDeleteVertexArrays(1, [vao])
|
|
except Exception:
|
|
pass
|
|
|
|
elapsed = (time.perf_counter() - start) * 1000
|
|
|
|
# Log device info
|
|
renderer = gl.glGetString(gl.GL_RENDERER)
|
|
vendor = gl.glGetString(gl.GL_VENDOR)
|
|
version = gl.glGetString(gl.GL_VERSION)
|
|
renderer = renderer.decode() if renderer else "Unknown"
|
|
vendor = vendor.decode() if vendor else "Unknown"
|
|
version = version.decode() if version else "Unknown"
|
|
|
|
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
|
|
|
def make_current(self):
|
|
if self._backend == "glfw":
|
|
glfw.make_context_current(self._window)
|
|
elif self._backend == "egl":
|
|
from OpenGL.EGL import eglMakeCurrent
|
|
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
|
elif self._backend == "osmesa":
|
|
from OpenGL.osmesa import OSMesaMakeCurrent
|
|
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
|
|
|
if self._vao is not None:
|
|
gl.glBindVertexArray(self._vao)
|
|
|
|
|
|
def _compile_shader(source: str, shader_type: int) -> int:
|
|
"""Compile a shader and return its ID."""
|
|
shader = gl.glCreateShader(shader_type)
|
|
gl.glShaderSource(shader, source)
|
|
gl.glCompileShader(shader)
|
|
|
|
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
|
error = gl.glGetShaderInfoLog(shader).decode()
|
|
gl.glDeleteShader(shader)
|
|
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
|
|
|
return shader
|
|
|
|
|
|
def _create_program(vertex_source: str, fragment_source: str) -> int:
|
|
"""Create and link a shader program."""
|
|
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
|
|
try:
|
|
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
|
|
except RuntimeError:
|
|
gl.glDeleteShader(vertex_shader)
|
|
raise
|
|
|
|
program = gl.glCreateProgram()
|
|
gl.glAttachShader(program, vertex_shader)
|
|
gl.glAttachShader(program, fragment_shader)
|
|
gl.glLinkProgram(program)
|
|
|
|
gl.glDeleteShader(vertex_shader)
|
|
gl.glDeleteShader(fragment_shader)
|
|
|
|
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
|
error = gl.glGetProgramInfoLog(program).decode()
|
|
gl.glDeleteProgram(program)
|
|
raise RuntimeError(f"Program linking failed:\n{error}")
|
|
|
|
return program
|
|
|
|
|
|
def _render_shader_batch(
|
|
fragment_code: str,
|
|
width: int,
|
|
height: int,
|
|
image_batches: list[list[np.ndarray]],
|
|
floats: list[float],
|
|
ints: list[int],
|
|
) -> list[list[np.ndarray]]:
|
|
"""
|
|
Render a fragment shader for multiple batches efficiently.
|
|
|
|
Compiles shader once, reuses framebuffer/textures across batches.
|
|
|
|
Args:
|
|
fragment_code: User's fragment shader code
|
|
width: Output width
|
|
height: Output height
|
|
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
|
floats: List of float uniforms
|
|
ints: List of int uniforms
|
|
|
|
Returns:
|
|
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
|
"""
|
|
if not image_batches:
|
|
return []
|
|
|
|
ctx = GLContext()
|
|
ctx.make_current()
|
|
|
|
# Convert from GLSL ES to desktop GLSL 330
|
|
fragment_source = _convert_es_to_desktop(fragment_code)
|
|
|
|
# Track resources for cleanup
|
|
program = None
|
|
fbo = None
|
|
output_textures = []
|
|
input_textures = []
|
|
|
|
num_inputs = len(image_batches[0])
|
|
|
|
try:
|
|
# Compile shaders (once for all batches)
|
|
try:
|
|
program = _create_program(VERTEX_SHADER, fragment_source)
|
|
except RuntimeError:
|
|
logger.error(f"Fragment shader:\n{fragment_source}")
|
|
raise
|
|
|
|
gl.glUseProgram(program)
|
|
|
|
# Create framebuffer with multiple color attachments (reused for all batches)
|
|
fbo = gl.glGenFramebuffers(1)
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
|
|
|
draw_buffers = []
|
|
for i in range(MAX_OUTPUTS):
|
|
tex = gl.glGenTextures(1)
|
|
output_textures.append(tex)
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
|
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
|
|
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
|
|
|
|
gl.glDrawBuffers(MAX_OUTPUTS, draw_buffers)
|
|
|
|
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
|
|
raise RuntimeError("Framebuffer is not complete")
|
|
|
|
# Create input textures (reused for all batches)
|
|
for i in range(num_inputs):
|
|
tex = gl.glGenTextures(1)
|
|
input_textures.append(tex)
|
|
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
|
|
|
loc = gl.glGetUniformLocation(program, f"u_image{i}")
|
|
if loc >= 0:
|
|
gl.glUniform1i(loc, i)
|
|
|
|
# Set static uniforms (once for all batches)
|
|
loc = gl.glGetUniformLocation(program, "u_resolution")
|
|
if loc >= 0:
|
|
gl.glUniform2f(loc, float(width), float(height))
|
|
|
|
for i, v in enumerate(floats):
|
|
loc = gl.glGetUniformLocation(program, f"u_float{i}")
|
|
if loc >= 0:
|
|
gl.glUniform1f(loc, v)
|
|
|
|
for i, v in enumerate(ints):
|
|
loc = gl.glGetUniformLocation(program, f"u_int{i}")
|
|
if loc >= 0:
|
|
gl.glUniform1i(loc, v)
|
|
|
|
gl.glViewport(0, 0, width, height)
|
|
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
|
|
|
|
# Process each batch
|
|
all_batch_outputs = []
|
|
for images in image_batches:
|
|
# Update input textures with this batch's images
|
|
for i, img in enumerate(images):
|
|
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
|
|
|
|
# Flip vertically for GL coordinates, ensure RGBA
|
|
h, w, c = img.shape
|
|
if c == 3:
|
|
img_upload = np.empty((h, w, 4), dtype=np.float32)
|
|
img_upload[:, :, :3] = img[::-1, :, :]
|
|
img_upload[:, :, 3] = 1.0
|
|
else:
|
|
img_upload = np.ascontiguousarray(img[::-1, :, :])
|
|
|
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
|
|
|
|
# Render
|
|
gl.glClearColor(0, 0, 0, 0)
|
|
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
|
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
|
|
|
# Read back outputs for this batch
|
|
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
|
batch_outputs = []
|
|
for tex in output_textures:
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
|
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
|
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
|
batch_outputs.append(np.ascontiguousarray(img[::-1, :, :]))
|
|
|
|
all_batch_outputs.append(batch_outputs)
|
|
|
|
return all_batch_outputs
|
|
|
|
finally:
|
|
# Unbind before deleting
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
|
gl.glUseProgram(0)
|
|
|
|
if input_textures:
|
|
gl.glDeleteTextures(len(input_textures), input_textures)
|
|
if output_textures:
|
|
gl.glDeleteTextures(len(output_textures), output_textures)
|
|
if fbo is not None:
|
|
gl.glDeleteFramebuffers(1, [fbo])
|
|
if program is not None:
|
|
gl.glDeleteProgram(program)
|
|
|
|
class GLSLShader(io.ComfyNode):
|
|
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
image_template = io.Autogrow.TemplatePrefix(
|
|
io.Image.Input("image"),
|
|
prefix="image",
|
|
min=1,
|
|
max=MAX_IMAGES,
|
|
)
|
|
|
|
float_template = io.Autogrow.TemplatePrefix(
|
|
io.Float.Input("float", default=0.0),
|
|
prefix="u_float",
|
|
min=0,
|
|
max=MAX_UNIFORMS,
|
|
)
|
|
|
|
int_template = io.Autogrow.TemplatePrefix(
|
|
io.Int.Input("int", default=0),
|
|
prefix="u_int",
|
|
min=0,
|
|
max=MAX_UNIFORMS,
|
|
)
|
|
|
|
return io.Schema(
|
|
node_id="GLSLShader",
|
|
display_name="GLSL Shader",
|
|
category="image/shader",
|
|
description=(
|
|
f"Apply GLSL fragment shaders to images. "
|
|
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
|
|
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
|
|
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
|
|
),
|
|
inputs=[
|
|
io.String.Input(
|
|
"fragment_shader",
|
|
default=DEFAULT_FRAGMENT_SHADER,
|
|
multiline=True,
|
|
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
|
),
|
|
io.DynamicCombo.Input(
|
|
"size_mode",
|
|
options=[
|
|
io.DynamicCombo.Option("from_input", []),
|
|
io.DynamicCombo.Option(
|
|
"custom",
|
|
[
|
|
io.Int.Input(
|
|
"width",
|
|
default=512,
|
|
min=1,
|
|
max=nodes.MAX_RESOLUTION,
|
|
),
|
|
io.Int.Input(
|
|
"height",
|
|
default=512,
|
|
min=1,
|
|
max=nodes.MAX_RESOLUTION,
|
|
),
|
|
],
|
|
),
|
|
],
|
|
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
|
),
|
|
io.Autogrow.Input("images", template=image_template),
|
|
io.Autogrow.Input("floats", template=float_template),
|
|
io.Autogrow.Input("ints", template=int_template),
|
|
],
|
|
outputs=[
|
|
io.Image.Output(display_name="IMAGE0"),
|
|
io.Image.Output(display_name="IMAGE1"),
|
|
io.Image.Output(display_name="IMAGE2"),
|
|
io.Image.Output(display_name="IMAGE3"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(
|
|
cls,
|
|
fragment_shader: str,
|
|
size_mode: SizeModeInput,
|
|
images: io.Autogrow.Type,
|
|
floats: io.Autogrow.Type = None,
|
|
ints: io.Autogrow.Type = None,
|
|
**kwargs,
|
|
) -> io.NodeOutput:
|
|
image_list = [v for v in images.values() if v is not None]
|
|
float_list = (
|
|
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
|
)
|
|
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
|
|
|
if not image_list:
|
|
raise ValueError("At least one input image is required")
|
|
|
|
# Determine output dimensions
|
|
if size_mode["size_mode"] == "custom":
|
|
out_width = size_mode["width"]
|
|
out_height = size_mode["height"]
|
|
else:
|
|
out_height, out_width = image_list[0].shape[1:3]
|
|
|
|
batch_size = image_list[0].shape[0]
|
|
|
|
# Prepare batches
|
|
image_batches = []
|
|
for batch_idx in range(batch_size):
|
|
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
|
|
image_batches.append(batch_images)
|
|
|
|
all_batch_outputs = _render_shader_batch(
|
|
fragment_shader,
|
|
out_width,
|
|
out_height,
|
|
image_batches,
|
|
float_list,
|
|
int_list,
|
|
)
|
|
|
|
# Collect outputs into tensors
|
|
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
|
|
for batch_outputs in all_batch_outputs:
|
|
for i, out_img in enumerate(batch_outputs):
|
|
all_outputs[i].append(torch.from_numpy(out_img))
|
|
|
|
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
|
|
return io.NodeOutput(
|
|
*output_tensors,
|
|
ui=cls._build_ui_output(image_list, output_tensors[0]),
|
|
)
|
|
|
|
@classmethod
|
|
def _build_ui_output(
|
|
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
|
) -> dict[str, list]:
|
|
"""Build UI output with input and output images for client-side shader execution."""
|
|
combined_inputs = torch.cat(image_list, dim=0)
|
|
input_images_ui = ui.ImageSaveHelper.save_images(
|
|
combined_inputs,
|
|
filename_prefix="GLSLShader_input",
|
|
folder_type=io.FolderType.temp,
|
|
cls=None,
|
|
compress_level=1,
|
|
)
|
|
|
|
output_images_ui = ui.ImageSaveHelper.save_images(
|
|
output_batch,
|
|
filename_prefix="GLSLShader_output",
|
|
folder_type=io.FolderType.temp,
|
|
cls=None,
|
|
compress_level=1,
|
|
)
|
|
|
|
return {"input_images": input_images_ui, "images": output_images_ui}
|
|
|
|
|
|
class GLSLExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [GLSLShader]
|
|
|
|
|
|
async def comfy_entrypoint() -> GLSLExtension:
|
|
return GLSLExtension()
|