diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py new file mode 100644 index 000000000..61118920c --- /dev/null +++ b/comfy_extras/nodes_glsl.py @@ -0,0 +1,500 @@ +import re +import logging +from typing import TypedDict + +import numpy as np +import torch + +import nodes +from comfy_api.latest import ComfyExtension, io, ui +from typing_extensions import override +from utils.install_util import get_missing_requirements_message + +logger = logging.getLogger(__name__) + +try: + import glfw + import OpenGL.GL as gl +except ImportError as e: + raise RuntimeError( + f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n" + "Install with: pip install PyOpenGL PyOpenGL-accelerate glfw" + ) from e +except AttributeError as e: + # This happens when PyOpenGL can't initialize (e.g., no display, missing libraries) + raise RuntimeError( + "OpenGL initialization failed.\n" + "Ensure OpenGL drivers are installed and a display is available.\n\n" + "For headless servers, you may need:\n" + " - EGL: sudo apt install libegl1-mesa-dev\n" + " - Or a virtual display: Xvfb :99 & export DISPLAY=:99" + ) from e + + +class SizeModeInput(TypedDict): + size_mode: str + width: int + height: int + + +MAX_IMAGES = 5 # u_image0-4 +MAX_UNIFORMS = 5 # u_float0-4, u_int0-4 +MAX_OUTPUTS = 4 # fragColor0-3 (MRT) + +# Vertex shader using gl_VertexID trick - no VBO needed. +# Draws a single triangle that covers the entire screen: +# +# (-1,3) +# /| +# / | <- visible area is the unit square from (-1,-1) to (1,1) +# / | parts outside get clipped away +# (-1,-1)---(3,-1) +# +# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1) +VERTEX_SHADER = """#version 330 core +out vec2 v_texCoord; +void main() { + vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3)); + v_texCoord = verts[gl_VertexID] * 0.5 + 0.5; + gl_Position = vec4(verts[gl_VertexID], 0, 1); +} +""" + +DEFAULT_FRAGMENT_SHADER = """#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; + +in vec2 v_texCoord; +layout(location = 0) out vec4 fragColor0; + +void main() { + fragColor0 = texture(u_image0, v_texCoord); +} +""" + + +def _convert_es_to_desktop(source: str) -> str: + """Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core.""" + # Remove any existing #version directive + source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE) + # Remove precision qualifiers (not needed in desktop GLSL) + source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source) + # Prepend desktop GLSL version + return "#version 330 core\n" + source + + +class GLContext: + """Manages OpenGL context and resources for shader execution.""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if GLContext._initialized: + return + GLContext._initialized = True + + import time + start = time.perf_counter() + + if not glfw.init(): + raise RuntimeError("Failed to initialize GLFW") + + glfw.window_hint(glfw.VISIBLE, glfw.FALSE) + glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3) + glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3) + glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) + + self._window = glfw.create_window(64, 64, "ComfyUI GLSL", None, None) + if not self._window: + glfw.terminate() + raise RuntimeError("Failed to create GLFW window") + + glfw.make_context_current(self._window) + + # Create VAO (required for core profile even if we don't use vertex attributes) + self._vao = gl.glGenVertexArrays(1) + gl.glBindVertexArray(self._vao) + + elapsed = (time.perf_counter() - start) * 1000 + + # Log device info + renderer = gl.glGetString(gl.GL_RENDERER) + vendor = gl.glGetString(gl.GL_VENDOR) + version = gl.glGetString(gl.GL_VERSION) + renderer = renderer.decode() if renderer else "Unknown" + vendor = vendor.decode() if vendor else "Unknown" + version = version.decode() if version else "Unknown" + + logger.info(f"GLSL context initialized in {elapsed:.1f}ms - {renderer} ({vendor}), GL {version}") + + def make_current(self): + glfw.make_context_current(self._window) + gl.glBindVertexArray(self._vao) + + +def _compile_shader(source: str, shader_type: int) -> int: + """Compile a shader and return its ID.""" + shader = gl.glCreateShader(shader_type) + gl.glShaderSource(shader, source) + gl.glCompileShader(shader) + + if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: + error = gl.glGetShaderInfoLog(shader).decode() + gl.glDeleteShader(shader) + raise RuntimeError(f"Shader compilation failed:\n{error}") + + return shader + + +def _create_program(vertex_source: str, fragment_source: str) -> int: + """Create and link a shader program.""" + vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER) + try: + fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER) + except RuntimeError: + gl.glDeleteShader(vertex_shader) + raise + + program = gl.glCreateProgram() + gl.glAttachShader(program, vertex_shader) + gl.glAttachShader(program, fragment_shader) + gl.glLinkProgram(program) + + gl.glDeleteShader(vertex_shader) + gl.glDeleteShader(fragment_shader) + + if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE: + error = gl.glGetProgramInfoLog(program).decode() + gl.glDeleteProgram(program) + raise RuntimeError(f"Program linking failed:\n{error}") + + return program + + +def _render_shader_batch( + fragment_code: str, + width: int, + height: int, + image_batches: list[list[np.ndarray]], + floats: list[float], + ints: list[int], +) -> list[list[np.ndarray]]: + """ + Render a fragment shader for multiple batches efficiently. + + Compiles shader once, reuses framebuffer/textures across batches. + + Args: + fragment_code: User's fragment shader code + width: Output width + height: Output height + image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1] + floats: List of float uniforms + ints: List of int uniforms + + Returns: + List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1] + """ + if not image_batches: + return [] + + ctx = GLContext() + ctx.make_current() + + # Convert from GLSL ES to desktop GLSL 330 + fragment_source = _convert_es_to_desktop(fragment_code) + + # Track resources for cleanup + program = None + fbo = None + output_textures = [] + input_textures = [] + + num_inputs = len(image_batches[0]) + + try: + # Compile shaders (once for all batches) + try: + program = _create_program(VERTEX_SHADER, fragment_source) + except RuntimeError: + logger.error(f"Fragment shader:\n{fragment_source}") + raise + + gl.glUseProgram(program) + + # Create framebuffer with multiple color attachments (reused for all batches) + fbo = gl.glGenFramebuffers(1) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) + + draw_buffers = [] + for i in range(MAX_OUTPUTS): + tex = gl.glGenTextures(1) + output_textures.append(tex) + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0) + draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i) + + gl.glDrawBuffers(MAX_OUTPUTS, draw_buffers) + + if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE: + raise RuntimeError("Framebuffer is not complete") + + # Create input textures (reused for all batches) + for i in range(num_inputs): + tex = gl.glGenTextures(1) + input_textures.append(tex) + gl.glActiveTexture(gl.GL_TEXTURE0 + i) + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + + loc = gl.glGetUniformLocation(program, f"u_image{i}") + if loc >= 0: + gl.glUniform1i(loc, i) + + # Set static uniforms (once for all batches) + loc = gl.glGetUniformLocation(program, "u_resolution") + if loc >= 0: + gl.glUniform2f(loc, float(width), float(height)) + + for i, v in enumerate(floats): + loc = gl.glGetUniformLocation(program, f"u_float{i}") + if loc >= 0: + gl.glUniform1f(loc, v) + + for i, v in enumerate(ints): + loc = gl.glGetUniformLocation(program, f"u_int{i}") + if loc >= 0: + gl.glUniform1i(loc, v) + + gl.glViewport(0, 0, width, height) + gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly + + # Process each batch + all_batch_outputs = [] + for images in image_batches: + # Update input textures with this batch's images + for i, img in enumerate(images): + gl.glActiveTexture(gl.GL_TEXTURE0 + i) + gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i]) + + # Flip vertically for GL coordinates, ensure RGBA + img_flipped = np.ascontiguousarray(img[::-1, :, :]) + if img_flipped.shape[2] == 3: + img_flipped = np.ascontiguousarray(np.concatenate( + [img_flipped, np.ones((*img_flipped.shape[:2], 1), dtype=np.float32)], + axis=2, + )) + + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, img_flipped.shape[1], img_flipped.shape[0], 0, gl.GL_RGBA, gl.GL_FLOAT, img_flipped) + + # Render + gl.glClearColor(0, 0, 0, 0) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) + + # Read back outputs for this batch + batch_outputs = [] + for tex in output_textures: + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT) + img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4) + batch_outputs.append(np.ascontiguousarray(img[::-1, :, :])) + + all_batch_outputs.append(batch_outputs) + + return all_batch_outputs + + finally: + # Unbind before deleting + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) + gl.glUseProgram(0) + + if input_textures: + gl.glDeleteTextures(len(input_textures), input_textures) + if output_textures: + gl.glDeleteTextures(len(output_textures), output_textures) + if fbo is not None: + gl.glDeleteFramebuffers(1, [fbo]) + if program is not None: + gl.glDeleteProgram(program) + +class GLSLShader(io.ComfyNode): + + @classmethod + def define_schema(cls) -> io.Schema: + image_template = io.Autogrow.TemplatePrefix( + io.Image.Input("image"), + prefix="image", + min=1, + max=MAX_IMAGES, + ) + + float_template = io.Autogrow.TemplatePrefix( + io.Float.Input("float", default=0.0), + prefix="u_float", + min=0, + max=MAX_UNIFORMS, + ) + + int_template = io.Autogrow.TemplatePrefix( + io.Int.Input("int", default=0), + prefix="u_int", + min=0, + max=MAX_UNIFORMS, + ) + + return io.Schema( + node_id="GLSLShader", + display_name="GLSL Shader", + category="image/shader", + description=( + f"Apply GLSL fragment shaders to images. " + f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), " + f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. " + f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}." + ), + inputs=[ + io.String.Input( + "fragment_shader", + default=DEFAULT_FRAGMENT_SHADER, + multiline=True, + tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)", + ), + io.DynamicCombo.Input( + "size_mode", + options=[ + io.DynamicCombo.Option("from_input", []), + io.DynamicCombo.Option( + "custom", + [ + io.Int.Input( + "width", + default=512, + min=1, + max=nodes.MAX_RESOLUTION, + ), + io.Int.Input( + "height", + default=512, + min=1, + max=nodes.MAX_RESOLUTION, + ), + ], + ), + ], + tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size", + ), + io.Autogrow.Input("images", template=image_template), + io.Autogrow.Input("floats", template=float_template), + io.Autogrow.Input("ints", template=int_template), + ], + outputs=[ + io.Image.Output(display_name="IMAGE0"), + io.Image.Output(display_name="IMAGE1"), + io.Image.Output(display_name="IMAGE2"), + io.Image.Output(display_name="IMAGE3"), + ], + ) + + @classmethod + def execute( + cls, + fragment_shader: str, + size_mode: SizeModeInput, + images: io.Autogrow.Type, + floats: io.Autogrow.Type = None, + ints: io.Autogrow.Type = None, + **kwargs, + ) -> io.NodeOutput: + image_list = [v for v in images.values() if v is not None] + float_list = ( + [v if v is not None else 0.0 for v in floats.values()] if floats else [] + ) + int_list = [v if v is not None else 0 for v in ints.values()] if ints else [] + + if not image_list: + raise ValueError("At least one input image is required") + + # Determine output dimensions + if size_mode["size_mode"] == "custom": + out_width = size_mode["width"] + out_height = size_mode["height"] + else: + out_height, out_width = image_list[0].shape[1:3] + + batch_size = image_list[0].shape[0] + + # Prepare batches + image_batches = [] + for batch_idx in range(batch_size): + batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list] + image_batches.append(batch_images) + + all_batch_outputs = _render_shader_batch( + fragment_shader, + out_width, + out_height, + image_batches, + float_list, + int_list, + ) + + # Collect outputs into tensors + all_outputs = [[] for _ in range(MAX_OUTPUTS)] + for batch_outputs in all_batch_outputs: + for i, out_img in enumerate(batch_outputs): + all_outputs[i].append(torch.from_numpy(out_img)) + + output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)] + return io.NodeOutput( + *output_tensors, + ui=cls._build_ui_output(image_list, output_tensors[0]), + ) + + @classmethod + def _build_ui_output( + cls, image_list: list[torch.Tensor], output_batch: torch.Tensor + ) -> dict[str, list]: + """Build UI output with input and output images for client-side shader execution.""" + combined_inputs = torch.cat(image_list, dim=0) + input_images_ui = ui.ImageSaveHelper.save_images( + combined_inputs, + filename_prefix="GLSLShader_input", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + ) + + output_images_ui = ui.ImageSaveHelper.save_images( + output_batch, + filename_prefix="GLSLShader_output", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + ) + + return {"input_images": input_images_ui, "images": output_images_ui} + + +class GLSLExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [GLSLShader] + + +async def comfy_entrypoint() -> GLSLExtension: + return GLSLExtension() diff --git a/nodes.py b/nodes.py index ad474d3cd..794969753 100644 --- a/nodes.py +++ b/nodes.py @@ -2432,6 +2432,7 @@ async def init_builtin_extra_nodes(): "nodes_wanmove.py", "nodes_image_compare.py", "nodes_zimage.py", + "nodes_glsl.py", "nodes_lora_debug.py" ] diff --git a/requirements.txt b/requirements.txt index 4ac94cb16..b25a79ab0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,6 @@ kornia>=0.7.1 spandrel pydantic~=2.0 pydantic-settings~=2.0 +PyOpenGL +PyOpenGL-accelerate +glfw