diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index c09f29a80..ef9374c44 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'): else: raise AssertionError('Unknown merge analysis result') - +pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) repo = pygit2.Repository(str(sys.argv[1])) ident = pygit2.Signature('comfyui', 'comfy@ui') try: diff --git a/.gitignore b/.gitignore index df6adbe4b..8380a2f7c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ custom_nodes/ !custom_nodes/example_node.py.example extra_model_paths.yaml /.vs +.idea/ \ No newline at end of file diff --git a/README.md b/README.md index bfa8904df..1de9d4c3b 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) +- Latent previews with [TAESD](https://github.com/madebyollin/taesd) - Starts up very fast. - Works fully offline: will never download anything. - [Config file](extra_model_paths.yaml.example) to set the search paths for models. @@ -37,28 +38,28 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ## Shortcuts -| Keybind | Explanation | -| - | - | -| Ctrl + Enter | Queue up current graph for generation | -| Ctrl + Shift + Enter | Queue up current graph as first for generation | -| Ctrl + S | Save workflow | -| Ctrl + O | Load workflow | -| Ctrl + A | Select all nodes | -| Ctrl + M | Mute/unmute selected nodes | -| Delete/Backspace | Delete selected nodes | -| Ctrl + Delete/Backspace | Delete the current graph | -| Space | Move the canvas around when held and moving the cursor | -| Ctrl/Shift + Click | Add clicked node to selection | -| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | -| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | -| Shift + Drag | Move multiple selected nodes at the same time | -| Ctrl + D | Load default graph | -| Q | Toggle visibility of the queue | -| H | Toggle visibility of history | -| R | Refresh graph | -| Double-Click LMB | Open node quick search palette | +| Keybind | Explanation | +|---------------------------|--------------------------------------------------------------------------------------------------------------------| +| Ctrl + Enter | Queue up current graph for generation | +| Ctrl + Shift + Enter | Queue up current graph as first for generation | +| Ctrl + S | Save workflow | +| Ctrl + O | Load workflow | +| Ctrl + A | Select all nodes | +| Ctrl + M | Mute/unmute selected nodes | +| Delete/Backspace | Delete selected nodes | +| Ctrl + Delete/Backspace | Delete the current graph | +| Space | Move the canvas around when held and moving the cursor | +| Ctrl/Shift + Click | Add clicked node to selection | +| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | +| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | +| Shift + Drag | Move multiple selected nodes at the same time | +| Ctrl + D | Load default graph | +| Q | Toggle visibility of the queue | +| H | Toggle visibility of history | +| R | Refresh graph | +| Double-Click LMB | Open node quick search palette | -Ctrl can also be replaced with Cmd instead for MacOS users +Ctrl can also be replaced with Cmd instead for macOS users # Installing @@ -118,13 +119,26 @@ After this you should have everything installed and can proceed to running Comfy ### Others: -[Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476) +#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476) -Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own. +#### Apple Mac silicon + +You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version. + +1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide. +1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. +1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). +1. Launch ComfyUI by running `python main.py`. + +> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). + +#### DirectML (AMD Cards on Windows) + +```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` ### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies? -You don't. If you have another UI installed and working with it's own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it: +You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it: ```source path_to_other_sd_gui/venv/bin/activate``` @@ -134,7 +148,7 @@ With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"``` With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"``` -And then you can use that terminal to run Comfyui without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI. +And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI. # Running @@ -158,6 +172,8 @@ You can use () to change emphasis of a word or phrase like: (good code:1.2) or ( You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}. +Dynamic prompts also support C-style comments, like `// comment` or `/* comment */`. + To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension): ```embedding:embedding_filename.pt``` @@ -181,6 +197,12 @@ You can set this command line setting to disable the upcasting to fp32 in some c ```--dont-upcast-attention``` +## How to show high-quality previews? + +Use ```--preview-method auto``` to enable previews. + +The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. + ## Support and dev channel [Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source). diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc4709f70..b56497de0 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,4 +1,35 @@ import argparse +import enum + + +class EnumAction(argparse.Action): + """ + Argparse action for handling Enums + """ + def __init__(self, **kwargs): + # Pop off the type value + enum_type = kwargs.pop("type", None) + + # Ensure an Enum subclass is provided + if enum_type is None: + raise ValueError("type must be assigned an Enum when using EnumAction") + if not issubclass(enum_type, enum.Enum): + raise TypeError("type must be an Enum when using EnumAction") + + # Generate choices from the Enum + choices = tuple(e.value for e in enum_type) + kwargs.setdefault("choices", choices) + kwargs.setdefault("metavar", f"[{','.join(list(choices))}]") + + super(EnumAction, self).__init__(**kwargs) + + self._enum = enum_type + + def __call__(self, parser, namespace, values, option_string=None): + # Convert value back into an Enum + value = self._enum(values) + setattr(namespace, self.dest, value) + parser = argparse.ArgumentParser() @@ -13,6 +44,14 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") +class LatentPreviewMethod(enum.Enum): + NoPreviews = "none" + Auto = "auto" + Latent2RGB = "latent2rgb" + TAESD = "taesd" + +parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index ceca80305..1eab54d4b 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,14 +1,5 @@ -import json -import os -import yaml - -import folder_paths -from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE -import os.path as osp import re import torch -from safetensors.torch import load_file, save_file # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict -def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): - diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) - diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) - - # magic - v2 = diffusers_unet_conf["sample_size"] == 96 - if 'prediction_type' in diffusers_scheduler_conf: - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' - - if v2: - if v_pred: - config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') - - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) - - model_config_params = config['model']['params'] - clip_config = model_config_params['cond_stage_config'] - scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - vae_config['scale_factor'] = scale_factor - model_config_params["unet_config"]["params"]["use_fp16"] = fp16 - - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") - text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") - - # Load models from safetensors if it exists, if it doesn't pytorch - if osp.exists(unet_path): - unet_state_dict = load_file(unet_path, device="cpu") - else: - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") - unet_state_dict = torch.load(unet_path, map_location="cpu") - - if osp.exists(vae_path): - vae_state_dict = load_file(vae_path, device="cpu") - else: - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") - vae_state_dict = torch.load(vae_path, map_location="cpu") - - if osp.exists(text_enc_path): - text_enc_dict = load_file(text_enc_path, device="cpu") - else: - text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} - - # Convert the VAE model - vae_state_dict = convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper - is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict - - if is_v20_model: - # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm - text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} - text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) - text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} - else: - text_enc_dict = convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - - # Put together new checkpoint - sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py new file mode 100644 index 000000000..f494f1d30 --- /dev/null +++ b/comfy/diffusers_load.py @@ -0,0 +1,87 @@ +import json +import os +import yaml + +import folder_paths +from comfy.ldm.util import instantiate_from_config +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint +import os.path as osp +import re +import torch +from safetensors.torch import load_file, save_file +import diffusers_convert + +def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): + diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) + diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) + + # magic + v2 = diffusers_unet_conf["sample_size"] == 96 + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + + if v2: + if v_pred: + config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) + + model_config_params = config['model']['params'] + clip_config = model_config_params['cond_stage_config'] + scale_factor = model_config_params['scale_factor'] + vae_config = model_config_params['first_stage_config'] + vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + + return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config) diff --git a/comfy/model_base.py b/comfy/model_base.py new file mode 100644 index 000000000..7370c19fd --- /dev/null +++ b/comfy/model_base.py @@ -0,0 +1,66 @@ +import torch +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule +import numpy as np + +class BaseModel(torch.nn.Module): + def __init__(self, unet_config, v_prediction=False): + super().__init__() + + self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + self.diffusion_model = UNetModel(**unet_config) + self.v_prediction = v_prediction + if self.v_prediction: + self.parameterization = "v" + else: + self.parameterization = "eps" + if "adm_in_channels" in unet_config: + self.adm_channels = unet_config["adm_in_channels"] + else: + self.adm_channels = 0 + print("v_prediction", v_prediction) + print("adm", self.adm_channels) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + + self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + + def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + context = torch.cat(c_crossattn, 1) + return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options) + + def get_dtype(self): + return self.diffusion_model.dtype + + def is_adm(self): + return self.adm_channels > 0 + +class SD21UNCLIP(BaseModel): + def __init__(self, unet_config, noise_aug_config, v_prediction=True): + super().__init__(unet_config, v_prediction) + self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) + +class SDInpaint(BaseModel): + def __init__(self, unet_config, v_prediction=False): + super().__init__(unet_config, v_prediction) + self.concat_keys = ("mask", "masked_image") diff --git a/comfy/model_management.py b/comfy/model_management.py index c15323219..1a8a1be17 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,23 +1,29 @@ import psutil from enum import Enum from comfy.cli_args import args +import torch class VRAMState(Enum): - CPU = 0 - NO_VRAM = 1 + DISABLED = 0 #No vram present: no need to move models to vram + NO_VRAM = 1 #Very low vram: enable all the options to save vram LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 - MPS = 5 + SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both. + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM +cpu_state = CPUState.GPU total_vram = 0 -total_vram_available_mb = -1 -accelerate_enabled = False +lowvram_available = True xpu_available = False directml_enabled = False @@ -31,30 +37,80 @@ if args.directml is not None: directml_device = torch_directml.device(device_index) print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) + lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import torch - if directml_enabled: - total_vram = 4097 #TODO - else: - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) - total_ram = psutil.virtual_memory().total / (1024 * 1024) - if not args.normalvram and not args.cpu: - if total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True except: pass +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS +except: + pass + +if args.cpu: + cpu_state = CPUState.CPU + +def get_torch_device(): + global xpu_available + global directml_enabled + global cpu_state + if directml_enabled: + global directml_device + return directml_device + if cpu_state == CPUState.MPS: + return torch.device("mps") + if cpu_state == CPUState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) + +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + mem_total_torch = mem_total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + +total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) +total_ram = psutil.virtual_memory().total / (1024 * 1024) +print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +if not args.normalvram and not args.cpu: + if lowvram_available and total_vram <= 4096: + print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + set_vram_to = VRAMState.LOW_VRAM + elif total_vram > total_ram * 1.1 and total_vram > 14336: + print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") + vram_state = VRAMState.HIGH_VRAM + try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError except: @@ -92,6 +148,7 @@ if ENABLE_PYTORCH_ATTENTION: if args.lowvram: set_vram_to = VRAMState.LOW_VRAM + lowvram_available = True elif args.novram: set_vram_to = VRAMState.NO_VRAM elif args.highvram: @@ -102,54 +159,38 @@ if args.force_fp32: print("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True - -if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): +if lowvram_available: try: import accelerate - accelerate_enabled = True - vram_state = set_vram_to + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to except Exception as e: import traceback print(traceback.format_exc()) - print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") + print("ERROR: LOW VRAM MODE NEEDS accelerate.") + lowvram_available = False - total_vram_available_mb = (total_vram - 1024) // 2 - total_vram_available_mb = int(max(256, total_vram_available_mb)) -try: - if torch.backends.mps.is_available(): - vram_state = VRAMState.MPS -except: - pass +if cpu_state != CPUState.GPU: + vram_state = VRAMState.DISABLED -if args.cpu: - vram_state = VRAMState.CPU +if cpu_state == CPUState.MPS: + vram_state = VRAMState.SHARED print(f"Set vram state to: {vram_state.name}") -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() def get_torch_device_name(device): if hasattr(device, 'type'): - return "{}".format(device.type) - return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + if device.type == "cuda": + return "{} {}".format(device, torch.cuda.get_device_name(device)) + else: + return "{}".format(device.type) + else: + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Using device:", get_torch_device_name(get_torch_device())) + print("Device:", get_torch_device_name(get_torch_device())) except: print("Could not pick default device.") @@ -199,22 +240,29 @@ def load_model_gpu(model): model.unpatch_model() raise e - model.model_patches_to(get_torch_device()) + torch_dev = get_torch_device() + model.model_patches_to(torch_dev) + + vram_set_state = vram_state + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = model.model_size() + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + current_loaded_model = model - if vram_state == VRAMState.CPU: + + if vram_set_state == VRAMState.DISABLED: pass - elif vram_state == VRAMState.MPS: - mps_device = torch.device("mps") - real_model.to(mps_device) - pass - elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False real_model.to(get_torch_device()) else: - if vram_state == VRAMState.NO_VRAM: + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) + elif vram_set_state == VRAMState.LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True @@ -223,7 +271,7 @@ def load_model_gpu(model): def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state - if vram_state == VRAMState.CPU: + if vram_state == VRAMState.DISABLED: return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: @@ -268,7 +316,8 @@ def get_autocast_device(dev): def xformers_enabled(): global xpu_available global directml_enabled - if vram_state == VRAMState.CPU: + global cpu_state + if cpu_state != CPUState.GPU: return False if xpu_available: return False @@ -340,12 +389,12 @@ def maximum_batch_area(): return int(max(area, 0)) def cpu_mode(): - global vram_state - return vram_state == VRAMState.CPU + global cpu_state + return cpu_state == CPUState.CPU def mps_mode(): - global vram_state - return vram_state == VRAMState.MPS + global cpu_state + return cpu_state == CPUState.MPS def should_use_fp16(): global xpu_available @@ -377,7 +426,10 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - if xpu_available: + global cpu_state + if cpu_state == CPUState.MPS: + torch.mps.empty_cache() + elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda diff --git a/comfy/samplers.py b/comfy/samplers.py index 1fb928f8d..a33d150d0 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c['transformer_options'] = transformer_options - output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) + output = model_function(input_x, timestep_, **c).chunk(batch_chunks) del input_x model_management.throw_exception_if_processing_interrupted() @@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond[temp[1]] = [o[0], n] -def encode_adm(noise_augmentor, conds, batch_size, device): +def encode_adm(conds, batch_size, device, noise_augmentor=None): for t in range(len(conds)): x = conds[t] - if 'adm' in x[1]: - adm_inputs = [] - weights = [] - noise_aug = [] - adm_in = x[1]["adm"] - for adm_c in adm_in: - adm_cond = adm_c[0].image_embeds - weight = adm_c[1] - noise_augment = adm_c[2] - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight - weights.append(weight) - noise_aug.append(noise_augment) - adm_inputs.append(adm_out) + adm_out = None + if noise_augmentor is not None: + if 'adm' in x[1]: + adm_inputs = [] + weights = [] + noise_aug = [] + adm_in = x[1]["adm"] + for adm_c in adm_in: + adm_cond = adm_c[0].image_embeds + weight = adm_c[1] + noise_augment = adm_c[2] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + noise_aug.append(noise_augment) + adm_inputs.append(adm_out) - if len(noise_aug) > 1: - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: add a way to control this - noise_augment = 0.05 - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: add a way to control this + noise_augment = 0.05 + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) + else: + adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) else: - adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) - x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) + if 'adm' in x[1]: + adm_out = x[1]["adm"].to(device) + if adm_out is not None: + x[1] = x[1].copy() + x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) return conds @@ -591,14 +597,17 @@ class KSampler: apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if self.model.model.diffusion_model.dtype == torch.float16: + if self.model.get_dtype() == torch.float16: precision_scope = torch.autocast else: precision_scope = contextlib.nullcontext - if hasattr(self.model, 'noise_augmentor'): #unclip - positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) - negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) + if self.model.is_adm(): + noise_augmentor = None + if hasattr(self.model, 'noise_augmentor'): #unclip + noise_augmentor = self.model.noise_augmentor + positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor) + negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} diff --git a/comfy/sd.py b/comfy/sd.py index c6be900ad..3747f53b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,8 +14,16 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision from . import gligen +from . import diffusers_convert +from . import model_base def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): + replace_prefix = {"model.diffusion_model.": "diffusion_model."} + for rp in replace_prefix: + replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys()))) + for x in replace: + sd[x[1]] = sd.pop(x[0]) + m, u = model.load_state_dict(sd, strict=False) k = list(sd.keys()) @@ -30,17 +38,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - keys_to_replace = { - "cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", - "cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", - "cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight", - "cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias", - } - - for x in keys_to_replace: - if x in sd: - sd[keys_to_replace[x]] = sd.pop(x) - sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24) for x in load_state_dict_to: @@ -192,7 +189,7 @@ def model_lora_keys(model, key_map={}): counter = 0 for b in range(12): - tk = "model.diffusion_model.input_blocks.{}.1".format(b) + tk = "diffusion_model.input_blocks.{}.1".format(b) up_counter = 0 for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) @@ -203,13 +200,13 @@ def model_lora_keys(model, key_map={}): if up_counter >= 4: counter += 1 for c in LORA_UNET_MAP_ATTENTIONS: - k = "model.diffusion_model.middle_block.1.{}.weight".format(c) + k = "diffusion_model.middle_block.1.{}.weight".format(c) if k in sdk: lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c]) key_map[lora_key] = k counter = 3 for b in range(12): - tk = "model.diffusion_model.output_blocks.{}.1".format(b) + tk = "diffusion_model.output_blocks.{}.1".format(b) up_counter = 0 for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) @@ -233,7 +230,7 @@ def model_lora_keys(model, key_map={}): ds_counter = 0 counter = 0 for b in range(12): - tk = "model.diffusion_model.input_blocks.{}.0".format(b) + tk = "diffusion_model.input_blocks.{}.0".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -252,7 +249,7 @@ def model_lora_keys(model, key_map={}): counter = 0 for b in range(3): - tk = "model.diffusion_model.middle_block.{}".format(b) + tk = "diffusion_model.middle_block.{}".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -266,7 +263,7 @@ def model_lora_keys(model, key_map={}): counter = 0 us_counter = 0 for b in range(12): - tk = "model.diffusion_model.output_blocks.{}.0".format(b) + tk = "diffusion_model.output_blocks.{}.0".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -285,15 +282,29 @@ def model_lora_keys(model, key_map={}): return key_map + class ModelPatcher: - def __init__(self, model): + def __init__(self, model, size=0): + self.size = size self.model = model self.patches = [] self.backup = {} self.model_options = {"transformer_options":{}} + self.model_size() + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + return size def clone(self): - n = ModelPatcher(self.model) + n = ModelPatcher(self.model, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) return n @@ -328,7 +339,7 @@ class ModelPatcher: patch_list[i] = patch_list[i].to(device) def model_dtype(self): - return self.model.diffusion_model.dtype + return self.model.get_dtype() def add_patches(self, patches, strength=1.0): p = {} @@ -504,10 +515,16 @@ class VAE: if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") else: - self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + if ckpt_path is not None: + sd = utils.load_torch_file(ckpt_path) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + self.first_stage_model.load_state_dict(sd, strict=False) + self.scale_factor = scale_factor if device is None: device = model_management.get_torch_device() @@ -600,7 +617,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): return torch.cat([tensor] * batched_number, dim=0) class ControlNet: - def __init__(self, control_model, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None): self.control_model = control_model self.cond_hint_original = None self.cond_hint = None @@ -609,6 +626,7 @@ class ControlNet: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond_txt, batched_number): control_prev = None @@ -644,6 +662,9 @@ class ControlNet: key = 'output' index = i x = control[i] + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) @@ -674,7 +695,7 @@ class ControlNet: self.cond_hint = None def copy(self): - c = ControlNet(self.control_model) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c.cond_hint_original = self.cond_hint_original c.strength = self.strength return c @@ -722,7 +743,7 @@ def load_controlnet(ckpt_path, model=None): use_spatial_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) else: @@ -739,7 +760,7 @@ def load_controlnet(ckpt_path, model=None): use_linear_in_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) if pth: @@ -750,7 +771,7 @@ def load_controlnet(ckpt_path, model=None): for x in controlnet_data: c_m = "control_model." if x.startswith(c_m): - sd_key = "model.diffusion_model.{}".format(x[len(c_m):]) + sd_key = "diffusion_model.{}".format(x[len(c_m):]) if sd_key in model_sd: cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) @@ -769,7 +790,11 @@ def load_controlnet(ckpt_path, model=None): if use_fp16: control_model = control_model.half() - control = ControlNet(control_model) + global_average_pooling = False + if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control class T2IAdapter: @@ -913,9 +938,10 @@ def load_gligen(ckpt_path): model = model.half() return model -def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) +def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): + if config is None: + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) model_config_params = config['model']['params'] clip_config = model_config_params['cond_stage_config'] scale_factor = model_config_params['scale_factor'] @@ -924,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e fp16 = False if "unet_config" in model_config_params: if "params" in model_config_params["unet_config"]: - if "use_fp16" in model_config_params["unet_config"]["params"]: - fp16 = model_config_params["unet_config"]["params"]["use_fp16"] + unet_config = model_config_params["unet_config"]["params"] + if "use_fp16" in unet_config: + fp16 = unet_config["use_fp16"] + + noise_aug_config = None + if "noise_aug_config" in model_config_params: + noise_aug_config = model_config_params["noise_aug_config"] + + v_prediction = False + + if "parameterization" in model_config_params: + if model_config_params["parameterization"] == "v": + v_prediction = True clip = None vae = None @@ -945,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] - model = instantiate_from_config(config["model"]) - sd = utils.load_torch_file(ckpt_path) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + if config['model']["target"].endswith("LatentInpaintDiffusion"): + model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): + model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + else: + model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + + if state_dict is None: + state_dict = utils.load_torch_file(ckpt_path) + model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: model = model.half() @@ -1024,7 +1068,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o } unet_config = { - "use_checkpoint": True, + "use_checkpoint": False, "image_size": 32, "out_channels": 4, "attention_resolutions": [ @@ -1044,47 +1088,59 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o "legacy": False } - if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2: + if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2: unet_config['use_linear_in_transformer'] = True unet_config["use_fp16"] = fp16 unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0] unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] - unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] + unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + unclip_model = False + inpaint_model = False if noise_aug_config is not None: #SD2.x unclip model sd_config["noise_aug_config"] = noise_aug_config sd_config["image_size"] = 96 sd_config["embedding_dropout"] = 0.25 sd_config["conditioning_key"] = 'crossattn-adm' + unclip_model = True model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + inpaint_model = True else: sd_config["conditioning_key"] = "crossattn" - if unet_config["context_dim"] == 1024: - unet_config["num_head_channels"] = 64 #SD2.x - else: + if unet_config["context_dim"] == 768: unet_config["num_heads"] = 8 #SD1.x + else: + unet_config["num_head_channels"] = 64 #SD2.x unclip = 'model.diffusion_model.label_emb.0.0.weight' if unclip in sd_keys: unet_config["num_classes"] = "sequential" unet_config["adm_in_channels"] = sd[unclip].shape[1] + v_prediction = False if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" out = sd[k] if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. + v_prediction = True sd_config["parameterization"] = 'v' - model = instantiate_from_config(model_config) + if inpaint_model: + model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + elif unclip_model: + model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + else: + model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index b1a392736..91fb4ff27 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -82,6 +82,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): next_new_token += 1 else: print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) + while len(tokens_temp) < len(x): + tokens_temp += [self.empty_tokens[0][-1]] out_tokens += [tokens_temp] if len(embedding_weights) > 0: diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py new file mode 100644 index 000000000..1549345ae --- /dev/null +++ b/comfy/taesd/taesd.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +Tiny AutoEncoder for Stable Diffusion +(DNN for encoding / decoding SD's latent space) +""" +import torch +import torch.nn as nn + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + +def Encoder(): + return nn.Sequential( + conv(3, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 4), + ) + +def Decoder(): + return nn.Sequential( + Clamp(), conv(4, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + +class TAESD(nn.Module): + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.encoder = Encoder() + self.decoder = Decoder() + if encoder_path is not None: + self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) + if decoder_path is not None: + self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) + + @staticmethod + def scale_latents(x): + """raw latents -> [0, 1]""" + return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1) + + @staticmethod + def unscale_latents(x): + """[0, 1] -> raw latents""" + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..585ebda51 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,11 +1,16 @@ import torch import math +import struct def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: + if safe_load: + if not 'weights_only' in torch.load.__code__.co_varnames: + print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) else: @@ -19,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False): return sd def transformers_convert(sd, prefix_from, prefix_to, number): + keys_to_replace = { + "{}.positional_embedding": "{}.embeddings.position_embedding.weight", + "{}.token_embedding.weight": "{}.embeddings.token_embedding.weight", + "{}.ln_final.weight": "{}.final_layer_norm.weight", + "{}.ln_final.bias": "{}.final_layer_norm.bias", + } + + for k in keys_to_replace: + x = k.format(prefix_from) + if x in sd: + sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) + resblock_to_replace = { "ln_1": "layer_norm1", "ln_2": "layer_norm2", @@ -46,71 +63,87 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd -#slow and inefficient, should be optimized +def safetensors_header(safetensors_path, max_size=100*1024*1024): + with open(safetensors_path, "rb") as f: + header = f.read(8) + length_of_header = struct.unpack(' max_size: + return None + return f.read(length_of_header) + def bislerp(samples, width, height): - shape = list(samples.shape) - width_scale = (shape[3]) / (width ) - height_scale = (shape[2]) / (height ) + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] - shape[3] = width - shape[2] = height - out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) - def algorithm(in1, in2, t): - dims = in1.shape - val = t + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms - #flatten to batches - low = in1.reshape(dims[0], -1) - high = in2.reshape(dims[0], -1) + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 - low_weight = torch.norm(low, dim=1, keepdim=True) - low_weight[low_weight == 0] = 0.0000000001 - low_norm = low/low_weight - high_weight = torch.norm(high, dim=1, keepdim=True) - high_weight[high_weight == 0] = 0.0000000001 - high_norm = high/high_weight - - dot_prod = (low_norm*high_norm).sum(1) - dot_prod[dot_prod > 0.9995] = 0.9995 - dot_prod[dot_prod < -0.9995] = -0.9995 - omega = torch.acos(dot_prod) + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm - res *= (low_weight * (1.0-val) + high_weight * val) - return res.reshape(dims) - for x_dest in range(shape[3]): - for y_dest in range(shape[2]): - y = (y_dest + 0.5) * height_scale - 0.5 - x = (x_dest + 0.5) * width_scale - 0.5 + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) - x1 = max(math.floor(x), 0) - x2 = min(x1 + 1, samples.shape[3] - 1) - wx = x - math.floor(x) + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new): + coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2[:,:,:,-1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + n,c,h,w = samples.shape + h_new, w_new = (height, width) + + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + coords_1 = coords_1.expand((n, c, h, -1)) + coords_2 = coords_2.expand((n, c, h, -1)) + ratios = ratios.expand((n, 1, h, -1)) - y1 = max(math.floor(y), 0) - y2 = min(y1 + 1, samples.shape[2] - 1) - wy = y - math.floor(y) + pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) - in1 = samples[:,:,y1,x1] - in2 = samples[:,:,y1,x2] - in3 = samples[:,:,y2,x1] - in4 = samples[:,:,y2,x2] + result = slerp(pass_1, pass_2, ratios) + result = result.reshape(n, h, w_new, c).movedim(-1, 1) - if (x1 == x2) and (y1 == y2): - out_value = in1 - elif (x1 == x2): - out_value = algorithm(in1, in3, wy) - elif (y1 == y2): - out_value = algorithm(in1, in2, wx) - else: - o1 = algorithm(in1, in2, wx) - o2 = algorithm(in3, in4, wx) - out_value = algorithm(o1, o2, wy) + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - out1[:,:,y_dest,x_dest] = out_value - return out1 + pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) + + result = slerp(pass_1, pass_2, ratios) + result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) + return result def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": @@ -176,14 +209,14 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK - def update_absolute(self, value, total=None): + def update_absolute(self, value, total=None, preview=None): if total is not None: self.total = total if value > self.total: value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total) + self.hook(self.current, self.total, preview) def update(self, value): self.update_absolute(self.current + value) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9916f3b21..15377af14 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -167,7 +167,7 @@ class MaskComposite: "source": ("MASK",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract"],), + "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), } } @@ -193,6 +193,12 @@ class MaskComposite: output[top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion + elif operation == "and": + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() + elif operation == "or": + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() + elif operation == "xor": + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() output = torch.clamp(output, 0.0, 1.0) diff --git a/execution.py b/execution.py index 25f2fcacd..218a84c36 100644 --- a/execution.py +++ b/execution.py @@ -102,13 +102,21 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui +def format_value(x): + if x is None: + return None + elif isinstance(x, (int, float, bool, str)): + return x + else: + return str(x) + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return + return (True, None, None) for x in inputs: input_data = inputs[x] @@ -117,22 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + if result[0] is not True: + # Another node failed further upstream + return result - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - obj = class_def() - - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + input_data_all = None + try: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.last_node_id = unique_id + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + obj = class_def() + + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui + if server.client_id is not None: + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + except comfy.model_management.InterruptProcessingException as iex: + print("Processing interrupted") + + # skip formatting inputs/outputs + error_details = { + "node_id": unique_id, + } + + return (False, error_details, iex) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputs in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputs] + + output_data_formatted = {} + for node_id, node_outputs in outputs.items(): + output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + + error_details = { + "node_id": unique_id, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "current_inputs": input_data_formatted, + "current_outputs": output_data_formatted + } + return (False, error_details, ex) + executed.add(unique_id) + return (True, None, None) + def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -210,6 +260,48 @@ class PromptExecutor: self.old_prompt = {} self.server = server + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + node_id = error["node_id"] + class_type = prompt[node_id]["class_type"] + + # First, send back the status to the frontend depending + # on the exception type + if isinstance(ex, comfy.model_management.InterruptProcessingException): + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + } + self.server.send_sync("execution_interrupted", mes, self.server.client_id) + else: + if self.server.client_id is not None: + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": error["exception_message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.server.send_sync("execution_error", mes, self.server.client_id) + + # Next, remove the subsequent outputs since they will not be executed + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -244,42 +336,30 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() - try: - to_execute = [] - for x in list(execute_outputs): - to_execute += [(0, x)] + output_node_id = None + to_execute = [] - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - x = to_execute.pop(0)[-1] + for node_id in list(execute_outputs): + to_execute += [(0, node_id)] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) - except Exception as e: - if isinstance(e, comfy.model_management.InterruptProcessingException): - print("Processing interrupted") - else: - message = str(traceback.format_exc()) - print(message) - if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) + while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + output_node_id = to_execute.pop(0)[-1] - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - finally: - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) - self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + # This call shouldn't raise anything if there's an error deep in + # the actual SD code, instead it will report the node where the + # error was raised + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + if success is not True: + self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + break + + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() @@ -297,57 +377,202 @@ def validate_inputs(prompt, item, validated): class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] + + errors = [] + valid = True + for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } + } + errors.append(error) + continue + val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) + error = { + "type": "bad_linked_input", + "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val + } + } + errors.append(error) + continue + o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) - r = validate_inputs(prompt, o_id, validated) - if r[0] == False: - validated[o_id] = r - return r + received_type = r[val[1]] + details = f"{x}, {received_type} != {type_input}" + error = { + "type": "return_type_mismatch", + "message": "Return type mismatch between linked nodes", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_type": received_type, + "linked_node": val + } + } + errors.append(error) + continue + try: + r = validate_inputs(prompt, o_id, validated) + if r[0] is False: + # `r` will be set in `validated[o_id]` already + valid = False + continue + except Exception as ex: + typ, _, tb = sys.exc_info() + valid = False + exception_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", + "details": str(ex), + "extra_info": { + "input_name": x, + "input_config": info, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "linked_node": val + } + }] + validated[o_id] = (False, reasons, o_id) + continue else: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val + try: + if type_input == "INT": + val = int(val) + inputs[x] = val + if type_input == "FLOAT": + val = float(val) + inputs[x] = val + if type_input == "STRING": + val = str(val) + inputs[x] = val + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + errors.append(error) + continue if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for r in ret: - if r != True: - return (False, "{}, {}".format(class_type, r), unique_id) + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f" - {str(r)}" + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) + input_config = info + list_info = "" + + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) + + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": val, + } + } + errors.append(error) + continue + + if len(errors) > 0 or valid is not True: + ret = (False, errors, unique_id) + else: + ret = (True, [], unique_id) - ret = (True, "", unique_id) validated[unique_id] = ret return ret +def full_type_name(klass): + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ + return module + '.' + klass.__qualname__ + def validate_prompt(prompt): outputs = set() for x in prompt: @@ -356,7 +581,13 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs", [], []) + error = { + "type": "prompt_no_outputs", + "message": "Prompt has no outputs", + "details": "", + "extra_info": {} + } + return (False, error, [], []) good_outputs = set() errors = [] @@ -364,34 +595,72 @@ def validate_prompt(prompt): validated = {} for o in outputs: valid = False - reason = "" + reasons = [] try: m = validate_inputs(prompt, o, validated) valid = m[0] - reason = m[1] - node_id = m[2] - except Exception as e: - print(traceback.format_exc()) + reasons = m[1] + except Exception as ex: + typ, _, tb = sys.exc_info() valid = False - reason = "Parsing error" - node_id = None + exception_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "exception_type": exception_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o] = (False, reasons, o) - if valid == True: + if valid is True: good_outputs.add(o) else: - print("Failed to validate prompt for output {} {}".format(o, reason)) - print("output will be ignored") - errors += [(o, reason)] - if node_id is not None: - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + print(f"Failed to validate prompt for output {o}:") + if len(reasons) > 0: + print("* (prompt):") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + errors += [(o, reasons)] + for node_id, result in validated.items(): + valid = result[0] + reasons = result[1] + # If a node upstream has errors, the nodes downstream will also + # be reported as invalid, but there will be no errors attached. + # So don't return those nodes as having errors in the response. + if valid is not True and len(reasons) > 0: + if node_id not in node_errors: + class_type = prompt[node_id]['class_type'] + node_errors[node_id] = { + "errors": reasons, + "dependent_outputs": [], + "class_type": class_type + } + print(f"* {class_type} {node_id}:") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + node_errors[node_id]["dependent_outputs"].append(o) + print("Output will be ignored") if len(good_outputs) == 0: - errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) + errors_list = [] + for o, errors in errors: + for error in errors: + errors_list.append(f"{error['message']}: {error['details']}") + errors_list = "\n".join(errors_list) - return (True, "", list(good_outputs), node_errors) + error = { + "type": "prompt_outputs_failed_validation", + "message": "Prompt outputs failed validation", + "details": errors_list, + "extra_info": {} + } + + return (False, error, list(good_outputs), node_errors) + + return (True, None, list(good_outputs), node_errors) class PromptQueue: diff --git a/folder_paths.py b/folder_paths.py index 28f117824..2ad1b1719 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,8 @@ import os +import time -supported_ckpt_extensions = set(['.ckpt', '.pth']) -supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) -try: - import safetensors.torch - supported_ckpt_extensions.add('.safetensors') - supported_pt_extensions.add('.safetensors') -except: - print("Could not import safetensors, safetensors support disabled.") - +supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) +supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} @@ -24,6 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision" folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) +folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) @@ -38,6 +33,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +filename_list_cache = {} + if not os.path.exists(input_directory): os.makedirs(input_directory) @@ -118,12 +115,18 @@ def get_folder_paths(folder_name): return folder_names_and_paths[folder_name][0][:] def recursive_search(directory): + if not os.path.isdir(directory): + return [], {} result = [] + dirs = {directory: os.path.getmtime(directory)} for root, subdir, file in os.walk(directory, followlinks=True): for filepath in file: #we os.path,join directory with a blank string to generate a path separator at the end. result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - return result + for d in subdir: + path = os.path.join(root, d) + dirs[path] = os.path.getmtime(path) + return result, dirs def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) @@ -132,20 +135,58 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths + if folder_name not in folder_names_and_paths: + return None folders = folder_names_and_paths[folder_name] + filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: full_path = os.path.join(x, filename) if os.path.isfile(full_path): return full_path + return None -def get_filename_list(folder_name): +def get_filename_list_(folder_name): global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] + output_folders = {} for x in folders[0]: - output_list.update(filter_files_extensions(recursive_search(x), folders[1])) - return sorted(list(output_list)) + files, folders_all = recursive_search(x) + output_list.update(filter_files_extensions(files, folders[1])) + output_folders = {**output_folders, **folders_all} + + return (sorted(list(output_list)), output_folders, time.perf_counter()) + +def cached_filename_list_(folder_name): + global filename_list_cache + global folder_names_and_paths + if folder_name not in filename_list_cache: + return None + out = filename_list_cache[folder_name] + if time.perf_counter() < (out[2] + 0.5): + return out + for x in out[1]: + time_modified = out[1][x] + folder = x + if os.path.getmtime(folder) != time_modified: + return None + + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + if os.path.isdir(x): + if x not in out[1]: + return None + + return out + +def get_filename_list(folder_name): + out = cached_filename_list_(folder_name) + if out is None: + out = get_filename_list_(folder_name) + global filename_list_cache + filename_list_cache[folder_name] = out + return list(out[0]) def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): diff --git a/latent_preview.py b/latent_preview.py new file mode 100644 index 000000000..ef6c201b6 --- /dev/null +++ b/latent_preview.py @@ -0,0 +1,95 @@ +import torch +from PIL import Image, ImageOps +from io import BytesIO +import struct +import numpy as np + +from comfy.cli_args import args, LatentPreviewMethod +from comfy.taesd.taesd import TAESD +import folder_paths + +MAX_PREVIEW_RESOLUTION = 512 + +class LatentPreviewer: + def decode_latent_to_preview(self, x0): + pass + + def decode_latent_to_preview_image(self, preview_format, x0): + preview_image = self.decode_latent_to_preview(x0) + preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS) + + preview_type = 1 + if preview_format == "JPEG": + preview_type = 1 + elif preview_format == "PNG": + preview_type = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", preview_type) + bytesIO.write(header) + preview_image.save(bytesIO, format=preview_format, quality=95) + preview_bytes = bytesIO.getvalue() + return preview_bytes + +class TAESDPreviewerImpl(LatentPreviewer): + def __init__(self, taesd): + self.taesd = taesd + + def decode_latent_to_preview(self, x0): + x_sample = self.taesd.decoder(x0)[0].detach() + # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] + x_sample = x_sample.sub(0.5).mul(2) + + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + + preview_image = Image.fromarray(x_sample) + return preview_image + + +class Latent2RGBPreviewer(LatentPreviewer): + def __init__(self): + self.latent_rgb_factors = torch.tensor([ + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ], device="cpu") + + def decode_latent_to_preview(self, x0): + latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors + + latents_ubyte = (((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte()).cpu() + + return Image.fromarray(latents_ubyte.numpy()) + + +def get_previewer(device): + previewer = None + method = args.preview_method + if method != LatentPreviewMethod.NoPreviews: + # TODO previewer methods + taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth") + + if method == LatentPreviewMethod.Auto: + method = LatentPreviewMethod.Latent2RGB + if taesd_decoder_path: + method = LatentPreviewMethod.TAESD + + if method == LatentPreviewMethod.TAESD: + if taesd_decoder_path: + taesd = TAESD(None, taesd_decoder_path).to(device) + previewer = TAESDPreviewerImpl(taesd) + else: + print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth") + + if previewer is None: + previewer = Latent2RGBPreviewer() + return previewer + + diff --git a/main.py b/main.py index 50d3b9a62..8293c06fc 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ import yaml import execution import folder_paths import server +from server import BinaryEventTypes from nodes import init_custom_nodes @@ -36,19 +37,25 @@ def prompt_worker(q, server): e.execute(item[2], item[1], item[3], item[4]) q.task_done(item_id, e.outputs_ui) + async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) + def hijack_progress(server): - def hook(value, total): - server.send_sync("progress", { "value": value, "max": total}, server.client_id) + def hook(value, total, preview_image_bytes): + server.send_sync("progress", {"value": value, "max": total}, server.client_id) + if preview_image_bytes is not None: + server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) + def cleanup_temp(): temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) + def load_extra_path_config(yaml_path): with open(yaml_path, 'r') as stream: config = yaml.safe_load(stream) @@ -69,6 +76,7 @@ def load_extra_path_config(yaml_path): print("Adding extra search path", x, full_path) folder_paths.add_model_folder_path(x, full_path) + if __name__ == "__main__": cleanup_temp() @@ -89,7 +97,7 @@ if __name__ == "__main__": server.add_routes() hijack_progress(server) - threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() + threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start() if args.output_directory: output_dir = os.path.abspath(args.output_directory) @@ -103,15 +111,12 @@ if __name__ == "__main__": if args.auto_launch: def startup_server(address, port): import webbrowser - webbrowser.open("http://{}:{}".format(address, port)) + webbrowser.open(f"http://{address}:{port}") call_on_start = startup_server - if os.name == "nt": - try: - loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) - except KeyboardInterrupt: - pass - else: + try: loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) + except KeyboardInterrupt: + print("\nStopped server") cleanup_temp() diff --git a/models/vae_approx/put_taesd_encoder_pth_and_taesd_decoder_pth_here b/models/vae_approx/put_taesd_encoder_pth_and_taesd_decoder_pth_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index f0a93ebd5..b057504ed 100644 --- a/nodes.py +++ b/nodes.py @@ -13,11 +13,10 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch - sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) -import comfy.diffusers_convert +import comfy.diffusers_load import comfy.samplers import comfy.sample import comfy.sd @@ -29,7 +28,7 @@ import comfy.model_management import importlib import folder_paths - +import latent_preview def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -248,7 +247,6 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) - class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -377,7 +375,7 @@ class DiffusersLoader: model_path = path break - return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: @@ -426,6 +424,9 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + lora_path = folder_paths.get_full_path("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) @@ -507,6 +508,9 @@ class ControlNetApply: CATEGORY = "conditioning" def apply_controlnet(self, conditioning, control_net, image, strength): + if strength == 0: + return (conditioning, ) + c = [] control_hint = image.movedim(-1,1) for t in conditioning: @@ -613,6 +617,9 @@ class unCLIPConditioning: CATEGORY = "conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): + if strength == 0: + return (conditioning, ) + c = [] for t in conditioning: o = t[1].copy() @@ -922,6 +929,7 @@ class SetLatentNoiseMask: s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) + def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -936,9 +944,18 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] + preview_format = "JPEG" + if preview_format not in ["JPEG", "PNG"]: + preview_format = "JPEG" + + previewer = latent_preview.get_previewer(device) + pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): - pbar.update_absolute(step + 1, total_steps) + preview_bytes = None + if previewer: + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + pbar.update_absolute(step + 1, total_steps, preview_bytes) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, @@ -961,7 +978,8 @@ class KSampler: "negative": ("CONDITIONING", ), "latent_image": ("LATENT", ), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }} + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" @@ -988,7 +1006,8 @@ class KSamplerAdvanced: "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "return_with_leftover_noise": (["disable", "enable"], ), - }} + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" diff --git a/server.py b/server.py index c0f79cbd5..174d38af1 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import execution import uuid import json import glob +import struct from PIL import Image from io import BytesIO @@ -22,6 +23,12 @@ except ImportError: import mimetypes from comfy.cli_args import args +import comfy.utils +import comfy.model_management + + +class BinaryEventTypes: + PREVIEW_IMAGE = 1 @web.middleware @@ -216,6 +223,27 @@ class PromptServer(): file = os.path.join(output_dir, filename) if os.path.isfile(file): + if 'preview' in request.rel_url.query: + with Image.open(file) as img: + preview_info = request.rel_url.query['preview'].split(';') + + image_format = preview_info[0] + if image_format not in ['webp', 'jpeg']: + image_format = 'webp' + + quality = 90 + if preview_info[-1].isdigit(): + quality = int(preview_info[-1]) + + buffer = BytesIO() + if image_format in ['jpeg']: + img = img.convert("RGB") + img.save(buffer, format=image_format, quality=quality) + buffer.seek(0) + + return web.Response(body=buffer.read(), content_type=f'image/{image_format}', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + if 'channel' not in request.rel_url.query: channel = 'rgba' else: @@ -257,6 +285,50 @@ class PromptServer(): return web.Response(status=404) + @routes.get("/view_metadata/{folder_name}") + async def view_metadata(request): + folder_name = request.match_info.get("folder_name", None) + if folder_name is None: + return web.Response(status=404) + if not "filename" in request.rel_url.query: + return web.Response(status=404) + + filename = request.rel_url.query["filename"] + if not filename.endswith(".safetensors"): + return web.Response(status=404) + + safetensors_path = folder_paths.get_full_path(folder_name, filename) + if safetensors_path is None: + return web.Response(status=404) + out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024) + if out is None: + return web.Response(status=404) + dt = json.loads(out) + if not "__metadata__" in dt: + return web.Response(status=404) + return web.json_response(dt["__metadata__"]) + + @routes.get("/system_stats") + async def get_queue(request): + device = comfy.model_management.get_torch_device() + device_name = comfy.model_management.get_torch_device_name(device) + vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + system_stats = { + "devices": [ + { + "name": device_name, + "type": device.type, + "index": device.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + } + ] + } + return web.json_response(system_stats) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) @@ -338,7 +410,7 @@ class PromptServer(): prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) - return web.json_response({"prompt_id": prompt_id}) + return web.json_response({"prompt_id": prompt_id, "number": number}) else: print("invalid prompt:", valid[1]) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) @@ -391,16 +463,37 @@ class PromptServer(): return prompt_info async def send(self, event, data, sid=None): - message = {"type": event, "data": data} - - if isinstance(message, str) == False: - message = json.dumps(message) + if isinstance(data, (bytes, bytearray)): + await self.send_bytes(event, data, sid) + else: + await self.send_json(event, data, sid) + + def encode_bytes(self, event, data): + if not isinstance(event, int): + raise RuntimeError(f"Binary event types must be integers, got {event}") + + packed = struct.pack(">I", event) + message = bytearray(packed) + message.extend(data) + return message + + async def send_bytes(self, event, data, sid=None): + message = self.encode_bytes(event, data) if sid is None: for ws in self.sockets.values(): - await ws.send_str(message) + await ws.send_bytes(message) elif sid in self.sockets: - await self.sockets[sid].send_str(message) + await self.sockets[sid].send_bytes(message) + + async def send_json(self, event, data, sid=None): + message = {"type": event, "data": data} + + if sid is None: + for ws in self.sockets.values(): + await ws.send_json(message) + elif sid in self.sockets: + await self.sockets[sid].send_json(message) def send_sync(self, event, data, sid=None): self.loop.call_soon_threadsafe( diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index bfcd847a3..84c2a3d10 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -21,6 +21,7 @@ const colorPalettes = { "MODEL": "#B39DDB", // light lavender-purple "STYLE_MODEL": "#C2FFAE", // light green-yellow "VAE": "#FF6E6E", // bright red + "TAESD": "#DCC274", // cheesecake }, "litegraph_base": { "NODE_TITLE_COLOR": "#999", diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index 51e66f924..662d87e74 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -1,132 +1,138 @@ -import { app } from "/scripts/app.js"; +import {app} from "/scripts/app.js"; // Adds filtering to combo context menus -const id = "Comfy.ContextMenuFilter"; -app.registerExtension({ - name: id, +const ext = { + name: "Comfy.ContextMenuFilter", init() { const ctxMenu = LiteGraph.ContextMenu; + LiteGraph.ContextMenu = function (values, options) { const ctx = ctxMenu.call(this, values, options); // If we are a dark menu (only used for combo boxes) then add a filter input if (options?.className === "dark" && values?.length > 10) { const filter = document.createElement("input"); - Object.assign(filter.style, { - width: "calc(100% - 10px)", - border: "0", - boxSizing: "border-box", - background: "#333", - border: "1px solid #999", - margin: "0 0 5px 5px", - color: "#fff", - }); + filter.classList.add("comfy-context-menu-filter"); filter.placeholder = "Filter list"; this.root.prepend(filter); - let selectedIndex = 0; - let items = this.root.querySelectorAll(".litemenu-entry"); - let itemCount = items.length; - let selectedItem; + const items = Array.from(this.root.querySelectorAll(".litemenu-entry")); + let displayedItems = [...items]; + let itemCount = displayedItems.length; - // Apply highlighting to the selected item - function updateSelected() { - if (selectedItem) { - selectedItem.style.setProperty("background-color", ""); - selectedItem.style.setProperty("color", ""); - } - selectedItem = items[selectedIndex]; - if (selectedItem) { - selectedItem.style.setProperty("background-color", "#ccc", "important"); - selectedItem.style.setProperty("color", "#000", "important"); - } - } + // We must request an animation frame for the current node of the active canvas to update. + requestAnimationFrame(() => { + const currentNode = LGraphCanvas.active_canvas.current_node; + const clickedComboValue = currentNode.widgets + .filter(w => w.type === "combo" && w.options.values.length === values.length) + .find(w => w.options.values.every((v, i) => v === values[i])) + .value; - const positionList = () => { - const rect = this.root.getBoundingClientRect(); - - // If the top is off screen then shift the element with scaling applied - if (rect.top < 0) { - const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; - const shift = (this.root.clientHeight * scale) / 2; - this.root.style.top = -shift + "px"; - } - } - - updateSelected(); - - // Arrow up/down to select items - filter.addEventListener("keydown", (e) => { - if (e.key === "ArrowUp") { - if (selectedIndex === 0) { - selectedIndex = itemCount - 1; - } else { - selectedIndex--; - } - updateSelected(); - e.preventDefault(); - } else if (e.key === "ArrowDown") { - if (selectedIndex === itemCount - 1) { - selectedIndex = 0; - } else { - selectedIndex++; - } - updateSelected(); - e.preventDefault(); - } else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) { - selectedItem.click(); - } else if(e.key === "Escape") { - this.close(); - } - }); - - filter.addEventListener("input", () => { - // Hide all items that dont match our filter - const term = filter.value.toLocaleLowerCase(); - items = this.root.querySelectorAll(".litemenu-entry"); - // When filtering recompute which items are visible for arrow up/down - // Try and maintain selection - let visibleItems = []; - for (const item of items) { - const visible = !term || item.textContent.toLocaleLowerCase().includes(term); - if (visible) { - item.style.display = "block"; - if (item === selectedItem) { - selectedIndex = visibleItems.length; - } - visibleItems.push(item); - } else { - item.style.display = "none"; - if (item === selectedItem) { - selectedIndex = 0; - } - } - } - items = visibleItems; + let selectedIndex = values.findIndex(v => v === clickedComboValue); + let selectedItem = displayedItems?.[selectedIndex]; updateSelected(); - // If we have an event then we can try and position the list under the source - if (options.event) { - let top = options.event.clientY - 10; - - const bodyRect = document.body.getBoundingClientRect(); - const rootRect = this.root.getBoundingClientRect(); - if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { - top = Math.max(0, bodyRect.height - rootRect.height - 10); - } - - this.root.style.top = top + "px"; - positionList(); + // Apply highlighting to the selected item + function updateSelected() { + selectedItem?.style.setProperty("background-color", ""); + selectedItem?.style.setProperty("color", ""); + selectedItem = displayedItems[selectedIndex]; + selectedItem?.style.setProperty("background-color", "#ccc", "important"); + selectedItem?.style.setProperty("color", "#000", "important"); } - }); - requestAnimationFrame(() => { - // Focus the filter box when opening - filter.focus(); + const positionList = () => { + const rect = this.root.getBoundingClientRect(); - positionList(); - }); + // If the top is off-screen then shift the element with scaling applied + if (rect.top < 0) { + const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; + const shift = (this.root.clientHeight * scale) / 2; + this.root.style.top = -shift + "px"; + } + } + + // Arrow up/down to select items + filter.addEventListener("keydown", (event) => { + switch (event.key) { + case "ArrowUp": + event.preventDefault(); + if (selectedIndex === 0) { + selectedIndex = itemCount - 1; + } else { + selectedIndex--; + } + updateSelected(); + break; + case "ArrowRight": + event.preventDefault(); + selectedIndex = itemCount - 1; + updateSelected(); + break; + case "ArrowDown": + event.preventDefault(); + if (selectedIndex === itemCount - 1) { + selectedIndex = 0; + } else { + selectedIndex++; + } + updateSelected(); + break; + case "ArrowLeft": + event.preventDefault(); + selectedIndex = 0; + updateSelected(); + break; + case "Enter": + selectedItem?.click(); + break; + case "Escape": + this.close(); + break; + } + }); + + filter.addEventListener("input", () => { + // Hide all items that don't match our filter + const term = filter.value.toLocaleLowerCase(); + // When filtering, recompute which items are visible for arrow up/down and maintain selection. + displayedItems = items.filter(item => { + const isVisible = !term || item.textContent.toLocaleLowerCase().includes(term); + item.style.display = isVisible ? "block" : "none"; + return isVisible; + }); + + selectedIndex = 0; + if (displayedItems.includes(selectedItem)) { + selectedIndex = displayedItems.findIndex(d => d === selectedItem); + } + itemCount = displayedItems.length; + + updateSelected(); + + // If we have an event then we can try and position the list under the source + if (options.event) { + let top = options.event.clientY - 10; + + const bodyRect = document.body.getBoundingClientRect(); + const rootRect = this.root.getBoundingClientRect(); + if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { + top = Math.max(0, bodyRect.height - rootRect.height - 10); + } + + this.root.style.top = top + "px"; + positionList(); + } + }); + + requestAnimationFrame(() => { + // Focus the filter box when opening + filter.focus(); + + positionList(); + }); + }) } return ctx; @@ -134,4 +140,6 @@ app.registerExtension({ LiteGraph.ContextMenu.prototype = ctxMenu.prototype; }, -}); +} + +app.registerExtension(ext); diff --git a/web/extensions/core/dynamicPrompts.js b/web/extensions/core/dynamicPrompts.js index 7dae07f4d..599a9e685 100644 --- a/web/extensions/core/dynamicPrompts.js +++ b/web/extensions/core/dynamicPrompts.js @@ -3,6 +3,13 @@ import { app } from "../../scripts/app.js"; // Allows for simple dynamic prompt replacement // Inputs in the format {a|b} will have a random value of a or b chosen when the prompt is queued. +/* + * Strips C-style line and block comments from a string + */ +function stripComments(str) { + return str.replace(/\/\*[\s\S]*?\*\/|\/\/.*/g,''); +} + app.registerExtension({ name: "Comfy.DynamicPrompts", nodeCreated(node) { @@ -15,7 +22,7 @@ app.registerExtension({ for (const widget of widgets) { // Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node widget.serializeValue = (workflowNode, widgetIndex) => { - let prompt = widget.value; + let prompt = stripComments(widget.value); while (prompt.replace("\\{", "").includes("{") && prompt.replace("\\}", "").includes("}")) { const startIndex = prompt.replace("\\{", "00").indexOf("{"); const endIndex = prompt.replace("\\}", "00").indexOf("}"); diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 4b0c12747..764164d5e 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) { }); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); - ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString(); + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam(); if(ComfyApp.clipspace.images) ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; @@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog { imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); // update mask - backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); maskCanvas.width = drawWidth; maskCanvas.height = drawHeight; maskCanvas.style.top = imgCanvas.offsetTop + "px"; maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); }); @@ -335,6 +335,7 @@ class MaskEditorDialog extends ComfyDialog { const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) alpha_url.searchParams.delete('channel'); + alpha_url.searchParams.delete('preview'); alpha_url.searchParams.set('channel', 'a'); touched_image.src = alpha_url; @@ -345,6 +346,7 @@ class MaskEditorDialog extends ComfyDialog { const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); rgb_url.searchParams.delete('channel'); + rgb_url.searchParams.delete('preview'); rgb_url.searchParams.set('channel', 'rgb'); orig_image.src = rgb_url; this.image = orig_image; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 4fe0a6013..c356655b0 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -200,8 +200,23 @@ app.registerExtension({ applyToGraph() { if (!this.outputs[0].links?.length) return; + function get_links(node) { + let links = []; + for (const l of node.outputs[0].links) { + const linkInfo = app.graph.links[l]; + const n = node.graph.getNodeById(linkInfo.target_id); + if (n.type == "Reroute") { + links = links.concat(get_links(n)); + } else { + links.push(l); + } + } + return links; + } + + let links = get_links(this); // For each output link copy our value over the original widget value - for (const l of this.outputs[0].links) { + for (const l of links) { const linkInfo = app.graph.links[l]; const node = this.graph.getNodeById(linkInfo.target_id); const input = node.inputs[linkInfo.target_slot]; diff --git a/web/index.html b/web/index.html index bb79433ce..da0adb6c2 100644 --- a/web/index.html +++ b/web/index.html @@ -14,5 +14,5 @@ window.graph = app.graph; - + diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 95f4a2735..a60848d77 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action) if (this.onShowNodePanel) { this.onShowNodePanel(n); } - else - { - this.showShowNodePanel(n); - } if (this.onNodeDblClicked) { this.onNodeDblClicked(n); @@ -8099,11 +8095,15 @@ LGraphNode.prototype.executeAction = function(action) bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR; hovercolor = hovercolor || "#555"; textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR; - var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title - var pos = this.mouse; - var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); - pos = this.last_click_position; - var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); + var pos = this.ds.convertOffsetToCanvas(this.graph_mouse); + var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h ); + pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null; + if(pos) { + var rect = this.canvas.getBoundingClientRect(); + pos[0] -= rect.left; + pos[1] -= rect.top; + } + var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h ); ctx.fillStyle = hover ? hovercolor : bgcolor; if(clicked) @@ -13067,6 +13067,10 @@ LGraphNode.prototype.executeAction = function(action) has_submenu: true, callback: LGraphCanvas.onShowMenuNodeProperties }, + { + content: "Properties Panel", + callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) } + }, null, { content: "Title", diff --git a/web/scripts/api.js b/web/scripts/api.js index 4f061c358..8313f1abe 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -42,6 +42,7 @@ class ComfyApi extends EventTarget { this.socket = new WebSocket( `ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}` ); + this.socket.binaryType = "arraybuffer"; this.socket.addEventListener("open", () => { opened = true; @@ -70,33 +71,65 @@ class ComfyApi extends EventTarget { this.socket.addEventListener("message", (event) => { try { - const msg = JSON.parse(event.data); - switch (msg.type) { - case "status": - if (msg.data.sid) { - this.clientId = msg.data.sid; - window.name = this.clientId; + if (event.data instanceof ArrayBuffer) { + const view = new DataView(event.data); + const eventType = view.getUint32(0); + const buffer = event.data.slice(4); + switch (eventType) { + case 1: + const view2 = new DataView(event.data); + const imageType = view2.getUint32(0) + let imageMime + switch (imageType) { + case 1: + default: + imageMime = "image/jpeg"; + break; + case 2: + imageMime = "image/png" } - this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); - break; - case "progress": - this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); - break; - case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); - break; - case "executed": - this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); + const imageBlob = new Blob([buffer.slice(4)], { type: imageMime }); + this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob })); break; default: - if (this.#registered.has(msg.type)) { - this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); - } else { - throw new Error("Unknown message type"); - } + throw new Error(`Unknown binary websocket message of type ${eventType}`); + } + } + else { + const msg = JSON.parse(event.data); + switch (msg.type) { + case "status": + if (msg.data.sid) { + this.clientId = msg.data.sid; + window.name = this.clientId; + } + this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); + break; + case "progress": + this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); + break; + case "executing": + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + break; + case "executed": + this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); + break; + case "execution_start": + this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); + break; + case "execution_error": + this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + break; + default: + if (this.#registered.has(msg.type)) { + this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); + } else { + throw new Error(`Unknown message type ${msg.type}`); + } + } } } catch (error) { - console.warn("Unhandled message:", event.data); + console.warn("Unhandled message:", event.data, error); } }); } diff --git a/web/scripts/app.js b/web/scripts/app.js index 88ce388ed..7be8ea537 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -44,6 +44,12 @@ export class ComfyApp { */ this.nodeOutputs = {}; + /** + * Stores the preview image data for each node + * @type {Record} + */ + this.nodePreviewImages = {}; + /** * If the shift key on the keyboard is pressed * @type {boolean} @@ -51,6 +57,14 @@ export class ComfyApp { this.shiftDown = false; } + getPreviewFormatParam() { + let preview_format = this.ui.settings.getSettingValue("Comfy.PreviewFormat"); + if(preview_format) + return `&preview=${preview_format}`; + else + return ""; + } + static isImageNode(node) { return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); } @@ -111,10 +125,14 @@ export class ComfyApp { if(ComfyApp.clipspace.imgs && node.imgs) { if(node.images && ComfyApp.clipspace.images) { if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; } - else - app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; + else { + node.images = ComfyApp.clipspace.images; + } + + if(app.nodeOutputs[node.id + ""]) + app.nodeOutputs[node.id + ""].images = node.images; } if(ComfyApp.clipspace.imgs) { @@ -147,7 +165,16 @@ export class ComfyApp { if(ComfyApp.clipspace.widgets) { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { + if (prop && prop.type != 'image') { + if(typeof prop.value == "string" && value.filename) { + prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:''); + } + else { + prop.value = value; + prop.callback(value); + } + } + else if (prop && prop.type != 'button') { prop.value = value; prop.callback(value); } @@ -231,14 +258,20 @@ export class ComfyApp { options.unshift( { content: "Open Image", - callback: () => window.open(img.src, "_blank"), + callback: () => { + let url = new URL(img.src); + url.searchParams.delete('preview'); + window.open(url, "_blank") + }, }, { content: "Save Image", callback: () => { const a = document.createElement("a"); - a.href = img.src; - a.setAttribute("download", new URLSearchParams(new URL(img.src).search).get("filename")); + let url = new URL(img.src); + url.searchParams.delete('preview'); + a.href = url; + a.setAttribute("download", new URLSearchParams(url.search).get("filename")); document.body.append(a); a.click(); requestAnimationFrame(() => a.remove()); @@ -345,6 +378,10 @@ export class ComfyApp { } node.prototype.setSizeForImage = function () { + if (this.inputHeight) { + this.setSize(this.size); + return; + } const minHeight = getImageTop(this) + 220; if (this.size[1] < minHeight) { this.setSize([this.size[0], minHeight]); @@ -353,29 +390,52 @@ export class ComfyApp { node.prototype.onDrawBackground = function (ctx) { if (!this.flags.collapsed) { + let imgURLs = [] + let imagesChanged = false + const output = app.nodeOutputs[this.id + ""]; if (output && output.images) { if (this.images !== output.images) { this.images = output.images; - this.imgs = null; - this.imageIndex = null; + imagesChanged = true; + imgURLs = imgURLs.concat(output.images.map(params => { + return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam(); + })) + } + } + + const preview = app.nodePreviewImages[this.id + ""] + if (this.preview !== preview) { + this.preview = preview + imagesChanged = true; + if (preview != null) { + imgURLs.push(preview); + } + } + + if (imagesChanged) { + this.imageIndex = null; + if (imgURLs.length > 0) { Promise.all( - output.images.map((src) => { + imgURLs.map((src) => { return new Promise((r) => { const img = new Image(); img.onload = () => r(img); img.onerror = () => r(null); - img.src = "/view?" + new URLSearchParams(src).toString(); + img.src = src }); }) ).then((imgs) => { - if (this.images === output.images) { + if ((!output || this.images === output.images) && (!preview || this.preview === preview)) { this.imgs = imgs.filter(Boolean); this.setSizeForImage?.(); app.graph.setDirtyCanvas(true); } }); } + else { + this.imgs = null; + } } if (this.imgs && this.imgs.length) { @@ -771,16 +831,27 @@ export class ComfyApp { LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) { const res = origDrawNodeShape.apply(this, arguments); + const nodeErrors = self.lastPromptError?.node_errors[node.id]; + let color = null; + let lineWidth = 1; if (node.id === +self.runningNodeId) { color = "#0f0"; } else if (self.dragOverNode && node.id === self.dragOverNode.id) { color = "dodgerblue"; } + else if (self.lastPromptError != null && nodeErrors?.errors) { + color = "red"; + lineWidth = 2; + } + else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) { + color = "#f0f"; + lineWidth = 2; + } if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; - ctx.lineWidth = 1; + ctx.lineWidth = lineWidth; ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) @@ -807,11 +878,28 @@ export class ComfyApp { ctx.stroke(); ctx.strokeStyle = fgcolor; ctx.globalAlpha = 1; + } - if (self.progress) { - ctx.fillStyle = "green"; - ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); - ctx.fillStyle = bgcolor; + if (self.progress && node.id === +self.runningNodeId) { + ctx.fillStyle = "green"; + ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); + ctx.fillStyle = bgcolor; + } + + // Highlight inputs that failed validation + if (nodeErrors) { + ctx.lineWidth = 2; + ctx.strokeStyle = "red"; + for (const error of nodeErrors.errors) { + if (error.extra_info && error.extra_info.input_name) { + const inputIndex = node.findInputSlot(error.extra_info.input_name) + if (inputIndex !== -1) { + let pos = node.getConnectionPos(true, inputIndex); + ctx.beginPath(); + ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false) + ctx.stroke(); + } + } } } @@ -859,16 +947,40 @@ export class ComfyApp { this.progress = null; this.runningNodeId = detail; this.graph.setDirtyCanvas(true, false); + delete this.nodePreviewImages[this.runningNodeId] }); api.addEventListener("executed", ({ detail }) => { this.nodeOutputs[detail.node] = detail.output; const node = this.graph.getNodeById(detail.node); - if (node?.onExecuted) { - node.onExecuted(detail.output); + if (node) { + if (node.onExecuted) + node.onExecuted(detail.output); } }); + api.addEventListener("execution_start", ({ detail }) => { + this.runningNodeId = null; + this.lastExecutionError = null + }); + + api.addEventListener("execution_error", ({ detail }) => { + this.lastExecutionError = detail; + const formattedError = this.#formatExecutionError(detail); + this.ui.dialog.show(formattedError); + this.canvas.draw(true, true); + }); + + api.addEventListener("b_preview", ({ detail }) => { + const id = this.runningNodeId + if (id == null) + return; + + const blob = detail + const blobUrl = URL.createObjectURL(blob) + this.nodePreviewImages[id] = [blobUrl] + }); + api.init(); } @@ -975,6 +1087,11 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); + await this.registerNodesFromDefs(defs); + await this.#invokeExtensionsAsync("registerCustomNodes"); + } + + async registerNodesFromDefs(defs) { await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); // Generate list of known widgets @@ -1047,8 +1164,6 @@ export class ComfyApp { LiteGraph.registerNodeType(nodeId, node); node.category = nodeData.category; } - - await this.#invokeExtensionsAsync("registerCustomNodes"); } /** @@ -1247,6 +1362,43 @@ export class ComfyApp { return { workflow, output }; } + #formatPromptError(error) { + if (error == null) { + return "(unknown error)" + } + else if (typeof error === "string") { + return error; + } + else if (error.stack && error.message) { + return error.toString() + } + else if (error.response) { + let message = error.response.error.message; + if (error.response.error.details) + message += ": " + error.response.error.details; + for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) { + message += "\n" + nodeError.class_type + ":" + for (const errorReason of nodeError.errors) { + message += "\n - " + errorReason.message + ": " + errorReason.details + } + } + return message + } + return "(unknown error)" + } + + #formatExecutionError(error) { + if (error == null) { + return "(unknown error)" + } + + const traceback = error.traceback.join("") + const nodeId = error.node_id + const nodeType = error.node_type + + return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}` + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1254,8 +1406,10 @@ export class ComfyApp { if (this.#processingQueue) { return; } - + this.#processingQueue = true; + this.lastPromptError = null; + try { while (this.#queueItems.length) { ({ number, batchCount } = this.#queueItems.pop()); @@ -1266,7 +1420,12 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response.error || error.toString()); + const formattedError = this.#formatPromptError(error) + this.ui.dialog.show(formattedError); + if (error.response) { + this.lastPromptError = error.response; + this.canvas.draw(true, true); + } break; } @@ -1345,6 +1504,11 @@ export class ComfyApp { const def = defs[node.type]; + // HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes, + // and additional work is needed to consider the primitive logic in the refresh logic. + if(!def) + continue; + for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { @@ -1364,6 +1528,10 @@ export class ComfyApp { */ clean() { this.nodeOutputs = {}; + this.nodePreviewImages = {} + this.lastPromptError = null; + this.lastExecutionError = null; + this.runningNodeId = null; } } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 8ddb7a1c5..977b5ac2f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) { const embeddings = await api.getEmbeddings(); const opts = parameters .substr(p) + .split("\n")[1] .split(",") .reduce((p, n) => { const s = n.split(":"); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 2c9043d00..a26eedec3 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -462,6 +462,24 @@ export class ComfyUI { defaultValue: true, }); + /** + * file format for preview + * + * format;quality + * + * ex) + * webp;50 -> webp, quality 50 + * jpeg;80 -> rgb, jpeg, quality 80 + * + * @type {string} + */ + const previewImage = this.settings.addSetting({ + id: "Comfy.PreviewFormat", + name: "When displaying a preview in the image widget, convert it to a lightweight image. (webp, jpeg, webp;50, ...)", + type: "string", + defaultValue: "", + }); + const fileInput = $el("input", { id: "comfy-file-input", type: "file", diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 82168b08b..dfa26aef4 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -115,12 +115,12 @@ function addMultilineWidget(node, name, opts, app) { // See how large each text input can be freeSpace -= widgetHeight; - freeSpace /= multi.length; + freeSpace /= multi.length + (!!node.imgs?.length); if (freeSpace < MIN_SIZE) { // There isnt enough space for all the widgets, increase the size of the node freeSpace = MIN_SIZE; - node.size[1] = y + widgetHeight + freeSpace * multi.length; + node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length)); node.graph.setDirtyCanvas(true); } @@ -303,7 +303,7 @@ export const ComfyWidgets = { subfolder = name.substring(0, folder_separator); name = name.substring(folder_separator + 1); } - img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; + img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`; node.setSizeForImage?.(); } diff --git a/web/style.css b/web/style.css index 5323ed4bd..cff9fa514 100644 --- a/web/style.css +++ b/web/style.css @@ -50,7 +50,7 @@ body { padding: 30px 30px 10px 30px; background-color: var(--comfy-menu-bg); /* Modal background */ color: var(--error-text); - box-shadow: 0px 0px 20px #888888; + box-shadow: 0 0 20px #888888; border-radius: 10px; top: 50%; left: 50%; @@ -84,7 +84,7 @@ body { font-size: 15px; position: absolute; top: 50%; - right: 0%; + right: 0; text-align: center; z-index: 100; width: 170px; @@ -252,7 +252,7 @@ button.comfy-queue-btn { bottom: 0 !important; left: auto !important; right: 0 !important; - border-radius: 0px; + border-radius: 0; } .comfy-menu span.drag-handle { visibility:hidden @@ -289,6 +289,11 @@ button.comfy-queue-btn { /* Context menu */ +.litegraph .dialog { + z-index: 1; + font-family: Arial, sans-serif; +} + .litegraph .litemenu-entry.has_submenu { position: relative; padding-right: 20px; @@ -325,12 +330,20 @@ button.comfy-queue-btn { color: var(--input-text) !important; } +.comfy-context-menu-filter { + box-sizing: border-box; + border: 1px solid #999; + margin: 0 0 5px 5px; + width: calc(100% - 10px); +} + /* Search box */ .litegraph.litesearchbox { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; overflow: hidden; + display: block; } .litegraph.litesearchbox input,