diff --git a/README.md b/README.md
index 093873921..a94a212ad 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
-- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
+- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
@@ -95,16 +95,15 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
-Note: pytorch stable does not support python 3.12 yet. If you have python 3.12 you will have to use the nightly version of pytorch. If you run into issues you should try python 3.11 instead.
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
-```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7```
-This is the command to install the nightly with ROCm 5.7 which has a python 3.12 package and might have some performance improvements:
+This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:
-```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```
+```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```
### NVIDIA
@@ -112,7 +111,7 @@ Nvidia users should install stable pytorch using this command:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```
-This is the command to install pytorch nightly instead which has a python 3.12 package and might have performance improvements:
+This is the command to install pytorch nightly instead which might have performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
diff --git a/comfy/clip_model.py b/comfy/clip_model.py
index 09e7bbca1..9b82a246b 100644
--- a/comfy/clip_model.py
+++ b/comfy/clip_model.py
@@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
- mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 82170431e..416197586 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -166,7 +166,7 @@ class ControlNet(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
- context = cond['c_crossattn']
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
@@ -318,9 +318,10 @@ def load_controlnet(ckpt_path, model=None):
return ControlLora(controlnet_data)
controlnet_config = None
+ supported_inference_dtypes = None
+
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
- unet_dtype = comfy.model_management.unet_dtype()
- controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@@ -380,12 +381,20 @@ def load_controlnet(ckpt_path, model=None):
return net
if controlnet_config is None:
- unet_dtype = comfy.model_management.unet_dtype()
- controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
+ supported_inference_dtypes = model_config.supported_inference_dtypes
+ controlnet_config = model_config.unet_config
+
load_device = comfy.model_management.get_torch_device()
+ if supported_inference_dtypes is None:
+ unet_dtype = comfy.model_management.unet_dtype()
+ else:
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
+ controlnet_config["dtype"] = unet_dtype
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
diff --git a/comfy/gligen.py b/comfy/gligen.py
index 71892dfb1..592522767 100644
--- a/comfy/gligen.py
+++ b/comfy/gligen.py
@@ -2,7 +2,8 @@ import torch
from torch import nn
from .ldm.modules.attention import CrossAttention
from inspect import isfunction
-
+import comfy.ops
+ops = comfy.ops.manual_cast
def exists(val):
return val is not None
@@ -22,7 +23,7 @@ def default(val, d):
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
+ self.proj = ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@@ -35,14 +36,14 @@ class FeedForward(nn.Module):
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
- nn.Linear(dim, inner_dim),
+ ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
- nn.Linear(inner_dim, dim_out)
+ ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
@@ -57,11 +58,12 @@ class GatedCrossAttentionDense(nn.Module):
query_dim=query_dim,
context_dim=context_dim,
heads=n_heads,
- dim_head=d_head)
+ dim_head=d_head,
+ operations=ops)
self.ff = FeedForward(query_dim, glu=True)
- self.norm1 = nn.LayerNorm(query_dim)
- self.norm2 = nn.LayerNorm(query_dim)
+ self.norm1 = ops.LayerNorm(query_dim)
+ self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@@ -87,17 +89,18 @@ class GatedSelfAttentionDense(nn.Module):
# we need a linear projection since we need cat visual feature and obj
# feature
- self.linear = nn.Linear(context_dim, query_dim)
+ self.linear = ops.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=query_dim,
heads=n_heads,
- dim_head=d_head)
+ dim_head=d_head,
+ operations=ops)
self.ff = FeedForward(query_dim, glu=True)
- self.norm1 = nn.LayerNorm(query_dim)
- self.norm2 = nn.LayerNorm(query_dim)
+ self.norm1 = ops.LayerNorm(query_dim)
+ self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@@ -126,14 +129,14 @@ class GatedSelfAttentionDense2(nn.Module):
# we need a linear projection since we need cat visual feature and obj
# feature
- self.linear = nn.Linear(context_dim, query_dim)
+ self.linear = ops.Linear(context_dim, query_dim)
self.attn = CrossAttention(
- query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
+ query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
self.ff = FeedForward(query_dim, glu=True)
- self.norm1 = nn.LayerNorm(query_dim)
- self.norm2 = nn.LayerNorm(query_dim)
+ self.norm1 = ops.LayerNorm(query_dim)
+ self.norm2 = ops.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
@@ -201,11 +204,11 @@ class PositionNet(nn.Module):
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
self.linears = nn.Sequential(
- nn.Linear(self.in_dim + self.position_dim, 512),
+ ops.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
- nn.Linear(512, 512),
+ ops.Linear(512, 512),
nn.SiLU(),
- nn.Linear(512, out_dim),
+ ops.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(
@@ -215,16 +218,15 @@ class PositionNet(nn.Module):
def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
- dtype = self.linears[0].weight.dtype
- masks = masks.unsqueeze(-1).to(dtype)
- positive_embeddings = positive_embeddings.to(dtype)
+ masks = masks.unsqueeze(-1)
+ positive_embeddings = positive_embeddings
# embedding position (it may includes padding as placeholder)
- xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
# learnable null embedding
- positive_null = self.null_positive_feature.view(1, 1, -1)
- xyxy_null = self.null_position_feature.view(1, 1, -1)
+ positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
+ xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \
@@ -251,7 +253,7 @@ class Gligen(nn.Module):
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
- return module(x, objs)
+ return module(x, objs.to(device=x.device, dtype=x.dtype))
return func
def set_position(self, latent_image_shape, position_params, device):
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 2252a075e..03fd59e3d 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -37,3 +37,41 @@ class SDXL(LatentFormat):
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
+ self.latent_rgb_factors = [
+ [-0.2340, -0.3863, -0.3257],
+ [ 0.0994, 0.0885, -0.0908],
+ [-0.2833, -0.2349, -0.3741],
+ [ 0.2523, -0.0055, -0.1651]
+ ]
+
+class SC_Prior(LatentFormat):
+ def __init__(self):
+ self.scale_factor = 1.0
+ self.latent_rgb_factors = [
+ [-0.0326, -0.0204, -0.0127],
+ [-0.1592, -0.0427, 0.0216],
+ [ 0.0873, 0.0638, -0.0020],
+ [-0.0602, 0.0442, 0.1304],
+ [ 0.0800, -0.0313, -0.1796],
+ [-0.0810, -0.0638, -0.1581],
+ [ 0.1791, 0.1180, 0.0967],
+ [ 0.0740, 0.1416, 0.0432],
+ [-0.1745, -0.1888, -0.1373],
+ [ 0.2412, 0.1577, 0.0928],
+ [ 0.1908, 0.0998, 0.0682],
+ [ 0.0209, 0.0365, -0.0092],
+ [ 0.0448, -0.0650, -0.1728],
+ [-0.1658, -0.1045, -0.1308],
+ [ 0.0542, 0.1545, 0.1325],
+ [-0.0352, -0.1672, -0.2541]
+ ]
+
+class SC_B(LatentFormat):
+ def __init__(self):
+ self.scale_factor = 1.0
+ self.latent_rgb_factors = [
+ [ 0.1121, 0.2006, 0.1023],
+ [-0.2093, -0.0222, -0.0195],
+ [-0.3087, -0.1535, 0.0366],
+ [ 0.0290, -0.1574, -0.4078]
+ ]
diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py
new file mode 100644
index 000000000..124902c09
--- /dev/null
+++ b/comfy/ldm/cascade/common.py
@@ -0,0 +1,161 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+import torch.nn as nn
+from comfy.ldm.modules.attention import optimized_attention
+
+class Linear(torch.nn.Linear):
+ def reset_parameters(self):
+ return None
+
+class Conv2d(torch.nn.Conv2d):
+ def reset_parameters(self):
+ return None
+
+class OptimizedAttention(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.heads = nhead
+
+ self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+ self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+ self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+
+ def forward(self, q, k, v):
+ q = self.to_q(q)
+ k = self.to_k(k)
+ v = self.to_v(v)
+
+ out = optimized_attention(q, k, v, self.heads)
+
+ return self.out_proj(out)
+
+class Attention2D(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
+ # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
+
+ def forward(self, x, kv, self_attn=False):
+ orig_shape = x.shape
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
+ if self_attn:
+ kv = torch.cat([x, kv], dim=1)
+ # x = self.attn(x, kv, kv, need_weights=False)[0]
+ x = self.attn(x, kv, kv)
+ x = x.permute(0, 2, 1).view(*orig_shape)
+ return x
+
+
+def LayerNorm2d_op(operations):
+ class LayerNorm2d(operations.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return LayerNorm2d
+
+class GlobalResponseNorm(nn.Module):
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
+ def __init__(self, dim, dtype=None, device=None):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
+ super().__init__()
+ self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.channelwise = nn.Sequential(
+ operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
+ nn.Dropout(dropout),
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x, x_skip=None):
+ x_res = x
+ x = self.norm(self.depthwise(x))
+ if x_skip is not None:
+ x = torch.cat([x, x_skip], dim=1)
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x + x_res
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(c_cond, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x, kv):
+ kv = self.kv_mapper(kv)
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
+ return x
+
+
+class FeedForwardBlock(nn.Module):
+ def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.channelwise = nn.Sequential(
+ operations.Linear(c, c * 4, dtype=dtype, device=device),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
+ nn.Dropout(dropout),
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x):
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x
+
+
+class TimestepBlock(nn.Module):
+ def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
+ super().__init__()
+ self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
+ self.conds = conds
+ for cname in conds:
+ setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
+
+ def forward(self, x, t):
+ t = t.chunk(len(self.conds) + 1, dim=1)
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
+ for i, c in enumerate(self.conds):
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
+ a, b = a + ac, b + bc
+ return x * (1 + a) + b
diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py
new file mode 100644
index 000000000..260ccfc0b
--- /dev/null
+++ b/comfy/ldm/cascade/stage_a.py
@@ -0,0 +1,258 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+from torch import nn
+from torch.autograd import Function
+
+class vector_quantize(Function):
+ @staticmethod
+ def forward(ctx, x, codebook):
+ with torch.no_grad():
+ codebook_sqr = torch.sum(codebook ** 2, dim=1)
+ x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
+
+ dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
+ _, indices = dist.min(dim=1)
+
+ ctx.save_for_backward(indices, codebook)
+ ctx.mark_non_differentiable(indices)
+
+ nn = torch.index_select(codebook, 0, indices)
+ return nn, indices
+
+ @staticmethod
+ def backward(ctx, grad_output, grad_indices):
+ grad_inputs, grad_codebook = None, None
+
+ if ctx.needs_input_grad[0]:
+ grad_inputs = grad_output.clone()
+ if ctx.needs_input_grad[1]:
+ # Gradient wrt. the codebook
+ indices, codebook = ctx.saved_tensors
+
+ grad_codebook = torch.zeros_like(codebook)
+ grad_codebook.index_add_(0, indices, grad_output)
+
+ return (grad_inputs, grad_codebook)
+
+
+class VectorQuantize(nn.Module):
+ def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
+ """
+ Takes an input of variable size (as long as the last dimension matches the embedding size).
+ Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
+ with the same size as the input, vq and commitment components for the loss as a touple
+ in the second output and the indices of the quantized vectors in the third:
+ quantized, (vq_loss, commit_loss), indices
+ """
+ super(VectorQuantize, self).__init__()
+
+ self.codebook = nn.Embedding(k, embedding_size)
+ self.codebook.weight.data.uniform_(-1./k, 1./k)
+ self.vq = vector_quantize.apply
+
+ self.ema_decay = ema_decay
+ self.ema_loss = ema_loss
+ if ema_loss:
+ self.register_buffer('ema_element_count', torch.ones(k))
+ self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
+
+ def _laplace_smoothing(self, x, epsilon):
+ n = torch.sum(x)
+ return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
+
+ def _updateEMA(self, z_e_x, indices):
+ mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
+ elem_count = mask.sum(dim=0)
+ weight_sum = torch.mm(mask.t(), z_e_x)
+
+ self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
+ self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
+ self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
+
+ self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
+
+ def idx2vq(self, idx, dim=-1):
+ q_idx = self.codebook(idx)
+ if dim != -1:
+ q_idx = q_idx.movedim(-1, dim)
+ return q_idx
+
+ def forward(self, x, get_losses=True, dim=-1):
+ if dim != -1:
+ x = x.movedim(dim, -1)
+ z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
+ z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
+ vq_loss, commit_loss = None, None
+ if self.ema_loss and self.training:
+ self._updateEMA(z_e_x.detach(), indices.detach())
+ # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
+ z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
+ if get_losses:
+ vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
+ commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
+
+ z_q_x = z_q_x.view(x.shape)
+ if dim != -1:
+ z_q_x = z_q_x.movedim(-1, dim)
+ return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
+
+
+class ResBlock(nn.Module):
+ def __init__(self, c, c_hidden):
+ super().__init__()
+ # depthwise/attention
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.depthwise = nn.Sequential(
+ nn.ReplicationPad2d(1),
+ nn.Conv2d(c, c, kernel_size=3, groups=c)
+ )
+
+ # channelwise
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ nn.Linear(c, c_hidden),
+ nn.GELU(),
+ nn.Linear(c_hidden, c),
+ )
+
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
+
+ # Init weights
+ def _basic_init(module):
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ def _norm(self, x, norm):
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+ def forward(self, x):
+ mods = self.gammas
+
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
+ try:
+ x = x + self.depthwise(x_temp) * mods[2]
+ except: #operation not implemented for bf16
+ x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
+ x = x + self.depthwise[1](x_temp) * mods[2]
+
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
+
+ return x
+
+
+class StageA(nn.Module):
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
+ scale_factor=0.43): # 0.3764
+ super().__init__()
+ self.c_latent = c_latent
+ self.scale_factor = scale_factor
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
+
+ # Encoder blocks
+ self.in_block = nn.Sequential(
+ nn.PixelUnshuffle(2),
+ nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
+ )
+ down_blocks = []
+ for i in range(levels):
+ if i > 0:
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
+ block = ResBlock(c_levels[i], c_levels[i] * 4)
+ down_blocks.append(block)
+ down_blocks.append(nn.Sequential(
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
+ ))
+ self.down_blocks = nn.Sequential(*down_blocks)
+ self.down_blocks[0]
+
+ self.codebook_size = codebook_size
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
+
+ # Decoder blocks
+ up_blocks = [nn.Sequential(
+ nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
+ )]
+ for i in range(levels):
+ for j in range(bottleneck_blocks if i == 0 else 1):
+ block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
+ up_blocks.append(block)
+ if i < levels - 1:
+ up_blocks.append(
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
+ padding=1))
+ self.up_blocks = nn.Sequential(*up_blocks)
+ self.out_block = nn.Sequential(
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
+ nn.PixelShuffle(2),
+ )
+
+ def encode(self, x, quantize=False):
+ x = self.in_block(x)
+ x = self.down_blocks(x)
+ if quantize:
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
+ return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
+ else:
+ return x / self.scale_factor
+
+ def decode(self, x):
+ x = x * self.scale_factor
+ x = self.up_blocks(x)
+ x = self.out_block(x)
+ return x
+
+ def forward(self, x, quantize=False):
+ qe, x, _, vq_loss = self.encode(x, quantize)
+ x = self.decode(qe)
+ return x, vq_loss
+
+
+class Discriminator(nn.Module):
+ def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
+ super().__init__()
+ d = max(depth - 3, 3)
+ layers = [
+ nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
+ nn.LeakyReLU(0.2),
+ ]
+ for i in range(depth - 1):
+ c_in = c_hidden // (2 ** max((d - i), 0))
+ c_out = c_hidden // (2 ** max((d - 1 - i), 0))
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
+ layers.append(nn.InstanceNorm2d(c_out))
+ layers.append(nn.LeakyReLU(0.2))
+ self.encoder = nn.Sequential(*layers)
+ self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
+ self.logits = nn.Sigmoid()
+
+ def forward(self, x, cond=None):
+ x = self.encoder(x)
+ if cond is not None:
+ cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
+ x = torch.cat([x, cond], dim=1)
+ x = self.shuffle(x)
+ x = self.logits(x)
+ return x
diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py
new file mode 100644
index 000000000..6d2c22231
--- /dev/null
+++ b/comfy/ldm/cascade/stage_b.py
@@ -0,0 +1,257 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import math
+import numpy as np
+import torch
+from torch import nn
+from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
+
+class StageB(nn.Module):
+ def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
+ nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
+ c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
+ t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.dtype = dtype
+ self.c_r = c_r
+ self.t_conds = t_conds
+ self.c_clip_seq = c_clip_seq
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+ if not isinstance(self_attn, list):
+ self_attn = [self_attn] * len(c_hidden)
+
+ # CONDITIONING
+ self.effnet_mapper = nn.Sequential(
+ operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
+ nn.GELU(),
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ )
+ self.pixels_mapper = nn.Sequential(
+ operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
+ nn.GELU(),
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ )
+ self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
+ if block_type == 'C':
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'A':
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'F':
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'T':
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
+ else:
+ raise Exception(f'Block type {block_type} not supported')
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ self.down_downscalers = nn.ModuleList()
+ self.down_repeat_mappers = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ if i > 0:
+ self.down_downscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
+ operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
+ ))
+ else:
+ self.down_downscalers.append(nn.Identity())
+ down_block = nn.ModuleList()
+ for _ in range(blocks[0][i]):
+ for block_type in level_config[i]:
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
+ down_block.append(block)
+ self.down_blocks.append(down_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[0][i] - 1):
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
+ self.down_repeat_mappers.append(block_repeat_mappers)
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ self.up_upscalers = nn.ModuleList()
+ self.up_repeat_mappers = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ if i > 0:
+ self.up_upscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
+ operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
+ ))
+ else:
+ self.up_upscalers.append(nn.Identity())
+ up_block = nn.ModuleList()
+ for j in range(blocks[1][::-1][i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
+ self_attn=self_attn[i])
+ up_block.append(block)
+ self.up_blocks.append(up_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[1][::-1][i] - 1):
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
+ self.up_repeat_mappers.append(block_repeat_mappers)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ # self.apply(self._init_weights) # General init
+ # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
+ # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
+ # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
+ # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
+ # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
+ #
+ # # blocks
+ # for level_block in self.down_blocks + self.up_blocks:
+ # for block in level_block:
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ # elif isinstance(block, TimestepBlock):
+ # for layer in block.modules():
+ # if isinstance(layer, nn.Linear):
+ # nn.init.constant_(layer.weight, 0)
+ #
+ # def _init_weights(self, m):
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
+ # torch.nn.init.xavier_uniform_(m.weight)
+ # if m.bias is not None:
+ # nn.init.constant_(m.bias, 0)
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+
+ def gen_c_embeddings(self, clip):
+ if len(clip.shape) == 2:
+ clip = clip.unsqueeze(1)
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
+ clip = self.clip_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, clip):
+ level_outputs = []
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
+ for down_block, downscaler, repmap in block_group:
+ x = downscaler(x)
+ for i in range(len(repmap) + 1):
+ for block in down_block:
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ x = block(x)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if i < len(repmap):
+ x = repmap[i](x)
+ level_outputs.insert(0, x)
+ return level_outputs
+
+ def _up_decode(self, level_outputs, r_embed, clip):
+ x = level_outputs[0]
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
+ for j in range(len(repmap) + 1):
+ for k, block in enumerate(up_block):
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ skip = level_outputs[i] if k == 0 and i > 0 else None
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if j < len(repmap):
+ x = repmap[j](x)
+ x = upscaler(x)
+ return x
+
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
+ if pixels is None:
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
+
+ # Process the conditioning embeddings
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
+ for c in self.t_conds:
+ t_cond = kwargs.get(c, torch.zeros_like(r))
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
+ clip = self.gen_c_embeddings(clip)
+
+ # Model Blocks
+ x = self.embedding(x)
+ x = x + self.effnet_mapper(
+ nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
+ x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ level_outputs = self._down_encode(x, r_embed, clip)
+ x = self._up_decode(level_outputs, r_embed, clip)
+ return self.clf(x)
+
+ def update_weights_ema(self, src_model, beta=0.999):
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py
new file mode 100644
index 000000000..08e33aded
--- /dev/null
+++ b/comfy/ldm/cascade/stage_c.py
@@ -0,0 +1,271 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+from torch import nn
+import numpy as np
+import math
+from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
+# from .controlnet import ControlNetDeliverer
+
+class UpDownBlock2d(nn.Module):
+ def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ assert mode in ['up', 'down']
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
+ align_corners=True) if enabled else nn.Identity()
+ mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+
+class StageC(nn.Module):
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
+ dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
+ dtype=None, device=None, operations=None):
+ super().__init__()
+ self.dtype = dtype
+ self.c_r = c_r
+ self.t_conds = t_conds
+ self.c_clip_seq = c_clip_seq
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+ if not isinstance(self_attn, list):
+ self_attn = [self_attn] * len(c_hidden)
+
+ # CONDITIONING
+ self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
+ self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
+ self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
+ if block_type == 'C':
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'A':
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'F':
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'T':
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
+ else:
+ raise Exception(f'Block type {block_type} not supported')
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ self.down_downscalers = nn.ModuleList()
+ self.down_repeat_mappers = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ if i > 0:
+ self.down_downscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
+ ))
+ else:
+ self.down_downscalers.append(nn.Identity())
+ down_block = nn.ModuleList()
+ for _ in range(blocks[0][i]):
+ for block_type in level_config[i]:
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
+ down_block.append(block)
+ self.down_blocks.append(down_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[0][i] - 1):
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
+ self.down_repeat_mappers.append(block_repeat_mappers)
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ self.up_upscalers = nn.ModuleList()
+ self.up_repeat_mappers = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ if i > 0:
+ self.up_upscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
+ ))
+ else:
+ self.up_upscalers.append(nn.Identity())
+ up_block = nn.ModuleList()
+ for j in range(blocks[1][::-1][i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
+ self_attn=self_attn[i])
+ up_block.append(block)
+ self.up_blocks.append(up_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[1][::-1][i] - 1):
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
+ self.up_repeat_mappers.append(block_repeat_mappers)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ # self.apply(self._init_weights) # General init
+ # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
+ # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
+ # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
+ #
+ # # blocks
+ # for level_block in self.down_blocks + self.up_blocks:
+ # for block in level_block:
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ # elif isinstance(block, TimestepBlock):
+ # for layer in block.modules():
+ # if isinstance(layer, nn.Linear):
+ # nn.init.constant_(layer.weight, 0)
+ #
+ # def _init_weights(self, m):
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
+ # torch.nn.init.xavier_uniform_(m.weight)
+ # if m.bias is not None:
+ # nn.init.constant_(m.bias, 0)
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
+ clip_txt = self.clip_txt_mapper(clip_txt)
+ if len(clip_txt_pooled.shape) == 2:
+ clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
+ if len(clip_img.shape) == 2:
+ clip_img = clip_img.unsqueeze(1)
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
+ clip = self.clip_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, clip, cnet=None):
+ level_outputs = []
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
+ for down_block, downscaler, repmap in block_group:
+ x = downscaler(x)
+ for i in range(len(repmap) + 1):
+ for block in down_block:
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ if cnet is not None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if i < len(repmap):
+ x = repmap[i](x)
+ level_outputs.insert(0, x)
+ return level_outputs
+
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
+ x = level_outputs[0]
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
+ for j in range(len(repmap) + 1):
+ for k, block in enumerate(up_block):
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ skip = level_outputs[i] if k == 0 and i > 0 else None
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
+ align_corners=True)
+ if cnet is not None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if j < len(repmap):
+ x = repmap[j](x)
+ x = upscaler(x)
+ return x
+
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
+ # Process the conditioning embeddings
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
+ for c in self.t_conds:
+ t_cond = kwargs.get(c, torch.zeros_like(r))
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
+
+ # Model Blocks
+ x = self.embedding(x)
+ if cnet is not None:
+ cnet = ControlNetDeliverer(cnet)
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
+ x = self._up_decode(level_outputs, r_embed, clip, cnet)
+ return self.clf(x)
+
+ def update_weights_ema(self, src_model, beta=0.999):
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py
new file mode 100644
index 000000000..0cb7c49fc
--- /dev/null
+++ b/comfy/ldm/cascade/stage_c_coder.py
@@ -0,0 +1,95 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+import torch
+import torchvision
+from torch import nn
+
+
+# EfficientNet
+class EfficientNetEncoder(nn.Module):
+ def __init__(self, c_latent=16):
+ super().__init__()
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
+ self.mapper = nn.Sequential(
+ nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
+ )
+ self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
+ self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
+
+ def forward(self, x):
+ x = x * 0.5 + 0.5
+ x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
+ o = self.mapper(self.backbone(x))
+ return o
+
+
+# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
+class Previewer(nn.Module):
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
+ super().__init__()
+ self.blocks = nn.Sequential(
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden),
+
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden),
+
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 2),
+
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 2),
+
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
+ )
+
+ def forward(self, x):
+ return (self.blocks(x) - 0.5) * 2.0
+
+class StageC_coder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.previewer = Previewer()
+ self.encoder = EfficientNetEncoder()
+
+ def encode(self, x):
+ return self.encoder(x)
+
+ def decode(self, x):
+ return self.previewer(x)
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 9c9cb761d..48399bc07 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -114,7 +114,12 @@ def attention_basic(q, k, v, heads, mask=None):
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
- sim += mask
+ if len(mask.shape) == 2:
+ bs = 1
+ else:
+ bs = mask.shape[0]
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
+ sim.add_(mask)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
@@ -165,6 +170,13 @@ def attention_sub_quad(query, key, value, heads, mask=None):
if query_chunk_size is None:
query_chunk_size = 512
+ if mask is not None:
+ if len(mask.shape) == 2:
+ bs = 1
+ else:
+ bs = mask.shape[0]
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
+
hidden_states = efficient_dot_product_attention(
query,
key,
@@ -223,6 +235,13 @@ def attention_split(q, k, v, heads, mask=None):
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
+ if mask is not None:
+ if len(mask.shape) == 2:
+ bs = 1
+ else:
+ bs = mask.shape[0]
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
+
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py
index ac7e27173..5a6aa7d77 100644
--- a/comfy/ldm/modules/diffusionmodules/util.py
+++ b/comfy/ldm/modules/diffusionmodules/util.py
@@ -98,7 +98,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
- betas = np.clip(betas, a_min=0, a_max=0.999)
+ betas = torch.clamp(betas, min=0, max=0.999)
elif schedule == "squaredcos_cap_v2": # used for karlo prior
# return early
@@ -113,7 +113,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
- return betas.numpy()
+ return betas
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 8a843a98c..421f271b2 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1,5 +1,7 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
+from comfy.ldm.cascade.stage_c import StageC
+from comfy.ldm.cascade.stage_b import StageB
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
import comfy.model_management
@@ -12,9 +14,10 @@ class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
V_PREDICTION_EDM = 3
+ STABLE_CASCADE = 4
-from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
+from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
def model_sampling(model_config, model_type):
@@ -27,6 +30,9 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
+ elif model_type == ModelType.STABLE_CASCADE:
+ c = EPS
+ s = StableCascadeSampling
class ModelSampling(s, c):
pass
@@ -35,7 +41,7 @@ def model_sampling(model_config, model_type):
class BaseModel(torch.nn.Module):
- def __init__(self, model_config, model_type=ModelType.EPS, device=None):
+ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
unet_config = model_config.unet_config
@@ -48,7 +54,7 @@ class BaseModel(torch.nn.Module):
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
- self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
+ self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@@ -153,6 +159,10 @@ class BaseModel(torch.nn.Module):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
+ cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
+ if cross_attn_cnet is not None:
+ out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)
+
return out
def load_model_weights(self, sd, unet_prefix=""):
@@ -423,3 +433,52 @@ class SD_X4Upscaler(BaseModel):
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
return out
+
+class StableCascade_C(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=StageC)
+ self.diffusion_model.eval().requires_grad_(False)
+
+ def extra_conds(self, **kwargs):
+ out = {}
+ clip_text_pooled = kwargs["pooled_output"]
+ if clip_text_pooled is not None:
+ out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
+
+ if "unclip_conditioning" in kwargs:
+ embeds = []
+ for unclip_cond in kwargs["unclip_conditioning"]:
+ weight = unclip_cond["strength"]
+ embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
+ clip_img = torch.cat(embeds, dim=1)
+ else:
+ clip_img = torch.zeros((1, 1, 768))
+ out["clip_img"] = comfy.conds.CONDRegular(clip_img)
+ out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
+ out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
+ return out
+
+
+class StableCascade_B(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=StageB)
+ self.diffusion_model.eval().requires_grad_(False)
+
+ def extra_conds(self, **kwargs):
+ out = {}
+ noise = kwargs.get("noise", None)
+
+ clip_text_pooled = kwargs["pooled_output"]
+ if clip_text_pooled is not None:
+ out['clip'] = comfy.conds.CONDRegular(clip_text_pooled)
+
+ #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
+ prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
+
+ out["effnet"] = comfy.conds.CONDRegular(prior)
+ out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
+ return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index ea824c44c..8fca6d8c8 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -28,9 +28,38 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
-def detect_unet_config(state_dict, key_prefix, dtype):
+def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys())
+ if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
+ unet_config = {}
+ text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
+ if text_mapper_name in state_dict_keys:
+ unet_config['stable_cascade_stage'] = 'c'
+ w = state_dict[text_mapper_name]
+ if w.shape[0] == 1536: #stage c lite
+ unet_config['c_cond'] = 1536
+ unet_config['c_hidden'] = [1536, 1536]
+ unet_config['nhead'] = [24, 24]
+ unet_config['blocks'] = [[4, 12], [12, 4]]
+ elif w.shape[0] == 2048: #stage c full
+ unet_config['c_cond'] = 2048
+ elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
+ unet_config['stable_cascade_stage'] = 'b'
+ w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
+ if w.shape[-1] == 640:
+ unet_config['c_hidden'] = [320, 640, 1280, 1280]
+ unet_config['nhead'] = [-1, -1, 20, 20]
+ unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
+ unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
+ elif w.shape[-1] == 576: #stage b lite
+ unet_config['c_hidden'] = [320, 576, 1152, 1152]
+ unet_config['nhead'] = [-1, 9, 18, 18]
+ unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
+ unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
+
+ return unet_config
+
unet_config = {
"use_checkpoint": False,
"image_size": 32,
@@ -45,7 +74,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
else:
unet_config["adm_in_channels"] = None
- unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
@@ -159,8 +187,8 @@ def model_config_from_unet_config(unet_config):
print("no match", unet_config)
return None
-def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
- unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
+def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
+ unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
@@ -206,7 +234,7 @@ def convert_config(unet_config):
return new_config
-def unet_config_from_diffusers_unet(state_dict, dtype):
+def unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {}
transformer_depth = []
@@ -313,8 +341,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
return convert_config(unet_config)
return None
-def model_config_from_diffusers_unet(state_dict, dtype):
- unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
+def model_config_from_diffusers_unet(state_dict):
+ unet_config = unet_config_from_diffusers_unet(state_dict)
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None
diff --git a/comfy/model_management.py b/comfy/model_management.py
index e12146d11..adcc0e8ac 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
-def unet_dtype(device=None, model_params=0):
+def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -496,21 +496,32 @@ def unet_dtype(device=None, model_params=0):
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
- if should_use_fp16(device=device, model_params=model_params):
- return torch.float16
+ if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
+ if torch.float16 in supported_dtypes:
+ return torch.float16
+ if should_use_bf16(device, model_params=model_params, manual_cast=True):
+ if torch.bfloat16 in supported_dtypes:
+ return torch.bfloat16
return torch.float32
# None means no manual cast
-def unet_manual_cast(weight_dtype, inference_device):
+def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
return None
- fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
+ fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
- if fp16_supported:
+ bf16_supported = should_use_bf16(inference_device)
+ if bf16_supported and weight_dtype == torch.bfloat16:
+ return None
+
+ if fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
+
+ elif bf16_supported and torch.bfloat16 in supported_dtypes:
+ return torch.bfloat16
else:
return torch.float32
@@ -546,10 +557,8 @@ def text_encoder_dtype(device=None):
if is_device_cpu(device):
return torch.float16
- if should_use_fp16(device, prioritize_performance=False):
- return torch.float16
- else:
- return torch.float32
+ return torch.float16
+
def intermediate_device():
if args.gpu_only:
@@ -686,19 +695,22 @@ def mps_mode():
global cpu_state
return cpu_state == CPUState.MPS
-def is_device_cpu(device):
+def is_device_type(device, type):
if hasattr(device, 'type'):
- if (device.type == 'cpu'):
+ if (device.type == type):
return True
return False
+def is_device_cpu(device):
+ return is_device_type(device, 'cpu')
+
def is_device_mps(device):
- if hasattr(device, 'type'):
- if (device.type == 'mps'):
- return True
- return False
+ return is_device_type(device, 'mps')
-def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
+def is_device_cuda(device):
+ return is_device_type(device, 'cuda')
+
+def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled
if device is not None:
@@ -708,9 +720,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if FORCE_FP16:
return True
- if device is not None: #TODO
+ if device is not None:
if is_device_mps(device):
- return False
+ return True
if FORCE_FP32:
return False
@@ -718,16 +730,22 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if directml_enabled:
return False
- if cpu_mode() or mps_mode():
- return False #TODO ?
+ if mps_mode():
+ return True
+
+ if cpu_mode():
+ return False
if is_intel_xpu():
return True
- if torch.cuda.is_bf16_supported():
+ if torch.version.hip:
return True
props = torch.cuda.get_device_properties("cuda")
+ if props.major >= 8:
+ return True
+
if props.major < 6:
return False
@@ -740,7 +758,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if x in props.name.lower():
fp16_works = True
- if fp16_works:
+ if fp16_works or manual_cast:
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@@ -756,6 +774,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return True
+def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
+ if device is not None:
+ if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
+ return False
+
+ if device is not None: #TODO not sure about mps bf16 support
+ if is_device_mps(device):
+ return False
+
+ if FORCE_FP32:
+ return False
+
+ if directml_enabled:
+ return False
+
+ if cpu_mode() or mps_mode():
+ return False
+
+ if is_intel_xpu():
+ return True
+
+ if device is None:
+ device = torch.device("cuda")
+
+ props = torch.cuda.get_device_properties(device)
+ if props.major >= 8:
+ return True
+
+ bf16_works = torch.cuda.is_bf16_supported()
+
+ if bf16_works or manual_cast:
+ free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
+ if (not prioritize_performance) or model_params * 4 > free_model_memory:
+ return True
+
+ return False
+
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py
index cc8745c10..97e91a01d 100644
--- a/comfy/model_sampling.py
+++ b/comfy/model_sampling.py
@@ -1,5 +1,4 @@
import torch
-import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
@@ -42,8 +41,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
- alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
- # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@@ -58,8 +56,8 @@ class ModelSamplingDiscrete(torch.nn.Module):
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
- self.register_buffer('sigmas', sigmas)
- self.register_buffer('log_sigmas', sigmas.log())
+ self.register_buffer('sigmas', sigmas.float())
+ self.register_buffer('log_sigmas', sigmas.log().float())
@property
def sigma_min(self):
@@ -134,3 +132,56 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
+
+class StableCascadeSampling(ModelSamplingDiscrete):
+ def __init__(self, model_config=None):
+ super().__init__()
+
+ if model_config is not None:
+ sampling_settings = model_config.sampling_settings
+ else:
+ sampling_settings = {}
+
+ self.set_parameters(sampling_settings.get("shift", 1.0))
+
+ def set_parameters(self, shift=1.0, cosine_s=8e-3):
+ self.shift = shift
+ self.cosine_s = torch.tensor(cosine_s)
+ self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
+
+ #This part is just for compatibility with some schedulers in the codebase
+ self.num_timesteps = 10000
+ sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
+ for x in range(self.num_timesteps):
+ t = (x + 1) / self.num_timesteps
+ sigmas[x] = self.sigma(t)
+
+ self.set_sigmas(sigmas)
+
+ def sigma(self, timestep):
+ alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod)
+
+ if self.shift != 1.0:
+ var = alpha_cumprod
+ logSNR = (var/(1-var)).log()
+ logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift))
+ alpha_cumprod = logSNR.sigmoid()
+
+ alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999)
+ return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
+
+ def timestep(self, sigma):
+ var = 1 / ((sigma * sigma) + 1)
+ var = var.clamp(0, 1.0)
+ s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
+ t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
+ return t
+
+ def percent_to_sigma(self, percent):
+ if percent <= 0.0:
+ return 999999999.9
+ if percent >= 1.0:
+ return 0.0
+
+ percent = 1.0 - percent
+ return self.sigma(torch.tensor(percent))
diff --git a/comfy/ops.py b/comfy/ops.py
index f674b47f7..517688e8b 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -1,3 +1,21 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
import torch
import comfy.model_management
@@ -78,7 +96,11 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
- weight, bias = cast_bias_weight(self, input)
+ if self.weight is not None:
+ weight, bias = cast_bias_weight(self, input)
+ else:
+ weight = None
+ bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
@@ -87,6 +109,28 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
+ class ConvTranspose2d(torch.nn.ConvTranspose2d):
+ comfy_cast_weights = False
+ def reset_parameters(self):
+ return None
+
+ def forward_comfy_cast_weights(self, input, output_size=None):
+ num_spatial_dims = 2
+ output_padding = self._output_padding(
+ input, output_size, self.stride, self.padding, self.kernel_size,
+ num_spatial_dims, self.dilation)
+
+ weight, bias = cast_bias_weight(self, input)
+ return torch.nn.functional.conv_transpose2d(
+ input, weight, bias, self.stride, self.padding,
+ output_padding, self.groups, self.dilation)
+
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
@@ -112,3 +156,6 @@ class manual_cast(disable_weight_init):
class LayerNorm(disable_weight_init.LayerNorm):
comfy_cast_weights = True
+
+ class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
+ comfy_cast_weights = True
diff --git a/comfy/samplers.py b/comfy/samplers.py
index f4c3e268f..c795f208d 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -295,7 +295,7 @@ def simple_scheduler(model, steps):
def ddim_scheduler(model, steps):
s = model.model_sampling
sigs = []
- ss = len(s.sigmas) // steps
+ ss = max(len(s.sigmas) // steps, 1)
x = 1
while x < len(s.sigmas):
sigs += [float(s.sigmas[x])]
@@ -652,6 +652,7 @@ def sampler_object(name):
class KSampler:
SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES
+ DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2'))
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
@@ -670,7 +671,7 @@ class KSampler:
sigmas = None
discard_penultimate_sigma = False
- if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
+ if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS:
steps += 1
discard_penultimate_sigma = True
diff --git a/comfy/sd.py b/comfy/sd.py
index 9ca9d1d12..7a77bb177 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -1,7 +1,11 @@
import torch
+from enum import Enum
from comfy import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
+from .ldm.cascade.stage_a import StageA
+from .ldm.cascade.stage_c_coder import StageC_coder
+
import yaml
import comfy.utils
@@ -134,8 +138,11 @@ class CLIP:
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
- def load_sd(self, sd):
- return self.cond_stage_model.load_sd(sd)
+ def load_sd(self, sd, full_model=False):
+ if full_model:
+ return self.cond_stage_model.load_state_dict(sd, strict=False)
+ else:
+ return self.cond_stage_model.load_sd(sd)
def get_sd(self):
return self.cond_stage_model.state_dict()
@@ -155,7 +162,10 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
+ self.upscale_ratio = 8
self.latent_channels = 4
+ self.process_input = lambda image: image * 2.0 - 1.0
+ self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -168,6 +178,34 @@ class VAE:
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
+ elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
+ self.first_stage_model = StageA()
+ self.downscale_ratio = 4
+ self.upscale_ratio = 4
+ #TODO
+ #self.memory_used_encode
+ #self.memory_used_decode
+ self.process_input = lambda image: image
+ self.process_output = lambda image: image
+ elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
+ self.first_stage_model = StageC_coder()
+ self.downscale_ratio = 32
+ self.latent_channels = 16
+ new_sd = {}
+ for k in sd:
+ new_sd["encoder.{}".format(k)] = sd[k]
+ sd = new_sd
+ elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
+ self.first_stage_model = StageC_coder()
+ self.latent_channels = 16
+ new_sd = {}
+ for k in sd:
+ new_sd["previewer.{}".format(k)] = sd[k]
+ sd = new_sd
+ elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
+ self.first_stage_model = StageC_coder()
+ self.downscale_ratio = 32
+ self.latent_channels = 16
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -175,6 +213,7 @@ class VAE:
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
+ self.upscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
@@ -200,18 +239,27 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
+ def vae_encode_crop_pixels(self, pixels):
+ x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio
+ y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio
+ if pixels.shape[1] != x or pixels.shape[2] != y:
+ x_offset = (pixels.shape[1] % self.downscale_ratio) // 2
+ y_offset = (pixels.shape[2] % self.downscale_ratio) // 2
+ pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
+ return pixels
+
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
- decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
- output = torch.clamp((
- (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
- comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
- comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
- / 3.0) / 2.0, min=0.0, max=1.0)
+ decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
+ output = self.process_output(
+ (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
+ comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
+ comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
+ / 3.0)
return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@@ -220,7 +268,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
- encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
+ encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@@ -235,10 +283,10 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
- pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
+ pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
- pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
+ pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
@@ -252,6 +300,7 @@ class VAE:
return output.movedim(1,-1)
def encode(self, pixel_samples):
+ pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
@@ -261,7 +310,7 @@ class VAE:
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
- pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
+ pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e:
@@ -271,6 +320,7 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
+ pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
@@ -297,8 +347,11 @@ def load_style_model(ckpt_path):
model.load_state_dict(model_data)
return StyleModel(model)
+class CLIPType(Enum):
+ STABLE_DIFFUSION = 1
+ STABLE_CASCADE = 2
-def load_clip(ckpt_paths, embedding_directory=None):
+def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
@@ -314,8 +367,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
clip_target.params = {}
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
- clip_target.clip = sdxl_clip.SDXLRefinerClipModel
- clip_target.tokenizer = sdxl_clip.SDXLTokenizer
+ if clip_type == CLIPType.STABLE_CASCADE:
+ clip_target.clip = sdxl_clip.StableCascadeClipModel
+ clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
+ else:
+ clip_target.clip = sdxl_clip.SDXLRefinerClipModel
+ clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
@@ -438,15 +495,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
- unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
- class WeightsLoader(torch.nn.Module):
- pass
-
- model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
- model_config.set_manual_cast(manual_cast_dtype)
+ model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@@ -462,18 +516,24 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
- vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
+ vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
if output_clip:
- w = WeightsLoader()
clip_target = model_config.clip_target()
if clip_target is not None:
- clip = CLIP(clip_target, embedding_directory=embedding_directory)
- w.cond_stage_model = clip.cond_stage_model
- sd = model_config.process_clip_state_dict(sd)
- load_model_weights(w, sd)
+ clip_sd = model_config.process_clip_state_dict(sd)
+ if len(clip_sd) > 0:
+ clip = CLIP(clip_target, embedding_directory=embedding_directory)
+ m, u = clip.load_sd(clip_sd, full_model=True)
+ if len(m) > 0:
+ print("clip missing:", m)
+
+ if len(u) > 0:
+ print("clip unexpected:", u)
+ else:
+ print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
left_over = sd.keys()
if len(left_over) > 0:
@@ -492,16 +552,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
- if "input_blocks.0.0.weight" in sd: #ldm
- model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
+ if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
+ model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None
new_sd = sd
else: #diffusers
- model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
+ model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
@@ -513,8 +572,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
+
offload_device = model_management.unet_offload_device()
- model_config.set_manual_cast(manual_cast_dtype)
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 65ea909fe..8287ad2e8 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
- special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
+ special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
@@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
- self.enable_attention_masks = False
+ self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py
index b35056bb9..3ce5c7e05 100644
--- a/comfy/sdxl_clip.py
+++ b/comfy/sdxl_clip.py
@@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
+
+
+class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, tokenizer_path=None, embedding_directory=None):
+ super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
+
+class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None):
+ super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
+
+class StableCascadeClipG(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
+ super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
+ special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
+
+ def load_sd(self, sd):
+ return super().load_sd(sd)
+
+class StableCascadeClipModel(sd1_clip.SD1ClipModel):
+ def __init__(self, device="cpu", dtype=None):
+ super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 1d442d4dd..5bb98d88a 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE):
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
replace_prefix = {}
- replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
- state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
+ replace_prefix["cond_stage_model."] = "clip_l."
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE):
def process_clip_state_dict(self, state_dict):
replace_prefix = {}
- replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
- state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
-
- state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
+ replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
+ replace_prefix["cond_stage_model.model."] = "clip_h."
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
+ state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
@@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE):
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
replace_prefix = {}
+ replace_prefix["conditioner.embedders.0.model."] = "clip_g."
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
- state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
- keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
- keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
-
+ state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
@@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE):
keys_to_replace = {}
replace_prefix = {}
- replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
- state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
- keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
- keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
- keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
+ replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
+ replace_prefix["conditioner.embedders.1.model."] = "clip_g."
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
+
+ state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
+ keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection"
- state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
@@ -306,5 +305,66 @@ class SD_X4Upscaler(SD20):
out = model_base.SD_X4Upscaler(self, device=device)
return out
-models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
+class Stable_Cascade_C(supported_models_base.BASE):
+ unet_config = {
+ "stable_cascade_stage": 'c',
+ }
+
+ unet_extra_config = {}
+
+ latent_format = latent_formats.SC_Prior
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ sampling_settings = {
+ "shift": 2.0,
+ }
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoder."]
+ clip_vision_prefix = "clip_l_vision."
+
+ def process_unet_state_dict(self, state_dict):
+ key_list = list(state_dict.keys())
+ for y in ["weight", "bias"]:
+ suffix = "in_proj_{}".format(y)
+ keys = filter(lambda a: a.endswith(suffix), key_list)
+ for k_from in keys:
+ weights = state_dict.pop(k_from)
+ prefix = k_from[:-(len(suffix) + 1)]
+ shape_from = weights.shape[0] // 3
+ for x in range(3):
+ p = ["to_q", "to_k", "to_v"]
+ k_to = "{}.{}.{}".format(prefix, p[x], y)
+ state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
+ return state_dict
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.StableCascade_C(self, device=device)
+ return out
+
+ def clip_target(self):
+ return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
+
+class Stable_Cascade_B(Stable_Cascade_C):
+ unet_config = {
+ "stable_cascade_stage": 'b',
+ }
+
+ unet_extra_config = {}
+
+ latent_format = latent_formats.SC_B
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ sampling_settings = {
+ "shift": 1.0,
+ }
+
+ clip_vision_prefix = None
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.StableCascade_B(self, device=device)
+ return out
+
+
+models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models += [SVD_img2vid]
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index 5baf4bca6..4d7e25936 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -21,13 +21,16 @@ class BASE:
noise_aug_config = None
sampling_settings = {}
latent_format = latent_formats.LatentFormat
+ vae_key_prefix = ["first_stage_model."]
+ text_encoder_key_prefix = ["cond_stage_model."]
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
- if s.unet_config[k] != unet_config[k]:
+ if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
return True
@@ -53,6 +56,7 @@ class BASE:
return out
def process_clip_state_dict(self, state_dict):
+ state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
return state_dict
def process_unet_state_dict(self, state_dict):
@@ -62,7 +66,7 @@ class BASE:
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
- replace_prefix = {"": "cond_stage_model."}
+ replace_prefix = {"": self.text_encoder_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
@@ -76,8 +80,9 @@ class BASE:
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict):
- replace_prefix = {"": "first_stage_model."}
+ replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
- def set_manual_cast(self, manual_cast_dtype):
+ def set_inference_dtype(self, dtype, manual_cast_dtype):
+ self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype
diff --git a/comfy/utils.py b/comfy/utils.py
index f8026ddab..04cf76ed6 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -169,6 +169,8 @@ UNET_MAP_BASIC = {
}
def unet_to_diffusers(unet_config):
+ if "num_res_blocks" not in unet_config:
+ return {}
num_res_blocks = unet_config["num_res_blocks"]
channel_mult = unet_config["channel_mult"]
transformer_depth = unet_config["transformer_depth"][:]
@@ -413,6 +415,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
+ x = max(0, min(s.shape[-1] - overlap, x))
+ y = max(0, min(s.shape[-2] - overlap, y))
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
ps = function(s_in).to(output_device)
diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py
new file mode 100644
index 000000000..646fefa17
--- /dev/null
+++ b/comfy_extras/nodes_cond.py
@@ -0,0 +1,25 @@
+
+
+class CLIPTextEncodeControlnet:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}}
+ RETURN_TYPES = ("CONDITIONING",)
+ FUNCTION = "encode"
+
+ CATEGORY = "_for_testing/conditioning"
+
+ def encode(self, clip, conditioning, text):
+ tokens = clip.tokenize(text)
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
+ c = []
+ for t in conditioning:
+ n = [t[0], t[1].copy()]
+ n[1]['cross_attn_controlnet'] = cond
+ n[1]['pooled_output_controlnet'] = pooled
+ c.append(n)
+ return (c, )
+
+NODE_CLASS_MAPPINGS = {
+ "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
+}
diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py
index aa80f5269..8f638bf8f 100644
--- a/comfy_extras/nodes_images.py
+++ b/comfy_extras/nodes_images.py
@@ -48,6 +48,25 @@ class RepeatImageBatch:
s = image.repeat((amount, 1,1,1))
return (s,)
+class ImageFromBatch:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "image": ("IMAGE",),
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
+ "length": ("INT", {"default": 1, "min": 1, "max": 64}),
+ }}
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "frombatch"
+
+ CATEGORY = "image/batch"
+
+ def frombatch(self, image, batch_index, length):
+ s_in = image
+ batch_index = min(s_in.shape[0] - 1, batch_index)
+ length = min(s_in.shape[0] - batch_index, length)
+ s = s_in[batch_index:batch_index + length].clone()
+ return (s,)
+
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@@ -170,6 +189,7 @@ class SaveAnimatedPNG:
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
+ "ImageFromBatch": ImageFromBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
}
diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py
index b7fd8cd68..eabae0885 100644
--- a/comfy_extras/nodes_latent.py
+++ b/comfy_extras/nodes_latent.py
@@ -126,7 +126,7 @@ class LatentBatchSeedBehavior:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
- "seed_behavior": (["random", "fixed"],),}}
+ "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py
index 541ce8fa5..1b3f3945e 100644
--- a/comfy_extras/nodes_model_advanced.py
+++ b/comfy_extras/nodes_model_advanced.py
@@ -17,6 +17,10 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
+class X0(comfy.model_sampling.EPS):
+ def calculate_denoised(self, sigma, model_output, model_input):
+ return model_output
+
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
@@ -68,7 +72,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
- "sampling": (["eps", "v_prediction", "lcm"],),
+ "sampling": (["eps", "v_prediction", "lcm", "x0"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
@@ -88,6 +92,8 @@ class ModelSamplingDiscrete:
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
+ elif sampling == "x0":
+ sampling_type = X0
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
@@ -99,6 +105,32 @@ class ModelSamplingDiscrete:
m.add_object_patch("model_sampling", model_sampling)
return (m, )
+class ModelSamplingStableCascade:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "model": ("MODEL",),
+ "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}),
+ }}
+
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "advanced/model"
+
+ def patch(self, model, shift):
+ m = model.clone()
+
+ sampling_base = comfy.model_sampling.StableCascadeSampling
+ sampling_type = comfy.model_sampling.EPS
+
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
+ pass
+
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
+ model_sampling.set_parameters(shift)
+ m.add_object_patch("model_sampling", model_sampling)
+ return (m, )
+
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
@@ -171,5 +203,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
+ "ModelSamplingStableCascade": ModelSamplingStableCascade,
"RescaleCFG": RescaleCFG,
}
diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py
new file mode 100644
index 000000000..b795d0083
--- /dev/null
+++ b/comfy_extras/nodes_stable_cascade.py
@@ -0,0 +1,109 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+import nodes
+import comfy.utils
+
+
+class StableCascade_EmptyLatentImage:
+ def __init__(self, device="cpu"):
+ self.device = device
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
+ "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
+ "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
+ }}
+ RETURN_TYPES = ("LATENT", "LATENT")
+ RETURN_NAMES = ("stage_c", "stage_b")
+ FUNCTION = "generate"
+
+ CATEGORY = "_for_testing/stable_cascade"
+
+ def generate(self, width, height, compression, batch_size=1):
+ c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
+ b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
+ return ({
+ "samples": c_latent,
+ }, {
+ "samples": b_latent,
+ })
+
+class StableCascade_StageC_VAEEncode:
+ def __init__(self, device="cpu"):
+ self.device = device
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "image": ("IMAGE",),
+ "vae": ("VAE", ),
+ "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
+ }}
+ RETURN_TYPES = ("LATENT", "LATENT")
+ RETURN_NAMES = ("stage_c", "stage_b")
+ FUNCTION = "generate"
+
+ CATEGORY = "_for_testing/stable_cascade"
+
+ def generate(self, image, vae, compression):
+ width = image.shape[-2]
+ height = image.shape[-3]
+ out_width = (width // compression) * vae.downscale_ratio
+ out_height = (height // compression) * vae.downscale_ratio
+
+ s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
+
+ c_latent = vae.encode(s[:,:,:,:3])
+ b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4])
+ return ({
+ "samples": c_latent,
+ }, {
+ "samples": b_latent,
+ })
+
+class StableCascade_StageB_Conditioning:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "conditioning": ("CONDITIONING",),
+ "stage_c": ("LATENT",),
+ }}
+ RETURN_TYPES = ("CONDITIONING",)
+
+ FUNCTION = "set_prior"
+
+ CATEGORY = "_for_testing/stable_cascade"
+
+ def set_prior(self, conditioning, stage_c):
+ c = []
+ for t in conditioning:
+ d = t[1].copy()
+ d['stable_cascade_prior'] = stage_c['samples']
+ n = [t[0], d]
+ c.append(n)
+ return (c, )
+
+NODE_CLASS_MAPPINGS = {
+ "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
+ "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
+ "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
+}
diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example
index 733014f3c..7ce271ec6 100644
--- a/custom_nodes/example_node.py.example
+++ b/custom_nodes/example_node.py.example
@@ -6,6 +6,8 @@ class Example:
-------------
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
+ IS_CHANGED:
+ optional method to control when the node is re executed.
Attributes
----------
@@ -89,6 +91,17 @@ class Example:
image = 1.0 - image
return (image,)
+ """
+ The node will always be re executed if any of the inputs change but
+ this method can be used to force the node to execute again even when the inputs don't change.
+ You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
+ executed, if it is different the node will be executed again.
+ This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
+ changes between executions the LoadImage node is executed again.
+ """
+ #@classmethod
+ #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
+ # return ""
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
diff --git a/custom_nodes/websocket_image_save.py.disabled b/custom_nodes/websocket_image_save.py.disabled
new file mode 100644
index 000000000..b85a5de8b
--- /dev/null
+++ b/custom_nodes/websocket_image_save.py.disabled
@@ -0,0 +1,49 @@
+from PIL import Image, ImageOps
+from io import BytesIO
+import numpy as np
+import struct
+import comfy.utils
+import time
+
+#You can use this node to save full size images through the websocket, the
+#images will be sent in exactly the same format as the image previews: as
+#binary images on the websocket with a 8 byte header indicating the type
+#of binary message (first 4 bytes) and the image format (next 4 bytes).
+
+#The reason this node is disabled by default is because there is a small
+#issue when using it with the default ComfyUI web interface: When generating
+#batches only the last image will be shown in the UI.
+
+#Note that no metadata will be put in the images saved with this node.
+
+class SaveImageWebsocket:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"images": ("IMAGE", ),}
+ }
+
+ RETURN_TYPES = ()
+ FUNCTION = "save_images"
+
+ OUTPUT_NODE = True
+
+ CATEGORY = "image"
+
+ def save_images(self, images):
+ pbar = comfy.utils.ProgressBar(images.shape[0])
+ step = 0
+ for image in images:
+ i = 255. * image.cpu().numpy()
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
+ pbar.update_absolute(step, images.shape[0], ("PNG", img, None))
+ step += 1
+
+ return {}
+
+ def IS_CHANGED(s, images):
+ return time.time()
+
+NODE_CLASS_MAPPINGS = {
+ "SaveImageWebsocket": SaveImageWebsocket,
+}
diff --git a/nodes.py b/nodes.py
index 4ad35f79b..a577c2126 100644
--- a/nodes.py
+++ b/nodes.py
@@ -184,6 +184,26 @@ class ConditioningSetAreaPercentage:
c.append(n)
return (c, )
+class ConditioningSetAreaStrength:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"conditioning": ("CONDITIONING", ),
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("CONDITIONING",)
+ FUNCTION = "append"
+
+ CATEGORY = "conditioning"
+
+ def append(self, conditioning, strength):
+ c = []
+ for t in conditioning:
+ n = [t[0], t[1].copy()]
+ n[1]['strength'] = strength
+ c.append(n)
+ return (c, )
+
+
class ConditioningSetMask:
@classmethod
def INPUT_TYPES(s):
@@ -289,18 +309,7 @@ class VAEEncode:
CATEGORY = "latent"
- @staticmethod
- def vae_encode_crop_pixels(pixels):
- x = (pixels.shape[1] // 8) * 8
- y = (pixels.shape[2] // 8) * 8
- if pixels.shape[1] != x or pixels.shape[2] != y:
- x_offset = (pixels.shape[1] % 8) // 2
- y_offset = (pixels.shape[2] % 8) // 2
- pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
- return pixels
-
def encode(self, vae, pixels):
- pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, )
@@ -316,7 +325,6 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size):
- pixels = VAEEncode.vae_encode_crop_pixels(pixels)
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
return ({"samples":t}, )
@@ -330,14 +338,14 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6):
- x = (pixels.shape[1] // 8) * 8
- y = (pixels.shape[2] // 8) * 8
+ x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio
+ y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
- x_offset = (pixels.shape[1] % 8) // 2
- y_offset = (pixels.shape[2] % 8) // 2
+ x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2
+ y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
@@ -834,15 +842,20 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
+ "type": (["stable_diffusion", "stable_cascade"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
- def load_clip(self, clip_name):
+ def load_clip(self, clip_name, type="stable_diffusion"):
+ clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
+ if type == "stable_cascade":
+ clip_type = comfy.sd.CLIPType.STABLE_CASCADE
+
clip_path = folder_paths.get_full_path("clip", clip_name)
- clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
class DualCLIPLoader:
@@ -1414,7 +1427,7 @@ class SaveImage:
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
- for image in images:
+ for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None
@@ -1426,7 +1439,8 @@ class SaveImage:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
- file = f"{filename}_{counter:05}_.png"
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
+ file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({
"filename": file,
@@ -1754,6 +1768,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningConcat": ConditioningConcat,
"ConditioningSetArea": ConditioningSetArea,
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
+ "ConditioningSetAreaStrength": ConditioningSetAreaStrength,
"ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,
@@ -1944,6 +1959,8 @@ def init_custom_nodes():
"nodes_stable3d.py",
"nodes_sdupscale.py",
"nodes_photomaker.py",
+ "nodes_cond.py",
+ "nodes_stable_cascade.py",
]
for node_file in extras_files:
diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js
index b78d33aac..157167edf 100644
--- a/web/extensions/core/groupNode.js
+++ b/web/extensions/core/groupNode.js
@@ -910,6 +910,9 @@ export class GroupNodeHandler {
const self = this;
const onNodeCreated = this.node.onNodeCreated;
this.node.onNodeCreated = function () {
+ if (!this.widgets) {
+ return;
+ }
const config = self.groupData.nodeData.config;
if (config) {
for (const n in config) {
diff --git a/web/extensions/core/groupNodeManage.css b/web/extensions/core/groupNodeManage.css
index 5ac89aee3..5470ecb5e 100644
--- a/web/extensions/core/groupNodeManage.css
+++ b/web/extensions/core/groupNodeManage.css
@@ -48,7 +48,7 @@
list-style: none;
}
.comfy-group-manage-list-items {
- max-height: 70vh;
+ max-height: calc(100% - 40px);
overflow-y: scroll;
overflow-x: hidden;
}
diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js
index bb2f16d42..4f69ac760 100644
--- a/web/extensions/core/maskeditor.js
+++ b/web/extensions/core/maskeditor.js
@@ -62,7 +62,7 @@ async function uploadMask(filepath, formData) {
ClipspaceDialog.invalidatePreview();
}
-function prepare_mask(image, maskCanvas, maskCtx) {
+function prepare_mask(image, maskCanvas, maskCtx, maskColor) {
// paste mask data into alpha channel
maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height);
const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height);
@@ -74,9 +74,9 @@ function prepare_mask(image, maskCanvas, maskCtx) {
else
maskData.data[i+3] = 255;
- maskData.data[i] = 0;
- maskData.data[i+1] = 0;
- maskData.data[i+2] = 0;
+ maskData.data[i] = maskColor.r;
+ maskData.data[i+1] = maskColor.g;
+ maskData.data[i+2] = maskColor.b;
}
maskCtx.globalCompositeOperation = 'source-over';
@@ -110,6 +110,7 @@ class MaskEditorDialog extends ComfyDialog {
createButton(name, callback) {
var button = document.createElement("button");
+ button.style.pointerEvents = "auto";
button.innerText = name;
button.addEventListener("click", callback);
return button;
@@ -146,6 +147,7 @@ class MaskEditorDialog extends ComfyDialog {
divElement.style.display = "flex";
divElement.style.position = "relative";
divElement.style.top = "2px";
+ divElement.style.pointerEvents = "auto";
self.brush_slider_input = document.createElement('input');
self.brush_slider_input.setAttribute('type', 'range');
self.brush_slider_input.setAttribute('min', '1');
@@ -173,6 +175,7 @@ class MaskEditorDialog extends ComfyDialog {
bottom_panel.style.left = "20px";
bottom_panel.style.right = "20px";
bottom_panel.style.height = "50px";
+ bottom_panel.style.pointerEvents = "none";
var brush = document.createElement("div");
brush.id = "brush";
@@ -191,14 +194,29 @@ class MaskEditorDialog extends ComfyDialog {
this.element.appendChild(bottom_panel);
document.body.appendChild(brush);
+ var clearButton = this.createLeftButton("Clear", () => {
+ self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
+ });
+
this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
self.brush_size = event.target.value;
self.updateBrushPreview(self, null, null);
});
- var clearButton = this.createLeftButton("Clear",
- () => {
- self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
- });
+
+ this.colorButton = this.createLeftButton(this.getColorButtonText(), () => {
+ if (self.brush_color_mode === "black") {
+ self.brush_color_mode = "white";
+ }
+ else if (self.brush_color_mode === "white") {
+ self.brush_color_mode = "negative";
+ }
+ else {
+ self.brush_color_mode = "black";
+ }
+
+ self.updateWhenBrushColorModeChanged();
+ });
+
var cancelButton = this.createRightButton("Cancel", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
@@ -219,6 +237,7 @@ class MaskEditorDialog extends ComfyDialog {
bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(this.brush_size_slider);
+ bottom_panel.appendChild(this.colorButton);
imgCanvas.style.position = "absolute";
maskCanvas.style.position = "absolute";
@@ -228,6 +247,10 @@ class MaskEditorDialog extends ComfyDialog {
maskCanvas.style.top = imgCanvas.style.top;
maskCanvas.style.left = imgCanvas.style.left;
+
+ const maskCanvasStyle = this.getMaskCanvasStyle();
+ maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode;
+ maskCanvas.style.opacity = maskCanvasStyle.opacity;
}
async show() {
@@ -313,7 +336,7 @@ class MaskEditorDialog extends ComfyDialog {
let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true });
imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height);
- prepare_mask(mask_image, this.maskCanvas, maskCtx);
+ prepare_mask(mask_image, this.maskCanvas, maskCtx, this.getMaskColor());
}
async setImages(imgCanvas) {
@@ -439,7 +462,84 @@ class MaskEditorDialog extends ComfyDialog {
}
}
+ getMaskCanvasStyle() {
+ if (this.brush_color_mode === "negative") {
+ return {
+ mixBlendMode: "difference",
+ opacity: "1",
+ };
+ }
+ else {
+ return {
+ mixBlendMode: "initial",
+ opacity: "0.7",
+ };
+ }
+ }
+
+ getMaskColor() {
+ if (this.brush_color_mode === "black") {
+ return { r: 0, g: 0, b: 0 };
+ }
+ if (this.brush_color_mode === "white") {
+ return { r: 255, g: 255, b: 255 };
+ }
+ if (this.brush_color_mode === "negative") {
+ // negative effect only works with white color
+ return { r: 255, g: 255, b: 255 };
+ }
+
+ return { r: 0, g: 0, b: 0 };
+ }
+
+ getMaskFillStyle() {
+ const maskColor = this.getMaskColor();
+
+ return "rgb(" + maskColor.r + "," + maskColor.g + "," + maskColor.b + ")";
+ }
+
+ getColorButtonText() {
+ let colorCaption = "unknown";
+
+ if (this.brush_color_mode === "black") {
+ colorCaption = "black";
+ }
+ else if (this.brush_color_mode === "white") {
+ colorCaption = "white";
+ }
+ else if (this.brush_color_mode === "negative") {
+ colorCaption = "negative";
+ }
+
+ return "Color: " + colorCaption;
+ }
+
+ updateWhenBrushColorModeChanged() {
+ this.colorButton.innerText = this.getColorButtonText();
+
+ // update mask canvas css styles
+
+ const maskCanvasStyle = this.getMaskCanvasStyle();
+ this.maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode;
+ this.maskCanvas.style.opacity = maskCanvasStyle.opacity;
+
+ // update mask canvas rgb colors
+
+ const maskColor = this.getMaskColor();
+
+ const maskData = this.maskCtx.getImageData(0, 0, this.maskCanvas.width, this.maskCanvas.height);
+
+ for (let i = 0; i < maskData.data.length; i += 4) {
+ maskData.data[i] = maskColor.r;
+ maskData.data[i+1] = maskColor.g;
+ maskData.data[i+2] = maskColor.b;
+ }
+
+ this.maskCtx.putImageData(maskData, 0, 0);
+ }
+
brush_size = 10;
+ brush_color_mode = "black";
drawing_mode = false;
lastx = -1;
lasty = -1;
@@ -518,6 +618,19 @@ class MaskEditorDialog extends ComfyDialog {
event.preventDefault();
self.pan_move(self, event);
}
+
+ let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1;
+
+ if(event.shiftKey && left_button_down) {
+ self.drawing_mode = false;
+
+ const y = event.clientY;
+ let delta = (self.zoom_lasty - y)*0.005;
+ self.zoom_ratio = Math.max(Math.min(10.0, self.last_zoom_ratio - delta), 0.2);
+
+ this.invalidatePanZoom();
+ return;
+ }
}
pan_move(self, event) {
@@ -535,7 +648,7 @@ class MaskEditorDialog extends ComfyDialog {
}
draw_move(self, event) {
- if(event.ctrlKey) {
+ if(event.ctrlKey || event.shiftKey) {
return;
}
@@ -546,7 +659,10 @@ class MaskEditorDialog extends ComfyDialog {
self.updateBrushPreview(self);
- if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) {
+ let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1;
+ let right_button_down = [2, 5, 32].includes(event.buttons);
+
+ if (!event.altKey && left_button_down) {
var diff = performance.now() - self.lasttime;
const maskRect = self.maskCanvas.getBoundingClientRect();
@@ -581,7 +697,7 @@ class MaskEditorDialog extends ComfyDialog {
if(diff > 20 && !this.drawing_mode)
requestAnimationFrame(() => {
self.maskCtx.beginPath();
- self.maskCtx.fillStyle = "rgb(0,0,0)";
+ self.maskCtx.fillStyle = this.getMaskFillStyle();
self.maskCtx.globalCompositeOperation = "source-over";
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
@@ -591,7 +707,7 @@ class MaskEditorDialog extends ComfyDialog {
else
requestAnimationFrame(() => {
self.maskCtx.beginPath();
- self.maskCtx.fillStyle = "rgb(0,0,0)";
+ self.maskCtx.fillStyle = this.getMaskFillStyle();
self.maskCtx.globalCompositeOperation = "source-over";
var dx = x - self.lastx;
@@ -613,7 +729,7 @@ class MaskEditorDialog extends ComfyDialog {
self.lasttime = performance.now();
}
- else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
+ else if((event.altKey && left_button_down) || right_button_down) {
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
@@ -687,13 +803,20 @@ class MaskEditorDialog extends ComfyDialog {
self.drawing_mode = true;
event.preventDefault();
+
+ if(event.shiftKey) {
+ self.zoom_lasty = event.clientY;
+ self.last_zoom_ratio = self.zoom_ratio;
+ return;
+ }
+
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
self.maskCtx.beginPath();
- if (event.button == 0) {
- self.maskCtx.fillStyle = "rgb(0,0,0)";
+ if (!event.altKey && event.button == 0) {
+ self.maskCtx.fillStyle = this.getMaskFillStyle();
self.maskCtx.globalCompositeOperation = "source-over";
} else {
self.maskCtx.globalCompositeOperation = "destination-out";
diff --git a/web/extensions/core/simpleTouchSupport.js b/web/extensions/core/simpleTouchSupport.js
new file mode 100644
index 000000000..041fc2c4c
--- /dev/null
+++ b/web/extensions/core/simpleTouchSupport.js
@@ -0,0 +1,102 @@
+import { app } from "../../scripts/app.js";
+
+let touchZooming;
+let touchCount = 0;
+
+app.registerExtension({
+ name: "Comfy.SimpleTouchSupport",
+ setup() {
+ let zoomPos;
+ let touchTime;
+ let lastTouch;
+
+ function getMultiTouchPos(e) {
+ return Math.hypot(e.touches[0].clientX - e.touches[1].clientX, e.touches[0].clientY - e.touches[1].clientY);
+ }
+
+ app.canvasEl.addEventListener(
+ "touchstart",
+ (e) => {
+ touchCount++;
+ lastTouch = null;
+ if (e.touches?.length === 1) {
+ // Store start time for press+hold for context menu
+ touchTime = new Date();
+ lastTouch = e.touches[0];
+ } else {
+ touchTime = null;
+ if (e.touches?.length === 2) {
+ // Store center pos for zoom
+ zoomPos = getMultiTouchPos(e);
+ app.canvas.pointer_is_down = false;
+ }
+ }
+ },
+ true
+ );
+
+ app.canvasEl.addEventListener("touchend", (e) => {
+ touchZooming = false;
+ touchCount = e.touches?.length ?? touchCount - 1;
+ if (touchTime && !e.touches?.length) {
+ if (new Date() - touchTime > 600) {
+ try {
+ // hack to get litegraph to use this event
+ e.constructor = CustomEvent;
+ } catch (error) {}
+ e.clientX = lastTouch.clientX;
+ e.clientY = lastTouch.clientY;
+
+ app.canvas.pointer_is_down = true;
+ app.canvas._mousedown_callback(e);
+ }
+ touchTime = null;
+ }
+ });
+
+ app.canvasEl.addEventListener(
+ "touchmove",
+ (e) => {
+ touchTime = null;
+ if (e.touches?.length === 2) {
+ app.canvas.pointer_is_down = false;
+ touchZooming = true;
+ LiteGraph.closeAllContextMenus();
+ app.canvas.search_box?.close();
+ const newZoomPos = getMultiTouchPos(e);
+
+ const midX = (e.touches[0].clientX + e.touches[1].clientX) / 2;
+ const midY = (e.touches[0].clientY + e.touches[1].clientY) / 2;
+
+ let scale = app.canvas.ds.scale;
+ const diff = zoomPos - newZoomPos;
+ if (diff > 0.5) {
+ scale *= 1 / 1.07;
+ } else if (diff < -0.5) {
+ scale *= 1.07;
+ }
+ app.canvas.ds.changeScale(scale, [midX, midY]);
+ app.canvas.setDirty(true, true);
+ zoomPos = newZoomPos;
+ }
+ },
+ true
+ );
+ },
+});
+
+const processMouseDown = LGraphCanvas.prototype.processMouseDown;
+LGraphCanvas.prototype.processMouseDown = function (e) {
+ if (touchZooming || touchCount) {
+ return;
+ }
+ return processMouseDown.apply(this, arguments);
+};
+
+const processMouseMove = LGraphCanvas.prototype.processMouseMove;
+LGraphCanvas.prototype.processMouseMove = function (e) {
+ if (touchZooming || touchCount > 1) {
+ return;
+ }
+ return processMouseMove.apply(this, arguments);
+};
diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js
index f89c731e6..0f41d10f8 100644
--- a/web/extensions/core/widgetInputs.js
+++ b/web/extensions/core/widgetInputs.js
@@ -22,6 +22,7 @@ function isConvertableWidget(widget, config) {
}
function hideWidget(node, widget, suffix = "") {
+ if (widget.type?.startsWith(CONVERTED_TYPE)) return;
widget.origType = widget.type;
widget.origComputeSize = widget.computeSize;
widget.origSerializeValue = widget.serializeValue;
@@ -260,6 +261,12 @@ app.registerExtension({
async beforeRegisterNodeDef(nodeType, nodeData, app) {
// Add menu options to conver to/from widgets
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
+ nodeType.prototype.convertWidgetToInput = function (widget) {
+ const config = getConfig.call(this, widget.name) ?? [widget.type, widget.options || {}];
+ if (!isConvertableWidget(widget, config)) return false;
+ convertToInput(this, widget, config);
+ return true;
+ };
nodeType.prototype.getExtraMenuOptions = function (_, options) {
const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : undefined;
diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js
index 080e0ef47..4ff05ae81 100644
--- a/web/lib/litegraph.core.js
+++ b/web/lib/litegraph.core.js
@@ -11549,7 +11549,7 @@ LGraphNode.prototype.executeAction = function(action)
dialog.close();
} else if (e.keyCode == 13) {
if (selected) {
- select(selected.innerHTML);
+ select(unescape(selected.dataset["type"]));
} else if (first) {
select(first);
} else {
@@ -11910,7 +11910,7 @@ LGraphNode.prototype.executeAction = function(action)
var ctor = LiteGraph.registered_node_types[ type ];
if(filter && ctor.filter != filter )
return false;
- if ((!options.show_all_if_empty || str) && type.toLowerCase().indexOf(str) === -1)
+ if ((!options.show_all_if_empty || str) && type.toLowerCase().indexOf(str) === -1 && (!ctor.title || ctor.title.toLowerCase().indexOf(str) === -1))
return false;
// filter by slot IN, OUT types
@@ -11964,7 +11964,18 @@ LGraphNode.prototype.executeAction = function(action)
if (!first) {
first = type;
}
- help.innerText = type;
+
+ const nodeType = LiteGraph.registered_node_types[type];
+ if (nodeType?.title) {
+ help.innerText = nodeType?.title;
+ const typeEl = document.createElement("span");
+ typeEl.className = "litegraph lite-search-item-type";
+ typeEl.textContent = type;
+ help.append(typeEl);
+ } else {
+ help.innerText = type;
+ }
+
help.dataset["type"] = escape(type);
help.className = "litegraph lite-search-item";
if (className) {
diff --git a/web/lib/litegraph.css b/web/lib/litegraph.css
index 918858f41..5524e24ba 100644
--- a/web/lib/litegraph.css
+++ b/web/lib/litegraph.css
@@ -184,6 +184,7 @@
color: white;
padding-left: 10px;
margin-right: 5px;
+ max-width: 300px;
}
.litegraph.litesearchbox .name {
@@ -227,6 +228,18 @@
color: black;
}
+.litegraph.lite-search-item-type {
+ display: inline-block;
+ background: rgba(0,0,0,0.2);
+ margin-left: 5px;
+ font-size: 14px;
+ padding: 2px 5px;
+ position: relative;
+ top: -2px;
+ opacity: 0.8;
+ border-radius: 4px;
+ }
+
/* DIALOGs ******/
.litegraph .dialog {
diff --git a/web/scripts/api.js b/web/scripts/api.js
index ae3fbd13a..c43255949 100644
--- a/web/scripts/api.js
+++ b/web/scripts/api.js
@@ -5,6 +5,7 @@ class ComfyApi extends EventTarget {
super();
this.api_host = location.host;
this.api_base = location.pathname.split('/').slice(0, -1).join('/');
+ this.initialClientId = sessionStorage.getItem("clientId");
}
apiURL(route) {
@@ -118,7 +119,8 @@ class ComfyApi extends EventTarget {
case "status":
if (msg.data.sid) {
this.clientId = msg.data.sid;
- window.name = this.clientId;
+ window.name = this.clientId; // use window name so it isnt reused when duplicating tabs
+ sessionStorage.setItem("clientId", this.clientId); // store in session storage so duplicate tab can load correct workflow
}
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
diff --git a/web/scripts/app.js b/web/scripts/app.js
index d16878454..41b769dd6 100644
--- a/web/scripts/app.js
+++ b/web/scripts/app.js
@@ -1499,12 +1499,17 @@ export class ComfyApp {
// Load previous workflow
let restored = false;
try {
- const json = localStorage.getItem("workflow");
- if (json) {
- const workflow = JSON.parse(json);
- await this.loadGraphData(workflow);
- restored = true;
- }
+ const loadWorkflow = async (json) => {
+ if (json) {
+ const workflow = JSON.parse(json);
+ await this.loadGraphData(workflow);
+ return true;
+ }
+ };
+ const clientId = api.initialClientId ?? api.clientId;
+ restored =
+ (clientId && (await loadWorkflow(sessionStorage.getItem(`workflow:${clientId}`)))) ||
+ (await loadWorkflow(localStorage.getItem("workflow")));
} catch (err) {
console.error("Error loading previous workflow", err);
}
@@ -1515,7 +1520,13 @@ export class ComfyApp {
}
// Save current workflow automatically
- setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())), 1000);
+ setInterval(() => {
+ const workflow = JSON.stringify(this.graph.serialize());
+ localStorage.setItem("workflow", workflow);
+ if (api.clientId) {
+ sessionStorage.setItem(`workflow:${api.clientId}`, workflow);
+ }
+ }, 1000);
this.#addDrawNodeHandler();
this.#addDrawGroupsHandler();
@@ -2096,6 +2107,8 @@ export class ComfyApp {
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
} else if (pngInfo.prompt) {
this.loadApiJson(JSON.parse(pngInfo.prompt));
+ } else if (pngInfo.Prompt) {
+ this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node.
}
}
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
@@ -2149,8 +2162,17 @@ export class ComfyApp {
if (value instanceof Array) {
const [fromId, fromSlot] = value;
const fromNode = app.graph.getNodeById(fromId);
- const toSlot = node.inputs?.findIndex((inp) => inp.name === input);
- if (toSlot !== -1) {
+ let toSlot = node.inputs?.findIndex((inp) => inp.name === input);
+ if (toSlot == null || toSlot === -1) {
+ try {
+ // Target has no matching input, most likely a converted widget
+ const widget = node.widgets?.find((w) => w.name === input);
+ if (widget && node.convertWidgetToInput?.(widget)) {
+ toSlot = node.inputs?.length - 1;
+ }
+ } catch (error) {}
+ }
+ if (toSlot != null || toSlot !== -1) {
fromNode.connect(fromSlot, node, toSlot);
}
} else {
diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js
index 83a4ebc86..169609209 100644
--- a/web/scripts/pnginfo.js
+++ b/web/scripts/pnginfo.js
@@ -24,7 +24,7 @@ export function getPngMetadata(file) {
const length = dataView.getUint32(offset);
// Get the chunk type
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
- if (type === "tEXt" || type == "comf") {
+ if (type === "tEXt" || type == "comf" || type === "iTXt") {
// Get the keyword
let keyword_end = offset + 8;
while (pngData[keyword_end] !== 0) {
@@ -33,7 +33,7 @@ export function getPngMetadata(file) {
const keyword = String.fromCharCode(...pngData.slice(offset + 8, keyword_end));
// Get the text
const contentArraySegment = pngData.slice(keyword_end + 1, offset + 8 + length);
- const contentJson = Array.from(contentArraySegment).map(s=>String.fromCharCode(s)).join('')
+ const contentJson = new TextDecoder("utf-8").decode(contentArraySegment);
txt_chunks[keyword] = contentJson;
}
diff --git a/web/scripts/ui.js b/web/scripts/ui.js
index d69434993..027bf4a3f 100644
--- a/web/scripts/ui.js
+++ b/web/scripts/ui.js
@@ -401,18 +401,42 @@ export class ComfyUI {
}
});
- this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [
- $el("div.drag-handle", {
+ this.menuHamburger = $el(
+ "div.comfy-menu-hamburger",
+ {
+ parent: document.body,
+ onclick: () => {
+ this.menuContainer.style.display = "block";
+ this.menuHamburger.style.display = "none";
+ },
+ },
+ [$el("div"), $el("div"), $el("div")]
+ );
+
+ this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
+ $el("div.drag-handle.comfy-menu-header", {
style: {
overflow: "hidden",
position: "relative",
width: "100%",
cursor: "default"
}
- }, [
+ }, [
$el("span.drag-handle"),
- $el("span", {$: (q) => (this.queueSize = q)}),
- $el("button.comfy-settings-btn", {textContent: "⚙️", onclick: () => this.settings.show()}),
+ $el("span.comfy-menu-queue-size", { $: (q) => (this.queueSize = q) }),
+ $el("div.comfy-menu-actions", [
+ $el("button.comfy-settings-btn", {
+ textContent: "⚙️",
+ onclick: () => this.settings.show(),
+ }),
+ $el("button.comfy-close-menu-btn", {
+ textContent: "\u00d7",
+ onclick: () => {
+ this.menuContainer.style.display = "none";
+ this.menuHamburger.style.display = "flex";
+ },
+ }),
+ ]),
]),
$el("button.comfy-queue-btn", {
id: "queue-button",
diff --git a/web/scripts/ui/settings.js b/web/scripts/ui/settings.js
index 1cdba5cfe..9e9d13af0 100644
--- a/web/scripts/ui/settings.js
+++ b/web/scripts/ui/settings.js
@@ -16,7 +16,17 @@ export class ComfySettingsDialog extends ComfyDialog {
},
[
$el("table.comfy-modal-content.comfy-table", [
- $el("caption", { textContent: "Settings" }),
+ $el(
+ "caption",
+ { textContent: "Settings" },
+ $el("button.comfy-btn", {
+ type: "button",
+ textContent: "\u00d7",
+ onclick: () => {
+ this.element.close();
+ },
+ })
+ ),
$el("tbody", { $: (tbody) => (this.textElement = tbody) }),
$el("button", {
type: "button",
diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js
index 0529b1d80..678b1b8ec 100644
--- a/web/scripts/widgets.js
+++ b/web/scripts/widgets.js
@@ -81,6 +81,9 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
const isCombo = targetWidget.type === "combo";
let comboFilter;
+ if (isCombo) {
+ valueControl.options.values.push("increment-wrap");
+ }
if (isCombo && options.addFilterList !== false) {
comboFilter = node.addWidget(
"string",
@@ -128,6 +131,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
case "increment":
current_index += 1;
break;
+ case "increment-wrap":
+ current_index += 1;
+ if ( current_index >= current_length ) {
+ current_index = 0;
+ }
+ break;
case "decrement":
current_index -= 1;
break;
@@ -295,7 +304,7 @@ export const ComfyWidgets = {
let disable_rounding = app.ui.settings.getSettingValue("Comfy.DisableFloatRounding")
if (precision == 0) precision = undefined;
const { val, config } = getNumberDefaults(inputData, 0.5, precision, !disable_rounding);
- return { widget: node.addWidget(widgetType, inputName, val,
+ return { widget: node.addWidget(widgetType, inputName, val,
function (v) {
if (config.round) {
this.value = Math.round(v/config.round)*config.round;
diff --git a/web/style.css b/web/style.css
index 44ee60198..cf7a8b9ea 100644
--- a/web/style.css
+++ b/web/style.css
@@ -82,6 +82,24 @@ body {
margin: 3px 3px 3px 4px;
}
+.comfy-menu-hamburger {
+ position: fixed;
+ top: 10px;
+ z-index: 9999;
+ right: 10px;
+ width: 30px;
+ display: none;
+ gap: 8px;
+ flex-direction: column;
+ cursor: pointer;
+}
+.comfy-menu-hamburger div {
+ height: 3px;
+ width: 100%;
+ border-radius: 20px;
+ background-color: white;
+}
+
.comfy-menu {
font-size: 15px;
position: absolute;
@@ -101,6 +119,44 @@ body {
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
}
+.comfy-menu-header {
+ display: flex;
+}
+
+.comfy-menu-actions {
+ display: flex;
+ gap: 3px;
+ align-items: center;
+ height: 20px;
+ position: relative;
+ top: -1px;
+ font-size: 22px;
+}
+
+.comfy-menu .comfy-menu-actions button {
+ background-color: rgba(0, 0, 0, 0);
+ padding: 0;
+ border: none;
+ cursor: pointer;
+ font-size: inherit;
+}
+
+.comfy-menu .comfy-menu-actions .comfy-settings-btn {
+ font-size: 0.6em;
+}
+
+button.comfy-close-menu-btn {
+ font-size: 1em;
+ line-height: 12px;
+ color: #ccc;
+ position: relative;
+ top: -1px;
+}
+
+.comfy-menu-queue-size {
+ flex: auto;
+}
+
.comfy-menu button,
.comfy-modal button {
font-size: 20px;
@@ -121,7 +177,6 @@ body {
width: 100%;
}
-.comfy-toggle-switch,
.comfy-btn,
.comfy-menu > button,
.comfy-menu-btns button,
@@ -140,17 +195,12 @@ body {
.comfy-menu-btns button:hover,
.comfy-menu .comfy-list button:hover,
.comfy-modal button:hover,
-.comfy-settings-btn:hover {
+.comfy-menu-actions button:hover {
filter: brightness(1.2);
+ will-change: transform;
cursor: pointer;
}
-.comfy-menu span.drag-handle {
- position: absolute;
- top: 0;
- left: 0;
-}
-
span.drag-handle {
width: 10px;
height: 20px;
@@ -215,15 +265,6 @@ span.drag-handle::after {
font-size: 12px;
}
-button.comfy-settings-btn {
- background-color: rgba(0, 0, 0, 0);
- font-size: 12px;
- padding: 0;
- position: absolute;
- right: 0;
- border: none;
-}
-
button.comfy-queue-btn {
margin: 6px 0 !important;
}
@@ -269,7 +310,19 @@ button.comfy-queue-btn {
}
.comfy-menu span.drag-handle {
- visibility: hidden
+ display: none;
+ }
+
+ .comfy-menu-queue-size {
+ flex: unset;
+ }
+
+ .comfy-menu-header {
+ justify-content: space-between;
+ }
+ .comfy-menu-actions {
+ gap: 10px;
+ font-size: 28px;
}
}
@@ -320,7 +373,7 @@ dialog::backdrop {
text-align: right;
}
-#comfy-settings-dialog button {
+#comfy-settings-dialog tbody button, #comfy-settings-dialog table > button {
background-color: var(--bg-color);
border: 1px var(--border-color) solid;
border-radius: 0;
@@ -343,12 +396,33 @@ dialog::backdrop {
}
.comfy-table caption {
+ position: sticky;
+ top: 0;
background-color: var(--bg-color);
color: var(--input-text);
font-size: 1rem;
font-weight: bold;
padding: 8px;
text-align: center;
+ border-bottom: 1px solid var(--border-color);
+}
+
+.comfy-table caption .comfy-btn {
+ position: absolute;
+ top: -2px;
+ right: 0;
+ bottom: 0;
+ cursor: pointer;
+ border: none;
+ height: 100%;
+ border-radius: 0;
+ aspect-ratio: 1/1;
+ user-select: none;
+ font-size: 20px;
+}
+
+.comfy-table caption .comfy-btn:focus {
+ outline: none;
}
.comfy-table tr:nth-child(even) {
@@ -389,11 +463,13 @@ dialog::backdrop {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
filter: brightness(95%);
+ will-change: transform;
}
.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) {
background-color: var(--comfy-menu-bg) !important;
filter: brightness(155%);
+ will-change: transform;
color: var(--input-text);
}
@@ -435,43 +511,6 @@ dialog::backdrop {
margin-left: 5px;
}
-.comfy-toggle-switch {
- border-width: 2px;
- display: flex;
- background-color: var(--comfy-input-bg);
- margin: 2px 0;
- white-space: nowrap;
-}
-
-.comfy-toggle-switch label {
- padding: 2px 0px 3px 6px;
- flex: auto;
- border-radius: 8px;
- align-items: center;
- display: flex;
- justify-content: center;
-}
-
-.comfy-toggle-switch label:first-child {
- border-top-left-radius: 8px;
- border-bottom-left-radius: 8px;
-}
-.comfy-toggle-switch label:last-child {
- border-top-right-radius: 8px;
- border-bottom-right-radius: 8px;
-}
-
-.comfy-toggle-switch .comfy-toggle-selected {
- background-color: var(--comfy-menu-bg);
-}
-
-#extraOptions {
- padding: 4px;
- background-color: var(--bg-color);
- margin-bottom: 4px;
- border-radius: 4px;
-}
-
/* Search box */
.litegraph.litesearchbox {
@@ -491,10 +530,30 @@ dialog::backdrop {
color: var(--input-text);
background-color: var(--comfy-input-bg);
filter: brightness(80%);
+ will-change: transform;
padding-left: 0.2em;
}
.litegraph.lite-search-item.generic_type {
color: var(--input-text);
filter: brightness(50%);
+ will-change: transform;
+}
+
+@media only screen and (max-width: 450px) {
+ #comfy-settings-dialog .comfy-table tbody {
+ display: grid;
+ }
+ #comfy-settings-dialog .comfy-table tr {
+ display: grid;
+ }
+ #comfy-settings-dialog tr > td:first-child {
+ text-align: center;
+ border-bottom: none;
+ padding-bottom: 0;
+ }
+ #comfy-settings-dialog tr > td:not(:first-child) {
+ text-align: center;
+ border-top: none;
+ }
}