mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-22 00:12:34 +08:00
feat: SUPIR model support (CORE-17) (#13250)
This commit is contained in:
parent
3086026401
commit
b9dedea57d
@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
|
|||||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||||
for layer in ts:
|
for layer in ts:
|
||||||
|
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||||
|
found_patched = False
|
||||||
|
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||||
|
if isinstance(layer, class_type):
|
||||||
|
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||||
|
found_patched = True
|
||||||
|
break
|
||||||
|
if found_patched:
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(layer, VideoResBlock):
|
if isinstance(layer, VideoResBlock):
|
||||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||||
elif isinstance(layer, TimestepBlock):
|
elif isinstance(layer, TimestepBlock):
|
||||||
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
|||||||
elif isinstance(layer, Upsample):
|
elif isinstance(layer, Upsample):
|
||||||
x = layer(x, output_shape=output_shape)
|
x = layer(x, output_shape=output_shape)
|
||||||
else:
|
else:
|
||||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
|
||||||
found_patched = False
|
|
||||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
|
||||||
if isinstance(layer, class_type):
|
|
||||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
|
||||||
found_patched = True
|
|
||||||
break
|
|
||||||
if found_patched:
|
|
||||||
continue
|
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
|
|||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
h = apply_control(h, control, 'middle')
|
h = apply_control(h, control, 'middle')
|
||||||
|
|
||||||
|
if "middle_block_after_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["middle_block_after_patch"]
|
||||||
|
for p in patch:
|
||||||
|
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
|
||||||
|
"timesteps": timesteps, "transformer_options": transformer_options})
|
||||||
|
h = out["h"]
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
@ -905,8 +912,9 @@ class UNetModel(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
h, hsp = p(h, hsp, transformer_options)
|
h, hsp = p(h, hsp, transformer_options)
|
||||||
|
|
||||||
h = th.cat([h, hsp], dim=1)
|
if hsp is not None:
|
||||||
del hsp
|
h = th.cat([h, hsp], dim=1)
|
||||||
|
del hsp
|
||||||
if len(hs) > 0:
|
if len(hs) > 0:
|
||||||
output_shape = hs[-1].shape
|
output_shape = hs[-1].shape
|
||||||
else:
|
else:
|
||||||
|
|||||||
0
comfy/ldm/supir/__init__.py
Normal file
0
comfy/ldm/supir/__init__.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroSFT(nn.Module):
|
||||||
|
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
ks = 3
|
||||||
|
pw = ks // 2
|
||||||
|
|
||||||
|
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
nhidden = 128
|
||||||
|
|
||||||
|
self.mlp_shared = nn.Sequential(
|
||||||
|
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||||
|
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
|
||||||
|
self.pre_concat = bool(concat_channels != 0)
|
||||||
|
|
||||||
|
def forward(self, c, h, h_ori=None, control_scale=1):
|
||||||
|
if h_ori is not None and self.pre_concat:
|
||||||
|
h_raw = torch.cat([h_ori, h], dim=1)
|
||||||
|
else:
|
||||||
|
h_raw = h
|
||||||
|
|
||||||
|
h = h + self.zero_conv(c)
|
||||||
|
if h_ori is not None and self.pre_concat:
|
||||||
|
h = torch.cat([h_ori, h], dim=1)
|
||||||
|
actv = self.mlp_shared(c)
|
||||||
|
gamma = self.zero_mul(actv)
|
||||||
|
beta = self.zero_add(actv)
|
||||||
|
h = self.param_free_norm(h)
|
||||||
|
h = torch.addcmul(h + beta, h, gamma)
|
||||||
|
if h_ori is not None and not self.pre_concat:
|
||||||
|
h = torch.cat([h_ori, h], dim=1)
|
||||||
|
return torch.lerp(h_raw, h, control_scale)
|
||||||
|
|
||||||
|
|
||||||
|
class _CrossAttnInner(nn.Module):
|
||||||
|
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
|
||||||
|
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, context):
|
||||||
|
q = self.to_q(x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
return self.to_out(optimized_attention(q, k, v, self.heads))
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroCrossAttn(nn.Module):
|
||||||
|
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
heads = query_dim // 64
|
||||||
|
dim_head = 64
|
||||||
|
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, context, x, control_scale=1):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
|
||||||
|
x = self.attn(
|
||||||
|
self.norm1(x).flatten(2).transpose(1, 2),
|
||||||
|
self.norm2(context).flatten(2).transpose(1, 2),
|
||||||
|
).transpose(1, 2).unflatten(2, (h, w))
|
||||||
|
|
||||||
|
return x_in + x * control_scale
|
||||||
|
|
||||||
|
|
||||||
|
class GLVControl(nn.Module):
|
||||||
|
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=4,
|
||||||
|
model_channels=320,
|
||||||
|
num_res_blocks=2,
|
||||||
|
attention_resolutions=(4, 2),
|
||||||
|
channel_mult=(1, 2, 4),
|
||||||
|
num_head_channels=64,
|
||||||
|
transformer_depth=(1, 2, 10),
|
||||||
|
context_dim=2048,
|
||||||
|
adm_in_channels=2816,
|
||||||
|
use_linear_in_transformer=True,
|
||||||
|
use_checkpoint=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_channels = model_channels
|
||||||
|
time_embed_dim = model_channels * 4
|
||||||
|
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
nn.Sequential(
|
||||||
|
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList([
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
])
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for nr in range(num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
layers.append(
|
||||||
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||||
|
depth=transformer_depth[level], context_dim=context_dim,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ds *= 2
|
||||||
|
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||||
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||||
|
depth=transformer_depth[-1], context_dim=context_dim,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dtype=dtype, device=device, operations=operations),
|
||||||
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_hint_block = TimestepEmbedSequential(
|
||||||
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
||||||
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
|
emb = self.time_embed(t_emb) + self.label_emb(y)
|
||||||
|
|
||||||
|
guided_hint = self.input_hint_block(x, emb, context)
|
||||||
|
|
||||||
|
hs = []
|
||||||
|
h = xt
|
||||||
|
for module in self.input_blocks:
|
||||||
|
if guided_hint is not None:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
h += guided_hint
|
||||||
|
guided_hint = None
|
||||||
|
else:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
h = self.middle_block(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIR(nn.Module):
|
||||||
|
"""
|
||||||
|
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
|
||||||
|
State dict keys match the original SUPIR checkpoint layout:
|
||||||
|
control_model.* -> GLVControl
|
||||||
|
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||||
|
"""
|
||||||
|
def __init__(self, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
project_channel_scale = 2
|
||||||
|
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
||||||
|
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
|
||||||
|
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
||||||
|
cross_attn_insert_idx = [6, 3]
|
||||||
|
|
||||||
|
self.project_modules = nn.ModuleList()
|
||||||
|
for i in range(len(cond_output_channels)):
|
||||||
|
self.project_modules.append(ZeroSFT(
|
||||||
|
project_channels[i], cond_output_channels[i],
|
||||||
|
concat_channels=concat_channels[i],
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
))
|
||||||
|
|
||||||
|
for i in cross_attn_insert_idx:
|
||||||
|
self.project_modules.insert(i, ZeroCrossAttn(
|
||||||
|
cond_output_channels[i], concat_channels[i],
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
))
|
||||||
103
comfy/ldm/supir/supir_patch.py
Normal file
103
comfy/ldm/supir/supir_patch.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIRPatch:
|
||||||
|
"""
|
||||||
|
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
|
||||||
|
Runs GLVControl lazily on first patch invocation per step, applies adapters through
|
||||||
|
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
|
||||||
|
"""
|
||||||
|
SIGMA_MAX = 14.6146
|
||||||
|
|
||||||
|
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
|
||||||
|
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
|
||||||
|
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||||
|
self.hint_latent = hint_latent # encoded LQ image latent
|
||||||
|
self.strength_start = strength_start
|
||||||
|
self.strength_end = strength_end
|
||||||
|
self.cached_features = None
|
||||||
|
self.adapter_idx = 0
|
||||||
|
self.control_idx = 0
|
||||||
|
self.current_control_idx = 0
|
||||||
|
self.active = True
|
||||||
|
|
||||||
|
def _ensure_features(self, kwargs):
|
||||||
|
"""Run GLVControl on first call per step, cache results."""
|
||||||
|
if self.cached_features is not None:
|
||||||
|
return
|
||||||
|
x = kwargs["x"]
|
||||||
|
b = x.shape[0]
|
||||||
|
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
|
||||||
|
if hint.shape[0] != b:
|
||||||
|
hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b]
|
||||||
|
self.cached_features = self.model_patch.model.control_model(
|
||||||
|
hint, kwargs["timesteps"], x,
|
||||||
|
kwargs["context"], kwargs["y"]
|
||||||
|
)
|
||||||
|
self.adapter_idx = len(self.project_modules) - 1
|
||||||
|
self.control_idx = len(self.cached_features) - 1
|
||||||
|
|
||||||
|
def _get_control_scale(self, kwargs):
|
||||||
|
if self.strength_start == self.strength_end:
|
||||||
|
return self.strength_end
|
||||||
|
sigma = kwargs["transformer_options"].get("sigmas")
|
||||||
|
if sigma is None:
|
||||||
|
return self.strength_end
|
||||||
|
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
|
||||||
|
t = min(s / self.SIGMA_MAX, 1.0)
|
||||||
|
return t * (self.strength_start - self.strength_end) + self.strength_end
|
||||||
|
|
||||||
|
def middle_after(self, kwargs):
|
||||||
|
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
|
||||||
|
self.cached_features = None # reset from previous step
|
||||||
|
self.current_scale = self._get_control_scale(kwargs)
|
||||||
|
self.active = self.current_scale > 0
|
||||||
|
if not self.active:
|
||||||
|
return {"h": kwargs["h"]}
|
||||||
|
self._ensure_features(kwargs)
|
||||||
|
h = kwargs["h"]
|
||||||
|
h = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.control_idx], h, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
self.control_idx -= 1
|
||||||
|
return {"h": h}
|
||||||
|
|
||||||
|
def output_block(self, h, hsp, transformer_options):
|
||||||
|
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
|
||||||
|
if not self.active:
|
||||||
|
return h, hsp
|
||||||
|
self.current_control_idx = self.control_idx
|
||||||
|
h = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
self.control_idx -= 1
|
||||||
|
return h, None
|
||||||
|
|
||||||
|
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
|
||||||
|
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
|
||||||
|
block_type, _ = transformer_options["block"]
|
||||||
|
if block_type == "output" and self.active and self.cached_features is not None:
|
||||||
|
x = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
return layer(x, output_shape=output_shape)
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.cached_features = None
|
||||||
|
if self.hint_latent is not None:
|
||||||
|
self.hint_latent = self.hint_latent.to(device_or_dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
|
def register(self, model_patcher):
|
||||||
|
"""Register all patches on a cloned model patcher."""
|
||||||
|
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
|
||||||
|
model_patcher.set_model_output_block_patch(self.output_block)
|
||||||
|
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")
|
||||||
@ -506,6 +506,10 @@ class ModelPatcher:
|
|||||||
def set_model_noise_refiner_patch(self, patch):
|
def set_model_noise_refiner_patch(self, patch):
|
||||||
self.set_model_patch(patch, "noise_refiner")
|
self.set_model_patch(patch, "noise_refiner")
|
||||||
|
|
||||||
|
def set_model_middle_block_after_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "middle_block_after_patch")
|
||||||
|
|
||||||
|
|
||||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
rope_options["scale_x"] = scale_x
|
rope_options["scale_x"] = scale_x
|
||||||
|
|||||||
@ -7,7 +7,10 @@ import comfy.model_management
|
|||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.ldm.lumina.controlnet
|
import comfy.ldm.lumina.controlnet
|
||||||
|
import comfy.ldm.supir.supir_modules
|
||||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||||
|
from comfy_api.latest import io
|
||||||
|
from comfy.ldm.supir.supir_patch import SUPIRPatch
|
||||||
|
|
||||||
|
|
||||||
class BlockWiseControlBlock(torch.nn.Module):
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
@ -266,6 +269,27 @@ class ModelPatchLoader:
|
|||||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||||
device=comfy.model_management.unet_offload_device(),
|
device=comfy.model_management.unet_offload_device(),
|
||||||
operations=comfy.ops.manual_cast)
|
operations=comfy.ops.manual_cast)
|
||||||
|
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
||||||
|
prefix_replace = {}
|
||||||
|
if 'model.control_model.input_hint_block.0.weight' in sd:
|
||||||
|
prefix_replace["model.control_model."] = "control_model."
|
||||||
|
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
||||||
|
else:
|
||||||
|
prefix_replace["control_model."] = "control_model."
|
||||||
|
prefix_replace["project_modules."] = "project_modules."
|
||||||
|
|
||||||
|
# Extract denoise_encoder weights before filter_keys discards them
|
||||||
|
de_prefix = "first_stage_model.denoise_encoder."
|
||||||
|
denoise_encoder_sd = {}
|
||||||
|
for k in list(sd.keys()):
|
||||||
|
if k.startswith(de_prefix):
|
||||||
|
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
|
||||||
|
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
||||||
|
sd.pop("control_model.mask_LQ", None)
|
||||||
|
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
if denoise_encoder_sd:
|
||||||
|
model.denoise_encoder_sd = denoise_encoder_sd
|
||||||
|
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||||
@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIRApply(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SUPIRApply",
|
||||||
|
category="model_patches/supir",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.ModelPatch.Input("model_patch"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
|
tooltip="Control strength at the start of sampling (high sigma)."),
|
||||||
|
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
|
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
|
||||||
|
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
|
||||||
|
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
|
||||||
|
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
|
||||||
|
tooltip="Sigma threshold below which restore_cfg is disabled."),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
|
||||||
|
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
|
||||||
|
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
|
||||||
|
if not denoise_sd:
|
||||||
|
return vae.encode(image)
|
||||||
|
|
||||||
|
# Clone VAE patcher, apply denoise_encoder weights to clone, encode
|
||||||
|
orig_patcher = vae.patcher
|
||||||
|
vae.patcher = orig_patcher.clone()
|
||||||
|
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
|
||||||
|
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
|
||||||
|
try:
|
||||||
|
return vae.encode(image)
|
||||||
|
finally:
|
||||||
|
vae.patcher = orig_patcher
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
|
||||||
|
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
||||||
|
model_patched = model.clone()
|
||||||
|
hint_latent = model.get_model_object("latent_format").process_in(
|
||||||
|
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
|
||||||
|
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
||||||
|
patch.register(model_patched)
|
||||||
|
|
||||||
|
if restore_cfg > 0.0:
|
||||||
|
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
|
||||||
|
latent_format = model.get_model_object("latent_format")
|
||||||
|
decoded = vae.decode(latent_format.process_out(hint_latent))
|
||||||
|
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
|
||||||
|
sigma_max = 14.6146
|
||||||
|
|
||||||
|
def restore_cfg_function(args):
|
||||||
|
denoised = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
if sigma.dim() > 0:
|
||||||
|
s = sigma[0].item()
|
||||||
|
else:
|
||||||
|
s = sigma.item()
|
||||||
|
if s > restore_cfg_s_tmin:
|
||||||
|
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
|
||||||
|
b = denoised.shape[0]
|
||||||
|
if ref.shape[0] != b:
|
||||||
|
ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b]
|
||||||
|
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
|
||||||
|
d_center = denoised - ref
|
||||||
|
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
|
||||||
|
|
||||||
|
return io.NodeOutput(model_patched)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
"ZImageFunControlnet": ZImageFunControlnet,
|
"ZImageFunControlnet": ZImageFunControlnet,
|
||||||
"USOStyleReference": USOStyleReference,
|
"USOStyleReference": USOStyleReference,
|
||||||
|
"SUPIRApply": SUPIRApply,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from PIL import Image
|
|||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypedDict, Literal
|
from typing import TypedDict, Literal
|
||||||
|
import kornia
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
|||||||
return io.NodeOutput(batched)
|
return io.NodeOutput(batched)
|
||||||
|
|
||||||
|
|
||||||
|
class ColorTransfer(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ColorTransfer",
|
||||||
|
category="image/postprocessing",
|
||||||
|
description="Match the colors of one image to another using various algorithms.",
|
||||||
|
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||||
|
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||||
|
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||||
|
io.DynamicCombo.Input("source_stats",
|
||||||
|
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||||
|
options=[
|
||||||
|
io.DynamicCombo.Option("per_frame", []),
|
||||||
|
io.DynamicCombo.Option("uniform", []),
|
||||||
|
io.DynamicCombo.Option("target_frame", [
|
||||||
|
io.Int.Input("target_index", default=0, min=0, max=10000,
|
||||||
|
tooltip="Frame index used as the source baseline for computing the transform to image_ref"),
|
||||||
|
]),
|
||||||
|
]),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(display_name="image"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_lab(images, i, device):
|
||||||
|
return kornia.color.rgb_to_lab(
|
||||||
|
images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pool_stats(images, device, is_reinhard, eps):
|
||||||
|
"""Two-pass pooled mean + std/cov across all frames."""
|
||||||
|
N, C = images.shape[0], images.shape[3]
|
||||||
|
HW = images.shape[1] * images.shape[2]
|
||||||
|
mean = torch.zeros(C, 1, device=device, dtype=torch.float32)
|
||||||
|
for i in range(N):
|
||||||
|
mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True)
|
||||||
|
mean /= N
|
||||||
|
acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32)
|
||||||
|
for i in range(N):
|
||||||
|
centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean
|
||||||
|
if is_reinhard:
|
||||||
|
acc += (centered * centered).mean(dim=-1, keepdim=True)
|
||||||
|
else:
|
||||||
|
acc += centered @ centered.T / HW
|
||||||
|
if is_reinhard:
|
||||||
|
return mean, torch.sqrt(acc / N).clamp_min_(eps)
|
||||||
|
return mean, acc / N
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _frame_stats(lab_flat, hw, is_reinhard, eps):
|
||||||
|
"""Per-frame mean + std/cov."""
|
||||||
|
mean = lab_flat.mean(dim=-1, keepdim=True)
|
||||||
|
if is_reinhard:
|
||||||
|
return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps)
|
||||||
|
centered = lab_flat - mean
|
||||||
|
return mean, centered @ centered.T / hw
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mkl_matrix(cov_s, cov_r, eps):
|
||||||
|
"""Compute MKL 3x3 transform matrix from source and ref covariances."""
|
||||||
|
eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s)
|
||||||
|
sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps)
|
||||||
|
|
||||||
|
scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0)
|
||||||
|
mid = scaled_V.T @ cov_r @ scaled_V
|
||||||
|
eig_val_m, eig_vec_m = torch.linalg.eigh(mid)
|
||||||
|
sqrt_m = torch.sqrt(eig_val_m.clamp_min(0))
|
||||||
|
|
||||||
|
inv_sqrt_s = 1.0 / sqrt_val_s
|
||||||
|
inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0)
|
||||||
|
M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T
|
||||||
|
return inv_scaled_V @ M_half @ inv_scaled_V.T
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _histogram_lut(src, ref, bins=256):
|
||||||
|
"""Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1]."""
|
||||||
|
s_bins = (src * (bins - 1)).long().clamp(0, bins - 1)
|
||||||
|
r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1)
|
||||||
|
s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
|
||||||
|
r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype)
|
||||||
|
ones_s = torch.ones_like(src)
|
||||||
|
ones_r = torch.ones_like(ref)
|
||||||
|
s_hist.scatter_add_(1, s_bins, ones_s)
|
||||||
|
r_hist.scatter_add_(1, r_bins, ones_r)
|
||||||
|
s_cdf = s_hist.cumsum(1)
|
||||||
|
s_cdf = s_cdf / s_cdf[:, -1:]
|
||||||
|
r_cdf = r_hist.cumsum(1)
|
||||||
|
r_cdf = r_cdf / r_cdf[:, -1:]
|
||||||
|
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _pooled_cdf(cls, images, device, num_bins=256):
|
||||||
|
"""Build pooled CDF across all frames, one frame at a time."""
|
||||||
|
C = images.shape[3]
|
||||||
|
hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32)
|
||||||
|
for i in range(images.shape[0]):
|
||||||
|
frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
|
||||||
|
bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1)
|
||||||
|
hist.scatter_add_(1, bins, torch.ones_like(frame))
|
||||||
|
cdf = hist.cumsum(1)
|
||||||
|
return cdf / cdf[:, -1:]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B):
|
||||||
|
"""Build per-frame or uniform LUT transform for histogram mode."""
|
||||||
|
if stats_mode == 'per_frame':
|
||||||
|
return None # LUT computed per-frame in the apply loop
|
||||||
|
|
||||||
|
r_cdf = cls._pooled_cdf(image_ref, device)
|
||||||
|
if stats_mode == 'target_frame':
|
||||||
|
ti = min(target_index, B - 1)
|
||||||
|
s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device)
|
||||||
|
else:
|
||||||
|
s_cdf = cls._pooled_cdf(image_target, device)
|
||||||
|
return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard):
|
||||||
|
"""Build transform parameters for Lab-based methods. Returns a transform function."""
|
||||||
|
eps = 1e-6
|
||||||
|
B, H, W, C = image_target.shape
|
||||||
|
B_ref = image_ref.shape[0]
|
||||||
|
single_ref = B_ref == 1
|
||||||
|
HW = H * W
|
||||||
|
HW_ref = image_ref.shape[1] * image_ref.shape[2]
|
||||||
|
|
||||||
|
# Precompute ref stats
|
||||||
|
if single_ref or stats_mode in ('uniform', 'target_frame'):
|
||||||
|
ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps)
|
||||||
|
|
||||||
|
# Uniform/target_frame: precompute single affine transform
|
||||||
|
if stats_mode in ('uniform', 'target_frame'):
|
||||||
|
if stats_mode == 'target_frame':
|
||||||
|
ti = min(target_index, B - 1)
|
||||||
|
s_lab = cls._to_lab(image_target, ti, device).view(C, -1)
|
||||||
|
s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps)
|
||||||
|
else:
|
||||||
|
s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps)
|
||||||
|
|
||||||
|
if is_reinhard:
|
||||||
|
scale = ref_sc / s_sc
|
||||||
|
offset = ref_mean - scale * s_mean
|
||||||
|
return lambda src_flat, **_: src_flat * scale + offset
|
||||||
|
T = cls._mkl_matrix(s_sc, ref_sc, eps)
|
||||||
|
offset = ref_mean - T @ s_mean
|
||||||
|
return lambda src_flat, **_: T @ src_flat + offset
|
||||||
|
|
||||||
|
# per_frame
|
||||||
|
def per_frame_transform(src_flat, frame_idx):
|
||||||
|
s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps)
|
||||||
|
|
||||||
|
if single_ref:
|
||||||
|
r_mean, r_sc = ref_mean, ref_sc
|
||||||
|
else:
|
||||||
|
ri = min(frame_idx, B_ref - 1)
|
||||||
|
r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps)
|
||||||
|
|
||||||
|
centered = src_flat - s_mean
|
||||||
|
if is_reinhard:
|
||||||
|
return centered * (r_sc / s_sc) + r_mean
|
||||||
|
T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps)
|
||||||
|
return T @ centered + r_mean
|
||||||
|
|
||||||
|
return per_frame_transform
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput:
|
||||||
|
stats_mode = source_stats["source_stats"]
|
||||||
|
target_index = source_stats.get("target_index", 0)
|
||||||
|
|
||||||
|
if strength == 0 or image_ref is None:
|
||||||
|
return io.NodeOutput(image_target)
|
||||||
|
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
intermediate_device = comfy.model_management.intermediate_device()
|
||||||
|
intermediate_dtype = comfy.model_management.intermediate_dtype()
|
||||||
|
|
||||||
|
B, H, W, C = image_target.shape
|
||||||
|
B_ref = image_ref.shape[0]
|
||||||
|
pbar = comfy.utils.ProgressBar(B)
|
||||||
|
out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
|
||||||
|
if method == 'histogram':
|
||||||
|
uniform_lut = cls._build_histogram_transform(
|
||||||
|
image_target, image_ref, device, stats_mode, target_index, B)
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1)
|
||||||
|
src_flat = src.reshape(C, -1)
|
||||||
|
if uniform_lut is not None:
|
||||||
|
lut = uniform_lut
|
||||||
|
else:
|
||||||
|
ri = min(i, B_ref - 1)
|
||||||
|
ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1)
|
||||||
|
lut = cls._histogram_lut(src_flat, ref)
|
||||||
|
bin_idx = (src_flat * 255).long().clamp(0, 255)
|
||||||
|
matched = lut.gather(1, bin_idx).view(C, H, W)
|
||||||
|
result = matched if strength == 1.0 else torch.lerp(src, matched, strength)
|
||||||
|
out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
pbar.update(1)
|
||||||
|
else:
|
||||||
|
transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab")
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
src_frame = cls._to_lab(image_target, i, device)
|
||||||
|
corrected = transform(src_frame.view(C, -1), frame_idx=i)
|
||||||
|
if strength == 1.0:
|
||||||
|
result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W))
|
||||||
|
else:
|
||||||
|
result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength))
|
||||||
|
out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
return io.NodeOutput(out)
|
||||||
|
|
||||||
|
|
||||||
class PostProcessingExtension(ComfyExtension):
|
class PostProcessingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension):
|
|||||||
BatchImagesNode,
|
BatchImagesNode,
|
||||||
BatchMasksNode,
|
BatchMasksNode,
|
||||||
BatchLatentsNode,
|
BatchLatentsNode,
|
||||||
|
ColorTransfer,
|
||||||
# BatchImagesMasksLatentsNode,
|
# BatchImagesMasksLatentsNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user