Merge branch 'comfyanonymous:master' into seedControls

This commit is contained in:
FizzleDorf 2023-04-06 11:32:49 -04:00 committed by GitHub
commit 77a4e42fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1612 additions and 286 deletions

View File

@ -24,6 +24,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- Starts up very fast.
- Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.

62
comfy/clip_vision.py Normal file
View File

@ -0,0 +1,62 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
from .utils import load_torch_file, transformers_convert
import os
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
def load_sd(self, sd):
self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
inputs = self.processor(images=[image[0]], return_tensors="pt")
outputs = self.model(**inputs)
return outputs
def convert_to_transformers(sd):
sd_k = sd.keys()
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k:
keys_to_replace = {
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding",
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight",
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight",
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias",
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight",
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias",
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight",
}
for x in keys_to_replace:
if x in sd_k:
sd[keys_to_replace[x]] = sd.pop(x)
if "embedder.model.visual.proj" in sd_k:
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1)
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32)
return sd
def load_clipvision_from_sd(sd):
sd = convert_to_transformers(sd)
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
clip = ClipVisionModel(json_config)
clip.load_sd(sd)
return clip
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_clipvision_from_sd(sd)

View File

@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1280,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 32,
"patch_size": 14,
"projection_dim": 1024,
"torch_dtype": "float32"
}

View File

@ -1,8 +1,4 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPVisionModel"
],
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
@ -18,6 +14,5 @@
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0"
"torch_dtype": "float32"
}

View File

@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
log = super().log_images(*args, **kwargs)
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
return log
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.embed_key = embedding_key
self.embedding_dropout = embedding_dropout
# self._init_embedder(embedder_config, freeze_embedder)
self._init_noise_aug(noise_aug_config)
def _init_embedder(self, config, freeze=True):
embedder = instantiate_from_config(config)
if freeze:
self.embedder = embedder.eval()
self.embedder.train = disabled_train
for param in self.embedder.parameters():
param.requires_grad = False
def _init_noise_aug(self, config):
if config is not None:
# use the KARLO schedule for noise augmentation on CLIP image embeddings
noise_augmentor = instantiate_from_config(config)
assert isinstance(noise_augmentor, nn.Module)
noise_augmentor = noise_augmentor.eval()
noise_augmentor.train = disabled_train
self.noise_augmentor = noise_augmentor
else:
self.noise_augmentor = None
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
z, c = outputs[0], outputs[1]
img = batch[self.embed_key][:bs]
img = rearrange(img, 'b h w c -> b c h w')
c_adm = self.embedder(img)
if self.noise_augmentor is not None:
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
# assume this gives embeddings of noise levels
c_adm = torch.cat((c_adm, noise_level_emb), 1)
if self.training:
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
device=c_adm.device)[:, None]) * c_adm
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
noutputs = [z, all_conds]
noutputs.extend(outputs[2:])
return noutputs
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, **kwargs):
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
return_original_cond=True)
log["inputs"] = x
log["reconstruction"] = xrec
assert self.model.conditioning_key is not None
assert self.cond_stage_key in ["caption", "txt"]
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
with ema_scope(f"Sampling"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_, )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log

View File

@ -307,7 +307,16 @@ def model_wrapper(
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition, condition])
if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict)
c_in = dict()
for k in condition:
if isinstance(condition[k], list):
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
else:
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
else:
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)

View File

@ -3,7 +3,6 @@ import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
@ -51,12 +50,20 @@ class DPMSolverSampler(object):
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
if isinstance(conditioning, torch.Tensor):
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
@ -83,6 +90,7 @@ class DPMSolverSampler(object):
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None
return x.to(device), None

View File

@ -9,7 +9,7 @@ from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management
if model_management.xformers_enabled():
if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled() and attn_type == "vanilla":
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"

View File

@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
return count_flops_attn(model, _x, y)
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
return timestep_embedding(t, self.dim)
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
@ -470,6 +479,7 @@ class UNetModel(nn.Module):
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
):
super().__init__()
if use_spatial_transformer:
@ -538,6 +548,15 @@ class UNetModel(nn.Module):
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()

View File

@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "squaredcos_cap_v2": # used for karlo prior
# return early
return betas_for_alpha_bar(
n_timestep,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
@ -218,6 +225,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
@ -267,4 +275,4 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
return repeat_noise() if repeat else noise()

View File

@ -0,0 +1,59 @@
from typing import List, Tuple, Union
import torch
import torch.nn as nn
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
r"""Normalize an image/video tensor with mean and standard deviation.
.. math::
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
Args:
data: Image tensor of size :math:`(B, C, *)`.
mean: Mean for each channel.
std: Standard deviations for each channel.
Return:
Normalised tensor with same size as input :math:`(B, C, *)`.
Examples:
>>> x = torch.rand(1, 4, 3, 3)
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3)
>>> mean = torch.zeros(4)
>>> std = 255. * torch.ones(4)
>>> out = normalize(x, mean, std)
>>> out.shape
torch.Size([1, 4, 3, 3])
"""
shape = data.shape
if len(mean.shape) == 0 or mean.shape[0] == 1:
mean = mean.expand(shape[1])
if len(std.shape) == 0 or std.shape[0] == 1:
std = std.expand(shape[1])
# Allow broadcast on channel dimension
if mean.shape and mean.shape[0] != 1:
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
# Allow broadcast on channel dimension
if std.shape and std.shape[0] != 1:
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
if mean.shape:
mean = mean[..., :, None]
if std.shape:
std = std[..., :, None]
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
return out.view(shape)

View File

@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from . import kornia_functions
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module):
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long()
c = self.embedding(c)
return c
@ -57,18 +58,20 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__()
@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return self(text)
class ClipImageEmbedder(nn.Module):
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.
):
super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
# x = kornia_functions.geometry_resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
# "pooled",
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last"):
super().__init__()
@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
pretrained=version, )
del model.transformer
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
# x = kornia.geometry.resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
clip_max_length=77, t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
def encode(self, text):
return self(text)
@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]

View File

@ -0,0 +1,35 @@
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.diffusionmodules.openaimodel import Timestep
import torch
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
super().__init__(*args, **kwargs)
if clip_stats_path is None:
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
else:
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
self.register_buffer("data_std", clip_std[None, :], persistent=False)
self.time_embed = Timestep(timestep_dim)
def scale(self, x):
# re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1. / self.data_std
return x
def unscale(self, x):
# back to original data stats
x = (x * self.data_std) + self.data_mean
return x
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
x = self.scale(x)
z = self.q_sample(x, noise_level)
z = self.unscale(z)
noise_level = self.time_embed(noise_level)
return z, noise_level

View File

@ -1,4 +1,4 @@
#Taken from: https://github.com/dbolya/tomesd
import torch
from typing import Tuple, Callable
@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
else:
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
@ -28,33 +38,49 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if r <= 0:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
if no_rand:
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64)
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device)
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype))
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1)
rand_idx = idx_buffer.argsort(dim=1)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
num_dst = int((1 / (sx*sy)) * N)
# Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
def split(x):
C = x.shape[-1]
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C))
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1)
@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c))
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out
@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w
downsample = int(math.sqrt(original_tokens // x.shape[1]))
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2
stride_y = 2
max_downsample = 1
if downsample <= max_downsample:
w = original_w // downsample
h = original_h // downsample
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio)
no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)

View File

@ -199,11 +199,25 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
def xformers_enabled():
if vram_state == CPU:
return False
return XFORMERS_IS_AVAILBLE
def xformers_enabled_vae():
enabled = xformers_enabled()
if not enabled:
return False
try:
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
if xformers.version.__version__ == "0.0.18":
return False
except:
pass
return enabled
def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION

View File

@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'strength' in cond[1]:
strength = cond[1]['strength']
adm_cond = None
if 'adm' in cond[1]:
adm_cond = cond[1]['adm']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength
@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
control = None
if 'control' in cond[1]:
control = cond[1]['control']
@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
if 'c_adm' in c1:
if c1['c_adm'].shape != c2['c_adm'].shape:
return False
return True
def can_concat_cond(c1, c2):
@ -92,16 +102,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
for x in c_list:
if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn'])
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {}
if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)]
if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm)
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
@ -327,6 +342,39 @@ def apply_control_net_to_equal_area(conds, uncond):
n['control'] = cond_cnets[x]
uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)):
x = conds[t]
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()
x[1]["adm"] = torch.cat([adm_out] * batch_size)
return conds
class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
@ -422,10 +470,14 @@ class KSampler:
else:
precision_scope = contextlib.nullcontext
if hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None
if hasattr(self.model, 'concat_keys'):
if hasattr(self.model, 'concat_keys'): #inpaint
cond_concat = []
for ck in self.model.concat_keys:
if denoise_mask is not None:

View File

@ -12,20 +12,7 @@ from .cldm import cldm
from .t2i_adapter import adapter
from . import utils
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
from . import clip_vision
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False)
@ -53,30 +40,7 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if x in sd:
sd[keys_to_replace[x]] = sd.pop(x)
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"mlp.c_fc": "mlp.fc1",
"mlp.c_proj": "mlp.fc2",
"attn.out_proj": "self_attn.out_proj",
}
for resblock in range(24):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y)
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y)
sd[k_to] = weights[1024*x:1024*(x + 1)]
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
for x in load_state_dict_to:
x.load_state_dict(sd, strict=False)
@ -123,7 +87,7 @@ LORA_UNET_MAP_RESNET = {
}
def load_lora(path, to_load):
lora = load_torch_file(path)
lora = utils.load_torch_file(path)
patch_dict = {}
loaded_keys = set()
for x in to_load:
@ -599,7 +563,7 @@ class ControlNet:
return out
def load_controlnet(ckpt_path, model=None):
controlnet_data = load_torch_file(ckpt_path)
controlnet_data = utils.load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False
sd2 = False
@ -793,7 +757,7 @@ class StyleModel:
def load_style_model(ckpt_path):
model_data = load_torch_file(ckpt_path)
model_data = utils.load_torch_file(ckpt_path)
keys = model_data.keys()
if "style_embedding" in keys:
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
@ -804,7 +768,7 @@ def load_style_model(ckpt_path):
def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path)
clip_data = utils.load_torch_file(ckpt_path)
config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
@ -847,7 +811,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
sd = load_torch_file(ckpt_path)
sd = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
@ -856,10 +820,11 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
return (ModelPatcher(model), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
sd = load_torch_file(ckpt_path)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
clip = None
clipvision = None
vae = None
fp16 = model_management.should_use_fp16()
@ -884,6 +849,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight"
noise_aug_config = None
if clipvision_key in sd_keys:
size = sd[clipvision_key].shape[1]
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd)
noise_aug_key = "noise_augmentor.betas"
if noise_aug_key in sd_keys:
noise_aug_config = {}
params = {}
noise_schedule_config = {}
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h
params["timestep_dim"] = 1024
elif size == 1024: #l
params["timestep_dim"] = 768
noise_aug_config['params'] = params
sd_config = {
"linear_start": 0.00085,
"linear_end": 0.012,
@ -932,7 +920,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
if unet_config["in_channels"] > 4: #inpainting model
if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
@ -944,6 +938,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
else:
unet_config["num_heads"] = 8 #SD1.x
unclip = 'model.diffusion_model.label_emb.0.0.weight'
if unclip in sd_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1]
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = sd[k]
@ -956,4 +955,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
if fp16:
model = model.half()
return (ModelPatcher(model), clip, vae)
return (ModelPatcher(model), clip, vae, clipvision)

View File

@ -74,9 +74,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if isinstance(y, int):
tokens_temp += [y]
else:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
if y.shape[0] == current_embeds.weight.shape[1]:
embedding_weights += [y]
tokens_temp += [next_new_token]
next_new_token += 1
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
out_tokens += [tokens_temp]
if len(embedding_weights) > 0:

View File

@ -1,5 +1,47 @@
import torch
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def transformers_convert(sd, prefix_from, prefix_to, number):
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"mlp.c_fc": "mlp.fc1",
"mlp.c_proj": "mlp.fc2",
"attn.out_proj": "self_attn.out_proj",
}
for resblock in range(number):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]

View File

@ -1,32 +0,0 @@
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
from comfy.sd import load_torch_file
import os
class ClipVisionModel():
def __init__(self):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json")
config = CLIPVisionConfig.from_json_file(json_config)
self.model = CLIPVisionModel(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
def load_sd(self, sd):
self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
inputs = self.processor(images=[image[0]], return_tensors="pt")
outputs = self.model(**inputs)
return outputs
def load(ckpt_path):
clip_data = load_torch_file(ckpt_path)
clip = ClipVisionModel()
clip.load_sd(clip_data)
return clip

View File

@ -0,0 +1,210 @@
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import comfy.utils
class Blend:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"blend_factor": ("FLOAT", {
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blend_images"
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
return (blended_image,)
def blend_mode(self, img1, img2, mode):
if mode == "normal":
return img2
elif mode == "multiply":
return img1 * img2
elif mode == "screen":
return 1 - (1 - img1) * (1 - img2)
elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
else:
raise ValueError(f"Unsupported blend mode: {mode}")
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
class Blur:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"blur_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blur"
CATEGORY = "image/postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
blurred = blurred.permute(0, 2, 3, 1)
return (blurred,)
class Quantize:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"colors": ("INT", {
"default": 256,
"min": 1,
"max": 256,
"step": 1
}),
"dither": (["none", "floyd-steinberg"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "quantize"
CATEGORY = "image/postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array
return (result,)
class Sharpen:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"sharpen_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 5.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sharpen"
CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
if sharpen_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1)
return (result,)
NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend,
"ImageBlur": Blur,
"ImageQuantize": Quantize,
"ImageSharpen": Sharpen,
}

View File

@ -1,6 +1,5 @@
import os
from comfy_extras.chainner_models import model_loading
from comfy.sd import load_torch_file
import model_management
import torch
import comfy.utils
@ -18,7 +17,7 @@ class UpscaleModelLoader:
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = load_torch_file(model_path)
sd = comfy.utils.load_torch_file(model_path)
out = model_loading.load_state_dict(sd).eval()
return (out, )

View File

@ -11,6 +11,8 @@ class Example:
----------
RETURN_TYPES (`tuple`):
The type of each element in the output tulple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tulple.
FUNCTION (`str`):
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]):
@ -61,6 +63,8 @@ class Example:
}
RETURN_TYPES = ("IMAGE",)
#RETURN_NAMES = ("image_output_name",)
FUNCTION = "test"
#OUTPUT_NODE = False

View File

@ -27,6 +27,40 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
if not os.path.exists(input_directory):
os.makedirs(input_directory)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def get_output_directory():
global output_directory
return output_directory
def get_temp_directory():
global temp_directory
return temp_directory
def get_input_directory():
global input_directory
return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
if type_name == "output":
return get_output_directory()
if type_name == "temp":
return get_temp_directory()
if type_name == "input":
return get_input_directory()
return None
def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths

18
main.py
View File

@ -11,9 +11,15 @@ if os.name == "nt":
if __name__ == "__main__":
if '--help' in sys.argv:
print()
print("Valid Command line Arguments:")
print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
print("\t--port 8188\t\t\tSet the listen port.")
print()
print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.")
print("\t--output-directory path/to/output\tSet the ComfyUI output directory.")
print()
print()
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
@ -40,6 +46,7 @@ if __name__ == "__main__":
except:
pass
from nodes import init_custom_nodes
import execution
import server
import folder_paths
@ -98,6 +105,8 @@ if __name__ == "__main__":
server = server.PromptServer(loop)
q = execution.PromptQueue(server)
init_custom_nodes()
server.add_routes()
hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
@ -113,7 +122,6 @@ if __name__ == "__main__":
except:
address = '127.0.0.1'
dont_print = False
if '--dont-print-server' in sys.argv:
dont_print = True
@ -127,6 +135,14 @@ if __name__ == "__main__":
for i in indices:
load_extra_path_config(sys.argv[i])
try:
output_dir = sys.argv[sys.argv.index('--output-directory') + 1]
output_dir = os.path.abspath(output_dir)
print("setting output directory to:", output_dir)
folder_paths.set_output_directory(output_dir)
except:
pass
port = 8188
try:
p_index = sys.argv.index('--port')

View File

@ -18,7 +18,7 @@ import comfy.samplers
import comfy.sd
import comfy.utils
import comfy_extras.clip_vision
import comfy.clip_vision
import model_management
import importlib
@ -197,7 +197,7 @@ class CheckpointLoader:
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
CATEGORY = "advanced/loaders"
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name)
@ -219,6 +219,21 @@ class CheckpointLoaderSimple:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
class unCLIPCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
class CLIPSetLastLayer:
@classmethod
def INPUT_TYPES(s):
@ -370,7 +385,7 @@ class CLIPVisionLoader:
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_vision = comfy_extras.clip_vision.load(clip_path)
clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,)
class CLIPVisionEncode:
@ -382,7 +397,7 @@ class CLIPVisionEncode:
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning/style_model"
CATEGORY = "conditioning"
def encode(self, clip_vision, image):
output = clip_vision.encode_image(image)
@ -424,6 +439,33 @@ class StyleModelApply:
c.append(n)
return (c, )
class unCLIPConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_adm"
CATEGORY = "conditioning"
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
c = []
for t in conditioning:
o = t[1].copy()
x = (clip_vision_output, strength, noise_augmentation)
if "adm" in o:
o["adm"] = o["adm"][:] + [x]
else:
o["adm"] = [x]
n = [t[0], o]
c.append(n)
return (c, )
class EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
@ -735,7 +777,7 @@ class KSamplerAdvanced:
class SaveImage:
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
@classmethod
@ -787,9 +829,6 @@ class SaveImage:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
results = list()
for image in images:
i = 255. * image.cpu().numpy()
@ -814,7 +853,7 @@ class SaveImage:
class PreviewImage(SaveImage):
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
@classmethod
@ -825,13 +864,11 @@ class PreviewImage(SaveImage):
}
class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
def INPUT_TYPES(s):
if not os.path.exists(s.input_dir):
os.makedirs(s.input_dir)
input_dir = folder_paths.get_input_directory()
return {"required":
{"image": (sorted(os.listdir(s.input_dir)), )},
{"image": (sorted(os.listdir(input_dir)), )},
}
CATEGORY = "image"
@ -839,7 +876,8 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
image_path = os.path.join(self.input_dir, image)
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
@ -853,18 +891,19 @@ class LoadImage:
@classmethod
def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image)
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
return {"required":
{"image": (sorted(os.listdir(s.input_dir)), ),
{"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),}
}
@ -873,7 +912,8 @@ class LoadImageMask:
RETURN_TYPES = ("MASK",)
FUNCTION = "load_image"
def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image)
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path)
mask = None
c = channel[0].upper()
@ -888,7 +928,8 @@ class LoadImageMask:
@classmethod
def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image)
input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
@ -996,7 +1037,6 @@ class ImagePadForOutpaint:
NODE_CLASS_MAPPINGS = {
"KSampler": KSampler,
"CheckpointLoader": CheckpointLoader,
"CheckpointLoaderSimple": CheckpointLoaderSimple,
"CLIPTextEncode": CLIPTextEncode,
"CLIPSetLastLayer": CLIPSetLastLayer,
@ -1025,6 +1065,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPLoader": CLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply,
"unCLIPConditioning": unCLIPConditioning,
"ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader,
@ -1033,6 +1074,8 @@ NODE_CLASS_MAPPINGS = {
"VAEDecodeTiled": VAEDecodeTiled,
"VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"CheckpointLoader": CheckpointLoader,
}
def load_custom_node(module_path):
@ -1067,6 +1110,7 @@ def load_custom_nodes():
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path)
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))

View File

@ -47,7 +47,7 @@
" !git pull\n",
"\n",
"!echo -= Install dependencies =-\n",
"!pip install xformers==0.0.16 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117"
"!pip install xformers -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118"
]
},
{

View File

@ -42,6 +42,7 @@ class PromptServer():
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web")
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
self.client_id = None
@ -88,7 +89,7 @@ class PromptServer():
@routes.post("/upload/image")
async def upload_image(request):
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
upload_dir = folder_paths.get_input_directory()
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
@ -121,10 +122,10 @@ class PromptServer():
async def view_image(request):
if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output")
if type not in ["output", "input", "temp"]:
output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
@ -239,8 +240,9 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200)
self.app.add_routes(routes)
def add_routes(self):
self.app.add_routes(self.routes)
self.app.add_routes([
web.static('/', self.web_root),
])

View File

@ -0,0 +1,137 @@
import { app } from "/scripts/app.js";
// Adds filtering to combo context menus
const id = "Comfy.ContextMenuFilter";
app.registerExtension({
name: id,
init() {
const ctxMenu = LiteGraph.ContextMenu;
LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input");
Object.assign(filter.style, {
width: "calc(100% - 10px)",
border: "0",
boxSizing: "border-box",
background: "#333",
border: "1px solid #999",
margin: "0 0 5px 5px",
color: "#fff",
});
filter.placeholder = "Filter list";
this.root.prepend(filter);
let selectedIndex = 0;
let items = this.root.querySelectorAll(".litemenu-entry");
let itemCount = items.length;
let selectedItem;
// Apply highlighting to the selected item
function updateSelected() {
if (selectedItem) {
selectedItem.style.setProperty("background-color", "");
selectedItem.style.setProperty("color", "");
}
selectedItem = items[selectedIndex];
if (selectedItem) {
selectedItem.style.setProperty("background-color", "#ccc", "important");
selectedItem.style.setProperty("color", "#000", "important");
}
}
const positionList = () => {
const rect = this.root.getBoundingClientRect();
// If the top is off screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
updateSelected();
// Arrow up/down to select items
filter.addEventListener("keydown", (e) => {
if (e.key === "ArrowUp") {
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
e.preventDefault();
} else if (e.key === "ArrowDown") {
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
e.preventDefault();
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
selectedItem.click();
} else if(e.key === "Escape") {
this.close();
}
});
filter.addEventListener("input", () => {
// Hide all items that dont match our filter
const term = filter.value.toLocaleLowerCase();
items = this.root.querySelectorAll(".litemenu-entry");
// When filtering recompute which items are visible for arrow up/down
// Try and maintain selection
let visibleItems = [];
for (const item of items) {
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
if (visible) {
item.style.display = "block";
if (item === selectedItem) {
selectedIndex = visibleItems.length;
}
visibleItems.push(item);
} else {
item.style.display = "none";
if (item === selectedItem) {
selectedIndex = 0;
}
}
}
items = visibleItems;
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
positionList();
});
}
return ctx;
};
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
},
});

View File

@ -30,7 +30,8 @@ app.registerExtension({
}
// Overwrite the value in the serialized workflow pnginfo
workflowNode.widgets_values[widgetIndex] = prompt;
if (workflowNode?.widgets_values)
workflowNode.widgets_values[widgetIndex] = prompt;
return prompt;
};

View File

@ -3,10 +3,10 @@ import { app } from "/scripts/app.js";
// Inverts the scrolling of context menus
const id = "Comfy.InvertMenuScrolling";
const ctxMenu = LiteGraph.ContextMenu;
app.registerExtension({
name: id,
init() {
const ctxMenu = LiteGraph.ContextMenu;
const replace = () => {
LiteGraph.ContextMenu = function (values, options) {
options = options || {};

View File

@ -11,11 +11,14 @@ app.registerExtension({
this.properties = {};
}
this.properties.showOutputText = RerouteNode.defaultVisibility;
this.properties.horizontal = false;
this.addInput("", "*");
this.addOutput(this.properties.showOutputText ? "*" : "", "*");
this.onConnectionsChange = function (type, index, connected, link_info) {
this.applyOrientation();
// Prevent multiple connections to different types when we have no input
if (connected && type === LiteGraph.OUTPUT) {
// Ignore wildcard nodes as these will be updated to real types
@ -43,12 +46,19 @@ app.registerExtension({
const node = app.graph.getNodeById(link.origin_id);
const type = node.constructor.type;
if (type === "Reroute") {
// Move the previous node
currentNode = node;
if (node === this) {
// We've found a circle
currentNode.disconnectInput(link.target_slot);
currentNode = null;
}
else {
// Move the previous node
currentNode = node;
}
} else {
// We've found the end
inputNode = currentNode;
inputType = node.outputs[link.origin_slot].type;
inputType = node.outputs[link.origin_slot]?.type ?? null;
break;
}
} else {
@ -80,7 +90,7 @@ app.registerExtension({
updateNodes.push(node);
} else {
// We've found an output
const nodeOutType = node.inputs[link.target_slot].type;
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot);
@ -105,6 +115,7 @@ app.registerExtension({
node.__outputType = displayType;
node.outputs[0].name = node.properties.showOutputText ? displayType : "";
node.size = node.computeSize();
node.applyOrientation();
for (const l of node.outputs[0].links || []) {
const link = app.graph.links[l];
@ -146,6 +157,7 @@ app.registerExtension({
this.outputs[0].name = "";
}
this.size = this.computeSize();
this.applyOrientation();
app.graph.setDirtyCanvas(true, true);
},
},
@ -154,9 +166,32 @@ app.registerExtension({
callback: () => {
RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility);
},
},
{
// naming is inverted with respect to LiteGraphNode.horizontal
// LiteGraphNode.horizontal == true means that
// each slot in the inputs and outputs are layed out horizontally,
// which is the opposite of the visual orientation of the inputs and outputs as a node
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
callback: () => {
this.properties.horizontal = !this.properties.horizontal;
this.applyOrientation();
},
}
);
}
applyOrientation() {
this.horizontal = this.properties.horizontal;
if (this.horizontal) {
// we correct the input position, because LiteGraphNode.horizontal
// doesn't account for title presence
// which reroute nodes don't have
this.inputs[0].pos = [this.size[0] / 2, 0];
} else {
delete this.inputs[0].pos;
}
app.graph.setDirtyCanvas(true, true);
}
computeSize() {
return [

View File

@ -0,0 +1,21 @@
import { app } from "/scripts/app.js";
// Adds defaults for quickly adding nodes with middle click on the input/output
app.registerExtension({
name: "Comfy.SlotDefaults",
init() {
LiteGraph.middle_click_slot_add_default_node = true;
LiteGraph.slot_types_default_in = {
MODEL: "CheckpointLoaderSimple",
LATENT: "EmptyLatentImage",
VAE: "VAELoader",
};
LiteGraph.slot_types_default_out = {
LATENT: "VAEDecode",
IMAGE: "SaveImage",
CLIP: "CLIPTextEncode",
};
},
});

View File

@ -0,0 +1,89 @@
import { app } from "/scripts/app.js";
// Shift + drag/resize to snap to grid
app.registerExtension({
name: "Comfy.SnapToGrid",
init() {
// Add setting to control grid size
app.ui.settings.addSetting({
id: "Comfy.SnapToGrid.GridSize",
name: "Grid Size",
type: "number",
attrs: {
min: 1,
max: 500,
},
tooltip:
"When dragging and resizing nodes while holding shift they will be aligned to the grid, this controls the size of that grid.",
defaultValue: LiteGraph.CANVAS_GRID_SIZE,
onChange(value) {
LiteGraph.CANVAS_GRID_SIZE = +value;
},
});
// After moving a node, if the shift key is down align it to grid
const onNodeMoved = app.canvas.onNodeMoved;
app.canvas.onNodeMoved = function (node) {
const r = onNodeMoved?.apply(this, arguments);
if (app.shiftDown) {
// Ensure all selected nodes are realigned
for (const id in this.selected_nodes) {
this.selected_nodes[id].alignToGrid();
}
}
return r;
};
// When a node is added, add a resize handler to it so we can fix align the size with the grid
const onNodeAdded = app.graph.onNodeAdded;
app.graph.onNodeAdded = function (node) {
const onResize = node.onResize;
node.onResize = function () {
if (app.shiftDown) {
const w = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[0] / LiteGraph.CANVAS_GRID_SIZE);
const h = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[1] / LiteGraph.CANVAS_GRID_SIZE);
node.size[0] = w;
node.size[1] = h;
}
return onResize?.apply(this, arguments);
};
return onNodeAdded?.apply(this, arguments);
};
// Draw a preview of where the node will go if holding shift and the node is selected
const origDrawNode = LGraphCanvas.prototype.drawNode;
LGraphCanvas.prototype.drawNode = function (node, ctx) {
if (app.shiftDown && this.node_dragged && node.id in this.selected_nodes) {
const x = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[0] / LiteGraph.CANVAS_GRID_SIZE);
const y = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[1] / LiteGraph.CANVAS_GRID_SIZE);
const shiftX = x - node.pos[0];
let shiftY = y - node.pos[1];
let w, h;
if (node.flags.collapsed) {
w = node._collapsed_width;
h = LiteGraph.NODE_TITLE_HEIGHT;
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
} else {
w = node.size[0];
h = node.size[1];
let titleMode = node.constructor.title_mode;
if (titleMode !== LiteGraph.TRANSPARENT_TITLE && titleMode !== LiteGraph.NO_TITLE) {
h += LiteGraph.NODE_TITLE_HEIGHT;
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
}
}
const f = ctx.fillStyle;
ctx.fillStyle = "rgba(100, 100, 100, 0.5)";
ctx.fillRect(shiftX, shiftY, w, h);
ctx.fillStyle = f;
}
return origDrawNode.apply(this, arguments);
};
},
});

View File

@ -25,7 +25,7 @@ function hideWidget(node, widget, suffix = "") {
if (link == null) {
return undefined;
}
return widget.value;
return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
};
// Hide any linked widgets, e.g. seed+seedControl

View File

@ -5,10 +5,20 @@ import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js";
class ComfyApp {
/**
* List of {number, batchCount} entries to queue
*/
#queueItems = [];
/**
* If the queue is currently being processed
*/
#processingQueue = false;
constructor() {
this.ui = new ComfyUI(this);
this.extensions = [];
this.nodeOutputs = {};
this.shiftDown = false;
}
/**
@ -102,6 +112,46 @@ class ComfyApp {
};
}
#addNodeKeyHandler(node) {
const app = this;
const origNodeOnKeyDown = node.prototype.onKeyDown;
node.prototype.onKeyDown = function(e) {
if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) {
return false;
}
if (this.flags.collapsed || !this.imgs || this.imageIndex === null) {
return;
}
let handled = false;
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
if (e.key === "ArrowLeft") {
this.imageIndex -= 1;
} else if (e.key === "ArrowRight") {
this.imageIndex += 1;
}
this.imageIndex %= this.imgs.length;
if (this.imageIndex < 0) {
this.imageIndex = this.imgs.length + this.imageIndex;
}
handled = true;
} else if (e.key === "Escape") {
this.imageIndex = null;
handled = true;
}
if (handled === true) {
e.preventDefault();
e.stopImmediatePropagation();
return false;
}
}
}
/**
* Adds Custom drawing logic for nodes
* e.g. Draws images and handles thumbnail navigation on nodes that output images
@ -628,11 +678,16 @@ class ComfyApp {
#addKeyboardHandler() {
window.addEventListener("keydown", (e) => {
this.shiftDown = e.shiftKey;
// Queue prompt using ctrl or command + enter
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
this.queuePrompt(e.shiftKey ? -1 : 0);
}
});
window.addEventListener("keyup", (e) => {
this.shiftDown = e.shiftKey;
});
}
/**
@ -667,6 +722,9 @@ class ComfyApp {
const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph));
this.ctx = canvasEl.getContext("2d");
LiteGraph.release_link_on_empty_shows_menu = true;
LiteGraph.alt_drag_do_clone_nodes = true;
this.graph.start();
function resizeCanvas() {
@ -785,6 +843,7 @@ class ComfyApp {
this.#addNodeContextMenuHandler(node);
this.#addDrawBackgroundHandler(node, app);
this.#addNodeKeyHandler(node);
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
LiteGraph.registerNodeType(nodeId, node);
@ -919,31 +978,47 @@ class ComfyApp {
}
async queuePrompt(number, batchCount = 1) {
for (let i = 0; i < batchCount; i++) {
const p = await this.graphToPrompt();
this.#queueItems.push({ number, batchCount });
try {
await api.queuePrompt(number, p);
} catch (error) {
this.ui.dialog.show(error.response || error.toString());
return;
}
// Only have one action process the items so each one gets a unique seed correctly
if (this.#processingQueue) {
return;
}
this.#processingQueue = true;
try {
while (this.#queueItems.length) {
({ number, batchCount } = this.#queueItems.pop());
for (const n of p.workflow.nodes) {
const node = graph.getNodeById(n.id);
if (node.widgets) {
for (const widget of node.widgets) {
// Allow widgets to run callbacks after a prompt has been queued
// e.g. random seed after every gen
if (widget.afterQueued) {
widget.afterQueued();
for (let i = 0; i < batchCount; i++) {
const p = await this.graphToPrompt();
try {
await api.queuePrompt(number, p);
} catch (error) {
this.ui.dialog.show(error.response || error.toString());
break;
}
for (const n of p.workflow.nodes) {
const node = graph.getNodeById(n.id);
if (node.widgets) {
for (const widget of node.widgets) {
// Allow widgets to run callbacks after a prompt has been queued
// e.g. random seed after every gen
if (widget.afterQueued) {
widget.afterQueued();
}
}
}
}
this.canvas.draw(true, true);
await this.ui.queue.update();
}
}
this.canvas.draw(true, true);
await this.ui.queue.update();
} finally {
this.#processingQueue = false;
}
}

View File

@ -35,21 +35,86 @@ export function $el(tag, propsOrChildren, children) {
return element;
}
function dragElement(dragEl) {
function dragElement(dragEl, settings) {
var posDiffX = 0,
posDiffY = 0,
posStartX = 0,
posStartY = 0,
newPosX = 0,
newPosY = 0;
if (dragEl.getElementsByClassName('drag-handle')[0]) {
if (dragEl.getElementsByClassName("drag-handle")[0]) {
// if present, the handle is where you move the DIV from:
dragEl.getElementsByClassName('drag-handle')[0].onmousedown = dragMouseDown;
dragEl.getElementsByClassName("drag-handle")[0].onmousedown = dragMouseDown;
} else {
// otherwise, move the DIV from anywhere inside the DIV:
dragEl.onmousedown = dragMouseDown;
}
// When the element resizes (e.g. view queue) ensure it is still in the windows bounds
const resizeObserver = new ResizeObserver(() => {
ensureInBounds();
}).observe(dragEl);
function ensureInBounds() {
if (dragEl.classList.contains("comfy-menu-manual-pos")) {
newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft));
newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop));
positionElement();
}
}
function positionElement() {
const halfWidth = document.body.clientWidth / 2;
const anchorRight = newPosX + dragEl.clientWidth / 2 > halfWidth;
// set the element's new position:
if (anchorRight) {
dragEl.style.left = "unset";
dragEl.style.right = document.body.clientWidth - newPosX - dragEl.clientWidth + "px";
} else {
dragEl.style.left = newPosX + "px";
dragEl.style.right = "unset";
}
dragEl.style.top = newPosY + "px";
dragEl.style.bottom = "unset";
if (savePos) {
localStorage.setItem(
"Comfy.MenuPosition",
JSON.stringify({
x: dragEl.offsetLeft,
y: dragEl.offsetTop,
})
);
}
}
function restorePos() {
let pos = localStorage.getItem("Comfy.MenuPosition");
if (pos) {
pos = JSON.parse(pos);
newPosX = pos.x;
newPosY = pos.y;
positionElement();
ensureInBounds();
}
}
let savePos = undefined;
settings.addSetting({
id: "Comfy.MenuPosition",
name: "Save menu position",
type: "boolean",
defaultValue: savePos,
onChange(value) {
if (savePos === undefined && value) {
restorePos();
}
savePos = value;
},
});
function dragMouseDown(e) {
e = e || window.event;
e.preventDefault();
@ -64,18 +129,25 @@ function dragElement(dragEl) {
function elementDrag(e) {
e = e || window.event;
e.preventDefault();
dragEl.classList.add("comfy-menu-manual-pos");
// calculate the new cursor position:
posDiffX = e.clientX - posStartX;
posDiffY = e.clientY - posStartY;
posStartX = e.clientX;
posStartY = e.clientY;
newPosX = Math.min((document.body.clientWidth - dragEl.clientWidth), Math.max(0, (dragEl.offsetLeft + posDiffX)));
newPosY = Math.min((document.body.clientHeight - dragEl.clientHeight), Math.max(0, (dragEl.offsetTop + posDiffY)));
// set the element's new position:
dragEl.style.top = newPosY + "px";
dragEl.style.left = newPosX + "px";
newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft + posDiffX));
newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop + posDiffY));
positionElement();
}
window.addEventListener("resize", () => {
ensureInBounds();
});
function closeDragElement() {
// stop moving when mouse button is released:
document.onmouseup = null;
@ -90,7 +162,7 @@ class ComfyDialog {
$el("p", { $: (p) => (this.textElement = p) }),
$el("button", {
type: "button",
textContent: "CLOSE",
textContent: "Close",
onclick: () => this.close(),
}),
]),
@ -125,7 +197,7 @@ class ComfySettingsDialog extends ComfyDialog {
localStorage[settingId] = JSON.stringify(value);
}
addSetting({ id, name, type, defaultValue, onChange }) {
addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", }) {
if (!id) {
throw new Error("Settings must have an ID");
}
@ -152,42 +224,83 @@ class ComfySettingsDialog extends ComfyDialog {
value = v;
};
let element;
value = this.getSettingValue(id, defaultValue);
if (typeof type === "function") {
return type(name, setter, value);
element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type: "checkbox",
checked: !!value,
oninput: (e) => {
setter(e.target.checked);
},
...attrs
}),
]),
]);
break;
case "number":
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type,
value,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
default:
console.warn("Unsupported setting type, defaulting to text");
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
value,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
}
}
if(tooltip) {
element.title = tooltip;
}
switch (type) {
case "boolean":
return $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type: "checkbox",
checked: !!value,
oninput: (e) => {
setter(e.target.checked);
},
}),
]),
]);
default:
console.warn("Unsupported setting type, defaulting to text");
return $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
value,
oninput: (e) => {
setter(e.target.value);
},
}),
]),
]);
}
return element;
},
});
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
}
show() {
super.show();
Object.assign(this.textElement.style, {
display: "flex",
flexDirection: "column",
gap: "10px"
});
this.textElement.replaceChildren(...this.settings.map((s) => s.render()));
}
}
@ -300,6 +413,13 @@ export class ComfyUI {
this.history.update();
});
const confirmClear = this.settings.addSetting({
id: "Comfy.ConfirmClear",
name: "Require confirmation when clearing workflow",
type: "boolean",
defaultValue: true,
});
const fileInput = $el("input", {
type: "file",
accept: ".json,image/png",
@ -311,39 +431,57 @@ export class ComfyUI {
});
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
$el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [
$el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [
$el("span.drag-handle"),
$el("span", { $: (q) => (this.queueSize = q) }),
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
]),
$el("button.comfy-queue-btn", { textContent: "Queue Prompt", onclick: () => app.queuePrompt(0, this.batchCount) }),
$el("button.comfy-queue-btn", {
textContent: "Queue Prompt",
onclick: () => app.queuePrompt(0, this.batchCount),
}),
$el("div", {}, [
$el("label", { innerHTML: "Extra options"}, [
$el("input", { type: "checkbox",
onchange: (i) => {
document.getElementById('extraOptions').style.display = i.srcElement.checked ? "block" : "none";
this.batchCount = i.srcElement.checked ? document.getElementById('batchCountInputRange').value : 1;
document.getElementById('autoQueueCheckbox').checked = false;
}
})
])
]),
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" }}, [
$el("label", { innerHTML: "Batch count" }, [
$el("input", { id: "batchCountInputNumber", type: "number", value: this.batchCount, min: "1", style: { width: "35%", "margin-left": "0.4em" },
oninput: (i) => {
this.batchCount = i.target.value;
document.getElementById('batchCountInputRange').value = this.batchCount;
}
$el("label", { innerHTML: "Extra options" }, [
$el("input", {
type: "checkbox",
onchange: (i) => {
document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none";
this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1;
document.getElementById("autoQueueCheckbox").checked = false;
},
}),
$el("input", { id: "batchCountInputRange", type: "range", min: "1", max: "100", value: this.batchCount,
]),
]),
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" } }, [
$el("label", { innerHTML: "Batch count" }, [
$el("input", {
id: "batchCountInputNumber",
type: "number",
value: this.batchCount,
min: "1",
style: { width: "35%", "margin-left": "0.4em" },
oninput: (i) => {
this.batchCount = i.target.value;
document.getElementById("batchCountInputRange").value = this.batchCount;
},
}),
$el("input", {
id: "batchCountInputRange",
type: "range",
min: "1",
max: "100",
value: this.batchCount,
oninput: (i) => {
this.batchCount = i.srcElement.value;
document.getElementById('batchCountInputNumber').value = i.srcElement.value;
}
document.getElementById("batchCountInputNumber").value = i.srcElement.value;
},
}),
$el("input", {
id: "autoQueueCheckbox",
type: "checkbox",
checked: false,
title: "automatically queue prompt when the queue size hits 0",
}),
$el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, title: "automatically queue prompt when the queue size hits 0",
})
]),
]),
$el("div.comfy-menu-btns", [
@ -389,13 +527,19 @@ export class ComfyUI {
$el("button", { textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { textContent: "Clear", onclick: () => {
app.clean();
app.graph.clear();
if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean();
app.graph.clear();
}
}}),
$el("button", { textContent: "Load Default", onclick: () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
}
}}),
$el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }),
]);
dragElement(this.menuContainer);
dragElement(this.menuContainer, this.settings);
this.setStatus({ exec_info: { queue_remaining: "X" } });
}
@ -403,10 +547,14 @@ export class ComfyUI {
setStatus(status) {
this.queueSize.textContent = "Queue size: " + (status ? status.exec_info.queue_remaining : "ERR");
if (status) {
if (this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && document.getElementById('autoQueueCheckbox').checked) {
if (
this.lastQueueSize != 0 &&
status.exec_info.queue_remaining == 0 &&
document.getElementById("autoQueueCheckbox").checked
) {
app.queuePrompt(0, this.batchCount);
}
this.lastQueueSize = status.exec_info.queue_remaining
this.lastQueueSize = status.exec_info.queue_remaining;
}
}
}

View File

@ -39,18 +39,19 @@ body {
position: fixed; /* Stay in place */
z-index: 100; /* Sit on top */
padding: 30px 30px 10px 30px;
background-color: #ff0000; /* Modal background */
background-color: #353535; /* Modal background */
color: #ff4444;
box-shadow: 0px 0px 20px #888888;
border-radius: 10px;
text-align: center;
top: 50%;
left: 50%;
max-width: 80vw;
max-height: 80vh;
transform: translate(-50%, -50%);
overflow: hidden;
min-width: 60%;
justify-content: center;
font-family: monospace;
font-size: 15px;
}
.comfy-modal-content {
@ -70,31 +71,11 @@ body {
margin: 3px 3px 3px 4px;
}
.comfy-modal button {
cursor: pointer;
color: #aaaaaa;
border: none;
background-color: transparent;
font-size: 24px;
font-weight: bold;
width: 100%;
}
.comfy-modal button:hover,
.comfy-modal button:focus {
color: #000;
text-decoration: none;
cursor: pointer;
}
.comfy-menu {
width: 200px;
font-size: 15px;
position: absolute;
top: 50%;
right: 0%;
background-color: white;
color: #000;
text-align: center;
z-index: 100;
width: 170px;
@ -109,7 +90,8 @@ body {
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
}
.comfy-menu button {
.comfy-menu button,
.comfy-modal button {
font-size: 20px;
}
@ -130,7 +112,8 @@ body {
.comfy-menu > button,
.comfy-menu-btns button,
.comfy-menu .comfy-list button {
.comfy-menu .comfy-list button,
.comfy-modal button{
color: #ddd;
background-color: #222;
border-radius: 8px;
@ -220,11 +203,22 @@ button.comfy-queue-btn {
}
.comfy-modal.comfy-settings {
background-color: var(--bg-color);
color: var(--fg-color);
text-align: center;
font-family: sans-serif;
color: #999;
z-index: 99;
}
.comfy-modal input,
.comfy-modal select {
color: #ddd;
background-color: #222;
border-radius: 8px;
border-color: #4e4e4e;
border-style: solid;
font-size: inherit;
}
@media only screen and (max-height: 850px) {
.comfy-menu {
top: 0 !important;
@ -237,3 +231,28 @@ button.comfy-queue-btn {
visibility:hidden
}
}
.graphdialog {
min-height: 1em;
}
.graphdialog .name {
font-size: 14px;
font-family: sans-serif;
color: #999999;
}
.graphdialog button {
margin-top: unset;
vertical-align: unset;
height: 1.6em;
padding-right: 8px;
}
.graphdialog input, .graphdialog textarea, .graphdialog select {
background-color: #222;
border: 2px solid;
border-color: #444444;
color: #ddd;
border-radius: 12px 0 0 12px;
}