mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
227 lines
9.2 KiB
Python
227 lines
9.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
|
|
|
|
class ZeroSFT(nn.Module):
|
|
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
|
|
ks = 3
|
|
pw = ks // 2
|
|
|
|
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
|
|
|
|
nhidden = 128
|
|
|
|
self.mlp_shared = nn.Sequential(
|
|
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
|
|
nn.SiLU()
|
|
)
|
|
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
|
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
|
|
|
|
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
|
|
self.pre_concat = bool(concat_channels != 0)
|
|
|
|
def forward(self, c, h, h_ori=None, control_scale=1):
|
|
if h_ori is not None and self.pre_concat:
|
|
h_raw = torch.cat([h_ori, h], dim=1)
|
|
else:
|
|
h_raw = h
|
|
|
|
h = h + self.zero_conv(c)
|
|
if h_ori is not None and self.pre_concat:
|
|
h = torch.cat([h_ori, h], dim=1)
|
|
actv = self.mlp_shared(c)
|
|
gamma = self.zero_mul(actv)
|
|
beta = self.zero_add(actv)
|
|
h = self.param_free_norm(h)
|
|
h = torch.addcmul(h + beta, h, gamma)
|
|
if h_ori is not None and not self.pre_concat:
|
|
h = torch.cat([h_ori, h], dim=1)
|
|
return torch.lerp(h_raw, h, control_scale)
|
|
|
|
|
|
class _CrossAttnInner(nn.Module):
|
|
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
|
|
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
self.heads = heads
|
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
self.to_out = nn.Sequential(
|
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
|
)
|
|
|
|
def forward(self, x, context):
|
|
q = self.to_q(x)
|
|
k = self.to_k(context)
|
|
v = self.to_v(context)
|
|
return self.to_out(optimized_attention(q, k, v, self.heads))
|
|
|
|
|
|
class ZeroCrossAttn(nn.Module):
|
|
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
heads = query_dim // 64
|
|
dim_head = 64
|
|
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
|
|
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
|
|
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
|
|
|
|
def forward(self, context, x, control_scale=1):
|
|
b, c, h, w = x.shape
|
|
x_in = x
|
|
|
|
x = self.attn(
|
|
self.norm1(x).flatten(2).transpose(1, 2),
|
|
self.norm2(context).flatten(2).transpose(1, 2),
|
|
).transpose(1, 2).unflatten(2, (h, w))
|
|
|
|
return x_in + x * control_scale
|
|
|
|
|
|
class GLVControl(nn.Module):
|
|
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
|
|
def __init__(
|
|
self,
|
|
in_channels=4,
|
|
model_channels=320,
|
|
num_res_blocks=2,
|
|
attention_resolutions=(4, 2),
|
|
channel_mult=(1, 2, 4),
|
|
num_head_channels=64,
|
|
transformer_depth=(1, 2, 10),
|
|
context_dim=2048,
|
|
adm_in_channels=2816,
|
|
use_linear_in_transformer=True,
|
|
use_checkpoint=False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.model_channels = model_channels
|
|
time_embed_dim = model_channels * 4
|
|
|
|
self.time_embed = nn.Sequential(
|
|
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
|
|
nn.SiLU(),
|
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
|
)
|
|
|
|
self.label_emb = nn.Sequential(
|
|
nn.Sequential(
|
|
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
|
|
nn.SiLU(),
|
|
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
|
|
)
|
|
)
|
|
|
|
self.input_blocks = nn.ModuleList([
|
|
TimestepEmbedSequential(
|
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
|
)
|
|
])
|
|
ch = model_channels
|
|
ds = 1
|
|
for level, mult in enumerate(channel_mult):
|
|
for nr in range(num_res_blocks):
|
|
layers = [
|
|
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
|
|
dtype=dtype, device=device, operations=operations)
|
|
]
|
|
ch = mult * model_channels
|
|
if ds in attention_resolutions:
|
|
num_heads = ch // num_head_channels
|
|
layers.append(
|
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
|
depth=transformer_depth[level], context_dim=context_dim,
|
|
use_linear=use_linear_in_transformer,
|
|
use_checkpoint=use_checkpoint,
|
|
dtype=dtype, device=device, operations=operations)
|
|
)
|
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
|
if level != len(channel_mult) - 1:
|
|
self.input_blocks.append(
|
|
TimestepEmbedSequential(
|
|
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
|
|
)
|
|
)
|
|
ds *= 2
|
|
|
|
num_heads = ch // num_head_channels
|
|
self.middle_block = TimestepEmbedSequential(
|
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
|
SpatialTransformer(ch, num_heads, num_head_channels,
|
|
depth=transformer_depth[-1], context_dim=context_dim,
|
|
use_linear=use_linear_in_transformer,
|
|
use_checkpoint=use_checkpoint,
|
|
dtype=dtype, device=device, operations=operations),
|
|
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
|
|
)
|
|
|
|
self.input_hint_block = TimestepEmbedSequential(
|
|
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
|
|
)
|
|
|
|
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
|
emb = self.time_embed(t_emb) + self.label_emb(y)
|
|
|
|
guided_hint = self.input_hint_block(x, emb, context)
|
|
|
|
hs = []
|
|
h = xt
|
|
for module in self.input_blocks:
|
|
if guided_hint is not None:
|
|
h = module(h, emb, context)
|
|
h += guided_hint
|
|
guided_hint = None
|
|
else:
|
|
h = module(h, emb, context)
|
|
hs.append(h)
|
|
h = self.middle_block(h, emb, context)
|
|
hs.append(h)
|
|
return hs
|
|
|
|
|
|
class SUPIR(nn.Module):
|
|
"""
|
|
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
|
|
State dict keys match the original SUPIR checkpoint layout:
|
|
control_model.* -> GLVControl
|
|
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
|
|
"""
|
|
def __init__(self, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
|
|
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
|
|
|
|
project_channel_scale = 2
|
|
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
|
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
|
|
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
|
cross_attn_insert_idx = [6, 3]
|
|
|
|
self.project_modules = nn.ModuleList()
|
|
for i in range(len(cond_output_channels)):
|
|
self.project_modules.append(ZeroSFT(
|
|
project_channels[i], cond_output_channels[i],
|
|
concat_channels=concat_channels[i],
|
|
dtype=dtype, device=device, operations=operations,
|
|
))
|
|
|
|
for i in cross_attn_insert_idx:
|
|
self.project_modules.insert(i, ZeroCrossAttn(
|
|
cond_output_channels[i], concat_channels[i],
|
|
dtype=dtype, device=device, operations=operations,
|
|
))
|