fix line endings
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
pythongosssss 2026-01-28 11:02:17 -08:00
parent cee092213e
commit aaea976f36

View File

@ -1,439 +1,439 @@
import os import os
import re import re
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import TypedDict, Generator from typing import TypedDict, Generator
import numpy as np import numpy as np
import torch import torch
import nodes import nodes
from comfy_api.latest import ComfyExtension, io, ui from comfy_api.latest import ComfyExtension, io, ui
from comfy.cli_args import args from comfy.cli_args import args
from typing_extensions import override from typing_extensions import override
from utils.install_util import get_missing_requirements_message from utils.install_util import get_missing_requirements_message
class SizeModeInput(TypedDict): class SizeModeInput(TypedDict):
size_mode: str size_mode: str
width: int width: int
height: int height: int
MAX_IMAGES = 5 # u_image0-4 MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4 MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_OUTPUTS = 4 # fragColor0-3 (MRT) MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
import moderngl import moderngl
except ImportError as e: except ImportError as e:
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
# Default NOOP fragment shader that passes through the input image unchanged # Default NOOP fragment shader that passes through the input image unchanged
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc. # For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
DEFAULT_FRAGMENT_SHADER = """#version 300 es DEFAULT_FRAGMENT_SHADER = """#version 300 es
precision highp float; precision highp float;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution; uniform vec2 u_resolution;
in vec2 v_texCoord; in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0; layout(location = 0) out vec4 fragColor0;
void main() { void main() {
fragColor0 = texture(u_image0, v_texCoord); fragColor0 = texture(u_image0, v_texCoord);
} }
""" """
# Simple vertex shader for full-screen quad # Simple vertex shader for full-screen quad
VERTEX_SHADER = """#version 330 VERTEX_SHADER = """#version 330
in vec2 in_position; in vec2 in_position;
in vec2 in_texcoord; in vec2 in_texcoord;
out vec2 v_texCoord; out vec2 v_texCoord;
void main() { void main() {
gl_Position = vec4(in_position, 0.0, 1.0); gl_Position = vec4(in_position, 0.0, 1.0);
v_texCoord = in_texcoord; v_texCoord = in_texcoord;
} }
""" """
def _convert_es_to_desktop_glsl(source: str) -> str: def _convert_es_to_desktop_glsl(source: str) -> str:
"""Convert GLSL ES 3.00 shader to desktop GLSL 3.30 for ModernGL compatibility.""" """Convert GLSL ES 3.00 shader to desktop GLSL 3.30 for ModernGL compatibility."""
return re.sub(r'#version\s+300\s+es', '#version 330', source) return re.sub(r'#version\s+300\s+es', '#version 330', source)
def _create_software_gl_context() -> moderngl.Context: def _create_software_gl_context() -> moderngl.Context:
original_env = os.environ.get("LIBGL_ALWAYS_SOFTWARE") original_env = os.environ.get("LIBGL_ALWAYS_SOFTWARE")
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1" os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
try: try:
ctx = moderngl.create_standalone_context(require=330) ctx = moderngl.create_standalone_context(require=330)
logger.info(f"Created software-rendered OpenGL context: {ctx.info['GL_RENDERER']}") logger.info(f"Created software-rendered OpenGL context: {ctx.info['GL_RENDERER']}")
return ctx return ctx
finally: finally:
if original_env is None: if original_env is None:
os.environ.pop("LIBGL_ALWAYS_SOFTWARE", None) os.environ.pop("LIBGL_ALWAYS_SOFTWARE", None)
else: else:
os.environ["LIBGL_ALWAYS_SOFTWARE"] = original_env os.environ["LIBGL_ALWAYS_SOFTWARE"] = original_env
def _create_gl_context(force_software: bool = False) -> moderngl.Context: def _create_gl_context(force_software: bool = False) -> moderngl.Context:
if force_software: if force_software:
try: try:
return _create_software_gl_context() return _create_software_gl_context()
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
"Failed to create software-rendered OpenGL context.\n" "Failed to create software-rendered OpenGL context.\n"
"Ensure Mesa/llvmpipe is installed for software rendering support." "Ensure Mesa/llvmpipe is installed for software rendering support."
) from e ) from e
# Try hardware rendering first, fall back to software # Try hardware rendering first, fall back to software
try: try:
ctx = moderngl.create_standalone_context(require=330) ctx = moderngl.create_standalone_context(require=330)
logger.info(f"Created OpenGL context: {ctx.info['GL_RENDERER']}") logger.info(f"Created OpenGL context: {ctx.info['GL_RENDERER']}")
return ctx return ctx
except Exception as hw_error: except Exception as hw_error:
logger.warning(f"Hardware OpenGL context creation failed: {hw_error}") logger.warning(f"Hardware OpenGL context creation failed: {hw_error}")
logger.info("Attempting software rendering fallback...") logger.info("Attempting software rendering fallback...")
try: try:
return _create_software_gl_context() return _create_software_gl_context()
except Exception as sw_error: except Exception as sw_error:
raise RuntimeError( raise RuntimeError(
f"Failed to create OpenGL context.\n" f"Failed to create OpenGL context.\n"
f"Hardware error: {hw_error}\n\n" f"Hardware error: {hw_error}\n\n"
f"Possible solutions:\n" f"Possible solutions:\n"
f"1. Install GPU drivers with OpenGL 3.3+ support\n" f"1. Install GPU drivers with OpenGL 3.3+ support\n"
f"2. Install Mesa for software rendering (Linux: apt install libgl1-mesa-dri)\n" f"2. Install Mesa for software rendering (Linux: apt install libgl1-mesa-dri)\n"
f"3. On headless servers, ensure virtual framebuffer (Xvfb) or EGL is available" f"3. On headless servers, ensure virtual framebuffer (Xvfb) or EGL is available"
) from sw_error ) from sw_error
def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Texture: def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Texture:
height, width = image.shape[:2] height, width = image.shape[:2]
channels = image.shape[2] if len(image.shape) > 2 else 1 channels = image.shape[2] if len(image.shape) > 2 else 1
components = min(channels, 4) components = min(channels, 4)
image_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8) image_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
# Flip vertically for OpenGL coordinate system (origin at bottom-left) # Flip vertically for OpenGL coordinate system (origin at bottom-left)
image_uint8 = np.ascontiguousarray(np.flipud(image_uint8)) image_uint8 = np.ascontiguousarray(np.flipud(image_uint8))
texture = ctx.texture((width, height), components, image_uint8.tobytes()) texture = ctx.texture((width, height), components, image_uint8.tobytes())
texture.filter = (moderngl.LINEAR, moderngl.LINEAR) texture.filter = (moderngl.LINEAR, moderngl.LINEAR)
texture.repeat_x = False texture.repeat_x = False
texture.repeat_y = False texture.repeat_y = False
return texture return texture
def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray: def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
width, height = fbo.size width, height = fbo.size
data = fbo.read(components=channels, attachment=attachment) data = fbo.read(components=channels, attachment=attachment)
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels)) image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
image = np.ascontiguousarray(np.flipud(image)) image = np.ascontiguousarray(np.flipud(image))
return image.astype(np.float32) / 255.0 return image.astype(np.float32) / 255.0
def _compile_shader(ctx: moderngl.Context, fragment_source: str) -> moderngl.Program: def _compile_shader(ctx: moderngl.Context, fragment_source: str) -> moderngl.Program:
# Convert user's GLSL ES 3.00 fragment shader to desktop GLSL 3.30 for ModernGL # Convert user's GLSL ES 3.00 fragment shader to desktop GLSL 3.30 for ModernGL
fragment_source = _convert_es_to_desktop_glsl(fragment_source) fragment_source = _convert_es_to_desktop_glsl(fragment_source)
try: try:
program = ctx.program( program = ctx.program(
vertex_shader=VERTEX_SHADER, vertex_shader=VERTEX_SHADER,
fragment_shader=fragment_source, fragment_shader=fragment_source,
) )
return program return program
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
"Fragment shader compilation failed.\n\n" "Fragment shader compilation failed.\n\n"
"Make sure your shader:\n" "Make sure your shader:\n"
"1. Uses #version 300 es (WebGL 2.0 compatible)\n" "1. Uses #version 300 es (WebGL 2.0 compatible)\n"
"2. Has valid GLSL ES 3.00 syntax\n" "2. Has valid GLSL ES 3.00 syntax\n"
"3. Includes 'precision highp float;' after version\n" "3. Includes 'precision highp float;' after version\n"
"4. Uses 'out vec4 fragColor' instead of gl_FragColor\n" "4. Uses 'out vec4 fragColor' instead of gl_FragColor\n"
"5. Declares uniforms correctly (e.g., uniform sampler2D u_image0;)" "5. Declares uniforms correctly (e.g., uniform sampler2D u_image0;)"
) from e ) from e
def _render_shader( def _render_shader(
ctx: moderngl.Context, ctx: moderngl.Context,
program: moderngl.Program, program: moderngl.Program,
width: int, width: int,
height: int, height: int,
textures: list[moderngl.Texture], textures: list[moderngl.Texture],
uniforms: dict[str, int | float], uniforms: dict[str, int | float],
) -> list[np.ndarray]: ) -> list[np.ndarray]:
# Create output textures # Create output textures
output_textures = [] output_textures = []
for _ in range(MAX_OUTPUTS): for _ in range(MAX_OUTPUTS):
tex = ctx.texture((width, height), 4) tex = ctx.texture((width, height), 4)
tex.filter = (moderngl.LINEAR, moderngl.LINEAR) tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
output_textures.append(tex) output_textures.append(tex)
fbo = ctx.framebuffer(color_attachments=output_textures) fbo = ctx.framebuffer(color_attachments=output_textures)
# Full-screen quad vertices (position + texcoord) # Full-screen quad vertices (position + texcoord)
vertices = np.array([ vertices = np.array([
# Position (x, y), Texcoord (u, v) # Position (x, y), Texcoord (u, v)
-1.0, -1.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0,
1.0, -1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0,
-1.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
], dtype='f4') ], dtype='f4')
vbo = ctx.buffer(vertices.tobytes()) vbo = ctx.buffer(vertices.tobytes())
vao = ctx.vertex_array( vao = ctx.vertex_array(
program, program,
[(vbo, '2f 2f', 'in_position', 'in_texcoord')], [(vbo, '2f 2f', 'in_position', 'in_texcoord')],
) )
try: try:
# Bind textures # Bind textures
for i, texture in enumerate(textures): for i, texture in enumerate(textures):
texture.use(i) texture.use(i)
uniform_name = f'u_image{i}' uniform_name = f'u_image{i}'
if uniform_name in program: if uniform_name in program:
program[uniform_name].value = i program[uniform_name].value = i
# Set uniforms # Set uniforms
if 'u_resolution' in program: if 'u_resolution' in program:
program['u_resolution'].value = (float(width), float(height)) program['u_resolution'].value = (float(width), float(height))
for name, value in uniforms.items(): for name, value in uniforms.items():
if name in program: if name in program:
program[name].value = value program[name].value = value
# Render # Render
fbo.use() fbo.use()
fbo.clear(0.0, 0.0, 0.0, 1.0) fbo.clear(0.0, 0.0, 0.0, 1.0)
vao.render(moderngl.TRIANGLE_STRIP) vao.render(moderngl.TRIANGLE_STRIP)
# Read results from all attachments # Read results from all attachments
results = [] results = []
for i in range(MAX_OUTPUTS): for i in range(MAX_OUTPUTS):
results.append(_texture_to_image(fbo, attachment=i, channels=4)) results.append(_texture_to_image(fbo, attachment=i, channels=4))
return results return results
finally: finally:
vao.release() vao.release()
vbo.release() vbo.release()
for tex in output_textures: for tex in output_textures:
tex.release() tex.release()
fbo.release() fbo.release()
def _prepare_textures( def _prepare_textures(
ctx: moderngl.Context, ctx: moderngl.Context,
image_list: list[torch.Tensor], image_list: list[torch.Tensor],
batch_idx: int, batch_idx: int,
) -> list[moderngl.Texture]: ) -> list[moderngl.Texture]:
textures = [] textures = []
for img_tensor in image_list[:MAX_IMAGES]: for img_tensor in image_list[:MAX_IMAGES]:
img_idx = min(batch_idx, img_tensor.shape[0] - 1) img_idx = min(batch_idx, img_tensor.shape[0] - 1)
img_np = img_tensor[img_idx].cpu().numpy() img_np = img_tensor[img_idx].cpu().numpy()
textures.append(_image_to_texture(ctx, img_np)) textures.append(_image_to_texture(ctx, img_np))
return textures return textures
def _prepare_uniforms(int_list: list[int], float_list: list[float]) -> dict[str, int | float]: def _prepare_uniforms(int_list: list[int], float_list: list[float]) -> dict[str, int | float]:
uniforms: dict[str, int | float] = {} uniforms: dict[str, int | float] = {}
for i, val in enumerate(int_list[:MAX_UNIFORMS]): for i, val in enumerate(int_list[:MAX_UNIFORMS]):
uniforms[f'u_int{i}'] = int(val) uniforms[f'u_int{i}'] = int(val)
for i, val in enumerate(float_list[:MAX_UNIFORMS]): for i, val in enumerate(float_list[:MAX_UNIFORMS]):
uniforms[f'u_float{i}'] = float(val) uniforms[f'u_float{i}'] = float(val)
return uniforms return uniforms
def _release_textures(textures: list[moderngl.Texture]) -> None: def _release_textures(textures: list[moderngl.Texture]) -> None:
for texture in textures: for texture in textures:
texture.release() texture.release()
@contextmanager @contextmanager
def _gl_context(force_software: bool = False) -> Generator[moderngl.Context, None, None]: def _gl_context(force_software: bool = False) -> Generator[moderngl.Context, None, None]:
ctx = _create_gl_context(force_software) ctx = _create_gl_context(force_software)
try: try:
yield ctx yield ctx
finally: finally:
ctx.release() ctx.release()
@contextmanager @contextmanager
def _shader_program(ctx: moderngl.Context, fragment_source: str) -> Generator[moderngl.Program, None, None]: def _shader_program(ctx: moderngl.Context, fragment_source: str) -> Generator[moderngl.Program, None, None]:
program = _compile_shader(ctx, fragment_source) program = _compile_shader(ctx, fragment_source)
try: try:
yield program yield program
finally: finally:
program.release() program.release()
@contextmanager @contextmanager
def _textures_context( def _textures_context(
ctx: moderngl.Context, ctx: moderngl.Context,
image_list: list[torch.Tensor], image_list: list[torch.Tensor],
batch_idx: int, batch_idx: int,
) -> Generator[list[moderngl.Texture], None, None]: ) -> Generator[list[moderngl.Texture], None, None]:
textures = _prepare_textures(ctx, image_list, batch_idx) textures = _prepare_textures(ctx, image_list, batch_idx)
try: try:
yield textures yield textures
finally: finally:
_release_textures(textures) _release_textures(textures)
class GLSLShader(io.ComfyNode): class GLSLShader(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
# Create autogrow templates # Create autogrow templates
image_template = io.Autogrow.TemplatePrefix( image_template = io.Autogrow.TemplatePrefix(
io.Image.Input("image"), io.Image.Input("image"),
prefix="image", prefix="image",
min=1, min=1,
max=MAX_IMAGES, max=MAX_IMAGES,
) )
float_template = io.Autogrow.TemplatePrefix( float_template = io.Autogrow.TemplatePrefix(
io.Float.Input("float", default=0.0), io.Float.Input("float", default=0.0),
prefix="u_float", prefix="u_float",
min=0, min=0,
max=MAX_UNIFORMS, max=MAX_UNIFORMS,
) )
int_template = io.Autogrow.TemplatePrefix( int_template = io.Autogrow.TemplatePrefix(
io.Int.Input("int", default=0), io.Int.Input("int", default=0),
prefix="u_int", prefix="u_int",
min=0, min=0,
max=MAX_UNIFORMS, max=MAX_UNIFORMS,
) )
return io.Schema( return io.Schema(
node_id="GLSLShader", node_id="GLSLShader",
display_name="GLSL Shader", display_name="GLSL Shader",
category="image/shader", category="image/shader",
description=( description=(
f"Apply GLSL fragment shaders to images. " f"Apply GLSL fragment shaders to images. "
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), " f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. " f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}." f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
), ),
inputs=[ inputs=[
io.String.Input( io.String.Input(
"fragment_shader", "fragment_shader",
default=DEFAULT_FRAGMENT_SHADER, default=DEFAULT_FRAGMENT_SHADER,
multiline=True, multiline=True,
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)", tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
), ),
io.DynamicCombo.Input( io.DynamicCombo.Input(
"size_mode", "size_mode",
options=[ options=[
io.DynamicCombo.Option( io.DynamicCombo.Option(
"from_input", "from_input",
[], # No extra inputs - uses first input image dimensions [], # No extra inputs - uses first input image dimensions
), ),
io.DynamicCombo.Option( io.DynamicCombo.Option(
"custom", "custom",
[ [
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION), io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION),
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION), io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION),
], ],
), ),
], ],
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size", tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
), ),
io.Autogrow.Input("images", template=image_template), io.Autogrow.Input("images", template=image_template),
io.Autogrow.Input("floats", template=float_template), io.Autogrow.Input("floats", template=float_template),
io.Autogrow.Input("ints", template=int_template), io.Autogrow.Input("ints", template=int_template),
], ],
outputs=[ outputs=[
io.Image.Output(display_name="IMAGE0"), io.Image.Output(display_name="IMAGE0"),
io.Image.Output(display_name="IMAGE1"), io.Image.Output(display_name="IMAGE1"),
io.Image.Output(display_name="IMAGE2"), io.Image.Output(display_name="IMAGE2"),
io.Image.Output(display_name="IMAGE3"), io.Image.Output(display_name="IMAGE3"),
], ],
) )
@classmethod @classmethod
def execute( def execute(
cls, cls,
fragment_shader: str, fragment_shader: str,
size_mode: SizeModeInput, size_mode: SizeModeInput,
images: io.Autogrow.Type, images: io.Autogrow.Type,
floats: io.Autogrow.Type = None, floats: io.Autogrow.Type = None,
ints: io.Autogrow.Type = None, ints: io.Autogrow.Type = None,
**kwargs, **kwargs,
) -> io.NodeOutput: ) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None] image_list = [v for v in images.values() if v is not None]
float_list = [v if v is not None else 0.0 for v in floats.values()] if floats else [] float_list = [v if v is not None else 0.0 for v in floats.values()] if floats else []
int_list = [v if v is not None else 0 for v in ints.values()] if ints else [] int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
if not image_list: if not image_list:
raise ValueError("At least one input image is required") raise ValueError("At least one input image is required")
# Determine output dimensions # Determine output dimensions
if size_mode["size_mode"] == "custom": if size_mode["size_mode"] == "custom":
out_width, out_height = size_mode["width"], size_mode["height"] out_width, out_height = size_mode["width"], size_mode["height"]
else: else:
out_height, out_width = image_list[0].shape[1], image_list[0].shape[2] out_height, out_width = image_list[0].shape[1], image_list[0].shape[2]
batch_size = image_list[0].shape[0] batch_size = image_list[0].shape[0]
uniforms = _prepare_uniforms(int_list, float_list) uniforms = _prepare_uniforms(int_list, float_list)
with _gl_context(force_software=args.cpu) as ctx: with _gl_context(force_software=args.cpu) as ctx:
with _shader_program(ctx, fragment_shader) as program: with _shader_program(ctx, fragment_shader) as program:
# Collect outputs for each render target across all batches # Collect outputs for each render target across all batches
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)] all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)]
for b in range(batch_size): for b in range(batch_size):
with _textures_context(ctx, image_list, b) as textures: with _textures_context(ctx, image_list, b) as textures:
results = _render_shader(ctx, program, out_width, out_height, textures, uniforms) results = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
for i, result in enumerate(results): for i, result in enumerate(results):
all_outputs[i].append(torch.from_numpy(result)) all_outputs[i].append(torch.from_numpy(result))
# Stack batches for each output # Stack batches for each output
output_values = [] output_values = []
for i in range(MAX_OUTPUTS): for i in range(MAX_OUTPUTS):
output_batch = torch.stack(all_outputs[i], dim=0) output_batch = torch.stack(all_outputs[i], dim=0)
output_values.append(output_batch) output_values.append(output_batch)
return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0])) return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0]))
@classmethod @classmethod
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]: def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]:
"""Build UI output with input and output images for client-side shader execution.""" """Build UI output with input and output images for client-side shader execution."""
combined_inputs = torch.cat(image_list, dim=0) combined_inputs = torch.cat(image_list, dim=0)
input_images_ui = ui.ImageSaveHelper.save_images( input_images_ui = ui.ImageSaveHelper.save_images(
combined_inputs, combined_inputs,
filename_prefix="GLSLShader_input", filename_prefix="GLSLShader_input",
folder_type=io.FolderType.temp, folder_type=io.FolderType.temp,
cls=None, cls=None,
compress_level=1, compress_level=1,
) )
output_images_ui = ui.ImageSaveHelper.save_images( output_images_ui = ui.ImageSaveHelper.save_images(
output_batch, output_batch,
filename_prefix="GLSLShader_output", filename_prefix="GLSLShader_output",
folder_type=io.FolderType.temp, folder_type=io.FolderType.temp,
cls=None, cls=None,
compress_level=1, compress_level=1,
) )
return {"input_images": input_images_ui, "images": output_images_ui} return {"input_images": input_images_ui, "images": output_images_ui}
class GLSLExtension(ComfyExtension): class GLSLExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [GLSLShader] return [GLSLShader]
async def comfy_entrypoint() -> GLSLExtension: async def comfy_entrypoint() -> GLSLExtension:
return GLSLExtension() return GLSLExtension()