mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
feat: SUPIR model support (CORE-17) (#13250)
This commit is contained in:
parent
3086026401
commit
b9dedea57d
@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
if "middle_block_after_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_block_after_patch"]
|
||||
for p in patch:
|
||||
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
|
||||
"timesteps": timesteps, "transformer_options": transformer_options})
|
||||
h = out["h"]
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
@ -905,8 +912,9 @@ class UNetModel(nn.Module):
|
||||
for p in patch:
|
||||
h, hsp = p(h, hsp, transformer_options)
|
||||
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if hsp is not None:
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
|
||||
0
comfy/ldm/supir/__init__.py
Normal file
0
comfy/ldm/supir/__init__.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
@ -0,0 +1,226 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
class ZeroSFT(nn.Module):
|
||||
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
ks = 3
|
||||
pw = ks // 2
|
||||
|
||||
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
|
||||
|
||||
nhidden = 128
|
||||
|
||||
self.mlp_shared = nn.Sequential(
|
||||
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||
|
||||
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
|
||||
self.pre_concat = bool(concat_channels != 0)
|
||||
|
||||
def forward(self, c, h, h_ori=None, control_scale=1):
|
||||
if h_ori is not None and self.pre_concat:
|
||||
h_raw = torch.cat([h_ori, h], dim=1)
|
||||
else:
|
||||
h_raw = h
|
||||
|
||||
h = h + self.zero_conv(c)
|
||||
if h_ori is not None and self.pre_concat:
|
||||
h = torch.cat([h_ori, h], dim=1)
|
||||
actv = self.mlp_shared(c)
|
||||
gamma = self.zero_mul(actv)
|
||||
beta = self.zero_add(actv)
|
||||
h = self.param_free_norm(h)
|
||||
h = torch.addcmul(h + beta, h, gamma)
|
||||
if h_ori is not None and not self.pre_concat:
|
||||
h = torch.cat([h_ori, h], dim=1)
|
||||
return torch.lerp(h_raw, h, control_scale)
|
||||
|
||||
|
||||
class _CrossAttnInner(nn.Module):
|
||||
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
|
||||
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, context):
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
return self.to_out(optimized_attention(q, k, v, self.heads))
|
||||
|
||||
|
||||
class ZeroCrossAttn(nn.Module):
|
||||
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
heads = query_dim // 64
|
||||
dim_head = 64
|
||||
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
|
||||
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
|
||||
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, context, x, control_scale=1):
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
|
||||
x = self.attn(
|
||||
self.norm1(x).flatten(2).transpose(1, 2),
|
||||
self.norm2(context).flatten(2).transpose(1, 2),
|
||||
).transpose(1, 2).unflatten(2, (h, w))
|
||||
|
||||
return x_in + x * control_scale
|
||||
|
||||
|
||||
class GLVControl(nn.Module):
|
||||
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
model_channels=320,
|
||||
num_res_blocks=2,
|
||||
attention_resolutions=(4, 2),
|
||||
channel_mult=(1, 2, 4),
|
||||
num_head_channels=64,
|
||||
transformer_depth=(1, 2, 10),
|
||||
context_dim=2048,
|
||||
adm_in_channels=2816,
|
||||
use_linear_in_transformer=True,
|
||||
use_checkpoint=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_channels = model_channels
|
||||
time_embed_dim = model_channels * 4
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
])
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
num_heads = ch // num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||
depth=transformer_depth[level], context_dim=context_dim,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
if level != len(channel_mult) - 1:
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
|
||||
)
|
||||
)
|
||||
ds *= 2
|
||||
|
||||
num_heads = ch // num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||
depth=transformer_depth[-1], context_dim=context_dim,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dtype=dtype, device=device, operations=operations),
|
||||
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||
)
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb) + self.label_emb(y)
|
||||
|
||||
guided_hint = self.input_hint_block(x, emb, context)
|
||||
|
||||
hs = []
|
||||
h = xt
|
||||
for module in self.input_blocks:
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
h += guided_hint
|
||||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
hs.append(h)
|
||||
return hs
|
||||
|
||||
|
||||
class SUPIR(nn.Module):
|
||||
"""
|
||||
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
|
||||
State dict keys match the original SUPIR checkpoint layout:
|
||||
control_model.* -> GLVControl
|
||||
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||
"""
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
|
||||
|
||||
project_channel_scale = 2
|
||||
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
||||
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
|
||||
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
||||
cross_attn_insert_idx = [6, 3]
|
||||
|
||||
self.project_modules = nn.ModuleList()
|
||||
for i in range(len(cond_output_channels)):
|
||||
self.project_modules.append(ZeroSFT(
|
||||
project_channels[i], cond_output_channels[i],
|
||||
concat_channels=concat_channels[i],
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
))
|
||||
|
||||
for i in cross_attn_insert_idx:
|
||||
self.project_modules.insert(i, ZeroCrossAttn(
|
||||
cond_output_channels[i], concat_channels[i],
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
))
|
||||
103
comfy/ldm/supir/supir_patch.py
Normal file
103
comfy/ldm/supir/supir_patch.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
|
||||
|
||||
|
||||
class SUPIRPatch:
|
||||
"""
|
||||
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
|
||||
Runs GLVControl lazily on first patch invocation per step, applies adapters through
|
||||
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
|
||||
"""
|
||||
SIGMA_MAX = 14.6146
|
||||
|
||||
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
|
||||
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
|
||||
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||
self.hint_latent = hint_latent # encoded LQ image latent
|
||||
self.strength_start = strength_start
|
||||
self.strength_end = strength_end
|
||||
self.cached_features = None
|
||||
self.adapter_idx = 0
|
||||
self.control_idx = 0
|
||||
self.current_control_idx = 0
|
||||
self.active = True
|
||||
|
||||
def _ensure_features(self, kwargs):
|
||||
"""Run GLVControl on first call per step, cache results."""
|
||||
if self.cached_features is not None:
|
||||
return
|
||||
x = kwargs["x"]
|
||||
b = x.shape[0]
|
||||
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
|
||||
if hint.shape[0] != b:
|
||||
hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b]
|
||||
self.cached_features = self.model_patch.model.control_model(
|
||||
hint, kwargs["timesteps"], x,
|
||||
kwargs["context"], kwargs["y"]
|
||||
)
|
||||
self.adapter_idx = len(self.project_modules) - 1
|
||||
self.control_idx = len(self.cached_features) - 1
|
||||
|
||||
def _get_control_scale(self, kwargs):
|
||||
if self.strength_start == self.strength_end:
|
||||
return self.strength_end
|
||||
sigma = kwargs["transformer_options"].get("sigmas")
|
||||
if sigma is None:
|
||||
return self.strength_end
|
||||
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
|
||||
t = min(s / self.SIGMA_MAX, 1.0)
|
||||
return t * (self.strength_start - self.strength_end) + self.strength_end
|
||||
|
||||
def middle_after(self, kwargs):
|
||||
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
|
||||
self.cached_features = None # reset from previous step
|
||||
self.current_scale = self._get_control_scale(kwargs)
|
||||
self.active = self.current_scale > 0
|
||||
if not self.active:
|
||||
return {"h": kwargs["h"]}
|
||||
self._ensure_features(kwargs)
|
||||
h = kwargs["h"]
|
||||
h = self.project_modules[self.adapter_idx](
|
||||
self.cached_features[self.control_idx], h, control_scale=self.current_scale
|
||||
)
|
||||
self.adapter_idx -= 1
|
||||
self.control_idx -= 1
|
||||
return {"h": h}
|
||||
|
||||
def output_block(self, h, hsp, transformer_options):
|
||||
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
|
||||
if not self.active:
|
||||
return h, hsp
|
||||
self.current_control_idx = self.control_idx
|
||||
h = self.project_modules[self.adapter_idx](
|
||||
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
|
||||
)
|
||||
self.adapter_idx -= 1
|
||||
self.control_idx -= 1
|
||||
return h, None
|
||||
|
||||
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
|
||||
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
|
||||
block_type, _ = transformer_options["block"]
|
||||
if block_type == "output" and self.active and self.cached_features is not None:
|
||||
x = self.project_modules[self.adapter_idx](
|
||||
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
|
||||
)
|
||||
self.adapter_idx -= 1
|
||||
return layer(x, output_shape=output_shape)
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
self.cached_features = None
|
||||
if self.hint_latent is not None:
|
||||
self.hint_latent = self.hint_latent.to(device_or_dtype)
|
||||
return self
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
def register(self, model_patcher):
|
||||
"""Register all patches on a cloned model patcher."""
|
||||
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
|
||||
model_patcher.set_model_output_block_patch(self.output_block)
|
||||
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")
|
||||
@ -506,6 +506,10 @@ class ModelPatcher:
|
||||
def set_model_noise_refiner_patch(self, patch):
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
def set_model_middle_block_after_patch(self, patch):
|
||||
self.set_model_patch(patch, "middle_block_after_patch")
|
||||
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
|
||||
@ -7,7 +7,10 @@ import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
import comfy.ldm.lumina.controlnet
|
||||
import comfy.ldm.supir.supir_modules
|
||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||
from comfy_api.latest import io
|
||||
from comfy.ldm.supir.supir_patch import SUPIRPatch
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@ -266,6 +269,27 @@ class ModelPatchLoader:
|
||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
||||
prefix_replace = {}
|
||||
if 'model.control_model.input_hint_block.0.weight' in sd:
|
||||
prefix_replace["model.control_model."] = "control_model."
|
||||
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
||||
else:
|
||||
prefix_replace["control_model."] = "control_model."
|
||||
prefix_replace["project_modules."] = "project_modules."
|
||||
|
||||
# Extract denoise_encoder weights before filter_keys discards them
|
||||
de_prefix = "first_stage_model.denoise_encoder."
|
||||
denoise_encoder_sd = {}
|
||||
for k in list(sd.keys()):
|
||||
if k.startswith(de_prefix):
|
||||
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
|
||||
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
||||
sd.pop("control_model.mask_LQ", None)
|
||||
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
if denoise_encoder_sd:
|
||||
model.denoise_encoder_sd = denoise_encoder_sd
|
||||
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||
@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class SUPIRApply(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SUPIRApply",
|
||||
category="model_patches/supir",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.ModelPatch.Input("model_patch"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||
tooltip="Control strength at the start of sampling (high sigma)."),
|
||||
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
|
||||
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
|
||||
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
|
||||
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
|
||||
tooltip="Sigma threshold below which restore_cfg is disabled."),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
|
||||
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
|
||||
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
|
||||
if not denoise_sd:
|
||||
return vae.encode(image)
|
||||
|
||||
# Clone VAE patcher, apply denoise_encoder weights to clone, encode
|
||||
orig_patcher = vae.patcher
|
||||
vae.patcher = orig_patcher.clone()
|
||||
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
|
||||
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
|
||||
try:
|
||||
return vae.encode(image)
|
||||
finally:
|
||||
vae.patcher = orig_patcher
|
||||
|
||||
@classmethod
|
||||
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
|
||||
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
||||
model_patched = model.clone()
|
||||
hint_latent = model.get_model_object("latent_format").process_in(
|
||||
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
|
||||
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
||||
patch.register(model_patched)
|
||||
|
||||
if restore_cfg > 0.0:
|
||||
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
|
||||
latent_format = model.get_model_object("latent_format")
|
||||
decoded = vae.decode(latent_format.process_out(hint_latent))
|
||||
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
|
||||
sigma_max = 14.6146
|
||||
|
||||
def restore_cfg_function(args):
|
||||
denoised = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
if sigma.dim() > 0:
|
||||
s = sigma[0].item()
|
||||
else:
|
||||
s = sigma.item()
|
||||
if s > restore_cfg_s_tmin:
|
||||
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
|
||||
b = denoised.shape[0]
|
||||
if ref.shape[0] != b:
|
||||
ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b]
|
||||
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
|
||||
d_center = denoised - ref
|
||||
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
|
||||
return denoised
|
||||
|
||||
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
|
||||
|
||||
return io.NodeOutput(model_patched)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
"ZImageFunControlnet": ZImageFunControlnet,
|
||||
"USOStyleReference": USOStyleReference,
|
||||
"SUPIRApply": SUPIRApply,
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ from PIL import Image
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TypedDict, Literal
|
||||
import kornia
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
|
||||
class ColorTransfer(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ColorTransfer",
|
||||
category="image/postprocessing",
|
||||
description="Match the colors of one image to another using various algorithms.",
|
||||
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||
inputs=[
|
||||
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||
io.DynamicCombo.Input("source_stats",
|
||||
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||
options=[
|
||||
io.DynamicCombo.Option("per_frame", []),
|
||||
io.DynamicCombo.Option("uniform", []),
|
||||
io.DynamicCombo.Option("target_frame", [
|
||||
io.Int.Input("target_index", default=0, min=0, max=10000,
|
||||
tooltip="Frame index used as the source baseline for computing the transform to image_ref"),
|
||||
]),
|
||||
]),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="image"),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _to_lab(images, i, device):
|
||||
return kornia.color.rgb_to_lab(
|
||||
images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2))
|
||||
|
||||
@staticmethod
|
||||
def _pool_stats(images, device, is_reinhard, eps):
|
||||
"""Two-pass pooled mean + std/cov across all frames."""
|
||||
N, C = images.shape[0], images.shape[3]
|
||||
HW = images.shape[1] * images.shape[2]
|
||||
mean = torch.zeros(C, 1, device=device, dtype=torch.float32)
|
||||
for i in range(N):
|
||||
mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True)
|
||||
mean /= N
|
||||
acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32)
|
||||
for i in range(N):
|
||||
centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean
|
||||
if is_reinhard:
|
||||
acc += (centered * centered).mean(dim=-1, keepdim=True)
|
||||
else:
|
||||
acc += centered @ centered.T / HW
|
||||
if is_reinhard:
|
||||
return mean, torch.sqrt(acc / N).clamp_min_(eps)
|
||||
return mean, acc / N
|
||||
|
||||
@staticmethod
|
||||
def _frame_stats(lab_flat, hw, is_reinhard, eps):
|
||||
"""Per-frame mean + std/cov."""
|
||||
mean = lab_flat.mean(dim=-1, keepdim=True)
|
||||
if is_reinhard:
|
||||
return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps)
|
||||
centered = lab_flat - mean
|
||||
return mean, centered @ centered.T / hw
|
||||
|
||||
@staticmethod
|
||||
def _mkl_matrix(cov_s, cov_r, eps):
|
||||
"""Compute MKL 3x3 transform matrix from source and ref covariances."""
|
||||
eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s)
|
||||
sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps)
|
||||
|
||||
scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0)
|
||||
mid = scaled_V.T @ cov_r @ scaled_V
|
||||
eig_val_m, eig_vec_m = torch.linalg.eigh(mid)
|
||||
sqrt_m = torch.sqrt(eig_val_m.clamp_min(0))
|
||||
|
||||
inv_sqrt_s = 1.0 / sqrt_val_s
|
||||
inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0)
|
||||
M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T
|
||||
return inv_scaled_V @ M_half @ inv_scaled_V.T
|
||||
|
||||
@staticmethod
|
||||
def _histogram_lut(src, ref, bins=256):
|
||||
"""Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1]."""
|
||||
s_bins = (src * (bins - 1)).long().clamp(0, bins - 1)
|
||||
r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1)
|
||||
s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
|
||||
r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
|
||||
ones_s = torch.ones_like(src)
|
||||
ones_r = torch.ones_like(ref)
|
||||
s_hist.scatter_add_(1, s_bins, ones_s)
|
||||
r_hist.scatter_add_(1, r_bins, ones_r)
|
||||
s_cdf = s_hist.cumsum(1)
|
||||
s_cdf = s_cdf / s_cdf[:, -1:]
|
||||
r_cdf = r_hist.cumsum(1)
|
||||
r_cdf = r_cdf / r_cdf[:, -1:]
|
||||
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1)
|
||||
|
||||
@classmethod
|
||||
def _pooled_cdf(cls, images, device, num_bins=256):
|
||||
"""Build pooled CDF across all frames, one frame at a time."""
|
||||
C = images.shape[3]
|
||||
hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32)
|
||||
for i in range(images.shape[0]):
|
||||
frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
|
||||
bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1)
|
||||
hist.scatter_add_(1, bins, torch.ones_like(frame))
|
||||
cdf = hist.cumsum(1)
|
||||
return cdf / cdf[:, -1:]
|
||||
|
||||
@classmethod
|
||||
def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B):
|
||||
"""Build per-frame or uniform LUT transform for histogram mode."""
|
||||
if stats_mode == 'per_frame':
|
||||
return None # LUT computed per-frame in the apply loop
|
||||
|
||||
r_cdf = cls._pooled_cdf(image_ref, device)
|
||||
if stats_mode == 'target_frame':
|
||||
ti = min(target_index, B - 1)
|
||||
s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device)
|
||||
else:
|
||||
s_cdf = cls._pooled_cdf(image_target, device)
|
||||
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0
|
||||
|
||||
@classmethod
|
||||
def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard):
|
||||
"""Build transform parameters for Lab-based methods. Returns a transform function."""
|
||||
eps = 1e-6
|
||||
B, H, W, C = image_target.shape
|
||||
B_ref = image_ref.shape[0]
|
||||
single_ref = B_ref == 1
|
||||
HW = H * W
|
||||
HW_ref = image_ref.shape[1] * image_ref.shape[2]
|
||||
|
||||
# Precompute ref stats
|
||||
if single_ref or stats_mode in ('uniform', 'target_frame'):
|
||||
ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps)
|
||||
|
||||
# Uniform/target_frame: precompute single affine transform
|
||||
if stats_mode in ('uniform', 'target_frame'):
|
||||
if stats_mode == 'target_frame':
|
||||
ti = min(target_index, B - 1)
|
||||
s_lab = cls._to_lab(image_target, ti, device).view(C, -1)
|
||||
s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps)
|
||||
else:
|
||||
s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps)
|
||||
|
||||
if is_reinhard:
|
||||
scale = ref_sc / s_sc
|
||||
offset = ref_mean - scale * s_mean
|
||||
return lambda src_flat, **_: src_flat * scale + offset
|
||||
T = cls._mkl_matrix(s_sc, ref_sc, eps)
|
||||
offset = ref_mean - T @ s_mean
|
||||
return lambda src_flat, **_: T @ src_flat + offset
|
||||
|
||||
# per_frame
|
||||
def per_frame_transform(src_flat, frame_idx):
|
||||
s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps)
|
||||
|
||||
if single_ref:
|
||||
r_mean, r_sc = ref_mean, ref_sc
|
||||
else:
|
||||
ri = min(frame_idx, B_ref - 1)
|
||||
r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps)
|
||||
|
||||
centered = src_flat - s_mean
|
||||
if is_reinhard:
|
||||
return centered * (r_sc / s_sc) + r_mean
|
||||
T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps)
|
||||
return T @ centered + r_mean
|
||||
|
||||
return per_frame_transform
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput:
|
||||
stats_mode = source_stats["source_stats"]
|
||||
target_index = source_stats.get("target_index", 0)
|
||||
|
||||
if strength == 0 or image_ref is None:
|
||||
return io.NodeOutput(image_target)
|
||||
|
||||
device = comfy.model_management.get_torch_device()
|
||||
intermediate_device = comfy.model_management.intermediate_device()
|
||||
intermediate_dtype = comfy.model_management.intermediate_dtype()
|
||||
|
||||
B, H, W, C = image_target.shape
|
||||
B_ref = image_ref.shape[0]
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype)
|
||||
|
||||
if method == 'histogram':
|
||||
uniform_lut = cls._build_histogram_transform(
|
||||
image_target, image_ref, device, stats_mode, target_index, B)
|
||||
|
||||
for i in range(B):
|
||||
src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1)
|
||||
src_flat = src.reshape(C, -1)
|
||||
if uniform_lut is not None:
|
||||
lut = uniform_lut
|
||||
else:
|
||||
ri = min(i, B_ref - 1)
|
||||
ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
|
||||
lut = cls._histogram_lut(src_flat, ref)
|
||||
bin_idx = (src_flat * 255).long().clamp(0, 255)
|
||||
matched = lut.gather(1, bin_idx).view(C, H, W)
|
||||
result = matched if strength == 1.0 else torch.lerp(src, matched, strength)
|
||||
out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
|
||||
pbar.update(1)
|
||||
else:
|
||||
transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab")
|
||||
|
||||
for i in range(B):
|
||||
src_frame = cls._to_lab(image_target, i, device)
|
||||
corrected = transform(src_frame.view(C, -1), frame_idx=i)
|
||||
if strength == 1.0:
|
||||
result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W))
|
||||
else:
|
||||
result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength))
|
||||
out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
|
||||
pbar.update(1)
|
||||
|
||||
return io.NodeOutput(out)
|
||||
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension):
|
||||
BatchImagesNode,
|
||||
BatchMasksNode,
|
||||
BatchLatentsNode,
|
||||
ColorTransfer,
|
||||
# BatchImagesMasksLatentsNode,
|
||||
]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user