diff --git a/README.md b/README.md index 0f7d24c45..90931141d 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) - Works even if you don't have a GPU with: ```--cpu``` (slow) -- Can load both ckpt and safetensors models/checkpoints. Standalone VAEs and CLIP models. +- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Embeddings/Textual inversion - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - Loading full workflows (with seeds) from generated PNG files. diff --git a/comfy/cli_args.py b/comfy/cli_args.py new file mode 100644 index 000000000..b24054ce0 --- /dev/null +++ b/comfy/cli_args.py @@ -0,0 +1,31 @@ +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") +parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") +parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") +parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") +parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") +parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") +parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") + +attn_group = parser.add_mutually_exclusive_group() +attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") +attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") + +parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") + +vram_group = parser.add_mutually_exclusive_group() +vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") +vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") +vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") +vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") +vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + +parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") +parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") +parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") + +args = parser.parse_args() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index cb29df432..efb2d5384 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,6 +1,7 @@ from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor from .utils import load_torch_file, transformers_convert import os +import torch class ClipVisionModel(): def __init__(self, json_config): @@ -20,7 +21,8 @@ class ClipVisionModel(): self.model.load_state_dict(sd, strict=False) def encode_image(self, image): - inputs = self.processor(images=[image[0]], return_tensors="pt") + img = torch.clip((255. * image[0]), 0, 255).round().int() + inputs = self.processor(images=[img], return_tensors="pt") outputs = self.model(**inputs) return outputs diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py new file mode 100644 index 000000000..ceca80305 --- /dev/null +++ b/comfy/diffusers_convert.py @@ -0,0 +1,362 @@ +import json +import os +import yaml + +import folder_paths +from comfy.ldm.util import instantiate_from_config +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +import os.path as osp +import re +import torch +from safetensors.torch import load_file, save_file + +# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py + +# =================# +# UNet Conversion # +# =================# + +unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), +] + +unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), +] + +unet_conversion_map_layer = [] +# hardcoded number of downblocks and resnets/attentions... +# would need smarter logic for other networks. +for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + +hf_mid_atn_prefix = "mid_block.attentions.0." +sd_mid_atn_prefix = "middle_block.1." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +def convert_unet_state_dict(unet_state_dict): + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + +vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), +] + +for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3 - i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3 - i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + +# this part accounts for mid blocks in both the encoder and the decoder +for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i + 1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + +vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), +] + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + return new_state_dict + + +# =========================# +# Text Encoder Conversion # +# =========================# + + +textenc_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + +# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp +code2idx = {"q": 0, "k": 1, "v": 2} + + +def convert_text_enc_state_dict_v20(text_enc_dict): + new_state_dict = {} + capture_qkv_weight = {} + capture_qkv_bias = {} + for k, v in text_enc_dict.items(): + if ( + k.endswith(".self_attn.q_proj.weight") + or k.endswith(".self_attn.k_proj.weight") + or k.endswith(".self_attn.v_proj.weight") + ): + k_pre = k[: -len(".q_proj.weight")] + k_code = k[-len("q_proj.weight")] + if k_pre not in capture_qkv_weight: + capture_qkv_weight[k_pre] = [None, None, None] + capture_qkv_weight[k_pre][code2idx[k_code]] = v + continue + + if ( + k.endswith(".self_attn.q_proj.bias") + or k.endswith(".self_attn.k_proj.bias") + or k.endswith(".self_attn.v_proj.bias") + ): + k_pre = k[: -len(".q_proj.bias")] + k_code = k[-len("q_proj.bias")] + if k_pre not in capture_qkv_bias: + capture_qkv_bias[k_pre] = [None, None, None] + capture_qkv_bias[k_pre][code2idx[k_code]] = v + continue + + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v + + for k_pre, tensors in capture_qkv_weight.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + + for k_pre, tensors in capture_qkv_bias.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + + return new_state_dict + + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + + +def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): + diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) + diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) + + # magic + v2 = diffusers_unet_conf["sample_size"] == 96 + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + + if v2: + if v_pred: + config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) + + model_config_params = config['model']['params'] + clip_config = model_config_params['cond_stage_config'] + scale_factor = model_config_params['scale_factor'] + vae_config = model_config_params['first_stage_config'] + vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + + clip = None + vae = None + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE(scale_factor=scale_factor, config=vae_config) + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + model = instantiate_from_config(config["model"]) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + if fp16: + model = model.half() + + return ModelPatcher(model), clip, vae diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 07553627c..92b3eca7c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -21,6 +21,8 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") +from cli_args import args + def exists(val): return val is not None @@ -474,7 +476,6 @@ class CrossAttentionPytorch(nn.Module): return self.to_out(out) -import sys if model_management.xformers_enabled(): print("Using xformers cross attention") CrossAttention = MemoryEfficientCrossAttention @@ -482,7 +483,7 @@ elif model_management.pytorch_attention_enabled(): print("Using pytorch cross attention") CrossAttention = CrossAttentionPytorch else: - if "--use-split-cross-attention" in sys.argv: + if args.use_split_cross_attention: print("Using split optimization for cross attention") CrossAttention = CrossAttentionDoggettx else: diff --git a/comfy/model_management.py b/comfy/model_management.py index 052dfb775..8303cb437 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,36 +1,42 @@ +import psutil +from enum import Enum +from cli_args import args -CPU = 0 -NO_VRAM = 1 -LOW_VRAM = 2 -NORMAL_VRAM = 3 -HIGH_VRAM = 4 -MPS = 5 +class VRAMState(Enum): + CPU = 0 + NO_VRAM = 1 + LOW_VRAM = 2 + NORMAL_VRAM = 3 + HIGH_VRAM = 4 + MPS = 5 -accelerate_enabled = False -vram_state = NORMAL_VRAM +# Determine VRAM State +vram_state = VRAMState.NORMAL_VRAM +set_vram_to = VRAMState.NORMAL_VRAM total_vram = 0 total_vram_available_mb = -1 -import sys -import psutil - -forced_cpu = "--cpu" in sys.argv - -set_vram_to = NORMAL_VRAM +accelerate_enabled = False +xpu_available = False try: import torch - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True + total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) + except: + total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) - forced_normal_vram = "--normalvram" in sys.argv - if not forced_normal_vram and not forced_cpu: + if not args.normalvram and not args.cpu: if total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = LOW_VRAM + set_vram_to = VRAMState.LOW_VRAM elif total_vram > total_ram * 1.1 and total_vram > 14336: print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = HIGH_VRAM + vram_state = VRAMState.HIGH_VRAM except: pass @@ -39,34 +45,50 @@ try: except: OOM_EXCEPTION = Exception -if "--disable-xformers" in sys.argv: - XFORMERS_IS_AVAILBLE = False +XFORMERS_VERSION = "" +XFORMERS_ENABLED_VAE = True +if args.disable_xformers: + XFORMERS_IS_AVAILABLE = False else: try: import xformers import xformers.ops - XFORMERS_IS_AVAILBLE = True + XFORMERS_IS_AVAILABLE = True + try: + XFORMERS_VERSION = xformers.version.__version__ + print("xformers version:", XFORMERS_VERSION) + if XFORMERS_VERSION.startswith("0.0.18"): + print() + print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") + print("Please downgrade or upgrade xformers to a different version.") + print() + XFORMERS_ENABLED_VAE = False + except: + pass except: - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = False -ENABLE_PYTORCH_ATTENTION = False -if "--use-pytorch-cross-attention" in sys.argv: +ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) - ENABLE_PYTORCH_ATTENTION = True - XFORMERS_IS_AVAILBLE = False + XFORMERS_IS_AVAILABLE = False + +if args.lowvram: + set_vram_to = VRAMState.LOW_VRAM +elif args.novram: + set_vram_to = VRAMState.NO_VRAM +elif args.highvram: + vram_state = VRAMState.HIGH_VRAM + +FORCE_FP32 = False +if args.force_fp32: + print("Forcing FP32, if this improves things please report it.") + FORCE_FP32 = True -if "--lowvram" in sys.argv: - set_vram_to = LOW_VRAM -if "--novram" in sys.argv: - set_vram_to = NO_VRAM -if "--highvram" in sys.argv: - vram_state = HIGH_VRAM - - -if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: +if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): try: import accelerate accelerate_enabled = True @@ -81,14 +103,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: try: if torch.backends.mps.is_available(): - vram_state = MPS + vram_state = VRAMState.MPS except: pass -if forced_cpu: - vram_state = CPU +if args.cpu: + vram_state = VRAMState.CPU -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) +print(f"Set vram state to: {vram_state.name}") current_loaded_model = None @@ -109,12 +131,12 @@ def unload_model(): model_accelerated = False #never unload models from GPU on high vram - if vram_state != HIGH_VRAM: + if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None - if vram_state != HIGH_VRAM: + if vram_state != VRAMState.HIGH_VRAM: if len(current_gpu_controlnets) > 0: for n in current_gpu_controlnets: n.cpu() @@ -135,32 +157,32 @@ def load_model_gpu(model): model.unpatch_model() raise e current_loaded_model = model - if vram_state == CPU: + if vram_state == VRAMState.CPU: pass - elif vram_state == MPS: + elif vram_state == VRAMState.MPS: mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: + elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: model_accelerated = False - real_model.cuda() + real_model.to(get_torch_device()) else: - if vram_state == NO_VRAM: + if vram_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == LOW_VRAM: + elif vram_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") + accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True return current_loaded_model def load_controlnet_gpu(models): global current_gpu_controlnets global vram_state - if vram_state == CPU: + if vram_state == VRAMState.CPU: return - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return @@ -176,23 +198,27 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state - if vram_state == LOW_VRAM or vram_state == NO_VRAM: - return model.cuda() + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: + return model.to(get_torch_device()) return model def unload_if_low_vram(model): global vram_state - if vram_state == LOW_VRAM or vram_state == NO_VRAM: + if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: return model.cpu() return model def get_torch_device(): - if vram_state == MPS: + global xpu_available + if vram_state == VRAMState.MPS: return torch.device("mps") - if vram_state == CPU: + if vram_state == VRAMState.CPU: return torch.device("cpu") else: - return torch.cuda.current_device() + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() def get_autocast_device(dev): if hasattr(dev, 'type'): @@ -201,27 +227,23 @@ def get_autocast_device(dev): def xformers_enabled(): - if vram_state == CPU: + if vram_state == VRAMState.CPU: return False - return XFORMERS_IS_AVAILBLE + return XFORMERS_IS_AVAILABLE def xformers_enabled_vae(): enabled = xformers_enabled() if not enabled: return False - try: - #0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above) - if xformers.version.__version__ == "0.0.18": - return False - except: - pass - return enabled + + return XFORMERS_ENABLED_VAE def pytorch_attention_enabled(): return ENABLE_PYTORCH_ATTENTION def get_free_memory(dev=None, torch_free_too=False): + global xpu_available if dev is None: dev = get_torch_device() @@ -229,12 +251,16 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: - stats = torch.cuda.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(dev) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + if xpu_available: + mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) + mem_free_torch = mem_free_total + else: + stats = torch.cuda.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch if torch_free_too: return (mem_free_total, mem_free_torch) @@ -243,7 +269,7 @@ def get_free_memory(dev=None, torch_free_too=False): def maximum_batch_area(): global vram_state - if vram_state == NO_VRAM: + if vram_state == VRAMState.NO_VRAM: return 0 memory_free = get_free_memory() / (1024 * 1024) @@ -252,14 +278,18 @@ def maximum_batch_area(): def cpu_mode(): global vram_state - return vram_state == CPU + return vram_state == VRAMState.CPU def mps_mode(): global vram_state - return vram_state == MPS + return vram_state == VRAMState.MPS def should_use_fp16(): - if cpu_mode() or mps_mode(): + global xpu_available + if FORCE_FP32: + return False + + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? if torch.cuda.is_bf16_supported(): diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index fb8172648..175202aeb 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -88,3 +88,8 @@ class Example: NODE_CLASS_MAPPINGS = { "Example": Example } + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "Example": "Example Node" +} diff --git a/folder_paths.py b/folder_paths.py index af56a6da1..ab3359347 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -23,10 +23,45 @@ folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) +folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") +temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") +input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + +if not os.path.exists(input_directory): + os.makedirs(input_directory) + +def set_output_directory(output_dir): + global output_directory + output_directory = output_dir + +def get_output_directory(): + global output_directory + return output_directory + +def get_temp_directory(): + global temp_directory + return temp_directory + +def get_input_directory(): + global input_directory + return input_directory + + +#NOTE: used in http server so don't put folders that should not be accessed remotely +def get_directory_by_type(type_name): + if type_name == "output": + return get_output_directory() + if type_name == "temp": + return get_temp_directory() + if type_name == "input": + return get_input_directory() + return None + def add_model_folder_path(folder_name, full_folder_path): global folder_names_and_paths diff --git a/main.py b/main.py index fbfaf6be5..9c0a3d8a1 100644 --- a/main.py +++ b/main.py @@ -1,56 +1,33 @@ -import os -import sys -import shutil - -import threading import asyncio +import itertools +import os +import shutil +import threading + +from comfy.cli_args import args if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": - if '--help' in sys.argv: - print() - print("Valid Command line Arguments:") - print("\t--listen [ip]\t\t\tListen on ip or 0.0.0.0 if none given so the UI can be accessed from other computers.") - print("\t--port 8188\t\t\tSet the listen port.") - print() - print("\t--extra-model-paths-config file.yaml\tload an extra_model_paths.yaml file.") - print() - print() - print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") - print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") - print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.") - print("\t--disable-xformers\t\tdisables xformers") - print("\t--cuda-device 1\t\tSet the id of the cuda device this instance will use.") - print() - print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n") - print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.") - print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.") - print("\t--novram\t\t\tWhen lowvram isn't enough.") - print() - print("\t--cpu\t\t\tTo use the CPU for everything (slow).") - exit() - - if '--dont-upcast-attention' in sys.argv: + if args.dont_upcast_attention: print("disabling upcasting of attention") os.environ['ATTN_PRECISION'] = "fp16" - try: - index = sys.argv.index('--cuda-device') - device = sys.argv[index + 1] - os.environ['CUDA_VISIBLE_DEVICES'] = device - print("Set cuda device to:", device) - except: - pass + if args.cuda_device is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) + print("Set cuda device to:", args.cuda_device) + -from nodes import init_custom_nodes -import execution -import server -import folder_paths import yaml +import execution +import folder_paths +import server +from nodes import init_custom_nodes + + def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: @@ -109,43 +86,31 @@ if __name__ == "__main__": hijack_progress(server) threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() - try: - address = '0.0.0.0' - p_index = sys.argv.index('--listen') - try: - ip = sys.argv[p_index + 1] - if ip[:2] != '--': - address = ip - except: - pass - except: - address = '127.0.0.1' - dont_print = False - if '--dont-print-server' in sys.argv: - dont_print = True + address = args.listen + + dont_print = args.dont_print_server extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): load_extra_path_config(extra_model_paths_config_path) - if '--extra-model-paths-config' in sys.argv: - indices = [(i + 1) for i in range(len(sys.argv) - 1) if sys.argv[i] == '--extra-model-paths-config'] - for i in indices: - load_extra_path_config(sys.argv[i]) + if args.extra_model_paths_config: + for config_path in itertools.chain(*args.extra_model_paths_config): + load_extra_path_config(config_path) - port = 8188 - try: - p_index = sys.argv.index('--port') - port = int(sys.argv[p_index + 1]) - except: - pass + if args.output_directory: + output_dir = os.path.abspath(args.output_directory) + print(f"Setting output directory to: {output_dir}") + folder_paths.set_output_directory(output_dir) - if '--quick-test-for-ci' in sys.argv: + port = args.port + + if args.quick_test_for_ci: exit(0) call_on_start = None - if "--windows-standalone-build" in sys.argv: + if args.windows_standalone_build: def startup_server(address, port): import webbrowser webbrowser.open("http://{}:{}".format(address, port)) diff --git a/models/diffusers/put_diffusers_models_here b/models/diffusers/put_diffusers_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index c4fffeb2a..897e615d5 100644 --- a/nodes.py +++ b/nodes.py @@ -4,16 +4,17 @@ import os import sys import json import hashlib -import copy import traceback from PIL import Image from PIL.PngImagePlugin import PngInfo import numpy as np + sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) +import comfy.diffusers_convert import comfy.samplers import comfy.sd import comfy.utils @@ -219,6 +220,30 @@ class CheckpointLoaderSimple: out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out +class DiffusersLoader: + @classmethod + def INPUT_TYPES(cls): + paths = [] + for search_path in folder_paths.get_folder_paths("diffusers"): + if os.path.exists(search_path): + paths += next(os.walk(search_path))[1] + return {"required": {"model_path": (paths,), }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "advanced/loaders" + + def load_checkpoint(self, model_path, output_vae=True, output_clip=True): + for search_path in folder_paths.get_folder_paths("diffusers"): + if os.path.exists(search_path): + paths = next(os.walk(search_path))[1] + if model_path in paths: + model_path = os.path.join(search_path, model_path) + break + + return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + + class unCLIPCheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -777,7 +802,7 @@ class KSamplerAdvanced: class SaveImage: def __init__(self): - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") + self.output_dir = folder_paths.get_output_directory() self.type = "output" @classmethod @@ -831,9 +856,6 @@ class SaveImage: os.makedirs(full_output_folder, exist_ok=True) counter = 1 - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) - results = list() for image in images: i = 255. * image.cpu().numpy() @@ -858,7 +880,7 @@ class SaveImage: class PreviewImage(SaveImage): def __init__(self): - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") + self.output_dir = folder_paths.get_temp_directory() self.type = "temp" @classmethod @@ -871,15 +893,13 @@ class PreviewImage(SaveImage): WIDGET_TYPES = {"send to img": ("IMAGESEND", "TEMP")} class LoadImage: - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") - output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") - temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") @classmethod def INPUT_TYPES(s): - if not os.path.exists(s.input_dir): - os.makedirs(s.input_dir) + input_dir = folder_paths.get_input_directory() + output_dir = folder_paths.get_output_directory() + temp_dir = folder_paths.get_temp_directory() return {"required": - {"image": (sorted(os.listdir(s.input_dir)), )}, + {"image": (sorted(os.listdir(input_dir)), )}, } WIDGET_TYPES = {"recv img": (["disable", "enable"], )} @@ -898,6 +918,7 @@ class LoadImage: return os.path.join(self.input_dir, image) def load_image(self, image): + input_dir = folder_paths.get_input_directory() image_path = LoadImage.get_image_path(self, image) i = Image.open(image_path) image = i.convert("RGB") @@ -912,6 +933,7 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): + input_dir = folder_paths.get_input_directory() image_path = LoadImage.get_image_path(s, image) m = hashlib.sha256() with open(image_path, 'rb') as f: @@ -919,13 +941,13 @@ class LoadImage: return m.digest().hex() class LoadImageMask: - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") - output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") - temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") @classmethod def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + output_dir = folder_paths.get_output_directory() + temp_dir = folder_paths.get_temp_directory() return {"required": - {"image": (sorted(os.listdir(s.input_dir)), ), + {"image": (sorted(os.listdir(input_dir)), ), "channel": (["alpha", "red", "green", "blue"], ),} } @@ -936,6 +958,7 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): + input_dir = folder_paths.get_input_directory() image_path = LoadImage.get_image_path(self, image) i = Image.open(image_path) mask = None @@ -951,6 +974,7 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): + input_dir = folder_paths.get_input_directory() image_path = LoadImage.get_image_path(s, image) m = hashlib.sha256() with open(image_path, 'rb') as f: @@ -1098,6 +1122,55 @@ NODE_CLASS_MAPPINGS = { "TomePatchModel": TomePatchModel, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "CheckpointLoader": CheckpointLoader, + "DiffusersLoader": DiffusersLoader, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "KSampler": "KSampler", + "KSamplerAdvanced": "KSampler (Advanced)", + # Loaders + "CheckpointLoader": "Load Checkpoint (With Config)", + "CheckpointLoaderSimple": "Load Checkpoint", + "VAELoader": "Load VAE", + "LoraLoader": "Load LoRA", + "CLIPLoader": "Load CLIP", + "ControlNetLoader": "Load ControlNet Model", + "DiffControlNetLoader": "Load ControlNet Model (diff)", + "StyleModelLoader": "Load Style Model", + "CLIPVisionLoader": "Load CLIP Vision", + "UpscaleModelLoader": "Load Upscale Model", + # Conditioning + "CLIPVisionEncode": "CLIP Vision Encode", + "StyleModelApply": "Apply Style Model", + "CLIPTextEncode": "CLIP Text Encode (Prompt)", + "CLIPSetLastLayer": "CLIP Set Last Layer", + "ConditioningCombine": "Conditioning (Combine)", + "ConditioningSetArea": "Conditioning (Set Area)", + "ControlNetApply": "Apply ControlNet", + # Latent + "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", + "SetLatentNoiseMask": "Set Latent Noise Mask", + "VAEDecode": "VAE Decode", + "VAEEncode": "VAE Encode", + "LatentRotate": "Rotate Latent", + "LatentFlip": "Flip Latent", + "LatentCrop": "Crop Latent", + "EmptyLatentImage": "Empty Latent Image", + "LatentUpscale": "Upscale Latent", + "LatentComposite": "Latent Composite", + # Image + "SaveImage": "Save Image", + "PreviewImage": "Preview Image", + "LoadImage": "Load Image", + "LoadImageMask": "Load Image (as Mask)", + "ImageScale": "Upscale Image", + "ImageUpscaleWithModel": "Upscale Image (using Model)", + "ImageInvert": "Invert Image", + "ImagePadForOutpaint": "Pad Image for Outpainting", + # _for_testing + "VAEDecodeTiled": "VAE Decode (Tiled)", + "VAEEncodeTiled": "VAE Encode (Tiled)", } def load_custom_node(module_path): @@ -1115,6 +1188,8 @@ def load_custom_node(module_path): module_spec.loader.exec_module(module) if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) + if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: + NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) else: print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") except Exception as e: diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 3e59fbde7..d17f9877d 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -47,7 +47,7 @@ " !git pull\n", "\n", "!echo -= Install dependencies =-\n", - "!pip install xformers -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118" + "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118" ] }, { diff --git a/requirements.txt b/requirements.txt index 3b4040a29..0527b31df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ torchsde einops open-clip-torch transformers>=4.25.1 -safetensors +safetensors>=0.3.0 pytorch_lightning aiohttp accelerate diff --git a/server.py b/server.py index d6e4013b3..9250c8bd9 100644 --- a/server.py +++ b/server.py @@ -7,7 +7,6 @@ import execution import uuid import json import glob - try: import aiohttp from aiohttp import web @@ -19,6 +18,7 @@ except ImportError: sys.exit() import mimetypes +from comfy.cli_args import args @web.middleware @@ -28,6 +28,23 @@ async def cache_control(request: web.Request, handler): response.headers.setdefault('Cache-Control', 'no-cache') return response +def create_cors_middleware(allowed_origin: str): + @web.middleware + async def cors_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + + response.headers['Access-Control-Allow-Origin'] = allowed_origin + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Credentials'] = 'true' + return response + + return cors_middleware + class PromptServer(): def __init__(self, loop): PromptServer.instance = self @@ -38,7 +55,12 @@ class PromptServer(): self.loop = loop self.messages = asyncio.Queue() self.number = 0 - self.app = web.Application(client_max_size=20971520, middlewares=[cache_control]) + + middlewares = [cache_control] + if args.enable_cors_header: + middlewares.append(create_cors_middleware(args.enable_cors_header)) + + self.app = web.Application(client_max_size=20971520, middlewares=middlewares) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") @@ -90,7 +112,7 @@ class PromptServer(): @routes.post("/upload/image") async def upload_image(request): - upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + upload_dir = folder_paths.get_input_directory() if not os.path.exists(upload_dir): os.makedirs(upload_dir) @@ -123,10 +145,10 @@ class PromptServer(): async def view_image(request): if "filename" in request.rel_url.query: type = request.rel_url.query.get("type", "output") - if type not in ["output", "input", "temp"]: + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) - output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type) if "subfolder" in request.rel_url.query: full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: @@ -155,9 +177,10 @@ class PromptServer(): info['input'] = obj_class.INPUT_TYPES() info['output'] = obj_class.RETURN_TYPES info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = x + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x if hasattr(obj_class, 'WIDGET_TYPES'): info['widget'] = obj_class.WIDGET_TYPES - info['name'] = x #TODO info['description'] = '' info['category'] = 'sd' if hasattr(obj_class, 'CATEGORY'): diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index e54bc2a38..a08d46684 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -21,28 +21,74 @@ const colorPalettes = { "MODEL": "#B39DDB", // light lavender-purple "STYLE_MODEL": "#C2FFAE", // light green-yellow "VAE": "#FF6E6E", // bright red - } - } + }, + "litegraph_base": { + "NODE_TITLE_COLOR": "#999", + "NODE_SELECTED_TITLE_COLOR": "#FFF", + "NODE_TEXT_SIZE": 14, + "NODE_TEXT_COLOR": "#AAA", + "NODE_SUBTEXT_SIZE": 12, + "NODE_DEFAULT_COLOR": "#333", + "NODE_DEFAULT_BGCOLOR": "#353535", + "NODE_DEFAULT_BOXCOLOR": "#666", + "NODE_DEFAULT_SHAPE": "box", + "NODE_BOX_OUTLINE_COLOR": "#FFF", + "DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)", + "DEFAULT_GROUP_FONT": 24, + + "WIDGET_BGCOLOR": "#222", + "WIDGET_OUTLINE_COLOR": "#666", + "WIDGET_TEXT_COLOR": "#DDD", + "WIDGET_SECONDARY_TEXT_COLOR": "#999", + + "LINK_COLOR": "#9A9", + "EVENT_LINK_COLOR": "#A86", + "CONNECTING_LINK_COLOR": "#AFA", + }, + }, }, - "palette_2": { - "id": "palette_2", - "name": "Palette 2", + "solarized": { + "id": "solarized", + "name": "Solarized", "colors": { "node_slot": { - "CLIP": "#556B2F", // Dark Olive Green - "CLIP_VISION": "#4B0082", // Indigo - "CLIP_VISION_OUTPUT": "#006400", // Green - "CONDITIONING": "#FF1493", // Deep Pink - "CONTROL_NET": "#8B4513", // Saddle Brown - "IMAGE": "#8B0000", // Dark Red - "LATENT": "#00008B", // Dark Blue - "MASK": "#2F4F4F", // Dark Slate Grey - "MODEL": "#FF8C00", // Dark Orange - "STYLE_MODEL": "#004A4A", // Sherpa Blue - "UPSCALE_MODEL": "#4A004A", // Tyrian Purple - "VAE": "#4F394F", // Loulou - } - } + "CLIP": "#859900", // Green + "CLIP_VISION": "#6c71c4", // Indigo + "CLIP_VISION_OUTPUT": "#859900", // Green + "CONDITIONING": "#d33682", // Magenta + "CONTROL_NET": "#cb4b16", // Orange + "IMAGE": "#dc322f", // Red + "LATENT": "#268bd2", // Blue + "MASK": "#073642", // Base02 + "MODEL": "#cb4b16", // Orange + "STYLE_MODEL": "#073642", // Base02 + "UPSCALE_MODEL": "#6c71c4", // Indigo + "VAE": "#586e75", // Base1 + }, + "litegraph_base": { + "NODE_TITLE_COLOR": "#fdf6e3", + "NODE_SELECTED_TITLE_COLOR": "#b58900", + "NODE_TEXT_SIZE": 14, + "NODE_TEXT_COLOR": "#657b83", + "NODE_SUBTEXT_SIZE": 12, + "NODE_DEFAULT_COLOR": "#586e75", + "NODE_DEFAULT_BGCOLOR": "#073642", + "NODE_DEFAULT_BOXCOLOR": "#839496", + "NODE_DEFAULT_SHAPE": "box", + "NODE_BOX_OUTLINE_COLOR": "#fdf6e3", + "DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)", + "DEFAULT_GROUP_FONT": 24, + + "WIDGET_BGCOLOR": "#002b36", + "WIDGET_OUTLINE_COLOR": "#839496", + "WIDGET_TEXT_COLOR": "#fdf6e3", + "WIDGET_SECONDARY_TEXT_COLOR": "#93a1a1", + + "LINK_COLOR": "#2aa198", + "EVENT_LINK_COLOR": "#268bd2", + "CONNECTING_LINK_COLOR": "#859900", + }, + }, } }; @@ -192,8 +238,20 @@ app.registerExtension({ if (colorPalette.colors) { if (colorPalette.colors.node_slot) { Object.assign(app.canvas.default_connection_color_byType, colorPalette.colors.node_slot); - app.canvas.draw(true, true); + Object.assign(LGraphCanvas.link_type_colors, colorPalette.colors.node_slot); } + if (colorPalette.colors.litegraph_base) { + // Everything updates correctly in the loop, except the Node Title and Link Color for some reason + app.canvas.node_title_color = colorPalette.colors.litegraph_base.NODE_TITLE_COLOR; + app.canvas.default_link_color = colorPalette.colors.litegraph_base.LINK_COLOR; + + for (const key in colorPalette.colors.litegraph_base) { + if (colorPalette.colors.litegraph_base.hasOwnProperty(key) && LiteGraph.hasOwnProperty(key)) { + LiteGraph[key] = colorPalette.colors.litegraph_base[key]; + } + } + } + app.canvas.draw(true, true); } }; diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js new file mode 100644 index 000000000..69d09cde8 --- /dev/null +++ b/web/extensions/core/nodeTemplates.js @@ -0,0 +1,184 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; + +// Adds the ability to save and add multiple nodes as a template +// To save: +// Select multiple nodes (ctrl + drag to select a region or ctrl+click individual nodes) +// Right click the canvas +// Save Node Template -> give it a name +// +// To add: +// Right click the canvas +// Node templates -> click the one to add +// +// To delete/rename: +// Right click the canvas +// Node templates -> Manage + +const id = "Comfy.NodeTemplates"; + +class ManageTemplates extends ComfyDialog { + constructor() { + super(); + this.element.classList.add("comfy-manage-templates"); + this.templates = this.load(); + } + + createButtons() { + const btns = super.createButtons(); + btns[0].textContent = "Cancel"; + btns.unshift( + $el("button", { + type: "button", + textContent: "Save", + onclick: () => this.save(), + }) + ); + return btns; + } + + load() { + const templates = localStorage.getItem(id); + if (templates) { + return JSON.parse(templates); + } else { + return []; + } + } + + save() { + // Find all visible inputs and save them as our new list + const inputs = this.element.querySelectorAll("input"); + const updated = []; + + for (let i = 0; i < inputs.length; i++) { + const input = inputs[i]; + if (input.parentElement.style.display !== "none") { + const t = this.templates[i]; + t.name = input.value.trim() || input.getAttribute("data-name"); + updated.push(t); + } + } + + this.templates = updated; + this.store(); + this.close(); + } + + store() { + localStorage.setItem(id, JSON.stringify(this.templates)); + } + + show() { + // Show list of template names + delete button + super.show( + $el( + "div", + { + style: { + display: "grid", + gridTemplateColumns: "1fr auto", + gap: "5px", + }, + }, + this.templates.flatMap((t) => { + let nameInput; + return [ + $el( + "label", + { + textContent: "Name: ", + }, + [ + $el("input", { + value: t.name, + dataset: { name: t.name }, + $: (el) => (nameInput = el), + }), + ] + ), + $el("button", { + textContent: "Delete", + style: { + fontSize: "12px", + color: "red", + fontWeight: "normal", + }, + onclick: (e) => { + nameInput.value = ""; + e.target.style.display = "none"; + e.target.previousElementSibling.style.display = "none"; + }, + }), + ]; + }) + ) + ); + } +} + +app.registerExtension({ + name: id, + setup() { + const manage = new ManageTemplates(); + + const clipboardAction = (cb) => { + // We use the clipboard functions but dont want to overwrite the current user clipboard + // Restore it after we've run our callback + const old = localStorage.getItem("litegrapheditor_clipboard"); + cb(); + localStorage.setItem("litegrapheditor_clipboard", old); + }; + + const orig = LGraphCanvas.prototype.getCanvasMenuOptions; + LGraphCanvas.prototype.getCanvasMenuOptions = function () { + const options = orig.apply(this, arguments); + + options.push(null); + options.push({ + content: `Save Selected as Template`, + disabled: !Object.keys(app.canvas.selected_nodes || {}).length, + callback: () => { + const name = prompt("Enter name"); + if (!name || !name.trim()) return; + + clipboardAction(() => { + app.canvas.copyToClipboard(); + manage.templates.push({ + name, + data: localStorage.getItem("litegrapheditor_clipboard"), + }); + manage.store(); + }); + }, + }); + + // Map each template to a menu item + const subItems = manage.templates.map((t) => ({ + content: t.name, + callback: () => { + clipboardAction(() => { + localStorage.setItem("litegrapheditor_clipboard", t.data); + app.canvas.pasteFromClipboard(); + }); + }, + })); + + if (subItems.length) { + subItems.push(null, { + content: "Manage", + callback: () => manage.show(), + }); + + options.push({ + content: "Node Templates", + submenu: { + options: subItems, + }, + }); + } + + return options; + }; + }, +}); diff --git a/web/index.html b/web/index.html index 86156a7f8..bb79433ce 100644 --- a/web/index.html +++ b/web/index.html @@ -2,6 +2,7 @@ + ComfyUI diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 862d59067..c3efa22a9 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -89,6 +89,7 @@ NO_TITLE: 1, TRANSPARENT_TITLE: 2, AUTOHIDE_TITLE: 3, + VERTICAL_LAYOUT: "vertical", // arrange nodes vertically proxy: null, //used to redirect calls node_images_path: "", @@ -125,14 +126,14 @@ registered_slot_out_types: {}, // slot types for nodeclass slot_types_in: [], // slot types IN slot_types_out: [], // slot types OUT - slot_types_default_in: [], // specify for each IN slot type a(/many) deafult node(s), use single string, array, or object (with node, title, parameters, ..) like for search - slot_types_default_out: [], // specify for each OUT slot type a(/many) deafult node(s), use single string, array, or object (with node, title, parameters, ..) like for search + slot_types_default_in: [], // specify for each IN slot type a(/many) default node(s), use single string, array, or object (with node, title, parameters, ..) like for search + slot_types_default_out: [], // specify for each OUT slot type a(/many) default node(s), use single string, array, or object (with node, title, parameters, ..) like for search alt_drag_do_clone_nodes: false, // [true!] very handy, ALT click to clone and drag the new node do_add_triggers_slots: false, // [true!] will create and connect event slots when using action/events connections, !WILL CHANGE node mode when using onTrigger (enable mode colors), onExecuted does not need this - allow_multi_output_for_events: true, // [false!] being events, it is strongly reccomended to use them sequentually, one by one + allow_multi_output_for_events: true, // [false!] being events, it is strongly reccomended to use them sequentially, one by one middle_click_slot_add_default_node: false, //[true!] allows to create and connect a ndoe clicking with the third button (wheel) @@ -158,80 +159,67 @@ console.log("Node registered: " + type); } - var categories = type.split("/"); - var classname = base_class.name; + const classname = base_class.name; - var pos = type.lastIndexOf("/"); - base_class.category = type.substr(0, pos); + const pos = type.lastIndexOf("/"); + base_class.category = type.substring(0, pos); if (!base_class.title) { base_class.title = classname; } - //info.name = name.substr(pos+1,name.length - pos); //extend class - if (base_class.prototype) { - //is a class - for (var i in LGraphNode.prototype) { - if (!base_class.prototype[i]) { - base_class.prototype[i] = LGraphNode.prototype[i]; - } + for (var i in LGraphNode.prototype) { + if (!base_class.prototype[i]) { + base_class.prototype[i] = LGraphNode.prototype[i]; } } - var prev = this.registered_node_types[type]; - if(prev) - console.log("replacing node type: " + type); - else - { - if( !Object.hasOwnProperty( base_class.prototype, "shape") ) - Object.defineProperty(base_class.prototype, "shape", { - set: function(v) { - switch (v) { - case "default": - delete this._shape; - break; - case "box": - this._shape = LiteGraph.BOX_SHAPE; - break; - case "round": - this._shape = LiteGraph.ROUND_SHAPE; - break; - case "circle": - this._shape = LiteGraph.CIRCLE_SHAPE; - break; - case "card": - this._shape = LiteGraph.CARD_SHAPE; - break; - default: - this._shape = v; - } - }, - get: function(v) { - return this._shape; - }, - enumerable: true, - configurable: true - }); + const prev = this.registered_node_types[type]; + if(prev) { + console.log("replacing node type: " + type); + } + if( !Object.prototype.hasOwnProperty.call( base_class.prototype, "shape") ) { + Object.defineProperty(base_class.prototype, "shape", { + set: function(v) { + switch (v) { + case "default": + delete this._shape; + break; + case "box": + this._shape = LiteGraph.BOX_SHAPE; + break; + case "round": + this._shape = LiteGraph.ROUND_SHAPE; + break; + case "circle": + this._shape = LiteGraph.CIRCLE_SHAPE; + break; + case "card": + this._shape = LiteGraph.CARD_SHAPE; + break; + default: + this._shape = v; + } + }, + get: function() { + return this._shape; + }, + enumerable: true, + configurable: true + }); + - //warnings - if (base_class.prototype.onPropertyChange) { - console.warn( - "LiteGraph node class " + - type + - " has onPropertyChange method, it must be called onPropertyChanged with d at the end" - ); - } - - //used to know which nodes create when dragging files to the canvas - if (base_class.supported_extensions) { - for (var i in base_class.supported_extensions) { - var ext = base_class.supported_extensions[i]; - if(ext && ext.constructor === String) - this.node_types_by_file_extension[ ext.toLowerCase() ] = base_class; - } - } - } + //used to know which nodes to create when dragging files to the canvas + if (base_class.supported_extensions) { + for (let i in base_class.supported_extensions) { + const ext = base_class.supported_extensions[i]; + if(ext && ext.constructor === String) { + this.node_types_by_file_extension[ ext.toLowerCase() ] = base_class; + } + } + } + } this.registered_node_types[type] = base_class; if (base_class.constructor.name) { @@ -252,19 +240,11 @@ " has onPropertyChange method, it must be called onPropertyChanged with d at the end" ); } - - //used to know which nodes create when dragging files to the canvas - if (base_class.supported_extensions) { - for (var i=0; i < base_class.supported_extensions.length; i++) { - var ext = base_class.supported_extensions[i]; - if(ext && ext.constructor === String) - this.node_types_by_file_extension[ ext.toLowerCase() ] = base_class; - } - } - // TODO one would want to know input and ouput :: this would allow trought registerNodeAndSlotType to get all the slots types - //console.debug("Registering "+type); - if (this.auto_load_slot_types) nodeTmp = new base_class(base_class.title || "tmpnode"); + // TODO one would want to know input and ouput :: this would allow through registerNodeAndSlotType to get all the slots types + if (this.auto_load_slot_types) { + new base_class(base_class.title || "tmpnode"); + } }, /** @@ -1260,37 +1240,39 @@ * Positions every node in a more readable manner * @method arrange */ - LGraph.prototype.arrange = function(margin) { + LGraph.prototype.arrange = function (margin, layout) { margin = margin || 100; - var nodes = this.computeExecutionOrder(false, true); - var columns = []; - for (var i = 0; i < nodes.length; ++i) { - var node = nodes[i]; - var col = node._level || 1; + const nodes = this.computeExecutionOrder(false, true); + const columns = []; + for (let i = 0; i < nodes.length; ++i) { + const node = nodes[i]; + const col = node._level || 1; if (!columns[col]) { columns[col] = []; } columns[col].push(node); } - var x = margin; + let x = margin; - for (var i = 0; i < columns.length; ++i) { - var column = columns[i]; + for (let i = 0; i < columns.length; ++i) { + const column = columns[i]; if (!column) { continue; } - var max_size = 100; - var y = margin + LiteGraph.NODE_TITLE_HEIGHT; - for (var j = 0; j < column.length; ++j) { - var node = column[j]; - node.pos[0] = x; - node.pos[1] = y; - if (node.size[0] > max_size) { - max_size = node.size[0]; + let max_size = 100; + let y = margin + LiteGraph.NODE_TITLE_HEIGHT; + for (let j = 0; j < column.length; ++j) { + const node = column[j]; + node.pos[0] = (layout == LiteGraph.VERTICAL_LAYOUT) ? y : x; + node.pos[1] = (layout == LiteGraph.VERTICAL_LAYOUT) ? x : y; + const max_size_index = (layout == LiteGraph.VERTICAL_LAYOUT) ? 1 : 0; + if (node.size[max_size_index] > max_size) { + max_size = node.size[max_size_index]; } - y += node.size[1] + margin + LiteGraph.NODE_TITLE_HEIGHT; + const node_size_index = (layout == LiteGraph.VERTICAL_LAYOUT) ? 0 : 1; + y += node.size[node_size_index] + margin + LiteGraph.NODE_TITLE_HEIGHT; } x += max_size + margin; } @@ -2468,43 +2450,34 @@ this.title = this.constructor.title; } - if (this.onConnectionsChange) { - if (this.inputs) { - for (var i = 0; i < this.inputs.length; ++i) { - var input = this.inputs[i]; - var link_info = this.graph - ? this.graph.links[input.link] - : null; - this.onConnectionsChange( - LiteGraph.INPUT, - i, - true, - link_info, - input - ); //link_info has been created now, so its updated - } - } + if (this.inputs) { + for (var i = 0; i < this.inputs.length; ++i) { + var input = this.inputs[i]; + var link_info = this.graph ? this.graph.links[input.link] : null; + if (this.onConnectionsChange) + this.onConnectionsChange( LiteGraph.INPUT, i, true, link_info, input ); //link_info has been created now, so its updated - if (this.outputs) { - for (var i = 0; i < this.outputs.length; ++i) { - var output = this.outputs[i]; - if (!output.links) { - continue; - } - for (var j = 0; j < output.links.length; ++j) { - var link_info = this.graph - ? this.graph.links[output.links[j]] - : null; - this.onConnectionsChange( - LiteGraph.OUTPUT, - i, - true, - link_info, - output - ); //link_info has been created now, so its updated - } - } - } + if( this.onInputAdded ) + this.onInputAdded(input); + + } + } + + if (this.outputs) { + for (var i = 0; i < this.outputs.length; ++i) { + var output = this.outputs[i]; + if (!output.links) { + continue; + } + for (var j = 0; j < output.links.length; ++j) { + var link_info = this.graph ? this.graph.links[output.links[j]] : null; + if (this.onConnectionsChange) + this.onConnectionsChange( LiteGraph.OUTPUT, i, true, link_info, output ); //link_info has been created now, so its updated + } + + if( this.onOutputAdded ) + this.onOutputAdded(output); + } } if( this.widgets ) @@ -3200,6 +3173,15 @@ return; } + if(slot == null) + { + console.error("slot must be a number"); + return; + } + + if(slot.constructor !== Number) + console.warn("slot must be a number, use node.trigger('name') if you want to use a string"); + var output = this.outputs[slot]; if (!output) { return; @@ -3346,26 +3328,26 @@ * @param {Object} extra_info this can be used to have special properties of an output (label, special color, position, etc) */ LGraphNode.prototype.addOutput = function(name, type, extra_info) { - var o = { name: name, type: type, links: null }; + var output = { name: name, type: type, links: null }; if (extra_info) { for (var i in extra_info) { - o[i] = extra_info[i]; + output[i] = extra_info[i]; } } if (!this.outputs) { this.outputs = []; } - this.outputs.push(o); + this.outputs.push(output); if (this.onOutputAdded) { - this.onOutputAdded(o); + this.onOutputAdded(output); } if (LiteGraph.auto_load_slot_types) LiteGraph.registerNodeAndSlotType(this,type,true); this.setSize( this.computeSize() ); this.setDirtyCanvas(true, true); - return o; + return output; }; /** @@ -3437,10 +3419,10 @@ */ LGraphNode.prototype.addInput = function(name, type, extra_info) { type = type || 0; - var o = { name: name, type: type, link: null }; + var input = { name: name, type: type, link: null }; if (extra_info) { for (var i in extra_info) { - o[i] = extra_info[i]; + input[i] = extra_info[i]; } } @@ -3448,17 +3430,17 @@ this.inputs = []; } - this.inputs.push(o); + this.inputs.push(input); this.setSize( this.computeSize() ); if (this.onInputAdded) { - this.onInputAdded(o); + this.onInputAdded(input); } LiteGraph.registerNodeAndSlotType(this,type); this.setDirtyCanvas(true, true); - return o; + return input; }; /** @@ -5210,6 +5192,7 @@ LGraphNode.prototype.executeAction = function(action) this.allow_dragcanvas = true; this.allow_dragnodes = true; this.allow_interaction = true; //allow to control widgets, buttons, collapse, etc + this.multi_select = false; //allow selecting multi nodes without pressing extra keys this.allow_searchbox = true; this.allow_reconnect_links = true; //allows to change a connection with having to redo it again this.align_to_grid = false; //snap to grid @@ -5435,7 +5418,7 @@ LGraphNode.prototype.executeAction = function(action) }; /** - * returns the visualy active graph (in case there are more in the stack) + * returns the visually active graph (in case there are more in the stack) * @method getCurrentGraph * @return {LGraph} the active graph */ @@ -6060,9 +6043,13 @@ LGraphNode.prototype.executeAction = function(action) this.graph.beforeChange(); this.node_dragged = node; } - if (!this.selected_nodes[node.id]) { - this.processNodeSelected(node, e); - } + this.processNodeSelected(node, e); + } else { // double-click + /** + * Don't call the function if the block is already selected. + * Otherwise, it could cause the block to be unselected while its panel is open. + */ + if (!node.is_selected) this.processNodeSelected(node, e); } this.dirty_canvas = true; @@ -6474,6 +6461,10 @@ LGraphNode.prototype.executeAction = function(action) var n = this.selected_nodes[i]; n.pos[0] += delta[0] / this.ds.scale; n.pos[1] += delta[1] / this.ds.scale; + if (!n.is_selected) this.processNodeSelected(n, e); /* + * Don't call the function if the block is already selected. + * Otherwise, it could cause the block to be unselected while dragging. + */ } this.dirty_canvas = true; @@ -7287,7 +7278,7 @@ LGraphNode.prototype.executeAction = function(action) }; LGraphCanvas.prototype.processNodeSelected = function(node, e) { - this.selectNode(node, e && (e.shiftKey||e.ctrlKey)); + this.selectNode(node, e && (e.shiftKey || e.ctrlKey || this.multi_select)); if (this.onNodeSelected) { this.onNodeSelected(node); } @@ -7323,6 +7314,7 @@ LGraphNode.prototype.executeAction = function(action) for (var i in nodes) { var node = nodes[i]; if (node.is_selected) { + this.deselectNode(node); continue; } @@ -9742,13 +9734,17 @@ LGraphNode.prototype.executeAction = function(action) ctx.fillRect(margin, y, widget_width - margin * 2, H); var range = w.options.max - w.options.min; var nvalue = (w.value - w.options.min) / range; - ctx.fillStyle = active_widget == w ? "#89A" : "#678"; + if(nvalue < 0.0) nvalue = 0.0; + if(nvalue > 1.0) nvalue = 1.0; + ctx.fillStyle = w.options.hasOwnProperty("slider_color") ? w.options.slider_color : (active_widget == w ? "#89A" : "#678"); ctx.fillRect(margin, y, nvalue * (widget_width - margin * 2), H); if(show_text && !w.disabled) ctx.strokeRect(margin, y, widget_width - margin * 2, H); if (w.marker) { var marker_nvalue = (w.marker - w.options.min) / range; - ctx.fillStyle = "#AA9"; + if(marker_nvalue < 0.0) marker_nvalue = 0.0; + if(marker_nvalue > 1.0) marker_nvalue = 1.0; + ctx.fillStyle = w.options.hasOwnProperty("marker_color") ? w.options.marker_color : "#AA9"; ctx.fillRect( margin + marker_nvalue * (widget_width - margin * 2), y, 2, H ); } if (show_text) { @@ -9915,6 +9911,7 @@ LGraphNode.prototype.executeAction = function(action) case "slider": var range = w.options.max - w.options.min; var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1); + if(w.options.read_only) break; w.value = w.options.min + (w.options.max - w.options.min) * nvalue; if (w.callback) { setTimeout(function() { @@ -9926,8 +9923,16 @@ LGraphNode.prototype.executeAction = function(action) case "number": case "combo": var old_value = w.value; - if (event.type == LiteGraph.pointerevents_method+"move" && w.type == "number") { - w.value += event.deltaX * 0.1 * (w.options.step || 1); + var delta = x < 40 ? -1 : x > widget_width - 40 ? 1 : 0; + var allow_scroll = true; + if (delta) { + if (x > -3 && x < widget_width + 3) { + allow_scroll = false; + } + } + if (allow_scroll && event.type == LiteGraph.pointerevents_method+"move" && w.type == "number") { + if(event.deltaX) + w.value += event.deltaX * 0.1 * (w.options.step || 1); if ( w.options.min != null && w.value < w.options.min ) { w.value = w.options.min; } @@ -9994,6 +9999,12 @@ LGraphNode.prototype.executeAction = function(action) var delta = x < 40 ? -1 : x > widget_width - 40 ? 1 : 0; if (event.click_time < 200 && delta == 0) { this.prompt("Value",w.value,function(v) { + // check if v is a valid equation or a number + if (/^[0-9+\-*/()\s]+$/.test(v)) { + try {//solve the equation if possible + v = eval(v); + } catch (e) { } + } this.value = Number(v); inner_value_change(this, this.value); }.bind(w), @@ -10022,7 +10033,6 @@ LGraphNode.prototype.executeAction = function(action) case "text": if (event.type == LiteGraph.pointerevents_method+"down") { this.prompt("Value",w.value,function(v) { - this.value = v; inner_value_change(this, v); }.bind(w), event,w.options ? w.options.multiline : false ); @@ -10047,6 +10057,9 @@ LGraphNode.prototype.executeAction = function(action) }//end for function inner_value_change(widget, value) { + if(widget.type == "number"){ + value = Number(value); + } widget.value = value; if ( widget.options && widget.options.property && node.properties[widget.options.property] !== undefined ) { node.setProperty( widget.options.property, value ); @@ -11165,7 +11178,7 @@ LGraphNode.prototype.executeAction = function(action) LGraphCanvas.search_limit = -1; LGraphCanvas.prototype.showSearchBox = function(event, options) { // proposed defaults - def_options = { slot_from: null + var def_options = { slot_from: null ,node_from: null ,node_to: null ,do_type_filter: LiteGraph.search_filter_enabled // TODO check for registered_slot_[in/out]_types not empty // this will be checked for functionality enabled : filter on slot type, in and out @@ -11863,7 +11876,7 @@ LGraphNode.prototype.executeAction = function(action) // TODO refactor, theer are different dialog, some uses createDialog, some dont LGraphCanvas.prototype.createDialog = function(html, options) { - def_options = { checkForInput: false, closeOnLeave: true, closeOnLeave_checkModified: true }; + var def_options = { checkForInput: false, closeOnLeave: true, closeOnLeave_checkModified: true }; options = Object.assign(def_options, options || {}); var dialog = document.createElement("div"); @@ -11993,7 +12006,8 @@ LGraphNode.prototype.executeAction = function(action) if (root.onClose && typeof root.onClose == "function"){ root.onClose(); } - root.parentNode.removeChild(root); + if(root.parentNode) + root.parentNode.removeChild(root); /* XXX CHECK THIS */ if(this.parentNode){ this.parentNode.removeChild(this); @@ -12285,7 +12299,7 @@ LGraphNode.prototype.executeAction = function(action) var ref_window = this.getCanvasWindow(); var that = this; var graphcanvas = this; - panel = this.createPanel(node.title || "",{ + var panel = this.createPanel(node.title || "",{ closable: true ,window: ref_window ,onOpen: function(){ diff --git a/web/scripts/app.js b/web/scripts/app.js index 132509390..edbe8ae76 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1,5 +1,5 @@ import { ComfyWidgets } from "./widgets.js"; -import { ComfyUI } from "./ui.js"; +import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js"; @@ -112,6 +112,46 @@ class ComfyApp { }; } + #addNodeKeyHandler(node) { + const app = this; + const origNodeOnKeyDown = node.prototype.onKeyDown; + + node.prototype.onKeyDown = function(e) { + if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) { + return false; + } + + if (this.flags.collapsed || !this.imgs || this.imageIndex === null) { + return; + } + + let handled = false; + + if (e.key === "ArrowLeft" || e.key === "ArrowRight") { + if (e.key === "ArrowLeft") { + this.imageIndex -= 1; + } else if (e.key === "ArrowRight") { + this.imageIndex += 1; + } + this.imageIndex %= this.imgs.length; + + if (this.imageIndex < 0) { + this.imageIndex = this.imgs.length + this.imageIndex; + } + handled = true; + } else if (e.key === "Escape") { + this.imageIndex = null; + handled = true; + } + + if (handled === true) { + e.preventDefault(); + e.stopImmediatePropagation(); + return false; + } + } + } + /** * Adds Custom drawing logic for nodes * e.g. Draws images and handles thumbnail navigation on nodes that output images @@ -798,7 +838,7 @@ class ComfyApp { app.#invokeExtensionsAsync("nodeCreated", this); }, { - title: nodeData.name, + title: nodeData.display_name || nodeData.name, comfyClass: nodeData.name, } ); @@ -806,6 +846,7 @@ class ComfyApp { this.#addNodeContextMenuHandler(node); this.#addDrawBackgroundHandler(node, app); + this.#addNodeKeyHandler(node); await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); LiteGraph.registerNodeType(nodeId, node); @@ -826,12 +867,62 @@ class ComfyApp { graphData = structuredClone(defaultGraph); } - // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now + const missingNodeTypes = []; for (let n of graphData.nodes) { + // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; + + // Find missing node types + if (!(n.type in LiteGraph.registered_node_types)) { + missingNodeTypes.push(n.type); + } } - this.graph.configure(graphData); + try { + this.graph.configure(graphData); + } catch (error) { + let errorHint = []; + // Try extracting filename to see if it was caused by an extension script + const filename = error.fileName || (error.stack || "").match(/(\/extensions\/.*\.js)/)?.[1]; + const pos = (filename || "").indexOf("/extensions/"); + if (pos > -1) { + errorHint.push( + $el("span", { textContent: "This may be due to the following script:" }), + $el("br"), + $el("span", { + style: { + fontWeight: "bold", + }, + textContent: filename.substring(pos), + }) + ); + } + + // Show dialog to let the user know something went wrong loading the data + this.ui.dialog.show( + $el("div", [ + $el("p", { textContent: "Loading aborted due to error reloading workflow data" }), + $el("pre", { + style: { padding: "5px", backgroundColor: "rgba(255,0,0,0.2)" }, + textContent: error.toString(), + }), + $el("pre", { + style: { + padding: "5px", + color: "#ccc", + fontSize: "10px", + maxHeight: "50vh", + overflow: "auto", + backgroundColor: "rgba(0,0,0,0.2)", + }, + textContent: error.stack || "No stacktrace available", + }), + ...errorHint, + ]).outerHTML + ); + + return; + } for (const node of this.graph._nodes) { const size = node.computeSize(); @@ -855,6 +946,14 @@ class ComfyApp { this.#invokeExtensions("loadedGraphNode", node); } + + if (missingNodeTypes.length) { + this.ui.dialog.show( + `When loading the graph, the following node types were not found: Nodes that have failed to load will show as red on the graph.` + ); + } } /** diff --git a/web/scripts/defaultGraph.js b/web/scripts/defaultGraph.js index 967377ad6..9b3cb4a7e 100644 --- a/web/scripts/defaultGraph.js +++ b/web/scripts/defaultGraph.js @@ -13,7 +13,7 @@ export const defaultGraph = { inputs: [{ name: "clip", type: "CLIP", link: 5 }], outputs: [{ name: "CONDITIONING", type: "CONDITIONING", links: [6], slot_index: 0 }], properties: {}, - widgets_values: ["bad hands"], + widgets_values: ["text, watermark"], }, { id: 6, @@ -26,7 +26,7 @@ export const defaultGraph = { inputs: [{ name: "clip", type: "CLIP", link: 3 }], outputs: [{ name: "CONDITIONING", type: "CONDITIONING", links: [4], slot_index: 0 }], properties: {}, - widgets_values: ["masterpiece best quality girl"], + widgets_values: ["beautiful scenery nature glass bottle landscape, , purple galaxy bottle,"], }, { id: 5, @@ -56,7 +56,7 @@ export const defaultGraph = { ], outputs: [{ name: "LATENT", type: "LATENT", links: [7], slot_index: 0 }], properties: {}, - widgets_values: [8566257, true, 20, 8, "euler", "normal", 1], + widgets_values: [156680208700286, true, 20, 8, "euler", "normal", 1], }, { id: 8, diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 580030d81..31f470739 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -32,8 +32,9 @@ export function getPngMetadata(file) { } const keyword = String.fromCharCode(...pngData.slice(offset + 8, keyword_end)); // Get the text - const text = String.fromCharCode(...pngData.slice(keyword_end + 1, offset + 8 + length)); - txt_chunks[keyword] = text; + const contentArraySegment = pngData.slice(keyword_end + 1, offset + 8 + length); + const contentJson = Array.from(contentArraySegment).map(s=>String.fromCharCode(s)).join('') + txt_chunks[keyword] = contentJson; } offset += 12 + length; diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 91821fac0..09861c440 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -8,14 +8,18 @@ export function $el(tag, propsOrChildren, children) { if (Array.isArray(propsOrChildren)) { element.append(...propsOrChildren); } else { - const parent = propsOrChildren.parent; + const { parent, $: cb, dataset, style } = propsOrChildren; delete propsOrChildren.parent; - const cb = propsOrChildren.$; delete propsOrChildren.$; + delete propsOrChildren.dataset; + delete propsOrChildren.style; - if (propsOrChildren.style) { - Object.assign(element.style, propsOrChildren.style); - delete propsOrChildren.style; + if (style) { + Object.assign(element.style, style); + } + + if (dataset) { + Object.assign(element.dataset, dataset); } Object.assign(element, propsOrChildren); @@ -76,7 +80,7 @@ function dragElement(dragEl, settings) { dragEl.style.left = newPosX + "px"; dragEl.style.right = "unset"; } - + dragEl.style.top = newPosY + "px"; dragEl.style.bottom = "unset"; @@ -115,14 +119,6 @@ function dragElement(dragEl, settings) { savePos = value; }, }); - - settings.addSetting({ - id: "Comfy.ConfirmClear", - name: "Require confirmation when clearing workflow", - type: "boolean", - defaultValue: true, - }); - function dragMouseDown(e) { e = e || window.event; e.preventDefault(); @@ -153,7 +149,7 @@ function dragElement(dragEl, settings) { } window.addEventListener("resize", () => { - ensureInBounds(); + ensureInBounds(); }); function closeDragElement() { @@ -163,26 +159,33 @@ function dragElement(dragEl, settings) { } } -class ComfyDialog { +export class ComfyDialog { constructor() { this.element = $el("div.comfy-modal", { parent: document.body }, [ - $el("div.comfy-modal-content", [ - $el("p", { $: (p) => (this.textElement = p) }), - $el("button", { - type: "button", - textContent: "CLOSE", - onclick: () => this.close(), - }), - ]), + $el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]), ]); } + createButtons() { + return [ + $el("button", { + type: "button", + textContent: "Close", + onclick: () => this.close(), + }), + ]; + } + close() { this.element.style.display = "none"; } show(html) { - this.textElement.innerHTML = html; + if (typeof html === "string") { + this.textElement.innerHTML = html; + } else { + this.textElement.replaceChildren(html); + } this.element.style.display = "flex"; } } @@ -233,6 +236,7 @@ class ComfySettingsDialog extends ComfyDialog { }; let element; + value = this.getSettingValue(id, defaultValue); if (typeof type === "function") { element = type(name, setter, value, attrs); @@ -289,6 +293,16 @@ class ComfySettingsDialog extends ComfyDialog { return element; }, }); + + const self = this; + return { + get value() { + return self.getSettingValue(id, defaultValue); + }, + set value(v) { + self.setSettingValue(id, v); + }, + }; } show() { @@ -410,6 +424,13 @@ export class ComfyUI { this.history.update(); }); + const confirmClear = this.settings.addSetting({ + id: "Comfy.ConfirmClear", + name: "Require confirmation when clearing workflow", + type: "boolean", + defaultValue: true, + }); + const fileInput = $el("input", { type: "file", accept: ".json,image/png", @@ -421,7 +442,7 @@ export class ComfyUI { }); this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ - $el("div", { style: { overflow: "hidden", position: "relative", width: "100%" } }, [ + $el("div.drag-handle", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } }, [ $el("span.drag-handle"), $el("span", { $: (q) => (this.queueSize = q) }), $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), @@ -517,13 +538,13 @@ export class ComfyUI { $el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), $el("button", { textContent: "Clear", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Clear workflow?")) { + if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); app.graph.clear(); } }}), $el("button", { textContent: "Load Default", onclick: () => { - if (localStorage.getItem("Comfy.Settings.Comfy.ConfirmClear") == "false" || confirm("Load default workflow?")) { + if (!confirmClear.value || confirm("Load default workflow?")) { app.loadGraphData() } }}), diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index a66ef745c..f6147be32 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -357,7 +357,7 @@ export const ComfyWidgets = { const fileInput = document.createElement("input"); Object.assign(fileInput, { type: "file", - accept: "image/jpeg,image/png", + accept: "image/jpeg,image/png,image/webp", style: "display: none", onchange: async () => { if (fileInput.files.length) { diff --git a/web/style.css b/web/style.css index 393d1667e..d00a2fbe2 100644 --- a/web/style.css +++ b/web/style.css @@ -39,18 +39,19 @@ body { position: fixed; /* Stay in place */ z-index: 100; /* Sit on top */ padding: 30px 30px 10px 30px; - background-color: #ff0000; /* Modal background */ + background-color: #353535; /* Modal background */ + color: #ff4444; box-shadow: 0px 0px 20px #888888; border-radius: 10px; - text-align: center; top: 50%; left: 50%; max-width: 80vw; max-height: 80vh; transform: translate(-50%, -50%); overflow: hidden; - min-width: 60%; justify-content: center; + font-family: monospace; + font-size: 15px; } .comfy-modal-content { @@ -70,31 +71,11 @@ body { margin: 3px 3px 3px 4px; } -.comfy-modal button { - cursor: pointer; - color: #aaaaaa; - border: none; - background-color: transparent; - font-size: 24px; - font-weight: bold; - width: 100%; -} - -.comfy-modal button:hover, -.comfy-modal button:focus { - color: #000; - text-decoration: none; - cursor: pointer; -} - .comfy-menu { - width: 200px; font-size: 15px; position: absolute; top: 50%; right: 0%; - background-color: white; - color: #000; text-align: center; z-index: 100; width: 170px; @@ -109,7 +90,8 @@ body { box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } -.comfy-menu button { +.comfy-menu button, +.comfy-modal button { font-size: 20px; } @@ -130,7 +112,8 @@ body { .comfy-menu > button, .comfy-menu-btns button, -.comfy-menu .comfy-list button { +.comfy-menu .comfy-list button, +.comfy-modal button{ color: #ddd; background-color: #222; border-radius: 8px; @@ -219,12 +202,24 @@ button.comfy-queue-btn { margin: 6px 0 !important; } -.comfy-modal.comfy-settings { - background-color: var(--bg-color); - color: var(--fg-color); +.comfy-modal.comfy-settings, +.comfy-modal.comfy-manage-templates { + text-align: center; + font-family: sans-serif; + color: #999; z-index: 99; } +.comfy-modal input, +.comfy-modal select { + color: #ddd; + background-color: #222; + border-radius: 8px; + border-color: #4e4e4e; + border-style: solid; + font-size: inherit; +} + @media only screen and (max-height: 850px) { .comfy-menu { top: 0 !important; @@ -239,26 +234,39 @@ button.comfy-queue-btn { } .graphdialog { - min-height: 1em; + min-height: 1em; } .graphdialog .name { - font-size: 14px; - font-family: sans-serif; - color: #999999; + font-size: 14px; + font-family: sans-serif; + color: #999999; } .graphdialog button { - margin-top: unset; - vertical-align: unset; - height: 1.6em; - padding-right: 8px; + margin-top: unset; + vertical-align: unset; + height: 1.6em; + padding-right: 8px; } .graphdialog input, .graphdialog textarea, .graphdialog select { - background-color: #222; - border: 2px solid; - border-color: #444444; - color: #ddd; - border-radius: 12px 0 0 12px; + background-color: #222; + border: 2px solid; + border-color: #444444; + color: #ddd; + border-radius: 12px 0 0 12px; } + +.litegraph .litemenu-entry.has_submenu { + position: relative; + padding-right: 20px; + } + + .litemenu-entry.has_submenu::after { + content: ">"; + position: absolute; + top: 0; + right: 2px; + } + \ No newline at end of file