Add ColorTransfer -node

This commit is contained in:
kijai 2026-04-01 18:02:23 +03:00
parent cea9dbd5d6
commit e433039102

View File

@ -6,6 +6,7 @@ from PIL import Image
import math import math
from enum import Enum from enum import Enum
from typing import TypedDict, Literal from typing import TypedDict, Literal
import kornia
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
return io.NodeOutput(batched) return io.NodeOutput(batched)
class ColorTransfer(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ColorTransfer",
category="image/postprocessing",
description="Match the colors of one image to another using various algorithms.",
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
inputs=[
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
io.DynamicCombo.Input("source_stats",
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
options=[
io.DynamicCombo.Option("per_frame", []),
io.DynamicCombo.Option("uniform", []),
io.DynamicCombo.Option("target_frame", [
io.Int.Input("target_index", default=0, min=0, max=10000,
tooltip="Frame index used as the source baseline for computing the transform to image_ref"),
]),
]),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Image.Output(display_name="image"),
],
)
@staticmethod
def _to_lab(images, i, device):
return kornia.color.rgb_to_lab(
images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2))
@staticmethod
def _pool_stats(images, device, is_reinhard, eps):
"""Two-pass pooled mean + std/cov across all frames."""
N, C = images.shape[0], images.shape[3]
HW = images.shape[1] * images.shape[2]
mean = torch.zeros(C, 1, device=device, dtype=torch.float32)
for i in range(N):
mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True)
mean /= N
acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32)
for i in range(N):
centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean
if is_reinhard:
acc += (centered * centered).mean(dim=-1, keepdim=True)
else:
acc += centered @ centered.T / HW
if is_reinhard:
return mean, torch.sqrt(acc / N).clamp_min_(eps)
return mean, acc / N
@staticmethod
def _frame_stats(lab_flat, hw, is_reinhard, eps):
"""Per-frame mean + std/cov."""
mean = lab_flat.mean(dim=-1, keepdim=True)
if is_reinhard:
return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps)
centered = lab_flat - mean
return mean, centered @ centered.T / hw
@staticmethod
def _mkl_matrix(cov_s, cov_r, eps):
"""Compute MKL 3x3 transform matrix from source and ref covariances."""
eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s)
sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps)
scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0)
mid = scaled_V.T @ cov_r @ scaled_V
eig_val_m, eig_vec_m = torch.linalg.eigh(mid)
sqrt_m = torch.sqrt(eig_val_m.clamp_min(0))
inv_sqrt_s = 1.0 / sqrt_val_s
inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0)
M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T
return inv_scaled_V @ M_half @ inv_scaled_V.T
@staticmethod
def _histogram_lut(src, ref, bins=256):
"""Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1]."""
s_bins = (src * (bins - 1)).long().clamp(0, bins - 1)
r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1)
s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
ones_s = torch.ones_like(src)
ones_r = torch.ones_like(ref)
s_hist.scatter_add_(1, s_bins, ones_s)
r_hist.scatter_add_(1, r_bins, ones_r)
s_cdf = s_hist.cumsum(1)
s_cdf = s_cdf / s_cdf[:, -1:]
r_cdf = r_hist.cumsum(1)
r_cdf = r_cdf / r_cdf[:, -1:]
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1)
@classmethod
def _pooled_cdf(cls, images, device, num_bins=256):
"""Build pooled CDF across all frames, one frame at a time."""
C = images.shape[3]
hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32)
for i in range(images.shape[0]):
frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1)
hist.scatter_add_(1, bins, torch.ones_like(frame))
cdf = hist.cumsum(1)
return cdf / cdf[:, -1:]
@classmethod
def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B):
"""Build per-frame or uniform LUT transform for histogram mode."""
if stats_mode == 'per_frame':
return None # LUT computed per-frame in the apply loop
r_cdf = cls._pooled_cdf(image_ref, device)
if stats_mode == 'target_frame':
ti = min(target_index, B - 1)
s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device)
else:
s_cdf = cls._pooled_cdf(image_target, device)
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0
@classmethod
def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard):
"""Build transform parameters for Lab-based methods. Returns a transform function."""
eps = 1e-6
B, H, W, C = image_target.shape
B_ref = image_ref.shape[0]
single_ref = B_ref == 1
HW = H * W
HW_ref = image_ref.shape[1] * image_ref.shape[2]
# Precompute ref stats
if single_ref or stats_mode in ('uniform', 'target_frame'):
ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps)
# Uniform/target_frame: precompute single affine transform
if stats_mode in ('uniform', 'target_frame'):
if stats_mode == 'target_frame':
ti = min(target_index, B - 1)
s_lab = cls._to_lab(image_target, ti, device).view(C, -1)
s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps)
else:
s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps)
if is_reinhard:
scale = ref_sc / s_sc
offset = ref_mean - scale * s_mean
return lambda src_flat, **_: src_flat * scale + offset
T = cls._mkl_matrix(s_sc, ref_sc, eps)
offset = ref_mean - T @ s_mean
return lambda src_flat, **_: T @ src_flat + offset
# per_frame
def per_frame_transform(src_flat, frame_idx):
s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps)
if single_ref:
r_mean, r_sc = ref_mean, ref_sc
else:
ri = min(frame_idx, B_ref - 1)
r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps)
centered = src_flat - s_mean
if is_reinhard:
return centered * (r_sc / s_sc) + r_mean
T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps)
return T @ centered + r_mean
return per_frame_transform
@classmethod
def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput:
stats_mode = source_stats["source_stats"]
target_index = source_stats.get("target_index", 0)
if strength == 0 or image_ref is None:
return io.NodeOutput(image_target)
device = comfy.model_management.get_torch_device()
intermediate_device = comfy.model_management.intermediate_device()
intermediate_dtype = comfy.model_management.intermediate_dtype()
B, H, W, C = image_target.shape
B_ref = image_ref.shape[0]
pbar = comfy.utils.ProgressBar(B)
out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype)
if method == 'histogram':
uniform_lut = cls._build_histogram_transform(
image_target, image_ref, device, stats_mode, target_index, B)
for i in range(B):
src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1)
src_flat = src.reshape(C, -1)
if uniform_lut is not None:
lut = uniform_lut
else:
ri = min(i, B_ref - 1)
ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
lut = cls._histogram_lut(src_flat, ref)
bin_idx = (src_flat * 255).long().clamp(0, 255)
matched = lut.gather(1, bin_idx).view(C, H, W)
result = matched if strength == 1.0 else torch.lerp(src, matched, strength)
out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
pbar.update(1)
else:
transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab")
for i in range(B):
src_frame = cls._to_lab(image_target, i, device)
corrected = transform(src_frame.view(C, -1), frame_idx=i)
if strength == 1.0:
result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W))
else:
result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength))
out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
pbar.update(1)
return io.NodeOutput(out)
class PostProcessingExtension(ComfyExtension): class PostProcessingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension):
BatchImagesNode, BatchImagesNode,
BatchMasksNode, BatchMasksNode,
BatchLatentsNode, BatchLatentsNode,
ColorTransfer,
# BatchImagesMasksLatentsNode, # BatchImagesMasksLatentsNode,
] ]