mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
feat: add curve inputs and raise uniform limit for GLSL shader node (#13158)
* feat: add curve inputs and raise uniform limit for GLSL shader node * allow arbitrary size for curve
This commit is contained in:
parent
359559c913
commit
1dc64f3526
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0;
|
||||||
|
uniform float u_float1;
|
||||||
|
uniform float u_float2;
|
||||||
|
uniform float u_float3;
|
||||||
|
uniform float u_float4;
|
||||||
|
uniform float u_float5;
|
||||||
|
uniform float u_float6;
|
||||||
|
uniform float u_float7;
|
||||||
|
uniform float u_float8;
|
||||||
|
uniform bool u_bool0;
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
vec3 rgb2hsl(vec3 c) {
|
||||||
|
float maxC = max(c.r, max(c.g, c.b));
|
||||||
|
float minC = min(c.r, min(c.g, c.b));
|
||||||
|
float l = (maxC + minC) * 0.5;
|
||||||
|
if (maxC == minC) return vec3(0.0, 0.0, l);
|
||||||
|
float d = maxC - minC;
|
||||||
|
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
|
||||||
|
float h;
|
||||||
|
if (maxC == c.r) {
|
||||||
|
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
|
||||||
|
} else if (maxC == c.g) {
|
||||||
|
h = (c.b - c.r) / d + 2.0;
|
||||||
|
} else {
|
||||||
|
h = (c.r - c.g) / d + 4.0;
|
||||||
|
}
|
||||||
|
h /= 6.0;
|
||||||
|
return vec3(h, s, l);
|
||||||
|
}
|
||||||
|
|
||||||
|
float hue2rgb(float p, float q, float t) {
|
||||||
|
if (t < 0.0) t += 1.0;
|
||||||
|
if (t > 1.0) t -= 1.0;
|
||||||
|
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
|
||||||
|
if (t < 1.0 / 2.0) return q;
|
||||||
|
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsl2rgb(vec3 hsl) {
|
||||||
|
float h = hsl.x, s = hsl.y, l = hsl.z;
|
||||||
|
if (s == 0.0) return vec3(l);
|
||||||
|
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
|
||||||
|
float p = 2.0 * l - q;
|
||||||
|
return vec3(
|
||||||
|
hue2rgb(p, q, h + 1.0 / 3.0),
|
||||||
|
hue2rgb(p, q, h),
|
||||||
|
hue2rgb(p, q, h - 1.0 / 3.0)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 tex = texture(u_image0, v_texCoord);
|
||||||
|
vec3 color = tex.rgb;
|
||||||
|
|
||||||
|
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
|
||||||
|
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
|
||||||
|
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
|
||||||
|
|
||||||
|
float maxC = max(color.r, max(color.g, color.b));
|
||||||
|
float minC = min(color.r, min(color.g, color.b));
|
||||||
|
float lightness = (maxC + minC) * 0.5;
|
||||||
|
|
||||||
|
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
|
||||||
|
const float a = 0.25;
|
||||||
|
const float b = 0.333;
|
||||||
|
const float scale = 0.7;
|
||||||
|
|
||||||
|
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
|
||||||
|
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
|
||||||
|
color += sw * shadows + mw * midtones + hw * highlights;
|
||||||
|
|
||||||
|
if (u_bool0) {
|
||||||
|
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
|
||||||
|
hsl.z = lightness;
|
||||||
|
color = hsl2rgb(hsl);
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||||
|
}
|
||||||
46
blueprints/.glsl/Color_Curves_8.frag
Normal file
46
blueprints/.glsl/Color_Curves_8.frag
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
|
||||||
|
uniform sampler2D u_curve1; // Red channel curve
|
||||||
|
uniform sampler2D u_curve2; // Green channel curve
|
||||||
|
uniform sampler2D u_curve3; // Blue channel curve
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
// GIMP-compatible curve lookup with manual linear interpolation.
|
||||||
|
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
|
||||||
|
// index = value * (n_samples - 1)
|
||||||
|
// f = fract(index)
|
||||||
|
// result = (1-f) * samples[floor] + f * samples[ceil]
|
||||||
|
//
|
||||||
|
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
|
||||||
|
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
|
||||||
|
float applyCurve(sampler2D curve, float value) {
|
||||||
|
value = clamp(value, 0.0, 1.0);
|
||||||
|
|
||||||
|
float pos = value * 255.0;
|
||||||
|
int lo = int(floor(pos));
|
||||||
|
int hi = min(lo + 1, 255);
|
||||||
|
float f = pos - float(lo);
|
||||||
|
|
||||||
|
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
|
||||||
|
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
|
||||||
|
|
||||||
|
return a + f * (b - a);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 color = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
// GIMP order: per-channel curves first, then RGB master curve.
|
||||||
|
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
|
||||||
|
// dest = colors_curve( channel_curve( src ) )
|
||||||
|
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r));
|
||||||
|
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g));
|
||||||
|
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b));
|
||||||
|
|
||||||
|
fragColor0 = vec4(color.rgb, color.a);
|
||||||
|
}
|
||||||
1
blueprints/Color Balance.json
Normal file
1
blueprints/Color Balance.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Curves.json
Normal file
1
blueprints/Color Curves.json
Normal file
File diff suppressed because one or more lines are too long
@ -87,7 +87,9 @@ class SizeModeInput(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
MAX_IMAGES = 5 # u_image0-4
|
MAX_IMAGES = 5 # u_image0-4
|
||||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
||||||
|
MAX_BOOLS = 10 # u_bool0-9
|
||||||
|
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
||||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||||
|
|
||||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||||
@ -497,6 +499,8 @@ def _render_shader_batch(
|
|||||||
image_batches: list[list[np.ndarray]],
|
image_batches: list[list[np.ndarray]],
|
||||||
floats: list[float],
|
floats: list[float],
|
||||||
ints: list[int],
|
ints: list[int],
|
||||||
|
bools: list[bool] | None = None,
|
||||||
|
curves: list[np.ndarray] | None = None,
|
||||||
) -> list[list[np.ndarray]]:
|
) -> list[list[np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
Render a fragment shader for multiple batches efficiently.
|
Render a fragment shader for multiple batches efficiently.
|
||||||
@ -511,6 +515,8 @@ def _render_shader_batch(
|
|||||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||||
floats: List of float uniforms
|
floats: List of float uniforms
|
||||||
ints: List of int uniforms
|
ints: List of int uniforms
|
||||||
|
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
||||||
|
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||||
@ -533,11 +539,17 @@ def _render_shader_batch(
|
|||||||
# Detect multi-pass rendering
|
# Detect multi-pass rendering
|
||||||
num_passes = _detect_pass_count(fragment_code)
|
num_passes = _detect_pass_count(fragment_code)
|
||||||
|
|
||||||
|
if bools is None:
|
||||||
|
bools = []
|
||||||
|
if curves is None:
|
||||||
|
curves = []
|
||||||
|
|
||||||
# Track resources for cleanup
|
# Track resources for cleanup
|
||||||
program = None
|
program = None
|
||||||
fbo = None
|
fbo = None
|
||||||
output_textures = []
|
output_textures = []
|
||||||
input_textures = []
|
input_textures = []
|
||||||
|
curve_textures = []
|
||||||
ping_pong_textures = []
|
ping_pong_textures = []
|
||||||
ping_pong_fbos = []
|
ping_pong_fbos = []
|
||||||
|
|
||||||
@ -624,6 +636,28 @@ def _render_shader_batch(
|
|||||||
if loc >= 0:
|
if loc >= 0:
|
||||||
gl.glUniform1i(loc, v)
|
gl.glUniform1i(loc, v)
|
||||||
|
|
||||||
|
for i, v in enumerate(bools):
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1i(loc, 1 if v else 0)
|
||||||
|
|
||||||
|
# Create 1D LUT textures for curves (bound after image texture units)
|
||||||
|
for i, lut in enumerate(curves):
|
||||||
|
tex = gl.glGenTextures(1)
|
||||||
|
curve_textures.append(tex)
|
||||||
|
unit = MAX_IMAGES + i
|
||||||
|
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
||||||
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||||
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||||
|
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1i(loc, unit)
|
||||||
|
|
||||||
# Get u_pass uniform location for multi-pass
|
# Get u_pass uniform location for multi-pass
|
||||||
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
||||||
|
|
||||||
@ -718,6 +752,8 @@ def _render_shader_batch(
|
|||||||
|
|
||||||
for tex in input_textures:
|
for tex in input_textures:
|
||||||
gl.glDeleteTextures(int(tex))
|
gl.glDeleteTextures(int(tex))
|
||||||
|
for tex in curve_textures:
|
||||||
|
gl.glDeleteTextures(int(tex))
|
||||||
for tex in output_textures:
|
for tex in output_textures:
|
||||||
gl.glDeleteTextures(int(tex))
|
gl.glDeleteTextures(int(tex))
|
||||||
for tex in ping_pong_textures:
|
for tex in ping_pong_textures:
|
||||||
@ -754,6 +790,20 @@ class GLSLShader(io.ComfyNode):
|
|||||||
max=MAX_UNIFORMS,
|
max=MAX_UNIFORMS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bool_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.Boolean.Input("bool", default=False),
|
||||||
|
prefix="u_bool",
|
||||||
|
min=0,
|
||||||
|
max=MAX_BOOLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
curve_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.Curve.Input("curve"),
|
||||||
|
prefix="u_curve",
|
||||||
|
min=0,
|
||||||
|
max=MAX_CURVES,
|
||||||
|
)
|
||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="GLSLShader",
|
node_id="GLSLShader",
|
||||||
display_name="GLSL Shader",
|
display_name="GLSL Shader",
|
||||||
@ -762,6 +812,7 @@ class GLSLShader(io.ComfyNode):
|
|||||||
"Apply GLSL ES fragment shaders to images. "
|
"Apply GLSL ES fragment shaders to images. "
|
||||||
"u_resolution (vec2) is always available."
|
"u_resolution (vec2) is always available."
|
||||||
),
|
),
|
||||||
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input(
|
io.String.Input(
|
||||||
"fragment_shader",
|
"fragment_shader",
|
||||||
@ -796,6 +847,8 @@ class GLSLShader(io.ComfyNode):
|
|||||||
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
||||||
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
||||||
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
||||||
|
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
||||||
|
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
||||||
@ -813,13 +866,19 @@ class GLSLShader(io.ComfyNode):
|
|||||||
images: io.Autogrow.Type,
|
images: io.Autogrow.Type,
|
||||||
floats: io.Autogrow.Type = None,
|
floats: io.Autogrow.Type = None,
|
||||||
ints: io.Autogrow.Type = None,
|
ints: io.Autogrow.Type = None,
|
||||||
|
bools: io.Autogrow.Type = None,
|
||||||
|
curves: io.Autogrow.Type = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
|
|
||||||
image_list = [v for v in images.values() if v is not None]
|
image_list = [v for v in images.values() if v is not None]
|
||||||
float_list = (
|
float_list = (
|
||||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||||
)
|
)
|
||||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||||
|
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
||||||
|
|
||||||
|
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
||||||
|
|
||||||
if not image_list:
|
if not image_list:
|
||||||
raise ValueError("At least one input image is required")
|
raise ValueError("At least one input image is required")
|
||||||
@ -846,6 +905,8 @@ class GLSLShader(io.ComfyNode):
|
|||||||
image_batches,
|
image_batches,
|
||||||
float_list,
|
float_list,
|
||||||
int_list,
|
int_list,
|
||||||
|
bool_list,
|
||||||
|
curve_luts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect outputs into tensors
|
# Collect outputs into tensors
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user