mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Merge with master
This commit is contained in:
commit
7520691021
@ -12,7 +12,7 @@ This UI will let you design and execute advanced stable diffusion pipelines usin
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||
|
||||
@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
x = self.embeddings(input_tokens)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
|
||||
@ -319,9 +319,10 @@ def load_controlnet(ckpt_path, model=None):
|
||||
return ControlLora(controlnet_data)
|
||||
|
||||
controlnet_config = None
|
||||
supported_inference_dtypes = None
|
||||
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
unet_dtype = model_management.unet_dtype()
|
||||
controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
||||
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
@ -381,12 +382,20 @@ def load_controlnet(ckpt_path, model=None):
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
unet_dtype = model_management.unet_dtype()
|
||||
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
model_config = model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
controlnet_config = model_config.unet_config
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
if supported_inference_dtypes is None:
|
||||
unet_dtype = model_management.unet_dtype()
|
||||
else:
|
||||
unet_dtype = model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = ops.manual_cast
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
@ -4,7 +4,8 @@ import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -24,7 +25,7 @@ def default(val, d):
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
@ -37,14 +38,14 @@ class FeedForward(nn.Module):
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
ops.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
ops.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -59,11 +60,12 @@ class GatedCrossAttentionDense(nn.Module):
|
||||
query_dim=query_dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
dim_head=d_head,
|
||||
operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -89,17 +91,18 @@ class GatedSelfAttentionDense(nn.Module):
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
self.linear = ops.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim,
|
||||
context_dim=query_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
dim_head=d_head,
|
||||
operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -128,14 +131,14 @@ class GatedSelfAttentionDense2(nn.Module):
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
self.linear = ops.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
|
||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -203,11 +206,11 @@ class PositionNet(nn.Module):
|
||||
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
||||
|
||||
self.linears = nn.Sequential(
|
||||
nn.Linear(self.in_dim + self.position_dim, 512),
|
||||
ops.Linear(self.in_dim + self.position_dim, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, 512),
|
||||
ops.Linear(512, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, out_dim),
|
||||
ops.Linear(512, out_dim),
|
||||
)
|
||||
|
||||
self.null_positive_feature = torch.nn.Parameter(
|
||||
@ -217,16 +220,15 @@ class PositionNet(nn.Module):
|
||||
|
||||
def forward(self, boxes, masks, positive_embeddings):
|
||||
B, N, _ = boxes.shape
|
||||
dtype = self.linears[0].weight.dtype
|
||||
masks = masks.unsqueeze(-1).to(dtype)
|
||||
positive_embeddings = positive_embeddings.to(dtype)
|
||||
masks = masks.unsqueeze(-1)
|
||||
positive_embeddings = positive_embeddings
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||
positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
||||
xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
||||
|
||||
# replace padding with learnable null embedding
|
||||
positive_embeddings = positive_embeddings * \
|
||||
@ -253,7 +255,7 @@ class Gligen(nn.Module):
|
||||
def func(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return module(x, objs.to(device=x.device, dtype=x.dtype))
|
||||
return func
|
||||
|
||||
def set_position(self, latent_image_shape, position_params, device):
|
||||
|
||||
@ -37,3 +37,11 @@ class SDXL(LatentFormat):
|
||||
class SD_X4(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.08333
|
||||
|
||||
class SC_Prior(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
|
||||
161
comfy/ldm/cascade/common.py
Normal file
161
comfy/ldm/cascade/common.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class OptimizedAttention(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = nhead
|
||||
|
||||
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
class Attention2D(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
|
||||
def LayerNorm2d_op(operations):
|
||||
class LayerNorm2d(operations.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return LayerNorm2d
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
||||
def __init__(self, dim, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
|
||||
super().__init__()
|
||||
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
|
||||
# self.depthwise = SAMBlock(c, num_heads, expansion)
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.channelwise = nn.Sequential(
|
||||
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
x_res = x
|
||||
x = self.norm(self.depthwise(x))
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x + x_res
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = self_attn
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
|
||||
self.kv_mapper = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
return x
|
||||
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.channelwise = nn.Sequential(
|
||||
operations.Linear(c, c * 4, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
|
||||
self.conds = conds
|
||||
for cname in conds:
|
||||
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, t):
|
||||
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
||||
for i, c in enumerate(self.conds):
|
||||
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
||||
a, b = a + ac, b + bc
|
||||
return x * (1 + a) + b
|
||||
258
comfy/ldm/cascade/stage_a.py
Normal file
258
comfy/ldm/cascade/stage_a.py
Normal file
@ -0,0 +1,258 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, codebook):
|
||||
with torch.no_grad():
|
||||
codebook_sqr = torch.sum(codebook ** 2, dim=1)
|
||||
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
|
||||
|
||||
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
|
||||
_, indices = dist.min(dim=1)
|
||||
|
||||
ctx.save_for_backward(indices, codebook)
|
||||
ctx.mark_non_differentiable(indices)
|
||||
|
||||
nn = torch.index_select(codebook, 0, indices)
|
||||
return nn, indices
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_indices):
|
||||
grad_inputs, grad_codebook = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_inputs = grad_output.clone()
|
||||
if ctx.needs_input_grad[1]:
|
||||
# Gradient wrt. the codebook
|
||||
indices, codebook = ctx.saved_tensors
|
||||
|
||||
grad_codebook = torch.zeros_like(codebook)
|
||||
grad_codebook.index_add_(0, indices, grad_output)
|
||||
|
||||
return (grad_inputs, grad_codebook)
|
||||
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
|
||||
"""
|
||||
Takes an input of variable size (as long as the last dimension matches the embedding size).
|
||||
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
|
||||
with the same size as the input, vq and commitment components for the loss as a touple
|
||||
in the second output and the indices of the quantized vectors in the third:
|
||||
quantized, (vq_loss, commit_loss), indices
|
||||
"""
|
||||
super(VectorQuantize, self).__init__()
|
||||
|
||||
self.codebook = nn.Embedding(k, embedding_size)
|
||||
self.codebook.weight.data.uniform_(-1./k, 1./k)
|
||||
self.vq = vector_quantize.apply
|
||||
|
||||
self.ema_decay = ema_decay
|
||||
self.ema_loss = ema_loss
|
||||
if ema_loss:
|
||||
self.register_buffer('ema_element_count', torch.ones(k))
|
||||
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
|
||||
|
||||
def _laplace_smoothing(self, x, epsilon):
|
||||
n = torch.sum(x)
|
||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||
|
||||
def _updateEMA(self, z_e_x, indices):
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
|
||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
|
||||
def idx2vq(self, idx, dim=-1):
|
||||
q_idx = self.codebook(idx)
|
||||
if dim != -1:
|
||||
q_idx = q_idx.movedim(-1, dim)
|
||||
return q_idx
|
||||
|
||||
def forward(self, x, get_losses=True, dim=-1):
|
||||
if dim != -1:
|
||||
x = x.movedim(dim, -1)
|
||||
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
|
||||
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
|
||||
vq_loss, commit_loss = None, None
|
||||
if self.ema_loss and self.training:
|
||||
self._updateEMA(z_e_x.detach(), indices.detach())
|
||||
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
|
||||
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
|
||||
if get_losses:
|
||||
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
|
||||
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
|
||||
|
||||
z_q_x = z_q_x.view(x.shape)
|
||||
if dim != -1:
|
||||
z_q_x = z_q_x.movedim(-1, dim)
|
||||
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_hidden):
|
||||
super().__init__()
|
||||
# depthwise/attention
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
def forward(self, x):
|
||||
mods = self.gammas
|
||||
|
||||
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
||||
try:
|
||||
x = x + self.depthwise(x_temp) * mods[2]
|
||||
except: #operation not implemented for bf16
|
||||
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
|
||||
x = x + self.depthwise[1](x_temp) * mods[2]
|
||||
|
||||
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
||||
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StageA(nn.Module):
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
|
||||
scale_factor=0.43): # 0.3764
|
||||
super().__init__()
|
||||
self.c_latent = c_latent
|
||||
self.scale_factor = scale_factor
|
||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
||||
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
self.down_blocks[0]
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
def encode(self, x, quantize=False):
|
||||
x = self.in_block(x)
|
||||
x = self.down_blocks(x)
|
||||
if quantize:
|
||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
||||
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
|
||||
else:
|
||||
return x / self.scale_factor
|
||||
|
||||
def decode(self, x):
|
||||
x = x * self.scale_factor
|
||||
x = self.up_blocks(x)
|
||||
x = self.out_block(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, quantize=False):
|
||||
qe, x, _, vq_loss = self.encode(x, quantize)
|
||||
x = self.decode(qe)
|
||||
return x, vq_loss
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
x = self.encoder(x)
|
||||
if cond is not None:
|
||||
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
||||
x = torch.cat([x, cond], dim=1)
|
||||
x = self.shuffle(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
257
comfy/ldm/cascade/stage_b.py
Normal file
257
comfy/ldm/cascade/stage_b.py
Normal file
@ -0,0 +1,257 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||
|
||||
class StageB(nn.Module):
|
||||
def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
|
||||
nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
|
||||
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
|
||||
c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
|
||||
t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.c_r = c_r
|
||||
self.t_conds = t_conds
|
||||
self.c_clip_seq = c_clip_seq
|
||||
if not isinstance(dropout, list):
|
||||
dropout = [dropout] * len(c_hidden)
|
||||
if not isinstance(self_attn, list):
|
||||
self_attn = [self_attn] * len(c_hidden)
|
||||
|
||||
# CONDITIONING
|
||||
self.effnet_mapper = nn.Sequential(
|
||||
operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
self.pixels_mapper = nn.Sequential(
|
||||
operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == 'C':
|
||||
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'A':
|
||||
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'F':
|
||||
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'T':
|
||||
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
raise Exception(f'Block type {block_type} not supported')
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(c_hidden)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
))
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(blocks[0][i]):
|
||||
for block_type in level_config[i]:
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[0][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(c_hidden))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
))
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(blocks[1][::-1][i]):
|
||||
for k, block_type in enumerate(level_config[i]):
|
||||
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||
self_attn=self_attn[i])
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
# --- WEIGHT INIT ---
|
||||
# self.apply(self._init_weights) # General init
|
||||
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||
# elif isinstance(block, TimestepBlock):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
if self.c_r % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||
return emb
|
||||
|
||||
def gen_c_embeddings(self, clip):
|
||||
if len(clip.shape) == 2:
|
||||
clip = clip.unsqueeze(1)
|
||||
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
x = block(x)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||
clip = self.gen_c_embeddings(clip)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
x = x + self.effnet_mapper(
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
||||
271
comfy/ldm/cascade/stage_c.py
Normal file
271
comfy/ldm/cascade/stage_c.py
Normal file
@ -0,0 +1,271 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import math
|
||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||
# from .controlnet import ControlNetDeliverer
|
||||
|
||||
class UpDownBlock2d(nn.Module):
|
||||
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert mode in ['up', 'down']
|
||||
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
|
||||
align_corners=True) if enabled else nn.Identity()
|
||||
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
|
||||
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class StageC(nn.Module):
|
||||
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
|
||||
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
|
||||
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
|
||||
dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.c_r = c_r
|
||||
self.t_conds = t_conds
|
||||
self.c_clip_seq = c_clip_seq
|
||||
if not isinstance(dropout, list):
|
||||
dropout = [dropout] * len(c_hidden)
|
||||
if not isinstance(self_attn, list):
|
||||
self_attn = [self_attn] * len(c_hidden)
|
||||
|
||||
# CONDITIONING
|
||||
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
|
||||
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
|
||||
)
|
||||
|
||||
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == 'C':
|
||||
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'A':
|
||||
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'F':
|
||||
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'T':
|
||||
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
raise Exception(f'Block type {block_type} not supported')
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(c_hidden)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||
))
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(blocks[0][i]):
|
||||
for block_type in level_config[i]:
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[0][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(c_hidden))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||
))
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(blocks[1][::-1][i]):
|
||||
for k, block_type in enumerate(level_config[i]):
|
||||
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||
self_attn=self_attn[i])
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
# --- WEIGHT INIT ---
|
||||
# self.apply(self._init_weights) # General init
|
||||
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||
# elif isinstance(block, TimestepBlock):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
if self.c_r % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||
return emb
|
||||
|
||||
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
|
||||
clip_txt = self.clip_txt_mapper(clip_txt)
|
||||
if len(clip_txt_pooled.shape) == 2:
|
||||
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
|
||||
if len(clip_img.shape) == 2:
|
||||
clip_img = clip_img.unsqueeze(1)
|
||||
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
|
||||
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
|
||||
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
if cnet is not None:
|
||||
next_cnet = cnet()
|
||||
if next_cnet is not None:
|
||||
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
x = block(x)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
if cnet is not None:
|
||||
next_cnet = cnet()
|
||||
if next_cnet is not None:
|
||||
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
if cnet is not None:
|
||||
cnet = ControlNetDeliverer(cnet)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
||||
95
comfy/ldm/cascade/stage_c_coder.py
Normal file
95
comfy/ldm/cascade/stage_c_coder.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
def __init__(self, c_latent=16):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
|
||||
|
||||
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
|
||||
class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.blocks(x) - 0.5) * 2.0
|
||||
|
||||
class StageC_coder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.previewer = Previewer()
|
||||
self.encoder = EfficientNetEncoder()
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.previewer(x)
|
||||
@ -113,7 +113,12 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
else:
|
||||
sim += mask
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
sim.add_(mask)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
@ -164,6 +169,13 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
if query_chunk_size is None:
|
||||
query_chunk_size = 512
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
@ -222,6 +234,13 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
|
||||
@ -5,6 +5,8 @@ from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmenta
|
||||
from . import model_management
|
||||
from . import conds
|
||||
from . import ops
|
||||
from .ldm.cascade.stage_c import StageC
|
||||
from .ldm.cascade.stage_b import StageB
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
|
||||
@ -12,9 +14,10 @@ class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
STABLE_CASCADE = 4
|
||||
|
||||
|
||||
from .model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
from .model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -27,6 +30,9 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -35,7 +41,7 @@ def model_sampling(model_config, model_type):
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
@ -48,7 +54,7 @@ class BaseModel(torch.nn.Module):
|
||||
operations = ops.manual_cast
|
||||
else:
|
||||
operations = ops.disable_weight_init
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@ -426,3 +432,56 @@ class SD_X4Upscaler(BaseModel):
|
||||
out['c_concat'] = conds.CONDNoiseShape(image)
|
||||
out['y'] = conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
clip_text_pooled = kwargs["pooled_output"]
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
if "unclip_conditioning" in kwargs:
|
||||
embeds = []
|
||||
for unclip_cond in kwargs["unclip_conditioning"]:
|
||||
weight = unclip_cond["strength"]
|
||||
embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
|
||||
clip_img = torch.cat(embeds, dim=1)
|
||||
else:
|
||||
clip_img = torch.zeros((1, 1, 768))
|
||||
out["clip_img"] = conds.CONDRegular(clip_img)
|
||||
out["sca"] = conds.CONDRegular(torch.zeros((1,)))
|
||||
out["crp"] = conds.CONDRegular(torch.zeros((1,)))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['clip_text'] = conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class StableCascade_B(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"]
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||
|
||||
out["effnet"] = conds.CONDRegular(prior)
|
||||
out["sca"] = conds.CONDRegular(torch.zeros((1,)))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['clip'] = conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
@ -28,9 +28,38 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
||||
return None
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
def detect_unet_config(state_dict, key_prefix):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||
unet_config = {}
|
||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||
if text_mapper_name in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'c'
|
||||
w = state_dict[text_mapper_name]
|
||||
if w.shape[0] == 1536: #stage c lite
|
||||
unet_config['c_cond'] = 1536
|
||||
unet_config['c_hidden'] = [1536, 1536]
|
||||
unet_config['nhead'] = [24, 24]
|
||||
unet_config['blocks'] = [[4, 12], [12, 4]]
|
||||
elif w.shape[0] == 2048: #stage c full
|
||||
unet_config['c_cond'] = 2048
|
||||
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'b'
|
||||
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
|
||||
if w.shape[-1] == 640:
|
||||
unet_config['c_hidden'] = [320, 640, 1280, 1280]
|
||||
unet_config['nhead'] = [-1, -1, 20, 20]
|
||||
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
|
||||
elif w.shape[-1] == 576: #stage b lite
|
||||
unet_config['c_hidden'] = [320, 576, 1152, 1152]
|
||||
unet_config['nhead'] = [-1, 9, 18, 18]
|
||||
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
|
||||
return unet_config
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
@ -45,7 +74,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
unet_config["dtype"] = dtype
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
@ -159,8 +187,8 @@ def model_config_from_unet_config(unet_config):
|
||||
print("no match", unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return supported_models_base.BASE(unet_config)
|
||||
@ -206,7 +234,7 @@ def convert_config(unet_config):
|
||||
return new_config
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
match = {}
|
||||
transformer_depth = []
|
||||
|
||||
@ -313,8 +341,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
return convert_config(unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||
def model_config_from_diffusers_unet(state_dict):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict)
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
||||
@ -496,7 +496,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@ -506,11 +506,15 @@ def unet_dtype(device=None, model_params=0):
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
return torch.float16
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device):
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
@ -518,8 +522,15 @@ def unet_manual_cast(weight_dtype, inference_device):
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
||||
if fp16_supported:
|
||||
bf16_supported = should_use_bf16(inference_device)
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
@ -694,17 +705,20 @@ def mps_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def is_device_cpu(device):
|
||||
def is_device_type(device, type):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'cpu'):
|
||||
if (device.type == type):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_device_cpu(device):
|
||||
return is_device_type(device, 'cpu')
|
||||
|
||||
def is_device_mps(device):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'mps'):
|
||||
return True
|
||||
return False
|
||||
return is_device_type(device, 'mps')
|
||||
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
global directml_enabled
|
||||
@ -716,9 +730,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if FORCE_FP16:
|
||||
return True
|
||||
|
||||
if device is not None: #TODO
|
||||
if device is not None:
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
@ -726,8 +740,11 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
return False #TODO ?
|
||||
if mps_mode():
|
||||
return True
|
||||
|
||||
if cpu_mode():
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
@ -767,6 +784,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
return True
|
||||
|
||||
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
if device is not None:
|
||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||
return False
|
||||
|
||||
if device is not None: #TODO not sure about mps bf16 support
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda")
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
if bf16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
with model_management_lock:
|
||||
global cpu_state
|
||||
|
||||
@ -132,3 +132,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
|
||||
log_sigma_min = math.log(self.sigma_min)
|
||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||
|
||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(sampling_settings.get("shift", 1.0))
|
||||
|
||||
def set_parameters(self, shift=1.0, cosine_s=8e-3):
|
||||
self.shift = shift
|
||||
self.cosine_s = torch.tensor(cosine_s)
|
||||
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
|
||||
|
||||
#This part is just for compatibility with some schedulers in the codebase
|
||||
self.num_timesteps = 1000
|
||||
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
|
||||
for x in range(self.num_timesteps):
|
||||
t = x / self.num_timesteps
|
||||
sigmas[x] = self.sigma(t)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def sigma(self, timestep):
|
||||
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
|
||||
|
||||
if self.shift != 1.0:
|
||||
var = alpha_cumprod
|
||||
logSNR = (var/(1-var)).log()
|
||||
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
|
||||
alpha_cumprod = logSNR.sigmoid()
|
||||
|
||||
alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
|
||||
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
|
||||
|
||||
def timestep(self, sigma):
|
||||
var = 1 / ((sigma * sigma) + 1)
|
||||
var = var.clamp(0, 1.0)
|
||||
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
|
||||
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return t
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent))
|
||||
|
||||
@ -292,18 +292,7 @@ class VAEEncode:
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
@staticmethod
|
||||
def vae_encode_crop_pixels(pixels):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
pixels = self.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
return ({"samples":t}, )
|
||||
|
||||
@ -319,7 +308,6 @@ class VAEEncodeTiled:
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels, tile_size):
|
||||
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
|
||||
return ({"samples":t}, )
|
||||
|
||||
@ -333,14 +321,14 @@ class VAEEncodeForInpaint:
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def encode(self, vae, pixels, mask, grow_mask_by=6):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
|
||||
pixels = pixels.clone()
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
|
||||
y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
|
||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
|
||||
|
||||
@ -837,15 +825,20 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"),),
|
||||
"type": (["stable_diffusion", "stable_cascade"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
if type == "stable_cascade":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
return (clip,)
|
||||
|
||||
class DualCLIPLoader:
|
||||
@ -1418,7 +1411,7 @@ class SaveImage:
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
for image in images:
|
||||
for (batch_number, image) in enumerate(images):
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
metadata = None
|
||||
@ -1430,7 +1423,8 @@ class SaveImage:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
||||
abs_path = os.path.join(full_output_folder, file)
|
||||
img.save(abs_path, pnginfo=metadata, compress_level=self.compress_level)
|
||||
results.append({
|
||||
|
||||
49
comfy/ops.py
49
comfy/ops.py
@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from . import model_management
|
||||
|
||||
@ -78,7 +96,11 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -87,6 +109,28 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
||||
comfy_cast_weights = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input, output_size=None):
|
||||
num_spatial_dims = 2
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
@ -112,3 +156,6 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class LayerNorm(disable_weight_init.LayerNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
123
comfy/sd.py
123
comfy/sd.py
@ -1,7 +1,11 @@
|
||||
import torch
|
||||
from enum import Enum
|
||||
|
||||
from . import model_management
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
|
||||
import yaml
|
||||
|
||||
from . import utils
|
||||
@ -134,8 +138,11 @@ class CLIP:
|
||||
tokens = self.tokenize(text)
|
||||
return self.encode_from_tokens(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
return self.cond_stage_model.state_dict()
|
||||
@ -155,7 +162,10 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@ -168,6 +178,34 @@ class VAE:
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.first_stage_model = taesd.TAESD()
|
||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
#TODO
|
||||
#self.memory_used_encode
|
||||
#self.memory_used_decode
|
||||
self.process_input = lambda image: image
|
||||
self.process_output = lambda image: image
|
||||
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["encoder.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["previewer.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
else:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
@ -175,6 +213,7 @@ class VAE:
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
@ -200,18 +239,27 @@ class VAE:
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
|
||||
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
|
||||
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp((
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
output = self.process_output(
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@ -220,7 +268,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@ -235,10 +283,10 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
@ -252,6 +300,7 @@ class VAE:
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
@ -261,7 +310,7 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
@ -271,6 +320,7 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
@ -297,8 +347,11 @@ def load_style_model(ckpt_path):
|
||||
model.load_state_dict(model_data)
|
||||
return StyleModel(model)
|
||||
|
||||
class CLIPType(Enum):
|
||||
STABLE_DIFFUSION = 1
|
||||
STABLE_CASCADE = 2
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(utils.load_torch_file(p, safe_load=True))
|
||||
@ -314,8 +367,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
if clip_type == CLIPType.STABLE_CASCADE:
|
||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
@ -438,15 +495,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
clip_target = None
|
||||
|
||||
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
@ -467,14 +521,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
clip_target = model_config.clip_target()
|
||||
if clip_target is not None:
|
||||
sd = model_config.process_clip_state_dict(sd)
|
||||
if any(k.startswith('cond_stage_model.') for k in sd):
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_model_weights(w, sd)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
print("clip missing:", m)
|
||||
|
||||
if len(u) > 0:
|
||||
print("clip unexpected:", u)
|
||||
else:
|
||||
print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
@ -495,16 +552,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
parameters = utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
new_sd = sd
|
||||
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
@ -516,8 +572,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
print(diffusers_keys[k], k)
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
|
||||
@ -68,7 +68,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
@ -91,7 +91,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.special_tokens = special_tokens
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.enable_attention_masks = False
|
||||
self.enable_attention_masks = enable_attention_masks
|
||||
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
|
||||
@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
||||
|
||||
@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE):
|
||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
replace_prefix = {}
|
||||
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
replace_prefix["cond_stage_model."] = "clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE):
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
|
||||
replace_prefix["cond_stage_model.model."] = "clip_h."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE):
|
||||
keys_to_replace = {}
|
||||
replace_prefix = {}
|
||||
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
|
||||
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection"
|
||||
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
@ -306,5 +305,66 @@ class SD_X4Upscaler(SD20):
|
||||
out = model_base.SD_X4Upscaler(self, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
||||
class Stable_Cascade_C(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'c',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent_formats.SC_Prior
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoder."]
|
||||
clip_vision_prefix = "clip_l_vision."
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
key_list = list(state_dict.keys())
|
||||
for y in ["weight", "bias"]:
|
||||
suffix = "in_proj_{}".format(y)
|
||||
keys = filter(lambda a: a.endswith(suffix), key_list)
|
||||
for k_from in keys:
|
||||
weights = state_dict.pop(k_from)
|
||||
prefix = k_from[:-(len(suffix) + 1)]
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["to_q", "to_k", "to_v"]
|
||||
k_to = "{}.{}.{}".format(prefix, p[x], y)
|
||||
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return state_dict
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_C(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||
|
||||
class Stable_Cascade_B(Stable_Cascade_C):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'b',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent_formats.SC_B
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
clip_vision_prefix = None
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_B(self, device=device)
|
||||
return out
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -22,13 +22,15 @@ class BASE:
|
||||
sampling_settings = {}
|
||||
latent_format = latent_formats.LatentFormat
|
||||
vae_key_prefix = ["first_stage_model."]
|
||||
text_encoder_key_prefix = ["cond_stage_model."]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
manual_cast_dtype = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config):
|
||||
for k in s.unet_config:
|
||||
if s.unet_config[k] != unet_config[k]:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -54,6 +56,7 @@ class BASE:
|
||||
return out
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
@ -63,7 +66,7 @@ class BASE:
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
||||
@ -77,8 +80,9 @@ class BASE:
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_vae_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "first_stage_model."}
|
||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def set_manual_cast(self, manual_cast_dtype):
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
@ -171,6 +171,8 @@ UNET_MAP_BASIC = {
|
||||
}
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
if "num_res_blocks" not in unet_config:
|
||||
return {}
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
|
||||
@ -24,7 +24,7 @@ export function getPngMetadata(file) {
|
||||
const length = dataView.getUint32(offset);
|
||||
// Get the chunk type
|
||||
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
|
||||
if (type === "tEXt" || type == "comf") {
|
||||
if (type === "tEXt" || type == "comf" || type === "iTXt") {
|
||||
// Get the keyword
|
||||
let keyword_end = offset + 8;
|
||||
while (pngData[keyword_end] !== 0) {
|
||||
@ -33,7 +33,7 @@ export function getPngMetadata(file) {
|
||||
const keyword = String.fromCharCode(...pngData.slice(offset + 8, keyword_end));
|
||||
// Get the text
|
||||
const contentArraySegment = pngData.slice(keyword_end + 1, offset + 8 + length);
|
||||
const contentJson = Array.from(contentArraySegment).map(s=>String.fromCharCode(s)).join('')
|
||||
const contentJson = new TextDecoder("utf-8").decode(contentArraySegment);
|
||||
txt_chunks[keyword] = contentJson;
|
||||
}
|
||||
|
||||
|
||||
@ -98,6 +98,32 @@ class ModelSamplingDiscrete:
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class ModelSamplingStableCascade:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, shift):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.StableCascadeSampling
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -170,5 +196,6 @@ class RescaleCFG:
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
}
|
||||
|
||||
109
comfy_extras/nodes_stable_cascade.py
Normal file
109
comfy_extras/nodes_stable_cascade.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class StableCascade_EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def generate(self, width, height, compression, batch_size=1):
|
||||
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
|
||||
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
|
||||
return ({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageC_VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE", ),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def generate(self, image, vae, compression):
|
||||
width = image.shape[-2]
|
||||
height = image.shape[-3]
|
||||
out_width = (width // compression) * vae.downscale_ratio
|
||||
out_height = (height // compression) * vae.downscale_ratio
|
||||
|
||||
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
|
||||
|
||||
c_latent = vae.encode(s[:,:,:,:3])
|
||||
b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4])
|
||||
return ({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageB_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "conditioning": ("CONDITIONING",),
|
||||
"stage_c": ("LATENT",),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
|
||||
FUNCTION = "set_prior"
|
||||
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def set_prior(self, conditioning, stage_c):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d['stable_cascade_prior'] = stage_c['samples']
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
|
||||
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
|
||||
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user