mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-29 21:13:33 +08:00
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
# Conflicts: # comfy_extras/nodes_glsl.py
773 lines
28 KiB
Python
773 lines
28 KiB
Python
import os
|
|
import sys
|
|
import re
|
|
import ctypes
|
|
import logging
|
|
from typing import TypedDict
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import nodes
|
|
import comfy_angle
|
|
from comfy_api.latest import ComfyExtension, io, ui
|
|
from typing_extensions import override
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _preload_angle():
|
|
egl_path = comfy_angle.get_egl_path()
|
|
gles_path = comfy_angle.get_glesv2_path()
|
|
|
|
if sys.platform == "win32":
|
|
angle_dir = comfy_angle.get_lib_dir()
|
|
os.add_dll_directory(angle_dir)
|
|
os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "")
|
|
|
|
mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL
|
|
ctypes.CDLL(str(egl_path), mode=mode)
|
|
ctypes.CDLL(str(gles_path), mode=mode)
|
|
|
|
|
|
# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform
|
|
# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs.
|
|
_preload_angle()
|
|
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
|
|
|
import OpenGL
|
|
OpenGL.USE_ACCELERATE = False
|
|
|
|
|
|
def _patch_find_library():
|
|
"""PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name
|
|
via ctypes.util.find_library, but ANGLE ships as 'libEGL' and
|
|
'libGLESv2'. Patch find_library to return the full ANGLE paths so
|
|
PyOpenGL loads the same libraries we pre-loaded."""
|
|
if sys.platform == "linux":
|
|
return
|
|
import ctypes.util
|
|
_orig = ctypes.util.find_library
|
|
def _patched(name):
|
|
if name == 'EGL':
|
|
return comfy_angle.get_egl_path()
|
|
if name == 'GLESv2':
|
|
return comfy_angle.get_glesv2_path()
|
|
return _orig(name)
|
|
ctypes.util.find_library = _patched
|
|
|
|
|
|
_patch_find_library()
|
|
|
|
from OpenGL import EGL
|
|
from OpenGL import GLES3 as gl
|
|
|
|
class SizeModeInput(TypedDict):
|
|
size_mode: str
|
|
width: int
|
|
height: int
|
|
|
|
|
|
MAX_IMAGES = 5 # u_image0-4
|
|
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
|
MAX_BOOLS = 10 # u_bool0-9
|
|
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
|
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
|
|
|
# Vertex shader using gl_VertexID trick - no VBO needed.
|
|
# Draws a single triangle that covers the entire screen:
|
|
#
|
|
# (-1,3)
|
|
# /|
|
|
# / | <- visible area is the unit square from (-1,-1) to (1,1)
|
|
# / | parts outside get clipped away
|
|
# (-1,-1)---(3,-1)
|
|
#
|
|
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
|
VERTEX_SHADER = """#version 300 es
|
|
out vec2 v_texCoord;
|
|
void main() {
|
|
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
|
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
|
|
gl_Position = vec4(verts[gl_VertexID], 0, 1);
|
|
}
|
|
"""
|
|
|
|
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
|
precision highp float;
|
|
|
|
uniform sampler2D u_image0;
|
|
uniform vec2 u_resolution;
|
|
|
|
in vec2 v_texCoord;
|
|
layout(location = 0) out vec4 fragColor0;
|
|
|
|
void main() {
|
|
fragColor0 = texture(u_image0, v_texCoord);
|
|
}
|
|
"""
|
|
|
|
|
|
|
|
def _egl_attribs(*values):
|
|
"""Build an EGL_NONE-terminated EGLint attribute array."""
|
|
vals = list(values) + [EGL.EGL_NONE]
|
|
return (ctypes.c_int32 * len(vals))(*vals)
|
|
|
|
|
|
def _gl_str(name):
|
|
"""Get an OpenGL string parameter."""
|
|
v = gl.glGetString(name)
|
|
if not v:
|
|
return "Unknown"
|
|
if isinstance(v, bytes):
|
|
return v.decode(errors="replace")
|
|
return ctypes.string_at(v).decode(errors="replace")
|
|
|
|
|
|
def _detect_output_count(source: str) -> int:
|
|
"""Detect how many fragColor outputs are used in the shader.
|
|
|
|
Returns the count of outputs needed (1 to MAX_OUTPUTS).
|
|
"""
|
|
matches = re.findall(r"fragColor(\d+)", source)
|
|
if not matches:
|
|
return 1 # Default to 1 output if none found
|
|
max_index = max(int(m) for m in matches)
|
|
return min(max_index + 1, MAX_OUTPUTS)
|
|
|
|
|
|
def _detect_pass_count(source: str) -> int:
|
|
"""Detect multi-pass rendering from #pragma passes N directive.
|
|
|
|
Returns the number of passes (1 if not specified).
|
|
"""
|
|
match = re.search(r'#pragma\s+passes\s+(\d+)', source)
|
|
if match:
|
|
return max(1, int(match.group(1)))
|
|
return 1
|
|
|
|
|
|
class GLContext:
|
|
"""Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton)."""
|
|
|
|
_instance = None
|
|
_initialized = False
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if GLContext._initialized:
|
|
return
|
|
|
|
import time
|
|
start = time.perf_counter()
|
|
|
|
self._display = None
|
|
self._surface = None
|
|
self._context = None
|
|
self._vao = None
|
|
|
|
try:
|
|
self._display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY)
|
|
if not self._display:
|
|
raise RuntimeError("eglGetDisplay() returned no display")
|
|
|
|
major, minor = ctypes.c_int32(0), ctypes.c_int32(0)
|
|
if not EGL.eglInitialize(self._display, ctypes.byref(major), ctypes.byref(minor)):
|
|
err = EGL.eglGetError()
|
|
self._display = None
|
|
raise RuntimeError(f"eglInitialize() failed (EGL error: 0x{err:04X})")
|
|
|
|
if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API):
|
|
raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed")
|
|
|
|
config = EGL.EGLConfig()
|
|
n_configs = ctypes.c_int32(0)
|
|
if not EGL.eglChooseConfig(
|
|
self._display,
|
|
_egl_attribs(
|
|
EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT,
|
|
EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
|
|
EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8,
|
|
EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8,
|
|
),
|
|
ctypes.byref(config), 1, ctypes.byref(n_configs),
|
|
) or n_configs.value == 0:
|
|
raise RuntimeError("eglChooseConfig() failed")
|
|
|
|
self._surface = EGL.eglCreatePbufferSurface(
|
|
self._display, config,
|
|
_egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64),
|
|
)
|
|
if not self._surface:
|
|
raise RuntimeError("eglCreatePbufferSurface() failed")
|
|
|
|
self._context = EGL.eglCreateContext(
|
|
self._display, config, EGL.EGL_NO_CONTEXT,
|
|
_egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3),
|
|
)
|
|
if not self._context:
|
|
raise RuntimeError("eglCreateContext() failed")
|
|
|
|
if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context):
|
|
raise RuntimeError("eglMakeCurrent() failed")
|
|
|
|
self._vao = gl.glGenVertexArrays(1)
|
|
gl.glBindVertexArray(self._vao)
|
|
|
|
except Exception:
|
|
self._cleanup()
|
|
raise
|
|
|
|
elapsed = (time.perf_counter() - start) * 1000
|
|
|
|
renderer = _gl_str(gl.GL_RENDERER)
|
|
vendor = _gl_str(gl.GL_VENDOR)
|
|
version = _gl_str(gl.GL_VERSION)
|
|
|
|
GLContext._initialized = True
|
|
logger.info(f"GLSL context initialized in {elapsed:.1f}ms - {renderer} ({vendor}), GL {version}")
|
|
|
|
def make_current(self):
|
|
EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context)
|
|
if self._vao is not None:
|
|
gl.glBindVertexArray(self._vao)
|
|
|
|
def _cleanup(self):
|
|
if not self._display:
|
|
return
|
|
try:
|
|
if self._vao is not None:
|
|
gl.glDeleteVertexArrays(1, [self._vao])
|
|
self._vao = None
|
|
except Exception:
|
|
pass
|
|
try:
|
|
EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
if self._context:
|
|
EGL.eglDestroyContext(self._display, self._context)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
if self._surface:
|
|
EGL.eglDestroySurface(self._display, self._surface)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
EGL.eglTerminate(self._display)
|
|
except Exception:
|
|
pass
|
|
self._display = None
|
|
|
|
|
|
def _compile_shader(source: str, shader_type: int) -> int:
|
|
"""Compile a shader and return its ID."""
|
|
shader = gl.glCreateShader(shader_type)
|
|
gl.glShaderSource(shader, source)
|
|
gl.glCompileShader(shader)
|
|
|
|
if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS):
|
|
error = gl.glGetShaderInfoLog(shader)
|
|
if isinstance(error, bytes):
|
|
error = error.decode(errors="replace")
|
|
gl.glDeleteShader(shader)
|
|
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
|
|
|
return shader
|
|
|
|
|
|
def _create_program(vertex_source: str, fragment_source: str) -> int:
|
|
"""Create and link a shader program."""
|
|
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
|
|
try:
|
|
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
|
|
except RuntimeError:
|
|
gl.glDeleteShader(vertex_shader)
|
|
raise
|
|
|
|
program = gl.glCreateProgram()
|
|
gl.glAttachShader(program, vertex_shader)
|
|
gl.glAttachShader(program, fragment_shader)
|
|
gl.glLinkProgram(program)
|
|
|
|
gl.glDeleteShader(vertex_shader)
|
|
gl.glDeleteShader(fragment_shader)
|
|
|
|
if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS):
|
|
error = gl.glGetProgramInfoLog(program)
|
|
if isinstance(error, bytes):
|
|
error = error.decode(errors="replace")
|
|
gl.glDeleteProgram(program)
|
|
raise RuntimeError(f"Program linking failed:\n{error}")
|
|
|
|
return program
|
|
|
|
|
|
def _render_shader_batch(
|
|
fragment_code: str,
|
|
width: int,
|
|
height: int,
|
|
image_batches: list[list[np.ndarray]],
|
|
floats: list[float],
|
|
ints: list[int],
|
|
bools: list[bool] | None = None,
|
|
curves: list[np.ndarray] | None = None,
|
|
) -> list[list[np.ndarray]]:
|
|
"""
|
|
Render a fragment shader for multiple batches efficiently.
|
|
|
|
Compiles shader once, reuses framebuffer/textures across batches.
|
|
Supports multi-pass rendering via #pragma passes N directive.
|
|
|
|
Args:
|
|
fragment_code: User's fragment shader code
|
|
width: Output width
|
|
height: Output height
|
|
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
|
floats: List of float uniforms
|
|
ints: List of int uniforms
|
|
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
|
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
|
|
|
Returns:
|
|
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
|
"""
|
|
import time
|
|
start_time = time.perf_counter()
|
|
|
|
if not image_batches:
|
|
return []
|
|
|
|
ctx = GLContext()
|
|
ctx.make_current()
|
|
|
|
# Detect how many outputs the shader actually uses
|
|
num_outputs = _detect_output_count(fragment_code)
|
|
|
|
# Detect multi-pass rendering
|
|
num_passes = _detect_pass_count(fragment_code)
|
|
|
|
if bools is None:
|
|
bools = []
|
|
if curves is None:
|
|
curves = []
|
|
|
|
# Track resources for cleanup
|
|
program = None
|
|
fbo = None
|
|
output_textures = []
|
|
input_textures = []
|
|
curve_textures = []
|
|
ping_pong_textures = []
|
|
ping_pong_fbos = []
|
|
|
|
num_inputs = len(image_batches[0])
|
|
|
|
try:
|
|
# Compile shaders (once for all batches)
|
|
try:
|
|
program = _create_program(VERTEX_SHADER, fragment_code)
|
|
except RuntimeError:
|
|
logger.error(f"Fragment shader:\n{fragment_code}")
|
|
raise
|
|
|
|
gl.glUseProgram(program)
|
|
|
|
# Create framebuffer with only the needed color attachments
|
|
fbo = gl.glGenFramebuffers(1)
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
|
|
|
draw_buffers = []
|
|
for i in range(num_outputs):
|
|
tex = gl.glGenTextures(1)
|
|
output_textures.append(tex)
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
|
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
|
|
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
|
|
|
|
gl.glDrawBuffers(num_outputs, draw_buffers)
|
|
|
|
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
|
|
raise RuntimeError("Framebuffer is not complete")
|
|
|
|
# Create ping-pong resources for multi-pass rendering
|
|
if num_passes > 1:
|
|
for _ in range(2):
|
|
pp_tex = gl.glGenTextures(1)
|
|
ping_pong_textures.append(pp_tex)
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, pp_tex)
|
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
|
|
|
pp_fbo = gl.glGenFramebuffers(1)
|
|
ping_pong_fbos.append(pp_fbo)
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, pp_fbo)
|
|
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, pp_tex, 0)
|
|
gl.glDrawBuffers(1, [gl.GL_COLOR_ATTACHMENT0])
|
|
|
|
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
|
|
raise RuntimeError("Ping-pong framebuffer is not complete")
|
|
|
|
# Create input textures (reused for all batches)
|
|
for i in range(num_inputs):
|
|
tex = gl.glGenTextures(1)
|
|
input_textures.append(tex)
|
|
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
|
|
|
loc = gl.glGetUniformLocation(program, f"u_image{i}")
|
|
if loc >= 0:
|
|
gl.glUniform1i(loc, i)
|
|
|
|
# Set static uniforms (once for all batches)
|
|
loc = gl.glGetUniformLocation(program, "u_resolution")
|
|
if loc >= 0:
|
|
gl.glUniform2f(loc, float(width), float(height))
|
|
|
|
for i, v in enumerate(floats):
|
|
loc = gl.glGetUniformLocation(program, f"u_float{i}")
|
|
if loc >= 0:
|
|
gl.glUniform1f(loc, v)
|
|
|
|
for i, v in enumerate(ints):
|
|
loc = gl.glGetUniformLocation(program, f"u_int{i}")
|
|
if loc >= 0:
|
|
gl.glUniform1i(loc, v)
|
|
|
|
for i, v in enumerate(bools):
|
|
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
|
if loc >= 0:
|
|
gl.glUniform1i(loc, 1 if v else 0)
|
|
|
|
# Create 1D LUT textures for curves (bound after image texture units)
|
|
for i, lut in enumerate(curves):
|
|
tex = gl.glGenTextures(1)
|
|
curve_textures.append(tex)
|
|
unit = MAX_IMAGES + i
|
|
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
|
|
|
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
|
if loc >= 0:
|
|
gl.glUniform1i(loc, unit)
|
|
|
|
# Get u_pass uniform location for multi-pass
|
|
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
|
|
|
gl.glViewport(0, 0, width, height)
|
|
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
|
|
|
|
# Process each batch
|
|
all_batch_outputs = []
|
|
for images in image_batches:
|
|
# Update input textures with this batch's images
|
|
for i, img in enumerate(images):
|
|
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
|
|
|
|
# Flip vertically for GL coordinates, ensure RGBA
|
|
h, w, c = img.shape
|
|
if c == 3:
|
|
img_upload = np.empty((h, w, 4), dtype=np.float32)
|
|
img_upload[:, :, :3] = img[::-1, :, :]
|
|
img_upload[:, :, 3] = 1.0
|
|
else:
|
|
img_upload = np.ascontiguousarray(img[::-1, :, :])
|
|
|
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
|
|
|
|
if num_passes == 1:
|
|
# Single pass - render directly to output FBO
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
|
if pass_loc >= 0:
|
|
gl.glUniform1i(pass_loc, 0)
|
|
gl.glClearColor(0, 0, 0, 0)
|
|
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
|
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
|
else:
|
|
# Multi-pass rendering with ping-pong
|
|
for p in range(num_passes):
|
|
is_last_pass = (p == num_passes - 1)
|
|
|
|
# Set pass uniform
|
|
if pass_loc >= 0:
|
|
gl.glUniform1i(pass_loc, p)
|
|
|
|
if is_last_pass:
|
|
# Last pass renders to the main output FBO
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
|
else:
|
|
# Intermediate passes render to ping-pong FBO
|
|
target_fbo = ping_pong_fbos[p % 2]
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, target_fbo)
|
|
|
|
# Set input texture for this pass
|
|
gl.glActiveTexture(gl.GL_TEXTURE0)
|
|
if p == 0:
|
|
# First pass reads from original input
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[0])
|
|
else:
|
|
# Subsequent passes read from previous pass output
|
|
source_tex = ping_pong_textures[(p - 1) % 2]
|
|
gl.glBindTexture(gl.GL_TEXTURE_2D, source_tex)
|
|
|
|
gl.glClearColor(0, 0, 0, 0)
|
|
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
|
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
|
|
|
# Read back outputs for this batch
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
|
batch_outputs = []
|
|
for i in range(num_outputs):
|
|
gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i)
|
|
buf = np.empty((height, width, 4), dtype=np.float32)
|
|
gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf)
|
|
batch_outputs.append(buf[::-1, :, :].copy())
|
|
|
|
# Pad with black images for unused outputs
|
|
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
|
for _ in range(num_outputs, MAX_OUTPUTS):
|
|
batch_outputs.append(black_img)
|
|
|
|
all_batch_outputs.append(batch_outputs)
|
|
|
|
elapsed = (time.perf_counter() - start_time) * 1000
|
|
num_batches = len(image_batches)
|
|
pass_info = f", {num_passes} passes" if num_passes > 1 else ""
|
|
logger.info(f"GLSL shader executed in {elapsed:.1f}ms ({num_batches} batch{'es' if num_batches != 1 else ''}, {width}x{height}{pass_info})")
|
|
|
|
return all_batch_outputs
|
|
|
|
finally:
|
|
# Unbind before deleting
|
|
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
|
gl.glUseProgram(0)
|
|
|
|
if input_textures:
|
|
gl.glDeleteTextures(len(input_textures), input_textures)
|
|
if curve_textures:
|
|
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
|
if output_textures:
|
|
gl.glDeleteTextures(len(output_textures), output_textures)
|
|
if ping_pong_textures:
|
|
gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures)
|
|
if fbo is not None:
|
|
gl.glDeleteFramebuffers(1, [fbo])
|
|
if ping_pong_fbos:
|
|
gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos)
|
|
if program is not None:
|
|
gl.glDeleteProgram(program)
|
|
|
|
class GLSLShader(io.ComfyNode):
|
|
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
image_template = io.Autogrow.TemplatePrefix(
|
|
io.Image.Input("image"),
|
|
prefix="image",
|
|
min=1,
|
|
max=MAX_IMAGES,
|
|
)
|
|
|
|
float_template = io.Autogrow.TemplatePrefix(
|
|
io.Float.Input("float", default=0.0),
|
|
prefix="u_float",
|
|
min=0,
|
|
max=MAX_UNIFORMS,
|
|
)
|
|
|
|
int_template = io.Autogrow.TemplatePrefix(
|
|
io.Int.Input("int", default=0),
|
|
prefix="u_int",
|
|
min=0,
|
|
max=MAX_UNIFORMS,
|
|
)
|
|
|
|
bool_template = io.Autogrow.TemplatePrefix(
|
|
io.Boolean.Input("bool", default=False),
|
|
prefix="u_bool",
|
|
min=0,
|
|
max=MAX_BOOLS,
|
|
)
|
|
|
|
curve_template = io.Autogrow.TemplatePrefix(
|
|
io.Curve.Input("curve"),
|
|
prefix="u_curve",
|
|
min=0,
|
|
max=MAX_CURVES,
|
|
)
|
|
|
|
return io.Schema(
|
|
node_id="GLSLShader",
|
|
display_name="GLSL Shader",
|
|
category="image/shader",
|
|
description=(
|
|
"Apply GLSL ES fragment shaders to images. "
|
|
"u_resolution (vec2) is always available."
|
|
),
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.String.Input(
|
|
"fragment_shader",
|
|
default=DEFAULT_FRAGMENT_SHADER,
|
|
multiline=True,
|
|
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
|
),
|
|
io.DynamicCombo.Input(
|
|
"size_mode",
|
|
options=[
|
|
io.DynamicCombo.Option("from_input", []),
|
|
io.DynamicCombo.Option(
|
|
"custom",
|
|
[
|
|
io.Int.Input(
|
|
"width",
|
|
default=512,
|
|
min=1,
|
|
max=nodes.MAX_RESOLUTION,
|
|
),
|
|
io.Int.Input(
|
|
"height",
|
|
default=512,
|
|
min=1,
|
|
max=nodes.MAX_RESOLUTION,
|
|
),
|
|
],
|
|
),
|
|
],
|
|
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
|
),
|
|
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
|
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
|
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
|
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
|
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
|
],
|
|
outputs=[
|
|
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
|
io.Image.Output(display_name="IMAGE1", tooltip="Available via layout(location = 1) out vec4 fragColor1 in the shader code"),
|
|
io.Image.Output(display_name="IMAGE2", tooltip="Available via layout(location = 2) out vec4 fragColor2 in the shader code"),
|
|
io.Image.Output(display_name="IMAGE3", tooltip="Available via layout(location = 3) out vec4 fragColor3 in the shader code"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(
|
|
cls,
|
|
fragment_shader: str,
|
|
size_mode: SizeModeInput,
|
|
images: io.Autogrow.Type,
|
|
floats: io.Autogrow.Type = None,
|
|
ints: io.Autogrow.Type = None,
|
|
bools: io.Autogrow.Type = None,
|
|
curves: io.Autogrow.Type = None,
|
|
**kwargs,
|
|
) -> io.NodeOutput:
|
|
|
|
image_list = [v for v in images.values() if v is not None]
|
|
float_list = (
|
|
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
|
)
|
|
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
|
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
|
|
|
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
|
|
|
if not image_list:
|
|
raise ValueError("At least one input image is required")
|
|
|
|
# Determine output dimensions
|
|
if size_mode["size_mode"] == "custom":
|
|
out_width = size_mode["width"]
|
|
out_height = size_mode["height"]
|
|
else:
|
|
out_height, out_width = image_list[0].shape[1:3]
|
|
|
|
batch_size = image_list[0].shape[0]
|
|
|
|
# Prepare batches
|
|
image_batches = []
|
|
for batch_idx in range(batch_size):
|
|
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
|
|
image_batches.append(batch_images)
|
|
|
|
all_batch_outputs = _render_shader_batch(
|
|
fragment_shader,
|
|
out_width,
|
|
out_height,
|
|
image_batches,
|
|
float_list,
|
|
int_list,
|
|
bool_list,
|
|
curve_luts,
|
|
)
|
|
|
|
# Collect outputs into tensors
|
|
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
|
|
for batch_outputs in all_batch_outputs:
|
|
for i, out_img in enumerate(batch_outputs):
|
|
all_outputs[i].append(torch.from_numpy(out_img))
|
|
|
|
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
|
|
return io.NodeOutput(
|
|
*output_tensors,
|
|
ui=cls._build_ui_output(image_list, output_tensors[0]),
|
|
)
|
|
|
|
@classmethod
|
|
def _build_ui_output(
|
|
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
|
) -> dict[str, list]:
|
|
"""Build UI output with input and output images for client-side shader execution."""
|
|
input_images_ui = []
|
|
for img in image_list:
|
|
input_images_ui.extend(ui.ImageSaveHelper.save_images(
|
|
img,
|
|
filename_prefix="GLSLShader_input",
|
|
folder_type=io.FolderType.temp,
|
|
cls=None,
|
|
compress_level=1,
|
|
))
|
|
|
|
output_images_ui = ui.ImageSaveHelper.save_images(
|
|
output_batch,
|
|
filename_prefix="GLSLShader_output",
|
|
folder_type=io.FolderType.temp,
|
|
cls=None,
|
|
compress_level=1,
|
|
)
|
|
|
|
return {"input_images": input_images_ui, "images": output_images_ui}
|
|
|
|
|
|
class GLSLExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [GLSLShader]
|
|
|
|
|
|
async def comfy_entrypoint() -> GLSLExtension:
|
|
return GLSLExtension()
|