mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-25 18:02:37 +08:00
initial SUPIR support
This commit is contained in:
parent
2d4970ff67
commit
7b35ae9440
@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
|
|||||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||||
for layer in ts:
|
for layer in ts:
|
||||||
|
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||||
|
found_patched = False
|
||||||
|
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||||
|
if isinstance(layer, class_type):
|
||||||
|
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||||
|
found_patched = True
|
||||||
|
break
|
||||||
|
if found_patched:
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(layer, VideoResBlock):
|
if isinstance(layer, VideoResBlock):
|
||||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||||
elif isinstance(layer, TimestepBlock):
|
elif isinstance(layer, TimestepBlock):
|
||||||
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
|||||||
elif isinstance(layer, Upsample):
|
elif isinstance(layer, Upsample):
|
||||||
x = layer(x, output_shape=output_shape)
|
x = layer(x, output_shape=output_shape)
|
||||||
else:
|
else:
|
||||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
|
||||||
found_patched = False
|
|
||||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
|
||||||
if isinstance(layer, class_type):
|
|
||||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
|
||||||
found_patched = True
|
|
||||||
break
|
|
||||||
if found_patched:
|
|
||||||
continue
|
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
|
|||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
h = apply_control(h, control, 'middle')
|
h = apply_control(h, control, 'middle')
|
||||||
|
|
||||||
|
if "middle_block_after_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["middle_block_after_patch"]
|
||||||
|
for p in patch:
|
||||||
|
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
|
||||||
|
"timesteps": timesteps, "transformer_options": transformer_options})
|
||||||
|
h = out["h"]
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
@ -905,8 +912,9 @@ class UNetModel(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
h, hsp = p(h, hsp, transformer_options)
|
h, hsp = p(h, hsp, transformer_options)
|
||||||
|
|
||||||
h = th.cat([h, hsp], dim=1)
|
if hsp is not None:
|
||||||
del hsp
|
h = th.cat([h, hsp], dim=1)
|
||||||
|
del hsp
|
||||||
if len(hs) > 0:
|
if len(hs) > 0:
|
||||||
output_shape = hs[-1].shape
|
output_shape = hs[-1].shape
|
||||||
else:
|
else:
|
||||||
|
|||||||
0
comfy/ldm/supir/__init__.py
Normal file
0
comfy/ldm/supir/__init__.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroSFT(nn.Module):
|
||||||
|
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
ks = 3
|
||||||
|
pw = ks // 2
|
||||||
|
|
||||||
|
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
nhidden = 128
|
||||||
|
|
||||||
|
self.mlp_shared = nn.Sequential(
|
||||||
|
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
|
||||||
|
nn.SiLU()
|
||||||
|
)
|
||||||
|
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||||
|
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
|
||||||
|
self.pre_concat = bool(concat_channels != 0)
|
||||||
|
|
||||||
|
def forward(self, c, h, h_ori=None, control_scale=1):
|
||||||
|
if h_ori is not None and self.pre_concat:
|
||||||
|
h_raw = torch.cat([h_ori, h], dim=1)
|
||||||
|
else:
|
||||||
|
h_raw = h
|
||||||
|
|
||||||
|
h = h.add_(self.zero_conv(c))
|
||||||
|
if h_ori is not None and self.pre_concat:
|
||||||
|
h = torch.cat([h_ori, h], dim=1)
|
||||||
|
actv = self.mlp_shared(c)
|
||||||
|
gamma = self.zero_mul(actv)
|
||||||
|
beta = self.zero_add(actv)
|
||||||
|
h = self.param_free_norm(h)
|
||||||
|
h = torch.addcmul(h + beta, h, gamma)
|
||||||
|
if h_ori is not None and not self.pre_concat:
|
||||||
|
h = torch.cat([h_ori, h], dim=1)
|
||||||
|
return torch.lerp(h_raw, h, control_scale)
|
||||||
|
|
||||||
|
|
||||||
|
class _CrossAttnInner(nn.Module):
|
||||||
|
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
|
||||||
|
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, context):
|
||||||
|
q = self.to_q(x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
return self.to_out(optimized_attention(q, k, v, self.heads))
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroCrossAttn(nn.Module):
|
||||||
|
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
heads = query_dim // 64
|
||||||
|
dim_head = 64
|
||||||
|
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
|
||||||
|
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, context, x, control_scale=1):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
|
||||||
|
x = self.attn(
|
||||||
|
self.norm1(x).flatten(2).transpose(1, 2),
|
||||||
|
self.norm2(context).flatten(2).transpose(1, 2),
|
||||||
|
).transpose(1, 2).unflatten(2, (h, w))
|
||||||
|
|
||||||
|
return x_in + x * control_scale
|
||||||
|
|
||||||
|
|
||||||
|
class GLVControl(nn.Module):
|
||||||
|
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=4,
|
||||||
|
model_channels=320,
|
||||||
|
num_res_blocks=2,
|
||||||
|
attention_resolutions=(4, 2),
|
||||||
|
channel_mult=(1, 2, 4),
|
||||||
|
num_head_channels=64,
|
||||||
|
transformer_depth=(1, 2, 10),
|
||||||
|
context_dim=2048,
|
||||||
|
adm_in_channels=2816,
|
||||||
|
use_linear_in_transformer=True,
|
||||||
|
use_checkpoint=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model_channels = model_channels
|
||||||
|
time_embed_dim = model_channels * 4
|
||||||
|
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
nn.Sequential(
|
||||||
|
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList([
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
])
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for nr in range(num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
layers.append(
|
||||||
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||||
|
depth=transformer_depth[level], context_dim=context_dim,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ds *= 2
|
||||||
|
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||||
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||||
|
depth=transformer_depth[-1], context_dim=context_dim,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dtype=dtype, device=device, operations=operations),
|
||||||
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_hint_block = TimestepEmbedSequential(
|
||||||
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
||||||
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
|
emb = self.time_embed(t_emb) + self.label_emb(y)
|
||||||
|
|
||||||
|
guided_hint = self.input_hint_block(x, emb, context)
|
||||||
|
|
||||||
|
hs = []
|
||||||
|
h = xt
|
||||||
|
for module in self.input_blocks:
|
||||||
|
if guided_hint is not None:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
h += guided_hint
|
||||||
|
guided_hint = None
|
||||||
|
else:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
h = self.middle_block(h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIR(nn.Module):
|
||||||
|
"""
|
||||||
|
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
|
||||||
|
State dict keys match the original SUPIR checkpoint layout:
|
||||||
|
control_model.* -> GLVControl
|
||||||
|
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||||
|
"""
|
||||||
|
def __init__(self, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
project_channel_scale = 2
|
||||||
|
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
||||||
|
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
|
||||||
|
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
||||||
|
cross_attn_insert_idx = [6, 3]
|
||||||
|
|
||||||
|
self.project_modules = nn.ModuleList()
|
||||||
|
for i in range(len(cond_output_channels)):
|
||||||
|
self.project_modules.append(ZeroSFT(
|
||||||
|
project_channels[i], cond_output_channels[i],
|
||||||
|
concat_channels=concat_channels[i],
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
))
|
||||||
|
|
||||||
|
for i in cross_attn_insert_idx:
|
||||||
|
self.project_modules.insert(i, ZeroCrossAttn(
|
||||||
|
cond_output_channels[i], concat_channels[i],
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
))
|
||||||
105
comfy/ldm/supir/supir_patch.py
Normal file
105
comfy/ldm/supir/supir_patch.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIRPatch:
|
||||||
|
"""
|
||||||
|
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
|
||||||
|
Runs GLVControl lazily on first patch invocation per step, applies adapters through
|
||||||
|
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
|
||||||
|
"""
|
||||||
|
SIGMA_MAX = 14.6146
|
||||||
|
|
||||||
|
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
|
||||||
|
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
|
||||||
|
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||||
|
self.hint_latent = hint_latent # encoded LQ image latent
|
||||||
|
self.strength_start = strength_start
|
||||||
|
self.strength_end = strength_end
|
||||||
|
self.cached_features = None
|
||||||
|
self.adapter_idx = 0
|
||||||
|
self.control_idx = 0
|
||||||
|
self.current_control_idx = 0
|
||||||
|
self.active = True
|
||||||
|
|
||||||
|
def _ensure_features(self, kwargs):
|
||||||
|
"""Run GLVControl on first call per step, cache results."""
|
||||||
|
if self.cached_features is not None:
|
||||||
|
return
|
||||||
|
x = kwargs["x"]
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
comfy.model_management.load_models_gpu([self.model_patch])
|
||||||
|
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
|
||||||
|
if hint.shape[0] < batch_size:
|
||||||
|
hint = hint.repeat(batch_size // hint.shape[0], 1, 1, 1)[:batch_size]
|
||||||
|
self.cached_features = self.model_patch.model.control_model(
|
||||||
|
hint, kwargs["timesteps"], x,
|
||||||
|
kwargs["context"], kwargs["y"]
|
||||||
|
)
|
||||||
|
self.adapter_idx = len(self.project_modules) - 1
|
||||||
|
self.control_idx = len(self.cached_features) - 1
|
||||||
|
|
||||||
|
def _get_control_scale(self, kwargs):
|
||||||
|
if self.strength_start == self.strength_end:
|
||||||
|
return self.strength_end
|
||||||
|
sigma = kwargs["transformer_options"].get("sigmas")
|
||||||
|
if sigma is None:
|
||||||
|
return self.strength_end
|
||||||
|
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
|
||||||
|
t = min(s / self.SIGMA_MAX, 1.0)
|
||||||
|
return t * (self.strength_start - self.strength_end) + self.strength_end
|
||||||
|
|
||||||
|
def middle_after(self, kwargs):
|
||||||
|
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
|
||||||
|
self.cached_features = None # reset from previous step
|
||||||
|
self.current_scale = self._get_control_scale(kwargs)
|
||||||
|
self.active = self.current_scale > 0
|
||||||
|
if not self.active:
|
||||||
|
return {"h": kwargs["h"]}
|
||||||
|
self._ensure_features(kwargs)
|
||||||
|
h = kwargs["h"]
|
||||||
|
h = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.control_idx], h, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
self.control_idx -= 1
|
||||||
|
return {"h": h}
|
||||||
|
|
||||||
|
def output_block(self, h, hsp, transformer_options):
|
||||||
|
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
|
||||||
|
if not self.active:
|
||||||
|
return h, hsp
|
||||||
|
self.current_control_idx = self.control_idx
|
||||||
|
h = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
self.control_idx -= 1
|
||||||
|
return h, None
|
||||||
|
|
||||||
|
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
|
||||||
|
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
|
||||||
|
block_type, _ = transformer_options["block"]
|
||||||
|
if block_type == "output" and self.active and self.cached_features is not None:
|
||||||
|
x = self.project_modules[self.adapter_idx](
|
||||||
|
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
|
||||||
|
)
|
||||||
|
self.adapter_idx -= 1
|
||||||
|
return layer(x, output_shape=output_shape)
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.cached_features = None
|
||||||
|
if self.hint_latent is not None:
|
||||||
|
self.hint_latent = self.hint_latent.to(device_or_dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
|
def register(self, model_patcher):
|
||||||
|
"""Register all patches on a cloned model patcher."""
|
||||||
|
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
|
||||||
|
model_patcher.set_model_output_block_patch(self.output_block)
|
||||||
|
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")
|
||||||
@ -509,6 +509,10 @@ class ModelPatcher:
|
|||||||
def set_model_noise_refiner_patch(self, patch):
|
def set_model_noise_refiner_patch(self, patch):
|
||||||
self.set_model_patch(patch, "noise_refiner")
|
self.set_model_patch(patch, "noise_refiner")
|
||||||
|
|
||||||
|
def set_model_middle_block_after_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "middle_block_after_patch")
|
||||||
|
|
||||||
|
|
||||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
rope_options["scale_x"] = scale_x
|
rope_options["scale_x"] = scale_x
|
||||||
|
|||||||
@ -7,7 +7,10 @@ import comfy.model_management
|
|||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.ldm.lumina.controlnet
|
import comfy.ldm.lumina.controlnet
|
||||||
|
import comfy.ldm.supir.supir_modules
|
||||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy.ldm.supir.supir_patch import SUPIRPatch
|
||||||
|
|
||||||
|
|
||||||
class BlockWiseControlBlock(torch.nn.Module):
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
@ -266,6 +269,13 @@ class ModelPatchLoader:
|
|||||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||||
device=comfy.model_management.unet_offload_device(),
|
device=comfy.model_management.unet_offload_device(),
|
||||||
operations=comfy.ops.manual_cast)
|
operations=comfy.ops.manual_cast)
|
||||||
|
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
||||||
|
prefix_replace = {}
|
||||||
|
if 'model.control_model.input_hint_block.0.weight' in sd:
|
||||||
|
prefix_replace["model.control_model."] = "control_model."
|
||||||
|
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
||||||
|
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||||
@ -565,9 +575,66 @@ class MultiTalkModelPatch(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SUPIRApply(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SUPIRApply",
|
||||||
|
category="model_patches/supir",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.ModelPatch.Input("model_patch"),
|
||||||
|
io.Latent.Input("latent"),
|
||||||
|
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
|
tooltip="Control strength at the start of sampling (high sigma)."),
|
||||||
|
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
|
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
|
||||||
|
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
|
||||||
|
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
|
||||||
|
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
|
||||||
|
tooltip="Sigma threshold below which restore_cfg is disabled."),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, latent: io.Latent.Type,
|
||||||
|
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
||||||
|
model_patched = model.clone()
|
||||||
|
hint_latent = model.get_model_object("latent_format").process_in(latent["samples"])
|
||||||
|
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
||||||
|
patch.register(model_patched)
|
||||||
|
|
||||||
|
if restore_cfg > 0.0:
|
||||||
|
x_center = hint_latent.clone()
|
||||||
|
sigma_max = 14.6146
|
||||||
|
|
||||||
|
def restore_cfg_function(args):
|
||||||
|
denoised = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
if sigma.dim() > 0:
|
||||||
|
s = sigma[0].item()
|
||||||
|
else:
|
||||||
|
s = sigma.item()
|
||||||
|
if s > restore_cfg_s_tmin:
|
||||||
|
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
|
||||||
|
if ref.shape[0] < denoised.shape[0]:
|
||||||
|
ref = ref.repeat(denoised.shape[0] // ref.shape[0], 1, 1, 1)[:denoised.shape[0]]
|
||||||
|
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
|
||||||
|
d_center = denoised - ref
|
||||||
|
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
|
||||||
|
|
||||||
|
return io.NodeOutput(model_patched)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
"ZImageFunControlnet": ZImageFunControlnet,
|
"ZImageFunControlnet": ZImageFunControlnet,
|
||||||
"USOStyleReference": USOStyleReference,
|
"USOStyleReference": USOStyleReference,
|
||||||
|
"SUPIRApply": SUPIRApply,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user