mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge branch 'master' into execution_model_inversion
This commit is contained in:
commit
5ab1565418
11
README.md
11
README.md
@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||
@ -95,16 +95,15 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
|
||||
Put your VAE in: models/vae
|
||||
|
||||
Note: pytorch stable does not support python 3.12 yet. If you have python 3.12 you will have to use the nightly version of pytorch. If you run into issues you should try python 3.11 instead.
|
||||
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7```
|
||||
|
||||
This is the command to install the nightly with ROCm 5.7 which has a python 3.12 package and might have some performance improvements:
|
||||
This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```
|
||||
|
||||
### NVIDIA
|
||||
|
||||
@ -112,7 +111,7 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
|
||||
|
||||
This is the command to install pytorch nightly instead which has a python 3.12 package and might have performance improvements:
|
||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
x = self.embeddings(input_tokens)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
|
||||
@ -166,7 +166,7 @@ class ControlNet(ControlBase):
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
context = cond['c_crossattn']
|
||||
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(dtype)
|
||||
@ -318,9 +318,10 @@ def load_controlnet(ckpt_path, model=None):
|
||||
return ControlLora(controlnet_data)
|
||||
|
||||
controlnet_config = None
|
||||
supported_inference_dtypes = None
|
||||
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
@ -380,12 +381,20 @@ def load_controlnet(ckpt_path, model=None):
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
controlnet_config = model_config.unet_config
|
||||
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
if supported_inference_dtypes is None:
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
else:
|
||||
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
|
||||
@ -2,7 +2,8 @@ import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -22,7 +23,7 @@ def default(val, d):
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
self.proj = ops.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
@ -35,14 +36,14 @@ class FeedForward(nn.Module):
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
ops.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
ops.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module):
|
||||
query_dim=query_dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
dim_head=d_head,
|
||||
operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module):
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
self.linear = ops.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim,
|
||||
context_dim=query_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head)
|
||||
dim_head=d_head,
|
||||
operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module):
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj
|
||||
# feature
|
||||
self.linear = nn.Linear(context_dim, query_dim)
|
||||
self.linear = ops.Linear(context_dim, query_dim)
|
||||
|
||||
self.attn = CrossAttention(
|
||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
|
||||
query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
|
||||
self.ff = FeedForward(query_dim, glu=True)
|
||||
|
||||
self.norm1 = nn.LayerNorm(query_dim)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm1 = ops.LayerNorm(query_dim)
|
||||
self.norm2 = ops.LayerNorm(query_dim)
|
||||
|
||||
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
||||
@ -201,11 +204,11 @@ class PositionNet(nn.Module):
|
||||
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
||||
|
||||
self.linears = nn.Sequential(
|
||||
nn.Linear(self.in_dim + self.position_dim, 512),
|
||||
ops.Linear(self.in_dim + self.position_dim, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, 512),
|
||||
ops.Linear(512, 512),
|
||||
nn.SiLU(),
|
||||
nn.Linear(512, out_dim),
|
||||
ops.Linear(512, out_dim),
|
||||
)
|
||||
|
||||
self.null_positive_feature = torch.nn.Parameter(
|
||||
@ -215,16 +218,15 @@ class PositionNet(nn.Module):
|
||||
|
||||
def forward(self, boxes, masks, positive_embeddings):
|
||||
B, N, _ = boxes.shape
|
||||
dtype = self.linears[0].weight.dtype
|
||||
masks = masks.unsqueeze(-1).to(dtype)
|
||||
positive_embeddings = positive_embeddings.to(dtype)
|
||||
masks = masks.unsqueeze(-1)
|
||||
positive_embeddings = positive_embeddings
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
positive_null = self.null_positive_feature.view(1, 1, -1)
|
||||
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||
positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
||||
xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
|
||||
|
||||
# replace padding with learnable null embedding
|
||||
positive_embeddings = positive_embeddings * \
|
||||
@ -251,7 +253,7 @@ class Gligen(nn.Module):
|
||||
def func(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return module(x, objs.to(device=x.device, dtype=x.dtype))
|
||||
return func
|
||||
|
||||
def set_position(self, latent_image_shape, position_params, device):
|
||||
|
||||
@ -37,3 +37,41 @@ class SDXL(LatentFormat):
|
||||
class SD_X4(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.08333
|
||||
self.latent_rgb_factors = [
|
||||
[-0.2340, -0.3863, -0.3257],
|
||||
[ 0.0994, 0.0885, -0.0908],
|
||||
[-0.2833, -0.2349, -0.3741],
|
||||
[ 0.2523, -0.0055, -0.1651]
|
||||
]
|
||||
|
||||
class SC_Prior(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latent_rgb_factors = [
|
||||
[-0.0326, -0.0204, -0.0127],
|
||||
[-0.1592, -0.0427, 0.0216],
|
||||
[ 0.0873, 0.0638, -0.0020],
|
||||
[-0.0602, 0.0442, 0.1304],
|
||||
[ 0.0800, -0.0313, -0.1796],
|
||||
[-0.0810, -0.0638, -0.1581],
|
||||
[ 0.1791, 0.1180, 0.0967],
|
||||
[ 0.0740, 0.1416, 0.0432],
|
||||
[-0.1745, -0.1888, -0.1373],
|
||||
[ 0.2412, 0.1577, 0.0928],
|
||||
[ 0.1908, 0.0998, 0.0682],
|
||||
[ 0.0209, 0.0365, -0.0092],
|
||||
[ 0.0448, -0.0650, -0.1728],
|
||||
[-0.1658, -0.1045, -0.1308],
|
||||
[ 0.0542, 0.1545, 0.1325],
|
||||
[-0.0352, -0.1672, -0.2541]
|
||||
]
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latent_rgb_factors = [
|
||||
[ 0.1121, 0.2006, 0.1023],
|
||||
[-0.2093, -0.0222, -0.0195],
|
||||
[-0.3087, -0.1535, 0.0366],
|
||||
[ 0.0290, -0.1574, -0.4078]
|
||||
]
|
||||
|
||||
161
comfy/ldm/cascade/common.py
Normal file
161
comfy/ldm/cascade/common.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class OptimizedAttention(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = nhead
|
||||
|
||||
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
class Attention2D(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
|
||||
def LayerNorm2d_op(operations):
|
||||
class LayerNorm2d(operations.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return LayerNorm2d
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
||||
def __init__(self, dim, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
|
||||
super().__init__()
|
||||
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
|
||||
# self.depthwise = SAMBlock(c, num_heads, expansion)
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.channelwise = nn.Sequential(
|
||||
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
x_res = x
|
||||
x = self.norm(self.depthwise(x))
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x + x_res
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = self_attn
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
|
||||
self.kv_mapper = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
return x
|
||||
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.channelwise = nn.Sequential(
|
||||
operations.Linear(c, c * 4, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
|
||||
self.conds = conds
|
||||
for cname in conds:
|
||||
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, t):
|
||||
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
||||
for i, c in enumerate(self.conds):
|
||||
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
||||
a, b = a + ac, b + bc
|
||||
return x * (1 + a) + b
|
||||
258
comfy/ldm/cascade/stage_a.py
Normal file
258
comfy/ldm/cascade/stage_a.py
Normal file
@ -0,0 +1,258 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, codebook):
|
||||
with torch.no_grad():
|
||||
codebook_sqr = torch.sum(codebook ** 2, dim=1)
|
||||
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
|
||||
|
||||
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
|
||||
_, indices = dist.min(dim=1)
|
||||
|
||||
ctx.save_for_backward(indices, codebook)
|
||||
ctx.mark_non_differentiable(indices)
|
||||
|
||||
nn = torch.index_select(codebook, 0, indices)
|
||||
return nn, indices
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_indices):
|
||||
grad_inputs, grad_codebook = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_inputs = grad_output.clone()
|
||||
if ctx.needs_input_grad[1]:
|
||||
# Gradient wrt. the codebook
|
||||
indices, codebook = ctx.saved_tensors
|
||||
|
||||
grad_codebook = torch.zeros_like(codebook)
|
||||
grad_codebook.index_add_(0, indices, grad_output)
|
||||
|
||||
return (grad_inputs, grad_codebook)
|
||||
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
|
||||
"""
|
||||
Takes an input of variable size (as long as the last dimension matches the embedding size).
|
||||
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
|
||||
with the same size as the input, vq and commitment components for the loss as a touple
|
||||
in the second output and the indices of the quantized vectors in the third:
|
||||
quantized, (vq_loss, commit_loss), indices
|
||||
"""
|
||||
super(VectorQuantize, self).__init__()
|
||||
|
||||
self.codebook = nn.Embedding(k, embedding_size)
|
||||
self.codebook.weight.data.uniform_(-1./k, 1./k)
|
||||
self.vq = vector_quantize.apply
|
||||
|
||||
self.ema_decay = ema_decay
|
||||
self.ema_loss = ema_loss
|
||||
if ema_loss:
|
||||
self.register_buffer('ema_element_count', torch.ones(k))
|
||||
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
|
||||
|
||||
def _laplace_smoothing(self, x, epsilon):
|
||||
n = torch.sum(x)
|
||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||
|
||||
def _updateEMA(self, z_e_x, indices):
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
|
||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
|
||||
def idx2vq(self, idx, dim=-1):
|
||||
q_idx = self.codebook(idx)
|
||||
if dim != -1:
|
||||
q_idx = q_idx.movedim(-1, dim)
|
||||
return q_idx
|
||||
|
||||
def forward(self, x, get_losses=True, dim=-1):
|
||||
if dim != -1:
|
||||
x = x.movedim(dim, -1)
|
||||
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
|
||||
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
|
||||
vq_loss, commit_loss = None, None
|
||||
if self.ema_loss and self.training:
|
||||
self._updateEMA(z_e_x.detach(), indices.detach())
|
||||
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
|
||||
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
|
||||
if get_losses:
|
||||
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
|
||||
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
|
||||
|
||||
z_q_x = z_q_x.view(x.shape)
|
||||
if dim != -1:
|
||||
z_q_x = z_q_x.movedim(-1, dim)
|
||||
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_hidden):
|
||||
super().__init__()
|
||||
# depthwise/attention
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
def forward(self, x):
|
||||
mods = self.gammas
|
||||
|
||||
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
||||
try:
|
||||
x = x + self.depthwise(x_temp) * mods[2]
|
||||
except: #operation not implemented for bf16
|
||||
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
|
||||
x = x + self.depthwise[1](x_temp) * mods[2]
|
||||
|
||||
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
||||
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StageA(nn.Module):
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
|
||||
scale_factor=0.43): # 0.3764
|
||||
super().__init__()
|
||||
self.c_latent = c_latent
|
||||
self.scale_factor = scale_factor
|
||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
||||
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
self.down_blocks[0]
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
def encode(self, x, quantize=False):
|
||||
x = self.in_block(x)
|
||||
x = self.down_blocks(x)
|
||||
if quantize:
|
||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
||||
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
|
||||
else:
|
||||
return x / self.scale_factor
|
||||
|
||||
def decode(self, x):
|
||||
x = x * self.scale_factor
|
||||
x = self.up_blocks(x)
|
||||
x = self.out_block(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, quantize=False):
|
||||
qe, x, _, vq_loss = self.encode(x, quantize)
|
||||
x = self.decode(qe)
|
||||
return x, vq_loss
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
x = self.encoder(x)
|
||||
if cond is not None:
|
||||
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
||||
x = torch.cat([x, cond], dim=1)
|
||||
x = self.shuffle(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
257
comfy/ldm/cascade/stage_b.py
Normal file
257
comfy/ldm/cascade/stage_b.py
Normal file
@ -0,0 +1,257 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||
|
||||
class StageB(nn.Module):
|
||||
def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
|
||||
nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
|
||||
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
|
||||
c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
|
||||
t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.c_r = c_r
|
||||
self.t_conds = t_conds
|
||||
self.c_clip_seq = c_clip_seq
|
||||
if not isinstance(dropout, list):
|
||||
dropout = [dropout] * len(c_hidden)
|
||||
if not isinstance(self_attn, list):
|
||||
self_attn = [self_attn] * len(c_hidden)
|
||||
|
||||
# CONDITIONING
|
||||
self.effnet_mapper = nn.Sequential(
|
||||
operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
self.pixels_mapper = nn.Sequential(
|
||||
operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == 'C':
|
||||
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'A':
|
||||
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'F':
|
||||
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'T':
|
||||
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
raise Exception(f'Block type {block_type} not supported')
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(c_hidden)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
))
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(blocks[0][i]):
|
||||
for block_type in level_config[i]:
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[0][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(c_hidden))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
))
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(blocks[1][::-1][i]):
|
||||
for k, block_type in enumerate(level_config[i]):
|
||||
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||
self_attn=self_attn[i])
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
# --- WEIGHT INIT ---
|
||||
# self.apply(self._init_weights) # General init
|
||||
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||
# elif isinstance(block, TimestepBlock):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
if self.c_r % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||
return emb
|
||||
|
||||
def gen_c_embeddings(self, clip):
|
||||
if len(clip.shape) == 2:
|
||||
clip = clip.unsqueeze(1)
|
||||
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
x = block(x)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||
clip = self.gen_c_embeddings(clip)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
x = x + self.effnet_mapper(
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
||||
271
comfy/ldm/cascade/stage_c.py
Normal file
271
comfy/ldm/cascade/stage_c.py
Normal file
@ -0,0 +1,271 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import math
|
||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||
# from .controlnet import ControlNetDeliverer
|
||||
|
||||
class UpDownBlock2d(nn.Module):
|
||||
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert mode in ['up', 'down']
|
||||
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
|
||||
align_corners=True) if enabled else nn.Identity()
|
||||
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
|
||||
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class StageC(nn.Module):
|
||||
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
|
||||
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
|
||||
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
|
||||
dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.c_r = c_r
|
||||
self.t_conds = t_conds
|
||||
self.c_clip_seq = c_clip_seq
|
||||
if not isinstance(dropout, list):
|
||||
dropout = [dropout] * len(c_hidden)
|
||||
if not isinstance(self_attn, list):
|
||||
self_attn = [self_attn] * len(c_hidden)
|
||||
|
||||
# CONDITIONING
|
||||
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
|
||||
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
|
||||
)
|
||||
|
||||
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == 'C':
|
||||
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'A':
|
||||
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'F':
|
||||
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'T':
|
||||
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
raise Exception(f'Block type {block_type} not supported')
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(c_hidden)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||
))
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(blocks[0][i]):
|
||||
for block_type in level_config[i]:
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[0][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(c_hidden))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||
))
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(blocks[1][::-1][i]):
|
||||
for k, block_type in enumerate(level_config[i]):
|
||||
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||
self_attn=self_attn[i])
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
# --- WEIGHT INIT ---
|
||||
# self.apply(self._init_weights) # General init
|
||||
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||
# elif isinstance(block, TimestepBlock):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
if self.c_r % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||
return emb
|
||||
|
||||
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
|
||||
clip_txt = self.clip_txt_mapper(clip_txt)
|
||||
if len(clip_txt_pooled.shape) == 2:
|
||||
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
|
||||
if len(clip_img.shape) == 2:
|
||||
clip_img = clip_img.unsqueeze(1)
|
||||
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
|
||||
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
|
||||
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
if cnet is not None:
|
||||
next_cnet = cnet()
|
||||
if next_cnet is not None:
|
||||
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
x = block(x)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
if cnet is not None:
|
||||
next_cnet = cnet()
|
||||
if next_cnet is not None:
|
||||
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
if cnet is not None:
|
||||
cnet = ControlNetDeliverer(cnet)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
||||
95
comfy/ldm/cascade/stage_c_coder.py
Normal file
95
comfy/ldm/cascade/stage_c_coder.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
def __init__(self, c_latent=16):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
|
||||
|
||||
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
|
||||
class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.blocks(x) - 0.5) * 2.0
|
||||
|
||||
class StageC_coder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.previewer = Previewer()
|
||||
self.encoder = EfficientNetEncoder()
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.previewer(x)
|
||||
@ -114,7 +114,12 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
else:
|
||||
sim += mask
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
sim.add_(mask)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
@ -165,6 +170,13 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
if query_chunk_size is None:
|
||||
query_chunk_size = 512
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
@ -223,6 +235,13 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
|
||||
@ -98,7 +98,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
betas = torch.clamp(betas, min=0, max=0.999)
|
||||
|
||||
elif schedule == "squaredcos_cap_v2": # used for karlo prior
|
||||
# return early
|
||||
@ -113,7 +113,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
return betas
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
import comfy.model_management
|
||||
@ -12,9 +14,10 @@ class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
STABLE_CASCADE = 4
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -27,6 +30,9 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -35,7 +41,7 @@ def model_sampling(model_config, model_type):
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
@ -48,7 +54,7 @@ class BaseModel(torch.nn.Module):
|
||||
operations = comfy.ops.manual_cast
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@ -153,6 +159,10 @@ class BaseModel(torch.nn.Module):
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
|
||||
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
|
||||
if cross_attn_cnet is not None:
|
||||
out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)
|
||||
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
@ -423,3 +433,52 @@ class SD_X4Upscaler(BaseModel):
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
clip_text_pooled = kwargs["pooled_output"]
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
if "unclip_conditioning" in kwargs:
|
||||
embeds = []
|
||||
for unclip_cond in kwargs["unclip_conditioning"]:
|
||||
weight = unclip_cond["strength"]
|
||||
embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
|
||||
clip_img = torch.cat(embeds, dim=1)
|
||||
else:
|
||||
clip_img = torch.zeros((1, 1, 768))
|
||||
out["clip_img"] = comfy.conds.CONDRegular(clip_img)
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class StableCascade_B(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||
self.diffusion_model.eval().requires_grad_(False)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"]
|
||||
if clip_text_pooled is not None:
|
||||
out['clip'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||
|
||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
return out
|
||||
|
||||
@ -28,9 +28,38 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
||||
return None
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
def detect_unet_config(state_dict, key_prefix):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||
unet_config = {}
|
||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||
if text_mapper_name in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'c'
|
||||
w = state_dict[text_mapper_name]
|
||||
if w.shape[0] == 1536: #stage c lite
|
||||
unet_config['c_cond'] = 1536
|
||||
unet_config['c_hidden'] = [1536, 1536]
|
||||
unet_config['nhead'] = [24, 24]
|
||||
unet_config['blocks'] = [[4, 12], [12, 4]]
|
||||
elif w.shape[0] == 2048: #stage c full
|
||||
unet_config['c_cond'] = 2048
|
||||
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'b'
|
||||
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
|
||||
if w.shape[-1] == 640:
|
||||
unet_config['c_hidden'] = [320, 640, 1280, 1280]
|
||||
unet_config['nhead'] = [-1, -1, 20, 20]
|
||||
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
|
||||
elif w.shape[-1] == 576: #stage b lite
|
||||
unet_config['c_hidden'] = [320, 576, 1152, 1152]
|
||||
unet_config['nhead'] = [-1, 9, 18, 18]
|
||||
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
|
||||
return unet_config
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
@ -45,7 +74,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
unet_config["dtype"] = dtype
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
@ -159,8 +187,8 @@ def model_config_from_unet_config(unet_config):
|
||||
print("no match", unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
@ -206,7 +234,7 @@ def convert_config(unet_config):
|
||||
return new_config
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
match = {}
|
||||
transformer_depth = []
|
||||
|
||||
@ -313,8 +341,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
return convert_config(unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||
def model_config_from_diffusers_unet(state_dict):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict)
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
||||
@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@ -496,21 +496,32 @@ def unet_dtype(device=None, model_params=0):
|
||||
return torch.float8_e4m3fn
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device):
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
||||
if fp16_supported:
|
||||
bf16_supported = should_use_bf16(inference_device)
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
@ -546,10 +557,8 @@ def text_encoder_dtype(device=None):
|
||||
if is_device_cpu(device):
|
||||
return torch.float16
|
||||
|
||||
if should_use_fp16(device, prioritize_performance=False):
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
return torch.float16
|
||||
|
||||
|
||||
def intermediate_device():
|
||||
if args.gpu_only:
|
||||
@ -686,19 +695,22 @@ def mps_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def is_device_cpu(device):
|
||||
def is_device_type(device, type):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'cpu'):
|
||||
if (device.type == type):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_device_cpu(device):
|
||||
return is_device_type(device, 'cpu')
|
||||
|
||||
def is_device_mps(device):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == 'mps'):
|
||||
return True
|
||||
return False
|
||||
return is_device_type(device, 'mps')
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
global directml_enabled
|
||||
|
||||
if device is not None:
|
||||
@ -708,9 +720,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
||||
if FORCE_FP16:
|
||||
return True
|
||||
|
||||
if device is not None: #TODO
|
||||
if device is not None:
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
@ -718,16 +730,22 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
return False #TODO ?
|
||||
if mps_mode():
|
||||
return True
|
||||
|
||||
if cpu_mode():
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
if props.major < 6:
|
||||
return False
|
||||
|
||||
@ -740,7 +758,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
||||
if x in props.name.lower():
|
||||
fp16_works = True
|
||||
|
||||
if fp16_works:
|
||||
if fp16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
@ -756,6 +774,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
||||
|
||||
return True
|
||||
|
||||
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
if device is not None:
|
||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||
return False
|
||||
|
||||
if device is not None: #TODO not sure about mps bf16 support
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda")
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
if bf16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import math
|
||||
|
||||
@ -42,8 +41,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
else:
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
||||
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@ -58,8 +56,8 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
self.register_buffer('sigmas', sigmas.float())
|
||||
self.register_buffer('log_sigmas', sigmas.log().float())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
@ -134,3 +132,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
|
||||
log_sigma_min = math.log(self.sigma_min)
|
||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||
|
||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(sampling_settings.get("shift", 1.0))
|
||||
|
||||
def set_parameters(self, shift=1.0, cosine_s=8e-3):
|
||||
self.shift = shift
|
||||
self.cosine_s = torch.tensor(cosine_s)
|
||||
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
|
||||
|
||||
#This part is just for compatibility with some schedulers in the codebase
|
||||
self.num_timesteps = 10000
|
||||
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
|
||||
for x in range(self.num_timesteps):
|
||||
t = (x + 1) / self.num_timesteps
|
||||
sigmas[x] = self.sigma(t)
|
||||
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def sigma(self, timestep):
|
||||
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
|
||||
|
||||
if self.shift != 1.0:
|
||||
var = alpha_cumprod
|
||||
logSNR = (var/(1-var)).log()
|
||||
logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
|
||||
alpha_cumprod = logSNR.sigmoid()
|
||||
|
||||
alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
|
||||
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
|
||||
|
||||
def timestep(self, sigma):
|
||||
var = 1 / ((sigma * sigma) + 1)
|
||||
var = var.clamp(0, 1.0)
|
||||
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
|
||||
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return t
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent))
|
||||
|
||||
49
comfy/ops.py
49
comfy/ops.py
@ -1,3 +1,21 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
@ -78,7 +96,11 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -87,6 +109,28 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d):
|
||||
comfy_cast_weights = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input, output_size=None):
|
||||
num_spatial_dims = 2
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
@ -112,3 +156,6 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class LayerNorm(disable_weight_init.LayerNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
@ -295,7 +295,7 @@ def simple_scheduler(model, steps):
|
||||
def ddim_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ss = len(s.sigmas) // steps
|
||||
ss = max(len(s.sigmas) // steps, 1)
|
||||
x = 1
|
||||
while x < len(s.sigmas):
|
||||
sigs += [float(s.sigmas[x])]
|
||||
@ -652,6 +652,7 @@ def sampler_object(name):
|
||||
class KSampler:
|
||||
SCHEDULERS = SCHEDULER_NAMES
|
||||
SAMPLERS = SAMPLER_NAMES
|
||||
DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2'))
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||
self.model = model
|
||||
@ -670,7 +671,7 @@ class KSampler:
|
||||
sigmas = None
|
||||
|
||||
discard_penultimate_sigma = False
|
||||
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
|
||||
if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS:
|
||||
steps += 1
|
||||
discard_penultimate_sigma = True
|
||||
|
||||
|
||||
128
comfy/sd.py
128
comfy/sd.py
@ -1,7 +1,11 @@
|
||||
import torch
|
||||
from enum import Enum
|
||||
|
||||
from comfy import model_management
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@ -134,8 +138,11 @@ class CLIP:
|
||||
tokens = self.tokenize(text)
|
||||
return self.encode_from_tokens(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
return self.cond_stage_model.state_dict()
|
||||
@ -155,7 +162,10 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@ -168,6 +178,34 @@ class VAE:
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
#TODO
|
||||
#self.memory_used_encode
|
||||
#self.memory_used_decode
|
||||
self.process_input = lambda image: image
|
||||
self.process_output = lambda image: image
|
||||
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["encoder.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["previewer.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
else:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
@ -175,6 +213,7 @@ class VAE:
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
@ -200,18 +239,27 @@ class VAE:
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
|
||||
y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
|
||||
y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp((
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@ -220,7 +268,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@ -235,10 +283,10 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
@ -252,6 +300,7 @@ class VAE:
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
@ -261,7 +310,7 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
@ -271,6 +320,7 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
@ -297,8 +347,11 @@ def load_style_model(ckpt_path):
|
||||
model.load_state_dict(model_data)
|
||||
return StyleModel(model)
|
||||
|
||||
class CLIPType(Enum):
|
||||
STABLE_DIFFUSION = 1
|
||||
STABLE_CASCADE = 2
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
@ -314,8 +367,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
if clip_type == CLIPType.STABLE_CASCADE:
|
||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
@ -438,15 +495,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
clip_target = None
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
@ -462,18 +516,24 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
clip_target = model_config.clip_target()
|
||||
if clip_target is not None:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
sd = model_config.process_clip_state_dict(sd)
|
||||
load_model_weights(w, sd)
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
print("clip missing:", m)
|
||||
|
||||
if len(u) > 0:
|
||||
print("clip unexpected:", u)
|
||||
else:
|
||||
print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
@ -492,16 +552,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
new_sd = sd
|
||||
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||
if model_config is None:
|
||||
return None
|
||||
|
||||
@ -513,8 +572,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
print(diffusers_keys[k], k)
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
|
||||
@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.special_tokens = special_tokens
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.enable_attention_masks = False
|
||||
self.enable_attention_masks = enable_attention_masks
|
||||
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
|
||||
@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
||||
|
||||
@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE):
|
||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
replace_prefix = {}
|
||||
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
replace_prefix["cond_stage_model."] = "clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE):
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
|
||||
replace_prefix["cond_stage_model.model."] = "clip_h."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE):
|
||||
keys_to_replace = {}
|
||||
replace_prefix = {}
|
||||
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
|
||||
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection"
|
||||
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
|
||||
@ -306,5 +305,66 @@ class SD_X4Upscaler(SD20):
|
||||
out = model_base.SD_X4Upscaler(self, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
||||
class Stable_Cascade_C(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'c',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent_formats.SC_Prior
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoder."]
|
||||
clip_vision_prefix = "clip_l_vision."
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
key_list = list(state_dict.keys())
|
||||
for y in ["weight", "bias"]:
|
||||
suffix = "in_proj_{}".format(y)
|
||||
keys = filter(lambda a: a.endswith(suffix), key_list)
|
||||
for k_from in keys:
|
||||
weights = state_dict.pop(k_from)
|
||||
prefix = k_from[:-(len(suffix) + 1)]
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["to_q", "to_k", "to_v"]
|
||||
k_to = "{}.{}.{}".format(prefix, p[x], y)
|
||||
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return state_dict
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_C(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||
|
||||
class Stable_Cascade_B(Stable_Cascade_C):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'b',
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
latent_format = latent_formats.SC_B
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
clip_vision_prefix = None
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_B(self, device=device)
|
||||
return out
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -21,13 +21,16 @@ class BASE:
|
||||
noise_aug_config = None
|
||||
sampling_settings = {}
|
||||
latent_format = latent_formats.LatentFormat
|
||||
vae_key_prefix = ["first_stage_model."]
|
||||
text_encoder_key_prefix = ["cond_stage_model."]
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
manual_cast_dtype = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config):
|
||||
for k in s.unet_config:
|
||||
if s.unet_config[k] != unet_config[k]:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -53,6 +56,7 @@ class BASE:
|
||||
return out
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
@ -62,7 +66,7 @@ class BASE:
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
||||
@ -76,8 +80,9 @@ class BASE:
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def process_vae_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "first_stage_model."}
|
||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def set_manual_cast(self, manual_cast_dtype):
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
@ -169,6 +169,8 @@ UNET_MAP_BASIC = {
|
||||
}
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
if "num_res_blocks" not in unet_config:
|
||||
return {}
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
@ -413,6 +415,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||
for y in range(0, s.shape[2], tile_y - overlap):
|
||||
for x in range(0, s.shape[3], tile_x - overlap):
|
||||
x = max(0, min(s.shape[-1] - overlap, x))
|
||||
y = max(0, min(s.shape[-2] - overlap, y))
|
||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
|
||||
25
comfy_extras/nodes_cond.py
Normal file
25
comfy_extras/nodes_cond.py
Normal file
@ -0,0 +1,25 @@
|
||||
|
||||
|
||||
class CLIPTextEncodeControlnet:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "_for_testing/conditioning"
|
||||
|
||||
def encode(self, clip, conditioning, text):
|
||||
tokens = clip.tokenize(text)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['cross_attn_controlnet'] = cond
|
||||
n[1]['pooled_output_controlnet'] = pooled
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
|
||||
}
|
||||
@ -48,6 +48,25 @@ class RepeatImageBatch:
|
||||
s = image.repeat((amount, 1,1,1))
|
||||
return (s,)
|
||||
|
||||
class ImageFromBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "frombatch"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def frombatch(self, image, batch_index, length):
|
||||
s_in = image
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return (s,)
|
||||
|
||||
class SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@ -170,6 +189,7 @@ class SaveAnimatedPNG:
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"ImageFromBatch": ImageFromBatch,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
}
|
||||
|
||||
@ -126,7 +126,7 @@ class LatentBatchSeedBehavior:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"seed_behavior": (["random", "fixed"],),}}
|
||||
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
@ -17,6 +17,10 @@ class LCM(comfy.model_sampling.EPS):
|
||||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
class X0(comfy.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
return model_output
|
||||
|
||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||
original_timesteps = 50
|
||||
|
||||
@ -68,7 +72,7 @@ class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["eps", "v_prediction", "lcm"],),
|
||||
"sampling": (["eps", "v_prediction", "lcm", "x0"],),
|
||||
"zsnr": ("BOOLEAN", {"default": False}),
|
||||
}}
|
||||
|
||||
@ -88,6 +92,8 @@ class ModelSamplingDiscrete:
|
||||
elif sampling == "lcm":
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
elif sampling == "x0":
|
||||
sampling_type = X0
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
@ -99,6 +105,32 @@ class ModelSamplingDiscrete:
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class ModelSamplingStableCascade:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, shift):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = comfy.model_sampling.StableCascadeSampling
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -171,5 +203,6 @@ class RescaleCFG:
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||
"ModelSamplingStableCascade": ModelSamplingStableCascade,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
}
|
||||
|
||||
109
comfy_extras/nodes_stable_cascade.py
Normal file
109
comfy_extras/nodes_stable_cascade.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class StableCascade_EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def generate(self, width, height, compression, batch_size=1):
|
||||
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
|
||||
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
|
||||
return ({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageC_VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE", ),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def generate(self, image, vae, compression):
|
||||
width = image.shape[-2]
|
||||
height = image.shape[-3]
|
||||
out_width = (width // compression) * vae.downscale_ratio
|
||||
out_height = (height // compression) * vae.downscale_ratio
|
||||
|
||||
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
|
||||
|
||||
c_latent = vae.encode(s[:,:,:,:3])
|
||||
b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4])
|
||||
return ({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageB_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "conditioning": ("CONDITIONING",),
|
||||
"stage_c": ("LATENT",),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
|
||||
FUNCTION = "set_prior"
|
||||
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def set_prior(self, conditioning, stage_c):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d['stable_cascade_prior'] = stage_c['samples']
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
|
||||
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
|
||||
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
|
||||
}
|
||||
@ -6,6 +6,8 @@ class Example:
|
||||
-------------
|
||||
INPUT_TYPES (dict):
|
||||
Tell the main program input parameters of nodes.
|
||||
IS_CHANGED:
|
||||
optional method to control when the node is re executed.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@ -89,6 +91,17 @@ class Example:
|
||||
image = 1.0 - image
|
||||
return (image,)
|
||||
|
||||
"""
|
||||
The node will always be re executed if any of the inputs change but
|
||||
this method can be used to force the node to execute again even when the inputs don't change.
|
||||
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
|
||||
executed, if it is different the node will be executed again.
|
||||
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
|
||||
changes between executions the LoadImage node is executed again.
|
||||
"""
|
||||
#@classmethod
|
||||
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
|
||||
# return ""
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
|
||||
49
custom_nodes/websocket_image_save.py.disabled
Normal file
49
custom_nodes/websocket_image_save.py.disabled
Normal file
@ -0,0 +1,49 @@
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import struct
|
||||
import comfy.utils
|
||||
import time
|
||||
|
||||
#You can use this node to save full size images through the websocket, the
|
||||
#images will be sent in exactly the same format as the image previews: as
|
||||
#binary images on the websocket with a 8 byte header indicating the type
|
||||
#of binary message (first 4 bytes) and the image format (next 4 bytes).
|
||||
|
||||
#The reason this node is disabled by default is because there is a small
|
||||
#issue when using it with the default ComfyUI web interface: When generating
|
||||
#batches only the last image will be shown in the UI.
|
||||
|
||||
#Note that no metadata will be put in the images saved with this node.
|
||||
|
||||
class SaveImageWebsocket:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def save_images(self, images):
|
||||
pbar = comfy.utils.ProgressBar(images.shape[0])
|
||||
step = 0
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pbar.update_absolute(step, images.shape[0], ("PNG", img, None))
|
||||
step += 1
|
||||
|
||||
return {}
|
||||
|
||||
def IS_CHANGED(s, images):
|
||||
return time.time()
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveImageWebsocket": SaveImageWebsocket,
|
||||
}
|
||||
57
nodes.py
57
nodes.py
@ -184,6 +184,26 @@ class ConditioningSetAreaPercentage:
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class ConditioningSetAreaStrength:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, strength):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['strength'] = strength
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
|
||||
class ConditioningSetMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -289,18 +309,7 @@ class VAEEncode:
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
@staticmethod
|
||||
def vae_encode_crop_pixels(pixels):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
pixels = self.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
return ({"samples":t}, )
|
||||
|
||||
@ -316,7 +325,6 @@ class VAEEncodeTiled:
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels, tile_size):
|
||||
pixels = VAEEncode.vae_encode_crop_pixels(pixels)
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
|
||||
return ({"samples":t}, )
|
||||
|
||||
@ -330,14 +338,14 @@ class VAEEncodeForInpaint:
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def encode(self, vae, pixels, mask, grow_mask_by=6):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
|
||||
pixels = pixels.clone()
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
|
||||
y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
|
||||
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
||||
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
|
||||
|
||||
@ -834,15 +842,20 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
if type == "stable_cascade":
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
return (clip,)
|
||||
|
||||
class DualCLIPLoader:
|
||||
@ -1414,7 +1427,7 @@ class SaveImage:
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
for image in images:
|
||||
for (batch_number, image) in enumerate(images):
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
metadata = None
|
||||
@ -1426,7 +1439,8 @@ class SaveImage:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
|
||||
results.append({
|
||||
"filename": file,
|
||||
@ -1754,6 +1768,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ConditioningConcat": ConditioningConcat,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
|
||||
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
|
||||
"ConditioningSetMask": ConditioningSetMask,
|
||||
"KSamplerAdvanced": KSamplerAdvanced,
|
||||
"SetLatentNoiseMask": SetLatentNoiseMask,
|
||||
@ -1944,6 +1959,8 @@ def init_custom_nodes():
|
||||
"nodes_stable3d.py",
|
||||
"nodes_sdupscale.py",
|
||||
"nodes_photomaker.py",
|
||||
"nodes_cond.py",
|
||||
"nodes_stable_cascade.py",
|
||||
]
|
||||
|
||||
for node_file in extras_files:
|
||||
|
||||
@ -910,6 +910,9 @@ export class GroupNodeHandler {
|
||||
const self = this;
|
||||
const onNodeCreated = this.node.onNodeCreated;
|
||||
this.node.onNodeCreated = function () {
|
||||
if (!this.widgets) {
|
||||
return;
|
||||
}
|
||||
const config = self.groupData.nodeData.config;
|
||||
if (config) {
|
||||
for (const n in config) {
|
||||
|
||||
@ -48,7 +48,7 @@
|
||||
list-style: none;
|
||||
}
|
||||
.comfy-group-manage-list-items {
|
||||
max-height: 70vh;
|
||||
max-height: calc(100% - 40px);
|
||||
overflow-y: scroll;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ async function uploadMask(filepath, formData) {
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
|
||||
function prepare_mask(image, maskCanvas, maskCtx) {
|
||||
function prepare_mask(image, maskCanvas, maskCtx, maskColor) {
|
||||
// paste mask data into alpha channel
|
||||
maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||
const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height);
|
||||
@ -74,9 +74,9 @@ function prepare_mask(image, maskCanvas, maskCtx) {
|
||||
else
|
||||
maskData.data[i+3] = 255;
|
||||
|
||||
maskData.data[i] = 0;
|
||||
maskData.data[i+1] = 0;
|
||||
maskData.data[i+2] = 0;
|
||||
maskData.data[i] = maskColor.r;
|
||||
maskData.data[i+1] = maskColor.g;
|
||||
maskData.data[i+2] = maskColor.b;
|
||||
}
|
||||
|
||||
maskCtx.globalCompositeOperation = 'source-over';
|
||||
@ -110,6 +110,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
createButton(name, callback) {
|
||||
var button = document.createElement("button");
|
||||
button.style.pointerEvents = "auto";
|
||||
button.innerText = name;
|
||||
button.addEventListener("click", callback);
|
||||
return button;
|
||||
@ -146,6 +147,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
divElement.style.display = "flex";
|
||||
divElement.style.position = "relative";
|
||||
divElement.style.top = "2px";
|
||||
divElement.style.pointerEvents = "auto";
|
||||
self.brush_slider_input = document.createElement('input');
|
||||
self.brush_slider_input.setAttribute('type', 'range');
|
||||
self.brush_slider_input.setAttribute('min', '1');
|
||||
@ -173,6 +175,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
bottom_panel.style.left = "20px";
|
||||
bottom_panel.style.right = "20px";
|
||||
bottom_panel.style.height = "50px";
|
||||
bottom_panel.style.pointerEvents = "none";
|
||||
|
||||
var brush = document.createElement("div");
|
||||
brush.id = "brush";
|
||||
@ -191,14 +194,29 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
this.element.appendChild(bottom_panel);
|
||||
document.body.appendChild(brush);
|
||||
|
||||
var clearButton = this.createLeftButton("Clear", () => {
|
||||
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
||||
});
|
||||
|
||||
this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
||||
self.brush_size = event.target.value;
|
||||
self.updateBrushPreview(self, null, null);
|
||||
});
|
||||
var clearButton = this.createLeftButton("Clear",
|
||||
() => {
|
||||
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
||||
});
|
||||
|
||||
this.colorButton = this.createLeftButton(this.getColorButtonText(), () => {
|
||||
if (self.brush_color_mode === "black") {
|
||||
self.brush_color_mode = "white";
|
||||
}
|
||||
else if (self.brush_color_mode === "white") {
|
||||
self.brush_color_mode = "negative";
|
||||
}
|
||||
else {
|
||||
self.brush_color_mode = "black";
|
||||
}
|
||||
|
||||
self.updateWhenBrushColorModeChanged();
|
||||
});
|
||||
|
||||
var cancelButton = this.createRightButton("Cancel", () => {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||
@ -219,6 +237,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
bottom_panel.appendChild(this.saveButton);
|
||||
bottom_panel.appendChild(cancelButton);
|
||||
bottom_panel.appendChild(this.brush_size_slider);
|
||||
bottom_panel.appendChild(this.colorButton);
|
||||
|
||||
imgCanvas.style.position = "absolute";
|
||||
maskCanvas.style.position = "absolute";
|
||||
@ -228,6 +247,10 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
maskCanvas.style.top = imgCanvas.style.top;
|
||||
maskCanvas.style.left = imgCanvas.style.left;
|
||||
|
||||
const maskCanvasStyle = this.getMaskCanvasStyle();
|
||||
maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode;
|
||||
maskCanvas.style.opacity = maskCanvasStyle.opacity;
|
||||
}
|
||||
|
||||
async show() {
|
||||
@ -313,7 +336,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true });
|
||||
|
||||
imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height);
|
||||
prepare_mask(mask_image, this.maskCanvas, maskCtx);
|
||||
prepare_mask(mask_image, this.maskCanvas, maskCtx, this.getMaskColor());
|
||||
}
|
||||
|
||||
async setImages(imgCanvas) {
|
||||
@ -439,7 +462,84 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
}
|
||||
}
|
||||
|
||||
getMaskCanvasStyle() {
|
||||
if (this.brush_color_mode === "negative") {
|
||||
return {
|
||||
mixBlendMode: "difference",
|
||||
opacity: "1",
|
||||
};
|
||||
}
|
||||
else {
|
||||
return {
|
||||
mixBlendMode: "initial",
|
||||
opacity: "0.7",
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
getMaskColor() {
|
||||
if (this.brush_color_mode === "black") {
|
||||
return { r: 0, g: 0, b: 0 };
|
||||
}
|
||||
if (this.brush_color_mode === "white") {
|
||||
return { r: 255, g: 255, b: 255 };
|
||||
}
|
||||
if (this.brush_color_mode === "negative") {
|
||||
// negative effect only works with white color
|
||||
return { r: 255, g: 255, b: 255 };
|
||||
}
|
||||
|
||||
return { r: 0, g: 0, b: 0 };
|
||||
}
|
||||
|
||||
getMaskFillStyle() {
|
||||
const maskColor = this.getMaskColor();
|
||||
|
||||
return "rgb(" + maskColor.r + "," + maskColor.g + "," + maskColor.b + ")";
|
||||
}
|
||||
|
||||
getColorButtonText() {
|
||||
let colorCaption = "unknown";
|
||||
|
||||
if (this.brush_color_mode === "black") {
|
||||
colorCaption = "black";
|
||||
}
|
||||
else if (this.brush_color_mode === "white") {
|
||||
colorCaption = "white";
|
||||
}
|
||||
else if (this.brush_color_mode === "negative") {
|
||||
colorCaption = "negative";
|
||||
}
|
||||
|
||||
return "Color: " + colorCaption;
|
||||
}
|
||||
|
||||
updateWhenBrushColorModeChanged() {
|
||||
this.colorButton.innerText = this.getColorButtonText();
|
||||
|
||||
// update mask canvas css styles
|
||||
|
||||
const maskCanvasStyle = this.getMaskCanvasStyle();
|
||||
this.maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode;
|
||||
this.maskCanvas.style.opacity = maskCanvasStyle.opacity;
|
||||
|
||||
// update mask canvas rgb colors
|
||||
|
||||
const maskColor = this.getMaskColor();
|
||||
|
||||
const maskData = this.maskCtx.getImageData(0, 0, this.maskCanvas.width, this.maskCanvas.height);
|
||||
|
||||
for (let i = 0; i < maskData.data.length; i += 4) {
|
||||
maskData.data[i] = maskColor.r;
|
||||
maskData.data[i+1] = maskColor.g;
|
||||
maskData.data[i+2] = maskColor.b;
|
||||
}
|
||||
|
||||
this.maskCtx.putImageData(maskData, 0, 0);
|
||||
}
|
||||
|
||||
brush_size = 10;
|
||||
brush_color_mode = "black";
|
||||
drawing_mode = false;
|
||||
lastx = -1;
|
||||
lasty = -1;
|
||||
@ -518,6 +618,19 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
event.preventDefault();
|
||||
self.pan_move(self, event);
|
||||
}
|
||||
|
||||
let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1;
|
||||
|
||||
if(event.shiftKey && left_button_down) {
|
||||
self.drawing_mode = false;
|
||||
|
||||
const y = event.clientY;
|
||||
let delta = (self.zoom_lasty - y)*0.005;
|
||||
self.zoom_ratio = Math.max(Math.min(10.0, self.last_zoom_ratio - delta), 0.2);
|
||||
|
||||
this.invalidatePanZoom();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
pan_move(self, event) {
|
||||
@ -535,7 +648,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
}
|
||||
|
||||
draw_move(self, event) {
|
||||
if(event.ctrlKey) {
|
||||
if(event.ctrlKey || event.shiftKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -546,7 +659,10 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
|
||||
if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) {
|
||||
let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1;
|
||||
let right_button_down = [2, 5, 32].includes(event.buttons);
|
||||
|
||||
if (!event.altKey && left_button_down) {
|
||||
var diff = performance.now() - self.lasttime;
|
||||
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
@ -581,7 +697,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
if(diff > 20 && !this.drawing_mode)
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.fillStyle = this.getMaskFillStyle();
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
@ -591,7 +707,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
else
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.fillStyle = this.getMaskFillStyle();
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
|
||||
var dx = x - self.lastx;
|
||||
@ -613,7 +729,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
|
||||
else if((event.altKey && left_button_down) || right_button_down) {
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
|
||||
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
|
||||
@ -687,13 +803,20 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
self.drawing_mode = true;
|
||||
|
||||
event.preventDefault();
|
||||
|
||||
if(event.shiftKey) {
|
||||
self.zoom_lasty = event.clientY;
|
||||
self.last_zoom_ratio = self.zoom_ratio;
|
||||
return;
|
||||
}
|
||||
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
|
||||
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
|
||||
|
||||
self.maskCtx.beginPath();
|
||||
if (event.button == 0) {
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
if (!event.altKey && event.button == 0) {
|
||||
self.maskCtx.fillStyle = this.getMaskFillStyle();
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
} else {
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
|
||||
102
web/extensions/core/simpleTouchSupport.js
Normal file
102
web/extensions/core/simpleTouchSupport.js
Normal file
@ -0,0 +1,102 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
|
||||
let touchZooming;
|
||||
let touchCount = 0;
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.SimpleTouchSupport",
|
||||
setup() {
|
||||
let zoomPos;
|
||||
let touchTime;
|
||||
let lastTouch;
|
||||
|
||||
function getMultiTouchPos(e) {
|
||||
return Math.hypot(e.touches[0].clientX - e.touches[1].clientX, e.touches[0].clientY - e.touches[1].clientY);
|
||||
}
|
||||
|
||||
app.canvasEl.addEventListener(
|
||||
"touchstart",
|
||||
(e) => {
|
||||
touchCount++;
|
||||
lastTouch = null;
|
||||
if (e.touches?.length === 1) {
|
||||
// Store start time for press+hold for context menu
|
||||
touchTime = new Date();
|
||||
lastTouch = e.touches[0];
|
||||
} else {
|
||||
touchTime = null;
|
||||
if (e.touches?.length === 2) {
|
||||
// Store center pos for zoom
|
||||
zoomPos = getMultiTouchPos(e);
|
||||
app.canvas.pointer_is_down = false;
|
||||
}
|
||||
}
|
||||
},
|
||||
true
|
||||
);
|
||||
|
||||
app.canvasEl.addEventListener("touchend", (e) => {
|
||||
touchZooming = false;
|
||||
touchCount = e.touches?.length ?? touchCount - 1;
|
||||
if (touchTime && !e.touches?.length) {
|
||||
if (new Date() - touchTime > 600) {
|
||||
try {
|
||||
// hack to get litegraph to use this event
|
||||
e.constructor = CustomEvent;
|
||||
} catch (error) {}
|
||||
e.clientX = lastTouch.clientX;
|
||||
e.clientY = lastTouch.clientY;
|
||||
|
||||
app.canvas.pointer_is_down = true;
|
||||
app.canvas._mousedown_callback(e);
|
||||
}
|
||||
touchTime = null;
|
||||
}
|
||||
});
|
||||
|
||||
app.canvasEl.addEventListener(
|
||||
"touchmove",
|
||||
(e) => {
|
||||
touchTime = null;
|
||||
if (e.touches?.length === 2) {
|
||||
app.canvas.pointer_is_down = false;
|
||||
touchZooming = true;
|
||||
LiteGraph.closeAllContextMenus();
|
||||
app.canvas.search_box?.close();
|
||||
const newZoomPos = getMultiTouchPos(e);
|
||||
|
||||
const midX = (e.touches[0].clientX + e.touches[1].clientX) / 2;
|
||||
const midY = (e.touches[0].clientY + e.touches[1].clientY) / 2;
|
||||
|
||||
let scale = app.canvas.ds.scale;
|
||||
const diff = zoomPos - newZoomPos;
|
||||
if (diff > 0.5) {
|
||||
scale *= 1 / 1.07;
|
||||
} else if (diff < -0.5) {
|
||||
scale *= 1.07;
|
||||
}
|
||||
app.canvas.ds.changeScale(scale, [midX, midY]);
|
||||
app.canvas.setDirty(true, true);
|
||||
zoomPos = newZoomPos;
|
||||
}
|
||||
},
|
||||
true
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
const processMouseDown = LGraphCanvas.prototype.processMouseDown;
|
||||
LGraphCanvas.prototype.processMouseDown = function (e) {
|
||||
if (touchZooming || touchCount) {
|
||||
return;
|
||||
}
|
||||
return processMouseDown.apply(this, arguments);
|
||||
};
|
||||
|
||||
const processMouseMove = LGraphCanvas.prototype.processMouseMove;
|
||||
LGraphCanvas.prototype.processMouseMove = function (e) {
|
||||
if (touchZooming || touchCount > 1) {
|
||||
return;
|
||||
}
|
||||
return processMouseMove.apply(this, arguments);
|
||||
};
|
||||
@ -22,6 +22,7 @@ function isConvertableWidget(widget, config) {
|
||||
}
|
||||
|
||||
function hideWidget(node, widget, suffix = "") {
|
||||
if (widget.type?.startsWith(CONVERTED_TYPE)) return;
|
||||
widget.origType = widget.type;
|
||||
widget.origComputeSize = widget.computeSize;
|
||||
widget.origSerializeValue = widget.serializeValue;
|
||||
@ -260,6 +261,12 @@ app.registerExtension({
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
// Add menu options to conver to/from widgets
|
||||
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
|
||||
nodeType.prototype.convertWidgetToInput = function (widget) {
|
||||
const config = getConfig.call(this, widget.name) ?? [widget.type, widget.options || {}];
|
||||
if (!isConvertableWidget(widget, config)) return false;
|
||||
convertToInput(this, widget, config);
|
||||
return true;
|
||||
};
|
||||
nodeType.prototype.getExtraMenuOptions = function (_, options) {
|
||||
const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : undefined;
|
||||
|
||||
|
||||
@ -11549,7 +11549,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
dialog.close();
|
||||
} else if (e.keyCode == 13) {
|
||||
if (selected) {
|
||||
select(selected.innerHTML);
|
||||
select(unescape(selected.dataset["type"]));
|
||||
} else if (first) {
|
||||
select(first);
|
||||
} else {
|
||||
@ -11910,7 +11910,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
var ctor = LiteGraph.registered_node_types[ type ];
|
||||
if(filter && ctor.filter != filter )
|
||||
return false;
|
||||
if ((!options.show_all_if_empty || str) && type.toLowerCase().indexOf(str) === -1)
|
||||
if ((!options.show_all_if_empty || str) && type.toLowerCase().indexOf(str) === -1 && (!ctor.title || ctor.title.toLowerCase().indexOf(str) === -1))
|
||||
return false;
|
||||
|
||||
// filter by slot IN, OUT types
|
||||
@ -11964,7 +11964,18 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
if (!first) {
|
||||
first = type;
|
||||
}
|
||||
help.innerText = type;
|
||||
|
||||
const nodeType = LiteGraph.registered_node_types[type];
|
||||
if (nodeType?.title) {
|
||||
help.innerText = nodeType?.title;
|
||||
const typeEl = document.createElement("span");
|
||||
typeEl.className = "litegraph lite-search-item-type";
|
||||
typeEl.textContent = type;
|
||||
help.append(typeEl);
|
||||
} else {
|
||||
help.innerText = type;
|
||||
}
|
||||
|
||||
help.dataset["type"] = escape(type);
|
||||
help.className = "litegraph lite-search-item";
|
||||
if (className) {
|
||||
|
||||
@ -184,6 +184,7 @@
|
||||
color: white;
|
||||
padding-left: 10px;
|
||||
margin-right: 5px;
|
||||
max-width: 300px;
|
||||
}
|
||||
|
||||
.litegraph.litesearchbox .name {
|
||||
@ -227,6 +228,18 @@
|
||||
color: black;
|
||||
}
|
||||
|
||||
.litegraph.lite-search-item-type {
|
||||
display: inline-block;
|
||||
background: rgba(0,0,0,0.2);
|
||||
margin-left: 5px;
|
||||
font-size: 14px;
|
||||
padding: 2px 5px;
|
||||
position: relative;
|
||||
top: -2px;
|
||||
opacity: 0.8;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* DIALOGs ******/
|
||||
|
||||
.litegraph .dialog {
|
||||
|
||||
@ -5,6 +5,7 @@ class ComfyApi extends EventTarget {
|
||||
super();
|
||||
this.api_host = location.host;
|
||||
this.api_base = location.pathname.split('/').slice(0, -1).join('/');
|
||||
this.initialClientId = sessionStorage.getItem("clientId");
|
||||
}
|
||||
|
||||
apiURL(route) {
|
||||
@ -118,7 +119,8 @@ class ComfyApi extends EventTarget {
|
||||
case "status":
|
||||
if (msg.data.sid) {
|
||||
this.clientId = msg.data.sid;
|
||||
window.name = this.clientId;
|
||||
window.name = this.clientId; // use window name so it isnt reused when duplicating tabs
|
||||
sessionStorage.setItem("clientId", this.clientId); // store in session storage so duplicate tab can load correct workflow
|
||||
}
|
||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||
break;
|
||||
|
||||
@ -1499,12 +1499,17 @@ export class ComfyApp {
|
||||
// Load previous workflow
|
||||
let restored = false;
|
||||
try {
|
||||
const json = localStorage.getItem("workflow");
|
||||
if (json) {
|
||||
const workflow = JSON.parse(json);
|
||||
await this.loadGraphData(workflow);
|
||||
restored = true;
|
||||
}
|
||||
const loadWorkflow = async (json) => {
|
||||
if (json) {
|
||||
const workflow = JSON.parse(json);
|
||||
await this.loadGraphData(workflow);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
const clientId = api.initialClientId ?? api.clientId;
|
||||
restored =
|
||||
(clientId && (await loadWorkflow(sessionStorage.getItem(`workflow:${clientId}`)))) ||
|
||||
(await loadWorkflow(localStorage.getItem("workflow")));
|
||||
} catch (err) {
|
||||
console.error("Error loading previous workflow", err);
|
||||
}
|
||||
@ -1515,7 +1520,13 @@ export class ComfyApp {
|
||||
}
|
||||
|
||||
// Save current workflow automatically
|
||||
setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())), 1000);
|
||||
setInterval(() => {
|
||||
const workflow = JSON.stringify(this.graph.serialize());
|
||||
localStorage.setItem("workflow", workflow);
|
||||
if (api.clientId) {
|
||||
sessionStorage.setItem(`workflow:${api.clientId}`, workflow);
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
this.#addDrawNodeHandler();
|
||||
this.#addDrawGroupsHandler();
|
||||
@ -2096,6 +2107,8 @@ export class ComfyApp {
|
||||
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
|
||||
} else if (pngInfo.prompt) {
|
||||
this.loadApiJson(JSON.parse(pngInfo.prompt));
|
||||
} else if (pngInfo.Prompt) {
|
||||
this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node.
|
||||
}
|
||||
}
|
||||
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
|
||||
@ -2149,8 +2162,17 @@ export class ComfyApp {
|
||||
if (value instanceof Array) {
|
||||
const [fromId, fromSlot] = value;
|
||||
const fromNode = app.graph.getNodeById(fromId);
|
||||
const toSlot = node.inputs?.findIndex((inp) => inp.name === input);
|
||||
if (toSlot !== -1) {
|
||||
let toSlot = node.inputs?.findIndex((inp) => inp.name === input);
|
||||
if (toSlot == null || toSlot === -1) {
|
||||
try {
|
||||
// Target has no matching input, most likely a converted widget
|
||||
const widget = node.widgets?.find((w) => w.name === input);
|
||||
if (widget && node.convertWidgetToInput?.(widget)) {
|
||||
toSlot = node.inputs?.length - 1;
|
||||
}
|
||||
} catch (error) {}
|
||||
}
|
||||
if (toSlot != null || toSlot !== -1) {
|
||||
fromNode.connect(fromSlot, node, toSlot);
|
||||
}
|
||||
} else {
|
||||
|
||||
@ -24,7 +24,7 @@ export function getPngMetadata(file) {
|
||||
const length = dataView.getUint32(offset);
|
||||
// Get the chunk type
|
||||
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
|
||||
if (type === "tEXt" || type == "comf") {
|
||||
if (type === "tEXt" || type == "comf" || type === "iTXt") {
|
||||
// Get the keyword
|
||||
let keyword_end = offset + 8;
|
||||
while (pngData[keyword_end] !== 0) {
|
||||
@ -33,7 +33,7 @@ export function getPngMetadata(file) {
|
||||
const keyword = String.fromCharCode(...pngData.slice(offset + 8, keyword_end));
|
||||
// Get the text
|
||||
const contentArraySegment = pngData.slice(keyword_end + 1, offset + 8 + length);
|
||||
const contentJson = Array.from(contentArraySegment).map(s=>String.fromCharCode(s)).join('')
|
||||
const contentJson = new TextDecoder("utf-8").decode(contentArraySegment);
|
||||
txt_chunks[keyword] = contentJson;
|
||||
}
|
||||
|
||||
|
||||
@ -401,18 +401,42 @@ export class ComfyUI {
|
||||
}
|
||||
});
|
||||
|
||||
this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [
|
||||
$el("div.drag-handle", {
|
||||
this.menuHamburger = $el(
|
||||
"div.comfy-menu-hamburger",
|
||||
{
|
||||
parent: document.body,
|
||||
onclick: () => {
|
||||
this.menuContainer.style.display = "block";
|
||||
this.menuHamburger.style.display = "none";
|
||||
},
|
||||
},
|
||||
[$el("div"), $el("div"), $el("div")]
|
||||
);
|
||||
|
||||
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
|
||||
$el("div.drag-handle.comfy-menu-header", {
|
||||
style: {
|
||||
overflow: "hidden",
|
||||
position: "relative",
|
||||
width: "100%",
|
||||
cursor: "default"
|
||||
}
|
||||
}, [
|
||||
}, [
|
||||
$el("span.drag-handle"),
|
||||
$el("span", {$: (q) => (this.queueSize = q)}),
|
||||
$el("button.comfy-settings-btn", {textContent: "⚙️", onclick: () => this.settings.show()}),
|
||||
$el("span.comfy-menu-queue-size", { $: (q) => (this.queueSize = q) }),
|
||||
$el("div.comfy-menu-actions", [
|
||||
$el("button.comfy-settings-btn", {
|
||||
textContent: "⚙️",
|
||||
onclick: () => this.settings.show(),
|
||||
}),
|
||||
$el("button.comfy-close-menu-btn", {
|
||||
textContent: "\u00d7",
|
||||
onclick: () => {
|
||||
this.menuContainer.style.display = "none";
|
||||
this.menuHamburger.style.display = "flex";
|
||||
},
|
||||
}),
|
||||
]),
|
||||
]),
|
||||
$el("button.comfy-queue-btn", {
|
||||
id: "queue-button",
|
||||
|
||||
@ -16,7 +16,17 @@ export class ComfySettingsDialog extends ComfyDialog {
|
||||
},
|
||||
[
|
||||
$el("table.comfy-modal-content.comfy-table", [
|
||||
$el("caption", { textContent: "Settings" }),
|
||||
$el(
|
||||
"caption",
|
||||
{ textContent: "Settings" },
|
||||
$el("button.comfy-btn", {
|
||||
type: "button",
|
||||
textContent: "\u00d7",
|
||||
onclick: () => {
|
||||
this.element.close();
|
||||
},
|
||||
})
|
||||
),
|
||||
$el("tbody", { $: (tbody) => (this.textElement = tbody) }),
|
||||
$el("button", {
|
||||
type: "button",
|
||||
|
||||
@ -81,6 +81,9 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
|
||||
|
||||
const isCombo = targetWidget.type === "combo";
|
||||
let comboFilter;
|
||||
if (isCombo) {
|
||||
valueControl.options.values.push("increment-wrap");
|
||||
}
|
||||
if (isCombo && options.addFilterList !== false) {
|
||||
comboFilter = node.addWidget(
|
||||
"string",
|
||||
@ -128,6 +131,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
|
||||
case "increment":
|
||||
current_index += 1;
|
||||
break;
|
||||
case "increment-wrap":
|
||||
current_index += 1;
|
||||
if ( current_index >= current_length ) {
|
||||
current_index = 0;
|
||||
}
|
||||
break;
|
||||
case "decrement":
|
||||
current_index -= 1;
|
||||
break;
|
||||
@ -295,7 +304,7 @@ export const ComfyWidgets = {
|
||||
let disable_rounding = app.ui.settings.getSettingValue("Comfy.DisableFloatRounding")
|
||||
if (precision == 0) precision = undefined;
|
||||
const { val, config } = getNumberDefaults(inputData, 0.5, precision, !disable_rounding);
|
||||
return { widget: node.addWidget(widgetType, inputName, val,
|
||||
return { widget: node.addWidget(widgetType, inputName, val,
|
||||
function (v) {
|
||||
if (config.round) {
|
||||
this.value = Math.round(v/config.round)*config.round;
|
||||
|
||||
171
web/style.css
171
web/style.css
@ -82,6 +82,24 @@ body {
|
||||
margin: 3px 3px 3px 4px;
|
||||
}
|
||||
|
||||
.comfy-menu-hamburger {
|
||||
position: fixed;
|
||||
top: 10px;
|
||||
z-index: 9999;
|
||||
right: 10px;
|
||||
width: 30px;
|
||||
display: none;
|
||||
gap: 8px;
|
||||
flex-direction: column;
|
||||
cursor: pointer;
|
||||
}
|
||||
.comfy-menu-hamburger div {
|
||||
height: 3px;
|
||||
width: 100%;
|
||||
border-radius: 20px;
|
||||
background-color: white;
|
||||
}
|
||||
|
||||
.comfy-menu {
|
||||
font-size: 15px;
|
||||
position: absolute;
|
||||
@ -101,6 +119,44 @@ body {
|
||||
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.comfy-menu-header {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.comfy-menu-actions {
|
||||
display: flex;
|
||||
gap: 3px;
|
||||
align-items: center;
|
||||
height: 20px;
|
||||
position: relative;
|
||||
top: -1px;
|
||||
font-size: 22px;
|
||||
}
|
||||
|
||||
.comfy-menu .comfy-menu-actions button {
|
||||
background-color: rgba(0, 0, 0, 0);
|
||||
padding: 0;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
font-size: inherit;
|
||||
}
|
||||
|
||||
.comfy-menu .comfy-menu-actions .comfy-settings-btn {
|
||||
font-size: 0.6em;
|
||||
}
|
||||
|
||||
button.comfy-close-menu-btn {
|
||||
font-size: 1em;
|
||||
line-height: 12px;
|
||||
color: #ccc;
|
||||
position: relative;
|
||||
top: -1px;
|
||||
}
|
||||
|
||||
.comfy-menu-queue-size {
|
||||
flex: auto;
|
||||
}
|
||||
|
||||
.comfy-menu button,
|
||||
.comfy-modal button {
|
||||
font-size: 20px;
|
||||
@ -121,7 +177,6 @@ body {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch,
|
||||
.comfy-btn,
|
||||
.comfy-menu > button,
|
||||
.comfy-menu-btns button,
|
||||
@ -140,17 +195,12 @@ body {
|
||||
.comfy-menu-btns button:hover,
|
||||
.comfy-menu .comfy-list button:hover,
|
||||
.comfy-modal button:hover,
|
||||
.comfy-settings-btn:hover {
|
||||
.comfy-menu-actions button:hover {
|
||||
filter: brightness(1.2);
|
||||
will-change: transform;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.comfy-menu span.drag-handle {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
span.drag-handle {
|
||||
width: 10px;
|
||||
height: 20px;
|
||||
@ -215,15 +265,6 @@ span.drag-handle::after {
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
button.comfy-settings-btn {
|
||||
background-color: rgba(0, 0, 0, 0);
|
||||
font-size: 12px;
|
||||
padding: 0;
|
||||
position: absolute;
|
||||
right: 0;
|
||||
border: none;
|
||||
}
|
||||
|
||||
button.comfy-queue-btn {
|
||||
margin: 6px 0 !important;
|
||||
}
|
||||
@ -269,7 +310,19 @@ button.comfy-queue-btn {
|
||||
}
|
||||
|
||||
.comfy-menu span.drag-handle {
|
||||
visibility: hidden
|
||||
display: none;
|
||||
}
|
||||
|
||||
.comfy-menu-queue-size {
|
||||
flex: unset;
|
||||
}
|
||||
|
||||
.comfy-menu-header {
|
||||
justify-content: space-between;
|
||||
}
|
||||
.comfy-menu-actions {
|
||||
gap: 10px;
|
||||
font-size: 28px;
|
||||
}
|
||||
}
|
||||
|
||||
@ -320,7 +373,7 @@ dialog::backdrop {
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
#comfy-settings-dialog button {
|
||||
#comfy-settings-dialog tbody button, #comfy-settings-dialog table > button {
|
||||
background-color: var(--bg-color);
|
||||
border: 1px var(--border-color) solid;
|
||||
border-radius: 0;
|
||||
@ -343,12 +396,33 @@ dialog::backdrop {
|
||||
}
|
||||
|
||||
.comfy-table caption {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
background-color: var(--bg-color);
|
||||
color: var(--input-text);
|
||||
font-size: 1rem;
|
||||
font-weight: bold;
|
||||
padding: 8px;
|
||||
text-align: center;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.comfy-table caption .comfy-btn {
|
||||
position: absolute;
|
||||
top: -2px;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
height: 100%;
|
||||
border-radius: 0;
|
||||
aspect-ratio: 1/1;
|
||||
user-select: none;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.comfy-table caption .comfy-btn:focus {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.comfy-table tr:nth-child(even) {
|
||||
@ -389,11 +463,13 @@ dialog::backdrop {
|
||||
z-index: 9999 !important;
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
filter: brightness(95%);
|
||||
will-change: transform;
|
||||
}
|
||||
|
||||
.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) {
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
filter: brightness(155%);
|
||||
will-change: transform;
|
||||
color: var(--input-text);
|
||||
}
|
||||
|
||||
@ -435,43 +511,6 @@ dialog::backdrop {
|
||||
margin-left: 5px;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch {
|
||||
border-width: 2px;
|
||||
display: flex;
|
||||
background-color: var(--comfy-input-bg);
|
||||
margin: 2px 0;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch label {
|
||||
padding: 2px 0px 3px 6px;
|
||||
flex: auto;
|
||||
border-radius: 8px;
|
||||
align-items: center;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch label:first-child {
|
||||
border-top-left-radius: 8px;
|
||||
border-bottom-left-radius: 8px;
|
||||
}
|
||||
.comfy-toggle-switch label:last-child {
|
||||
border-top-right-radius: 8px;
|
||||
border-bottom-right-radius: 8px;
|
||||
}
|
||||
|
||||
.comfy-toggle-switch .comfy-toggle-selected {
|
||||
background-color: var(--comfy-menu-bg);
|
||||
}
|
||||
|
||||
#extraOptions {
|
||||
padding: 4px;
|
||||
background-color: var(--bg-color);
|
||||
margin-bottom: 4px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Search box */
|
||||
|
||||
.litegraph.litesearchbox {
|
||||
@ -491,10 +530,30 @@ dialog::backdrop {
|
||||
color: var(--input-text);
|
||||
background-color: var(--comfy-input-bg);
|
||||
filter: brightness(80%);
|
||||
will-change: transform;
|
||||
padding-left: 0.2em;
|
||||
}
|
||||
|
||||
.litegraph.lite-search-item.generic_type {
|
||||
color: var(--input-text);
|
||||
filter: brightness(50%);
|
||||
will-change: transform;
|
||||
}
|
||||
|
||||
@media only screen and (max-width: 450px) {
|
||||
#comfy-settings-dialog .comfy-table tbody {
|
||||
display: grid;
|
||||
}
|
||||
#comfy-settings-dialog .comfy-table tr {
|
||||
display: grid;
|
||||
}
|
||||
#comfy-settings-dialog tr > td:first-child {
|
||||
text-align: center;
|
||||
border-bottom: none;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
#comfy-settings-dialog tr > td:not(:first-child) {
|
||||
text-align: center;
|
||||
border-top: none;
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user