import os import sys import re import logging import ctypes.util import importlib.util from typing import TypedDict import numpy as np import torch import nodes from comfy_api.latest import ComfyExtension, io, ui from typing_extensions import override from utils.install_util import get_missing_requirements_message logger = logging.getLogger(__name__) def _check_opengl_availability(): """Early check for OpenGL availability. Raises RuntimeError if unlikely to work.""" missing = [] # Check Python packages (using find_spec to avoid importing) if importlib.util.find_spec("glfw") is None: missing.append("glfw") if importlib.util.find_spec("OpenGL") is None: missing.append("PyOpenGL") if missing: raise RuntimeError( f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n" ) # On Linux without display, check if headless backends are available if sys.platform.startswith("linux"): has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY") if not has_display: # Check for EGL or OSMesa libraries has_egl = ctypes.util.find_library("EGL") has_osmesa = ctypes.util.find_library("OSMesa") # Error disabled for CI as it fails this check # if not has_egl and not has_osmesa: # raise RuntimeError( # "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n" # "See error below for installation instructions." # ) logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}") # Run early check at import time _check_opengl_availability() # OpenGL modules - initialized lazily when context is created gl = None glfw = None EGL = None def _import_opengl(): """Import OpenGL module. Called after context is created.""" global gl if gl is None: import OpenGL.GL as _gl gl = _gl return gl class SizeModeInput(TypedDict): size_mode: str width: int height: int MAX_IMAGES = 5 # u_image0-4 MAX_UNIFORMS = 5 # u_float0-4, u_int0-4 MAX_OUTPUTS = 4 # fragColor0-3 (MRT) # Vertex shader using gl_VertexID trick - no VBO needed. # Draws a single triangle that covers the entire screen: # # (-1,3) # /| # / | <- visible area is the unit square from (-1,-1) to (1,1) # / | parts outside get clipped away # (-1,-1)---(3,-1) # # v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1) VERTEX_SHADER = """#version 330 core out vec2 v_texCoord; void main() { vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3)); v_texCoord = verts[gl_VertexID] * 0.5 + 0.5; gl_Position = vec4(verts[gl_VertexID], 0, 1); } """ DEFAULT_FRAGMENT_SHADER = """#version 300 es precision highp float; uniform sampler2D u_image0; uniform vec2 u_resolution; in vec2 v_texCoord; layout(location = 0) out vec4 fragColor0; void main() { fragColor0 = texture(u_image0, v_texCoord); } """ def _convert_es_to_desktop(source: str) -> str: """Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core.""" # Remove any existing #version directive source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE) # Remove precision qualifiers (not needed in desktop GLSL) source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source) # Prepend desktop GLSL version return "#version 330 core\n" + source def _detect_output_count(source: str) -> int: """Detect how many fragColor outputs are used in the shader. Returns the count of outputs needed (1 to MAX_OUTPUTS). """ matches = re.findall(r"fragColor(\d+)", source) if not matches: return 1 # Default to 1 output if none found max_index = max(int(m) for m in matches) return min(max_index + 1, MAX_OUTPUTS) def _init_glfw(): """Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure.""" import glfw as _glfw if not _glfw.init(): raise RuntimeError("glfw.init() failed") try: _glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE) _glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3) _glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3) _glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE) window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None) if not window: raise RuntimeError("glfw.create_window() failed") _glfw.make_context_current(window) return window, _glfw except Exception: _glfw.terminate() raise def _init_egl(): """Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure.""" from OpenGL import EGL as _EGL from OpenGL.EGL import ( eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext, eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI, eglTerminate, eglDestroyContext, eglDestroySurface, EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE, EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE, EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API, ) display = None context = None surface = None try: display = eglGetDisplay(EGL_DEFAULT_DISPLAY) if display == _EGL.EGL_NO_DISPLAY: raise RuntimeError("eglGetDisplay() failed") major, minor = _EGL.EGLint(), _EGL.EGLint() if not eglInitialize(display, major, minor): display = None # Not initialized, don't terminate raise RuntimeError("eglInitialize() failed") config_attribs = [ EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8, EGL_DEPTH_SIZE, 0, EGL_NONE ] configs = (_EGL.EGLConfig * 1)() num_configs = _EGL.EGLint() if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0: raise RuntimeError("eglChooseConfig() failed") config = configs[0] if not eglBindAPI(EGL_OPENGL_API): raise RuntimeError("eglBindAPI() failed") context_attribs = [ _EGL.EGL_CONTEXT_MAJOR_VERSION, 3, _EGL.EGL_CONTEXT_MINOR_VERSION, 3, _EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT, EGL_NONE ] context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs) if context == EGL_NO_CONTEXT: raise RuntimeError("eglCreateContext() failed") pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE] surface = eglCreatePbufferSurface(display, config, pbuffer_attribs) if surface == _EGL.EGL_NO_SURFACE: raise RuntimeError("eglCreatePbufferSurface() failed") if not eglMakeCurrent(display, surface, surface, context): raise RuntimeError("eglMakeCurrent() failed") return display, context, surface, _EGL except Exception: # Clean up any resources on failure if surface is not None: eglDestroySurface(display, surface) if context is not None: eglDestroyContext(display, context) if display is not None: eglTerminate(display) raise def _init_osmesa(): """Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure.""" import ctypes os.environ["PYOPENGL_PLATFORM"] = "osmesa" from OpenGL import GL as _gl from OpenGL.osmesa import ( OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext, OSMESA_RGBA, ) ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None) if not ctx: raise RuntimeError("OSMesaCreateContextExt() failed") width, height = 64, 64 buffer = (ctypes.c_ubyte * (width * height * 4))() if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height): OSMesaDestroyContext(ctx) raise RuntimeError("OSMesaMakeCurrent() failed") return ctx, buffer class GLContext: """Manages OpenGL context and resources for shader execution. Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software). """ _instance = None _initialized = False def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if GLContext._initialized: return GLContext._initialized = True global glfw, EGL import time start = time.perf_counter() self._backend = None self._window = None self._egl_display = None self._egl_context = None self._egl_surface = None self._osmesa_ctx = None self._osmesa_buffer = None # Try backends in order: GLFW → EGL → OSMesa errors = [] try: self._window, glfw = _init_glfw() self._backend = "glfw" except Exception as e: errors.append(("GLFW", e)) if self._backend is None: try: self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl() self._backend = "egl" except Exception as e: errors.append(("EGL", e)) if self._backend is None: try: self._osmesa_ctx, self._osmesa_buffer = _init_osmesa() self._backend = "osmesa" except Exception as e: errors.append(("OSMesa", e)) if self._backend is None: if sys.platform == "win32": platform_help = ( "Windows: Ensure GPU drivers are installed and display is available.\n" " CPU-only/headless mode is not supported on Windows." ) elif sys.platform == "darwin": platform_help = ( "macOS: Ensure display is available. For headless, try virtual display." ) else: platform_help = ( "Linux: Install one of these backends:\n" " Desktop: sudo apt install libgl1-mesa-glx libglfw3\n" " Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n" " Headless (CPU): sudo apt install libosmesa6" ) error_details = "\n".join(f" {name}: {err}" for name, err in errors) raise RuntimeError( f"Failed to create OpenGL context.\n\n" f"Backend errors:\n{error_details}\n\n" f"{platform_help}\n\n" "Python packages: pip install PyOpenGL PyOpenGL-accelerate glfw" ) # Now import OpenGL.GL (after context is current) _import_opengl() # Create VAO (required for core profile, but OSMesa may use compat profile) self._vao = None try: vao = gl.glGenVertexArrays(1) gl.glBindVertexArray(vao) self._vao = vao # Only store after successful bind except Exception: # OSMesa with older Mesa may not support VAOs # Clean up if we created but couldn't bind if vao: try: gl.glDeleteVertexArrays(1, [vao]) except Exception: pass elapsed = (time.perf_counter() - start) * 1000 # Log device info renderer = gl.glGetString(gl.GL_RENDERER) vendor = gl.glGetString(gl.GL_VENDOR) version = gl.glGetString(gl.GL_VERSION) renderer = renderer.decode() if renderer else "Unknown" vendor = vendor.decode() if vendor else "Unknown" version = version.decode() if version else "Unknown" logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}") def make_current(self): if self._backend == "glfw": glfw.make_context_current(self._window) elif self._backend == "egl": from OpenGL.EGL import eglMakeCurrent eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context) elif self._backend == "osmesa": from OpenGL.osmesa import OSMesaMakeCurrent OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64) if self._vao is not None: gl.glBindVertexArray(self._vao) def _compile_shader(source: str, shader_type: int) -> int: """Compile a shader and return its ID.""" shader = gl.glCreateShader(shader_type) gl.glShaderSource(shader, source) gl.glCompileShader(shader) if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: error = gl.glGetShaderInfoLog(shader).decode() gl.glDeleteShader(shader) raise RuntimeError(f"Shader compilation failed:\n{error}") return shader def _create_program(vertex_source: str, fragment_source: str) -> int: """Create and link a shader program.""" vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER) try: fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER) except RuntimeError: gl.glDeleteShader(vertex_shader) raise program = gl.glCreateProgram() gl.glAttachShader(program, vertex_shader) gl.glAttachShader(program, fragment_shader) gl.glLinkProgram(program) gl.glDeleteShader(vertex_shader) gl.glDeleteShader(fragment_shader) if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE: error = gl.glGetProgramInfoLog(program).decode() gl.glDeleteProgram(program) raise RuntimeError(f"Program linking failed:\n{error}") return program def _render_shader_batch( fragment_code: str, width: int, height: int, image_batches: list[list[np.ndarray]], floats: list[float], ints: list[int], ) -> list[list[np.ndarray]]: """ Render a fragment shader for multiple batches efficiently. Compiles shader once, reuses framebuffer/textures across batches. Args: fragment_code: User's fragment shader code width: Output width height: Output height image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1] floats: List of float uniforms ints: List of int uniforms Returns: List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1] """ if not image_batches: return [] ctx = GLContext() ctx.make_current() # Convert from GLSL ES to desktop GLSL 330 fragment_source = _convert_es_to_desktop(fragment_code) # Detect how many outputs the shader actually uses num_outputs = _detect_output_count(fragment_code) # Track resources for cleanup program = None fbo = None output_textures = [] input_textures = [] num_inputs = len(image_batches[0]) try: # Compile shaders (once for all batches) try: program = _create_program(VERTEX_SHADER, fragment_source) except RuntimeError: logger.error(f"Fragment shader:\n{fragment_source}") raise gl.glUseProgram(program) # Create framebuffer with only the needed color attachments fbo = gl.glGenFramebuffers(1) gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) draw_buffers = [] for i in range(num_outputs): tex = gl.glGenTextures(1) output_textures.append(tex) gl.glBindTexture(gl.GL_TEXTURE_2D, tex) gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None) gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0) draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i) gl.glDrawBuffers(num_outputs, draw_buffers) if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE: raise RuntimeError("Framebuffer is not complete") # Create input textures (reused for all batches) for i in range(num_inputs): tex = gl.glGenTextures(1) input_textures.append(tex) gl.glActiveTexture(gl.GL_TEXTURE0 + i) gl.glBindTexture(gl.GL_TEXTURE_2D, tex) gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) loc = gl.glGetUniformLocation(program, f"u_image{i}") if loc >= 0: gl.glUniform1i(loc, i) # Set static uniforms (once for all batches) loc = gl.glGetUniformLocation(program, "u_resolution") if loc >= 0: gl.glUniform2f(loc, float(width), float(height)) for i, v in enumerate(floats): loc = gl.glGetUniformLocation(program, f"u_float{i}") if loc >= 0: gl.glUniform1f(loc, v) for i, v in enumerate(ints): loc = gl.glGetUniformLocation(program, f"u_int{i}") if loc >= 0: gl.glUniform1i(loc, v) gl.glViewport(0, 0, width, height) gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly # Process each batch all_batch_outputs = [] for images in image_batches: # Update input textures with this batch's images for i, img in enumerate(images): gl.glActiveTexture(gl.GL_TEXTURE0 + i) gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i]) # Flip vertically for GL coordinates, ensure RGBA h, w, c = img.shape if c == 3: img_upload = np.empty((h, w, 4), dtype=np.float32) img_upload[:, :, :3] = img[::-1, :, :] img_upload[:, :, 3] = 1.0 else: img_upload = np.ascontiguousarray(img[::-1, :, :]) gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload) # Render gl.glClearColor(0, 0, 0, 0) gl.glClear(gl.GL_COLOR_BUFFER_BIT) gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) # Read back outputs for this batch # (glGetTexImage is synchronous, implicitly waits for rendering) batch_outputs = [] for tex in output_textures: gl.glBindTexture(gl.GL_TEXTURE_2D, tex) data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT) img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4) batch_outputs.append(np.ascontiguousarray(img[::-1, :, :])) # Pad with black images for unused outputs black_img = np.zeros((height, width, 4), dtype=np.float32) for _ in range(num_outputs, MAX_OUTPUTS): batch_outputs.append(black_img) all_batch_outputs.append(batch_outputs) return all_batch_outputs finally: # Unbind before deleting gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) gl.glUseProgram(0) if input_textures: gl.glDeleteTextures(len(input_textures), input_textures) if output_textures: gl.glDeleteTextures(len(output_textures), output_textures) if fbo is not None: gl.glDeleteFramebuffers(1, [fbo]) if program is not None: gl.glDeleteProgram(program) class GLSLShader(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: image_template = io.Autogrow.TemplatePrefix( io.Image.Input("image"), prefix="image", min=1, max=MAX_IMAGES, ) float_template = io.Autogrow.TemplatePrefix( io.Float.Input("float", default=0.0), prefix="u_float", min=0, max=MAX_UNIFORMS, ) int_template = io.Autogrow.TemplatePrefix( io.Int.Input("int", default=0), prefix="u_int", min=0, max=MAX_UNIFORMS, ) return io.Schema( node_id="GLSLShader", display_name="GLSL Shader", category="image/shader", description=( f"Apply GLSL fragment shaders to images. " f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), " f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. " f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}." ), inputs=[ io.String.Input( "fragment_shader", default=DEFAULT_FRAGMENT_SHADER, multiline=True, tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)", ), io.DynamicCombo.Input( "size_mode", options=[ io.DynamicCombo.Option("from_input", []), io.DynamicCombo.Option( "custom", [ io.Int.Input( "width", default=512, min=1, max=nodes.MAX_RESOLUTION, ), io.Int.Input( "height", default=512, min=1, max=nodes.MAX_RESOLUTION, ), ], ), ], tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size", ), io.Autogrow.Input("images", template=image_template), io.Autogrow.Input("floats", template=float_template), io.Autogrow.Input("ints", template=int_template), ], outputs=[ io.Image.Output(display_name="IMAGE0"), io.Image.Output(display_name="IMAGE1"), io.Image.Output(display_name="IMAGE2"), io.Image.Output(display_name="IMAGE3"), ], ) @classmethod def execute( cls, fragment_shader: str, size_mode: SizeModeInput, images: io.Autogrow.Type, floats: io.Autogrow.Type = None, ints: io.Autogrow.Type = None, **kwargs, ) -> io.NodeOutput: image_list = [v for v in images.values() if v is not None] float_list = ( [v if v is not None else 0.0 for v in floats.values()] if floats else [] ) int_list = [v if v is not None else 0 for v in ints.values()] if ints else [] if not image_list: raise ValueError("At least one input image is required") # Determine output dimensions if size_mode["size_mode"] == "custom": out_width = size_mode["width"] out_height = size_mode["height"] else: out_height, out_width = image_list[0].shape[1:3] batch_size = image_list[0].shape[0] # Prepare batches image_batches = [] for batch_idx in range(batch_size): batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list] image_batches.append(batch_images) all_batch_outputs = _render_shader_batch( fragment_shader, out_width, out_height, image_batches, float_list, int_list, ) # Collect outputs into tensors all_outputs = [[] for _ in range(MAX_OUTPUTS)] for batch_outputs in all_batch_outputs: for i, out_img in enumerate(batch_outputs): all_outputs[i].append(torch.from_numpy(out_img)) output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)] return io.NodeOutput( *output_tensors, ui=cls._build_ui_output(image_list, output_tensors[0]), ) @classmethod def _build_ui_output( cls, image_list: list[torch.Tensor], output_batch: torch.Tensor ) -> dict[str, list]: """Build UI output with input and output images for client-side shader execution.""" combined_inputs = torch.cat(image_list, dim=0) input_images_ui = ui.ImageSaveHelper.save_images( combined_inputs, filename_prefix="GLSLShader_input", folder_type=io.FolderType.temp, cls=None, compress_level=1, ) output_images_ui = ui.ImageSaveHelper.save_images( output_batch, filename_prefix="GLSLShader_output", folder_type=io.FolderType.temp, cls=None, compress_level=1, ) return {"input_images": input_images_ui, "images": output_images_ui} class GLSLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [GLSLShader] async def comfy_entrypoint() -> GLSLExtension: return GLSLExtension()