mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
a52b976dd5
@ -14,7 +14,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||||
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
- Works even if you don't have a GPU with: ```--cpu``` (slow)
|
||||||
- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models.
|
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||||
- Embeddings/Textual inversion
|
- Embeddings/Textual inversion
|
||||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||||
- Loading full workflows (with seeds) from generated PNG files.
|
- Loading full workflows (with seeds) from generated PNG files.
|
||||||
@ -24,6 +24,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
|
||||||
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
- [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
|
||||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||||
|
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||||
- Starts up very fast.
|
- Starts up very fast.
|
||||||
- Works fully offline: will never download anything.
|
- Works fully offline: will never download anything.
|
||||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||||
|
|||||||
31
comfy/cli_args.py
Normal file
31
comfy/cli_args.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||||
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
|
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||||
|
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||||
|
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||||
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
|
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
|
|
||||||
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||||
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
|
||||||
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
|
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||||
|
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||||
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
62
comfy/clip_vision.py
Normal file
62
comfy/clip_vision.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
|
||||||
|
from .utils import load_torch_file, transformers_convert
|
||||||
|
import os
|
||||||
|
|
||||||
|
class ClipVisionModel():
|
||||||
|
def __init__(self, json_config):
|
||||||
|
config = CLIPVisionConfig.from_json_file(json_config)
|
||||||
|
self.model = CLIPVisionModelWithProjection(config)
|
||||||
|
self.processor = CLIPImageProcessor(crop_size=224,
|
||||||
|
do_center_crop=True,
|
||||||
|
do_convert_rgb=True,
|
||||||
|
do_normalize=True,
|
||||||
|
do_resize=True,
|
||||||
|
image_mean=[ 0.48145466,0.4578275,0.40821073],
|
||||||
|
image_std=[0.26862954,0.26130258,0.27577711],
|
||||||
|
resample=3, #bicubic
|
||||||
|
size=224)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
def encode_image(self, image):
|
||||||
|
inputs = self.processor(images=[image[0]], return_tensors="pt")
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def convert_to_transformers(sd):
|
||||||
|
sd_k = sd.keys()
|
||||||
|
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k:
|
||||||
|
keys_to_replace = {
|
||||||
|
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding",
|
||||||
|
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight",
|
||||||
|
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight",
|
||||||
|
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias",
|
||||||
|
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight",
|
||||||
|
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias",
|
||||||
|
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
for x in keys_to_replace:
|
||||||
|
if x in sd_k:
|
||||||
|
sd[keys_to_replace[x]] = sd.pop(x)
|
||||||
|
|
||||||
|
if "embedder.model.visual.proj" in sd_k:
|
||||||
|
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1)
|
||||||
|
|
||||||
|
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32)
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def load_clipvision_from_sd(sd):
|
||||||
|
sd = convert_to_transformers(sd)
|
||||||
|
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
|
else:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
clip = ClipVisionModel(json_config)
|
||||||
|
clip.load_sd(sd)
|
||||||
|
return clip
|
||||||
|
|
||||||
|
def load(ckpt_path):
|
||||||
|
sd = load_torch_file(ckpt_path)
|
||||||
|
return load_clipvision_from_sd(sd)
|
||||||
18
comfy/clip_vision_config_h.json
Normal file
18
comfy/clip_vision_config_h.json
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"image_size": 224,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 5120,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"model_type": "clip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"patch_size": 14,
|
||||||
|
"projection_dim": 1024,
|
||||||
|
"torch_dtype": "float32"
|
||||||
|
}
|
||||||
@ -1,8 +1,4 @@
|
|||||||
{
|
{
|
||||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
|
||||||
"architectures": [
|
|
||||||
"CLIPVisionModel"
|
|
||||||
],
|
|
||||||
"attention_dropout": 0.0,
|
"attention_dropout": 0.0,
|
||||||
"dropout": 0.0,
|
"dropout": 0.0,
|
||||||
"hidden_act": "quick_gelu",
|
"hidden_act": "quick_gelu",
|
||||||
@ -18,6 +14,5 @@
|
|||||||
"num_hidden_layers": 24,
|
"num_hidden_layers": 24,
|
||||||
"patch_size": 14,
|
"patch_size": 14,
|
||||||
"projection_dim": 768,
|
"projection_dim": 768,
|
||||||
"torch_dtype": "float32",
|
"torch_dtype": "float32"
|
||||||
"transformers_version": "4.24.0"
|
|
||||||
}
|
}
|
||||||
362
comfy/diffusers_convert.py
Normal file
362
comfy/diffusers_convert.py
Normal file
@ -0,0 +1,362 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from comfy.ldm.util import instantiate_from_config
|
||||||
|
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||||
|
|
||||||
|
# =================#
|
||||||
|
# UNet Conversion #
|
||||||
|
# =================#
|
||||||
|
|
||||||
|
unet_conversion_map = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||||
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||||
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||||
|
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||||
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||||
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||||
|
("out.0.weight", "conv_norm_out.weight"),
|
||||||
|
("out.0.bias", "conv_norm_out.bias"),
|
||||||
|
("out.2.weight", "conv_out.weight"),
|
||||||
|
("out.2.bias", "conv_out.bias"),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map_resnet = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("in_layers.0", "norm1"),
|
||||||
|
("in_layers.2", "conv1"),
|
||||||
|
("out_layers.0", "norm2"),
|
||||||
|
("out_layers.3", "conv2"),
|
||||||
|
("emb_layers.1", "time_emb_proj"),
|
||||||
|
("skip_connection", "conv_shortcut"),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map_layer = []
|
||||||
|
# hardcoded number of downblocks and resnets/attentions...
|
||||||
|
# would need smarter logic for other networks.
|
||||||
|
for i in range(4):
|
||||||
|
# loop over downblocks/upblocks
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
# loop over resnets/attentions for downblocks
|
||||||
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no attention layers in down_blocks.3
|
||||||
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||||
|
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(3):
|
||||||
|
# loop over resnets/attentions for upblocks
|
||||||
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||||
|
|
||||||
|
if i > 0:
|
||||||
|
# no attention layers in up_blocks.0
|
||||||
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||||
|
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no downsample in down_blocks.3
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||||
|
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||||
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
# no upsample in up_blocks.3
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||||
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||||
|
sd_mid_atn_prefix = "middle_block.1."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||||
|
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_unet_state_dict(unet_state_dict):
|
||||||
|
# buyer beware: this is a *brittle* function,
|
||||||
|
# and correct output requires that all of these pieces interact in
|
||||||
|
# the exact order in which I have arranged them.
|
||||||
|
mapping = {k: k for k in unet_state_dict.keys()}
|
||||||
|
for sd_name, hf_name in unet_conversion_map:
|
||||||
|
mapping[hf_name] = sd_name
|
||||||
|
for k, v in mapping.items():
|
||||||
|
if "resnets" in k:
|
||||||
|
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||||
|
v = v.replace(hf_part, sd_part)
|
||||||
|
mapping[k] = v
|
||||||
|
for k, v in mapping.items():
|
||||||
|
for sd_part, hf_part in unet_conversion_map_layer:
|
||||||
|
v = v.replace(hf_part, sd_part)
|
||||||
|
mapping[k] = v
|
||||||
|
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
# ================#
|
||||||
|
# VAE Conversion #
|
||||||
|
# ================#
|
||||||
|
|
||||||
|
vae_conversion_map = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("nin_shortcut", "conv_shortcut"),
|
||||||
|
("norm_out", "conv_norm_out"),
|
||||||
|
("mid.attn_1.", "mid_block.attentions.0."),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
# down_blocks have two resnets
|
||||||
|
for j in range(2):
|
||||||
|
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
||||||
|
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
||||||
|
sd_downsample_prefix = f"down.{i}.downsample."
|
||||||
|
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||||
|
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
# up_blocks have three resnets
|
||||||
|
# also, up blocks in hf are numbered in reverse from sd
|
||||||
|
for j in range(3):
|
||||||
|
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||||
|
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||||
|
|
||||||
|
# this part accounts for mid blocks in both the encoder and the decoder
|
||||||
|
for i in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||||
|
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||||
|
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
vae_conversion_map_attn = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("norm.", "group_norm."),
|
||||||
|
("q.", "query."),
|
||||||
|
("k.", "key."),
|
||||||
|
("v.", "value."),
|
||||||
|
("proj_out.", "proj_attn."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_weight_for_sd(w):
|
||||||
|
# convert HF linear weights to SD conv2d weights
|
||||||
|
return w.reshape(*w.shape, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_vae_state_dict(vae_state_dict):
|
||||||
|
mapping = {k: k for k in vae_state_dict.keys()}
|
||||||
|
for k, v in mapping.items():
|
||||||
|
for sd_part, hf_part in vae_conversion_map:
|
||||||
|
v = v.replace(hf_part, sd_part)
|
||||||
|
mapping[k] = v
|
||||||
|
for k, v in mapping.items():
|
||||||
|
if "attentions" in k:
|
||||||
|
for sd_part, hf_part in vae_conversion_map_attn:
|
||||||
|
v = v.replace(hf_part, sd_part)
|
||||||
|
mapping[k] = v
|
||||||
|
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
||||||
|
weights_to_convert = ["q", "k", "v", "proj_out"]
|
||||||
|
for k, v in new_state_dict.items():
|
||||||
|
for weight_name in weights_to_convert:
|
||||||
|
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||||
|
print(f"Reshaping {k} for SD format")
|
||||||
|
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
# =========================#
|
||||||
|
# Text Encoder Conversion #
|
||||||
|
# =========================#
|
||||||
|
|
||||||
|
|
||||||
|
textenc_conversion_lst = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("resblocks.", "text_model.encoder.layers."),
|
||||||
|
("ln_1", "layer_norm1"),
|
||||||
|
("ln_2", "layer_norm2"),
|
||||||
|
(".c_fc.", ".fc1."),
|
||||||
|
(".c_proj.", ".fc2."),
|
||||||
|
(".attn", ".self_attn"),
|
||||||
|
("ln_final.", "transformer.text_model.final_layer_norm."),
|
||||||
|
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
||||||
|
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
||||||
|
]
|
||||||
|
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
||||||
|
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||||
|
|
||||||
|
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||||
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_text_enc_state_dict_v20(text_enc_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
capture_qkv_weight = {}
|
||||||
|
capture_qkv_bias = {}
|
||||||
|
for k, v in text_enc_dict.items():
|
||||||
|
if (
|
||||||
|
k.endswith(".self_attn.q_proj.weight")
|
||||||
|
or k.endswith(".self_attn.k_proj.weight")
|
||||||
|
or k.endswith(".self_attn.v_proj.weight")
|
||||||
|
):
|
||||||
|
k_pre = k[: -len(".q_proj.weight")]
|
||||||
|
k_code = k[-len("q_proj.weight")]
|
||||||
|
if k_pre not in capture_qkv_weight:
|
||||||
|
capture_qkv_weight[k_pre] = [None, None, None]
|
||||||
|
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
k.endswith(".self_attn.q_proj.bias")
|
||||||
|
or k.endswith(".self_attn.k_proj.bias")
|
||||||
|
or k.endswith(".self_attn.v_proj.bias")
|
||||||
|
):
|
||||||
|
k_pre = k[: -len(".q_proj.bias")]
|
||||||
|
k_code = k[-len("q_proj.bias")]
|
||||||
|
if k_pre not in capture_qkv_bias:
|
||||||
|
capture_qkv_bias[k_pre] = [None, None, None]
|
||||||
|
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||||
|
continue
|
||||||
|
|
||||||
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||||
|
new_state_dict[relabelled_key] = v
|
||||||
|
|
||||||
|
for k_pre, tensors in capture_qkv_weight.items():
|
||||||
|
if None in tensors:
|
||||||
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||||
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||||
|
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
||||||
|
|
||||||
|
for k_pre, tensors in capture_qkv_bias.items():
|
||||||
|
if None in tensors:
|
||||||
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||||
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||||
|
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
||||||
|
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_text_enc_state_dict(text_enc_dict):
|
||||||
|
return text_enc_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||||
|
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||||
|
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||||
|
|
||||||
|
# magic
|
||||||
|
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||||
|
if 'prediction_type' in diffusers_scheduler_conf:
|
||||||
|
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||||
|
|
||||||
|
if v2:
|
||||||
|
if v_pred:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||||
|
else:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||||
|
else:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||||
|
|
||||||
|
with open(config_path, 'r') as stream:
|
||||||
|
config = yaml.safe_load(stream)
|
||||||
|
|
||||||
|
model_config_params = config['model']['params']
|
||||||
|
clip_config = model_config_params['cond_stage_config']
|
||||||
|
scale_factor = model_config_params['scale_factor']
|
||||||
|
vae_config = model_config_params['first_stage_config']
|
||||||
|
vae_config['scale_factor'] = scale_factor
|
||||||
|
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||||
|
|
||||||
|
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||||
|
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||||
|
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||||
|
|
||||||
|
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||||
|
if osp.exists(unet_path):
|
||||||
|
unet_state_dict = load_file(unet_path, device="cpu")
|
||||||
|
else:
|
||||||
|
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||||
|
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||||
|
|
||||||
|
if osp.exists(vae_path):
|
||||||
|
vae_state_dict = load_file(vae_path, device="cpu")
|
||||||
|
else:
|
||||||
|
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||||
|
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||||
|
|
||||||
|
if osp.exists(text_enc_path):
|
||||||
|
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||||
|
else:
|
||||||
|
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||||
|
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||||
|
|
||||||
|
# Convert the UNet model
|
||||||
|
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||||
|
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||||
|
|
||||||
|
# Convert the VAE model
|
||||||
|
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||||
|
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||||
|
|
||||||
|
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||||
|
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||||
|
|
||||||
|
if is_v20_model:
|
||||||
|
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||||
|
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
||||||
|
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
else:
|
||||||
|
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||||
|
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
|
||||||
|
# Put together new checkpoint
|
||||||
|
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||||
|
|
||||||
|
clip = None
|
||||||
|
vae = None
|
||||||
|
|
||||||
|
class WeightsLoader(torch.nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
w = WeightsLoader()
|
||||||
|
load_state_dict_to = []
|
||||||
|
if output_vae:
|
||||||
|
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||||
|
w.first_stage_model = vae.first_stage_model
|
||||||
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
if output_clip:
|
||||||
|
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||||
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
model = instantiate_from_config(config["model"])
|
||||||
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
|
if fp16:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
return ModelPatcher(model), clip, vae
|
||||||
@ -78,7 +78,7 @@ class DDIMSampler(object):
|
|||||||
dynamic_threshold=None,
|
dynamic_threshold=None,
|
||||||
ucg_schedule=None,
|
ucg_schedule=None,
|
||||||
denoise_function=None,
|
denoise_function=None,
|
||||||
cond_concat=None,
|
extra_args=None,
|
||||||
to_zero=True,
|
to_zero=True,
|
||||||
end_step=None,
|
end_step=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -101,7 +101,7 @@ class DDIMSampler(object):
|
|||||||
dynamic_threshold=dynamic_threshold,
|
dynamic_threshold=dynamic_threshold,
|
||||||
ucg_schedule=ucg_schedule,
|
ucg_schedule=ucg_schedule,
|
||||||
denoise_function=denoise_function,
|
denoise_function=denoise_function,
|
||||||
cond_concat=cond_concat,
|
extra_args=extra_args,
|
||||||
to_zero=to_zero,
|
to_zero=to_zero,
|
||||||
end_step=end_step
|
end_step=end_step
|
||||||
)
|
)
|
||||||
@ -174,7 +174,7 @@ class DDIMSampler(object):
|
|||||||
dynamic_threshold=dynamic_threshold,
|
dynamic_threshold=dynamic_threshold,
|
||||||
ucg_schedule=ucg_schedule,
|
ucg_schedule=ucg_schedule,
|
||||||
denoise_function=None,
|
denoise_function=None,
|
||||||
cond_concat=None
|
extra_args=None
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ class DDIMSampler(object):
|
|||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||||
ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None):
|
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
@ -225,7 +225,7 @@ class DDIMSampler(object):
|
|||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat)
|
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
|
||||||
img, pred_x0 = outs
|
img, pred_x0 = outs
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
if img_callback: img_callback(pred_x0, i)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
@ -249,11 +249,11 @@ class DDIMSampler(object):
|
|||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
dynamic_threshold=None, denoise_function=None, cond_concat=None):
|
dynamic_threshold=None, denoise_function=None, extra_args=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
if denoise_function is not None:
|
if denoise_function is not None:
|
||||||
model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat)
|
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
|
||||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
model_output = self.model.apply_model(x, t, c)
|
model_output = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
|
|||||||
self.conditioning_key = conditioning_key
|
self.conditioning_key = conditioning_key
|
||||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
||||||
|
|
||||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
|
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
|
||||||
if self.conditioning_key is None:
|
if self.conditioning_key is None:
|
||||||
out = self.diffusion_model(x, t, control=control)
|
out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'concat':
|
elif self.conditioning_key == 'concat':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
out = self.diffusion_model(xc, t, control=control)
|
out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'crossattn':
|
elif self.conditioning_key == 'crossattn':
|
||||||
if not self.sequential_cross_attn:
|
if not self.sequential_cross_attn:
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
|
|||||||
# TorchScript changes names of the arguments
|
# TorchScript changes names of the arguments
|
||||||
# with argument cc defined as context=cc scripted model will produce
|
# with argument cc defined as context=cc scripted model will produce
|
||||||
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||||
out = self.scripted_diffusion_model(x, t, cc, control=control)
|
out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
out = self.diffusion_model(x, t, context=cc, control=control)
|
out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'hybrid':
|
elif self.conditioning_key == 'hybrid':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(xc, t, context=cc, control=control)
|
out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'hybrid-adm':
|
elif self.conditioning_key == 'hybrid-adm':
|
||||||
assert c_adm is not None
|
assert c_adm is not None
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
|
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'crossattn-adm':
|
elif self.conditioning_key == 'crossattn-adm':
|
||||||
assert c_adm is not None
|
assert c_adm is not None
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
|
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'adm':
|
elif self.conditioning_key == 'adm':
|
||||||
cc = c_crossattn[0]
|
cc = c_crossattn[0]
|
||||||
out = self.diffusion_model(x, t, y=cc, control=control)
|
out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|||||||
log = super().log_images(*args, **kwargs)
|
log = super().log_images(*args, **kwargs)
|
||||||
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
||||||
return log
|
return log
|
||||||
|
|
||||||
|
|
||||||
|
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
||||||
|
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
|
||||||
|
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.embed_key = embedding_key
|
||||||
|
self.embedding_dropout = embedding_dropout
|
||||||
|
# self._init_embedder(embedder_config, freeze_embedder)
|
||||||
|
self._init_noise_aug(noise_aug_config)
|
||||||
|
|
||||||
|
def _init_embedder(self, config, freeze=True):
|
||||||
|
embedder = instantiate_from_config(config)
|
||||||
|
if freeze:
|
||||||
|
self.embedder = embedder.eval()
|
||||||
|
self.embedder.train = disabled_train
|
||||||
|
for param in self.embedder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _init_noise_aug(self, config):
|
||||||
|
if config is not None:
|
||||||
|
# use the KARLO schedule for noise augmentation on CLIP image embeddings
|
||||||
|
noise_augmentor = instantiate_from_config(config)
|
||||||
|
assert isinstance(noise_augmentor, nn.Module)
|
||||||
|
noise_augmentor = noise_augmentor.eval()
|
||||||
|
noise_augmentor.train = disabled_train
|
||||||
|
self.noise_augmentor = noise_augmentor
|
||||||
|
else:
|
||||||
|
self.noise_augmentor = None
|
||||||
|
|
||||||
|
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
|
||||||
|
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
|
||||||
|
z, c = outputs[0], outputs[1]
|
||||||
|
img = batch[self.embed_key][:bs]
|
||||||
|
img = rearrange(img, 'b h w c -> b c h w')
|
||||||
|
c_adm = self.embedder(img)
|
||||||
|
if self.noise_augmentor is not None:
|
||||||
|
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
|
||||||
|
# assume this gives embeddings of noise levels
|
||||||
|
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
|
if self.training:
|
||||||
|
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
|
||||||
|
device=c_adm.device)[:, None]) * c_adm
|
||||||
|
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
|
||||||
|
noutputs = [z, all_conds]
|
||||||
|
noutputs.extend(outputs[2:])
|
||||||
|
return noutputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, N=8, n_row=4, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
|
||||||
|
return_original_cond=True)
|
||||||
|
log["inputs"] = x
|
||||||
|
log["reconstruction"] = xrec
|
||||||
|
assert self.model.conditioning_key is not None
|
||||||
|
assert self.cond_stage_key in ["caption", "txt"]
|
||||||
|
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
|
||||||
|
log["conditioning"] = xc
|
||||||
|
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
|
||||||
|
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
|
||||||
|
|
||||||
|
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
|
||||||
|
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
|
||||||
|
with ema_scope(f"Sampling"):
|
||||||
|
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
|
||||||
|
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=uc_, )
|
||||||
|
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||||
|
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||||
|
return log
|
||||||
|
|||||||
@ -307,7 +307,16 @@ def model_wrapper(
|
|||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t_continuous] * 2)
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
c_in = torch.cat([unconditional_condition, condition])
|
if isinstance(condition, dict):
|
||||||
|
assert isinstance(unconditional_condition, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in condition:
|
||||||
|
if isinstance(condition[k], list):
|
||||||
|
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_condition, condition])
|
||||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import torch
|
|||||||
|
|
||||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||||
|
|
||||||
|
|
||||||
MODEL_TYPES = {
|
MODEL_TYPES = {
|
||||||
"eps": "noise",
|
"eps": "noise",
|
||||||
"v": "v"
|
"v": "v"
|
||||||
@ -51,12 +50,20 @@ class DPMSolverSampler(object):
|
|||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
if cbs != batch_size:
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
if isinstance(ctmp, torch.Tensor):
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
elif isinstance(conditioning, list):
|
||||||
|
for ctmp in conditioning:
|
||||||
|
if ctmp.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
else:
|
else:
|
||||||
if conditioning.shape[0] != batch_size:
|
if isinstance(conditioning, torch.Tensor):
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
# sampling
|
# sampling
|
||||||
C, H, W = shape
|
C, H, W = shape
|
||||||
@ -83,6 +90,7 @@ class DPMSolverSampler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||||
|
lower_order_final=True)
|
||||||
|
|
||||||
return x.to(device), None
|
return x.to(device), None
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||||||
|
|
||||||
import model_management
|
import model_management
|
||||||
|
|
||||||
|
from . import tomesd
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
import xformers
|
import xformers
|
||||||
@ -20,6 +21,8 @@ if model_management.xformers_enabled():
|
|||||||
import os
|
import os
|
||||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||||
|
|
||||||
|
from cli_args import args
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
@ -473,7 +476,6 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
import sys
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
print("Using xformers cross attention")
|
print("Using xformers cross attention")
|
||||||
CrossAttention = MemoryEfficientCrossAttention
|
CrossAttention = MemoryEfficientCrossAttention
|
||||||
@ -481,7 +483,7 @@ elif model_management.pytorch_attention_enabled():
|
|||||||
print("Using pytorch cross attention")
|
print("Using pytorch cross attention")
|
||||||
CrossAttention = CrossAttentionPytorch
|
CrossAttention = CrossAttentionPytorch
|
||||||
else:
|
else:
|
||||||
if "--use-split-cross-attention" in sys.argv:
|
if args.use_split_cross_attention:
|
||||||
print("Using split optimization for cross attention")
|
print("Using split optimization for cross attention")
|
||||||
CrossAttention = CrossAttentionDoggettx
|
CrossAttention = CrossAttentionDoggettx
|
||||||
else:
|
else:
|
||||||
@ -504,12 +506,22 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
n = self.norm1(x)
|
||||||
x = self.attn2(self.norm2(x), context=context) + x
|
if "tomesd" in transformer_options:
|
||||||
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
|
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
||||||
|
else:
|
||||||
|
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||||
|
|
||||||
|
x += n
|
||||||
|
n = self.norm2(x)
|
||||||
|
n = self.attn2(n, context=context)
|
||||||
|
|
||||||
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -557,7 +569,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||||
self.use_linear = use_linear
|
self.use_linear = use_linear
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
# note: if no context is given, cross-attention defaults to self-attention
|
||||||
if not isinstance(context, list):
|
if not isinstance(context, list):
|
||||||
context = [context]
|
context = [context]
|
||||||
@ -570,7 +582,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
x = block(x, context=context[i])
|
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from typing import Optional, Any
|
|||||||
from ldm.modules.attention import MemoryEfficientCrossAttention
|
from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||||
import model_management
|
import model_management
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
@ -364,7 +364,7 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
if model_management.xformers_enabled() and attn_type == "vanilla":
|
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-xformers"
|
attn_type = "vanilla-xformers"
|
||||||
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-pytorch"
|
attn_type = "vanilla-pytorch"
|
||||||
|
|||||||
@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
support it as an extra input.
|
support it as an extra input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, x, emb, context=None):
|
def forward(self, x, emb, context=None, transformer_options={}):
|
||||||
for layer in self:
|
for layer in self:
|
||||||
if isinstance(layer, TimestepBlock):
|
if isinstance(layer, TimestepBlock):
|
||||||
x = layer(x, emb)
|
x = layer(x, emb)
|
||||||
elif isinstance(layer, SpatialTransformer):
|
elif isinstance(layer, SpatialTransformer):
|
||||||
x = layer(x, context)
|
x = layer(x, context, transformer_options)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
|
|||||||
return count_flops_attn(model, _x, y)
|
return count_flops_attn(model, _x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Timestep(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
return timestep_embedding(t, self.dim)
|
||||||
|
|
||||||
|
|
||||||
class UNetModel(nn.Module):
|
class UNetModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
The full UNet model with attention and timestep embedding.
|
The full UNet model with attention and timestep embedding.
|
||||||
@ -470,6 +479,7 @@ class UNetModel(nn.Module):
|
|||||||
num_attention_blocks=None,
|
num_attention_blocks=None,
|
||||||
disable_middle_self_attn=False,
|
disable_middle_self_attn=False,
|
||||||
use_linear_in_transformer=False,
|
use_linear_in_transformer=False,
|
||||||
|
adm_in_channels=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
@ -538,6 +548,15 @@ class UNetModel(nn.Module):
|
|||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
print("setting up linear c_adm embedding layer")
|
print("setting up linear c_adm embedding layer")
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
|
elif self.num_classes == "sequential":
|
||||||
|
assert adm_in_channels is not None
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
nn.Sequential(
|
||||||
|
linear(adm_in_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
@ -753,7 +772,7 @@ class UNetModel(nn.Module):
|
|||||||
self.middle_block.apply(convert_module_to_f32)
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
self.output_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
@ -762,6 +781,7 @@ class UNetModel(nn.Module):
|
|||||||
:param y: an [N] Tensor of labels, if class-conditional.
|
:param y: an [N] Tensor of labels, if class-conditional.
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
|
transformer_options["original_shape"] = list(x.shape)
|
||||||
assert (y is not None) == (
|
assert (y is not None) == (
|
||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
@ -775,13 +795,13 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context, transformer_options)
|
||||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||||
ctrl = control['input'].pop()
|
ctrl = control['input'].pop()
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
h += ctrl
|
h += ctrl
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
h = self.middle_block(h, emb, context)
|
h = self.middle_block(h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control['middle'].pop()
|
h += control['middle'].pop()
|
||||||
|
|
||||||
@ -793,7 +813,7 @@ class UNetModel(nn.Module):
|
|||||||
hsp += ctrl
|
hsp += ctrl
|
||||||
h = th.cat([h, hsp], dim=1)
|
h = th.cat([h, hsp], dim=1)
|
||||||
del hsp
|
del hsp
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context, transformer_options)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
|
|||||||
@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
|||||||
betas = 1 - alphas[1:] / alphas[:-1]
|
betas = 1 - alphas[1:] / alphas[:-1]
|
||||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||||
|
|
||||||
|
elif schedule == "squaredcos_cap_v2": # used for karlo prior
|
||||||
|
# return early
|
||||||
|
return betas_for_alpha_bar(
|
||||||
|
n_timestep,
|
||||||
|
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||||
|
)
|
||||||
|
|
||||||
elif schedule == "sqrt_linear":
|
elif schedule == "sqrt_linear":
|
||||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||||
elif schedule == "sqrt":
|
elif schedule == "sqrt":
|
||||||
@ -218,6 +225,7 @@ class GroupNorm32(nn.GroupNorm):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return super().forward(x.float()).type(x.dtype)
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def conv_nd(dims, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a 1D, 2D, or 3D convolution module.
|
Create a 1D, 2D, or 3D convolution module.
|
||||||
@ -267,4 +275,4 @@ class HybridConditioner(nn.Module):
|
|||||||
def noise_like(shape, device, repeat=False):
|
def noise_like(shape, device, repeat=False):
|
||||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||||
noise = lambda: torch.randn(shape, device=device)
|
noise = lambda: torch.randn(shape, device=device)
|
||||||
return repeat_noise() if repeat else noise()
|
return repeat_noise() if repeat else noise()
|
||||||
|
|||||||
59
comfy/ldm/modules/encoders/kornia_functions.py
Normal file
59
comfy/ldm/modules/encoders/kornia_functions.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
|
||||||
|
|
||||||
|
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
||||||
|
r"""Normalize an image/video tensor with mean and standard deviation.
|
||||||
|
.. math::
|
||||||
|
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
|
||||||
|
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
|
||||||
|
Args:
|
||||||
|
data: Image tensor of size :math:`(B, C, *)`.
|
||||||
|
mean: Mean for each channel.
|
||||||
|
std: Standard deviations for each channel.
|
||||||
|
Return:
|
||||||
|
Normalised tensor with same size as input :math:`(B, C, *)`.
|
||||||
|
Examples:
|
||||||
|
>>> x = torch.rand(1, 4, 3, 3)
|
||||||
|
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
|
||||||
|
>>> out.shape
|
||||||
|
torch.Size([1, 4, 3, 3])
|
||||||
|
>>> x = torch.rand(1, 4, 3, 3)
|
||||||
|
>>> mean = torch.zeros(4)
|
||||||
|
>>> std = 255. * torch.ones(4)
|
||||||
|
>>> out = normalize(x, mean, std)
|
||||||
|
>>> out.shape
|
||||||
|
torch.Size([1, 4, 3, 3])
|
||||||
|
"""
|
||||||
|
shape = data.shape
|
||||||
|
if len(mean.shape) == 0 or mean.shape[0] == 1:
|
||||||
|
mean = mean.expand(shape[1])
|
||||||
|
if len(std.shape) == 0 or std.shape[0] == 1:
|
||||||
|
std = std.expand(shape[1])
|
||||||
|
|
||||||
|
# Allow broadcast on channel dimension
|
||||||
|
if mean.shape and mean.shape[0] != 1:
|
||||||
|
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
|
||||||
|
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
|
||||||
|
|
||||||
|
# Allow broadcast on channel dimension
|
||||||
|
if std.shape and std.shape[0] != 1:
|
||||||
|
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
|
||||||
|
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
|
||||||
|
|
||||||
|
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
|
||||||
|
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
|
||||||
|
|
||||||
|
if mean.shape:
|
||||||
|
mean = mean[..., :, None]
|
||||||
|
if std.shape:
|
||||||
|
std = std[..., :, None]
|
||||||
|
|
||||||
|
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
|
||||||
|
|
||||||
|
return out.view(shape)
|
||||||
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from . import kornia_functions
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||||
@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module):
|
|||||||
c = batch[key][:, None]
|
c = batch[key][:, None]
|
||||||
if self.ucg_rate > 0. and not disable_dropout:
|
if self.ucg_rate > 0. and not disable_dropout:
|
||||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||||
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
||||||
c = c.long()
|
c = c.long()
|
||||||
c = self.embedding(c)
|
c = self.embedding(c)
|
||||||
return c
|
return c
|
||||||
@ -57,18 +58,20 @@ def disabled_train(self, mode=True):
|
|||||||
|
|
||||||
class FrozenT5Embedder(AbstractEncoder):
|
class FrozenT5Embedder(AbstractEncoder):
|
||||||
"""Uses the T5 transformer encoder for text"""
|
"""Uses the T5 transformer encoder for text"""
|
||||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
|
||||||
|
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
|
||||||
|
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length # TODO: typical value?
|
self.max_length = max_length # TODO: typical value?
|
||||||
if freeze:
|
if freeze:
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
self.transformer = self.transformer.eval()
|
self.transformer = self.transformer.eval()
|
||||||
#self.train = disabled_train
|
# self.train = disabled_train
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
"pooled",
|
"pooled",
|
||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
self.transformer = self.transformer.eval()
|
self.transformer = self.transformer.eval()
|
||||||
#self.train = disabled_train
|
# self.train = disabled_train
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
elif self.layer == "pooled":
|
elif self.layer == "pooled":
|
||||||
@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipImageEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
jit=False,
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
antialias=True,
|
||||||
|
ucg_rate=0.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from clip import load as load_clip
|
||||||
|
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
self.ucg_rate = ucg_rate
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
# x = kornia_functions.geometry_resize(x, (224, 224),
|
||||||
|
# interpolation='bicubic', align_corners=True,
|
||||||
|
# antialias=self.antialias)
|
||||||
|
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# re-normalize according to clip
|
||||||
|
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, no_dropout=False):
|
||||||
|
# x is assumed to be in range [-1,1]
|
||||||
|
out = self.model.encode_image(self.preprocess(x))
|
||||||
|
out = out.to(x.dtype)
|
||||||
|
if self.ucg_rate > 0. and not no_dropout:
|
||||||
|
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||||
"""
|
"""
|
||||||
Uses the OpenCLIP transformer encoder for text
|
Uses the OpenCLIP transformer encoder for text
|
||||||
"""
|
"""
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
#"pooled",
|
# "pooled",
|
||||||
"last",
|
"last",
|
||||||
"penultimate"
|
"penultimate"
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||||
freeze=True, layer="last"):
|
freeze=True, layer="last"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|||||||
x = self.model.ln_final(x)
|
x = self.model.ln_final(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||||
for i, r in enumerate(self.model.transformer.resblocks):
|
for i, r in enumerate(self.model.transformer.resblocks):
|
||||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||||
break
|
break
|
||||||
@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||||
|
"""
|
||||||
|
Uses the OpenCLIP vision transformer encoder for images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||||
|
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||||
|
super().__init__()
|
||||||
|
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
||||||
|
pretrained=version, )
|
||||||
|
del model.transformer
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
if freeze:
|
||||||
|
self.freeze()
|
||||||
|
self.layer = layer
|
||||||
|
if self.layer == "penultimate":
|
||||||
|
raise NotImplementedError()
|
||||||
|
self.layer_idx = 1
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
self.ucg_rate = ucg_rate
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
# x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
# interpolation='bicubic', align_corners=True,
|
||||||
|
# antialias=self.antialias)
|
||||||
|
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# renormalize according to clip
|
||||||
|
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.model = self.model.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, image, no_dropout=False):
|
||||||
|
z = self.encode_with_vision_transformer(image)
|
||||||
|
if self.ucg_rate > 0. and not no_dropout:
|
||||||
|
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_with_vision_transformer(self, img):
|
||||||
|
img = self.preprocess(img)
|
||||||
|
x = self.model.visual(img)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||||
clip_max_length=77, t5_max_length=77):
|
clip_max_length=77, t5_max_length=77):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
return self(text)
|
return self(text)
|
||||||
@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
|||||||
clip_z = self.clip_encoder.encode(text)
|
clip_z = self.clip_encoder.encode(text)
|
||||||
t5_z = self.t5_encoder.encode(text)
|
t5_z = self.t5_encoder.encode(text)
|
||||||
return [clip_z, t5_z]
|
return [clip_z, t5_z]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
35
comfy/ldm/modules/encoders/noise_aug_modules.py
Normal file
35
comfy/ldm/modules/encoders/noise_aug_modules.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
|
from ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||||
|
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if clip_stats_path is None:
|
||||||
|
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||||
|
else:
|
||||||
|
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||||
|
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||||
|
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||||
|
self.time_embed = Timestep(timestep_dim)
|
||||||
|
|
||||||
|
def scale(self, x):
|
||||||
|
# re-normalize to centered mean and unit variance
|
||||||
|
x = (x - self.data_mean) * 1. / self.data_std
|
||||||
|
return x
|
||||||
|
|
||||||
|
def unscale(self, x):
|
||||||
|
# back to original data stats
|
||||||
|
x = (x * self.data_std) + self.data_mean
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, noise_level=None):
|
||||||
|
if noise_level is None:
|
||||||
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
|
else:
|
||||||
|
assert isinstance(noise_level, torch.Tensor)
|
||||||
|
x = self.scale(x)
|
||||||
|
z = self.q_sample(x, noise_level)
|
||||||
|
z = self.unscale(z)
|
||||||
|
noise_level = self.time_embed(noise_level)
|
||||||
|
return z, noise_level
|
||||||
144
comfy/ldm/modules/tomesd.py
Normal file
144
comfy/ldm/modules/tomesd.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
#Taken from: https://github.com/dbolya/tomesd
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import Tuple, Callable
|
||||||
|
import math
|
||||||
|
|
||||||
|
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mps_gather_workaround(input, dim, index):
|
||||||
|
if input.shape[-1] == 1:
|
||||||
|
return torch.gather(
|
||||||
|
input.unsqueeze(-1),
|
||||||
|
dim - 1 if dim < 0 else dim,
|
||||||
|
index.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
else:
|
||||||
|
return torch.gather(input, dim, index)
|
||||||
|
|
||||||
|
|
||||||
|
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||||
|
w: int, h: int, sx: int, sy: int, r: int,
|
||||||
|
no_rand: bool = False) -> Tuple[Callable, Callable]:
|
||||||
|
"""
|
||||||
|
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||||
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||||
|
Args:
|
||||||
|
- metric [B, N, C]: metric to use for similarity
|
||||||
|
- w: image width in tokens
|
||||||
|
- h: image height in tokens
|
||||||
|
- sx: stride in the x dimension for dst, must divide w
|
||||||
|
- sy: stride in the y dimension for dst, must divide h
|
||||||
|
- r: number of tokens to remove (by merging)
|
||||||
|
- no_rand: if true, disable randomness (use top left corner only)
|
||||||
|
"""
|
||||||
|
B, N, _ = metric.shape
|
||||||
|
|
||||||
|
if r <= 0:
|
||||||
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
hsy, wsx = h // sy, w // sx
|
||||||
|
|
||||||
|
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||||
|
if no_rand:
|
||||||
|
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||||
|
else:
|
||||||
|
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||||
|
|
||||||
|
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||||
|
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||||
|
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||||
|
|
||||||
|
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||||
|
if (hsy * sy) < h or (wsx * sx) < w:
|
||||||
|
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||||
|
else:
|
||||||
|
idx_buffer = idx_buffer_view
|
||||||
|
|
||||||
|
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||||
|
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||||
|
|
||||||
|
# We're finished with these
|
||||||
|
del idx_buffer, idx_buffer_view
|
||||||
|
|
||||||
|
# rand_idx is currently dst|src, so split them
|
||||||
|
num_dst = hsy * wsx
|
||||||
|
a_idx = rand_idx[:, num_dst:, :] # src
|
||||||
|
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||||
|
|
||||||
|
def split(x):
|
||||||
|
C = x.shape[-1]
|
||||||
|
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||||
|
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||||
|
return src, dst
|
||||||
|
|
||||||
|
# Cosine similarity between A and B
|
||||||
|
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||||
|
a, b = split(metric)
|
||||||
|
scores = a @ b.transpose(-1, -2)
|
||||||
|
|
||||||
|
# Can't reduce more than the # tokens in src
|
||||||
|
r = min(a.shape[1], r)
|
||||||
|
|
||||||
|
# Find the most similar greedily
|
||||||
|
node_max, node_idx = scores.max(dim=-1)
|
||||||
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||||
|
|
||||||
|
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||||
|
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||||
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||||
|
|
||||||
|
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||||
|
src, dst = split(x)
|
||||||
|
n, t1, c = src.shape
|
||||||
|
|
||||||
|
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||||
|
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||||
|
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||||
|
|
||||||
|
return torch.cat([unm, dst], dim=1)
|
||||||
|
|
||||||
|
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
unm_len = unm_idx.shape[1]
|
||||||
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||||
|
_, _, c = unm.shape
|
||||||
|
|
||||||
|
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||||
|
|
||||||
|
# Combine back to the original shape
|
||||||
|
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||||
|
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||||
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||||
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
return merge, unmerge
|
||||||
|
|
||||||
|
|
||||||
|
def get_functions(x, ratio, original_shape):
|
||||||
|
b, c, original_h, original_w = original_shape
|
||||||
|
original_tokens = original_h * original_w
|
||||||
|
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||||
|
stride_x = 2
|
||||||
|
stride_y = 2
|
||||||
|
max_downsample = 1
|
||||||
|
|
||||||
|
if downsample <= max_downsample:
|
||||||
|
w = int(math.ceil(original_w / downsample))
|
||||||
|
h = int(math.ceil(original_h / downsample))
|
||||||
|
r = int(x.shape[1] * ratio)
|
||||||
|
no_rand = False
|
||||||
|
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||||
|
return m, u
|
||||||
|
|
||||||
|
nothing = lambda y: y
|
||||||
|
return nothing, nothing
|
||||||
@ -1,36 +1,42 @@
|
|||||||
|
import psutil
|
||||||
|
from enum import Enum
|
||||||
|
from cli_args import args
|
||||||
|
|
||||||
CPU = 0
|
class VRAMState(Enum):
|
||||||
NO_VRAM = 1
|
CPU = 0
|
||||||
LOW_VRAM = 2
|
NO_VRAM = 1
|
||||||
NORMAL_VRAM = 3
|
LOW_VRAM = 2
|
||||||
HIGH_VRAM = 4
|
NORMAL_VRAM = 3
|
||||||
MPS = 5
|
HIGH_VRAM = 4
|
||||||
|
MPS = 5
|
||||||
|
|
||||||
accelerate_enabled = False
|
# Determine VRAM State
|
||||||
vram_state = NORMAL_VRAM
|
vram_state = VRAMState.NORMAL_VRAM
|
||||||
|
set_vram_to = VRAMState.NORMAL_VRAM
|
||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
total_vram_available_mb = -1
|
total_vram_available_mb = -1
|
||||||
|
|
||||||
import sys
|
accelerate_enabled = False
|
||||||
import psutil
|
xpu_available = False
|
||||||
|
|
||||||
forced_cpu = "--cpu" in sys.argv
|
|
||||||
|
|
||||||
set_vram_to = NORMAL_VRAM
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
if torch.xpu.is_available():
|
||||||
|
xpu_available = True
|
||||||
|
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
||||||
|
except:
|
||||||
|
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
forced_normal_vram = "--normalvram" in sys.argv
|
if not args.normalvram and not args.cpu:
|
||||||
if not forced_normal_vram and not forced_cpu:
|
|
||||||
if total_vram <= 4096:
|
if total_vram <= 4096:
|
||||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||||
set_vram_to = LOW_VRAM
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||||
vram_state = HIGH_VRAM
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -39,34 +45,37 @@ try:
|
|||||||
except:
|
except:
|
||||||
OOM_EXCEPTION = Exception
|
OOM_EXCEPTION = Exception
|
||||||
|
|
||||||
if "--disable-xformers" in sys.argv:
|
if args.disable_xformers:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
XFORMERS_IS_AVAILBLE = True
|
XFORMERS_IS_AVAILABLE = True
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILBLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = False
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
if "--use-pytorch-cross-attention" in sys.argv:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
XFORMERS_IS_AVAILABLE = False
|
||||||
XFORMERS_IS_AVAILBLE = False
|
|
||||||
|
if args.lowvram:
|
||||||
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
|
elif args.novram:
|
||||||
|
set_vram_to = VRAMState.NO_VRAM
|
||||||
|
elif args.highvram:
|
||||||
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
|
|
||||||
|
FORCE_FP32 = False
|
||||||
|
if args.force_fp32:
|
||||||
|
print("Forcing FP32, if this improves things please report it.")
|
||||||
|
FORCE_FP32 = True
|
||||||
|
|
||||||
|
|
||||||
if "--lowvram" in sys.argv:
|
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||||
set_vram_to = LOW_VRAM
|
|
||||||
if "--novram" in sys.argv:
|
|
||||||
set_vram_to = NO_VRAM
|
|
||||||
if "--highvram" in sys.argv:
|
|
||||||
vram_state = HIGH_VRAM
|
|
||||||
|
|
||||||
|
|
||||||
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
|
|
||||||
try:
|
try:
|
||||||
import accelerate
|
import accelerate
|
||||||
accelerate_enabled = True
|
accelerate_enabled = True
|
||||||
@ -81,14 +90,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
vram_state = MPS
|
vram_state = VRAMState.MPS
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if forced_cpu:
|
if args.cpu:
|
||||||
vram_state = CPU
|
vram_state = VRAMState.CPU
|
||||||
|
|
||||||
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
|
print(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
@ -109,12 +118,12 @@ def unload_model():
|
|||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
#never unload models from GPU on high vram
|
#never unload models from GPU on high vram
|
||||||
if vram_state != HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model.cpu()
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
if vram_state != HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
if len(current_gpu_controlnets) > 0:
|
if len(current_gpu_controlnets) > 0:
|
||||||
for n in current_gpu_controlnets:
|
for n in current_gpu_controlnets:
|
||||||
n.cpu()
|
n.cpu()
|
||||||
@ -135,32 +144,32 @@ def load_model_gpu(model):
|
|||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
if vram_state == CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
pass
|
pass
|
||||||
elif vram_state == MPS:
|
elif vram_state == VRAMState.MPS:
|
||||||
mps_device = torch.device("mps")
|
mps_device = torch.device("mps")
|
||||||
real_model.to(mps_device)
|
real_model.to(mps_device)
|
||||||
pass
|
pass
|
||||||
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
|
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.cuda()
|
real_model.to(get_torch_device())
|
||||||
else:
|
else:
|
||||||
if vram_state == NO_VRAM:
|
if vram_state == VRAMState.NO_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
elif vram_state == LOW_VRAM:
|
elif vram_state == VRAMState.LOW_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
||||||
|
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
def load_controlnet_gpu(models):
|
def load_controlnet_gpu(models):
|
||||||
global current_gpu_controlnets
|
global current_gpu_controlnets
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
return
|
return
|
||||||
|
|
||||||
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -176,38 +185,57 @@ def load_controlnet_gpu(models):
|
|||||||
|
|
||||||
def load_if_low_vram(model):
|
def load_if_low_vram(model):
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
return model.cuda()
|
return model.to(get_torch_device())
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def unload_if_low_vram(model):
|
def unload_if_low_vram(model):
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
if vram_state == MPS:
|
global xpu_available
|
||||||
|
if vram_state == VRAMState.MPS:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if vram_state == CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
return torch.cuda.current_device()
|
if xpu_available:
|
||||||
|
return torch.device("xpu")
|
||||||
|
else:
|
||||||
|
return torch.cuda.current_device()
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
if vram_state == CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILBLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_enabled_vae():
|
||||||
|
enabled = xformers_enabled()
|
||||||
|
if not enabled:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
#0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
|
||||||
|
if xformers.version.__version__ == "0.0.18":
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return enabled
|
||||||
|
|
||||||
def pytorch_attention_enabled():
|
def pytorch_attention_enabled():
|
||||||
return ENABLE_PYTORCH_ATTENTION
|
return ENABLE_PYTORCH_ATTENTION
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
|
global xpu_available
|
||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
||||||
@ -215,12 +243,16 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
mem_free_total = psutil.virtual_memory().available
|
mem_free_total = psutil.virtual_memory().available
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
if xpu_available:
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_free_torch = mem_free_total
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
|
else:
|
||||||
mem_free_torch = mem_reserved - mem_active
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
if torch_free_too:
|
if torch_free_too:
|
||||||
return (mem_free_total, mem_free_torch)
|
return (mem_free_total, mem_free_torch)
|
||||||
@ -229,7 +261,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
|
|
||||||
def maximum_batch_area():
|
def maximum_batch_area():
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == NO_VRAM:
|
if vram_state == VRAMState.NO_VRAM:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
memory_free = get_free_memory() / (1024 * 1024)
|
memory_free = get_free_memory() / (1024 * 1024)
|
||||||
@ -238,14 +270,18 @@ def maximum_batch_area():
|
|||||||
|
|
||||||
def cpu_mode():
|
def cpu_mode():
|
||||||
global vram_state
|
global vram_state
|
||||||
return vram_state == CPU
|
return vram_state == VRAMState.CPU
|
||||||
|
|
||||||
def mps_mode():
|
def mps_mode():
|
||||||
global vram_state
|
global vram_state
|
||||||
return vram_state == MPS
|
return vram_state == VRAMState.MPS
|
||||||
|
|
||||||
def should_use_fp16():
|
def should_use_fp16():
|
||||||
if cpu_mode() or mps_mode():
|
global xpu_available
|
||||||
|
if FORCE_FP32:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if cpu_mode() or mps_mode() or xpu_available:
|
||||||
return False #TODO ?
|
return False #TODO ?
|
||||||
|
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
|
|||||||
@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
|
||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if 'strength' in cond[1]:
|
if 'strength' in cond[1]:
|
||||||
strength = cond[1]['strength']
|
strength = cond[1]['strength']
|
||||||
|
|
||||||
|
adm_cond = None
|
||||||
|
if 'adm' in cond[1]:
|
||||||
|
adm_cond = cond[1]['adm']
|
||||||
|
|
||||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
mult = torch.ones_like(input_x) * strength
|
mult = torch.ones_like(input_x) * strength
|
||||||
|
|
||||||
@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
cropped.append(cr)
|
cropped.append(cr)
|
||||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||||
|
|
||||||
|
if adm_cond is not None:
|
||||||
|
conditionning['c_adm'] = adm_cond
|
||||||
|
|
||||||
control = None
|
control = None
|
||||||
if 'control' in cond[1]:
|
if 'control' in cond[1]:
|
||||||
control = cond[1]['control']
|
control = cond[1]['control']
|
||||||
@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if 'c_concat' in c1:
|
if 'c_concat' in c1:
|
||||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||||
return False
|
return False
|
||||||
|
if 'c_adm' in c1:
|
||||||
|
if c1['c_adm'].shape != c2['c_adm'].shape:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def can_concat_cond(c1, c2):
|
def can_concat_cond(c1, c2):
|
||||||
@ -92,19 +102,24 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
def cond_cat(c_list):
|
def cond_cat(c_list):
|
||||||
c_crossattn = []
|
c_crossattn = []
|
||||||
c_concat = []
|
c_concat = []
|
||||||
|
c_adm = []
|
||||||
for x in c_list:
|
for x in c_list:
|
||||||
if 'c_crossattn' in x:
|
if 'c_crossattn' in x:
|
||||||
c_crossattn.append(x['c_crossattn'])
|
c_crossattn.append(x['c_crossattn'])
|
||||||
if 'c_concat' in x:
|
if 'c_concat' in x:
|
||||||
c_concat.append(x['c_concat'])
|
c_concat.append(x['c_concat'])
|
||||||
|
if 'c_adm' in x:
|
||||||
|
c_adm.append(x['c_adm'])
|
||||||
out = {}
|
out = {}
|
||||||
if len(c_crossattn) > 0:
|
if len(c_crossattn) > 0:
|
||||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||||
if len(c_concat) > 0:
|
if len(c_concat) > 0:
|
||||||
out['c_concat'] = [torch.cat(c_concat)]
|
out['c_concat'] = [torch.cat(c_concat)]
|
||||||
|
if len(c_adm) > 0:
|
||||||
|
out['c_adm'] = torch.cat(c_adm)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in)/100000.0
|
||||||
|
|
||||||
@ -169,6 +184,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if control is not None:
|
if control is not None:
|
||||||
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
|
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
|
||||||
|
|
||||||
|
if 'transformer_options' in model_options:
|
||||||
|
c['transformer_options'] = model_options['transformer_options']
|
||||||
|
|
||||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
@ -192,7 +210,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
|
|
||||||
max_total_area = model_management.maximum_batch_area()
|
max_total_area = model_management.maximum_batch_area()
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
|
||||||
@ -209,8 +227,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.alphas_cumprod = model.alphas_cumprod
|
self.alphas_cumprod = model.alphas_cumprod
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -218,11 +236,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out *= denoise_mask
|
||||||
|
|
||||||
@ -324,13 +342,46 @@ def apply_control_net_to_equal_area(conds, uncond):
|
|||||||
n['control'] = cond_cnets[x]
|
n['control'] = cond_cnets[x]
|
||||||
uncond[temp[1]] = [o[0], n]
|
uncond[temp[1]] = [o[0], n]
|
||||||
|
|
||||||
|
def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
if 'adm' in x[1]:
|
||||||
|
adm_inputs = []
|
||||||
|
weights = []
|
||||||
|
noise_aug = []
|
||||||
|
adm_in = x[1]["adm"]
|
||||||
|
for adm_c in adm_in:
|
||||||
|
adm_cond = adm_c[0].image_embeds
|
||||||
|
weight = adm_c[1]
|
||||||
|
noise_augment = adm_c[2]
|
||||||
|
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||||
|
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
||||||
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||||
|
weights.append(weight)
|
||||||
|
noise_aug.append(noise_augment)
|
||||||
|
adm_inputs.append(adm_out)
|
||||||
|
|
||||||
|
if len(noise_aug) > 1:
|
||||||
|
adm_out = torch.stack(adm_inputs).sum(0)
|
||||||
|
#TODO: add a way to control this
|
||||||
|
noise_augment = 0.05
|
||||||
|
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||||
|
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
|
||||||
|
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
|
else:
|
||||||
|
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
||||||
|
x[1] = x[1].copy()
|
||||||
|
x[1]["adm"] = torch.cat([adm_out] * batch_size)
|
||||||
|
|
||||||
|
return conds
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_denoise = CFGNoisePredictor(self.model)
|
self.model_denoise = CFGNoisePredictor(self.model)
|
||||||
if self.model.parameterization == "v":
|
if self.model.parameterization == "v":
|
||||||
@ -350,6 +401,7 @@ class KSampler:
|
|||||||
self.sigma_max=float(self.model_wrap.sigma_max)
|
self.sigma_max=float(self.model_wrap.sigma_max)
|
||||||
self.set_steps(steps, denoise)
|
self.set_steps(steps, denoise)
|
||||||
self.denoise = denoise
|
self.denoise = denoise
|
||||||
|
self.model_options = model_options
|
||||||
|
|
||||||
def _calculate_sigmas(self, steps):
|
def _calculate_sigmas(self, steps):
|
||||||
sigmas = None
|
sigmas = None
|
||||||
@ -418,10 +470,14 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
precision_scope = contextlib.nullcontext
|
precision_scope = contextlib.nullcontext
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
if hasattr(self.model, 'noise_augmentor'): #unclip
|
||||||
|
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
|
||||||
|
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
|
||||||
|
|
||||||
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||||
|
|
||||||
cond_concat = None
|
cond_concat = None
|
||||||
if hasattr(self.model, 'concat_keys'):
|
if hasattr(self.model, 'concat_keys'): #inpaint
|
||||||
cond_concat = []
|
cond_concat = []
|
||||||
for ck in self.model.concat_keys:
|
for ck in self.model.concat_keys:
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
@ -467,7 +523,7 @@ class KSampler:
|
|||||||
x_T=z_enc,
|
x_T=z_enc,
|
||||||
x0=latent_image,
|
x0=latent_image,
|
||||||
denoise_function=sampling_function,
|
denoise_function=sampling_function,
|
||||||
cond_concat=cond_concat,
|
extra_args=extra_args,
|
||||||
mask=noise_mask,
|
mask=noise_mask,
|
||||||
to_zero=sigmas[-1]==0,
|
to_zero=sigmas[-1]==0,
|
||||||
end_step=sigmas.shape[0] - 1)
|
end_step=sigmas.shape[0] - 1)
|
||||||
|
|||||||
102
comfy/sd.py
102
comfy/sd.py
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import copy
|
||||||
|
|
||||||
import sd1_clip
|
import sd1_clip
|
||||||
import sd2_clip
|
import sd2_clip
|
||||||
@ -11,20 +12,7 @@ from .cldm import cldm
|
|||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from . import clip_vision
|
||||||
def load_torch_file(ckpt):
|
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
|
||||||
import safetensors.torch
|
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
|
||||||
else:
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
if "state_dict" in pl_sd:
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
else:
|
|
||||||
sd = pl_sd
|
|
||||||
return sd
|
|
||||||
|
|
||||||
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
@ -52,30 +40,7 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
|||||||
if x in sd:
|
if x in sd:
|
||||||
sd[keys_to_replace[x]] = sd.pop(x)
|
sd[keys_to_replace[x]] = sd.pop(x)
|
||||||
|
|
||||||
resblock_to_replace = {
|
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
|
||||||
"ln_1": "layer_norm1",
|
|
||||||
"ln_2": "layer_norm2",
|
|
||||||
"mlp.c_fc": "mlp.fc1",
|
|
||||||
"mlp.c_proj": "mlp.fc2",
|
|
||||||
"attn.out_proj": "self_attn.out_proj",
|
|
||||||
}
|
|
||||||
|
|
||||||
for resblock in range(24):
|
|
||||||
for x in resblock_to_replace:
|
|
||||||
for y in ["weight", "bias"]:
|
|
||||||
k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y)
|
|
||||||
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y)
|
|
||||||
if k in sd:
|
|
||||||
sd[k_to] = sd.pop(k)
|
|
||||||
|
|
||||||
for y in ["weight", "bias"]:
|
|
||||||
k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y)
|
|
||||||
if k_from in sd:
|
|
||||||
weights = sd.pop(k_from)
|
|
||||||
for x in range(3):
|
|
||||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
|
||||||
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y)
|
|
||||||
sd[k_to] = weights[1024*x:1024*(x + 1)]
|
|
||||||
|
|
||||||
for x in load_state_dict_to:
|
for x in load_state_dict_to:
|
||||||
x.load_state_dict(sd, strict=False)
|
x.load_state_dict(sd, strict=False)
|
||||||
@ -122,7 +87,7 @@ LORA_UNET_MAP_RESNET = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def load_lora(path, to_load):
|
def load_lora(path, to_load):
|
||||||
lora = load_torch_file(path)
|
lora = utils.load_torch_file(path)
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
@ -274,12 +239,20 @@ class ModelPatcher:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.patches = []
|
self.patches = []
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
self.model_options = {"transformer_options":{}}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model)
|
n = ModelPatcher(self.model)
|
||||||
n.patches = self.patches[:]
|
n.patches = self.patches[:]
|
||||||
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def set_model_tomesd(self, ratio):
|
||||||
|
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
||||||
|
|
||||||
|
def model_dtype(self):
|
||||||
|
return self.model.diffusion_model.dtype
|
||||||
|
|
||||||
def add_patches(self, patches, strength=1.0):
|
def add_patches(self, patches, strength=1.0):
|
||||||
p = {}
|
p = {}
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
@ -590,7 +563,7 @@ class ControlNet:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = load_torch_file(ckpt_path)
|
controlnet_data = utils.load_torch_file(ckpt_path)
|
||||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
pth = False
|
pth = False
|
||||||
sd2 = False
|
sd2 = False
|
||||||
@ -784,7 +757,7 @@ class StyleModel:
|
|||||||
|
|
||||||
|
|
||||||
def load_style_model(ckpt_path):
|
def load_style_model(ckpt_path):
|
||||||
model_data = load_torch_file(ckpt_path)
|
model_data = utils.load_torch_file(ckpt_path)
|
||||||
keys = model_data.keys()
|
keys = model_data.keys()
|
||||||
if "style_embedding" in keys:
|
if "style_embedding" in keys:
|
||||||
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||||
@ -795,7 +768,7 @@ def load_style_model(ckpt_path):
|
|||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_path, embedding_directory=None):
|
def load_clip(ckpt_path, embedding_directory=None):
|
||||||
clip_data = load_torch_file(ckpt_path)
|
clip_data = utils.load_torch_file(ckpt_path)
|
||||||
config = {}
|
config = {}
|
||||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
||||||
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||||
@ -838,7 +811,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
load_state_dict_to = [w]
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
model = instantiate_from_config(config["model"])
|
model = instantiate_from_config(config["model"])
|
||||||
sd = load_torch_file(ckpt_path)
|
sd = utils.load_torch_file(ckpt_path)
|
||||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
@ -847,10 +820,11 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
return (ModelPatcher(model), clip, vae)
|
return (ModelPatcher(model), clip, vae)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||||
sd = load_torch_file(ckpt_path)
|
sd = utils.load_torch_file(ckpt_path)
|
||||||
sd_keys = sd.keys()
|
sd_keys = sd.keys()
|
||||||
clip = None
|
clip = None
|
||||||
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
|
|
||||||
fp16 = model_management.should_use_fp16()
|
fp16 = model_management.should_use_fp16()
|
||||||
@ -875,6 +849,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
load_state_dict_to = [w]
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight"
|
||||||
|
noise_aug_config = None
|
||||||
|
if clipvision_key in sd_keys:
|
||||||
|
size = sd[clipvision_key].shape[1]
|
||||||
|
|
||||||
|
if output_clipvision:
|
||||||
|
clipvision = clip_vision.load_clipvision_from_sd(sd)
|
||||||
|
|
||||||
|
noise_aug_key = "noise_augmentor.betas"
|
||||||
|
if noise_aug_key in sd_keys:
|
||||||
|
noise_aug_config = {}
|
||||||
|
params = {}
|
||||||
|
noise_schedule_config = {}
|
||||||
|
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
|
||||||
|
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
|
||||||
|
params["noise_schedule_config"] = noise_schedule_config
|
||||||
|
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
|
||||||
|
if size == 1280: #h
|
||||||
|
params["timestep_dim"] = 1024
|
||||||
|
elif size == 1024: #l
|
||||||
|
params["timestep_dim"] = 768
|
||||||
|
noise_aug_config['params'] = params
|
||||||
|
|
||||||
sd_config = {
|
sd_config = {
|
||||||
"linear_start": 0.00085,
|
"linear_start": 0.00085,
|
||||||
"linear_end": 0.012,
|
"linear_end": 0.012,
|
||||||
@ -923,7 +920,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||||
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||||
|
|
||||||
if unet_config["in_channels"] > 4: #inpainting model
|
if noise_aug_config is not None: #SD2.x unclip model
|
||||||
|
sd_config["noise_aug_config"] = noise_aug_config
|
||||||
|
sd_config["image_size"] = 96
|
||||||
|
sd_config["embedding_dropout"] = 0.25
|
||||||
|
sd_config["conditioning_key"] = 'crossattn-adm'
|
||||||
|
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
||||||
|
elif unet_config["in_channels"] > 4: #inpainting model
|
||||||
sd_config["conditioning_key"] = "hybrid"
|
sd_config["conditioning_key"] = "hybrid"
|
||||||
sd_config["finetune_keys"] = None
|
sd_config["finetune_keys"] = None
|
||||||
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||||
@ -935,6 +938,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
else:
|
else:
|
||||||
unet_config["num_heads"] = 8 #SD1.x
|
unet_config["num_heads"] = 8 #SD1.x
|
||||||
|
|
||||||
|
unclip = 'model.diffusion_model.label_emb.0.0.weight'
|
||||||
|
if unclip in sd_keys:
|
||||||
|
unet_config["num_classes"] = "sequential"
|
||||||
|
unet_config["adm_in_channels"] = sd[unclip].shape[1]
|
||||||
|
|
||||||
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
||||||
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
||||||
out = sd[k]
|
out = sd[k]
|
||||||
@ -947,4 +955,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae)
|
return (ModelPatcher(model), clip, vae, clipvision)
|
||||||
|
|||||||
@ -74,9 +74,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if isinstance(y, int):
|
if isinstance(y, int):
|
||||||
tokens_temp += [y]
|
tokens_temp += [y]
|
||||||
else:
|
else:
|
||||||
embedding_weights += [y]
|
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||||
tokens_temp += [next_new_token]
|
embedding_weights += [y]
|
||||||
next_new_token += 1
|
tokens_temp += [next_new_token]
|
||||||
|
next_new_token += 1
|
||||||
|
else:
|
||||||
|
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
if len(embedding_weights) > 0:
|
if len(embedding_weights) > 0:
|
||||||
|
|||||||
@ -1,5 +1,47 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
def load_torch_file(ckpt):
|
||||||
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
|
import safetensors.torch
|
||||||
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
|
else:
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
if "state_dict" in pl_sd:
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
else:
|
||||||
|
sd = pl_sd
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
|
resblock_to_replace = {
|
||||||
|
"ln_1": "layer_norm1",
|
||||||
|
"ln_2": "layer_norm2",
|
||||||
|
"mlp.c_fc": "mlp.fc1",
|
||||||
|
"mlp.c_proj": "mlp.fc2",
|
||||||
|
"attn.out_proj": "self_attn.out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
for resblock in range(number):
|
||||||
|
for x in resblock_to_replace:
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||||
|
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||||
|
if k in sd:
|
||||||
|
sd[k_to] = sd.pop(k)
|
||||||
|
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||||
|
if k_from in sd:
|
||||||
|
weights = sd.pop(k_from)
|
||||||
|
shape_from = weights.shape[0] // 3
|
||||||
|
for x in range(3):
|
||||||
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||||
|
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||||
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
|
return sd
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
old_width = samples.shape[3]
|
old_width = samples.shape[3]
|
||||||
|
|||||||
@ -1,32 +0,0 @@
|
|||||||
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
|
|
||||||
from comfy.sd import load_torch_file
|
|
||||||
import os
|
|
||||||
|
|
||||||
class ClipVisionModel():
|
|
||||||
def __init__(self):
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json")
|
|
||||||
config = CLIPVisionConfig.from_json_file(json_config)
|
|
||||||
self.model = CLIPVisionModel(config)
|
|
||||||
self.processor = CLIPImageProcessor(crop_size=224,
|
|
||||||
do_center_crop=True,
|
|
||||||
do_convert_rgb=True,
|
|
||||||
do_normalize=True,
|
|
||||||
do_resize=True,
|
|
||||||
image_mean=[ 0.48145466,0.4578275,0.40821073],
|
|
||||||
image_std=[0.26862954,0.26130258,0.27577711],
|
|
||||||
resample=3, #bicubic
|
|
||||||
size=224)
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
|
||||||
self.model.load_state_dict(sd, strict=False)
|
|
||||||
|
|
||||||
def encode_image(self, image):
|
|
||||||
inputs = self.processor(images=[image[0]], return_tensors="pt")
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def load(ckpt_path):
|
|
||||||
clip_data = load_torch_file(ckpt_path)
|
|
||||||
clip = ClipVisionModel()
|
|
||||||
clip.load_sd(clip_data)
|
|
||||||
return clip
|
|
||||||
210
comfy_extras/nodes_post_processing.py
Normal file
210
comfy_extras/nodes_post_processing.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
|
class Blend:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image1": ("IMAGE",),
|
||||||
|
"image2": ("IMAGE",),
|
||||||
|
"blend_factor": ("FLOAT", {
|
||||||
|
"default": 0.5,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01
|
||||||
|
}),
|
||||||
|
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "blend_images"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||||
|
if image1.shape != image2.shape:
|
||||||
|
image2 = image2.permute(0, 3, 1, 2)
|
||||||
|
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
|
||||||
|
image2 = image2.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
blended_image = self.blend_mode(image1, image2, blend_mode)
|
||||||
|
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
||||||
|
blended_image = torch.clamp(blended_image, 0, 1)
|
||||||
|
return (blended_image,)
|
||||||
|
|
||||||
|
def blend_mode(self, img1, img2, mode):
|
||||||
|
if mode == "normal":
|
||||||
|
return img2
|
||||||
|
elif mode == "multiply":
|
||||||
|
return img1 * img2
|
||||||
|
elif mode == "screen":
|
||||||
|
return 1 - (1 - img1) * (1 - img2)
|
||||||
|
elif mode == "overlay":
|
||||||
|
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||||
|
elif mode == "soft_light":
|
||||||
|
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||||
|
|
||||||
|
def g(self, x):
|
||||||
|
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||||
|
|
||||||
|
class Blur:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"blur_radius": ("INT", {
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 31,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
"sigma": ("FLOAT", {
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.1,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.1
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "blur"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def gaussian_kernel(self, kernel_size: int, sigma: float):
|
||||||
|
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||||
|
d = torch.sqrt(x * x + y * y)
|
||||||
|
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||||
|
return g / g.sum()
|
||||||
|
|
||||||
|
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
|
||||||
|
if blur_radius == 0:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
|
kernel_size = blur_radius * 2 + 1
|
||||||
|
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
|
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
|
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
|
||||||
|
blurred = blurred.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
return (blurred,)
|
||||||
|
|
||||||
|
class Quantize:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"colors": ("INT", {
|
||||||
|
"default": 256,
|
||||||
|
"min": 1,
|
||||||
|
"max": 256,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
"dither": (["none", "floyd-steinberg"],),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "quantize"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
|
||||||
|
batch_size, height, width, _ = image.shape
|
||||||
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b]
|
||||||
|
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||||
|
pil_image = Image.fromarray(img, mode='RGB')
|
||||||
|
|
||||||
|
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||||
|
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
|
||||||
|
|
||||||
|
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
||||||
|
result[b] = quantized_array
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
class Sharpen:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"sharpen_radius": ("INT", {
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 31,
|
||||||
|
"step": 1
|
||||||
|
}),
|
||||||
|
"alpha": ("FLOAT", {
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.1,
|
||||||
|
"max": 5.0,
|
||||||
|
"step": 0.1
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "sharpen"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
|
||||||
|
if sharpen_radius == 0:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
|
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
|
||||||
|
center = kernel_size // 2
|
||||||
|
kernel[center, center] = kernel_size**2
|
||||||
|
kernel *= alpha
|
||||||
|
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
|
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
|
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
|
||||||
|
sharpened = sharpened.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
result = torch.clamp(sharpened, 0, 1)
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImageBlend": Blend,
|
||||||
|
"ImageBlur": Blur,
|
||||||
|
"ImageQuantize": Quantize,
|
||||||
|
"ImageSharpen": Sharpen,
|
||||||
|
}
|
||||||
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from comfy_extras.chainner_models import model_loading
|
from comfy_extras.chainner_models import model_loading
|
||||||
from comfy.sd import load_torch_file
|
|
||||||
import model_management
|
import model_management
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -18,7 +17,7 @@ class UpscaleModelLoader:
|
|||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||||
sd = load_torch_file(model_path)
|
sd = comfy.utils.load_torch_file(model_path)
|
||||||
out = model_loading.load_state_dict(sd).eval()
|
out = model_loading.load_state_dict(sd).eval()
|
||||||
return (out, )
|
return (out, )
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,8 @@ class Example:
|
|||||||
----------
|
----------
|
||||||
RETURN_TYPES (`tuple`):
|
RETURN_TYPES (`tuple`):
|
||||||
The type of each element in the output tulple.
|
The type of each element in the output tulple.
|
||||||
|
RETURN_NAMES (`tuple`):
|
||||||
|
Optional: The name of each output in the output tulple.
|
||||||
FUNCTION (`str`):
|
FUNCTION (`str`):
|
||||||
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
|
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
|
||||||
OUTPUT_NODE ([`bool`]):
|
OUTPUT_NODE ([`bool`]):
|
||||||
@ -61,6 +63,8 @@ class Example:
|
|||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
#RETURN_NAMES = ("image_output_name",)
|
||||||
|
|
||||||
FUNCTION = "test"
|
FUNCTION = "test"
|
||||||
|
|
||||||
#OUTPUT_NODE = False
|
#OUTPUT_NODE = False
|
||||||
|
|||||||
@ -24,10 +24,45 @@ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_
|
|||||||
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
|
||||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||||
|
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
||||||
|
|
||||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||||
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
||||||
|
|
||||||
|
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||||
|
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||||
|
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||||
|
|
||||||
|
if not os.path.exists(input_directory):
|
||||||
|
os.makedirs(input_directory)
|
||||||
|
|
||||||
|
def set_output_directory(output_dir):
|
||||||
|
global output_directory
|
||||||
|
output_directory = output_dir
|
||||||
|
|
||||||
|
def get_output_directory():
|
||||||
|
global output_directory
|
||||||
|
return output_directory
|
||||||
|
|
||||||
|
def get_temp_directory():
|
||||||
|
global temp_directory
|
||||||
|
return temp_directory
|
||||||
|
|
||||||
|
def get_input_directory():
|
||||||
|
global input_directory
|
||||||
|
return input_directory
|
||||||
|
|
||||||
|
|
||||||
|
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||||
|
def get_directory_by_type(type_name):
|
||||||
|
if type_name == "output":
|
||||||
|
return get_output_directory()
|
||||||
|
if type_name == "temp":
|
||||||
|
return get_temp_directory()
|
||||||
|
if type_name == "input":
|
||||||
|
return get_input_directory()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name, full_folder_path):
|
def add_model_folder_path(folder_name, full_folder_path):
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
|||||||
90
main.py
90
main.py
@ -1,49 +1,32 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if '--help' in sys.argv:
|
if args.dont_upcast_attention:
|
||||||
print("Valid Command line Arguments:")
|
|
||||||
print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.")
|
|
||||||
print("\t--port 8188\t\t\tSet the listen port.")
|
|
||||||
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
|
|
||||||
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
|
|
||||||
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
|
|
||||||
print("\t--disable-xformers\t\tdisables xformers")
|
|
||||||
print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.")
|
|
||||||
print()
|
|
||||||
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
|
|
||||||
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
|
|
||||||
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
|
|
||||||
print("\t--novram\t\t\tWhen lowvram isn't enough.")
|
|
||||||
print()
|
|
||||||
print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
if '--dont-upcast-attention' in sys.argv:
|
|
||||||
print("disabling upcasting of attention")
|
print("disabling upcasting of attention")
|
||||||
os.environ['ATTN_PRECISION'] = "fp16"
|
os.environ['ATTN_PRECISION'] = "fp16"
|
||||||
|
|
||||||
try:
|
if args.cuda_device is not None:
|
||||||
index = sys.argv.index('--cuda-device')
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
device = sys.argv[index + 1]
|
print("Set cuda device to:", args.cuda_device)
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
|
||||||
print("Set cuda device to:", device)
|
|
||||||
except:
|
import yaml
|
||||||
pass
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import yaml
|
import server
|
||||||
|
from nodes import init_custom_nodes
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q, server):
|
def prompt_worker(q, server):
|
||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server)
|
||||||
@ -98,47 +81,36 @@ if __name__ == "__main__":
|
|||||||
server = server.PromptServer(loop)
|
server = server.PromptServer(loop)
|
||||||
q = execution.PromptQueue(server)
|
q = execution.PromptQueue(server)
|
||||||
|
|
||||||
|
init_custom_nodes()
|
||||||
|
server.add_routes()
|
||||||
hijack_progress(server)
|
hijack_progress(server)
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
||||||
try:
|
|
||||||
address = '0.0.0.0'
|
|
||||||
p_index = sys.argv.index('--listen')
|
|
||||||
try:
|
|
||||||
ip = sys.argv[p_index + 1]
|
|
||||||
if ip[:2] != '--':
|
|
||||||
address = ip
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
except:
|
|
||||||
address = '127.0.0.1'
|
|
||||||
|
|
||||||
|
address = args.listen
|
||||||
|
|
||||||
dont_print = False
|
dont_print = args.dont_print_server
|
||||||
if '--dont-print-server' in sys.argv:
|
|
||||||
dont_print = True
|
|
||||||
|
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
load_extra_path_config(extra_model_paths_config_path)
|
load_extra_path_config(extra_model_paths_config_path)
|
||||||
|
|
||||||
if '--extra-model-paths-config' in sys.argv:
|
if args.extra_model_paths_config:
|
||||||
indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config']
|
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||||
for i in indices:
|
load_extra_path_config(config_path)
|
||||||
load_extra_path_config(sys.argv[i])
|
|
||||||
|
|
||||||
port = 8188
|
if args.output_directory:
|
||||||
try:
|
output_dir = os.path.abspath(args.output_directory)
|
||||||
p_index = sys.argv.index('--port')
|
print(f"Setting output directory to: {output_dir}")
|
||||||
port = int(sys.argv[p_index + 1])
|
folder_paths.set_output_directory(output_dir)
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if '--quick-test-for-ci' in sys.argv:
|
port = args.port
|
||||||
|
|
||||||
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if "--windows-standalone-build" in sys.argv:
|
if args.windows_standalone_build:
|
||||||
def startup_server(address, port):
|
def startup_server(address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
webbrowser.open("http://{}:{}".format(address, port))
|
webbrowser.open("http://{}:{}".format(address, port))
|
||||||
|
|||||||
0
models/diffusers/put_diffusers_models_here
Normal file
0
models/diffusers/put_diffusers_models_here
Normal file
139
nodes.py
139
nodes.py
@ -4,21 +4,22 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import copy
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||||
|
|
||||||
|
|
||||||
|
import comfy.diffusers_convert
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
import comfy_extras.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
import model_management
|
import model_management
|
||||||
import importlib
|
import importlib
|
||||||
@ -197,7 +198,7 @@ class CheckpointLoader:
|
|||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
||||||
config_path = folder_paths.get_full_path("configs", config_name)
|
config_path = folder_paths.get_full_path("configs", config_name)
|
||||||
@ -219,6 +220,45 @@ class CheckpointLoaderSimple:
|
|||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class DiffusersLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
paths = []
|
||||||
|
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||||
|
if os.path.exists(search_path):
|
||||||
|
paths += next(os.walk(search_path))[1]
|
||||||
|
return {"required": {"model_path": (paths,), }}
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||||
|
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||||
|
if os.path.exists(search_path):
|
||||||
|
paths = next(os.walk(search_path))[1]
|
||||||
|
if model_path in paths:
|
||||||
|
model_path = os.path.join(search_path, model_path)
|
||||||
|
break
|
||||||
|
|
||||||
|
return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
|
|
||||||
|
class unCLIPCheckpointLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
|
||||||
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
return out
|
||||||
|
|
||||||
class CLIPSetLastLayer:
|
class CLIPSetLastLayer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -254,6 +294,22 @@ class LoraLoader:
|
|||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
|
class TomePatchModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, ratio):
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_tomesd(ratio)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -354,7 +410,7 @@ class CLIPVisionLoader:
|
|||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
||||||
clip_vision = comfy_extras.clip_vision.load(clip_path)
|
clip_vision = comfy.clip_vision.load(clip_path)
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
class CLIPVisionEncode:
|
class CLIPVisionEncode:
|
||||||
@ -366,7 +422,7 @@ class CLIPVisionEncode:
|
|||||||
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
|
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "conditioning/style_model"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def encode(self, clip_vision, image):
|
def encode(self, clip_vision, image):
|
||||||
output = clip_vision.encode_image(image)
|
output = clip_vision.encode_image(image)
|
||||||
@ -408,6 +464,33 @@ class StyleModelApply:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class unCLIPConditioning:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||||
|
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "apply_adm"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
o = t[1].copy()
|
||||||
|
x = (clip_vision_output, strength, noise_augmentation)
|
||||||
|
if "adm" in o:
|
||||||
|
o["adm"] = o["adm"][:] + [x]
|
||||||
|
else:
|
||||||
|
o["adm"] = [x]
|
||||||
|
n = [t[0], o]
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
class EmptyLatentImage:
|
class EmptyLatentImage:
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu"):
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -646,7 +729,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
model_management.load_controlnet_gpu(control_net_models)
|
model_management.load_controlnet_gpu(control_net_models)
|
||||||
|
|
||||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
else:
|
else:
|
||||||
#other samplers
|
#other samplers
|
||||||
pass
|
pass
|
||||||
@ -719,7 +802,7 @@ class KSamplerAdvanced:
|
|||||||
|
|
||||||
class SaveImage:
|
class SaveImage:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
self.type = "output"
|
self.type = "output"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -771,9 +854,6 @@ class SaveImage:
|
|||||||
os.makedirs(full_output_folder, exist_ok=True)
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
counter = 1
|
counter = 1
|
||||||
|
|
||||||
if not os.path.exists(self.output_dir):
|
|
||||||
os.makedirs(self.output_dir)
|
|
||||||
|
|
||||||
results = list()
|
results = list()
|
||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.cpu().numpy()
|
||||||
@ -798,7 +878,7 @@ class SaveImage:
|
|||||||
|
|
||||||
class PreviewImage(SaveImage):
|
class PreviewImage(SaveImage):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
self.output_dir = folder_paths.get_temp_directory()
|
||||||
self.type = "temp"
|
self.type = "temp"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -809,13 +889,11 @@ class PreviewImage(SaveImage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
class LoadImage:
|
class LoadImage:
|
||||||
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
if not os.path.exists(s.input_dir):
|
input_dir = folder_paths.get_input_directory()
|
||||||
os.makedirs(s.input_dir)
|
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(os.listdir(s.input_dir)), )},
|
{"image": (sorted(os.listdir(input_dir)), )},
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
@ -823,7 +901,8 @@ class LoadImage:
|
|||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image):
|
def load_image(self, image):
|
||||||
image_path = os.path.join(self.input_dir, image)
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
image_path = os.path.join(input_dir, image)
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
image = i.convert("RGB")
|
image = i.convert("RGB")
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
@ -837,18 +916,19 @@ class LoadImage:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image):
|
def IS_CHANGED(s, image):
|
||||||
image_path = os.path.join(s.input_dir, image)
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
image_path = os.path.join(input_dir, image)
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
m.update(f.read())
|
m.update(f.read())
|
||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
class LoadImageMask:
|
class LoadImageMask:
|
||||||
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(os.listdir(s.input_dir)), ),
|
{"image": (sorted(os.listdir(input_dir)), ),
|
||||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
"channel": (["alpha", "red", "green", "blue"], ),}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -857,7 +937,8 @@ class LoadImageMask:
|
|||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image, channel):
|
def load_image(self, image, channel):
|
||||||
image_path = os.path.join(self.input_dir, image)
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
image_path = os.path.join(input_dir, image)
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
mask = None
|
mask = None
|
||||||
c = channel[0].upper()
|
c = channel[0].upper()
|
||||||
@ -872,7 +953,8 @@ class LoadImageMask:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image, channel):
|
def IS_CHANGED(s, image, channel):
|
||||||
image_path = os.path.join(s.input_dir, image)
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
image_path = os.path.join(input_dir, image)
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
m.update(f.read())
|
m.update(f.read())
|
||||||
@ -980,7 +1062,6 @@ class ImagePadForOutpaint:
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"KSampler": KSampler,
|
"KSampler": KSampler,
|
||||||
"CheckpointLoader": CheckpointLoader,
|
|
||||||
"CheckpointLoaderSimple": CheckpointLoaderSimple,
|
"CheckpointLoaderSimple": CheckpointLoaderSimple,
|
||||||
"CLIPTextEncode": CLIPTextEncode,
|
"CLIPTextEncode": CLIPTextEncode,
|
||||||
"CLIPSetLastLayer": CLIPSetLastLayer,
|
"CLIPSetLastLayer": CLIPSetLastLayer,
|
||||||
@ -1009,6 +1090,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
"CLIPVisionEncode": CLIPVisionEncode,
|
"CLIPVisionEncode": CLIPVisionEncode,
|
||||||
"StyleModelApply": StyleModelApply,
|
"StyleModelApply": StyleModelApply,
|
||||||
|
"unCLIPConditioning": unCLIPConditioning,
|
||||||
"ControlNetApply": ControlNetApply,
|
"ControlNetApply": ControlNetApply,
|
||||||
"ControlNetLoader": ControlNetLoader,
|
"ControlNetLoader": ControlNetLoader,
|
||||||
"DiffControlNetLoader": DiffControlNetLoader,
|
"DiffControlNetLoader": DiffControlNetLoader,
|
||||||
@ -1016,6 +1098,10 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPVisionLoader": CLIPVisionLoader,
|
"CLIPVisionLoader": CLIPVisionLoader,
|
||||||
"VAEDecodeTiled": VAEDecodeTiled,
|
"VAEDecodeTiled": VAEDecodeTiled,
|
||||||
"VAEEncodeTiled": VAEEncodeTiled,
|
"VAEEncodeTiled": VAEEncodeTiled,
|
||||||
|
"TomePatchModel": TomePatchModel,
|
||||||
|
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||||
|
"CheckpointLoader": CheckpointLoader,
|
||||||
|
"DiffusersLoader": DiffusersLoader,
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_custom_node(module_path):
|
def load_custom_node(module_path):
|
||||||
@ -1050,7 +1136,8 @@ def load_custom_nodes():
|
|||||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||||
load_custom_node(module_path)
|
load_custom_node(module_path)
|
||||||
|
|
||||||
load_custom_nodes()
|
def init_custom_nodes():
|
||||||
|
load_custom_nodes()
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "silver_custom.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "silver_custom.py"))
|
||||||
@ -47,7 +47,7 @@
|
|||||||
" !git pull\n",
|
" !git pull\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!echo -= Install dependencies =-\n",
|
"!echo -= Install dependencies =-\n",
|
||||||
"!pip install xformers==0.0.16 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu117"
|
"!pip install xformers -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@ -4,7 +4,7 @@ torchsde
|
|||||||
einops
|
einops
|
||||||
open-clip-torch
|
open-clip-torch
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
safetensors
|
safetensors>=0.3.0
|
||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
aiohttp
|
aiohttp
|
||||||
accelerate
|
accelerate
|
||||||
|
|||||||
37
server.py
37
server.py
@ -18,6 +18,7 @@ except ImportError:
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
@ -27,6 +28,23 @@ async def cache_control(request: web.Request, handler):
|
|||||||
response.headers.setdefault('Cache-Control', 'no-cache')
|
response.headers.setdefault('Cache-Control', 'no-cache')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def create_cors_middleware(allowed_origin: str):
|
||||||
|
@web.middleware
|
||||||
|
async def cors_middleware(request: web.Request, handler):
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
# Pre-flight request. Reply successfully:
|
||||||
|
response = web.Response()
|
||||||
|
else:
|
||||||
|
response = await handler(request)
|
||||||
|
|
||||||
|
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
||||||
|
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
|
||||||
|
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
||||||
|
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||||
|
return response
|
||||||
|
|
||||||
|
return cors_middleware
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
@ -37,11 +55,17 @@ class PromptServer():
|
|||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
self.number = 0
|
self.number = 0
|
||||||
self.app = web.Application(client_max_size=20971520, middlewares=[cache_control])
|
|
||||||
|
middlewares = [cache_control]
|
||||||
|
if args.enable_cors_header:
|
||||||
|
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||||||
|
|
||||||
|
self.app = web.Application(client_max_size=20971520, middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
self.web_root = os.path.join(os.path.dirname(
|
self.web_root = os.path.join(os.path.dirname(
|
||||||
os.path.realpath(__file__)), "web")
|
os.path.realpath(__file__)), "web")
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
|
self.routes = routes
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.client_id = None
|
self.client_id = None
|
||||||
|
|
||||||
@ -88,7 +112,7 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.post("/upload/image")
|
@routes.post("/upload/image")
|
||||||
async def upload_image(request):
|
async def upload_image(request):
|
||||||
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
upload_dir = folder_paths.get_input_directory()
|
||||||
|
|
||||||
if not os.path.exists(upload_dir):
|
if not os.path.exists(upload_dir):
|
||||||
os.makedirs(upload_dir)
|
os.makedirs(upload_dir)
|
||||||
@ -130,10 +154,10 @@ class PromptServer():
|
|||||||
async def view_image(request):
|
async def view_image(request):
|
||||||
if "filename" in request.rel_url.query:
|
if "filename" in request.rel_url.query:
|
||||||
type = request.rel_url.query.get("type", "output")
|
type = request.rel_url.query.get("type", "output")
|
||||||
if type not in ["output", "input", "temp"]:
|
output_dir = folder_paths.get_directory_by_type(type)
|
||||||
|
if output_dir is None:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
|
|
||||||
if "subfolder" in request.rel_url.query:
|
if "subfolder" in request.rel_url.query:
|
||||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||||
@ -271,8 +295,9 @@ class PromptServer():
|
|||||||
self.prompt_queue.delete_history_item(id_to_delete)
|
self.prompt_queue.delete_history_item(id_to_delete)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
self.app.add_routes(routes)
|
def add_routes(self):
|
||||||
|
self.app.add_routes(self.routes)
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root),
|
||||||
])
|
])
|
||||||
|
|||||||
137
web/extensions/core/contextMenuFilter.js
Normal file
137
web/extensions/core/contextMenuFilter.js
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
import { app } from "/scripts/app.js";
|
||||||
|
|
||||||
|
// Adds filtering to combo context menus
|
||||||
|
|
||||||
|
const id = "Comfy.ContextMenuFilter";
|
||||||
|
app.registerExtension({
|
||||||
|
name: id,
|
||||||
|
init() {
|
||||||
|
const ctxMenu = LiteGraph.ContextMenu;
|
||||||
|
LiteGraph.ContextMenu = function (values, options) {
|
||||||
|
const ctx = ctxMenu.call(this, values, options);
|
||||||
|
|
||||||
|
// If we are a dark menu (only used for combo boxes) then add a filter input
|
||||||
|
if (options?.className === "dark" && values?.length > 10) {
|
||||||
|
const filter = document.createElement("input");
|
||||||
|
Object.assign(filter.style, {
|
||||||
|
width: "calc(100% - 10px)",
|
||||||
|
border: "0",
|
||||||
|
boxSizing: "border-box",
|
||||||
|
background: "#333",
|
||||||
|
border: "1px solid #999",
|
||||||
|
margin: "0 0 5px 5px",
|
||||||
|
color: "#fff",
|
||||||
|
});
|
||||||
|
filter.placeholder = "Filter list";
|
||||||
|
this.root.prepend(filter);
|
||||||
|
|
||||||
|
let selectedIndex = 0;
|
||||||
|
let items = this.root.querySelectorAll(".litemenu-entry");
|
||||||
|
let itemCount = items.length;
|
||||||
|
let selectedItem;
|
||||||
|
|
||||||
|
// Apply highlighting to the selected item
|
||||||
|
function updateSelected() {
|
||||||
|
if (selectedItem) {
|
||||||
|
selectedItem.style.setProperty("background-color", "");
|
||||||
|
selectedItem.style.setProperty("color", "");
|
||||||
|
}
|
||||||
|
selectedItem = items[selectedIndex];
|
||||||
|
if (selectedItem) {
|
||||||
|
selectedItem.style.setProperty("background-color", "#ccc", "important");
|
||||||
|
selectedItem.style.setProperty("color", "#000", "important");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const positionList = () => {
|
||||||
|
const rect = this.root.getBoundingClientRect();
|
||||||
|
|
||||||
|
// If the top is off screen then shift the element with scaling applied
|
||||||
|
if (rect.top < 0) {
|
||||||
|
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
|
||||||
|
const shift = (this.root.clientHeight * scale) / 2;
|
||||||
|
this.root.style.top = -shift + "px";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updateSelected();
|
||||||
|
|
||||||
|
// Arrow up/down to select items
|
||||||
|
filter.addEventListener("keydown", (e) => {
|
||||||
|
if (e.key === "ArrowUp") {
|
||||||
|
if (selectedIndex === 0) {
|
||||||
|
selectedIndex = itemCount - 1;
|
||||||
|
} else {
|
||||||
|
selectedIndex--;
|
||||||
|
}
|
||||||
|
updateSelected();
|
||||||
|
e.preventDefault();
|
||||||
|
} else if (e.key === "ArrowDown") {
|
||||||
|
if (selectedIndex === itemCount - 1) {
|
||||||
|
selectedIndex = 0;
|
||||||
|
} else {
|
||||||
|
selectedIndex++;
|
||||||
|
}
|
||||||
|
updateSelected();
|
||||||
|
e.preventDefault();
|
||||||
|
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
|
||||||
|
selectedItem.click();
|
||||||
|
} else if(e.key === "Escape") {
|
||||||
|
this.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
filter.addEventListener("input", () => {
|
||||||
|
// Hide all items that dont match our filter
|
||||||
|
const term = filter.value.toLocaleLowerCase();
|
||||||
|
items = this.root.querySelectorAll(".litemenu-entry");
|
||||||
|
// When filtering recompute which items are visible for arrow up/down
|
||||||
|
// Try and maintain selection
|
||||||
|
let visibleItems = [];
|
||||||
|
for (const item of items) {
|
||||||
|
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
|
||||||
|
if (visible) {
|
||||||
|
item.style.display = "block";
|
||||||
|
if (item === selectedItem) {
|
||||||
|
selectedIndex = visibleItems.length;
|
||||||
|
}
|
||||||
|
visibleItems.push(item);
|
||||||
|
} else {
|
||||||
|
item.style.display = "none";
|
||||||
|
if (item === selectedItem) {
|
||||||
|
selectedIndex = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
items = visibleItems;
|
||||||
|
updateSelected();
|
||||||
|
|
||||||
|
// If we have an event then we can try and position the list under the source
|
||||||
|
if (options.event) {
|
||||||
|
let top = options.event.clientY - 10;
|
||||||
|
|
||||||
|
const bodyRect = document.body.getBoundingClientRect();
|
||||||
|
const rootRect = this.root.getBoundingClientRect();
|
||||||
|
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
|
||||||
|
top = Math.max(0, bodyRect.height - rootRect.height - 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.root.style.top = top + "px";
|
||||||
|
positionList();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
// Focus the filter box when opening
|
||||||
|
filter.focus();
|
||||||
|
|
||||||
|
positionList();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx;
|
||||||
|
};
|
||||||
|
|
||||||
|
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
|
||||||
|
},
|
||||||
|
});
|
||||||
@ -30,7 +30,8 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Overwrite the value in the serialized workflow pnginfo
|
// Overwrite the value in the serialized workflow pnginfo
|
||||||
workflowNode.widgets_values[widgetIndex] = prompt;
|
if (workflowNode?.widgets_values)
|
||||||
|
workflowNode.widgets_values[widgetIndex] = prompt;
|
||||||
|
|
||||||
return prompt;
|
return prompt;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -3,10 +3,10 @@ import { app } from "/scripts/app.js";
|
|||||||
// Inverts the scrolling of context menus
|
// Inverts the scrolling of context menus
|
||||||
|
|
||||||
const id = "Comfy.InvertMenuScrolling";
|
const id = "Comfy.InvertMenuScrolling";
|
||||||
const ctxMenu = LiteGraph.ContextMenu;
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: id,
|
name: id,
|
||||||
init() {
|
init() {
|
||||||
|
const ctxMenu = LiteGraph.ContextMenu;
|
||||||
const replace = () => {
|
const replace = () => {
|
||||||
LiteGraph.ContextMenu = function (values, options) {
|
LiteGraph.ContextMenu = function (values, options) {
|
||||||
options = options || {};
|
options = options || {};
|
||||||
|
|||||||
@ -11,11 +11,14 @@ app.registerExtension({
|
|||||||
this.properties = {};
|
this.properties = {};
|
||||||
}
|
}
|
||||||
this.properties.showOutputText = RerouteNode.defaultVisibility;
|
this.properties.showOutputText = RerouteNode.defaultVisibility;
|
||||||
|
this.properties.horizontal = false;
|
||||||
|
|
||||||
this.addInput("", "*");
|
this.addInput("", "*");
|
||||||
this.addOutput(this.properties.showOutputText ? "*" : "", "*");
|
this.addOutput(this.properties.showOutputText ? "*" : "", "*");
|
||||||
|
|
||||||
this.onConnectionsChange = function (type, index, connected, link_info) {
|
this.onConnectionsChange = function (type, index, connected, link_info) {
|
||||||
|
this.applyOrientation();
|
||||||
|
|
||||||
// Prevent multiple connections to different types when we have no input
|
// Prevent multiple connections to different types when we have no input
|
||||||
if (connected && type === LiteGraph.OUTPUT) {
|
if (connected && type === LiteGraph.OUTPUT) {
|
||||||
// Ignore wildcard nodes as these will be updated to real types
|
// Ignore wildcard nodes as these will be updated to real types
|
||||||
@ -43,12 +46,19 @@ app.registerExtension({
|
|||||||
const node = app.graph.getNodeById(link.origin_id);
|
const node = app.graph.getNodeById(link.origin_id);
|
||||||
const type = node.constructor.type;
|
const type = node.constructor.type;
|
||||||
if (type === "Reroute") {
|
if (type === "Reroute") {
|
||||||
// Move the previous node
|
if (node === this) {
|
||||||
currentNode = node;
|
// We've found a circle
|
||||||
|
currentNode.disconnectInput(link.target_slot);
|
||||||
|
currentNode = null;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Move the previous node
|
||||||
|
currentNode = node;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// We've found the end
|
// We've found the end
|
||||||
inputNode = currentNode;
|
inputNode = currentNode;
|
||||||
inputType = node.outputs[link.origin_slot].type;
|
inputType = node.outputs[link.origin_slot]?.type ?? null;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -80,7 +90,7 @@ app.registerExtension({
|
|||||||
updateNodes.push(node);
|
updateNodes.push(node);
|
||||||
} else {
|
} else {
|
||||||
// We've found an output
|
// We've found an output
|
||||||
const nodeOutType = node.inputs[link.target_slot].type;
|
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
|
||||||
if (inputType && nodeOutType !== inputType) {
|
if (inputType && nodeOutType !== inputType) {
|
||||||
// The output doesnt match our input so disconnect it
|
// The output doesnt match our input so disconnect it
|
||||||
node.disconnectInput(link.target_slot);
|
node.disconnectInput(link.target_slot);
|
||||||
@ -105,6 +115,7 @@ app.registerExtension({
|
|||||||
node.__outputType = displayType;
|
node.__outputType = displayType;
|
||||||
node.outputs[0].name = node.properties.showOutputText ? displayType : "";
|
node.outputs[0].name = node.properties.showOutputText ? displayType : "";
|
||||||
node.size = node.computeSize();
|
node.size = node.computeSize();
|
||||||
|
node.applyOrientation();
|
||||||
|
|
||||||
for (const l of node.outputs[0].links || []) {
|
for (const l of node.outputs[0].links || []) {
|
||||||
const link = app.graph.links[l];
|
const link = app.graph.links[l];
|
||||||
@ -146,6 +157,7 @@ app.registerExtension({
|
|||||||
this.outputs[0].name = "";
|
this.outputs[0].name = "";
|
||||||
}
|
}
|
||||||
this.size = this.computeSize();
|
this.size = this.computeSize();
|
||||||
|
this.applyOrientation();
|
||||||
app.graph.setDirtyCanvas(true, true);
|
app.graph.setDirtyCanvas(true, true);
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -154,9 +166,32 @@ app.registerExtension({
|
|||||||
callback: () => {
|
callback: () => {
|
||||||
RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility);
|
RerouteNode.setDefaultTextVisibility(!RerouteNode.defaultVisibility);
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// naming is inverted with respect to LiteGraphNode.horizontal
|
||||||
|
// LiteGraphNode.horizontal == true means that
|
||||||
|
// each slot in the inputs and outputs are layed out horizontally,
|
||||||
|
// which is the opposite of the visual orientation of the inputs and outputs as a node
|
||||||
|
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
|
||||||
|
callback: () => {
|
||||||
|
this.properties.horizontal = !this.properties.horizontal;
|
||||||
|
this.applyOrientation();
|
||||||
|
},
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
applyOrientation() {
|
||||||
|
this.horizontal = this.properties.horizontal;
|
||||||
|
if (this.horizontal) {
|
||||||
|
// we correct the input position, because LiteGraphNode.horizontal
|
||||||
|
// doesn't account for title presence
|
||||||
|
// which reroute nodes don't have
|
||||||
|
this.inputs[0].pos = [this.size[0] / 2, 0];
|
||||||
|
} else {
|
||||||
|
delete this.inputs[0].pos;
|
||||||
|
}
|
||||||
|
app.graph.setDirtyCanvas(true, true);
|
||||||
|
}
|
||||||
|
|
||||||
computeSize() {
|
computeSize() {
|
||||||
return [
|
return [
|
||||||
|
|||||||
21
web/extensions/core/slotDefaults.js
Normal file
21
web/extensions/core/slotDefaults.js
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import { app } from "/scripts/app.js";
|
||||||
|
|
||||||
|
// Adds defaults for quickly adding nodes with middle click on the input/output
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.SlotDefaults",
|
||||||
|
init() {
|
||||||
|
LiteGraph.middle_click_slot_add_default_node = true;
|
||||||
|
LiteGraph.slot_types_default_in = {
|
||||||
|
MODEL: "CheckpointLoaderSimple",
|
||||||
|
LATENT: "EmptyLatentImage",
|
||||||
|
VAE: "VAELoader",
|
||||||
|
};
|
||||||
|
|
||||||
|
LiteGraph.slot_types_default_out = {
|
||||||
|
LATENT: "VAEDecode",
|
||||||
|
IMAGE: "SaveImage",
|
||||||
|
CLIP: "CLIPTextEncode",
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
89
web/extensions/core/snapToGrid.js
Normal file
89
web/extensions/core/snapToGrid.js
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import { app } from "/scripts/app.js";
|
||||||
|
|
||||||
|
// Shift + drag/resize to snap to grid
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.SnapToGrid",
|
||||||
|
init() {
|
||||||
|
// Add setting to control grid size
|
||||||
|
app.ui.settings.addSetting({
|
||||||
|
id: "Comfy.SnapToGrid.GridSize",
|
||||||
|
name: "Grid Size",
|
||||||
|
type: "number",
|
||||||
|
attrs: {
|
||||||
|
min: 1,
|
||||||
|
max: 500,
|
||||||
|
},
|
||||||
|
tooltip:
|
||||||
|
"When dragging and resizing nodes while holding shift they will be aligned to the grid, this controls the size of that grid.",
|
||||||
|
defaultValue: LiteGraph.CANVAS_GRID_SIZE,
|
||||||
|
onChange(value) {
|
||||||
|
LiteGraph.CANVAS_GRID_SIZE = +value;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// After moving a node, if the shift key is down align it to grid
|
||||||
|
const onNodeMoved = app.canvas.onNodeMoved;
|
||||||
|
app.canvas.onNodeMoved = function (node) {
|
||||||
|
const r = onNodeMoved?.apply(this, arguments);
|
||||||
|
|
||||||
|
if (app.shiftDown) {
|
||||||
|
// Ensure all selected nodes are realigned
|
||||||
|
for (const id in this.selected_nodes) {
|
||||||
|
this.selected_nodes[id].alignToGrid();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r;
|
||||||
|
};
|
||||||
|
|
||||||
|
// When a node is added, add a resize handler to it so we can fix align the size with the grid
|
||||||
|
const onNodeAdded = app.graph.onNodeAdded;
|
||||||
|
app.graph.onNodeAdded = function (node) {
|
||||||
|
const onResize = node.onResize;
|
||||||
|
node.onResize = function () {
|
||||||
|
if (app.shiftDown) {
|
||||||
|
const w = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[0] / LiteGraph.CANVAS_GRID_SIZE);
|
||||||
|
const h = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.size[1] / LiteGraph.CANVAS_GRID_SIZE);
|
||||||
|
node.size[0] = w;
|
||||||
|
node.size[1] = h;
|
||||||
|
}
|
||||||
|
return onResize?.apply(this, arguments);
|
||||||
|
};
|
||||||
|
return onNodeAdded?.apply(this, arguments);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Draw a preview of where the node will go if holding shift and the node is selected
|
||||||
|
const origDrawNode = LGraphCanvas.prototype.drawNode;
|
||||||
|
LGraphCanvas.prototype.drawNode = function (node, ctx) {
|
||||||
|
if (app.shiftDown && this.node_dragged && node.id in this.selected_nodes) {
|
||||||
|
const x = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[0] / LiteGraph.CANVAS_GRID_SIZE);
|
||||||
|
const y = LiteGraph.CANVAS_GRID_SIZE * Math.round(node.pos[1] / LiteGraph.CANVAS_GRID_SIZE);
|
||||||
|
|
||||||
|
const shiftX = x - node.pos[0];
|
||||||
|
let shiftY = y - node.pos[1];
|
||||||
|
|
||||||
|
let w, h;
|
||||||
|
if (node.flags.collapsed) {
|
||||||
|
w = node._collapsed_width;
|
||||||
|
h = LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
|
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
|
} else {
|
||||||
|
w = node.size[0];
|
||||||
|
h = node.size[1];
|
||||||
|
let titleMode = node.constructor.title_mode;
|
||||||
|
if (titleMode !== LiteGraph.TRANSPARENT_TITLE && titleMode !== LiteGraph.NO_TITLE) {
|
||||||
|
h += LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
|
shiftY -= LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const f = ctx.fillStyle;
|
||||||
|
ctx.fillStyle = "rgba(100, 100, 100, 0.5)";
|
||||||
|
ctx.fillRect(shiftX, shiftY, w, h);
|
||||||
|
ctx.fillStyle = f;
|
||||||
|
}
|
||||||
|
|
||||||
|
return origDrawNode.apply(this, arguments);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
@ -20,7 +20,7 @@ function hideWidget(node, widget, suffix = "") {
|
|||||||
if (link == null) {
|
if (link == null) {
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
return widget.value;
|
return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Hide any linked widgets, e.g. seed+randomize
|
// Hide any linked widgets, e.g. seed+randomize
|
||||||
@ -101,7 +101,7 @@ app.registerExtension({
|
|||||||
callback: () => convertToWidget(this, w),
|
callback: () => convertToWidget(this, w),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
const config = nodeData?.input?.required[w.name] || [w.type, w.options || {}];
|
const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}];
|
||||||
if (isConvertableWidget(w, config)) {
|
if (isConvertableWidget(w, config)) {
|
||||||
toInput.push({
|
toInput.push({
|
||||||
content: `Convert ${w.name} to input`,
|
content: `Convert ${w.name} to input`,
|
||||||
|
|||||||
@ -5,10 +5,20 @@ import { defaultGraph } from "./defaultGraph.js";
|
|||||||
import { getPngMetadata, importA1111 } from "./pnginfo.js";
|
import { getPngMetadata, importA1111 } from "./pnginfo.js";
|
||||||
|
|
||||||
class ComfyApp {
|
class ComfyApp {
|
||||||
|
/**
|
||||||
|
* List of {number, batchCount} entries to queue
|
||||||
|
*/
|
||||||
|
#queueItems = [];
|
||||||
|
/**
|
||||||
|
* If the queue is currently being processed
|
||||||
|
*/
|
||||||
|
#processingQueue = false;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.ui = new ComfyUI(this);
|
this.ui = new ComfyUI(this);
|
||||||
this.extensions = [];
|
this.extensions = [];
|
||||||
this.nodeOutputs = {};
|
this.nodeOutputs = {};
|
||||||
|
this.shiftDown = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -102,6 +112,46 @@ class ComfyApp {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#addNodeKeyHandler(node) {
|
||||||
|
const app = this;
|
||||||
|
const origNodeOnKeyDown = node.prototype.onKeyDown;
|
||||||
|
|
||||||
|
node.prototype.onKeyDown = function(e) {
|
||||||
|
if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.flags.collapsed || !this.imgs || this.imageIndex === null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let handled = false;
|
||||||
|
|
||||||
|
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
|
||||||
|
if (e.key === "ArrowLeft") {
|
||||||
|
this.imageIndex -= 1;
|
||||||
|
} else if (e.key === "ArrowRight") {
|
||||||
|
this.imageIndex += 1;
|
||||||
|
}
|
||||||
|
this.imageIndex %= this.imgs.length;
|
||||||
|
|
||||||
|
if (this.imageIndex < 0) {
|
||||||
|
this.imageIndex = this.imgs.length + this.imageIndex;
|
||||||
|
}
|
||||||
|
handled = true;
|
||||||
|
} else if (e.key === "Escape") {
|
||||||
|
this.imageIndex = null;
|
||||||
|
handled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (handled === true) {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopImmediatePropagation();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds Custom drawing logic for nodes
|
* Adds Custom drawing logic for nodes
|
||||||
* e.g. Draws images and handles thumbnail navigation on nodes that output images
|
* e.g. Draws images and handles thumbnail navigation on nodes that output images
|
||||||
@ -628,11 +678,16 @@ class ComfyApp {
|
|||||||
|
|
||||||
#addKeyboardHandler() {
|
#addKeyboardHandler() {
|
||||||
window.addEventListener("keydown", (e) => {
|
window.addEventListener("keydown", (e) => {
|
||||||
|
this.shiftDown = e.shiftKey;
|
||||||
|
|
||||||
// Queue prompt using ctrl or command + enter
|
// Queue prompt using ctrl or command + enter
|
||||||
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
|
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
|
||||||
this.queuePrompt(e.shiftKey ? -1 : 0);
|
this.queuePrompt(e.shiftKey ? -1 : 0);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
window.addEventListener("keyup", (e) => {
|
||||||
|
this.shiftDown = e.shiftKey;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -667,6 +722,9 @@ class ComfyApp {
|
|||||||
const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph));
|
const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph));
|
||||||
this.ctx = canvasEl.getContext("2d");
|
this.ctx = canvasEl.getContext("2d");
|
||||||
|
|
||||||
|
LiteGraph.release_link_on_empty_shows_menu = true;
|
||||||
|
LiteGraph.alt_drag_do_clone_nodes = true;
|
||||||
|
|
||||||
this.graph.start();
|
this.graph.start();
|
||||||
|
|
||||||
function resizeCanvas() {
|
function resizeCanvas() {
|
||||||
@ -785,6 +843,7 @@ class ComfyApp {
|
|||||||
|
|
||||||
this.#addNodeContextMenuHandler(node);
|
this.#addNodeContextMenuHandler(node);
|
||||||
this.#addDrawBackgroundHandler(node, app);
|
this.#addDrawBackgroundHandler(node, app);
|
||||||
|
this.#addNodeKeyHandler(node);
|
||||||
|
|
||||||
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
|
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
|
||||||
LiteGraph.registerNodeType(nodeId, node);
|
LiteGraph.registerNodeType(nodeId, node);
|
||||||
@ -802,7 +861,7 @@ class ComfyApp {
|
|||||||
this.clean();
|
this.clean();
|
||||||
|
|
||||||
if (!graphData) {
|
if (!graphData) {
|
||||||
graphData = defaultGraph;
|
graphData = structuredClone(defaultGraph);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
|
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
|
||||||
@ -915,31 +974,47 @@ class ComfyApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async queuePrompt(number, batchCount = 1) {
|
async queuePrompt(number, batchCount = 1) {
|
||||||
for (let i = 0; i < batchCount; i++) {
|
this.#queueItems.push({ number, batchCount });
|
||||||
const p = await this.graphToPrompt();
|
|
||||||
|
|
||||||
try {
|
// Only have one action process the items so each one gets a unique seed correctly
|
||||||
await api.queuePrompt(number, p);
|
if (this.#processingQueue) {
|
||||||
} catch (error) {
|
return;
|
||||||
this.ui.dialog.show(error.response || error.toString());
|
}
|
||||||
return;
|
|
||||||
}
|
this.#processingQueue = true;
|
||||||
|
try {
|
||||||
|
while (this.#queueItems.length) {
|
||||||
|
({ number, batchCount } = this.#queueItems.pop());
|
||||||
|
|
||||||
for (const n of p.workflow.nodes) {
|
for (let i = 0; i < batchCount; i++) {
|
||||||
const node = graph.getNodeById(n.id);
|
const p = await this.graphToPrompt();
|
||||||
if (node.widgets) {
|
|
||||||
for (const widget of node.widgets) {
|
try {
|
||||||
// Allow widgets to run callbacks after a prompt has been queued
|
await api.queuePrompt(number, p);
|
||||||
// e.g. random seed after every gen
|
} catch (error) {
|
||||||
if (widget.afterQueued) {
|
this.ui.dialog.show(error.response || error.toString());
|
||||||
widget.afterQueued();
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const n of p.workflow.nodes) {
|
||||||
|
const node = graph.getNodeById(n.id);
|
||||||
|
if (node.widgets) {
|
||||||
|
for (const widget of node.widgets) {
|
||||||
|
// Allow widgets to run callbacks after a prompt has been queued
|
||||||
|
// e.g. random seed after every gen
|
||||||
|
if (widget.afterQueued) {
|
||||||
|
widget.afterQueued();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.canvas.draw(true, true);
|
||||||
|
await this.ui.queue.update();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} finally {
|
||||||
this.canvas.draw(true, true);
|
this.#processingQueue = false;
|
||||||
await this.ui.queue.update();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -35,21 +35,86 @@ export function $el(tag, propsOrChildren, children) {
|
|||||||
return element;
|
return element;
|
||||||
}
|
}
|
||||||
|
|
||||||
function dragElement(dragEl) {
|
function dragElement(dragEl, settings) {
|
||||||
var posDiffX = 0,
|
var posDiffX = 0,
|
||||||
posDiffY = 0,
|
posDiffY = 0,
|
||||||
posStartX = 0,
|
posStartX = 0,
|
||||||
posStartY = 0,
|
posStartY = 0,
|
||||||
newPosX = 0,
|
newPosX = 0,
|
||||||
newPosY = 0;
|
newPosY = 0;
|
||||||
if (dragEl.getElementsByClassName('drag-handle')[0]) {
|
if (dragEl.getElementsByClassName("drag-handle")[0]) {
|
||||||
// if present, the handle is where you move the DIV from:
|
// if present, the handle is where you move the DIV from:
|
||||||
dragEl.getElementsByClassName('drag-handle')[0].onmousedown = dragMouseDown;
|
dragEl.getElementsByClassName("drag-handle")[0].onmousedown = dragMouseDown;
|
||||||
} else {
|
} else {
|
||||||
// otherwise, move the DIV from anywhere inside the DIV:
|
// otherwise, move the DIV from anywhere inside the DIV:
|
||||||
dragEl.onmousedown = dragMouseDown;
|
dragEl.onmousedown = dragMouseDown;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When the element resizes (e.g. view queue) ensure it is still in the windows bounds
|
||||||
|
const resizeObserver = new ResizeObserver(() => {
|
||||||
|
ensureInBounds();
|
||||||
|
}).observe(dragEl);
|
||||||
|
|
||||||
|
function ensureInBounds() {
|
||||||
|
if (dragEl.classList.contains("comfy-menu-manual-pos")) {
|
||||||
|
newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft));
|
||||||
|
newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop));
|
||||||
|
|
||||||
|
positionElement();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function positionElement() {
|
||||||
|
const halfWidth = document.body.clientWidth / 2;
|
||||||
|
const anchorRight = newPosX + dragEl.clientWidth / 2 > halfWidth;
|
||||||
|
|
||||||
|
// set the element's new position:
|
||||||
|
if (anchorRight) {
|
||||||
|
dragEl.style.left = "unset";
|
||||||
|
dragEl.style.right = document.body.clientWidth - newPosX - dragEl.clientWidth + "px";
|
||||||
|
} else {
|
||||||
|
dragEl.style.left = newPosX + "px";
|
||||||
|
dragEl.style.right = "unset";
|
||||||
|
}
|
||||||
|
|
||||||
|
dragEl.style.top = newPosY + "px";
|
||||||
|
dragEl.style.bottom = "unset";
|
||||||
|
|
||||||
|
if (savePos) {
|
||||||
|
localStorage.setItem(
|
||||||
|
"Comfy.MenuPosition",
|
||||||
|
JSON.stringify({
|
||||||
|
x: dragEl.offsetLeft,
|
||||||
|
y: dragEl.offsetTop,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function restorePos() {
|
||||||
|
let pos = localStorage.getItem("Comfy.MenuPosition");
|
||||||
|
if (pos) {
|
||||||
|
pos = JSON.parse(pos);
|
||||||
|
newPosX = pos.x;
|
||||||
|
newPosY = pos.y;
|
||||||
|
positionElement();
|
||||||
|
ensureInBounds();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let savePos = undefined;
|
||||||
|
settings.addSetting({
|
||||||
|
id: "Comfy.MenuPosition",
|
||||||
|
name: "Save menu position",
|
||||||
|
type: "boolean",
|
||||||
|
defaultValue: savePos,
|
||||||
|
onChange(value) {
|
||||||
|
if (savePos === undefined && value) {
|
||||||
|
restorePos();
|
||||||
|
}
|
||||||
|
savePos = value;
|
||||||
|
},
|
||||||
|
});
|
||||||
function dragMouseDown(e) {
|
function dragMouseDown(e) {
|
||||||
e = e || window.event;
|
e = e || window.event;
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
@ -64,18 +129,25 @@ function dragElement(dragEl) {
|
|||||||
function elementDrag(e) {
|
function elementDrag(e) {
|
||||||
e = e || window.event;
|
e = e || window.event;
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|
||||||
|
dragEl.classList.add("comfy-menu-manual-pos");
|
||||||
|
|
||||||
// calculate the new cursor position:
|
// calculate the new cursor position:
|
||||||
posDiffX = e.clientX - posStartX;
|
posDiffX = e.clientX - posStartX;
|
||||||
posDiffY = e.clientY - posStartY;
|
posDiffY = e.clientY - posStartY;
|
||||||
posStartX = e.clientX;
|
posStartX = e.clientX;
|
||||||
posStartY = e.clientY;
|
posStartY = e.clientY;
|
||||||
newPosX = Math.min((document.body.clientWidth - dragEl.clientWidth), Math.max(0, (dragEl.offsetLeft + posDiffX)));
|
|
||||||
newPosY = Math.min((document.body.clientHeight - dragEl.clientHeight), Math.max(0, (dragEl.offsetTop + posDiffY)));
|
newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft + posDiffX));
|
||||||
// set the element's new position:
|
newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop + posDiffY));
|
||||||
dragEl.style.top = newPosY + "px";
|
|
||||||
dragEl.style.left = newPosX + "px";
|
positionElement();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
window.addEventListener("resize", () => {
|
||||||
|
ensureInBounds();
|
||||||
|
});
|
||||||
|
|
||||||
function closeDragElement() {
|
function closeDragElement() {
|
||||||
// stop moving when mouse button is released:
|
// stop moving when mouse button is released:
|
||||||
document.onmouseup = null;
|
document.onmouseup = null;
|
||||||
@ -90,7 +162,7 @@ class ComfyDialog {
|
|||||||
$el("p", { $: (p) => (this.textElement = p) }),
|
$el("p", { $: (p) => (this.textElement = p) }),
|
||||||
$el("button", {
|
$el("button", {
|
||||||
type: "button",
|
type: "button",
|
||||||
textContent: "CLOSE",
|
textContent: "Close",
|
||||||
onclick: () => this.close(),
|
onclick: () => this.close(),
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
@ -125,7 +197,7 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
localStorage[settingId] = JSON.stringify(value);
|
localStorage[settingId] = JSON.stringify(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
addSetting({ id, name, type, defaultValue, onChange }) {
|
addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", }) {
|
||||||
if (!id) {
|
if (!id) {
|
||||||
throw new Error("Settings must have an ID");
|
throw new Error("Settings must have an ID");
|
||||||
}
|
}
|
||||||
@ -152,42 +224,83 @@ class ComfySettingsDialog extends ComfyDialog {
|
|||||||
value = v;
|
value = v;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let element;
|
||||||
|
value = this.getSettingValue(id, defaultValue);
|
||||||
|
|
||||||
if (typeof type === "function") {
|
if (typeof type === "function") {
|
||||||
return type(name, setter, value);
|
element = type(name, setter, value, attrs);
|
||||||
|
} else {
|
||||||
|
switch (type) {
|
||||||
|
case "boolean":
|
||||||
|
element = $el("div", [
|
||||||
|
$el("label", { textContent: name || id }, [
|
||||||
|
$el("input", {
|
||||||
|
type: "checkbox",
|
||||||
|
checked: !!value,
|
||||||
|
oninput: (e) => {
|
||||||
|
setter(e.target.checked);
|
||||||
|
},
|
||||||
|
...attrs
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
|
break;
|
||||||
|
case "number":
|
||||||
|
element = $el("div", [
|
||||||
|
$el("label", { textContent: name || id }, [
|
||||||
|
$el("input", {
|
||||||
|
type,
|
||||||
|
value,
|
||||||
|
oninput: (e) => {
|
||||||
|
setter(e.target.value);
|
||||||
|
},
|
||||||
|
...attrs
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
console.warn("Unsupported setting type, defaulting to text");
|
||||||
|
element = $el("div", [
|
||||||
|
$el("label", { textContent: name || id }, [
|
||||||
|
$el("input", {
|
||||||
|
value,
|
||||||
|
oninput: (e) => {
|
||||||
|
setter(e.target.value);
|
||||||
|
},
|
||||||
|
...attrs
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(tooltip) {
|
||||||
|
element.title = tooltip;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (type) {
|
return element;
|
||||||
case "boolean":
|
|
||||||
return $el("div", [
|
|
||||||
$el("label", { textContent: name || id }, [
|
|
||||||
$el("input", {
|
|
||||||
type: "checkbox",
|
|
||||||
checked: !!value,
|
|
||||||
oninput: (e) => {
|
|
||||||
setter(e.target.checked);
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
]),
|
|
||||||
]);
|
|
||||||
default:
|
|
||||||
console.warn("Unsupported setting type, defaulting to text");
|
|
||||||
return $el("div", [
|
|
||||||
$el("label", { textContent: name || id }, [
|
|
||||||
$el("input", {
|
|
||||||
value,
|
|
||||||
oninput: (e) => {
|
|
||||||
setter(e.target.value);
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
]),
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const self = this;
|
||||||
|
return {
|
||||||
|
get value() {
|
||||||
|
return self.getSettingValue(id, defaultValue);
|
||||||
|
},
|
||||||
|
set value(v) {
|
||||||
|
self.setSettingValue(id, v);
|
||||||
|
},
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
show() {
|
show() {
|
||||||
super.show();
|
super.show();
|
||||||
|
Object.assign(this.textElement.style, {
|
||||||
|
display: "flex",
|
||||||
|
flexDirection: "column",
|
||||||
|
gap: "10px"
|
||||||
|
});
|
||||||
this.textElement.replaceChildren(...this.settings.map((s) => s.render()));
|
this.textElement.replaceChildren(...this.settings.map((s) => s.render()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -225,10 +338,10 @@ class ComfyList {
|
|||||||
$el("button", {
|
$el("button", {
|
||||||
textContent: "Load",
|
textContent: "Load",
|
||||||
onclick: () => {
|
onclick: () => {
|
||||||
|
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
|
||||||
if (item.outputs) {
|
if (item.outputs) {
|
||||||
app.nodeOutputs = item.outputs;
|
app.nodeOutputs = item.outputs;
|
||||||
}
|
}
|
||||||
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
|
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
$el("button", {
|
$el("button", {
|
||||||
@ -300,6 +413,13 @@ export class ComfyUI {
|
|||||||
this.history.update();
|
this.history.update();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const confirmClear = this.settings.addSetting({
|
||||||
|
id: "Comfy.ConfirmClear",
|
||||||
|
name: "Require confirmation when clearing workflow",
|
||||||
|
type: "boolean",
|
||||||
|
defaultValue: true,
|
||||||
|
});
|
||||||
|
|
||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: ".json,image/png",
|
accept: ".json,image/png",
|
||||||
@ -311,39 +431,57 @@ export class ComfyUI {
|
|||||||
});
|
});
|
||||||
|
|
||||||
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
|
this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [
|
||||||
$el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [
|
$el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [
|
||||||
$el("span.drag-handle"),
|
$el("span.drag-handle"),
|
||||||
$el("span", { $: (q) => (this.queueSize = q) }),
|
$el("span", { $: (q) => (this.queueSize = q) }),
|
||||||
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
|
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
|
||||||
]),
|
]),
|
||||||
$el("button.comfy-queue-btn", { textContent: "Queue Prompt", onclick: () => app.queuePrompt(0, this.batchCount) }),
|
$el("button.comfy-queue-btn", {
|
||||||
|
textContent: "Queue Prompt",
|
||||||
|
onclick: () => app.queuePrompt(0, this.batchCount),
|
||||||
|
}),
|
||||||
$el("div", {}, [
|
$el("div", {}, [
|
||||||
$el("label", { innerHTML: "Extra options"}, [
|
$el("label", { innerHTML: "Extra options" }, [
|
||||||
$el("input", { type: "checkbox",
|
$el("input", {
|
||||||
onchange: (i) => {
|
type: "checkbox",
|
||||||
document.getElementById('extraOptions').style.display = i.srcElement.checked ? "block" : "none";
|
onchange: (i) => {
|
||||||
this.batchCount = i.srcElement.checked ? document.getElementById('batchCountInputRange').value : 1;
|
document.getElementById("extraOptions").style.display = i.srcElement.checked ? "block" : "none";
|
||||||
document.getElementById('autoQueueCheckbox').checked = false;
|
this.batchCount = i.srcElement.checked ? document.getElementById("batchCountInputRange").value : 1;
|
||||||
}
|
document.getElementById("autoQueueCheckbox").checked = false;
|
||||||
})
|
},
|
||||||
])
|
|
||||||
]),
|
|
||||||
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" }}, [
|
|
||||||
$el("label", { innerHTML: "Batch count" }, [
|
|
||||||
$el("input", { id: "batchCountInputNumber", type: "number", value: this.batchCount, min: "1", style: { width: "35%", "margin-left": "0.4em" },
|
|
||||||
oninput: (i) => {
|
|
||||||
this.batchCount = i.target.value;
|
|
||||||
document.getElementById('batchCountInputRange').value = this.batchCount;
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
$el("input", { id: "batchCountInputRange", type: "range", min: "1", max: "100", value: this.batchCount,
|
]),
|
||||||
|
]),
|
||||||
|
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" } }, [
|
||||||
|
$el("label", { innerHTML: "Batch count" }, [
|
||||||
|
$el("input", {
|
||||||
|
id: "batchCountInputNumber",
|
||||||
|
type: "number",
|
||||||
|
value: this.batchCount,
|
||||||
|
min: "1",
|
||||||
|
style: { width: "35%", "margin-left": "0.4em" },
|
||||||
|
oninput: (i) => {
|
||||||
|
this.batchCount = i.target.value;
|
||||||
|
document.getElementById("batchCountInputRange").value = this.batchCount;
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
id: "batchCountInputRange",
|
||||||
|
type: "range",
|
||||||
|
min: "1",
|
||||||
|
max: "100",
|
||||||
|
value: this.batchCount,
|
||||||
oninput: (i) => {
|
oninput: (i) => {
|
||||||
this.batchCount = i.srcElement.value;
|
this.batchCount = i.srcElement.value;
|
||||||
document.getElementById('batchCountInputNumber').value = i.srcElement.value;
|
document.getElementById("batchCountInputNumber").value = i.srcElement.value;
|
||||||
}
|
},
|
||||||
|
}),
|
||||||
|
$el("input", {
|
||||||
|
id: "autoQueueCheckbox",
|
||||||
|
type: "checkbox",
|
||||||
|
checked: false,
|
||||||
|
title: "automatically queue prompt when the queue size hits 0",
|
||||||
}),
|
}),
|
||||||
$el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, title: "automatically queue prompt when the queue size hits 0",
|
|
||||||
})
|
|
||||||
]),
|
]),
|
||||||
]),
|
]),
|
||||||
$el("div.comfy-menu-btns", [
|
$el("div.comfy-menu-btns", [
|
||||||
@ -389,14 +527,19 @@ export class ComfyUI {
|
|||||||
$el("button", { textContent: "Load", onclick: () => fileInput.click() }),
|
$el("button", { textContent: "Load", onclick: () => fileInput.click() }),
|
||||||
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
$el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
||||||
$el("button", { textContent: "Clear", onclick: () => {
|
$el("button", { textContent: "Clear", onclick: () => {
|
||||||
app.clean();
|
if (!confirmClear.value || confirm("Clear workflow?")) {
|
||||||
app.graph.clear();
|
app.clean();
|
||||||
|
app.graph.clear();
|
||||||
|
}
|
||||||
|
}}),
|
||||||
|
$el("button", { textContent: "Load Default", onclick: () => {
|
||||||
|
if (!confirmClear.value || confirm("Load default workflow?")) {
|
||||||
|
app.loadGraphData()
|
||||||
|
}
|
||||||
}}),
|
}}),
|
||||||
$el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }),
|
|
||||||
$el("button", { textContent: "Delete Images", onclick: () => api.deleteAllImages() }),
|
|
||||||
]);
|
]);
|
||||||
|
|
||||||
dragElement(this.menuContainer);
|
dragElement(this.menuContainer, this.settings);
|
||||||
|
|
||||||
this.setStatus({ exec_info: { queue_remaining: "X" } });
|
this.setStatus({ exec_info: { queue_remaining: "X" } });
|
||||||
}
|
}
|
||||||
@ -404,10 +547,14 @@ export class ComfyUI {
|
|||||||
setStatus(status) {
|
setStatus(status) {
|
||||||
this.queueSize.textContent = "Queue size: " + (status ? status.exec_info.queue_remaining : "ERR");
|
this.queueSize.textContent = "Queue size: " + (status ? status.exec_info.queue_remaining : "ERR");
|
||||||
if (status) {
|
if (status) {
|
||||||
if (this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && document.getElementById('autoQueueCheckbox').checked) {
|
if (
|
||||||
|
this.lastQueueSize != 0 &&
|
||||||
|
status.exec_info.queue_remaining == 0 &&
|
||||||
|
document.getElementById("autoQueueCheckbox").checked
|
||||||
|
) {
|
||||||
app.queuePrompt(0, this.batchCount);
|
app.queuePrompt(0, this.batchCount);
|
||||||
}
|
}
|
||||||
this.lastQueueSize = status.exec_info.queue_remaining
|
this.lastQueueSize = status.exec_info.queue_remaining;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -306,7 +306,7 @@ export const ComfyWidgets = {
|
|||||||
const fileInput = document.createElement("input");
|
const fileInput = document.createElement("input");
|
||||||
Object.assign(fileInput, {
|
Object.assign(fileInput, {
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: "image/jpeg,image/png",
|
accept: "image/jpeg,image/png,image/webp",
|
||||||
style: "display: none",
|
style: "display: none",
|
||||||
onchange: async () => {
|
onchange: async () => {
|
||||||
if (fileInput.files.length) {
|
if (fileInput.files.length) {
|
||||||
|
|||||||
@ -39,18 +39,19 @@ body {
|
|||||||
position: fixed; /* Stay in place */
|
position: fixed; /* Stay in place */
|
||||||
z-index: 100; /* Sit on top */
|
z-index: 100; /* Sit on top */
|
||||||
padding: 30px 30px 10px 30px;
|
padding: 30px 30px 10px 30px;
|
||||||
background-color: #ff0000; /* Modal background */
|
background-color: #353535; /* Modal background */
|
||||||
|
color: #ff4444;
|
||||||
box-shadow: 0px 0px 20px #888888;
|
box-shadow: 0px 0px 20px #888888;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
text-align: center;
|
|
||||||
top: 50%;
|
top: 50%;
|
||||||
left: 50%;
|
left: 50%;
|
||||||
max-width: 80vw;
|
max-width: 80vw;
|
||||||
max-height: 80vh;
|
max-height: 80vh;
|
||||||
transform: translate(-50%, -50%);
|
transform: translate(-50%, -50%);
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
min-width: 60%;
|
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
|
font-family: monospace;
|
||||||
|
font-size: 15px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-modal-content {
|
.comfy-modal-content {
|
||||||
@ -70,31 +71,11 @@ body {
|
|||||||
margin: 3px 3px 3px 4px;
|
margin: 3px 3px 3px 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-modal button {
|
|
||||||
cursor: pointer;
|
|
||||||
color: #aaaaaa;
|
|
||||||
border: none;
|
|
||||||
background-color: transparent;
|
|
||||||
font-size: 24px;
|
|
||||||
font-weight: bold;
|
|
||||||
width: 100%;
|
|
||||||
}
|
|
||||||
|
|
||||||
.comfy-modal button:hover,
|
|
||||||
.comfy-modal button:focus {
|
|
||||||
color: #000;
|
|
||||||
text-decoration: none;
|
|
||||||
cursor: pointer;
|
|
||||||
}
|
|
||||||
|
|
||||||
.comfy-menu {
|
.comfy-menu {
|
||||||
width: 200px;
|
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
position: absolute;
|
position: absolute;
|
||||||
top: 50%;
|
top: 50%;
|
||||||
right: 0%;
|
right: 0%;
|
||||||
background-color: white;
|
|
||||||
color: #000;
|
|
||||||
text-align: center;
|
text-align: center;
|
||||||
z-index: 100;
|
z-index: 100;
|
||||||
width: 170px;
|
width: 170px;
|
||||||
@ -109,7 +90,8 @@ body {
|
|||||||
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
|
box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4);
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-menu button {
|
.comfy-menu button,
|
||||||
|
.comfy-modal button {
|
||||||
font-size: 20px;
|
font-size: 20px;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,7 +112,8 @@ body {
|
|||||||
|
|
||||||
.comfy-menu > button,
|
.comfy-menu > button,
|
||||||
.comfy-menu-btns button,
|
.comfy-menu-btns button,
|
||||||
.comfy-menu .comfy-list button {
|
.comfy-menu .comfy-list button,
|
||||||
|
.comfy-modal button{
|
||||||
color: #ddd;
|
color: #ddd;
|
||||||
background-color: #222;
|
background-color: #222;
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
@ -220,11 +203,22 @@ button.comfy-queue-btn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.comfy-modal.comfy-settings {
|
.comfy-modal.comfy-settings {
|
||||||
background-color: var(--bg-color);
|
text-align: center;
|
||||||
color: var(--fg-color);
|
font-family: sans-serif;
|
||||||
|
color: #999;
|
||||||
z-index: 99;
|
z-index: 99;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.comfy-modal input,
|
||||||
|
.comfy-modal select {
|
||||||
|
color: #ddd;
|
||||||
|
background-color: #222;
|
||||||
|
border-radius: 8px;
|
||||||
|
border-color: #4e4e4e;
|
||||||
|
border-style: solid;
|
||||||
|
font-size: inherit;
|
||||||
|
}
|
||||||
|
|
||||||
@media only screen and (max-height: 850px) {
|
@media only screen and (max-height: 850px) {
|
||||||
.comfy-menu {
|
.comfy-menu {
|
||||||
top: 0 !important;
|
top: 0 !important;
|
||||||
@ -237,3 +231,28 @@ button.comfy-queue-btn {
|
|||||||
visibility:hidden
|
visibility:hidden
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.graphdialog {
|
||||||
|
min-height: 1em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.graphdialog .name {
|
||||||
|
font-size: 14px;
|
||||||
|
font-family: sans-serif;
|
||||||
|
color: #999999;
|
||||||
|
}
|
||||||
|
|
||||||
|
.graphdialog button {
|
||||||
|
margin-top: unset;
|
||||||
|
vertical-align: unset;
|
||||||
|
height: 1.6em;
|
||||||
|
padding-right: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.graphdialog input, .graphdialog textarea, .graphdialog select {
|
||||||
|
background-color: #222;
|
||||||
|
border: 2px solid;
|
||||||
|
border-color: #444444;
|
||||||
|
color: #ddd;
|
||||||
|
border-radius: 12px 0 0 12px;
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user