Support multiple outputs

This commit is contained in:
pythongosssss 2026-01-24 12:55:06 -08:00
parent 521ca3b5d2
commit 5b0fb64d20

View File

@ -20,8 +20,9 @@ class SizeModeInput(TypedDict):
height: int height: int
MAX_IMAGES = 5 # u_image0-4 MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4 MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,6 +32,7 @@ except ImportError as e:
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
# Default NOOP fragment shader that passes through the input image unchanged # Default NOOP fragment shader that passes through the input image unchanged
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
DEFAULT_FRAGMENT_SHADER = """#version 300 es DEFAULT_FRAGMENT_SHADER = """#version 300 es
precision highp float; precision highp float;
@ -38,10 +40,10 @@ uniform sampler2D u_image0;
uniform vec2 u_resolution; uniform vec2 u_resolution;
in vec2 v_texcoord; in vec2 v_texcoord;
out vec4 fragColor; layout(location = 0) out vec4 fragColor0;
void main() { void main() {
fragColor = texture(u_image0, v_texcoord); fragColor0 = texture(u_image0, v_texcoord);
} }
""" """
@ -130,10 +132,10 @@ def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Text
return texture return texture
def _texture_to_image(fbo: moderngl.Framebuffer, channels: int = 4) -> np.ndarray: def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
width, height = fbo.size width, height = fbo.size
data = fbo.read(components=channels) data = fbo.read(components=channels, attachment=attachment)
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels)) image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
image = np.ascontiguousarray(np.flipud(image)) image = np.ascontiguousarray(np.flipud(image))
@ -170,11 +172,15 @@ def _render_shader(
height: int, height: int,
textures: list[moderngl.Texture], textures: list[moderngl.Texture],
uniforms: dict[str, int | float], uniforms: dict[str, int | float],
) -> np.ndarray: ) -> list[np.ndarray]:
# Create output texture and framebuffer # Create output textures
output_texture = ctx.texture((width, height), 4) output_textures = []
output_texture.filter = (moderngl.LINEAR, moderngl.LINEAR) for _ in range(MAX_OUTPUTS):
fbo = ctx.framebuffer(color_attachments=[output_texture]) tex = ctx.texture((width, height), 4)
tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
output_textures.append(tex)
fbo = ctx.framebuffer(color_attachments=output_textures)
# Full-screen quad vertices (position + texcoord) # Full-screen quad vertices (position + texcoord)
vertices = np.array([ vertices = np.array([
@ -212,12 +218,16 @@ def _render_shader(
fbo.clear(0.0, 0.0, 0.0, 1.0) fbo.clear(0.0, 0.0, 0.0, 1.0)
vao.render(moderngl.TRIANGLE_STRIP) vao.render(moderngl.TRIANGLE_STRIP)
# Read result # Read results from all attachments
return _texture_to_image(fbo, channels=4) results = []
for i in range(MAX_OUTPUTS):
results.append(_texture_to_image(fbo, attachment=i, channels=4))
return results
finally: finally:
vao.release() vao.release()
vbo.release() vbo.release()
output_texture.release() for tex in output_textures:
tex.release()
fbo.release() fbo.release()
@ -311,8 +321,9 @@ class GLSLShader(io.ComfyNode):
category="image/shader", category="image/shader",
description=( description=(
f"Apply GLSL fragment shaders to images. " f"Apply GLSL fragment shaders to images. "
f"Uniforms: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), " f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}." f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
), ),
inputs=[ inputs=[
io.String.Input( io.String.Input(
@ -343,7 +354,10 @@ class GLSLShader(io.ComfyNode):
io.Autogrow.Input("ints", template=int_template), io.Autogrow.Input("ints", template=int_template),
], ],
outputs=[ outputs=[
io.Image.Output(display_name="IMAGE"), io.Image.Output(display_name="IMAGE0"),
io.Image.Output(display_name="IMAGE1"),
io.Image.Output(display_name="IMAGE2"),
io.Image.Output(display_name="IMAGE3"),
], ],
) )
@ -375,17 +389,22 @@ class GLSLShader(io.ComfyNode):
with _gl_context(force_software=args.cpu) as ctx: with _gl_context(force_software=args.cpu) as ctx:
with _shader_program(ctx, fragment_shader) as program: with _shader_program(ctx, fragment_shader) as program:
output_images = [] # Collect outputs for each render target across all batches
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)]
for b in range(batch_size): for b in range(batch_size):
with _textures_context(ctx, image_list, b) as textures: with _textures_context(ctx, image_list, b) as textures:
result = _render_shader(ctx, program, out_width, out_height, textures, uniforms) results = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
output_images.append(torch.from_numpy(result)) for i, result in enumerate(results):
all_outputs[i].append(torch.from_numpy(result))
output_batch = torch.stack(output_images, dim=0) # Stack batches for each output
if output_batch.shape[-1] == 4: output_values = []
output_batch = output_batch[:, :, :, :3] for i in range(MAX_OUTPUTS):
output_batch = torch.stack(all_outputs[i], dim=0)
output_values.append(output_batch)
return io.NodeOutput(output_batch, ui=cls._build_ui_output(image_list, output_batch)) return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0]))
@classmethod @classmethod
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]: def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]: