feat: add curve inputs and raise uniform limit for GLSL shader node (#13158)

* feat: add curve inputs and raise uniform limit for GLSL shader node

* allow arbitrary size for curve
This commit is contained in:
Terry Jia 2026-03-26 21:45:05 -04:00 committed by GitHub
parent 359559c913
commit 1dc64f3526
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 200 additions and 1 deletions

View File

@ -0,0 +1,90 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
uniform float u_float5;
uniform float u_float6;
uniform float u_float7;
uniform float u_float8;
uniform bool u_bool0;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 rgb2hsl(vec3 c) {
float maxC = max(c.r, max(c.g, c.b));
float minC = min(c.r, min(c.g, c.b));
float l = (maxC + minC) * 0.5;
if (maxC == minC) return vec3(0.0, 0.0, l);
float d = maxC - minC;
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
float h;
if (maxC == c.r) {
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / d + 2.0;
} else {
h = (c.r - c.g) / d + 4.0;
}
h /= 6.0;
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
if (t < 0.0) t += 1.0;
if (t > 1.0) t -= 1.0;
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
if (t < 1.0 / 2.0) return q;
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
float h = hsl.x, s = hsl.y, l = hsl.z;
if (s == 0.0) return vec3(l);
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
float p = 2.0 * l - q;
return vec3(
hue2rgb(p, q, h + 1.0 / 3.0),
hue2rgb(p, q, h),
hue2rgb(p, q, h - 1.0 / 3.0)
);
}
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float lightness = (maxC + minC) * 0.5;
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
const float a = 0.25;
const float b = 0.333;
const float scale = 0.7;
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
color += sw * shadows + mw * midtones + hw * highlights;
if (u_bool0) {
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
hsl.z = lightness;
color = hsl2rgb(hsl);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@ -0,0 +1,46 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
uniform sampler2D u_curve1; // Red channel curve
uniform sampler2D u_curve2; // Green channel curve
uniform sampler2D u_curve3; // Blue channel curve
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// GIMP-compatible curve lookup with manual linear interpolation.
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
// index = value * (n_samples - 1)
// f = fract(index)
// result = (1-f) * samples[floor] + f * samples[ceil]
//
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
float applyCurve(sampler2D curve, float value) {
value = clamp(value, 0.0, 1.0);
float pos = value * 255.0;
int lo = int(floor(pos));
int hi = min(lo + 1, 255);
float f = pos - float(lo);
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
return a + f * (b - a);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// GIMP order: per-channel curves first, then RGB master curve.
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
// dest = colors_curve( channel_curve( src ) )
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r));
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g));
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b));
fragColor0 = vec4(color.rgb, color.a);
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -87,7 +87,9 @@ class SizeModeInput(TypedDict):
MAX_IMAGES = 5 # u_image0-4 MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4 MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
MAX_BOOLS = 10 # u_bool0-9
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
MAX_OUTPUTS = 4 # fragColor0-3 (MRT) MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
# Vertex shader using gl_VertexID trick - no VBO needed. # Vertex shader using gl_VertexID trick - no VBO needed.
@ -497,6 +499,8 @@ def _render_shader_batch(
image_batches: list[list[np.ndarray]], image_batches: list[list[np.ndarray]],
floats: list[float], floats: list[float],
ints: list[int], ints: list[int],
bools: list[bool] | None = None,
curves: list[np.ndarray] | None = None,
) -> list[list[np.ndarray]]: ) -> list[list[np.ndarray]]:
""" """
Render a fragment shader for multiple batches efficiently. Render a fragment shader for multiple batches efficiently.
@ -511,6 +515,8 @@ def _render_shader_batch(
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1] image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
floats: List of float uniforms floats: List of float uniforms
ints: List of int uniforms ints: List of int uniforms
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
Returns: Returns:
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1] List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
@ -533,11 +539,17 @@ def _render_shader_batch(
# Detect multi-pass rendering # Detect multi-pass rendering
num_passes = _detect_pass_count(fragment_code) num_passes = _detect_pass_count(fragment_code)
if bools is None:
bools = []
if curves is None:
curves = []
# Track resources for cleanup # Track resources for cleanup
program = None program = None
fbo = None fbo = None
output_textures = [] output_textures = []
input_textures = [] input_textures = []
curve_textures = []
ping_pong_textures = [] ping_pong_textures = []
ping_pong_fbos = [] ping_pong_fbos = []
@ -624,6 +636,28 @@ def _render_shader_batch(
if loc >= 0: if loc >= 0:
gl.glUniform1i(loc, v) gl.glUniform1i(loc, v)
for i, v in enumerate(bools):
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
if loc >= 0:
gl.glUniform1i(loc, 1 if v else 0)
# Create 1D LUT textures for curves (bound after image texture units)
for i, lut in enumerate(curves):
tex = gl.glGenTextures(1)
curve_textures.append(tex)
unit = MAX_IMAGES + i
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
if loc >= 0:
gl.glUniform1i(loc, unit)
# Get u_pass uniform location for multi-pass # Get u_pass uniform location for multi-pass
pass_loc = gl.glGetUniformLocation(program, "u_pass") pass_loc = gl.glGetUniformLocation(program, "u_pass")
@ -718,6 +752,8 @@ def _render_shader_batch(
for tex in input_textures: for tex in input_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(int(tex))
for tex in curve_textures:
gl.glDeleteTextures(int(tex))
for tex in output_textures: for tex in output_textures:
gl.glDeleteTextures(int(tex)) gl.glDeleteTextures(int(tex))
for tex in ping_pong_textures: for tex in ping_pong_textures:
@ -754,6 +790,20 @@ class GLSLShader(io.ComfyNode):
max=MAX_UNIFORMS, max=MAX_UNIFORMS,
) )
bool_template = io.Autogrow.TemplatePrefix(
io.Boolean.Input("bool", default=False),
prefix="u_bool",
min=0,
max=MAX_BOOLS,
)
curve_template = io.Autogrow.TemplatePrefix(
io.Curve.Input("curve"),
prefix="u_curve",
min=0,
max=MAX_CURVES,
)
return io.Schema( return io.Schema(
node_id="GLSLShader", node_id="GLSLShader",
display_name="GLSL Shader", display_name="GLSL Shader",
@ -762,6 +812,7 @@ class GLSLShader(io.ComfyNode):
"Apply GLSL ES fragment shaders to images. " "Apply GLSL ES fragment shaders to images. "
"u_resolution (vec2) is always available." "u_resolution (vec2) is always available."
), ),
is_experimental=True,
inputs=[ inputs=[
io.String.Input( io.String.Input(
"fragment_shader", "fragment_shader",
@ -796,6 +847,8 @@ class GLSLShader(io.ComfyNode):
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"), io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"), io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"), io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
], ],
outputs=[ outputs=[
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"), io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
@ -813,13 +866,19 @@ class GLSLShader(io.ComfyNode):
images: io.Autogrow.Type, images: io.Autogrow.Type,
floats: io.Autogrow.Type = None, floats: io.Autogrow.Type = None,
ints: io.Autogrow.Type = None, ints: io.Autogrow.Type = None,
bools: io.Autogrow.Type = None,
curves: io.Autogrow.Type = None,
**kwargs, **kwargs,
) -> io.NodeOutput: ) -> io.NodeOutput:
image_list = [v for v in images.values() if v is not None] image_list = [v for v in images.values() if v is not None]
float_list = ( float_list = (
[v if v is not None else 0.0 for v in floats.values()] if floats else [] [v if v is not None else 0.0 for v in floats.values()] if floats else []
) )
int_list = [v if v is not None else 0 for v in ints.values()] if ints else [] int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
if not image_list: if not image_list:
raise ValueError("At least one input image is required") raise ValueError("At least one input image is required")
@ -846,6 +905,8 @@ class GLSLShader(io.ComfyNode):
image_batches, image_batches,
float_list, float_list,
int_list, int_list,
bool_list,
curve_luts,
) )
# Collect outputs into tensors # Collect outputs into tensors