Merge branch 'comfyanonymous:master' into seedControls

This commit is contained in:
FizzleDorf 2023-04-06 11:32:49 -04:00 committed by GitHub
commit 77a4e42fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1612 additions and 286 deletions

View File

@ -24,6 +24,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models. - [Config file](extra_model_paths.yaml.example) to set the search paths for models.

62
comfy/clip_vision.py Normal file
View File

@ -0,0 +1,62 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
from .utils import load_torch_file, transformers_convert
import os
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
def load_sd(self, sd):
self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
inputs = self.processor(images=[image[0]], return_tensors="pt")
outputs = self.model(**inputs)
return outputs
def convert_to_transformers(sd):
sd_k = sd.keys()
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k:
keys_to_replace = {
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding",
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight",
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight",
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias",
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight",
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias",
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight",
}
for x in keys_to_replace:
if x in sd_k:
sd[keys_to_replace[x]] = sd.pop(x)
if "embedder.model.visual.proj" in sd_k:
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1)
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32)
return sd
def load_clipvision_from_sd(sd):
sd = convert_to_transformers(sd)
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
clip = ClipVisionModel(json_config)
clip.load_sd(sd)
return clip
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_clipvision_from_sd(sd)

View File

@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1280,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 32,
"patch_size": 14,
"projection_dim": 1024,
"torch_dtype": "float32"
}

View File

@ -1,8 +1,4 @@
{ {
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPVisionModel"
],
"attention_dropout": 0.0, "attention_dropout": 0.0,
"dropout": 0.0, "dropout": 0.0,
"hidden_act": "quick_gelu", "hidden_act": "quick_gelu",
@ -18,6 +14,5 @@
"num_hidden_layers": 24, "num_hidden_layers": 24,
"patch_size": 14, "patch_size": 14,
"projection_dim": 768, "projection_dim": 768,
"torch_dtype": "float32", "torch_dtype": "float32"
"transformers_version": "4.24.0"
} }

View File

@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
log = super().log_images(*args, **kwargs) log = super().log_images(*args, **kwargs)
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
return log return log
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.embed_key = embedding_key
self.embedding_dropout = embedding_dropout
# self._init_embedder(embedder_config, freeze_embedder)
self._init_noise_aug(noise_aug_config)
def _init_embedder(self, config, freeze=True):
embedder = instantiate_from_config(config)
if freeze:
self.embedder = embedder.eval()
self.embedder.train = disabled_train
for param in self.embedder.parameters():
param.requires_grad = False
def _init_noise_aug(self, config):
if config is not None:
# use the KARLO schedule for noise augmentation on CLIP image embeddings
noise_augmentor = instantiate_from_config(config)
assert isinstance(noise_augmentor, nn.Module)
noise_augmentor = noise_augmentor.eval()
noise_augmentor.train = disabled_train
self.noise_augmentor = noise_augmentor
else:
self.noise_augmentor = None
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
z, c = outputs[0], outputs[1]
img = batch[self.embed_key][:bs]
img = rearrange(img, 'b h w c -> b c h w')
c_adm = self.embedder(img)
if self.noise_augmentor is not None:
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
# assume this gives embeddings of noise levels
c_adm = torch.cat((c_adm, noise_level_emb), 1)
if self.training:
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
device=c_adm.device)[:, None]) * c_adm
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
noutputs = [z, all_conds]
noutputs.extend(outputs[2:])
return noutputs
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, **kwargs):
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
return_original_cond=True)
log["inputs"] = x
log["reconstruction"] = xrec
assert self.model.conditioning_key is not None
assert self.cond_stage_key in ["caption", "txt"]
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
with ema_scope(f"Sampling"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_, )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log

View File

@ -307,7 +307,16 @@ def model_wrapper(
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2) t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition, condition]) if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict)
c_in = dict()
for k in condition:
if isinstance(condition[k], list):
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
else:
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
else:
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond) return noise_uncond + guidance_scale * (noise - noise_uncond)

View File

@ -3,7 +3,6 @@ import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = { MODEL_TYPES = {
"eps": "noise", "eps": "noise",
"v": "v" "v": "v"
@ -51,12 +50,20 @@ class DPMSolverSampler(object):
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] ctmp = conditioning[list(conditioning.keys())[0]]
if cbs != batch_size: while isinstance(ctmp, list): ctmp = ctmp[0]
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else: else:
if conditioning.shape[0] != batch_size: if isinstance(conditioning, torch.Tensor):
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling # sampling
C, H, W = shape C, H, W = shape
@ -83,6 +90,7 @@ class DPMSolverSampler(object):
) )
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None return x.to(device), None

View File

@ -9,7 +9,7 @@ from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management import model_management
if model_management.xformers_enabled(): if model_management.xformers_enabled_vae():
import xformers import xformers
import xformers.ops import xformers.ops
@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled() and attn_type == "vanilla": if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers" attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla": if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch" attn_type = "vanilla-pytorch"

View File

@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
return count_flops_attn(model, _x, y) return count_flops_attn(model, _x, y)
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
return timestep_embedding(t, self.dim)
class UNetModel(nn.Module): class UNetModel(nn.Module):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
@ -470,6 +479,7 @@ class UNetModel(nn.Module):
num_attention_blocks=None, num_attention_blocks=None,
disable_middle_self_attn=False, disable_middle_self_attn=False,
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None,
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
@ -538,6 +548,15 @@ class UNetModel(nn.Module):
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer") print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else: else:
raise ValueError() raise ValueError()

View File

@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1] betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999) betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "squaredcos_cap_v2": # used for karlo prior
# return early
return betas_for_alpha_bar(
n_timestep,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
elif schedule == "sqrt_linear": elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt": elif schedule == "sqrt":
@ -218,6 +225,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
@ -267,4 +275,4 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()

View File

@ -0,0 +1,59 @@
from typing import List, Tuple, Union
import torch
import torch.nn as nn
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
r"""Normalize an image/video tensor with mean and standard deviation.
.. math::
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
Args:
data: Image tensor of size :math:`(B, C, *)`.
mean: Mean for each channel.
std: Standard deviations for each channel.
Return:
Normalised tensor with same size as input :math:`(B, C, *)`.
Examples:
>>> x = torch.rand(1, 4, 3, 3)
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3)
>>> mean = torch.zeros(4)
>>> std = 255. * torch.ones(4)
>>> out = normalize(x, mean, std)
>>> out.shape
torch.Size([1, 4, 3, 3])
"""
shape = data.shape
if len(mean.shape) == 0 or mean.shape[0] == 1:
mean = mean.expand(shape[1])
if len(std.shape) == 0 or std.shape[0] == 1:
std = std.expand(shape[1])
# Allow broadcast on channel dimension
if mean.shape and mean.shape[0] != 1:
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
# Allow broadcast on channel dimension
if std.shape and std.shape[0] != 1:
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
if mean.shape:
mean = mean[..., :, None]
if std.shape:
std = std[..., :, None]
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
return out.view(shape)

View File

@ -1,5 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from . import kornia_functions
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module):
c = batch[key][:, None] c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout: if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long() c = c.long()
c = self.embedding(c) c = self.embedding(c)
return c return c
@ -57,18 +58,20 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder): class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text""" """Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__() super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version) self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device self.device = device
self.max_length = max_length # TODO: typical value? self.max_length = max_length # TODO: typical value?
if freeze: if freeze:
self.freeze() self.freeze()
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
#self.train = disabled_train # self.train = disabled_train
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
"pooled", "pooled",
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__() super().__init__()
@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
#self.train = disabled_train # self.train = disabled_train
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device) tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
if self.layer == "last": if self.layer == "last":
z = outputs.last_hidden_state z = outputs.last_hidden_state
elif self.layer == "pooled": elif self.layer == "pooled":
@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return self(text) return self(text)
class ClipImageEmbedder(nn.Module):
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.
):
super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
# x = kornia_functions.geometry_resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder): class FrozenOpenCLIPEmbedder(AbstractEncoder):
""" """
Uses the OpenCLIP transformer encoder for text Uses the OpenCLIP transformer encoder for text
""" """
LAYERS = [ LAYERS = [
#"pooled", # "pooled",
"last", "last",
"penultimate" "penultimate"
] ]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last"): freeze=True, layer="last"):
super().__init__() super().__init__()
@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
x = self.model.ln_final(x) x = self.model.ln_final(x)
return x return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks): for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx: if i == len(self.model.transformer.resblocks) - self.layer_idx:
break break
@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
return self(text) return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
pretrained=version, )
del model.transformer
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
# x = kornia.geometry.resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder): class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
clip_max_length=77, t5_max_length=77): clip_max_length=77, t5_max_length=77):
super().__init__() super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
def encode(self, text): def encode(self, text):
return self(text) return self(text)
@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
clip_z = self.clip_encoder.encode(text) clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text) t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z] return [clip_z, t5_z]

View File

@ -0,0 +1,35 @@
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.diffusionmodules.openaimodel import Timestep
import torch
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
super().__init__(*args, **kwargs)
if clip_stats_path is None:
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
else:
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
self.register_buffer("data_std", clip_std[None, :], persistent=False)
self.time_embed = Timestep(timestep_dim)
def scale(self, x):
# re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1. / self.data_std
return x
def unscale(self, x):
# back to original data stats
x = (x * self.data_std) + self.data_mean
return x
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
x = self.scale(x)
z = self.q_sample(x, noise_level)
z = self.unscale(z)
noise_level = self.time_embed(noise_level)
return z, noise_level

View File

@ -1,4 +1,4 @@
#Taken from: https://github.com/dbolya/tomesd
import torch import torch
from typing import Tuple, Callable from typing import Tuple, Callable
@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
return x return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
else:
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(metric: torch.Tensor, def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int, w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]: no_rand: bool = False) -> Tuple[Callable, Callable]:
""" """
Partitions the tokens into src and dst and merges r tokens from src to dst. Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args: Args:
- metric [B, N, C]: metric to use for similarity - metric [B, N, C]: metric to use for similarity
- w: image width in tokens - w: image width in tokens
@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if r <= 0: if r <= 0:
return do_nothing, do_nothing return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad(): with torch.no_grad():
hsy, wsx = h // sy, w // sx hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src # For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
if no_rand: if no_rand:
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else: else:
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
rand_idx = idx_buffer.argsort(dim=1) idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
num_dst = int((1 / (sx*sy)) * N) # Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst b_idx = rand_idx[:, :num_dst, :] # dst
def split(x): def split(x):
C = x.shape[-1] C = x.shape[-1]
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True) metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric) a, b = split(metric)
scores = a @ b.transpose(-1, -2) scores = a @ b.transpose(-1, -2)
@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
# Can't reduce more than the # tokens in src # Can't reduce more than the # tokens in src
r = min(a.shape[1], r) r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1) node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x) src, dst = split(x)
n, t1, c = src.shape n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1) return torch.cat([unm, dst], dim=1)
@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape _, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape # Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out return out
@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def get_functions(x, ratio, original_shape): def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w original_tokens = original_h * original_w
downsample = int(math.sqrt(original_tokens // x.shape[1])) downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2 stride_x = 2
stride_y = 2 stride_y = 2
max_downsample = 1 max_downsample = 1
if downsample <= max_downsample: if downsample <= max_downsample:
w = original_w // downsample w = int(math.ceil(original_w / downsample))
h = original_h // downsample h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio) r = int(x.shape[1] * ratio)
no_rand = False no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)

View File

@ -199,11 +199,25 @@ def get_autocast_device(dev):
return dev.type return dev.type
return "cuda" return "cuda"
def xformers_enabled(): def xformers_enabled():
if vram_state == CPU: if vram_state == CPU:
return False return False
return XFORMERS_IS_AVAILBLE return XFORMERS_IS_AVAILBLE
def xformers_enabled_vae():
enabled = xformers_enabled()
if not enabled:
return False
try:
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
if xformers.version.__version__ == "0.0.18":
return False
except:
pass
return enabled
def pytorch_attention_enabled(): def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION

View File

@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'strength' in cond[1]: if 'strength' in cond[1]:
strength = cond[1]['strength'] strength = cond[1]['strength']
adm_cond = None
if 'adm' in cond[1]:
adm_cond = cond[1]['adm']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength mult = torch.ones_like(input_x) * strength
@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cropped.append(cr) cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1) conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
control = None control = None
if 'control' in cond[1]: if 'control' in cond[1]:
control = cond[1]['control'] control = cond[1]['control']
@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'c_concat' in c1: if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape: if c1['c_concat'].shape != c2['c_concat'].shape:
return False return False
if 'c_adm' in c1:
if c1['c_adm'].shape != c2['c_adm'].shape:
return False
return True return True
def can_concat_cond(c1, c2): def can_concat_cond(c1, c2):
@ -92,16 +102,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def cond_cat(c_list): def cond_cat(c_list):
c_crossattn = [] c_crossattn = []
c_concat = [] c_concat = []
c_adm = []
for x in c_list: for x in c_list:
if 'c_crossattn' in x: if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn']) c_crossattn.append(x['c_crossattn'])
if 'c_concat' in x: if 'c_concat' in x:
c_concat.append(x['c_concat']) c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {} out = {}
if len(c_crossattn) > 0: if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)] out['c_crossattn'] = [torch.cat(c_crossattn)]
if len(c_concat) > 0: if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)] out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm)
return out return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
@ -327,6 +342,39 @@ def apply_control_net_to_equal_area(conds, uncond):
n['control'] = cond_cnets[x] n['control'] = cond_cnets[x]
uncond[temp[1]] = [o[0], n] uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)):
x = conds[t]
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()
x[1]["adm"] = torch.cat([adm_out] * batch_size)
return conds
class KSampler: class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
@ -422,10 +470,14 @@ class KSampler:
else: else:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
if hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None cond_concat = None
if hasattr(self.model, 'concat_keys'): if hasattr(self.model, 'concat_keys'): #inpaint
cond_concat = [] cond_concat = []
for ck in self.model.concat_keys: for ck in self.model.concat_keys:
if denoise_mask is not None: if denoise_mask is not None:

View File

@ -12,20 +12,7 @@ from .cldm import cldm
from .t2i_adapter import adapter from .t2i_adapter import adapter
from . import utils from . import utils
from . import clip_vision
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
@ -53,30 +40,7 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if x in sd: if x in sd:
sd[keys_to_replace[x]] = sd.pop(x) sd[keys_to_replace[x]] = sd.pop(x)
resblock_to_replace = { sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"mlp.c_fc": "mlp.fc1",
"mlp.c_proj": "mlp.fc2",
"attn.out_proj": "self_attn.out_proj",
}
for resblock in range(24):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y)
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y)
sd[k_to] = weights[1024*x:1024*(x + 1)]
for x in load_state_dict_to: for x in load_state_dict_to:
x.load_state_dict(sd, strict=False) x.load_state_dict(sd, strict=False)
@ -123,7 +87,7 @@ LORA_UNET_MAP_RESNET = {
} }
def load_lora(path, to_load): def load_lora(path, to_load):
lora = load_torch_file(path) lora = utils.load_torch_file(path)
patch_dict = {} patch_dict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
@ -599,7 +563,7 @@ class ControlNet:
return out return out
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = load_torch_file(ckpt_path) controlnet_data = utils.load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False pth = False
sd2 = False sd2 = False
@ -793,7 +757,7 @@ class StyleModel:
def load_style_model(ckpt_path): def load_style_model(ckpt_path):
model_data = load_torch_file(ckpt_path) model_data = utils.load_torch_file(ckpt_path)
keys = model_data.keys() keys = model_data.keys()
if "style_embedding" in keys: if "style_embedding" in keys:
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
@ -804,7 +768,7 @@ def load_style_model(ckpt_path):
def load_clip(ckpt_path, embedding_directory=None): def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path) clip_data = utils.load_torch_file(ckpt_path)
config = {} config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
@ -847,7 +811,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
load_state_dict_to = [w] load_state_dict_to = [w]
model = instantiate_from_config(config["model"]) model = instantiate_from_config(config["model"])
sd = load_torch_file(ckpt_path) sd = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16: if fp16:
@ -856,10 +820,11 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
return (ModelPatcher(model), clip, vae) return (ModelPatcher(model), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = load_torch_file(ckpt_path) sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() sd_keys = sd.keys()
clip = None clip = None
clipvision = None
vae = None vae = None
fp16 = model_management.should_use_fp16() fp16 = model_management.should_use_fp16()
@ -884,6 +849,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w] load_state_dict_to = [w]
clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight"
noise_aug_config = None
if clipvision_key in sd_keys:
size = sd[clipvision_key].shape[1]
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd)
noise_aug_key = "noise_augmentor.betas"
if noise_aug_key in sd_keys:
noise_aug_config = {}
params = {}
noise_schedule_config = {}
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h
params["timestep_dim"] = 1024
elif size == 1024: #l
params["timestep_dim"] = 768
noise_aug_config['params'] = params
sd_config = { sd_config = {
"linear_start": 0.00085, "linear_start": 0.00085,
"linear_end": 0.012, "linear_end": 0.012,
@ -932,7 +920,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
if unet_config["in_channels"] > 4: #inpainting model if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid" sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None sd_config["finetune_keys"] = None
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
@ -944,6 +938,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
else: else:
unet_config["num_heads"] = 8 #SD1.x unet_config["num_heads"] = 8 #SD1.x
unclip = 'model.diffusion_model.label_emb.0.0.weight'
if unclip in sd_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1]
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = sd[k] out = sd[k]
@ -956,4 +955,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
if fp16: if fp16:
model = model.half() model = model.half()
return (ModelPatcher(model), clip, vae) return (ModelPatcher(model), clip, vae, clipvision)

View File

@ -74,9 +74,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if isinstance(y, int): if isinstance(y, int):
tokens_temp += [y] tokens_temp += [y]
else: else:
embedding_weights += [y] if y.shape[0] == current_embeds.weight.shape[1]:
tokens_temp += [next_new_token] embedding_weights += [y]
next_new_token += 1 tokens_temp += [next_new_token]
next_new_token += 1
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
if len(embedding_weights) > 0: if len(embedding_weights) > 0:

View File

@ -1,5 +1,47 @@
import torch import torch
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def transformers_convert(sd, prefix_from, prefix_to, number):
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"mlp.c_fc": "mlp.fc1",
"mlp.c_proj": "mlp.fc2",
"attn.out_proj": "self_attn.out_proj",
}
for resblock in range(number):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def common_upscale(samples, width, height, upscale_method, crop): def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center": if crop == "center":
old_width = samples.shape[3] old_width = samples.shape[3]

View File

@ -1,32 +0,0 @@
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
from comfy.sd import load_torch_file
import os
class ClipVisionModel():
def __init__(self):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json")
config = CLIPVisionConfig.from_json_file(json_config)
self.model = CLIPVisionModel(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
def load_sd(self, sd):
self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
inputs = self.processor(images=[image[0]], return_tensors="pt")
outputs = self.model(**inputs)
return outputs
def load(ckpt_path):
clip_data = load_torch_file(ckpt_path)
clip = ClipVisionModel()
clip.load_sd(clip_data)
return clip

View File

@ -0,0 +1,210 @@
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import comfy.utils
class Blend:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"blend_factor": ("FLOAT", {
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blend_images"
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
return (blended_image,)
def blend_mode(self, img1, img2, mode):
if mode == "normal":
return img2
elif mode == "multiply":
return img1 * img2
elif mode == "screen":
return 1 - (1 - img1) * (1 - img2)
elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
else:
raise ValueError(f"Unsupported blend mode: {mode}")
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
class Blur:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"blur_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blur"
CATEGORY = "image/postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
blurred = blurred.permute(0, 2, 3, 1)
return (blurred,)
class Quantize:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"colors": ("INT", {
"default": 256,
"min": 1,
"max": 256,
"step": 1
}),
"dither": (["none", "floyd-steinberg"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "quantize"
CATEGORY = "image/postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array
return (result,)
class Sharpen:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"sharpen_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 5.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sharpen"
CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
if sharpen_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1)
return (result,)
NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend,
"ImageBlur": Blur,
"ImageQuantize": Quantize,
"ImageSharpen": Sharpen,
}

View File

@ -1,6 +1,5 @@
import os import os
from comfy_extras.chainner_models import model_loading from comfy_extras.chainner_models import model_loading
from comfy.sd import load_torch_file
import model_management import model_management
import torch import torch
import comfy.utils import comfy.utils
@ -18,7 +17,7 @@ class UpscaleModelLoader:
def load_model(self, model_name): def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name) model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = load_torch_file(model_path) sd = comfy.utils.load_torch_file(model_path)
out = model_loading.load_state_dict(sd).eval() out = model_loading.load_state_dict(sd).eval()
return (out, ) return (out, )

View File

@ -11,6 +11,8 @@ class Example:
---------- ----------
RETURN_TYPES (`tuple`): RETURN_TYPES (`tuple`):
The type of each element in the output tulple. The type of each element in the output tulple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tulple.
FUNCTION (`str`): FUNCTION (`str`):
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]): OUTPUT_NODE ([`bool`]):
@ -61,6 +63,8 @@ class Example:
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
#RETURN_NAMES = ("image_output_name",)
FUNCTION = "test" FUNCTION = "test"
#OUTPUT_NODE = False #OUTPUT_NODE = False

View File

@ -27,6 +27,40 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
if not os.path.exists(input_directory):
os.makedirs(input_directory)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def get_output_directory():
global output_directory
return output_directory
def get_temp_directory():
global temp_directory
return temp_directory
def get_input_directory():
global input_directory
return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
if type_name == "output":
return get_output_directory()
if type_name == "temp":
return get_temp_directory()
if type_name == "input":
return get_input_directory()
return None
def add_model_folder_path(folder_name, full_folder_path): def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths global folder_names_and_paths

18
main.py
View File

@ -11,9 +11,15 @@ if os.name == "nt":
if __name__ == "__main__": if __name__ == "__main__":
if '--help' in sys.argv: if '--help' in sys.argv:
print()
print("Valid Command line Arguments:") print("Valid Command line Arguments:")
print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.") print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
print("\t--port 8188\t\t\tSet the listen port.") print("\t--port 8188\t\t\tSet the listen port.")
print()
print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.")
print("\t--output-directory path/to/output\tSet the ComfyUI output directory.")
print()
print()
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.") print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
@ -40,6 +46,7 @@ if __name__ == "__main__":
except: except:
pass pass
from nodes import init_custom_nodes
import execution import execution
import server import server
import folder_paths import folder_paths
@ -98,6 +105,8 @@ if __name__ == "__main__":
server = server.PromptServer(loop) server = server.PromptServer(loop)
q = execution.PromptQueue(server) q = execution.PromptQueue(server)
init_custom_nodes()
server.add_routes()
hijack_progress(server) hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
@ -113,7 +122,6 @@ if __name__ == "__main__":
except: except:
address = '127.0.0.1' address = '127.0.0.1'
dont_print = False dont_print = False
if '--dont-print-server' in sys.argv: if '--dont-print-server' in sys.argv:
dont_print = True dont_print = True
@ -127,6 +135,14 @@ if __name__ == "__main__":
for i in indices: for i in indices:
load_extra_path_config(sys.argv[i]) load_extra_path_config(sys.argv[i])
try:
output_dir = sys.argv[sys.argv.index('--output-directory') + 1]
output_dir = os.path.abspath(output_dir)
print("setting output directory to:", output_dir)
folder_paths.set_output_directory(output_dir)
except:
pass
port = 8188 port = 8188
try: try:
p_index = sys.argv.index('--port') p_index = sys.argv.index('--port')

View File

@ -18,7 +18,7 @@ import comfy.samplers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy_extras.clip_vision import comfy.clip_vision
import model_management import model_management
import importlib import importlib
@ -197,7 +197,7 @@ class CheckpointLoader:
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "loaders" CATEGORY = "advanced/loaders"
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
@ -219,6 +219,21 @@ class CheckpointLoaderSimple:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
class unCLIPCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
class CLIPSetLastLayer: class CLIPSetLastLayer:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -370,7 +385,7 @@ class CLIPVisionLoader:
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_vision = comfy_extras.clip_vision.load(clip_path) clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,) return (clip_vision,)
class CLIPVisionEncode: class CLIPVisionEncode:
@ -382,7 +397,7 @@ class CLIPVisionEncode:
RETURN_TYPES = ("CLIP_VISION_OUTPUT",) RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "conditioning/style_model" CATEGORY = "conditioning"
def encode(self, clip_vision, image): def encode(self, clip_vision, image):
output = clip_vision.encode_image(image) output = clip_vision.encode_image(image)
@ -424,6 +439,33 @@ class StyleModelApply:
c.append(n) c.append(n)
return (c, ) return (c, )
class unCLIPConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_adm"
CATEGORY = "conditioning"
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
c = []
for t in conditioning:
o = t[1].copy()
x = (clip_vision_output, strength, noise_augmentation)
if "adm" in o:
o["adm"] = o["adm"][:] + [x]
else:
o["adm"] = [x]
n = [t[0], o]
c.append(n)
return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -735,7 +777,7 @@ class KSamplerAdvanced:
class SaveImage: class SaveImage:
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") self.output_dir = folder_paths.get_output_directory()
self.type = "output" self.type = "output"
@classmethod @classmethod
@ -787,9 +829,6 @@ class SaveImage:
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
counter = 1 counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
results = list() results = list()
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
@ -814,7 +853,7 @@ class SaveImage:
class PreviewImage(SaveImage): class PreviewImage(SaveImage):
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") self.output_dir = folder_paths.get_temp_directory()
self.type = "temp" self.type = "temp"
@classmethod @classmethod
@ -825,13 +864,11 @@ class PreviewImage(SaveImage):
} }
class LoadImage: class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
if not os.path.exists(s.input_dir): input_dir = folder_paths.get_input_directory()
os.makedirs(s.input_dir)
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), )}, {"image": (sorted(os.listdir(input_dir)), )},
} }
CATEGORY = "image" CATEGORY = "image"
@ -839,7 +876,8 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK") RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
@ -853,18 +891,19 @@ class LoadImage:
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()
class LoadImageMask: class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), ), {"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),} "channel": (["alpha", "red", "green", "blue"], ),}
} }
@ -873,7 +912,8 @@ class LoadImageMask:
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image, channel): def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
mask = None mask = None
c = channel[0].upper() c = channel[0].upper()
@ -888,7 +928,8 @@ class LoadImageMask:
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
@ -996,7 +1037,6 @@ class ImagePadForOutpaint:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"KSampler": KSampler, "KSampler": KSampler,
"CheckpointLoader": CheckpointLoader,
"CheckpointLoaderSimple": CheckpointLoaderSimple, "CheckpointLoaderSimple": CheckpointLoaderSimple,
"CLIPTextEncode": CLIPTextEncode, "CLIPTextEncode": CLIPTextEncode,
"CLIPSetLastLayer": CLIPSetLastLayer, "CLIPSetLastLayer": CLIPSetLastLayer,
@ -1025,6 +1065,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPLoader": CLIPLoader, "CLIPLoader": CLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode, "CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply, "StyleModelApply": StyleModelApply,
"unCLIPConditioning": unCLIPConditioning,
"ControlNetApply": ControlNetApply, "ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader, "ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader,
@ -1033,6 +1074,8 @@ NODE_CLASS_MAPPINGS = {
"VAEDecodeTiled": VAEDecodeTiled, "VAEDecodeTiled": VAEDecodeTiled,
"VAEEncodeTiled": VAEEncodeTiled, "VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel, "TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"CheckpointLoader": CheckpointLoader,
} }
def load_custom_node(module_path): def load_custom_node(module_path):
@ -1067,6 +1110,7 @@ def load_custom_nodes():
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path) load_custom_node(module_path)
load_custom_nodes() def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))

View File

@ -47,7 +47,7 @@
" !git pull\n", " !git pull\n",
"\n", "\n",
"!echo -= Install dependencies =-\n", "!echo -= Install dependencies =-\n",
"!pip install xformers==0.0.16 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117" "!pip install xformers -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118"
] ]
}, },
{ {

View File

@ -42,6 +42,7 @@ class PromptServer():
self.web_root = os.path.join(os.path.dirname( self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web") os.path.realpath(__file__)), "web")
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None self.last_node_id = None
self.client_id = None self.client_id = None
@ -88,7 +89,7 @@ class PromptServer():
@routes.post("/upload/image") @routes.post("/upload/image")
async def upload_image(request): async def upload_image(request):
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") upload_dir = folder_paths.get_input_directory()
if not os.path.exists(upload_dir): if not os.path.exists(upload_dir):
os.makedirs(upload_dir) os.makedirs(upload_dir)
@ -121,10 +122,10 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output") type = request.rel_url.query.get("type", "output")
if type not in ["output", "input", "temp"]: output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400) return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
if "subfolder" in request.rel_url.query: if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
@ -239,8 +240,9 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete) self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200) return web.Response(status=200)
self.app.add_routes(routes) def add_routes(self):
self.app.add_routes(self.routes)
self.app.add_routes([ self.app.add_routes([
web.static('/', self.web_root), web.static('/', self.web_root),
]) ])

View File

@ -0,0 +1,137 @@
import { app } from "/scripts/app.js";
// Adds filtering to combo context menus
const id = "Comfy.ContextMenuFilter";
app.registerExtension({
name: id,
init() {
const ctxMenu = LiteGraph.ContextMenu;
LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input");
Object.assign(filter.style, {
width: "calc(100% - 10px)",
border: "0",
boxSizing: "border-box",
background: "#333",
border: "1px solid #999",
margin: "0 0 5px 5px",
color: "#fff",
});
filter.placeholder = "Filter list";
this.root.prepend(filter);
let selectedIndex = 0;
let items = this.root.querySelectorAll(".litemenu-entry");
let itemCount = items.length;
let selectedItem;
// Apply highlighting to the selected item
function updateSelected() {
if (selectedItem) {
selectedItem.style.setProperty("background-color", "");
selectedItem.style.setProperty("color", "");
}
selectedItem = items[selectedIndex];
if (selectedItem) {
selectedItem.style.setProperty("background-color", "#ccc", "important");
selectedItem.style.setProperty("color", "#000", "important");
}
}
const positionList = () => {
const rect = this.root.getBoundingClientRect();
// If the top is off screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
updateSelected();
// Arrow up/down to select items
filter.addEventListener("keydown", (e) => {
if (e.key === "ArrowUp") {
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
e.preventDefault();
} else if (e.key === "ArrowDown") {
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
e.preventDefault();
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
selectedItem.click();
} else if(e.key === "Escape") {
this.close();
}
});
filter.addEventListener("input", () => {
// Hide all items that dont match our filter
const term = filter.value.toLocaleLowerCase();
items = this.root.querySelectorAll(".litemenu-entry");
// When filtering recompute which items are visible for arrow up/down
// Try and maintain selection
let visibleItems = [];
for (const item of items) {
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
if (visible) {
item.style.display = "block";
if (item === selectedItem) {
selectedIndex = visibleItems.length;
}
visibleItems.push(item);
} else {
item.style.display = "none";
if (item === selectedItem) {
selectedIndex = 0;
}
}
}
items = visibleItems;
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
positionList();
});
}
return ctx;
};
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
},
});

View File

@ -30,7 +30,8 @@ app.registerExtension({
} }
// Overwrite the value in the serialized workflow pnginfo // Overwrite the value in the serialized workflow pnginfo
workflowNode.widgets_values[widgetIndex] = prompt; if (workflowNode?.widgets_values)
workflowNode.widgets_values[widgetIndex] = prompt;
return prompt; return prompt;
}; };

View File

@ -3,10 +3,10 @@ import { app } from "/scripts/app.js";
// Inverts the scrolling of context menus // Inverts the scrolling of context menus
const id = "Comfy.InvertMenuScrolling"; const id = "Comfy.InvertMenuScrolling";
const ctxMenu = LiteGraph.ContextMenu;
app.registerExtension({ app.registerExtension({
name: id, name: id,
init() { init() {
const ctxMenu = LiteGraph.ContextMenu;
const replace = () => { const replace = () => {
LiteGraph.ContextMenu = function (values, options) { LiteGraph.ContextMenu = function (values, options) {
options = options || {}; options = options || {};

View File

@ -11,11 +11,14 @@ app.registerExtension({
this.properties = {}; this.properties = {};
} }
this.properties.showOutputText = RerouteNode.defaultVisibility; this.properties.showOutputText = RerouteNode.defaultVisibility;
this.properties.horizontal = false;
this.addInput("", "*"); this.addInput("", "*");
this.addOutput(this.properties.showOutputText ? "*" : "", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*");
this.onConnectionsChange = function (type, index, connected, link_info) { this.onConnectionsChange = function (type, index, connected, link_info) {
this.applyOrientation();
// Prevent multiple connections to different types when we have no input // Prevent multiple connections to different types when we have no input
if (connected && type === LiteGraph.OUTPUT) { if (connected && type === LiteGraph.OUTPUT) {
// Ignore wildcard nodes as these will be updated to real types // Ignore wildcard nodes as these will be updated to real types
@ -43,12 +46,19 @@ app.registerExtension({
const node = app.graph.getNodeById(link.origin_id); const node = app.graph.getNodeById(link.origin_id);
const type = node.constructor.type; const type = node.constructor.type;
if (type === "Reroute") { if (type === "Reroute") {
// Move the previous node if (node === this) {
currentNode = node; // We've found a circle
currentNode.disconnectInput(link.target_slot);
currentNode = null;
}
else {
// Move the previous node
currentNode = node;
}
} else { } else {
// We've found the end // We've found the end
inputNode = currentNode; inputNode = currentNode;
inputType = node.outputs[link.origin_slot].type; inputType = node.outputs[link.origin_slot]?.type ?? null;
break; break;
} }
} else { } else {
@ -80,7 +90,7 @@ app.registerExtension({
updateNodes.push(node); updateNodes.push(node);
} else { } else {
// We've found an output // We've found an output
const nodeOutType = node.inputs[link.target_slot].type; const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && nodeOutType !== inputType) { if (inputType && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it // The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot); node.disconnectInput(link.target_slot);
@ -105,6 +115,7 @@ app.registerExtension({
node.__outputType = displayType; node.__outputType = displayType;
node.outputs[0].name = node.properties.showOutputText ? displayType : ""; node.outputs[0].name = node.properties.showOutputText ? displayType : "";
node.size = node.computeSize(); node.size = node.computeSize();
node.applyOrientation();
for (const l of node.outputs[0].links || []) { for (const l of node.outputs[0].links || []) {
const link = app.graph.links[l]; const link = app.graph.links[l];
@ -146,6 +157,7 @@ app.registerExtension({
this.outputs[0].name = ""; this.outputs[0].name = "";
} }
this.size = this.computeSize(); this.size = this.computeSize();
this.applyOrientation();
app.graph.setDirtyCanvas(true, true); app.graph.setDirtyCanvas(true, true);
}, },
}, },
@ -154,9 +166,32 @@ app.registerExtension({
callback: () => { callback: () => {
RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility); RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility);
}, },
},
{
// naming is inverted with respect to LiteGraphNode.horizontal
// LiteGraphNode.horizontal == true means that
// each slot in the inputs and outputs are layed out horizontally,
// which is the opposite of the visual orientation of the inputs and outputs as a node
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
callback: () => {
this.properties.horizontal = !this.properties.horizontal;
this.applyOrientation();
},
} }
); );
} }
applyOrientation() {
this.horizontal = this.properties.horizontal;
if (this.horizontal) {
// we correct the input position, because LiteGraphNode.horizontal
// doesn't account for title presence
// which reroute nodes don't have
this.inputs[0].pos = [this.size[0] / 2, 0];
} else {
delete this.inputs[0].pos;
}
app.graph.setDirtyCanvas(true, true);
}
computeSize() { computeSize() {
return [ return [

View File

@ -0,0 +1,21 @@
import { app } from "/scripts/app.js";
// Adds defaults for quickly adding nodes with middle click on the input/output
app.registerExtension({
name: "Comfy.SlotDefaults",
init() {
LiteGraph.middle_click_slot_add_default_node = true;
LiteGraph.slot_types_default_in = {
MODEL: "CheckpointLoaderSimple",
LATENT: "EmptyLatentImage",
VAE: "VAELoader",
};
LiteGraph.slot_types_default_out = {
LATENT: "VAEDecode",
IMAGE: "SaveImage",
CLIP: "CLIPTextEncode",
};
},
});

View File

@ -0,0 +1,89 @@
import { app } from "/scripts/app.js";
// Shift + drag/resize to snap to grid
app.registerExtension({
name: "Comfy.SnapToGrid",
init() {
// Add setting to control grid size
app.ui.settings.addSetting({
id: "Comfy.SnapToGrid.GridSize",
name: "Grid Size",
type: "number",
attrs: {
min: 1,
max: 500,
},
tooltip:
"When dragging and resizing nodes while holding shift they will be aligned to the grid, this controls the size of that grid.",
defaultValue: LiteGraph.CANVAS_GRID_SIZE,
onChange(value) {
LiteGraph.CANVAS_GRID_SIZE = +value;
},
});
// After moving a node, if the shift key is down align it to grid
const onNodeMoved = app.canvas.onNodeMoved;
app.canvas.onNodeMoved = function (node) {
const r = onNodeMoved?.apply(this, arguments);
if (app.shiftDown) {
// Ensure all selected nodes are realigned
for (const id in this.selected_nodes) {
this.selected_nodes[id].alignToGrid();
}
}
return r;
};
// When a node is added, add a resize handler to it so we can fix align the size with the grid
const onNodeAdded = app.graph.onNodeAdded;
app.graph.onNodeAdded = function (node) {
const onResize = node.onResize;
node.onResize = function () {
if (app.shiftDown) {
const w = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[0] / LiteGraph.CANVAS_GRID_SIZE);
const h = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[1] / LiteGraph.CANVAS_GRID_SIZE);
node.size[0] = w;
node.size[1] = h;
}
return onResize?.apply(this, arguments);
};
return onNodeAdded?.apply(this, arguments);
};
// Draw a preview of where the node will go if holding shift and the node is selected
const origDrawNode = LGraphCanvas.prototype.drawNode;
LGraphCanvas.prototype.drawNode = function (node, ctx) {
if (app.shiftDown && this.node_dragged && node.id in this.selected_nodes) {
const x = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[0] / LiteGraph.CANVAS_GRID_SIZE);
const y = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[1] / LiteGraph.CANVAS_GRID_SIZE);
const shiftX = x - node.pos[0];
let shiftY = y - node.pos[1];
let w, h;
if (node.flags.collapsed) {
w = node._collapsed_width;
h = LiteGraph.NODE_TITLE_HEIGHT;
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
} else {
w = node.size[0];
h = node.size[1];
let titleMode = node.constructor.title_mode;
if (titleMode !== LiteGraph.TRANSPARENT_TITLE && titleMode !== LiteGraph.NO_TITLE) {
h += LiteGraph.NODE_TITLE_HEIGHT;
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
}
}
const f = ctx.fillStyle;
ctx.fillStyle = "rgba(100, 100, 100, 0.5)";
ctx.fillRect(shiftX, shiftY, w, h);
ctx.fillStyle = f;
}
return origDrawNode.apply(this, arguments);
};
},
});

View File

@ -25,7 +25,7 @@ function hideWidget(node, widget, suffix = "") {
if (link == null) { if (link == null) {
return undefined; return undefined;
} }
return widget.value; return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
}; };
// Hide any linked widgets, e.g. seed+seedControl // Hide any linked widgets, e.g. seed+seedControl

View File

@ -5,10 +5,20 @@ import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js";
class ComfyApp { class ComfyApp {
/**
* List of {number, batchCount} entries to queue
*/
#queueItems = [];
/**
* If the queue is currently being processed
*/
#processingQueue = false;
constructor() { constructor() {
this.ui = new ComfyUI(this); this.ui = new ComfyUI(this);
this.extensions = []; this.extensions = [];
this.nodeOutputs = {}; this.nodeOutputs = {};
this.shiftDown = false;
} }
/** /**
@ -102,6 +112,46 @@ class ComfyApp {
}; };
} }
#addNodeKeyHandler(node) {
const app = this;
const origNodeOnKeyDown = node.prototype.onKeyDown;
node.prototype.onKeyDown = function(e) {
if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) {
return false;
}
if (this.flags.collapsed || !this.imgs || this.imageIndex === null) {
return;
}
let handled = false;
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
if (e.key === "ArrowLeft") {
this.imageIndex -= 1;
} else if (e.key === "ArrowRight") {
this.imageIndex += 1;
}
this.imageIndex %= this.imgs.length;
if (this.imageIndex < 0) {
this.imageIndex = this.imgs.length + this.imageIndex;
}
handled = true;
} else if (e.key === "Escape") {
this.imageIndex = null;
handled = true;
}
if (handled === true) {
e.preventDefault();
e.stopImmediatePropagation();
return false;
}
}
}
/** /**
* Adds Custom drawing logic for nodes * Adds Custom drawing logic for nodes
* e.g. Draws images and handles thumbnail navigation on nodes that output images * e.g. Draws images and handles thumbnail navigation on nodes that output images
@ -628,11 +678,16 @@ class ComfyApp {
#addKeyboardHandler() { #addKeyboardHandler() {
window.addEventListener("keydown", (e) => { window.addEventListener("keydown", (e) => {
this.shiftDown = e.shiftKey;
// Queue prompt using ctrl or command + enter // Queue prompt using ctrl or command + enter
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) { if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
this.queuePrompt(e.shiftKey ? -1 : 0); this.queuePrompt(e.shiftKey ? -1 : 0);
} }
}); });
window.addEventListener("keyup", (e) => {
this.shiftDown = e.shiftKey;
});
} }
/** /**
@ -667,6 +722,9 @@ class ComfyApp {
const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph)); const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph));
this.ctx = canvasEl.getContext("2d"); this.ctx = canvasEl.getContext("2d");
LiteGraph.release_link_on_empty_shows_menu = true;
LiteGraph.alt_drag_do_clone_nodes = true;
this.graph.start(); this.graph.start();
function resizeCanvas() { function resizeCanvas() {
@ -785,6 +843,7 @@ class ComfyApp {
this.#addNodeContextMenuHandler(node); this.#addNodeContextMenuHandler(node);
this.#addDrawBackgroundHandler(node, app); this.#addDrawBackgroundHandler(node, app);
this.#addNodeKeyHandler(node);
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
LiteGraph.registerNodeType(nodeId, node); LiteGraph.registerNodeType(nodeId, node);
@ -919,31 +978,47 @@ class ComfyApp {
} }
async queuePrompt(number, batchCount = 1) { async queuePrompt(number, batchCount = 1) {
for (let i = 0; i < batchCount; i++) { this.#queueItems.push({ number, batchCount });
const p = await this.graphToPrompt();
try { // Only have one action process the items so each one gets a unique seed correctly
await api.queuePrompt(number, p); if (this.#processingQueue) {
} catch (error) { return;
this.ui.dialog.show(error.response || error.toString()); }
return;
} this.#processingQueue = true;
try {
while (this.#queueItems.length) {
({ number, batchCount } = this.#queueItems.pop());
for (const n of p.workflow.nodes) { for (let i = 0; i < batchCount; i++) {
const node = graph.getNodeById(n.id); const p = await this.graphToPrompt();
if (node.widgets) {
for (const widget of node.widgets) { try {
// Allow widgets to run callbacks after a prompt has been queued await api.queuePrompt(number, p);
// e.g. random seed after every gen } catch (error) {
if (widget.afterQueued) { this.ui.dialog.show(error.response || error.toString());
widget.afterQueued(); break;
}
for (const n of p.workflow.nodes) {
const node = graph.getNodeById(n.id);
if (node.widgets) {
for (const widget of node.widgets) {
// Allow widgets to run callbacks after a prompt has been queued
// e.g. random seed after every gen
if (widget.afterQueued) {
widget.afterQueued();
}
}
} }
} }
this.canvas.draw(true, true);
await this.ui.queue.update();
} }
} }
} finally {
this.canvas.draw(true, true); this.#processingQueue = false;
await this.ui.queue.update();
} }
} }

View File

@ -35,21 +35,86 @@ export function $el(tag, propsOrChildren, children) {
return element; return element;
} }
function dragElement(dragEl) { function dragElement(dragEl, settings) {
var posDiffX = 0, var posDiffX = 0,
posDiffY = 0, posDiffY = 0,
posStartX = 0, posStartX = 0,
posStartY = 0, posStartY = 0,
newPosX = 0, newPosX = 0,
newPosY = 0; newPosY = 0;
if (dragEl.getElementsByClassName('drag-handle')[0]) { if (dragEl.getElementsByClassName("drag-handle")[0]) {
// if present, the handle is where you move the DIV from: // if present, the handle is where you move the DIV from:
dragEl.getElementsByClassName('drag-handle')[0].onmousedown = dragMouseDown; dragEl.getElementsByClassName("drag-handle")[0].onmousedown = dragMouseDown;
} else { } else {
// otherwise, move the DIV from anywhere inside the DIV: // otherwise, move the DIV from anywhere inside the DIV:
dragEl.onmousedown = dragMouseDown; dragEl.onmousedown = dragMouseDown;
} }
// When the element resizes (e.g. view queue) ensure it is still in the windows bounds
const resizeObserver = new ResizeObserver(() => {
ensureInBounds();
}).observe(dragEl);
function ensureInBounds() {
if (dragEl.classList.contains("comfy-menu-manual-pos")) {
newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft));
newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop));
positionElement();
}
}
function positionElement() {
const halfWidth = document.body.clientWidth / 2;
const anchorRight = newPosX + dragEl.clientWidth / 2 > halfWidth;
// set the element's new position:
if (anchorRight) {
dragEl.style.left = "unset";
dragEl.style.right = document.body.clientWidth - newPosX - dragEl.clientWidth + "px";
} else {
dragEl.style.left = newPosX + "px";
dragEl.style.right = "unset";
}
dragEl.style.top = newPosY + "px";
dragEl.style.bottom = "unset";
if (savePos) {
localStorage.setItem(
"Comfy.MenuPosition",
JSON.stringify({
x: dragEl.offsetLeft,
y: dragEl.offsetTop,
})
);
}
}
function restorePos() {
let pos = localStorage.getItem("Comfy.MenuPosition");
if (pos) {
pos = JSON.parse(pos);
newPosX = pos.x;
newPosY = pos.y;
positionElement();
ensureInBounds();
}
}
let savePos = undefined;
settings.addSetting({
id: "Comfy.MenuPosition",
name: "Save menu position",
type: "boolean",
defaultValue: savePos,
onChange(value) {
if (savePos === undefined && value) {
restorePos();
}
savePos = value;
},
});
function dragMouseDown(e) { function dragMouseDown(e) {
e = e || window.event; e = e || window.event;
e.preventDefault(); e.preventDefault();
@ -64,18 +129,25 @@ function dragElement(dragEl) {
function elementDrag(e) { function elementDrag(e) {
e = e || window.event; e = e || window.event;
e.preventDefault(); e.preventDefault();
dragEl.classList.add("comfy-menu-manual-pos");
// calculate the new cursor position: // calculate the new cursor position:
posDiffX = e.clientX - posStartX; posDiffX = e.clientX - posStartX;
posDiffY = e.clientY - posStartY; posDiffY = e.clientY - posStartY;
posStartX = e.clientX; posStartX = e.clientX;
posStartY = e.clientY; posStartY = e.clientY;
newPosX = Math.min((document.body.clientWidth - dragEl.clientWidth), Math.max(0, (dragEl.offsetLeft + posDiffX)));
newPosY = Math.min((document.body.clientHeight - dragEl.clientHeight), Math.max(0, (dragEl.offsetTop + posDiffY))); newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft + posDiffX));
// set the element's new position: newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop + posDiffY));
dragEl.style.top = newPosY + "px";
dragEl.style.left = newPosX + "px"; positionElement();
} }
window.addEventListener("resize", () => {
ensureInBounds();
});
function closeDragElement() { function closeDragElement() {
// stop moving when mouse button is released: // stop moving when mouse button is released:
document.onmouseup = null; document.onmouseup = null;
@ -90,7 +162,7 @@ class ComfyDialog {
$el("p", { $: (p) => (this.textElement = p) }), $el("p", { $: (p) => (this.textElement = p) }),
$el("button", { $el("button", {
type: "button", type: "button",
textContent: "CLOSE", textContent: "Close",
onclick: () => this.close(), onclick: () => this.close(),
}), }),
]), ]),
@ -125,7 +197,7 @@ class ComfySettingsDialog extends ComfyDialog {
localStorage[settingId] = JSON.stringify(value); localStorage[settingId] = JSON.stringify(value);
} }
addSetting({ id, name, type, defaultValue, onChange }) { addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", }) {
if (!id) { if (!id) {
throw new Error("Settings must have an ID"); throw new Error("Settings must have an ID");
} }
@ -152,42 +224,83 @@ class ComfySettingsDialog extends ComfyDialog {
value = v; value = v;
}; };
let element;
value = this.getSettingValue(id, defaultValue);
if (typeof type === "function") { if (typeof type === "function") {
return type(name, setter, value); element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type: "checkbox",
checked: !!value,
oninput: (e) => {
setter(e.target.checked);
},
...attrs
}),
]),
]);
break;
case "number":
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type,
value,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
default:
console.warn("Unsupported setting type, defaulting to text");
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
value,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
}
}
if(tooltip) {
element.title = tooltip;
} }
switch (type) { return element;
case "boolean":
return $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type: "checkbox",
checked: !!value,
oninput: (e) => {
setter(e.target.checked);
},
}),
]),
]);
default:
console.warn("Unsupported setting type, defaulting to text");
return $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
value,
oninput: (e) => {
setter(e.target.value);
},
}),
]),
]);
}
}, },
}); });
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
} }
show() { show() {
super.show(); super.show();
Object.assign(this.textElement.style, {
display: "flex",
flexDirection: "column",
gap: "10px"
});
this.textElement.replaceChildren(...this.settings.map((s) => s.render())); this.textElement.replaceChildren(...this.settings.map((s) => s.render()));
} }
} }
@ -300,6 +413,13 @@ export class ComfyUI {
this.history.update(); this.history.update();
}); });
const confirmClear = this.settings.addSetting({
id: "Comfy.ConfirmClear",
name: "Require confirmation when clearing workflow",
type: "boolean",
defaultValue: true,
});
const fileInput = $el("input", { const fileInput = $el("input", {
type: "file", type: "file",
accept: ".json,image/png", accept: ".json,image/png",
@ -311,39 +431,57 @@ export class ComfyUI {
}); });
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
$el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [ $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [
$el("span.drag-handle"), $el("span.drag-handle"),
$el("span", { $: (q) => (this.queueSize = q) }), $el("span", { $: (q) => (this.queueSize = q) }),
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
]), ]),
$el("button.comfy-queue-btn", { textContent: "Queue Prompt", onclick: () => app.queuePrompt(0, this.batchCount) }), $el("button.comfy-queue-btn", {
textContent: "Queue Prompt",
onclick: () => app.queuePrompt(0, this.batchCount),
}),
$el("div", {}, [ $el("div", {}, [
$el("label", { innerHTML: "Extra options"}, [ $el("label", { innerHTML: "Extra options" }, [
$el("input", { type: "checkbox", $el("input", {
onchange: (i) => { type: "checkbox",
document.getElementById('extraOptions').style.display = i.srcElement.checked ? "block" : "none"; onchange: (i) => {
this.batchCount = i.srcElement.checked ? document.getElementById('batchCountInputRange').value : 1; document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none";
document.getElementById('autoQueueCheckbox').checked = false; this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1;
} document.getElementById("autoQueueCheckbox").checked = false;
}) },
])
]),
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" }}, [
$el("label", { innerHTML: "Batch count" }, [
$el("input", { id: "batchCountInputNumber", type: "number", value: this.batchCount, min: "1", style: { width: "35%", "margin-left": "0.4em" },
oninput: (i) => {
this.batchCount = i.target.value;
document.getElementById('batchCountInputRange').value = this.batchCount;
}
}), }),
$el("input", { id: "batchCountInputRange", type: "range", min: "1", max: "100", value: this.batchCount, ]),
]),
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" } }, [
$el("label", { innerHTML: "Batch count" }, [
$el("input", {
id: "batchCountInputNumber",
type: "number",
value: this.batchCount,
min: "1",
style: { width: "35%", "margin-left": "0.4em" },
oninput: (i) => {
this.batchCount = i.target.value;
document.getElementById("batchCountInputRange").value = this.batchCount;
},
}),
$el("input", {
id: "batchCountInputRange",
type: "range",
min: "1",
max: "100",
value: this.batchCount,
oninput: (i) => { oninput: (i) => {
this.batchCount = i.srcElement.value; this.batchCount = i.srcElement.value;
document.getElementById('batchCountInputNumber').value = i.srcElement.value; document.getElementById("batchCountInputNumber").value = i.srcElement.value;
} },
}),
$el("input", {
id: "autoQueueCheckbox",
type: "checkbox",
checked: false,
title: "automatically queue prompt when the queue size hits 0",
}), }),
$el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, title: "automatically queue prompt when the queue size hits 0",
})
]), ]),
]), ]),
$el("div.comfy-menu-btns", [ $el("div.comfy-menu-btns", [
@ -389,13 +527,19 @@ export class ComfyUI {
$el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { textContent: "Clear", onclick: () => { $el("button", { textContent: "Clear", onclick: () => {
app.clean(); if (!confirmClear.value || confirm("Clear workflow?")) {
app.graph.clear(); app.clean();
app.graph.clear();
}
}}),
$el("button", { textContent: "Load Default", onclick: () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
}
}}), }}),
$el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }),
]); ]);
dragElement(this.menuContainer); dragElement(this.menuContainer, this.settings);
this.setStatus({ exec_info: { queue_remaining: "X" } }); this.setStatus({ exec_info: { queue_remaining: "X" } });
} }
@ -403,10 +547,14 @@ export class ComfyUI {
setStatus(status) { setStatus(status) {
this.queueSize.textContent = "Queue size: " + (status ? status.exec_info.queue_remaining : "ERR"); this.queueSize.textContent = "Queue size: " + (status ? status.exec_info.queue_remaining : "ERR");
if (status) { if (status) {
if (this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && document.getElementById('autoQueueCheckbox').checked) { if (
this.lastQueueSize != 0 &&
status.exec_info.queue_remaining == 0 &&
document.getElementById("autoQueueCheckbox").checked
) {
app.queuePrompt(0, this.batchCount); app.queuePrompt(0, this.batchCount);
} }
this.lastQueueSize = status.exec_info.queue_remaining this.lastQueueSize = status.exec_info.queue_remaining;
} }
} }
} }

View File

@ -39,18 +39,19 @@ body {
position: fixed; /* Stay in place */ position: fixed; /* Stay in place */
z-index: 100; /* Sit on top */ z-index: 100; /* Sit on top */
padding: 30px 30px 10px 30px; padding: 30px 30px 10px 30px;
background-color: #ff0000; /* Modal background */ background-color: #353535; /* Modal background */
color: #ff4444;
box-shadow: 0px 0px 20px #888888; box-shadow: 0px 0px 20px #888888;
border-radius: 10px; border-radius: 10px;
text-align: center;
top: 50%; top: 50%;
left: 50%; left: 50%;
max-width: 80vw; max-width: 80vw;
max-height: 80vh; max-height: 80vh;
transform: translate(-50%, -50%); transform: translate(-50%, -50%);
overflow: hidden; overflow: hidden;
min-width: 60%;
justify-content: center; justify-content: center;
font-family: monospace;
font-size: 15px;
} }
.comfy-modal-content { .comfy-modal-content {
@ -70,31 +71,11 @@ body {
margin: 3px 3px 3px 4px; margin: 3px 3px 3px 4px;
} }
.comfy-modal button {
cursor: pointer;
color: #aaaaaa;
border: none;
background-color: transparent;
font-size: 24px;
font-weight: bold;
width: 100%;
}
.comfy-modal button:hover,
.comfy-modal button:focus {
color: #000;
text-decoration: none;
cursor: pointer;
}
.comfy-menu { .comfy-menu {
width: 200px;
font-size: 15px; font-size: 15px;
position: absolute; position: absolute;
top: 50%; top: 50%;
right: 0%; right: 0%;
background-color: white;
color: #000;
text-align: center; text-align: center;
z-index: 100; z-index: 100;
width: 170px; width: 170px;
@ -109,7 +90,8 @@ body {
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
} }
.comfy-menu button { .comfy-menu button,
.comfy-modal button {
font-size: 20px; font-size: 20px;
} }
@ -130,7 +112,8 @@ body {
.comfy-menu > button, .comfy-menu > button,
.comfy-menu-btns button, .comfy-menu-btns button,
.comfy-menu .comfy-list button { .comfy-menu .comfy-list button,
.comfy-modal button{
color: #ddd; color: #ddd;
background-color: #222; background-color: #222;
border-radius: 8px; border-radius: 8px;
@ -220,11 +203,22 @@ button.comfy-queue-btn {
} }
.comfy-modal.comfy-settings { .comfy-modal.comfy-settings {
background-color: var(--bg-color); text-align: center;
color: var(--fg-color); font-family: sans-serif;
color: #999;
z-index: 99; z-index: 99;
} }
.comfy-modal input,
.comfy-modal select {
color: #ddd;
background-color: #222;
border-radius: 8px;
border-color: #4e4e4e;
border-style: solid;
font-size: inherit;
}
@media only screen and (max-height: 850px) { @media only screen and (max-height: 850px) {
.comfy-menu { .comfy-menu {
top: 0 !important; top: 0 !important;
@ -237,3 +231,28 @@ button.comfy-queue-btn {
visibility:hidden visibility:hidden
} }
} }
.graphdialog {
min-height: 1em;
}
.graphdialog .name {
font-size: 14px;
font-family: sans-serif;
color: #999999;
}
.graphdialog button {
margin-top: unset;
vertical-align: unset;
height: 1.6em;
padding-right: 8px;
}
.graphdialog input, .graphdialog textarea, .graphdialog select {
background-color: #222;
border: 2px solid;
border-color: #444444;
color: #ddd;
border-radius: 12px 0 0 12px;
}