mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-12 19:42:37 +08:00
color correct
This commit is contained in:
parent
85b8ee1390
commit
30c87e2a37
@ -1203,6 +1203,70 @@ class Color(ComfyTypeIO):
|
|||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict()
|
return super().as_dict()
|
||||||
|
|
||||||
|
@comfytype(io_type="COLOR_CORRECT")
|
||||||
|
class ColorCorrect(ComfyTypeIO):
|
||||||
|
Type = dict
|
||||||
|
|
||||||
|
class Input(WidgetInput):
|
||||||
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
|
socketless: bool=True, default: dict=None, advanced: bool=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||||
|
if default is None:
|
||||||
|
self.default = {
|
||||||
|
"temperature": 0,
|
||||||
|
"hue": 0,
|
||||||
|
"brightness": 0,
|
||||||
|
"contrast": 0,
|
||||||
|
"saturation": 0,
|
||||||
|
"gamma": 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict()
|
||||||
|
|
||||||
|
@comfytype(io_type="COLOR_BALANCE")
|
||||||
|
class ColorBalance(ComfyTypeIO):
|
||||||
|
Type = dict
|
||||||
|
|
||||||
|
class Input(WidgetInput):
|
||||||
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
|
socketless: bool=True, default: dict=None, advanced: bool=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||||
|
if default is None:
|
||||||
|
self.default = {
|
||||||
|
"shadows_red": 0,
|
||||||
|
"shadows_green": 0,
|
||||||
|
"shadows_blue": 0,
|
||||||
|
"midtones_red": 0,
|
||||||
|
"midtones_green": 0,
|
||||||
|
"midtones_blue": 0,
|
||||||
|
"highlights_red": 0,
|
||||||
|
"highlights_green": 0,
|
||||||
|
"highlights_blue": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict()
|
||||||
|
|
||||||
|
@comfytype(io_type="COLOR_CURVES")
|
||||||
|
class ColorCurves(ComfyTypeIO):
|
||||||
|
Type = dict
|
||||||
|
|
||||||
|
class Input(WidgetInput):
|
||||||
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
|
socketless: bool=True, default: dict=None, advanced: bool=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||||
|
if default is None:
|
||||||
|
self.default = {
|
||||||
|
"rgb": [[0, 0], [1, 1]],
|
||||||
|
"red": [[0, 0], [1, 1]],
|
||||||
|
"green": [[0, 0], [1, 1]],
|
||||||
|
"blue": [[0, 0], [1, 1]]
|
||||||
|
}
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict()
|
||||||
|
|
||||||
@comfytype(io_type="BOUNDING_BOX")
|
@comfytype(io_type="BOUNDING_BOX")
|
||||||
class BoundingBox(ComfyTypeIO):
|
class BoundingBox(ComfyTypeIO):
|
||||||
Type = dict
|
Type = dict
|
||||||
@ -2141,4 +2205,7 @@ __all__ = [
|
|||||||
"PriceBadgeDepends",
|
"PriceBadgeDepends",
|
||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
"BoundingBox",
|
"BoundingBox",
|
||||||
|
"ColorCorrect",
|
||||||
|
"ColorBalance",
|
||||||
|
"ColorCurves"
|
||||||
]
|
]
|
||||||
|
|||||||
78
comfy_extras/nodes_color_balance.py
Normal file
78
comfy_extras/nodes_color_balance.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
|
|
||||||
|
|
||||||
|
def _smoothstep(edge0: float, edge1: float, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
t = torch.clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0)
|
||||||
|
return t * t * (3.0 - 2.0 * t)
|
||||||
|
|
||||||
|
|
||||||
|
class ColorBalanceNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ColorBalance",
|
||||||
|
display_name="Color Balance",
|
||||||
|
category="image/adjustment",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.ColorBalance.Input("settings"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
|
||||||
|
shadows_red = settings.get("shadows_red", 0)
|
||||||
|
shadows_green = settings.get("shadows_green", 0)
|
||||||
|
shadows_blue = settings.get("shadows_blue", 0)
|
||||||
|
midtones_red = settings.get("midtones_red", 0)
|
||||||
|
midtones_green = settings.get("midtones_green", 0)
|
||||||
|
midtones_blue = settings.get("midtones_blue", 0)
|
||||||
|
highlights_red = settings.get("highlights_red", 0)
|
||||||
|
highlights_green = settings.get("highlights_green", 0)
|
||||||
|
highlights_blue = settings.get("highlights_blue", 0)
|
||||||
|
|
||||||
|
result = image.clone().float()
|
||||||
|
|
||||||
|
# Compute per-pixel luminance
|
||||||
|
luminance = (
|
||||||
|
0.2126 * result[..., 0]
|
||||||
|
+ 0.7152 * result[..., 1]
|
||||||
|
+ 0.0722 * result[..., 2]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute tonal range weights
|
||||||
|
shadow_weight = 1.0 - _smoothstep(0.0, 0.5, luminance)
|
||||||
|
highlight_weight = _smoothstep(0.5, 1.0, luminance)
|
||||||
|
midtone_weight = 1.0 - shadow_weight - highlight_weight
|
||||||
|
|
||||||
|
# Apply offsets per channel
|
||||||
|
for ch, (s, m, h) in enumerate([
|
||||||
|
(shadows_red, midtones_red, highlights_red),
|
||||||
|
(shadows_green, midtones_green, highlights_green),
|
||||||
|
(shadows_blue, midtones_blue, highlights_blue),
|
||||||
|
]):
|
||||||
|
offset = (
|
||||||
|
shadow_weight * (s / 100.0)
|
||||||
|
+ midtone_weight * (m / 100.0)
|
||||||
|
+ highlight_weight * (h / 100.0)
|
||||||
|
)
|
||||||
|
result[..., ch] = result[..., ch] + offset
|
||||||
|
|
||||||
|
result = torch.clamp(result, 0, 1)
|
||||||
|
return io.NodeOutput(result, ui=ui.PreviewImage(result))
|
||||||
|
|
||||||
|
|
||||||
|
class ColorBalanceExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [ColorBalanceNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ColorBalanceExtension:
|
||||||
|
return ColorBalanceExtension()
|
||||||
88
comfy_extras/nodes_color_correct.py
Normal file
88
comfy_extras/nodes_color_correct.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
|
|
||||||
|
|
||||||
|
class ColorCorrectNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ColorCorrect",
|
||||||
|
display_name="Color Correct",
|
||||||
|
category="image/adjustment",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.ColorCorrect.Input("settings"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
|
||||||
|
temperature = settings.get("temperature", 0)
|
||||||
|
hue = settings.get("hue", 0)
|
||||||
|
brightness = settings.get("brightness", 0)
|
||||||
|
contrast = settings.get("contrast", 0)
|
||||||
|
saturation = settings.get("saturation", 0)
|
||||||
|
gamma = settings.get("gamma", 1.0)
|
||||||
|
|
||||||
|
result = image.clone()
|
||||||
|
|
||||||
|
# Brightness: scale RGB values
|
||||||
|
if brightness != 0:
|
||||||
|
factor = 1.0 + brightness / 100.0
|
||||||
|
result = result * factor
|
||||||
|
|
||||||
|
# Contrast: adjust around midpoint
|
||||||
|
if contrast != 0:
|
||||||
|
factor = 1.0 + contrast / 100.0
|
||||||
|
mean = result[..., :3].mean()
|
||||||
|
result[..., :3] = (result[..., :3] - mean) * factor + mean
|
||||||
|
|
||||||
|
# Temperature: shift warm (red+) / cool (blue+)
|
||||||
|
if temperature != 0:
|
||||||
|
temp_factor = temperature / 100.0
|
||||||
|
result[..., 0] = result[..., 0] + temp_factor * 0.1 # Red
|
||||||
|
result[..., 2] = result[..., 2] - temp_factor * 0.1 # Blue
|
||||||
|
|
||||||
|
# Gamma correction
|
||||||
|
if gamma != 1.0:
|
||||||
|
result[..., :3] = torch.pow(torch.clamp(result[..., :3], 0, 1), 1.0 / gamma)
|
||||||
|
|
||||||
|
# Saturation: convert to HSV-like space
|
||||||
|
if saturation != 0:
|
||||||
|
factor = 1.0 + saturation / 100.0
|
||||||
|
gray = result[..., :3].mean(dim=-1, keepdim=True)
|
||||||
|
result[..., :3] = gray + (result[..., :3] - gray) * factor
|
||||||
|
|
||||||
|
# Hue rotation: rotate in RGB space using rotation matrix
|
||||||
|
if hue != 0:
|
||||||
|
angle = np.radians(hue)
|
||||||
|
cos_a = np.cos(angle)
|
||||||
|
sin_a = np.sin(angle)
|
||||||
|
# Rodrigues' rotation formula around (1,1,1)/sqrt(3) axis
|
||||||
|
k = 1.0 / 3.0
|
||||||
|
rotation = torch.tensor([
|
||||||
|
[cos_a + k * (1 - cos_a), k * (1 - cos_a) - sin_a / np.sqrt(3), k * (1 - cos_a) + sin_a / np.sqrt(3)],
|
||||||
|
[k * (1 - cos_a) + sin_a / np.sqrt(3), cos_a + k * (1 - cos_a), k * (1 - cos_a) - sin_a / np.sqrt(3)],
|
||||||
|
[k * (1 - cos_a) - sin_a / np.sqrt(3), k * (1 - cos_a) + sin_a / np.sqrt(3), cos_a + k * (1 - cos_a)]
|
||||||
|
], dtype=result.dtype, device=result.device)
|
||||||
|
rgb = result[..., :3]
|
||||||
|
result[..., :3] = torch.matmul(rgb, rotation.T)
|
||||||
|
|
||||||
|
result = torch.clamp(result, 0, 1)
|
||||||
|
return io.NodeOutput(result, ui=ui.PreviewImage(result))
|
||||||
|
|
||||||
|
|
||||||
|
class ColorCorrectExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [ColorCorrectNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ColorCorrectExtension:
|
||||||
|
return ColorCorrectExtension()
|
||||||
137
comfy_extras/nodes_color_curves.py
Normal file
137
comfy_extras/nodes_color_curves.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
|
|
||||||
|
|
||||||
|
def _monotone_cubic_hermite(xs, ys, x_query):
|
||||||
|
"""Evaluate monotone cubic Hermite interpolation at x_query points."""
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros_like(x_query)
|
||||||
|
if n == 1:
|
||||||
|
return np.full_like(x_query, ys[0])
|
||||||
|
|
||||||
|
# Compute slopes
|
||||||
|
deltas = np.diff(ys) / np.maximum(np.diff(xs), 1e-10)
|
||||||
|
|
||||||
|
# Compute tangents (Fritsch-Carlson)
|
||||||
|
slopes = np.zeros(n)
|
||||||
|
slopes[0] = deltas[0]
|
||||||
|
slopes[-1] = deltas[-1]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
if deltas[i - 1] * deltas[i] <= 0:
|
||||||
|
slopes[i] = 0
|
||||||
|
else:
|
||||||
|
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||||
|
|
||||||
|
# Enforce monotonicity
|
||||||
|
for i in range(n - 1):
|
||||||
|
if deltas[i] == 0:
|
||||||
|
slopes[i] = 0
|
||||||
|
slopes[i + 1] = 0
|
||||||
|
else:
|
||||||
|
alpha = slopes[i] / deltas[i]
|
||||||
|
beta = slopes[i + 1] / deltas[i]
|
||||||
|
s = alpha ** 2 + beta ** 2
|
||||||
|
if s > 9:
|
||||||
|
t = 3 / np.sqrt(s)
|
||||||
|
slopes[i] = t * alpha * deltas[i]
|
||||||
|
slopes[i + 1] = t * beta * deltas[i]
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
result = np.zeros_like(x_query, dtype=np.float64)
|
||||||
|
indices = np.searchsorted(xs, x_query, side='right') - 1
|
||||||
|
indices = np.clip(indices, 0, n - 2)
|
||||||
|
|
||||||
|
for i in range(n - 1):
|
||||||
|
mask = indices == i
|
||||||
|
if not np.any(mask):
|
||||||
|
continue
|
||||||
|
dx = xs[i + 1] - xs[i]
|
||||||
|
if dx == 0:
|
||||||
|
result[mask] = ys[i]
|
||||||
|
continue
|
||||||
|
t = (x_query[mask] - xs[i]) / dx
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
result[mask] = h00 * ys[i] + h10 * dx * slopes[i] + h01 * ys[i + 1] + h11 * dx * slopes[i + 1]
|
||||||
|
|
||||||
|
# Clamp edges
|
||||||
|
result[x_query <= xs[0]] = ys[0]
|
||||||
|
result[x_query >= xs[-1]] = ys[-1]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _build_lut(points):
|
||||||
|
"""Build a 256-entry LUT from curve control points in [0,1] space."""
|
||||||
|
if not points or len(points) < 2:
|
||||||
|
return np.arange(256, dtype=np.float64) / 255.0
|
||||||
|
|
||||||
|
pts = sorted(points, key=lambda p: p[0])
|
||||||
|
xs = np.array([p[0] for p in pts], dtype=np.float64)
|
||||||
|
ys = np.array([p[1] for p in pts], dtype=np.float64)
|
||||||
|
|
||||||
|
x_query = np.linspace(0, 1, 256)
|
||||||
|
lut = _monotone_cubic_hermite(xs, ys, x_query)
|
||||||
|
return np.clip(lut, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class ColorCurvesNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ColorCurves",
|
||||||
|
display_name="Color Curves",
|
||||||
|
category="image/adjustment",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.ColorCurves.Input("settings"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image: torch.Tensor, settings: dict) -> io.NodeOutput:
|
||||||
|
rgb_pts = settings.get("rgb", [[0, 0], [1, 1]])
|
||||||
|
red_pts = settings.get("red", [[0, 0], [1, 1]])
|
||||||
|
green_pts = settings.get("green", [[0, 0], [1, 1]])
|
||||||
|
blue_pts = settings.get("blue", [[0, 0], [1, 1]])
|
||||||
|
|
||||||
|
rgb_lut = _build_lut(rgb_pts)
|
||||||
|
red_lut = _build_lut(red_pts)
|
||||||
|
green_lut = _build_lut(green_pts)
|
||||||
|
blue_lut = _build_lut(blue_pts)
|
||||||
|
|
||||||
|
# Convert to numpy for LUT application
|
||||||
|
img_np = image.cpu().numpy().copy()
|
||||||
|
|
||||||
|
# Apply per-channel curves then RGB master curve
|
||||||
|
for ch, ch_lut in enumerate([red_lut, green_lut, blue_lut]):
|
||||||
|
# Per-channel curve
|
||||||
|
indices = np.clip(img_np[..., ch] * 255, 0, 255).astype(np.int32)
|
||||||
|
img_np[..., ch] = ch_lut[indices]
|
||||||
|
# RGB master curve
|
||||||
|
indices = np.clip(img_np[..., ch] * 255, 0, 255).astype(np.int32)
|
||||||
|
img_np[..., ch] = rgb_lut[indices]
|
||||||
|
|
||||||
|
result = torch.from_numpy(np.clip(img_np, 0, 1)).to(image.device, dtype=image.dtype)
|
||||||
|
return io.NodeOutput(result, ui=ui.PreviewImage(result))
|
||||||
|
|
||||||
|
|
||||||
|
class ColorCurvesExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [ColorCurvesNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ColorCurvesExtension:
|
||||||
|
return ColorCurvesExtension()
|
||||||
3
nodes.py
3
nodes.py
@ -2435,6 +2435,9 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_lora_debug.py",
|
"nodes_lora_debug.py",
|
||||||
"nodes_color.py",
|
"nodes_color.py",
|
||||||
"nodes_toolkit.py",
|
"nodes_toolkit.py",
|
||||||
|
"nodes_color_correct.py",
|
||||||
|
"nodes_color_balance.py",
|
||||||
|
"nodes_color_curves.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user