Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Silversith 2023-04-07 14:39:46 +02:00
commit a52b976dd5
45 changed files with 2350 additions and 444 deletions

View File

@ -14,7 +14,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
- Works even if you don't have a GPU with: ```--cpu``` (slow) - Works even if you don't have a GPU with: ```--cpu``` (slow)
- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion - Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- Loading full workflows (with seeds) from generated PNG files. - Loading full workflows (with seeds) from generated PNG files.
@ -24,6 +24,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. - [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models. - [Config file](extra_model_paths.yaml.example) to set the search paths for models.

31
comfy/cli_args.py Normal file
View File

@ -0,0 +1,31 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.")
args = parser.parse_args()

62
comfy/clip_vision.py Normal file
View File

@ -0,0 +1,62 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
from .utils import load_torch_file, transformers_convert
import os
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
self.model = CLIPVisionModelWithProjection(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
def load_sd(self, sd):
self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
inputs = self.processor(images=[image[0]], return_tensors="pt")
outputs = self.model(**inputs)
return outputs
def convert_to_transformers(sd):
sd_k = sd.keys()
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k:
keys_to_replace = {
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding",
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight",
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight",
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias",
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight",
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias",
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight",
}
for x in keys_to_replace:
if x in sd_k:
sd[keys_to_replace[x]] = sd.pop(x)
if "embedder.model.visual.proj" in sd_k:
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1)
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32)
return sd
def load_clipvision_from_sd(sd):
sd = convert_to_transformers(sd)
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
clip = ClipVisionModel(json_config)
clip.load_sd(sd)
return clip
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_clipvision_from_sd(sd)

View File

@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1280,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 32,
"patch_size": 14,
"projection_dim": 1024,
"torch_dtype": "float32"
}

View File

@ -1,8 +1,4 @@
{ {
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPVisionModel"
],
"attention_dropout": 0.0, "attention_dropout": 0.0,
"dropout": 0.0, "dropout": 0.0,
"hidden_act": "quick_gelu", "hidden_act": "quick_gelu",
@ -18,6 +14,5 @@
"num_hidden_layers": 24, "num_hidden_layers": 24,
"patch_size": 14, "patch_size": 14,
"projection_dim": 768, "projection_dim": 768,
"torch_dtype": "float32", "torch_dtype": "float32"
"transformers_version": "4.24.0"
} }

362
comfy/diffusers_convert.py Normal file
View File

@ -0,0 +1,362 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
# =================#
# UNet Conversion #
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("out.0.weight", "conv_norm_out.weight"),
("out.0.bias", "conv_norm_out.bias"),
("out.2.weight", "conv_out.weight"),
("out.2.bias", "conv_out.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
def convert_unet_state_dict(unet_state_dict):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
("nin_shortcut", "conv_shortcut"),
("norm_out", "conv_norm_out"),
("mid.attn_1.", "mid_block.attentions.0."),
]
for i in range(4):
# down_blocks have two resnets
for j in range(2):
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
sd_down_prefix = f"encoder.down.{i}.block.{j}."
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
sd_downsample_prefix = f"down.{i}.downsample."
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
for sd_part, hf_part in vae_conversion_map_attn:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ["q", "k", "v", "proj_out"]
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
# =========================#
# Text Encoder Conversion #
# =========================#
textenc_conversion_lst = [
# (stable-diffusion, HF Diffusers)
("resblocks.", "text_model.encoder.layers."),
("ln_1", "layer_norm1"),
("ln_2", "layer_norm2"),
(".c_fc.", ".fc1."),
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
]
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
def convert_text_enc_state_dict_v20(text_enc_dict):
new_state_dict = {}
capture_qkv_weight = {}
capture_qkv_bias = {}
for k, v in text_enc_dict.items():
if (
k.endswith(".self_attn.q_proj.weight")
or k.endswith(".self_attn.k_proj.weight")
or k.endswith(".self_attn.v_proj.weight")
):
k_pre = k[: -len(".q_proj.weight")]
k_code = k[-len("q_proj.weight")]
if k_pre not in capture_qkv_weight:
capture_qkv_weight[k_pre] = [None, None, None]
capture_qkv_weight[k_pre][code2idx[k_code]] = v
continue
if (
k.endswith(".self_attn.q_proj.bias")
or k.endswith(".self_attn.k_proj.bias")
or k.endswith(".self_attn.v_proj.bias")
):
k_pre = k[: -len(".q_proj.bias")]
k_code = k[-len("q_proj.bias")]
if k_pre not in capture_qkv_bias:
capture_qkv_bias[k_pre] = [None, None, None]
capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
for k_pre, tensors in capture_qkv_bias.items():
if None in tensors:
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
return new_state_dict
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
clip = None
vae = None
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return ModelPatcher(model), clip, vae

View File

@ -78,7 +78,7 @@ class DDIMSampler(object):
dynamic_threshold=None, dynamic_threshold=None,
ucg_schedule=None, ucg_schedule=None,
denoise_function=None, denoise_function=None,
cond_concat=None, extra_args=None,
to_zero=True, to_zero=True,
end_step=None, end_step=None,
**kwargs **kwargs
@ -101,7 +101,7 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold, dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule, ucg_schedule=ucg_schedule,
denoise_function=denoise_function, denoise_function=denoise_function,
cond_concat=cond_concat, extra_args=extra_args,
to_zero=to_zero, to_zero=to_zero,
end_step=end_step end_step=end_step
) )
@ -174,7 +174,7 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold, dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule, ucg_schedule=ucg_schedule,
denoise_function=None, denoise_function=None,
cond_concat=None extra_args=None
) )
return samples, intermediates return samples, intermediates
@ -185,7 +185,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None): ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -225,7 +225,7 @@ class DDIMSampler(object):
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat) dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
img, pred_x0 = outs img, pred_x0 = outs
if callback: callback(i) if callback: callback(i)
if img_callback: img_callback(pred_x0, i) if img_callback: img_callback(pred_x0, i)
@ -249,11 +249,11 @@ class DDIMSampler(object):
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None, denoise_function=None, cond_concat=None): dynamic_threshold=None, denoise_function=None, extra_args=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if denoise_function is not None: if denoise_function is not None:
model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat) model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c) model_output = self.model.apply_model(x, t, c)
else: else:

View File

@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
self.conditioning_key = conditioning_key self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t, control=control) out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'concat': elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, control=control) out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'crossattn': elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn: if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments # TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce # with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'. # an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc, control=control) out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
else: else:
out = self.diffusion_model(x, t, context=cc, control=control) out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, control=control) out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid-adm': elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'crossattn-adm': elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None assert c_adm is not None
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'adm': elif self.conditioning_key == 'adm':
cc = c_crossattn[0] cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc, control=control) out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options)
else: else:
raise NotImplementedError() raise NotImplementedError()
@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
log = super().log_images(*args, **kwargs) log = super().log_images(*args, **kwargs)
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
return log return log
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.embed_key = embedding_key
self.embedding_dropout = embedding_dropout
# self._init_embedder(embedder_config, freeze_embedder)
self._init_noise_aug(noise_aug_config)
def _init_embedder(self, config, freeze=True):
embedder = instantiate_from_config(config)
if freeze:
self.embedder = embedder.eval()
self.embedder.train = disabled_train
for param in self.embedder.parameters():
param.requires_grad = False
def _init_noise_aug(self, config):
if config is not None:
# use the KARLO schedule for noise augmentation on CLIP image embeddings
noise_augmentor = instantiate_from_config(config)
assert isinstance(noise_augmentor, nn.Module)
noise_augmentor = noise_augmentor.eval()
noise_augmentor.train = disabled_train
self.noise_augmentor = noise_augmentor
else:
self.noise_augmentor = None
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
z, c = outputs[0], outputs[1]
img = batch[self.embed_key][:bs]
img = rearrange(img, 'b h w c -> b c h w')
c_adm = self.embedder(img)
if self.noise_augmentor is not None:
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
# assume this gives embeddings of noise levels
c_adm = torch.cat((c_adm, noise_level_emb), 1)
if self.training:
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
device=c_adm.device)[:, None]) * c_adm
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
noutputs = [z, all_conds]
noutputs.extend(outputs[2:])
return noutputs
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, **kwargs):
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
return_original_cond=True)
log["inputs"] = x
log["reconstruction"] = xrec
assert self.model.conditioning_key is not None
assert self.cond_stage_key in ["caption", "txt"]
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
with ema_scope(f"Sampling"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_, )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log

View File

@ -307,7 +307,16 @@ def model_wrapper(
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2) t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition, condition]) if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict)
c_in = dict()
for k in condition:
if isinstance(condition[k], list):
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
else:
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
else:
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond) return noise_uncond + guidance_scale * (noise - noise_uncond)

View File

@ -3,7 +3,6 @@ import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = { MODEL_TYPES = {
"eps": "noise", "eps": "noise",
"v": "v" "v": "v"
@ -51,12 +50,20 @@ class DPMSolverSampler(object):
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] ctmp = conditioning[list(conditioning.keys())[0]]
if cbs != batch_size: while isinstance(ctmp, list): ctmp = ctmp[0]
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else: else:
if conditioning.shape[0] != batch_size: if isinstance(conditioning, torch.Tensor):
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling # sampling
C, H, W = shape C, H, W = shape
@ -83,6 +90,7 @@ class DPMSolverSampler(object):
) )
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None return x.to(device), None

View File

@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
import model_management import model_management
from . import tomesd
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers
@ -20,6 +21,8 @@ if model_management.xformers_enabled():
import os import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
from cli_args import args
def exists(val): def exists(val):
return val is not None return val is not None
@ -473,7 +476,6 @@ class CrossAttentionPytorch(nn.Module):
return self.to_out(out) return self.to_out(out)
import sys
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention CrossAttention = MemoryEfficientCrossAttention
@ -481,7 +483,7 @@ elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention") print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch CrossAttention = CrossAttentionPytorch
else: else:
if "--use-split-cross-attention" in sys.argv: if args.use_split_cross_attention:
print("Using split optimization for cross attention") print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx CrossAttention = CrossAttentionDoggettx
else: else:
@ -504,12 +506,22 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None): def _forward(self, x, context=None, transformer_options={}):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x n = self.norm1(x)
x = self.attn2(self.norm2(x), context=context) + x if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
else:
n = self.attn1(n, context=context if self.disable_self_attn else None)
x += n
n = self.norm2(x)
n = self.attn2(n, context=context)
x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
return x return x
@ -557,7 +569,7 @@ class SpatialTransformer(nn.Module):
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None): def forward(self, x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list): if not isinstance(context, list):
context = [context] context = [context]
@ -570,7 +582,7 @@ class SpatialTransformer(nn.Module):
if self.use_linear: if self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i]) x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear: if self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()

View File

@ -9,7 +9,7 @@ from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management import model_management
if model_management.xformers_enabled(): if model_management.xformers_enabled_vae():
import xformers import xformers
import xformers.ops import xformers.ops
@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled() and attn_type == "vanilla": if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers" attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla": if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch" attn_type = "vanilla-pytorch"

View File

@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None): def forward(self, x, emb, context=None, transformer_options={}):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context) x = layer(x, context, transformer_options)
else: else:
x = layer(x) x = layer(x)
return x return x
@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
return count_flops_attn(model, _x, y) return count_flops_attn(model, _x, y)
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
return timestep_embedding(t, self.dim)
class UNetModel(nn.Module): class UNetModel(nn.Module):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
@ -470,6 +479,7 @@ class UNetModel(nn.Module):
num_attention_blocks=None, num_attention_blocks=None,
disable_middle_self_attn=False, disable_middle_self_attn=False,
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None,
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
@ -538,6 +548,15 @@ class UNetModel(nn.Module):
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer") print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else: else:
raise ValueError() raise ValueError()
@ -753,7 +772,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.
@ -762,6 +781,7 @@ class UNetModel(nn.Module):
:param y: an [N] Tensor of labels, if class-conditional. :param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
transformer_options["original_shape"] = list(x.shape)
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
@ -775,13 +795,13 @@ class UNetModel(nn.Module):
h = x.type(self.dtype) h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
h = module(h, emb, context) h = module(h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0: if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop() ctrl = control['input'].pop()
if ctrl is not None: if ctrl is not None:
h += ctrl h += ctrl
hs.append(h) hs.append(h)
h = self.middle_block(h, emb, context) h = self.middle_block(h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0: if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop() h += control['middle'].pop()
@ -793,7 +813,7 @@ class UNetModel(nn.Module):
hsp += ctrl hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context) h = module(h, emb, context, transformer_options)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)

View File

@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1] betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999) betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "squaredcos_cap_v2": # used for karlo prior
# return early
return betas_for_alpha_bar(
n_timestep,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
elif schedule == "sqrt_linear": elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt": elif schedule == "sqrt":
@ -218,6 +225,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
@ -267,4 +275,4 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()

View File

@ -0,0 +1,59 @@
from typing import List, Tuple, Union
import torch
import torch.nn as nn
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
r"""Normalize an image/video tensor with mean and standard deviation.
.. math::
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
Args:
data: Image tensor of size :math:`(B, C, *)`.
mean: Mean for each channel.
std: Standard deviations for each channel.
Return:
Normalised tensor with same size as input :math:`(B, C, *)`.
Examples:
>>> x = torch.rand(1, 4, 3, 3)
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3)
>>> mean = torch.zeros(4)
>>> std = 255. * torch.ones(4)
>>> out = normalize(x, mean, std)
>>> out.shape
torch.Size([1, 4, 3, 3])
"""
shape = data.shape
if len(mean.shape) == 0 or mean.shape[0] == 1:
mean = mean.expand(shape[1])
if len(std.shape) == 0 or std.shape[0] == 1:
std = std.expand(shape[1])
# Allow broadcast on channel dimension
if mean.shape and mean.shape[0] != 1:
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
# Allow broadcast on channel dimension
if std.shape and std.shape[0] != 1:
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
if mean.shape:
mean = mean[..., :, None]
if std.shape:
std = std[..., :, None]
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
return out.view(shape)

View File

@ -1,5 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from . import kornia_functions
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module):
c = batch[key][:, None] c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout: if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long() c = c.long()
c = self.embedding(c) c = self.embedding(c)
return c return c
@ -57,18 +58,20 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder): class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text""" """Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__() super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version) self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device self.device = device
self.max_length = max_length # TODO: typical value? self.max_length = max_length # TODO: typical value?
if freeze: if freeze:
self.freeze() self.freeze()
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
#self.train = disabled_train # self.train = disabled_train
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
"pooled", "pooled",
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__() super().__init__()
@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
#self.train = disabled_train # self.train = disabled_train
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device) tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
if self.layer == "last": if self.layer == "last":
z = outputs.last_hidden_state z = outputs.last_hidden_state
elif self.layer == "pooled": elif self.layer == "pooled":
@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return self(text) return self(text)
class ClipImageEmbedder(nn.Module):
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.
):
super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
# x = kornia_functions.geometry_resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder): class FrozenOpenCLIPEmbedder(AbstractEncoder):
""" """
Uses the OpenCLIP transformer encoder for text Uses the OpenCLIP transformer encoder for text
""" """
LAYERS = [ LAYERS = [
#"pooled", # "pooled",
"last", "last",
"penultimate" "penultimate"
] ]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last"): freeze=True, layer="last"):
super().__init__() super().__init__()
@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
x = self.model.ln_final(x) x = self.model.ln_final(x)
return x return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks): for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx: if i == len(self.model.transformer.resblocks) - self.layer_idx:
break break
@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
return self(text) return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
pretrained=version, )
del model.transformer
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
# x = kornia.geometry.resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder): class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
clip_max_length=77, t5_max_length=77): clip_max_length=77, t5_max_length=77):
super().__init__() super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
def encode(self, text): def encode(self, text):
return self(text) return self(text)
@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
clip_z = self.clip_encoder.encode(text) clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text) t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z] return [clip_z, t5_z]

View File

@ -0,0 +1,35 @@
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.diffusionmodules.openaimodel import Timestep
import torch
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
super().__init__(*args, **kwargs)
if clip_stats_path is None:
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
else:
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
self.register_buffer("data_std", clip_std[None, :], persistent=False)
self.time_embed = Timestep(timestep_dim)
def scale(self, x):
# re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1. / self.data_std
return x
def unscale(self, x):
# back to original data stats
x = (x * self.data_std) + self.data_mean
return x
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
x = self.scale(x)
z = self.q_sample(x, noise_level)
z = self.unscale(z)
noise_level = self.time_embed(noise_level)
return z, noise_level

144
comfy/ldm/modules/tomesd.py Normal file
View File

@ -0,0 +1,144 @@
#Taken from: https://github.com/dbolya/tomesd
import torch
from typing import Tuple, Callable
import math
def do_nothing(x: torch.Tensor, mode:str=None):
return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
else:
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
- h: image height in tokens
- sx: stride in the x dimension for dst, must divide w
- sy: stride in the y dimension for dst, must divide h
- r: number of tokens to remove (by merging)
- no_rand: if true, disable randomness (use top left corner only)
"""
B, N, _ = metric.shape
if r <= 0:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
if no_rand:
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
# Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
def split(x):
C = x.shape[-1]
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1)
def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out
return merge, unmerge
def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2
stride_y = 2
max_downsample = 1
if downsample <= max_downsample:
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio)
no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
return m, u
nothing = lambda y: y
return nothing, nothing

View File

@ -1,36 +1,42 @@
import psutil
from enum import Enum
from cli_args import args
CPU = 0 class VRAMState(Enum):
NO_VRAM = 1 CPU = 0
LOW_VRAM = 2 NO_VRAM = 1
NORMAL_VRAM = 3 LOW_VRAM = 2
HIGH_VRAM = 4 NORMAL_VRAM = 3
MPS = 5 HIGH_VRAM = 4
MPS = 5
accelerate_enabled = False # Determine VRAM State
vram_state = NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
total_vram = 0 total_vram = 0
total_vram_available_mb = -1 total_vram_available_mb = -1
import sys accelerate_enabled = False
import psutil xpu_available = False
forced_cpu = "--cpu" in sys.argv
set_vram_to = NORMAL_VRAM
try: try:
import torch import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv if not args.normalvram and not args.cpu:
if not forced_normal_vram and not forced_cpu:
if total_vram <= 4096: if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336: elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = HIGH_VRAM vram_state = VRAMState.HIGH_VRAM
except: except:
pass pass
@ -39,34 +45,37 @@ try:
except: except:
OOM_EXCEPTION = Exception OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv: if args.disable_xformers:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILABLE = False
else: else:
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILBLE = True XFORMERS_IS_AVAILABLE = True
except: except:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILABLE = False
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
if "--use-pytorch-cross-attention" in sys.argv: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False
XFORMERS_IS_AVAILBLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
vram_state = VRAMState.HIGH_VRAM
FORCE_FP32 = False
if args.force_fp32:
print("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if "--lowvram" in sys.argv: if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
set_vram_to = NO_VRAM
if "--highvram" in sys.argv:
vram_state = HIGH_VRAM
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
import accelerate import accelerate
accelerate_enabled = True accelerate_enabled = True
@ -81,14 +90,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try: try:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
vram_state = MPS vram_state = VRAMState.MPS
except: except:
pass pass
if forced_cpu: if args.cpu:
vram_state = CPU vram_state = VRAMState.CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) print(f"Set vram state to: {vram_state.name}")
current_loaded_model = None current_loaded_model = None
@ -109,12 +118,12 @@ def unload_model():
model_accelerated = False model_accelerated = False
#never unload models from GPU on high vram #never unload models from GPU on high vram
if vram_state != HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
current_loaded_model.model.cpu() current_loaded_model.model.cpu()
current_loaded_model.unpatch_model() current_loaded_model.unpatch_model()
current_loaded_model = None current_loaded_model = None
if vram_state != HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
if len(current_gpu_controlnets) > 0: if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets: for n in current_gpu_controlnets:
n.cpu() n.cpu()
@ -135,32 +144,32 @@ def load_model_gpu(model):
model.unpatch_model() model.unpatch_model()
raise e raise e
current_loaded_model = model current_loaded_model = model
if vram_state == CPU: if vram_state == VRAMState.CPU:
pass pass
elif vram_state == MPS: elif vram_state == VRAMState.MPS:
mps_device = torch.device("mps") mps_device = torch.device("mps")
real_model.to(mps_device) real_model.to(mps_device)
pass pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.cuda() real_model.to(get_torch_device())
else: else:
if vram_state == NO_VRAM: if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM: elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
def load_controlnet_gpu(models): def load_controlnet_gpu(models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == CPU: if vram_state == VRAMState.CPU:
return return
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
@ -176,38 +185,57 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model): def load_if_low_vram(model):
global vram_state global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cuda() return model.to(get_torch_device())
return model return model
def unload_if_low_vram(model): def unload_if_low_vram(model):
global vram_state global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cpu() return model.cpu()
return model return model
def get_torch_device(): def get_torch_device():
if vram_state == MPS: global xpu_available
if vram_state == VRAMState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == CPU: if vram_state == VRAMState.CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
return torch.cuda.current_device() if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type
return "cuda" return "cuda"
def xformers_enabled(): def xformers_enabled():
if vram_state == CPU: if vram_state == VRAMState.CPU:
return False return False
return XFORMERS_IS_AVAILBLE return XFORMERS_IS_AVAILABLE
def xformers_enabled_vae():
enabled = xformers_enabled()
if not enabled:
return False
try:
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
if xformers.version.__version__ == "0.0.18":
return False
except:
pass
return enabled
def pytorch_attention_enabled(): def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -215,12 +243,16 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
stats = torch.cuda.memory_stats(dev) if xpu_available:
mem_active = stats['active_bytes.all.current'] mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_free_torch = mem_free_total
mem_free_cuda, _ = torch.cuda.mem_get_info(dev) else:
mem_free_torch = mem_reserved - mem_active stats = torch.cuda.memory_stats(dev)
mem_free_total = mem_free_cuda + mem_free_torch mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
if torch_free_too: if torch_free_too:
return (mem_free_total, mem_free_torch) return (mem_free_total, mem_free_torch)
@ -229,7 +261,7 @@ def get_free_memory(dev=None, torch_free_too=False):
def maximum_batch_area(): def maximum_batch_area():
global vram_state global vram_state
if vram_state == NO_VRAM: if vram_state == VRAMState.NO_VRAM:
return 0 return 0
memory_free = get_free_memory() / (1024 * 1024) memory_free = get_free_memory() / (1024 * 1024)
@ -238,14 +270,18 @@ def maximum_batch_area():
def cpu_mode(): def cpu_mode():
global vram_state global vram_state
return vram_state == CPU return vram_state == VRAMState.CPU
def mps_mode(): def mps_mode():
global vram_state global vram_state
return vram_state == MPS return vram_state == VRAMState.MPS
def should_use_fp16(): def should_use_fp16():
if cpu_mode() or mps_mode(): global xpu_available
if FORCE_FP32:
return False
if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ? return False #TODO ?
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():

View File

@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module):
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns predicted noise #Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'strength' in cond[1]: if 'strength' in cond[1]:
strength = cond[1]['strength'] strength = cond[1]['strength']
adm_cond = None
if 'adm' in cond[1]:
adm_cond = cond[1]['adm']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength mult = torch.ones_like(input_x) * strength
@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cropped.append(cr) cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1) conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
control = None control = None
if 'control' in cond[1]: if 'control' in cond[1]:
control = cond[1]['control'] control = cond[1]['control']
@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if 'c_concat' in c1: if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape: if c1['c_concat'].shape != c2['c_concat'].shape:
return False return False
if 'c_adm' in c1:
if c1['c_adm'].shape != c2['c_adm'].shape:
return False
return True return True
def can_concat_cond(c1, c2): def can_concat_cond(c1, c2):
@ -92,19 +102,24 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def cond_cat(c_list): def cond_cat(c_list):
c_crossattn = [] c_crossattn = []
c_concat = [] c_concat = []
c_adm = []
for x in c_list: for x in c_list:
if 'c_crossattn' in x: if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn']) c_crossattn.append(x['c_crossattn'])
if 'c_concat' in x: if 'c_concat' in x:
c_concat.append(x['c_concat']) c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {} out = {}
if len(c_crossattn) > 0: if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)] out['c_crossattn'] = [torch.cat(c_crossattn)]
if len(c_concat) > 0: if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)] out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm)
return out return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in): def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0 out_count = torch.ones_like(x_in)/100000.0
@ -169,6 +184,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
if 'transformer_options' in model_options:
c['transformer_options'] = model_options['transformer_options']
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x del input_x
@ -192,7 +210,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area() max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
@ -209,8 +227,8 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None): def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat) out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
return out return out
@ -218,11 +236,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
if denoise_mask is not None: if denoise_mask is not None:
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat) out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
if denoise_mask is not None: if denoise_mask is not None:
out *= denoise_mask out *= denoise_mask
@ -324,13 +342,46 @@ def apply_control_net_to_equal_area(conds, uncond):
n['control'] = cond_cnets[x] n['control'] = cond_cnets[x]
uncond[temp[1]] = [o[0], n] uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)):
x = conds[t]
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()
x[1]["adm"] = torch.cat([adm_out] * batch_size)
return conds
class KSampler: class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model self.model = model
self.model_denoise = CFGNoisePredictor(self.model) self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v": if self.model.parameterization == "v":
@ -350,6 +401,7 @@ class KSampler:
self.sigma_max=float(self.model_wrap.sigma_max) self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise) self.set_steps(steps, denoise)
self.denoise = denoise self.denoise = denoise
self.model_options = model_options
def _calculate_sigmas(self, steps): def _calculate_sigmas(self, steps):
sigmas = None sigmas = None
@ -418,10 +470,14 @@ class KSampler:
else: else:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} if hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None cond_concat = None
if hasattr(self.model, 'concat_keys'): if hasattr(self.model, 'concat_keys'): #inpaint
cond_concat = [] cond_concat = []
for ck in self.model.concat_keys: for ck in self.model.concat_keys:
if denoise_mask is not None: if denoise_mask is not None:
@ -467,7 +523,7 @@ class KSampler:
x_T=z_enc, x_T=z_enc,
x0=latent_image, x0=latent_image,
denoise_function=sampling_function, denoise_function=sampling_function,
cond_concat=cond_concat, extra_args=extra_args,
mask=noise_mask, mask=noise_mask,
to_zero=sigmas[-1]==0, to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1) end_step=sigmas.shape[0] - 1)

View File

@ -1,5 +1,6 @@
import torch import torch
import contextlib import contextlib
import copy
import sd1_clip import sd1_clip
import sd2_clip import sd2_clip
@ -11,20 +12,7 @@ from .cldm import cldm
from .t2i_adapter import adapter from .t2i_adapter import adapter
from . import utils from . import utils
from . import clip_vision
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
@ -52,30 +40,7 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if x in sd: if x in sd:
sd[keys_to_replace[x]] = sd.pop(x) sd[keys_to_replace[x]] = sd.pop(x)
resblock_to_replace = { sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"mlp.c_fc": "mlp.fc1",
"mlp.c_proj": "mlp.fc2",
"attn.out_proj": "self_attn.out_proj",
}
for resblock in range(24):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y)
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y)
sd[k_to] = weights[1024*x:1024*(x + 1)]
for x in load_state_dict_to: for x in load_state_dict_to:
x.load_state_dict(sd, strict=False) x.load_state_dict(sd, strict=False)
@ -122,7 +87,7 @@ LORA_UNET_MAP_RESNET = {
} }
def load_lora(path, to_load): def load_lora(path, to_load):
lora = load_torch_file(path) lora = utils.load_torch_file(path)
patch_dict = {} patch_dict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
@ -274,12 +239,20 @@ class ModelPatcher:
self.model = model self.model = model
self.patches = [] self.patches = []
self.backup = {} self.backup = {}
self.model_options = {"transformer_options":{}}
def clone(self): def clone(self):
n = ModelPatcher(self.model) n = ModelPatcher(self.model)
n.patches = self.patches[:] n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options)
return n return n
def set_model_tomesd(self, ratio):
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
def model_dtype(self):
return self.model.diffusion_model.dtype
def add_patches(self, patches, strength=1.0): def add_patches(self, patches, strength=1.0):
p = {} p = {}
model_sd = self.model.state_dict() model_sd = self.model.state_dict()
@ -590,7 +563,7 @@ class ControlNet:
return out return out
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = load_torch_file(ckpt_path) controlnet_data = utils.load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False pth = False
sd2 = False sd2 = False
@ -784,7 +757,7 @@ class StyleModel:
def load_style_model(ckpt_path): def load_style_model(ckpt_path):
model_data = load_torch_file(ckpt_path) model_data = utils.load_torch_file(ckpt_path)
keys = model_data.keys() keys = model_data.keys()
if "style_embedding" in keys: if "style_embedding" in keys:
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
@ -795,7 +768,7 @@ def load_style_model(ckpt_path):
def load_clip(ckpt_path, embedding_directory=None): def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path) clip_data = utils.load_torch_file(ckpt_path)
config = {} config = {}
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
@ -838,7 +811,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
load_state_dict_to = [w] load_state_dict_to = [w]
model = instantiate_from_config(config["model"]) model = instantiate_from_config(config["model"])
sd = load_torch_file(ckpt_path) sd = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16: if fp16:
@ -847,10 +820,11 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
return (ModelPatcher(model), clip, vae) return (ModelPatcher(model), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = load_torch_file(ckpt_path) sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() sd_keys = sd.keys()
clip = None clip = None
clipvision = None
vae = None vae = None
fp16 = model_management.should_use_fp16() fp16 = model_management.should_use_fp16()
@ -875,6 +849,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w] load_state_dict_to = [w]
clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight"
noise_aug_config = None
if clipvision_key in sd_keys:
size = sd[clipvision_key].shape[1]
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd)
noise_aug_key = "noise_augmentor.betas"
if noise_aug_key in sd_keys:
noise_aug_config = {}
params = {}
noise_schedule_config = {}
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
params["noise_schedule_config"] = noise_schedule_config
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
if size == 1280: #h
params["timestep_dim"] = 1024
elif size == 1024: #l
params["timestep_dim"] = 768
noise_aug_config['params'] = params
sd_config = { sd_config = {
"linear_start": 0.00085, "linear_start": 0.00085,
"linear_end": 0.012, "linear_end": 0.012,
@ -923,7 +920,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
if unet_config["in_channels"] > 4: #inpainting model if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid" sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None sd_config["finetune_keys"] = None
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
@ -935,6 +938,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
else: else:
unet_config["num_heads"] = 8 #SD1.x unet_config["num_heads"] = 8 #SD1.x
unclip = 'model.diffusion_model.label_emb.0.0.weight'
if unclip in sd_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1]
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = sd[k] out = sd[k]
@ -947,4 +955,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
if fp16: if fp16:
model = model.half() model = model.half()
return (ModelPatcher(model), clip, vae) return (ModelPatcher(model), clip, vae, clipvision)

View File

@ -74,9 +74,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if isinstance(y, int): if isinstance(y, int):
tokens_temp += [y] tokens_temp += [y]
else: else:
embedding_weights += [y] if y.shape[0] == current_embeds.weight.shape[1]:
tokens_temp += [next_new_token] embedding_weights += [y]
next_new_token += 1 tokens_temp += [next_new_token]
next_new_token += 1
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
if len(embedding_weights) > 0: if len(embedding_weights) > 0:

View File

@ -1,5 +1,47 @@
import torch import torch
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def transformers_convert(sd, prefix_from, prefix_to, number):
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
"mlp.c_fc": "mlp.fc1",
"mlp.c_proj": "mlp.fc2",
"attn.out_proj": "self_attn.out_proj",
}
for resblock in range(number):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def common_upscale(samples, width, height, upscale_method, crop): def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center": if crop == "center":
old_width = samples.shape[3] old_width = samples.shape[3]

View File

@ -1,32 +0,0 @@
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
from comfy.sd import load_torch_file
import os
class ClipVisionModel():
def __init__(self):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json")
config = CLIPVisionConfig.from_json_file(json_config)
self.model = CLIPVisionModel(config)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
def load_sd(self, sd):
self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
inputs = self.processor(images=[image[0]], return_tensors="pt")
outputs = self.model(**inputs)
return outputs
def load(ckpt_path):
clip_data = load_torch_file(ckpt_path)
clip = ClipVisionModel()
clip.load_sd(clip_data)
return clip

View File

@ -0,0 +1,210 @@
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import comfy.utils
class Blend:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"image2": ("IMAGE",),
"blend_factor": ("FLOAT", {
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blend_images"
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
return (blended_image,)
def blend_mode(self, img1, img2, mode):
if mode == "normal":
return img2
elif mode == "multiply":
return img1 * img2
elif mode == "screen":
return 1 - (1 - img1) * (1 - img2)
elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
else:
raise ValueError(f"Unsupported blend mode: {mode}")
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
class Blur:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"blur_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "blur"
CATEGORY = "image/postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
blurred = blurred.permute(0, 2, 3, 1)
return (blurred,)
class Quantize:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"colors": ("INT", {
"default": 256,
"min": 1,
"max": 256,
"step": 1
}),
"dither": (["none", "floyd-steinberg"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "quantize"
CATEGORY = "image/postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array
return (result,)
class Sharpen:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"sharpen_radius": ("INT", {
"default": 1,
"min": 1,
"max": 31,
"step": 1
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 5.0,
"step": 0.1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "sharpen"
CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
if sharpen_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1)
return (result,)
NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend,
"ImageBlur": Blur,
"ImageQuantize": Quantize,
"ImageSharpen": Sharpen,
}

View File

@ -1,6 +1,5 @@
import os import os
from comfy_extras.chainner_models import model_loading from comfy_extras.chainner_models import model_loading
from comfy.sd import load_torch_file
import model_management import model_management
import torch import torch
import comfy.utils import comfy.utils
@ -18,7 +17,7 @@ class UpscaleModelLoader:
def load_model(self, model_name): def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name) model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = load_torch_file(model_path) sd = comfy.utils.load_torch_file(model_path)
out = model_loading.load_state_dict(sd).eval() out = model_loading.load_state_dict(sd).eval()
return (out, ) return (out, )

View File

@ -11,6 +11,8 @@ class Example:
---------- ----------
RETURN_TYPES (`tuple`): RETURN_TYPES (`tuple`):
The type of each element in the output tulple. The type of each element in the output tulple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tulple.
FUNCTION (`str`): FUNCTION (`str`):
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]): OUTPUT_NODE ([`bool`]):
@ -61,6 +63,8 @@ class Example:
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
#RETURN_NAMES = ("image_output_name",)
FUNCTION = "test" FUNCTION = "test"
#OUTPUT_NODE = False #OUTPUT_NODE = False

View File

@ -24,10 +24,45 @@ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
if not os.path.exists(input_directory):
os.makedirs(input_directory)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def get_output_directory():
global output_directory
return output_directory
def get_temp_directory():
global temp_directory
return temp_directory
def get_input_directory():
global input_directory
return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
if type_name == "output":
return get_output_directory()
if type_name == "temp":
return get_temp_directory()
if type_name == "input":
return get_input_directory()
return None
def add_model_folder_path(folder_name, full_folder_path): def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths global folder_names_and_paths

90
main.py
View File

@ -1,49 +1,32 @@
import os
import sys
import shutil
import threading
import asyncio import asyncio
import itertools
import os
import shutil
import threading
from comfy.cli_args import args
if os.name == "nt": if os.name == "nt":
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__": if __name__ == "__main__":
if '--help' in sys.argv: if args.dont_upcast_attention:
print("Valid Command line Arguments:")
print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
print("\t--port 8188\t\t\tSet the listen port.")
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
print("\t--disable-xformers\t\tdisables xformers")
print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
print("\t--novram\t\t\tWhen lowvram isn't enough.")
print()
print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
exit()
if '--dont-upcast-attention' in sys.argv:
print("disabling upcasting of attention") print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16" os.environ['ATTN_PRECISION'] = "fp16"
try: if args.cuda_device is not None:
index = sys.argv.index('--cuda-device') os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
device = sys.argv[index + 1] print("Set cuda device to:", args.cuda_device)
os.environ['CUDA_VISIBLE_DEVICES'] = device
print("Set cuda device to:", device)
except: import yaml
pass
import execution import execution
import server
import folder_paths import folder_paths
import yaml import server
from nodes import init_custom_nodes
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
@ -98,47 +81,36 @@ if __name__ == "__main__":
server = server.PromptServer(loop) server = server.PromptServer(loop)
q = execution.PromptQueue(server) q = execution.PromptQueue(server)
init_custom_nodes()
server.add_routes()
hijack_progress(server) hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
try:
address = '0.0.0.0'
p_index = sys.argv.index('--listen')
try:
ip = sys.argv[p_index + 1]
if ip[:2] != '--':
address = ip
except:
pass
except:
address = '127.0.0.1'
address = args.listen
dont_print = False dont_print = args.dont_print_server
if '--dont-print-server' in sys.argv:
dont_print = True
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path): if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path) load_extra_path_config(extra_model_paths_config_path)
if '--extra-model-paths-config' in sys.argv: if args.extra_model_paths_config:
indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config'] for config_path in itertools.chain(*args.extra_model_paths_config):
for i in indices: load_extra_path_config(config_path)
load_extra_path_config(sys.argv[i])
port = 8188 if args.output_directory:
try: output_dir = os.path.abspath(args.output_directory)
p_index = sys.argv.index('--port') print(f"Setting output directory to: {output_dir}")
port = int(sys.argv[p_index + 1]) folder_paths.set_output_directory(output_dir)
except:
pass
if '--quick-test-for-ci' in sys.argv: port = args.port
if args.quick_test_for_ci:
exit(0) exit(0)
call_on_start = None call_on_start = None
if "--windows-standalone-build" in sys.argv: if args.windows_standalone_build:
def startup_server(address, port): def startup_server(address, port):
import webbrowser import webbrowser
webbrowser.open("http://{}:{}".format(address, port)) webbrowser.open("http://{}:{}".format(address, port))

139
nodes.py
View File

@ -4,21 +4,22 @@ import os
import sys import sys
import json import json
import hashlib import hashlib
import copy
import traceback import traceback
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_convert
import comfy.samplers import comfy.samplers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy_extras.clip_vision import comfy.clip_vision
import model_management import model_management
import importlib import importlib
@ -197,7 +198,7 @@ class CheckpointLoader:
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "loaders" CATEGORY = "advanced/loaders"
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
@ -219,6 +220,45 @@ class CheckpointLoaderSimple:
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
class DiffusersLoader:
@classmethod
def INPUT_TYPES(cls):
paths = []
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths += next(os.walk(search_path))[1]
return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths = next(os.walk(search_path))[1]
if model_path in paths:
model_path = os.path.join(search_path, model_path)
break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
class CLIPSetLastLayer: class CLIPSetLastLayer:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -254,6 +294,22 @@ class LoraLoader:
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class TomePatchModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, ratio):
m = model.clone()
m.set_model_tomesd(ratio)
return (m, )
class VAELoader: class VAELoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -354,7 +410,7 @@ class CLIPVisionLoader:
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_vision = comfy_extras.clip_vision.load(clip_path) clip_vision = comfy.clip_vision.load(clip_path)
return (clip_vision,) return (clip_vision,)
class CLIPVisionEncode: class CLIPVisionEncode:
@ -366,7 +422,7 @@ class CLIPVisionEncode:
RETURN_TYPES = ("CLIP_VISION_OUTPUT",) RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "conditioning/style_model" CATEGORY = "conditioning"
def encode(self, clip_vision, image): def encode(self, clip_vision, image):
output = clip_vision.encode_image(image) output = clip_vision.encode_image(image)
@ -408,6 +464,33 @@ class StyleModelApply:
c.append(n) c.append(n)
return (c, ) return (c, )
class unCLIPConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_adm"
CATEGORY = "conditioning"
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
c = []
for t in conditioning:
o = t[1].copy()
x = (clip_vision_output, strength, noise_augmentation)
if "adm" in o:
o["adm"] = o["adm"][:] + [x]
else:
o["adm"] = [x]
n = [t[0], o]
c.append(n)
return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -646,7 +729,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
model_management.load_controlnet_gpu(control_net_models) model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
else: else:
#other samplers #other samplers
pass pass
@ -719,7 +802,7 @@ class KSamplerAdvanced:
class SaveImage: class SaveImage:
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") self.output_dir = folder_paths.get_output_directory()
self.type = "output" self.type = "output"
@classmethod @classmethod
@ -771,9 +854,6 @@ class SaveImage:
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
counter = 1 counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
results = list() results = list()
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
@ -798,7 +878,7 @@ class SaveImage:
class PreviewImage(SaveImage): class PreviewImage(SaveImage):
def __init__(self): def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") self.output_dir = folder_paths.get_temp_directory()
self.type = "temp" self.type = "temp"
@classmethod @classmethod
@ -809,13 +889,11 @@ class PreviewImage(SaveImage):
} }
class LoadImage: class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
if not os.path.exists(s.input_dir): input_dir = folder_paths.get_input_directory()
os.makedirs(s.input_dir)
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), )}, {"image": (sorted(os.listdir(input_dir)), )},
} }
CATEGORY = "image" CATEGORY = "image"
@ -823,7 +901,8 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK") RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
@ -837,18 +916,19 @@ class LoadImage:
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()
class LoadImageMask: class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
return {"required": return {"required":
{"image": (sorted(os.listdir(s.input_dir)), ), {"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),} "channel": (["alpha", "red", "green", "blue"], ),}
} }
@ -857,7 +937,8 @@ class LoadImageMask:
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image, channel): def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
mask = None mask = None
c = channel[0].upper() c = channel[0].upper()
@ -872,7 +953,8 @@ class LoadImageMask:
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image) input_dir = folder_paths.get_input_directory()
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
@ -980,7 +1062,6 @@ class ImagePadForOutpaint:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"KSampler": KSampler, "KSampler": KSampler,
"CheckpointLoader": CheckpointLoader,
"CheckpointLoaderSimple": CheckpointLoaderSimple, "CheckpointLoaderSimple": CheckpointLoaderSimple,
"CLIPTextEncode": CLIPTextEncode, "CLIPTextEncode": CLIPTextEncode,
"CLIPSetLastLayer": CLIPSetLastLayer, "CLIPSetLastLayer": CLIPSetLastLayer,
@ -1009,6 +1090,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPLoader": CLIPLoader, "CLIPLoader": CLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode, "CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply, "StyleModelApply": StyleModelApply,
"unCLIPConditioning": unCLIPConditioning,
"ControlNetApply": ControlNetApply, "ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader, "ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader,
@ -1016,6 +1098,10 @@ NODE_CLASS_MAPPINGS = {
"CLIPVisionLoader": CLIPVisionLoader, "CLIPVisionLoader": CLIPVisionLoader,
"VAEDecodeTiled": VAEDecodeTiled, "VAEDecodeTiled": VAEDecodeTiled,
"VAEEncodeTiled": VAEEncodeTiled, "VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
} }
def load_custom_node(module_path): def load_custom_node(module_path):
@ -1050,7 +1136,8 @@ def load_custom_nodes():
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path) load_custom_node(module_path)
load_custom_nodes() def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "silver_custom.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "silver_custom.py"))

View File

@ -47,7 +47,7 @@
" !git pull\n", " !git pull\n",
"\n", "\n",
"!echo -= Install dependencies =-\n", "!echo -= Install dependencies =-\n",
"!pip install xformers==0.0.16 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117" "!pip install xformers -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118"
] ]
}, },
{ {

View File

@ -4,7 +4,7 @@ torchsde
einops einops
open-clip-torch open-clip-torch
transformers>=4.25.1 transformers>=4.25.1
safetensors safetensors>=0.3.0
pytorch_lightning pytorch_lightning
aiohttp aiohttp
accelerate accelerate

View File

@ -18,6 +18,7 @@ except ImportError:
sys.exit() sys.exit()
import mimetypes import mimetypes
from comfy.cli_args import args
@web.middleware @web.middleware
@ -27,6 +28,23 @@ async def cache_control(request: web.Request, handler):
response.headers.setdefault('Cache-Control', 'no-cache') response.headers.setdefault('Cache-Control', 'no-cache')
return response return response
def create_cors_middleware(allowed_origin: str):
@web.middleware
async def cors_middleware(request: web.Request, handler):
if request.method == "OPTIONS":
# Pre-flight request. Reply successfully:
response = web.Response()
else:
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
return cors_middleware
class PromptServer(): class PromptServer():
def __init__(self, loop): def __init__(self, loop):
PromptServer.instance = self PromptServer.instance = self
@ -37,11 +55,17 @@ class PromptServer():
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
self.number = 0 self.number = 0
self.app = web.Application(client_max_size=20971520, middlewares=[cache_control])
middlewares = [cache_control]
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
self.web_root = os.path.join(os.path.dirname( self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web") os.path.realpath(__file__)), "web")
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None self.last_node_id = None
self.client_id = None self.client_id = None
@ -88,7 +112,7 @@ class PromptServer():
@routes.post("/upload/image") @routes.post("/upload/image")
async def upload_image(request): async def upload_image(request):
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") upload_dir = folder_paths.get_input_directory()
if not os.path.exists(upload_dir): if not os.path.exists(upload_dir):
os.makedirs(upload_dir) os.makedirs(upload_dir)
@ -130,10 +154,10 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output") type = request.rel_url.query.get("type", "output")
if type not in ["output", "input", "temp"]: output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400) return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
if "subfolder" in request.rel_url.query: if "subfolder" in request.rel_url.query:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
@ -271,8 +295,9 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete) self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200) return web.Response(status=200)
self.app.add_routes(routes) def add_routes(self):
self.app.add_routes(self.routes)
self.app.add_routes([ self.app.add_routes([
web.static('/', self.web_root), web.static('/', self.web_root),
]) ])

View File

@ -0,0 +1,137 @@
import { app } from "/scripts/app.js";
// Adds filtering to combo context menus
const id = "Comfy.ContextMenuFilter";
app.registerExtension({
name: id,
init() {
const ctxMenu = LiteGraph.ContextMenu;
LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input");
Object.assign(filter.style, {
width: "calc(100% - 10px)",
border: "0",
boxSizing: "border-box",
background: "#333",
border: "1px solid #999",
margin: "0 0 5px 5px",
color: "#fff",
});
filter.placeholder = "Filter list";
this.root.prepend(filter);
let selectedIndex = 0;
let items = this.root.querySelectorAll(".litemenu-entry");
let itemCount = items.length;
let selectedItem;
// Apply highlighting to the selected item
function updateSelected() {
if (selectedItem) {
selectedItem.style.setProperty("background-color", "");
selectedItem.style.setProperty("color", "");
}
selectedItem = items[selectedIndex];
if (selectedItem) {
selectedItem.style.setProperty("background-color", "#ccc", "important");
selectedItem.style.setProperty("color", "#000", "important");
}
}
const positionList = () => {
const rect = this.root.getBoundingClientRect();
// If the top is off screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
updateSelected();
// Arrow up/down to select items
filter.addEventListener("keydown", (e) => {
if (e.key === "ArrowUp") {
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
e.preventDefault();
} else if (e.key === "ArrowDown") {
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
e.preventDefault();
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
selectedItem.click();
} else if(e.key === "Escape") {
this.close();
}
});
filter.addEventListener("input", () => {
// Hide all items that dont match our filter
const term = filter.value.toLocaleLowerCase();
items = this.root.querySelectorAll(".litemenu-entry");
// When filtering recompute which items are visible for arrow up/down
// Try and maintain selection
let visibleItems = [];
for (const item of items) {
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
if (visible) {
item.style.display = "block";
if (item === selectedItem) {
selectedIndex = visibleItems.length;
}
visibleItems.push(item);
} else {
item.style.display = "none";
if (item === selectedItem) {
selectedIndex = 0;
}
}
}
items = visibleItems;
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
positionList();
});
}
return ctx;
};
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
},
});

View File

@ -30,7 +30,8 @@ app.registerExtension({
} }
// Overwrite the value in the serialized workflow pnginfo // Overwrite the value in the serialized workflow pnginfo
workflowNode.widgets_values[widgetIndex] = prompt; if (workflowNode?.widgets_values)
workflowNode.widgets_values[widgetIndex] = prompt;
return prompt; return prompt;
}; };

View File

@ -3,10 +3,10 @@ import { app } from "/scripts/app.js";
// Inverts the scrolling of context menus // Inverts the scrolling of context menus
const id = "Comfy.InvertMenuScrolling"; const id = "Comfy.InvertMenuScrolling";
const ctxMenu = LiteGraph.ContextMenu;
app.registerExtension({ app.registerExtension({
name: id, name: id,
init() { init() {
const ctxMenu = LiteGraph.ContextMenu;
const replace = () => { const replace = () => {
LiteGraph.ContextMenu = function (values, options) { LiteGraph.ContextMenu = function (values, options) {
options = options || {}; options = options || {};

View File

@ -11,11 +11,14 @@ app.registerExtension({
this.properties = {}; this.properties = {};
} }
this.properties.showOutputText = RerouteNode.defaultVisibility; this.properties.showOutputText = RerouteNode.defaultVisibility;
this.properties.horizontal = false;
this.addInput("", "*"); this.addInput("", "*");
this.addOutput(this.properties.showOutputText ? "*" : "", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*");
this.onConnectionsChange = function (type, index, connected, link_info) { this.onConnectionsChange = function (type, index, connected, link_info) {
this.applyOrientation();
// Prevent multiple connections to different types when we have no input // Prevent multiple connections to different types when we have no input
if (connected && type === LiteGraph.OUTPUT) { if (connected && type === LiteGraph.OUTPUT) {
// Ignore wildcard nodes as these will be updated to real types // Ignore wildcard nodes as these will be updated to real types
@ -43,12 +46,19 @@ app.registerExtension({
const node = app.graph.getNodeById(link.origin_id); const node = app.graph.getNodeById(link.origin_id);
const type = node.constructor.type; const type = node.constructor.type;
if (type === "Reroute") { if (type === "Reroute") {
// Move the previous node if (node === this) {
currentNode = node; // We've found a circle
currentNode.disconnectInput(link.target_slot);
currentNode = null;
}
else {
// Move the previous node
currentNode = node;
}
} else { } else {
// We've found the end // We've found the end
inputNode = currentNode; inputNode = currentNode;
inputType = node.outputs[link.origin_slot].type; inputType = node.outputs[link.origin_slot]?.type ?? null;
break; break;
} }
} else { } else {
@ -80,7 +90,7 @@ app.registerExtension({
updateNodes.push(node); updateNodes.push(node);
} else { } else {
// We've found an output // We've found an output
const nodeOutType = node.inputs[link.target_slot].type; const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && nodeOutType !== inputType) { if (inputType && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it // The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot); node.disconnectInput(link.target_slot);
@ -105,6 +115,7 @@ app.registerExtension({
node.__outputType = displayType; node.__outputType = displayType;
node.outputs[0].name = node.properties.showOutputText ? displayType : ""; node.outputs[0].name = node.properties.showOutputText ? displayType : "";
node.size = node.computeSize(); node.size = node.computeSize();
node.applyOrientation();
for (const l of node.outputs[0].links || []) { for (const l of node.outputs[0].links || []) {
const link = app.graph.links[l]; const link = app.graph.links[l];
@ -146,6 +157,7 @@ app.registerExtension({
this.outputs[0].name = ""; this.outputs[0].name = "";
} }
this.size = this.computeSize(); this.size = this.computeSize();
this.applyOrientation();
app.graph.setDirtyCanvas(true, true); app.graph.setDirtyCanvas(true, true);
}, },
}, },
@ -154,9 +166,32 @@ app.registerExtension({
callback: () => { callback: () => {
RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility); RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility);
}, },
},
{
// naming is inverted with respect to LiteGraphNode.horizontal
// LiteGraphNode.horizontal == true means that
// each slot in the inputs and outputs are layed out horizontally,
// which is the opposite of the visual orientation of the inputs and outputs as a node
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
callback: () => {
this.properties.horizontal = !this.properties.horizontal;
this.applyOrientation();
},
} }
); );
} }
applyOrientation() {
this.horizontal = this.properties.horizontal;
if (this.horizontal) {
// we correct the input position, because LiteGraphNode.horizontal
// doesn't account for title presence
// which reroute nodes don't have
this.inputs[0].pos = [this.size[0] / 2, 0];
} else {
delete this.inputs[0].pos;
}
app.graph.setDirtyCanvas(true, true);
}
computeSize() { computeSize() {
return [ return [

View File

@ -0,0 +1,21 @@
import { app } from "/scripts/app.js";
// Adds defaults for quickly adding nodes with middle click on the input/output
app.registerExtension({
name: "Comfy.SlotDefaults",
init() {
LiteGraph.middle_click_slot_add_default_node = true;
LiteGraph.slot_types_default_in = {
MODEL: "CheckpointLoaderSimple",
LATENT: "EmptyLatentImage",
VAE: "VAELoader",
};
LiteGraph.slot_types_default_out = {
LATENT: "VAEDecode",
IMAGE: "SaveImage",
CLIP: "CLIPTextEncode",
};
},
});

View File

@ -0,0 +1,89 @@
import { app } from "/scripts/app.js";
// Shift + drag/resize to snap to grid
app.registerExtension({
name: "Comfy.SnapToGrid",
init() {
// Add setting to control grid size
app.ui.settings.addSetting({
id: "Comfy.SnapToGrid.GridSize",
name: "Grid Size",
type: "number",
attrs: {
min: 1,
max: 500,
},
tooltip:
"When dragging and resizing nodes while holding shift they will be aligned to the grid, this controls the size of that grid.",
defaultValue: LiteGraph.CANVAS_GRID_SIZE,
onChange(value) {
LiteGraph.CANVAS_GRID_SIZE = +value;
},
});
// After moving a node, if the shift key is down align it to grid
const onNodeMoved = app.canvas.onNodeMoved;
app.canvas.onNodeMoved = function (node) {
const r = onNodeMoved?.apply(this, arguments);
if (app.shiftDown) {
// Ensure all selected nodes are realigned
for (const id in this.selected_nodes) {
this.selected_nodes[id].alignToGrid();
}
}
return r;
};
// When a node is added, add a resize handler to it so we can fix align the size with the grid
const onNodeAdded = app.graph.onNodeAdded;
app.graph.onNodeAdded = function (node) {
const onResize = node.onResize;
node.onResize = function () {
if (app.shiftDown) {
const w = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[0] / LiteGraph.CANVAS_GRID_SIZE);
const h = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[1] / LiteGraph.CANVAS_GRID_SIZE);
node.size[0] = w;
node.size[1] = h;
}
return onResize?.apply(this, arguments);
};
return onNodeAdded?.apply(this, arguments);
};
// Draw a preview of where the node will go if holding shift and the node is selected
const origDrawNode = LGraphCanvas.prototype.drawNode;
LGraphCanvas.prototype.drawNode = function (node, ctx) {
if (app.shiftDown && this.node_dragged && node.id in this.selected_nodes) {
const x = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[0] / LiteGraph.CANVAS_GRID_SIZE);
const y = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[1] / LiteGraph.CANVAS_GRID_SIZE);
const shiftX = x - node.pos[0];
let shiftY = y - node.pos[1];
let w, h;
if (node.flags.collapsed) {
w = node._collapsed_width;
h = LiteGraph.NODE_TITLE_HEIGHT;
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
} else {
w = node.size[0];
h = node.size[1];
let titleMode = node.constructor.title_mode;
if (titleMode !== LiteGraph.TRANSPARENT_TITLE && titleMode !== LiteGraph.NO_TITLE) {
h += LiteGraph.NODE_TITLE_HEIGHT;
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
}
}
const f = ctx.fillStyle;
ctx.fillStyle = "rgba(100, 100, 100, 0.5)";
ctx.fillRect(shiftX, shiftY, w, h);
ctx.fillStyle = f;
}
return origDrawNode.apply(this, arguments);
};
},
});

View File

@ -20,7 +20,7 @@ function hideWidget(node, widget, suffix = "") {
if (link == null) { if (link == null) {
return undefined; return undefined;
} }
return widget.value; return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
}; };
// Hide any linked widgets, e.g. seed+randomize // Hide any linked widgets, e.g. seed+randomize
@ -101,7 +101,7 @@ app.registerExtension({
callback: () => convertToWidget(this, w), callback: () => convertToWidget(this, w),
}); });
} else { } else {
const config = nodeData?.input?.required[w.name] || [w.type, w.options || {}]; const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}];
if (isConvertableWidget(w, config)) { if (isConvertableWidget(w, config)) {
toInput.push({ toInput.push({
content: `Convert ${w.name} to input`, content: `Convert ${w.name} to input`,

View File

@ -5,10 +5,20 @@ import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js";
class ComfyApp { class ComfyApp {
/**
* List of {number, batchCount} entries to queue
*/
#queueItems = [];
/**
* If the queue is currently being processed
*/
#processingQueue = false;
constructor() { constructor() {
this.ui = new ComfyUI(this); this.ui = new ComfyUI(this);
this.extensions = []; this.extensions = [];
this.nodeOutputs = {}; this.nodeOutputs = {};
this.shiftDown = false;
} }
/** /**
@ -102,6 +112,46 @@ class ComfyApp {
}; };
} }
#addNodeKeyHandler(node) {
const app = this;
const origNodeOnKeyDown = node.prototype.onKeyDown;
node.prototype.onKeyDown = function(e) {
if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) {
return false;
}
if (this.flags.collapsed || !this.imgs || this.imageIndex === null) {
return;
}
let handled = false;
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
if (e.key === "ArrowLeft") {
this.imageIndex -= 1;
} else if (e.key === "ArrowRight") {
this.imageIndex += 1;
}
this.imageIndex %= this.imgs.length;
if (this.imageIndex < 0) {
this.imageIndex = this.imgs.length + this.imageIndex;
}
handled = true;
} else if (e.key === "Escape") {
this.imageIndex = null;
handled = true;
}
if (handled === true) {
e.preventDefault();
e.stopImmediatePropagation();
return false;
}
}
}
/** /**
* Adds Custom drawing logic for nodes * Adds Custom drawing logic for nodes
* e.g. Draws images and handles thumbnail navigation on nodes that output images * e.g. Draws images and handles thumbnail navigation on nodes that output images
@ -628,11 +678,16 @@ class ComfyApp {
#addKeyboardHandler() { #addKeyboardHandler() {
window.addEventListener("keydown", (e) => { window.addEventListener("keydown", (e) => {
this.shiftDown = e.shiftKey;
// Queue prompt using ctrl or command + enter // Queue prompt using ctrl or command + enter
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) { if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
this.queuePrompt(e.shiftKey ? -1 : 0); this.queuePrompt(e.shiftKey ? -1 : 0);
} }
}); });
window.addEventListener("keyup", (e) => {
this.shiftDown = e.shiftKey;
});
} }
/** /**
@ -667,6 +722,9 @@ class ComfyApp {
const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph)); const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph));
this.ctx = canvasEl.getContext("2d"); this.ctx = canvasEl.getContext("2d");
LiteGraph.release_link_on_empty_shows_menu = true;
LiteGraph.alt_drag_do_clone_nodes = true;
this.graph.start(); this.graph.start();
function resizeCanvas() { function resizeCanvas() {
@ -785,6 +843,7 @@ class ComfyApp {
this.#addNodeContextMenuHandler(node); this.#addNodeContextMenuHandler(node);
this.#addDrawBackgroundHandler(node, app); this.#addDrawBackgroundHandler(node, app);
this.#addNodeKeyHandler(node);
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
LiteGraph.registerNodeType(nodeId, node); LiteGraph.registerNodeType(nodeId, node);
@ -802,7 +861,7 @@ class ComfyApp {
this.clean(); this.clean();
if (!graphData) { if (!graphData) {
graphData = defaultGraph; graphData = structuredClone(defaultGraph);
} }
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
@ -915,31 +974,47 @@ class ComfyApp {
} }
async queuePrompt(number, batchCount = 1) { async queuePrompt(number, batchCount = 1) {
for (let i = 0; i < batchCount; i++) { this.#queueItems.push({ number, batchCount });
const p = await this.graphToPrompt();
try { // Only have one action process the items so each one gets a unique seed correctly
await api.queuePrompt(number, p); if (this.#processingQueue) {
} catch (error) { return;
this.ui.dialog.show(error.response || error.toString()); }
return;
} this.#processingQueue = true;
try {
while (this.#queueItems.length) {
({ number, batchCount } = this.#queueItems.pop());
for (const n of p.workflow.nodes) { for (let i = 0; i < batchCount; i++) {
const node = graph.getNodeById(n.id); const p = await this.graphToPrompt();
if (node.widgets) {
for (const widget of node.widgets) { try {
// Allow widgets to run callbacks after a prompt has been queued await api.queuePrompt(number, p);
// e.g. random seed after every gen } catch (error) {
if (widget.afterQueued) { this.ui.dialog.show(error.response || error.toString());
widget.afterQueued(); break;
}
for (const n of p.workflow.nodes) {
const node = graph.getNodeById(n.id);
if (node.widgets) {
for (const widget of node.widgets) {
// Allow widgets to run callbacks after a prompt has been queued
// e.g. random seed after every gen
if (widget.afterQueued) {
widget.afterQueued();
}
}
} }
} }
this.canvas.draw(true, true);
await this.ui.queue.update();
} }
} }
} finally {
this.canvas.draw(true, true); this.#processingQueue = false;
await this.ui.queue.update();
} }
} }

View File

@ -35,21 +35,86 @@ export function $el(tag, propsOrChildren, children) {
return element; return element;
} }
function dragElement(dragEl) { function dragElement(dragEl, settings) {
var posDiffX = 0, var posDiffX = 0,
posDiffY = 0, posDiffY = 0,
posStartX = 0, posStartX = 0,
posStartY = 0, posStartY = 0,
newPosX = 0, newPosX = 0,
newPosY = 0; newPosY = 0;
if (dragEl.getElementsByClassName('drag-handle')[0]) { if (dragEl.getElementsByClassName("drag-handle")[0]) {
// if present, the handle is where you move the DIV from: // if present, the handle is where you move the DIV from:
dragEl.getElementsByClassName('drag-handle')[0].onmousedown = dragMouseDown; dragEl.getElementsByClassName("drag-handle")[0].onmousedown = dragMouseDown;
} else { } else {
// otherwise, move the DIV from anywhere inside the DIV: // otherwise, move the DIV from anywhere inside the DIV:
dragEl.onmousedown = dragMouseDown; dragEl.onmousedown = dragMouseDown;
} }
// When the element resizes (e.g. view queue) ensure it is still in the windows bounds
const resizeObserver = new ResizeObserver(() => {
ensureInBounds();
}).observe(dragEl);
function ensureInBounds() {
if (dragEl.classList.contains("comfy-menu-manual-pos")) {
newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft));
newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop));
positionElement();
}
}
function positionElement() {
const halfWidth = document.body.clientWidth / 2;
const anchorRight = newPosX + dragEl.clientWidth / 2 > halfWidth;
// set the element's new position:
if (anchorRight) {
dragEl.style.left = "unset";
dragEl.style.right = document.body.clientWidth - newPosX - dragEl.clientWidth + "px";
} else {
dragEl.style.left = newPosX + "px";
dragEl.style.right = "unset";
}
dragEl.style.top = newPosY + "px";
dragEl.style.bottom = "unset";
if (savePos) {
localStorage.setItem(
"Comfy.MenuPosition",
JSON.stringify({
x: dragEl.offsetLeft,
y: dragEl.offsetTop,
})
);
}
}
function restorePos() {
let pos = localStorage.getItem("Comfy.MenuPosition");
if (pos) {
pos = JSON.parse(pos);
newPosX = pos.x;
newPosY = pos.y;
positionElement();
ensureInBounds();
}
}
let savePos = undefined;
settings.addSetting({
id: "Comfy.MenuPosition",
name: "Save menu position",
type: "boolean",
defaultValue: savePos,
onChange(value) {
if (savePos === undefined && value) {
restorePos();
}
savePos = value;
},
});
function dragMouseDown(e) { function dragMouseDown(e) {
e = e || window.event; e = e || window.event;
e.preventDefault(); e.preventDefault();
@ -64,18 +129,25 @@ function dragElement(dragEl) {
function elementDrag(e) { function elementDrag(e) {
e = e || window.event; e = e || window.event;
e.preventDefault(); e.preventDefault();
dragEl.classList.add("comfy-menu-manual-pos");
// calculate the new cursor position: // calculate the new cursor position:
posDiffX = e.clientX - posStartX; posDiffX = e.clientX - posStartX;
posDiffY = e.clientY - posStartY; posDiffY = e.clientY - posStartY;
posStartX = e.clientX; posStartX = e.clientX;
posStartY = e.clientY; posStartY = e.clientY;
newPosX = Math.min((document.body.clientWidth - dragEl.clientWidth), Math.max(0, (dragEl.offsetLeft + posDiffX)));
newPosY = Math.min((document.body.clientHeight - dragEl.clientHeight), Math.max(0, (dragEl.offsetTop + posDiffY))); newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft + posDiffX));
// set the element's new position: newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop + posDiffY));
dragEl.style.top = newPosY + "px";
dragEl.style.left = newPosX + "px"; positionElement();
} }
window.addEventListener("resize", () => {
ensureInBounds();
});
function closeDragElement() { function closeDragElement() {
// stop moving when mouse button is released: // stop moving when mouse button is released:
document.onmouseup = null; document.onmouseup = null;
@ -90,7 +162,7 @@ class ComfyDialog {
$el("p", { $: (p) => (this.textElement = p) }), $el("p", { $: (p) => (this.textElement = p) }),
$el("button", { $el("button", {
type: "button", type: "button",
textContent: "CLOSE", textContent: "Close",
onclick: () => this.close(), onclick: () => this.close(),
}), }),
]), ]),
@ -125,7 +197,7 @@ class ComfySettingsDialog extends ComfyDialog {
localStorage[settingId] = JSON.stringify(value); localStorage[settingId] = JSON.stringify(value);
} }
addSetting({ id, name, type, defaultValue, onChange }) { addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", }) {
if (!id) { if (!id) {
throw new Error("Settings must have an ID"); throw new Error("Settings must have an ID");
} }
@ -152,42 +224,83 @@ class ComfySettingsDialog extends ComfyDialog {
value = v; value = v;
}; };
let element;
value = this.getSettingValue(id, defaultValue);
if (typeof type === "function") { if (typeof type === "function") {
return type(name, setter, value); element = type(name, setter, value, attrs);
} else {
switch (type) {
case "boolean":
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type: "checkbox",
checked: !!value,
oninput: (e) => {
setter(e.target.checked);
},
...attrs
}),
]),
]);
break;
case "number":
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type,
value,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
default:
console.warn("Unsupported setting type, defaulting to text");
element = $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
value,
oninput: (e) => {
setter(e.target.value);
},
...attrs
}),
]),
]);
break;
}
}
if(tooltip) {
element.title = tooltip;
} }
switch (type) { return element;
case "boolean":
return $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
type: "checkbox",
checked: !!value,
oninput: (e) => {
setter(e.target.checked);
},
}),
]),
]);
default:
console.warn("Unsupported setting type, defaulting to text");
return $el("div", [
$el("label", { textContent: name || id }, [
$el("input", {
value,
oninput: (e) => {
setter(e.target.value);
},
}),
]),
]);
}
}, },
}); });
const self = this;
return {
get value() {
return self.getSettingValue(id, defaultValue);
},
set value(v) {
self.setSettingValue(id, v);
},
};
} }
show() { show() {
super.show(); super.show();
Object.assign(this.textElement.style, {
display: "flex",
flexDirection: "column",
gap: "10px"
});
this.textElement.replaceChildren(...this.settings.map((s) => s.render())); this.textElement.replaceChildren(...this.settings.map((s) => s.render()));
} }
} }
@ -225,10 +338,10 @@ class ComfyList {
$el("button", { $el("button", {
textContent: "Load", textContent: "Load",
onclick: () => { onclick: () => {
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
if (item.outputs) { if (item.outputs) {
app.nodeOutputs = item.outputs; app.nodeOutputs = item.outputs;
} }
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
}, },
}), }),
$el("button", { $el("button", {
@ -300,6 +413,13 @@ export class ComfyUI {
this.history.update(); this.history.update();
}); });
const confirmClear = this.settings.addSetting({
id: "Comfy.ConfirmClear",
name: "Require confirmation when clearing workflow",
type: "boolean",
defaultValue: true,
});
const fileInput = $el("input", { const fileInput = $el("input", {
type: "file", type: "file",
accept: ".json,image/png", accept: ".json,image/png",
@ -311,39 +431,57 @@ export class ComfyUI {
}); });
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
$el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [ $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [
$el("span.drag-handle"), $el("span.drag-handle"),
$el("span", { $: (q) => (this.queueSize = q) }), $el("span", { $: (q) => (this.queueSize = q) }),
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
]), ]),
$el("button.comfy-queue-btn", { textContent: "Queue Prompt", onclick: () => app.queuePrompt(0, this.batchCount) }), $el("button.comfy-queue-btn", {
textContent: "Queue Prompt",
onclick: () => app.queuePrompt(0, this.batchCount),
}),
$el("div", {}, [ $el("div", {}, [
$el("label", { innerHTML: "Extra options"}, [ $el("label", { innerHTML: "Extra options" }, [
$el("input", { type: "checkbox", $el("input", {
onchange: (i) => { type: "checkbox",
document.getElementById('extraOptions').style.display = i.srcElement.checked ? "block" : "none"; onchange: (i) => {
this.batchCount = i.srcElement.checked ? document.getElementById('batchCountInputRange').value : 1; document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none";
document.getElementById('autoQueueCheckbox').checked = false; this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1;
} document.getElementById("autoQueueCheckbox").checked = false;
}) },
])
]),
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" }}, [
$el("label", { innerHTML: "Batch count" }, [
$el("input", { id: "batchCountInputNumber", type: "number", value: this.batchCount, min: "1", style: { width: "35%", "margin-left": "0.4em" },
oninput: (i) => {
this.batchCount = i.target.value;
document.getElementById('batchCountInputRange').value = this.batchCount;
}
}), }),
$el("input", { id: "batchCountInputRange", type: "range", min: "1", max: "100", value: this.batchCount, ]),
]),
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" } }, [
$el("label", { innerHTML: "Batch count" }, [
$el("input", {
id: "batchCountInputNumber",
type: "number",
value: this.batchCount,
min: "1",
style: { width: "35%", "margin-left": "0.4em" },
oninput: (i) => {
this.batchCount = i.target.value;
document.getElementById("batchCountInputRange").value = this.batchCount;
},
}),
$el("input", {
id: "batchCountInputRange",
type: "range",
min: "1",
max: "100",
value: this.batchCount,
oninput: (i) => { oninput: (i) => {
this.batchCount = i.srcElement.value; this.batchCount = i.srcElement.value;
document.getElementById('batchCountInputNumber').value = i.srcElement.value; document.getElementById("batchCountInputNumber").value = i.srcElement.value;
} },
}),
$el("input", {
id: "autoQueueCheckbox",
type: "checkbox",
checked: false,
title: "automatically queue prompt when the queue size hits 0",
}), }),
$el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, title: "automatically queue prompt when the queue size hits 0",
})
]), ]),
]), ]),
$el("div.comfy-menu-btns", [ $el("div.comfy-menu-btns", [
@ -389,14 +527,19 @@ export class ComfyUI {
$el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { textContent: "Clear", onclick: () => { $el("button", { textContent: "Clear", onclick: () => {
app.clean(); if (!confirmClear.value || confirm("Clear workflow?")) {
app.graph.clear(); app.clean();
app.graph.clear();
}
}}),
$el("button", { textContent: "Load Default", onclick: () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
}
}}), }}),
$el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }),
$el("button", { textContent: "Delete Images", onclick: () => api.deleteAllImages() }),
]); ]);
dragElement(this.menuContainer); dragElement(this.menuContainer, this.settings);
this.setStatus({ exec_info: { queue_remaining: "X" } }); this.setStatus({ exec_info: { queue_remaining: "X" } });
} }
@ -404,10 +547,14 @@ export class ComfyUI {
setStatus(status) { setStatus(status) {
this.queueSize.textContent = "Queue size: " + (status ? status.exec_info.queue_remaining : "ERR"); this.queueSize.textContent = "Queue size: " + (status ? status.exec_info.queue_remaining : "ERR");
if (status) { if (status) {
if (this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && document.getElementById('autoQueueCheckbox').checked) { if (
this.lastQueueSize != 0 &&
status.exec_info.queue_remaining == 0 &&
document.getElementById("autoQueueCheckbox").checked
) {
app.queuePrompt(0, this.batchCount); app.queuePrompt(0, this.batchCount);
} }
this.lastQueueSize = status.exec_info.queue_remaining this.lastQueueSize = status.exec_info.queue_remaining;
} }
} }
} }

View File

@ -306,7 +306,7 @@ export const ComfyWidgets = {
const fileInput = document.createElement("input"); const fileInput = document.createElement("input");
Object.assign(fileInput, { Object.assign(fileInput, {
type: "file", type: "file",
accept: "image/jpeg,image/png", accept: "image/jpeg,image/png,image/webp",
style: "display: none", style: "display: none",
onchange: async () => { onchange: async () => {
if (fileInput.files.length) { if (fileInput.files.length) {

View File

@ -39,18 +39,19 @@ body {
position: fixed; /* Stay in place */ position: fixed; /* Stay in place */
z-index: 100; /* Sit on top */ z-index: 100; /* Sit on top */
padding: 30px 30px 10px 30px; padding: 30px 30px 10px 30px;
background-color: #ff0000; /* Modal background */ background-color: #353535; /* Modal background */
color: #ff4444;
box-shadow: 0px 0px 20px #888888; box-shadow: 0px 0px 20px #888888;
border-radius: 10px; border-radius: 10px;
text-align: center;
top: 50%; top: 50%;
left: 50%; left: 50%;
max-width: 80vw; max-width: 80vw;
max-height: 80vh; max-height: 80vh;
transform: translate(-50%, -50%); transform: translate(-50%, -50%);
overflow: hidden; overflow: hidden;
min-width: 60%;
justify-content: center; justify-content: center;
font-family: monospace;
font-size: 15px;
} }
.comfy-modal-content { .comfy-modal-content {
@ -70,31 +71,11 @@ body {
margin: 3px 3px 3px 4px; margin: 3px 3px 3px 4px;
} }
.comfy-modal button {
cursor: pointer;
color: #aaaaaa;
border: none;
background-color: transparent;
font-size: 24px;
font-weight: bold;
width: 100%;
}
.comfy-modal button:hover,
.comfy-modal button:focus {
color: #000;
text-decoration: none;
cursor: pointer;
}
.comfy-menu { .comfy-menu {
width: 200px;
font-size: 15px; font-size: 15px;
position: absolute; position: absolute;
top: 50%; top: 50%;
right: 0%; right: 0%;
background-color: white;
color: #000;
text-align: center; text-align: center;
z-index: 100; z-index: 100;
width: 170px; width: 170px;
@ -109,7 +90,8 @@ body {
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
} }
.comfy-menu button { .comfy-menu button,
.comfy-modal button {
font-size: 20px; font-size: 20px;
} }
@ -130,7 +112,8 @@ body {
.comfy-menu > button, .comfy-menu > button,
.comfy-menu-btns button, .comfy-menu-btns button,
.comfy-menu .comfy-list button { .comfy-menu .comfy-list button,
.comfy-modal button{
color: #ddd; color: #ddd;
background-color: #222; background-color: #222;
border-radius: 8px; border-radius: 8px;
@ -220,11 +203,22 @@ button.comfy-queue-btn {
} }
.comfy-modal.comfy-settings { .comfy-modal.comfy-settings {
background-color: var(--bg-color); text-align: center;
color: var(--fg-color); font-family: sans-serif;
color: #999;
z-index: 99; z-index: 99;
} }
.comfy-modal input,
.comfy-modal select {
color: #ddd;
background-color: #222;
border-radius: 8px;
border-color: #4e4e4e;
border-style: solid;
font-size: inherit;
}
@media only screen and (max-height: 850px) { @media only screen and (max-height: 850px) {
.comfy-menu { .comfy-menu {
top: 0 !important; top: 0 !important;
@ -237,3 +231,28 @@ button.comfy-queue-btn {
visibility:hidden visibility:hidden
} }
} }
.graphdialog {
min-height: 1em;
}
.graphdialog .name {
font-size: 14px;
font-family: sans-serif;
color: #999999;
}
.graphdialog button {
margin-top: unset;
vertical-align: unset;
height: 1.6em;
padding-right: 8px;
}
.graphdialog input, .graphdialog textarea, .graphdialog select {
background-color: #222;
border: 2px solid;
border-color: #444444;
color: #ddd;
border-radius: 12px 0 0 12px;
}