mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 08:40:19 +08:00
fix line endings
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
parent
cee092213e
commit
aaea976f36
@ -1,439 +1,439 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TypedDict, Generator
|
from typing import TypedDict, Generator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_api.latest import ComfyExtension, io, ui
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from utils.install_util import get_missing_requirements_message
|
from utils.install_util import get_missing_requirements_message
|
||||||
|
|
||||||
|
|
||||||
class SizeModeInput(TypedDict):
|
class SizeModeInput(TypedDict):
|
||||||
size_mode: str
|
size_mode: str
|
||||||
width: int
|
width: int
|
||||||
height: int
|
height: int
|
||||||
|
|
||||||
|
|
||||||
MAX_IMAGES = 5 # u_image0-4
|
MAX_IMAGES = 5 # u_image0-4
|
||||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import moderngl
|
import moderngl
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
|
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
|
||||||
|
|
||||||
# Default NOOP fragment shader that passes through the input image unchanged
|
# Default NOOP fragment shader that passes through the input image unchanged
|
||||||
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
|
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
|
||||||
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
||||||
precision highp float;
|
precision highp float;
|
||||||
|
|
||||||
uniform sampler2D u_image0;
|
uniform sampler2D u_image0;
|
||||||
uniform vec2 u_resolution;
|
uniform vec2 u_resolution;
|
||||||
|
|
||||||
in vec2 v_texCoord;
|
in vec2 v_texCoord;
|
||||||
layout(location = 0) out vec4 fragColor0;
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
fragColor0 = texture(u_image0, v_texCoord);
|
fragColor0 = texture(u_image0, v_texCoord);
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Simple vertex shader for full-screen quad
|
# Simple vertex shader for full-screen quad
|
||||||
VERTEX_SHADER = """#version 330
|
VERTEX_SHADER = """#version 330
|
||||||
|
|
||||||
in vec2 in_position;
|
in vec2 in_position;
|
||||||
in vec2 in_texcoord;
|
in vec2 in_texcoord;
|
||||||
|
|
||||||
out vec2 v_texCoord;
|
out vec2 v_texCoord;
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
gl_Position = vec4(in_position, 0.0, 1.0);
|
gl_Position = vec4(in_position, 0.0, 1.0);
|
||||||
v_texCoord = in_texcoord;
|
v_texCoord = in_texcoord;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _convert_es_to_desktop_glsl(source: str) -> str:
|
def _convert_es_to_desktop_glsl(source: str) -> str:
|
||||||
"""Convert GLSL ES 3.00 shader to desktop GLSL 3.30 for ModernGL compatibility."""
|
"""Convert GLSL ES 3.00 shader to desktop GLSL 3.30 for ModernGL compatibility."""
|
||||||
return re.sub(r'#version\s+300\s+es', '#version 330', source)
|
return re.sub(r'#version\s+300\s+es', '#version 330', source)
|
||||||
|
|
||||||
|
|
||||||
def _create_software_gl_context() -> moderngl.Context:
|
def _create_software_gl_context() -> moderngl.Context:
|
||||||
original_env = os.environ.get("LIBGL_ALWAYS_SOFTWARE")
|
original_env = os.environ.get("LIBGL_ALWAYS_SOFTWARE")
|
||||||
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
||||||
try:
|
try:
|
||||||
ctx = moderngl.create_standalone_context(require=330)
|
ctx = moderngl.create_standalone_context(require=330)
|
||||||
logger.info(f"Created software-rendered OpenGL context: {ctx.info['GL_RENDERER']}")
|
logger.info(f"Created software-rendered OpenGL context: {ctx.info['GL_RENDERER']}")
|
||||||
return ctx
|
return ctx
|
||||||
finally:
|
finally:
|
||||||
if original_env is None:
|
if original_env is None:
|
||||||
os.environ.pop("LIBGL_ALWAYS_SOFTWARE", None)
|
os.environ.pop("LIBGL_ALWAYS_SOFTWARE", None)
|
||||||
else:
|
else:
|
||||||
os.environ["LIBGL_ALWAYS_SOFTWARE"] = original_env
|
os.environ["LIBGL_ALWAYS_SOFTWARE"] = original_env
|
||||||
|
|
||||||
|
|
||||||
def _create_gl_context(force_software: bool = False) -> moderngl.Context:
|
def _create_gl_context(force_software: bool = False) -> moderngl.Context:
|
||||||
if force_software:
|
if force_software:
|
||||||
try:
|
try:
|
||||||
return _create_software_gl_context()
|
return _create_software_gl_context()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Failed to create software-rendered OpenGL context.\n"
|
"Failed to create software-rendered OpenGL context.\n"
|
||||||
"Ensure Mesa/llvmpipe is installed for software rendering support."
|
"Ensure Mesa/llvmpipe is installed for software rendering support."
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# Try hardware rendering first, fall back to software
|
# Try hardware rendering first, fall back to software
|
||||||
try:
|
try:
|
||||||
ctx = moderngl.create_standalone_context(require=330)
|
ctx = moderngl.create_standalone_context(require=330)
|
||||||
logger.info(f"Created OpenGL context: {ctx.info['GL_RENDERER']}")
|
logger.info(f"Created OpenGL context: {ctx.info['GL_RENDERER']}")
|
||||||
return ctx
|
return ctx
|
||||||
except Exception as hw_error:
|
except Exception as hw_error:
|
||||||
logger.warning(f"Hardware OpenGL context creation failed: {hw_error}")
|
logger.warning(f"Hardware OpenGL context creation failed: {hw_error}")
|
||||||
logger.info("Attempting software rendering fallback...")
|
logger.info("Attempting software rendering fallback...")
|
||||||
try:
|
try:
|
||||||
return _create_software_gl_context()
|
return _create_software_gl_context()
|
||||||
except Exception as sw_error:
|
except Exception as sw_error:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Failed to create OpenGL context.\n"
|
f"Failed to create OpenGL context.\n"
|
||||||
f"Hardware error: {hw_error}\n\n"
|
f"Hardware error: {hw_error}\n\n"
|
||||||
f"Possible solutions:\n"
|
f"Possible solutions:\n"
|
||||||
f"1. Install GPU drivers with OpenGL 3.3+ support\n"
|
f"1. Install GPU drivers with OpenGL 3.3+ support\n"
|
||||||
f"2. Install Mesa for software rendering (Linux: apt install libgl1-mesa-dri)\n"
|
f"2. Install Mesa for software rendering (Linux: apt install libgl1-mesa-dri)\n"
|
||||||
f"3. On headless servers, ensure virtual framebuffer (Xvfb) or EGL is available"
|
f"3. On headless servers, ensure virtual framebuffer (Xvfb) or EGL is available"
|
||||||
) from sw_error
|
) from sw_error
|
||||||
|
|
||||||
|
|
||||||
def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Texture:
|
def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Texture:
|
||||||
height, width = image.shape[:2]
|
height, width = image.shape[:2]
|
||||||
channels = image.shape[2] if len(image.shape) > 2 else 1
|
channels = image.shape[2] if len(image.shape) > 2 else 1
|
||||||
|
|
||||||
components = min(channels, 4)
|
components = min(channels, 4)
|
||||||
|
|
||||||
image_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
image_uint8 = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
||||||
|
|
||||||
# Flip vertically for OpenGL coordinate system (origin at bottom-left)
|
# Flip vertically for OpenGL coordinate system (origin at bottom-left)
|
||||||
image_uint8 = np.ascontiguousarray(np.flipud(image_uint8))
|
image_uint8 = np.ascontiguousarray(np.flipud(image_uint8))
|
||||||
|
|
||||||
texture = ctx.texture((width, height), components, image_uint8.tobytes())
|
texture = ctx.texture((width, height), components, image_uint8.tobytes())
|
||||||
texture.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
texture.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
||||||
texture.repeat_x = False
|
texture.repeat_x = False
|
||||||
texture.repeat_y = False
|
texture.repeat_y = False
|
||||||
|
|
||||||
return texture
|
return texture
|
||||||
|
|
||||||
|
|
||||||
def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
|
def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
|
||||||
width, height = fbo.size
|
width, height = fbo.size
|
||||||
|
|
||||||
data = fbo.read(components=channels, attachment=attachment)
|
data = fbo.read(components=channels, attachment=attachment)
|
||||||
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
|
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
|
||||||
|
|
||||||
image = np.ascontiguousarray(np.flipud(image))
|
image = np.ascontiguousarray(np.flipud(image))
|
||||||
|
|
||||||
return image.astype(np.float32) / 255.0
|
return image.astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
def _compile_shader(ctx: moderngl.Context, fragment_source: str) -> moderngl.Program:
|
def _compile_shader(ctx: moderngl.Context, fragment_source: str) -> moderngl.Program:
|
||||||
# Convert user's GLSL ES 3.00 fragment shader to desktop GLSL 3.30 for ModernGL
|
# Convert user's GLSL ES 3.00 fragment shader to desktop GLSL 3.30 for ModernGL
|
||||||
fragment_source = _convert_es_to_desktop_glsl(fragment_source)
|
fragment_source = _convert_es_to_desktop_glsl(fragment_source)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
program = ctx.program(
|
program = ctx.program(
|
||||||
vertex_shader=VERTEX_SHADER,
|
vertex_shader=VERTEX_SHADER,
|
||||||
fragment_shader=fragment_source,
|
fragment_shader=fragment_source,
|
||||||
)
|
)
|
||||||
return program
|
return program
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Fragment shader compilation failed.\n\n"
|
"Fragment shader compilation failed.\n\n"
|
||||||
"Make sure your shader:\n"
|
"Make sure your shader:\n"
|
||||||
"1. Uses #version 300 es (WebGL 2.0 compatible)\n"
|
"1. Uses #version 300 es (WebGL 2.0 compatible)\n"
|
||||||
"2. Has valid GLSL ES 3.00 syntax\n"
|
"2. Has valid GLSL ES 3.00 syntax\n"
|
||||||
"3. Includes 'precision highp float;' after version\n"
|
"3. Includes 'precision highp float;' after version\n"
|
||||||
"4. Uses 'out vec4 fragColor' instead of gl_FragColor\n"
|
"4. Uses 'out vec4 fragColor' instead of gl_FragColor\n"
|
||||||
"5. Declares uniforms correctly (e.g., uniform sampler2D u_image0;)"
|
"5. Declares uniforms correctly (e.g., uniform sampler2D u_image0;)"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def _render_shader(
|
def _render_shader(
|
||||||
ctx: moderngl.Context,
|
ctx: moderngl.Context,
|
||||||
program: moderngl.Program,
|
program: moderngl.Program,
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
textures: list[moderngl.Texture],
|
textures: list[moderngl.Texture],
|
||||||
uniforms: dict[str, int | float],
|
uniforms: dict[str, int | float],
|
||||||
) -> list[np.ndarray]:
|
) -> list[np.ndarray]:
|
||||||
# Create output textures
|
# Create output textures
|
||||||
output_textures = []
|
output_textures = []
|
||||||
for _ in range(MAX_OUTPUTS):
|
for _ in range(MAX_OUTPUTS):
|
||||||
tex = ctx.texture((width, height), 4)
|
tex = ctx.texture((width, height), 4)
|
||||||
tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
||||||
output_textures.append(tex)
|
output_textures.append(tex)
|
||||||
|
|
||||||
fbo = ctx.framebuffer(color_attachments=output_textures)
|
fbo = ctx.framebuffer(color_attachments=output_textures)
|
||||||
|
|
||||||
# Full-screen quad vertices (position + texcoord)
|
# Full-screen quad vertices (position + texcoord)
|
||||||
vertices = np.array([
|
vertices = np.array([
|
||||||
# Position (x, y), Texcoord (u, v)
|
# Position (x, y), Texcoord (u, v)
|
||||||
-1.0, -1.0, 0.0, 0.0,
|
-1.0, -1.0, 0.0, 0.0,
|
||||||
1.0, -1.0, 1.0, 0.0,
|
1.0, -1.0, 1.0, 0.0,
|
||||||
-1.0, 1.0, 0.0, 1.0,
|
-1.0, 1.0, 0.0, 1.0,
|
||||||
1.0, 1.0, 1.0, 1.0,
|
1.0, 1.0, 1.0, 1.0,
|
||||||
], dtype='f4')
|
], dtype='f4')
|
||||||
|
|
||||||
vbo = ctx.buffer(vertices.tobytes())
|
vbo = ctx.buffer(vertices.tobytes())
|
||||||
vao = ctx.vertex_array(
|
vao = ctx.vertex_array(
|
||||||
program,
|
program,
|
||||||
[(vbo, '2f 2f', 'in_position', 'in_texcoord')],
|
[(vbo, '2f 2f', 'in_position', 'in_texcoord')],
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Bind textures
|
# Bind textures
|
||||||
for i, texture in enumerate(textures):
|
for i, texture in enumerate(textures):
|
||||||
texture.use(i)
|
texture.use(i)
|
||||||
uniform_name = f'u_image{i}'
|
uniform_name = f'u_image{i}'
|
||||||
if uniform_name in program:
|
if uniform_name in program:
|
||||||
program[uniform_name].value = i
|
program[uniform_name].value = i
|
||||||
|
|
||||||
# Set uniforms
|
# Set uniforms
|
||||||
if 'u_resolution' in program:
|
if 'u_resolution' in program:
|
||||||
program['u_resolution'].value = (float(width), float(height))
|
program['u_resolution'].value = (float(width), float(height))
|
||||||
|
|
||||||
for name, value in uniforms.items():
|
for name, value in uniforms.items():
|
||||||
if name in program:
|
if name in program:
|
||||||
program[name].value = value
|
program[name].value = value
|
||||||
|
|
||||||
# Render
|
# Render
|
||||||
fbo.use()
|
fbo.use()
|
||||||
fbo.clear(0.0, 0.0, 0.0, 1.0)
|
fbo.clear(0.0, 0.0, 0.0, 1.0)
|
||||||
vao.render(moderngl.TRIANGLE_STRIP)
|
vao.render(moderngl.TRIANGLE_STRIP)
|
||||||
|
|
||||||
# Read results from all attachments
|
# Read results from all attachments
|
||||||
results = []
|
results = []
|
||||||
for i in range(MAX_OUTPUTS):
|
for i in range(MAX_OUTPUTS):
|
||||||
results.append(_texture_to_image(fbo, attachment=i, channels=4))
|
results.append(_texture_to_image(fbo, attachment=i, channels=4))
|
||||||
return results
|
return results
|
||||||
finally:
|
finally:
|
||||||
vao.release()
|
vao.release()
|
||||||
vbo.release()
|
vbo.release()
|
||||||
for tex in output_textures:
|
for tex in output_textures:
|
||||||
tex.release()
|
tex.release()
|
||||||
fbo.release()
|
fbo.release()
|
||||||
|
|
||||||
|
|
||||||
def _prepare_textures(
|
def _prepare_textures(
|
||||||
ctx: moderngl.Context,
|
ctx: moderngl.Context,
|
||||||
image_list: list[torch.Tensor],
|
image_list: list[torch.Tensor],
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
) -> list[moderngl.Texture]:
|
) -> list[moderngl.Texture]:
|
||||||
textures = []
|
textures = []
|
||||||
for img_tensor in image_list[:MAX_IMAGES]:
|
for img_tensor in image_list[:MAX_IMAGES]:
|
||||||
img_idx = min(batch_idx, img_tensor.shape[0] - 1)
|
img_idx = min(batch_idx, img_tensor.shape[0] - 1)
|
||||||
img_np = img_tensor[img_idx].cpu().numpy()
|
img_np = img_tensor[img_idx].cpu().numpy()
|
||||||
textures.append(_image_to_texture(ctx, img_np))
|
textures.append(_image_to_texture(ctx, img_np))
|
||||||
return textures
|
return textures
|
||||||
|
|
||||||
|
|
||||||
def _prepare_uniforms(int_list: list[int], float_list: list[float]) -> dict[str, int | float]:
|
def _prepare_uniforms(int_list: list[int], float_list: list[float]) -> dict[str, int | float]:
|
||||||
uniforms: dict[str, int | float] = {}
|
uniforms: dict[str, int | float] = {}
|
||||||
for i, val in enumerate(int_list[:MAX_UNIFORMS]):
|
for i, val in enumerate(int_list[:MAX_UNIFORMS]):
|
||||||
uniforms[f'u_int{i}'] = int(val)
|
uniforms[f'u_int{i}'] = int(val)
|
||||||
for i, val in enumerate(float_list[:MAX_UNIFORMS]):
|
for i, val in enumerate(float_list[:MAX_UNIFORMS]):
|
||||||
uniforms[f'u_float{i}'] = float(val)
|
uniforms[f'u_float{i}'] = float(val)
|
||||||
return uniforms
|
return uniforms
|
||||||
|
|
||||||
|
|
||||||
def _release_textures(textures: list[moderngl.Texture]) -> None:
|
def _release_textures(textures: list[moderngl.Texture]) -> None:
|
||||||
for texture in textures:
|
for texture in textures:
|
||||||
texture.release()
|
texture.release()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _gl_context(force_software: bool = False) -> Generator[moderngl.Context, None, None]:
|
def _gl_context(force_software: bool = False) -> Generator[moderngl.Context, None, None]:
|
||||||
ctx = _create_gl_context(force_software)
|
ctx = _create_gl_context(force_software)
|
||||||
try:
|
try:
|
||||||
yield ctx
|
yield ctx
|
||||||
finally:
|
finally:
|
||||||
ctx.release()
|
ctx.release()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _shader_program(ctx: moderngl.Context, fragment_source: str) -> Generator[moderngl.Program, None, None]:
|
def _shader_program(ctx: moderngl.Context, fragment_source: str) -> Generator[moderngl.Program, None, None]:
|
||||||
program = _compile_shader(ctx, fragment_source)
|
program = _compile_shader(ctx, fragment_source)
|
||||||
try:
|
try:
|
||||||
yield program
|
yield program
|
||||||
finally:
|
finally:
|
||||||
program.release()
|
program.release()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _textures_context(
|
def _textures_context(
|
||||||
ctx: moderngl.Context,
|
ctx: moderngl.Context,
|
||||||
image_list: list[torch.Tensor],
|
image_list: list[torch.Tensor],
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
) -> Generator[list[moderngl.Texture], None, None]:
|
) -> Generator[list[moderngl.Texture], None, None]:
|
||||||
textures = _prepare_textures(ctx, image_list, batch_idx)
|
textures = _prepare_textures(ctx, image_list, batch_idx)
|
||||||
try:
|
try:
|
||||||
yield textures
|
yield textures
|
||||||
finally:
|
finally:
|
||||||
_release_textures(textures)
|
_release_textures(textures)
|
||||||
|
|
||||||
|
|
||||||
class GLSLShader(io.ComfyNode):
|
class GLSLShader(io.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
# Create autogrow templates
|
# Create autogrow templates
|
||||||
image_template = io.Autogrow.TemplatePrefix(
|
image_template = io.Autogrow.TemplatePrefix(
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
prefix="image",
|
prefix="image",
|
||||||
min=1,
|
min=1,
|
||||||
max=MAX_IMAGES,
|
max=MAX_IMAGES,
|
||||||
)
|
)
|
||||||
|
|
||||||
float_template = io.Autogrow.TemplatePrefix(
|
float_template = io.Autogrow.TemplatePrefix(
|
||||||
io.Float.Input("float", default=0.0),
|
io.Float.Input("float", default=0.0),
|
||||||
prefix="u_float",
|
prefix="u_float",
|
||||||
min=0,
|
min=0,
|
||||||
max=MAX_UNIFORMS,
|
max=MAX_UNIFORMS,
|
||||||
)
|
)
|
||||||
|
|
||||||
int_template = io.Autogrow.TemplatePrefix(
|
int_template = io.Autogrow.TemplatePrefix(
|
||||||
io.Int.Input("int", default=0),
|
io.Int.Input("int", default=0),
|
||||||
prefix="u_int",
|
prefix="u_int",
|
||||||
min=0,
|
min=0,
|
||||||
max=MAX_UNIFORMS,
|
max=MAX_UNIFORMS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="GLSLShader",
|
node_id="GLSLShader",
|
||||||
display_name="GLSL Shader",
|
display_name="GLSL Shader",
|
||||||
category="image/shader",
|
category="image/shader",
|
||||||
description=(
|
description=(
|
||||||
f"Apply GLSL fragment shaders to images. "
|
f"Apply GLSL fragment shaders to images. "
|
||||||
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
|
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
|
||||||
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
|
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
|
||||||
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
|
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
|
||||||
),
|
),
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input(
|
io.String.Input(
|
||||||
"fragment_shader",
|
"fragment_shader",
|
||||||
default=DEFAULT_FRAGMENT_SHADER,
|
default=DEFAULT_FRAGMENT_SHADER,
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
||||||
),
|
),
|
||||||
io.DynamicCombo.Input(
|
io.DynamicCombo.Input(
|
||||||
"size_mode",
|
"size_mode",
|
||||||
options=[
|
options=[
|
||||||
io.DynamicCombo.Option(
|
io.DynamicCombo.Option(
|
||||||
"from_input",
|
"from_input",
|
||||||
[], # No extra inputs - uses first input image dimensions
|
[], # No extra inputs - uses first input image dimensions
|
||||||
),
|
),
|
||||||
io.DynamicCombo.Option(
|
io.DynamicCombo.Option(
|
||||||
"custom",
|
"custom",
|
||||||
[
|
[
|
||||||
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
io.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||||
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
io.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
||||||
),
|
),
|
||||||
io.Autogrow.Input("images", template=image_template),
|
io.Autogrow.Input("images", template=image_template),
|
||||||
io.Autogrow.Input("floats", template=float_template),
|
io.Autogrow.Input("floats", template=float_template),
|
||||||
io.Autogrow.Input("ints", template=int_template),
|
io.Autogrow.Input("ints", template=int_template),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(display_name="IMAGE0"),
|
io.Image.Output(display_name="IMAGE0"),
|
||||||
io.Image.Output(display_name="IMAGE1"),
|
io.Image.Output(display_name="IMAGE1"),
|
||||||
io.Image.Output(display_name="IMAGE2"),
|
io.Image.Output(display_name="IMAGE2"),
|
||||||
io.Image.Output(display_name="IMAGE3"),
|
io.Image.Output(display_name="IMAGE3"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(
|
def execute(
|
||||||
cls,
|
cls,
|
||||||
fragment_shader: str,
|
fragment_shader: str,
|
||||||
size_mode: SizeModeInput,
|
size_mode: SizeModeInput,
|
||||||
images: io.Autogrow.Type,
|
images: io.Autogrow.Type,
|
||||||
floats: io.Autogrow.Type = None,
|
floats: io.Autogrow.Type = None,
|
||||||
ints: io.Autogrow.Type = None,
|
ints: io.Autogrow.Type = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
image_list = [v for v in images.values() if v is not None]
|
image_list = [v for v in images.values() if v is not None]
|
||||||
float_list = [v if v is not None else 0.0 for v in floats.values()] if floats else []
|
float_list = [v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||||
|
|
||||||
if not image_list:
|
if not image_list:
|
||||||
raise ValueError("At least one input image is required")
|
raise ValueError("At least one input image is required")
|
||||||
|
|
||||||
# Determine output dimensions
|
# Determine output dimensions
|
||||||
if size_mode["size_mode"] == "custom":
|
if size_mode["size_mode"] == "custom":
|
||||||
out_width, out_height = size_mode["width"], size_mode["height"]
|
out_width, out_height = size_mode["width"], size_mode["height"]
|
||||||
else:
|
else:
|
||||||
out_height, out_width = image_list[0].shape[1], image_list[0].shape[2]
|
out_height, out_width = image_list[0].shape[1], image_list[0].shape[2]
|
||||||
|
|
||||||
batch_size = image_list[0].shape[0]
|
batch_size = image_list[0].shape[0]
|
||||||
uniforms = _prepare_uniforms(int_list, float_list)
|
uniforms = _prepare_uniforms(int_list, float_list)
|
||||||
|
|
||||||
with _gl_context(force_software=args.cpu) as ctx:
|
with _gl_context(force_software=args.cpu) as ctx:
|
||||||
with _shader_program(ctx, fragment_shader) as program:
|
with _shader_program(ctx, fragment_shader) as program:
|
||||||
# Collect outputs for each render target across all batches
|
# Collect outputs for each render target across all batches
|
||||||
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)]
|
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)]
|
||||||
|
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
with _textures_context(ctx, image_list, b) as textures:
|
with _textures_context(ctx, image_list, b) as textures:
|
||||||
results = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
|
results = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
all_outputs[i].append(torch.from_numpy(result))
|
all_outputs[i].append(torch.from_numpy(result))
|
||||||
|
|
||||||
# Stack batches for each output
|
# Stack batches for each output
|
||||||
output_values = []
|
output_values = []
|
||||||
for i in range(MAX_OUTPUTS):
|
for i in range(MAX_OUTPUTS):
|
||||||
output_batch = torch.stack(all_outputs[i], dim=0)
|
output_batch = torch.stack(all_outputs[i], dim=0)
|
||||||
output_values.append(output_batch)
|
output_values.append(output_batch)
|
||||||
|
|
||||||
return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0]))
|
return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0]))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]:
|
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]:
|
||||||
"""Build UI output with input and output images for client-side shader execution."""
|
"""Build UI output with input and output images for client-side shader execution."""
|
||||||
combined_inputs = torch.cat(image_list, dim=0)
|
combined_inputs = torch.cat(image_list, dim=0)
|
||||||
input_images_ui = ui.ImageSaveHelper.save_images(
|
input_images_ui = ui.ImageSaveHelper.save_images(
|
||||||
combined_inputs,
|
combined_inputs,
|
||||||
filename_prefix="GLSLShader_input",
|
filename_prefix="GLSLShader_input",
|
||||||
folder_type=io.FolderType.temp,
|
folder_type=io.FolderType.temp,
|
||||||
cls=None,
|
cls=None,
|
||||||
compress_level=1,
|
compress_level=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_images_ui = ui.ImageSaveHelper.save_images(
|
output_images_ui = ui.ImageSaveHelper.save_images(
|
||||||
output_batch,
|
output_batch,
|
||||||
filename_prefix="GLSLShader_output",
|
filename_prefix="GLSLShader_output",
|
||||||
folder_type=io.FolderType.temp,
|
folder_type=io.FolderType.temp,
|
||||||
cls=None,
|
cls=None,
|
||||||
compress_level=1,
|
compress_level=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"input_images": input_images_ui, "images": output_images_ui}
|
return {"input_images": input_images_ui, "images": output_images_ui}
|
||||||
|
|
||||||
|
|
||||||
class GLSLExtension(ComfyExtension):
|
class GLSLExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [GLSLShader]
|
return [GLSLShader]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> GLSLExtension:
|
async def comfy_entrypoint() -> GLSLExtension:
|
||||||
return GLSLExtension()
|
return GLSLExtension()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user