mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 21:12:30 +08:00
initial SUPIR support
This commit is contained in:
parent
2d4970ff67
commit
7b35ae9440
@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
|
||||
found_patched = False
|
||||
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
|
||||
if isinstance(layer, class_type):
|
||||
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
|
||||
found_patched = True
|
||||
break
|
||||
if found_patched:
|
||||
continue
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
if "middle_block_after_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_block_after_patch"]
|
||||
for p in patch:
|
||||
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
|
||||
"timesteps": timesteps, "transformer_options": transformer_options})
|
||||
h = out["h"]
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
@ -905,8 +912,9 @@ class UNetModel(nn.Module):
|
||||
for p in patch:
|
||||
h, hsp = p(h, hsp, transformer_options)
|
||||
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if hsp is not None:
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
|
||||
0
comfy/ldm/supir/__init__.py
Normal file
0
comfy/ldm/supir/__init__.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
226
comfy/ldm/supir/supir_modules.py
Normal file
@ -0,0 +1,226 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
class ZeroSFT(nn.Module):
|
||||
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
ks = 3
|
||||
pw = ks // 2
|
||||
|
||||
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
|
||||
|
||||
nhidden = 128
|
||||
|
||||
self.mlp_shared = nn.Sequential(
|
||||
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
||||
|
||||
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
|
||||
self.pre_concat = bool(concat_channels != 0)
|
||||
|
||||
def forward(self, c, h, h_ori=None, control_scale=1):
|
||||
if h_ori is not None and self.pre_concat:
|
||||
h_raw = torch.cat([h_ori, h], dim=1)
|
||||
else:
|
||||
h_raw = h
|
||||
|
||||
h = h.add_(self.zero_conv(c))
|
||||
if h_ori is not None and self.pre_concat:
|
||||
h = torch.cat([h_ori, h], dim=1)
|
||||
actv = self.mlp_shared(c)
|
||||
gamma = self.zero_mul(actv)
|
||||
beta = self.zero_add(actv)
|
||||
h = self.param_free_norm(h)
|
||||
h = torch.addcmul(h + beta, h, gamma)
|
||||
if h_ori is not None and not self.pre_concat:
|
||||
h = torch.cat([h_ori, h], dim=1)
|
||||
return torch.lerp(h_raw, h, control_scale)
|
||||
|
||||
|
||||
class _CrossAttnInner(nn.Module):
|
||||
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
|
||||
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, context):
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
return self.to_out(optimized_attention(q, k, v, self.heads))
|
||||
|
||||
|
||||
class ZeroCrossAttn(nn.Module):
|
||||
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
heads = query_dim // 64
|
||||
dim_head = 64
|
||||
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
|
||||
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
|
||||
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, context, x, control_scale=1):
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
|
||||
x = self.attn(
|
||||
self.norm1(x).flatten(2).transpose(1, 2),
|
||||
self.norm2(context).flatten(2).transpose(1, 2),
|
||||
).transpose(1, 2).unflatten(2, (h, w))
|
||||
|
||||
return x_in + x * control_scale
|
||||
|
||||
|
||||
class GLVControl(nn.Module):
|
||||
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
model_channels=320,
|
||||
num_res_blocks=2,
|
||||
attention_resolutions=(4, 2),
|
||||
channel_mult=(1, 2, 4),
|
||||
num_head_channels=64,
|
||||
transformer_depth=(1, 2, 10),
|
||||
context_dim=2048,
|
||||
adm_in_channels=2816,
|
||||
use_linear_in_transformer=True,
|
||||
use_checkpoint=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_channels = model_channels
|
||||
time_embed_dim = model_channels * 4
|
||||
|
||||
self.time_embed = nn.Sequential(
|
||||
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
])
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
num_heads = ch // num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||
depth=transformer_depth[level], context_dim=context_dim,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
if level != len(channel_mult) - 1:
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
|
||||
)
|
||||
)
|
||||
ds *= 2
|
||||
|
||||
num_heads = ch // num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||
SpatialTransformer(ch, num_heads, num_head_channels,
|
||||
depth=transformer_depth[-1], context_dim=context_dim,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dtype=dtype, device=device, operations=operations),
|
||||
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
||||
)
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb) + self.label_emb(y)
|
||||
|
||||
guided_hint = self.input_hint_block(x, emb, context)
|
||||
|
||||
hs = []
|
||||
h = xt
|
||||
for module in self.input_blocks:
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
h += guided_hint
|
||||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
hs.append(h)
|
||||
return hs
|
||||
|
||||
|
||||
class SUPIR(nn.Module):
|
||||
"""
|
||||
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
|
||||
State dict keys match the original SUPIR checkpoint layout:
|
||||
control_model.* -> GLVControl
|
||||
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||
"""
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
|
||||
|
||||
project_channel_scale = 2
|
||||
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
||||
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
|
||||
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
||||
cross_attn_insert_idx = [6, 3]
|
||||
|
||||
self.project_modules = nn.ModuleList()
|
||||
for i in range(len(cond_output_channels)):
|
||||
self.project_modules.append(ZeroSFT(
|
||||
project_channels[i], cond_output_channels[i],
|
||||
concat_channels=concat_channels[i],
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
))
|
||||
|
||||
for i in cross_attn_insert_idx:
|
||||
self.project_modules.insert(i, ZeroCrossAttn(
|
||||
cond_output_channels[i], concat_channels[i],
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
))
|
||||
105
comfy/ldm/supir/supir_patch.py
Normal file
105
comfy/ldm/supir/supir_patch.py
Normal file
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
|
||||
|
||||
|
||||
class SUPIRPatch:
|
||||
"""
|
||||
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
|
||||
Runs GLVControl lazily on first patch invocation per step, applies adapters through
|
||||
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
|
||||
"""
|
||||
SIGMA_MAX = 14.6146
|
||||
|
||||
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
|
||||
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
|
||||
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
||||
self.hint_latent = hint_latent # encoded LQ image latent
|
||||
self.strength_start = strength_start
|
||||
self.strength_end = strength_end
|
||||
self.cached_features = None
|
||||
self.adapter_idx = 0
|
||||
self.control_idx = 0
|
||||
self.current_control_idx = 0
|
||||
self.active = True
|
||||
|
||||
def _ensure_features(self, kwargs):
|
||||
"""Run GLVControl on first call per step, cache results."""
|
||||
if self.cached_features is not None:
|
||||
return
|
||||
x = kwargs["x"]
|
||||
batch_size = x.shape[0]
|
||||
comfy.model_management.load_models_gpu([self.model_patch])
|
||||
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
|
||||
if hint.shape[0] < batch_size:
|
||||
hint = hint.repeat(batch_size // hint.shape[0], 1, 1, 1)[:batch_size]
|
||||
self.cached_features = self.model_patch.model.control_model(
|
||||
hint, kwargs["timesteps"], x,
|
||||
kwargs["context"], kwargs["y"]
|
||||
)
|
||||
self.adapter_idx = len(self.project_modules) - 1
|
||||
self.control_idx = len(self.cached_features) - 1
|
||||
|
||||
def _get_control_scale(self, kwargs):
|
||||
if self.strength_start == self.strength_end:
|
||||
return self.strength_end
|
||||
sigma = kwargs["transformer_options"].get("sigmas")
|
||||
if sigma is None:
|
||||
return self.strength_end
|
||||
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
|
||||
t = min(s / self.SIGMA_MAX, 1.0)
|
||||
return t * (self.strength_start - self.strength_end) + self.strength_end
|
||||
|
||||
def middle_after(self, kwargs):
|
||||
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
|
||||
self.cached_features = None # reset from previous step
|
||||
self.current_scale = self._get_control_scale(kwargs)
|
||||
self.active = self.current_scale > 0
|
||||
if not self.active:
|
||||
return {"h": kwargs["h"]}
|
||||
self._ensure_features(kwargs)
|
||||
h = kwargs["h"]
|
||||
h = self.project_modules[self.adapter_idx](
|
||||
self.cached_features[self.control_idx], h, control_scale=self.current_scale
|
||||
)
|
||||
self.adapter_idx -= 1
|
||||
self.control_idx -= 1
|
||||
return {"h": h}
|
||||
|
||||
def output_block(self, h, hsp, transformer_options):
|
||||
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
|
||||
if not self.active:
|
||||
return h, hsp
|
||||
self.current_control_idx = self.control_idx
|
||||
h = self.project_modules[self.adapter_idx](
|
||||
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
|
||||
)
|
||||
self.adapter_idx -= 1
|
||||
self.control_idx -= 1
|
||||
return h, None
|
||||
|
||||
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
|
||||
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
|
||||
block_type, _ = transformer_options["block"]
|
||||
if block_type == "output" and self.active and self.cached_features is not None:
|
||||
x = self.project_modules[self.adapter_idx](
|
||||
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
|
||||
)
|
||||
self.adapter_idx -= 1
|
||||
return layer(x, output_shape=output_shape)
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
self.cached_features = None
|
||||
if self.hint_latent is not None:
|
||||
self.hint_latent = self.hint_latent.to(device_or_dtype)
|
||||
return self
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
def register(self, model_patcher):
|
||||
"""Register all patches on a cloned model patcher."""
|
||||
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
|
||||
model_patcher.set_model_output_block_patch(self.output_block)
|
||||
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")
|
||||
@ -509,6 +509,10 @@ class ModelPatcher:
|
||||
def set_model_noise_refiner_patch(self, patch):
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
def set_model_middle_block_after_patch(self, patch):
|
||||
self.set_model_patch(patch, "middle_block_after_patch")
|
||||
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
|
||||
@ -7,7 +7,10 @@ import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
import comfy.ldm.lumina.controlnet
|
||||
import comfy.ldm.supir.supir_modules
|
||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy.ldm.supir.supir_patch import SUPIRPatch
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@ -266,6 +269,13 @@ class ModelPatchLoader:
|
||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
||||
prefix_replace = {}
|
||||
if 'model.control_model.input_hint_block.0.weight' in sd:
|
||||
prefix_replace["model.control_model."] = "control_model."
|
||||
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
||||
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||
@ -565,9 +575,66 @@ class MultiTalkModelPatch(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class SUPIRApply(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SUPIRApply",
|
||||
category="model_patches/supir",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.ModelPatch.Input("model_patch"),
|
||||
io.Latent.Input("latent"),
|
||||
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||
tooltip="Control strength at the start of sampling (high sigma)."),
|
||||
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||
tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."),
|
||||
io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True,
|
||||
tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."),
|
||||
io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True,
|
||||
tooltip="Sigma threshold below which restore_cfg is disabled."),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, latent: io.Latent.Type,
|
||||
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
||||
model_patched = model.clone()
|
||||
hint_latent = model.get_model_object("latent_format").process_in(latent["samples"])
|
||||
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
||||
patch.register(model_patched)
|
||||
|
||||
if restore_cfg > 0.0:
|
||||
x_center = hint_latent.clone()
|
||||
sigma_max = 14.6146
|
||||
|
||||
def restore_cfg_function(args):
|
||||
denoised = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
if sigma.dim() > 0:
|
||||
s = sigma[0].item()
|
||||
else:
|
||||
s = sigma.item()
|
||||
if s > restore_cfg_s_tmin:
|
||||
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
|
||||
if ref.shape[0] < denoised.shape[0]:
|
||||
ref = ref.repeat(denoised.shape[0] // ref.shape[0], 1, 1, 1)[:denoised.shape[0]]
|
||||
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
|
||||
d_center = denoised - ref
|
||||
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)
|
||||
return denoised
|
||||
|
||||
model_patched.set_model_sampler_post_cfg_function(restore_cfg_function)
|
||||
|
||||
return io.NodeOutput(model_patched)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
"ZImageFunControlnet": ZImageFunControlnet,
|
||||
"USOStyleReference": USOStyleReference,
|
||||
"SUPIRApply": SUPIRApply,
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user