From 7b35ae94404237c119d9b90c46a247b5bc225cc7 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 24 Mar 2026 16:16:12 +0200 Subject: [PATCH] initial SUPIR support --- .../modules/diffusionmodules/openaimodel.py | 30 ++- comfy/ldm/supir/__init__.py | 0 comfy/ldm/supir/supir_modules.py | 226 ++++++++++++++++++ comfy/ldm/supir/supir_patch.py | 105 ++++++++ comfy/model_patcher.py | 4 + comfy_extras/nodes_model_patch.py | 67 ++++++ 6 files changed, 421 insertions(+), 11 deletions(-) create mode 100644 comfy/ldm/supir/__init__.py create mode 100644 comfy/ldm/supir/supir_modules.py create mode 100644 comfy/ldm/supir/supir_patch.py diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 295310df6..4b92c44cf 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -34,6 +34,16 @@ class TimestepBlock(nn.Module): #This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: + if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: + found_patched = False + for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: + if isinstance(layer, class_type): + x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) + found_patched = True + break + if found_patched: + continue + if isinstance(layer, VideoResBlock): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(layer, TimestepBlock): @@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: - if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: - found_patched = False - for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: - if isinstance(layer, class_type): - x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) - found_patched = True - break - if found_patched: - continue x = layer(x) return x @@ -894,6 +895,12 @@ class UNetModel(nn.Module): h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + if "middle_block_after_patch" in transformer_patches: + patch = transformer_patches["middle_block_after_patch"] + for p in patch: + out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y, + "timesteps": timesteps, "transformer_options": transformer_options}) + h = out["h"] for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) @@ -905,8 +912,9 @@ class UNetModel(nn.Module): for p in patch: h, hsp = p(h, hsp, transformer_options) - h = th.cat([h, hsp], dim=1) - del hsp + if hsp is not None: + h = th.cat([h, hsp], dim=1) + del hsp if len(hs) > 0: output_shape = hs[-1].shape else: diff --git a/comfy/ldm/supir/__init__.py b/comfy/ldm/supir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/supir/supir_modules.py b/comfy/ldm/supir/supir_modules.py new file mode 100644 index 000000000..8fee6ac93 --- /dev/null +++ b/comfy/ldm/supir/supir_modules.py @@ -0,0 +1,226 @@ +import torch +import torch.nn as nn + +from comfy.ldm.modules.diffusionmodules.util import timestep_embedding +from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer +from comfy.ldm.modules.attention import optimized_attention + + +class ZeroSFT(nn.Module): + def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None): + super().__init__() + + ks = 3 + pw = ks // 2 + + self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device) + + nhidden = 128 + + self.mlp_shared = nn.Sequential( + operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device), + nn.SiLU() + ) + self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device) + self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device) + + self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device) + self.pre_concat = bool(concat_channels != 0) + + def forward(self, c, h, h_ori=None, control_scale=1): + if h_ori is not None and self.pre_concat: + h_raw = torch.cat([h_ori, h], dim=1) + else: + h_raw = h + + h = h.add_(self.zero_conv(c)) + if h_ori is not None and self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + actv = self.mlp_shared(c) + gamma = self.zero_mul(actv) + beta = self.zero_add(actv) + h = self.param_free_norm(h) + h = torch.addcmul(h + beta, h, gamma) + if h_ori is not None and not self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + return torch.lerp(h_raw, h, control_scale) + + +class _CrossAttnInner(nn.Module): + """Inner cross-attention module matching the state_dict layout of the original CrossAttention.""" + def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), + ) + + def forward(self, x, context): + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + return self.to_out(optimized_attention(q, k, v, self.heads)) + + +class ZeroCrossAttn(nn.Module): + def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None): + super().__init__() + heads = query_dim // 64 + dim_head = 64 + self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations) + self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device) + self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device) + + def forward(self, context, x, control_scale=1): + b, c, h, w = x.shape + x_in = x + + x = self.attn( + self.norm1(x).flatten(2).transpose(1, 2), + self.norm2(context).flatten(2).transpose(1, 2), + ).transpose(1, 2).unflatten(2, (h, w)) + + return x_in + x * control_scale + + +class GLVControl(nn.Module): + """SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only).""" + def __init__( + self, + in_channels=4, + model_channels=320, + num_res_blocks=2, + attention_resolutions=(4, 2), + channel_mult=(1, 2, 4), + num_head_channels=64, + transformer_depth=(1, 2, 10), + context_dim=2048, + adm_in_channels=2816, + use_linear_in_transformer=True, + use_checkpoint=False, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__() + self.model_channels = model_channels + time_embed_dim = model_channels * 4 + + self.time_embed = nn.Sequential( + operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + + self.label_emb = nn.Sequential( + nn.Sequential( + operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + ) + + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + ]) + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(num_res_blocks): + layers = [ + ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels, + dtype=dtype, device=device, operations=operations) + ] + ch = mult * model_channels + if ds in attention_resolutions: + num_heads = ch // num_head_channels + layers.append( + SpatialTransformer(ch, num_heads, num_head_channels, + depth=transformer_depth[level], context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + dtype=dtype, device=device, operations=operations) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + if level != len(channel_mult) - 1: + self.input_blocks.append( + TimestepEmbedSequential( + Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations) + ) + ) + ds *= 2 + + num_heads = ch // num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations), + SpatialTransformer(ch, num_heads, num_head_channels, + depth=transformer_depth[-1], context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + dtype=dtype, device=device, operations=operations), + ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations), + ) + + self.input_hint_block = TimestepEmbedSequential( + operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + + def forward(self, x, timesteps, xt, context=None, y=None, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + self.label_emb(y) + + guided_hint = self.input_hint_block(x, emb, context) + + hs = [] + h = xt + for module in self.input_blocks: + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + hs.append(h) + return hs + + +class SUPIR(nn.Module): + """ + SUPIR model containing GLVControl (control encoder) and project_modules (adapters). + State dict keys match the original SUPIR checkpoint layout: + control_model.* -> GLVControl + project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn + """ + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + + self.control_model = GLVControl(dtype=dtype, device=device, operations=operations) + + project_channel_scale = 2 + cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3 + project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3] + concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0] + cross_attn_insert_idx = [6, 3] + + self.project_modules = nn.ModuleList() + for i in range(len(cond_output_channels)): + self.project_modules.append(ZeroSFT( + project_channels[i], cond_output_channels[i], + concat_channels=concat_channels[i], + dtype=dtype, device=device, operations=operations, + )) + + for i in cross_attn_insert_idx: + self.project_modules.insert(i, ZeroCrossAttn( + cond_output_channels[i], concat_channels[i], + dtype=dtype, device=device, operations=operations, + )) diff --git a/comfy/ldm/supir/supir_patch.py b/comfy/ldm/supir/supir_patch.py new file mode 100644 index 000000000..8c210971a --- /dev/null +++ b/comfy/ldm/supir/supir_patch.py @@ -0,0 +1,105 @@ +import torch +import comfy.model_management +from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample + + +class SUPIRPatch: + """ + Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters). + Runs GLVControl lazily on first patch invocation per step, applies adapters through + middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch. + """ + SIGMA_MAX = 14.6146 + + def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end): + self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl + self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn + self.hint_latent = hint_latent # encoded LQ image latent + self.strength_start = strength_start + self.strength_end = strength_end + self.cached_features = None + self.adapter_idx = 0 + self.control_idx = 0 + self.current_control_idx = 0 + self.active = True + + def _ensure_features(self, kwargs): + """Run GLVControl on first call per step, cache results.""" + if self.cached_features is not None: + return + x = kwargs["x"] + batch_size = x.shape[0] + comfy.model_management.load_models_gpu([self.model_patch]) + hint = self.hint_latent.to(device=x.device, dtype=x.dtype) + if hint.shape[0] < batch_size: + hint = hint.repeat(batch_size // hint.shape[0], 1, 1, 1)[:batch_size] + self.cached_features = self.model_patch.model.control_model( + hint, kwargs["timesteps"], x, + kwargs["context"], kwargs["y"] + ) + self.adapter_idx = len(self.project_modules) - 1 + self.control_idx = len(self.cached_features) - 1 + + def _get_control_scale(self, kwargs): + if self.strength_start == self.strength_end: + return self.strength_end + sigma = kwargs["transformer_options"].get("sigmas") + if sigma is None: + return self.strength_end + s = sigma[0].item() if sigma.dim() > 0 else sigma.item() + t = min(s / self.SIGMA_MAX, 1.0) + return t * (self.strength_start - self.strength_end) + self.strength_end + + def middle_after(self, kwargs): + """middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block.""" + self.cached_features = None # reset from previous step + self.current_scale = self._get_control_scale(kwargs) + self.active = self.current_scale > 0 + if not self.active: + return {"h": kwargs["h"]} + self._ensure_features(kwargs) + h = kwargs["h"] + h = self.project_modules[self.adapter_idx]( + self.cached_features[self.control_idx], h, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + self.control_idx -= 1 + return {"h": h} + + def output_block(self, h, hsp, transformer_options): + """output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat.""" + if not self.active: + return h, hsp + self.current_control_idx = self.control_idx + h = self.project_modules[self.adapter_idx]( + self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + self.control_idx -= 1 + return h, None + + def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw): + """forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample.""" + block_type, _ = transformer_options["block"] + if block_type == "output" and self.active and self.cached_features is not None: + x = self.project_modules[self.adapter_idx]( + self.cached_features[self.current_control_idx], x, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + return layer(x, output_shape=output_shape) + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.cached_features = None + if self.hint_latent is not None: + self.hint_latent = self.hint_latent.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + + def register(self, model_patcher): + """Register all patches on a cloned model patcher.""" + model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch") + model_patcher.set_model_output_block_patch(self.output_block) + model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c26d37db2..8159c9d74 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -509,6 +509,10 @@ class ModelPatcher: def set_model_noise_refiner_patch(self, patch): self.set_model_patch(patch, "noise_refiner") + def set_model_middle_block_after_patch(self, patch): + self.set_model_patch(patch, "middle_block_after_patch") + + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options["scale_x"] = scale_x diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 176e6bc2f..f1ba2ed18 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,7 +7,10 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +import comfy.ldm.supir.supir_modules from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel +from comfy_api.latest import ComfyExtension, io +from comfy.ldm.supir.supir_patch import SUPIRPatch class BlockWiseControlBlock(torch.nn.Module): @@ -266,6 +269,13 @@ class ModelPatchLoader: out_dim=sd["audio_proj.norm.weight"].shape[0], device=comfy.model_management.unet_offload_device(), operations=comfy.ops.manual_cast) + elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd: + prefix_replace = {} + if 'model.control_model.input_hint_block.0.weight' in sd: + prefix_replace["model.control_model."] = "control_model." + prefix_replace["model.diffusion_model.project_modules."] = "project_modules." + sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True) + model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model.load_state_dict(sd, assign=model_patcher.is_dynamic()) @@ -565,9 +575,66 @@ class MultiTalkModelPatch(torch.nn.Module): ) +class SUPIRApply(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SUPIRApply", + category="model_patches/supir", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Latent.Input("latent"), + io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01, + tooltip="Control strength at the start of sampling (high sigma)."), + io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01, + tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."), + io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True, + tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."), + io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Sigma threshold below which restore_cfg is disabled."), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, latent: io.Latent.Type, + strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput: + model_patched = model.clone() + hint_latent = model.get_model_object("latent_format").process_in(latent["samples"]) + patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end) + patch.register(model_patched) + + if restore_cfg > 0.0: + x_center = hint_latent.clone() + sigma_max = 14.6146 + + def restore_cfg_function(args): + denoised = args["denoised"] + sigma = args["sigma"] + if sigma.dim() > 0: + s = sigma[0].item() + else: + s = sigma.item() + if s > restore_cfg_s_tmin: + ref = x_center.to(device=denoised.device, dtype=denoised.dtype) + if ref.shape[0] < denoised.shape[0]: + ref = ref.repeat(denoised.shape[0] // ref.shape[0], 1, 1, 1)[:denoised.shape[0]] + sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma + d_center = denoised - ref + denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg) + return denoised + + model_patched.set_model_sampler_post_cfg_function(restore_cfg_function) + + return io.NodeOutput(model_patched) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, "ZImageFunControlnet": ZImageFunControlnet, "USOStyleReference": USOStyleReference, + "SUPIRApply": SUPIRApply, }