Merge branch 'comfyanonymous:master' into dpr

This commit is contained in:
kali-linex 2023-06-11 01:26:40 +02:00 committed by GitHub
commit 7352a6b41a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1759 additions and 636 deletions

View File

@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'):
else: else:
raise AssertionError('Unknown merge analysis result') raise AssertionError('Unknown merge analysis result')
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
repo = pygit2.Repository(str(sys.argv[1])) repo = pygit2.Repository(str(sys.argv[1]))
ident = pygit2.Signature('comfyui', 'comfy@ui') ident = pygit2.Signature('comfyui', 'comfy@ui')
try: try:

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ custom_nodes/
!custom_nodes/example_node.py.example !custom_nodes/example_node.py.example
extra_model_paths.yaml extra_model_paths.yaml
/.vs /.vs
.idea/

View File

@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models. - [Config file](extra_model_paths.yaml.example) to set the search paths for models.
@ -37,28 +38,28 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Shortcuts ## Shortcuts
| Keybind | Explanation | | Keybind | Explanation |
| - | - | |---------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation | | Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation | | Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + S | Save workflow | | Ctrl + S | Save workflow |
| Ctrl + O | Load workflow | | Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes | | Ctrl + A | Select all nodes |
| Ctrl + M | Mute/unmute selected nodes | | Ctrl + M | Mute/unmute selected nodes |
| Delete/Backspace | Delete selected nodes | | Delete/Backspace | Delete selected nodes |
| Ctrl + Delete/Backspace | Delete the current graph | | Ctrl + Delete/Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor | | Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection | | Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | | Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | | Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time | | Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph | | Ctrl + D | Load default graph |
| Q | Toggle visibility of the queue | | Q | Toggle visibility of the queue |
| H | Toggle visibility of history | | H | Toggle visibility of history |
| R | Refresh graph | | R | Refresh graph |
| Double-Click LMB | Open node quick search palette | | Double-Click LMB | Open node quick search palette |
Ctrl can also be replaced with Cmd instead for MacOS users Ctrl can also be replaced with Cmd instead for macOS users
# Installing # Installing
@ -118,13 +119,26 @@ After this you should have everything installed and can proceed to running Comfy
### Others: ### Others:
[Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476) #### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own. #### Apple Mac silicon
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide.
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
1. Launch ComfyUI by running `python main.py`.
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies? ### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
You don't. If you have another UI installed and working with it's own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it: You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
```source path_to_other_sd_gui/venv/bin/activate``` ```source path_to_other_sd_gui/venv/bin/activate```
@ -134,7 +148,7 @@ With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"```
With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"``` With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"```
And then you can use that terminal to run Comfyui without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI. And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
# Running # Running
@ -158,6 +172,8 @@ You can use () to change emphasis of a word or phrase like: (good code:1.2) or (
You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}. You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}.
Dynamic prompts also support C-style comments, like `// comment` or `/* comment */`.
To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension): To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension):
```embedding:embedding_filename.pt``` ```embedding:embedding_filename.pt```
@ -181,6 +197,12 @@ You can set this command line setting to disable the upcasting to fp32 in some c
```--dont-upcast-attention``` ```--dont-upcast-attention```
## How to show high-quality previews?
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
## Support and dev channel ## Support and dev channel
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source). [Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).

View File

@ -1,4 +1,35 @@
import argparse import argparse
import enum
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -13,6 +44,14 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")

View File

@ -1,14 +1,5 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
import os.path as osp
import re import re
import torch import torch
from safetensors.torch import load_file, save_file
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict return text_enc_dict
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
clip = None
vae = None
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return ModelPatcher(model), clip, vae

87
comfy/diffusers_load.py Normal file
View File

@ -0,0 +1,87 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
import diffusers_convert
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)

66
comfy/model_base.py Normal file
View File

@ -0,0 +1,66 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import numpy as np
class BaseModel(torch.nn.Module):
def __init__(self, unet_config, v_prediction=False):
super().__init__()
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
if self.v_prediction:
self.parameterization = "v"
else:
self.parameterization = "eps"
if "adm_in_channels" in unet_config:
self.adm_channels = unet_config["adm_in_channels"]
else:
self.adm_channels = 0
print("v_prediction", v_prediction)
print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
if c_concat is not None:
xc = torch.cat([x] + c_concat, dim=1)
else:
xc = x
context = torch.cat(c_crossattn, 1)
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
def get_dtype(self):
return self.diffusion_model.dtype
def is_adm(self):
return self.adm_channels > 0
class SD21UNCLIP(BaseModel):
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
super().__init__(unet_config, v_prediction)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.concat_keys = ("mask", "masked_image")

View File

@ -1,23 +1,29 @@
import psutil import psutil
from enum import Enum from enum import Enum
from comfy.cli_args import args from comfy.cli_args import args
import torch
class VRAMState(Enum): class VRAMState(Enum):
CPU = 0 DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 NO_VRAM = 1 #Very low vram: enable all the options to save vram
LOW_VRAM = 2 LOW_VRAM = 2
NORMAL_VRAM = 3 NORMAL_VRAM = 3
HIGH_VRAM = 4 HIGH_VRAM = 4
MPS = 5 SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
# Determine VRAM State # Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM
cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
total_vram_available_mb = -1
accelerate_enabled = False lowvram_available = True
xpu_available = False xpu_available = False
directml_enabled = False directml_enabled = False
@ -31,30 +37,80 @@ if args.directml is not None:
directml_device = torch_directml.device(device_index) directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index)) print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True) # torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try: try:
import torch import intel_extension_for_pytorch as ipex
if directml_enabled: if torch.xpu.is_available():
total_vram = 4097 #TODO xpu_available = True
else:
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
if not args.normalvram and not args.cpu:
if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
except: except:
pass pass
try:
if torch.backends.mps.is_available():
cpu_state = CPUState.MPS
except:
pass
if args.cpu:
cpu_state = CPUState.CPU
def get_torch_device():
global xpu_available
global directml_enabled
global cpu_state
if directml_enabled:
global directml_device
return directml_device
if cpu_state == CPUState.MPS:
return torch.device("mps")
if cpu_state == CPUState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
try: try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except: except:
@ -92,6 +148,7 @@ if ENABLE_PYTORCH_ATTENTION:
if args.lowvram: if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
lowvram_available = True
elif args.novram: elif args.novram:
set_vram_to = VRAMState.NO_VRAM set_vram_to = VRAMState.NO_VRAM
elif args.highvram: elif args.highvram:
@ -102,54 +159,38 @@ if args.force_fp32:
print("Forcing FP32, if this improves things please report it.") print("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True FORCE_FP32 = True
if lowvram_available:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
try: try:
import accelerate import accelerate
accelerate_enabled = True if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to vram_state = set_vram_to
except Exception as e: except Exception as e:
import traceback import traceback
print(traceback.format_exc()) print(traceback.format_exc())
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
total_vram_available_mb = (total_vram - 1024) // 2
total_vram_available_mb = int(max(256, total_vram_available_mb))
try: if cpu_state != CPUState.GPU:
if torch.backends.mps.is_available(): vram_state = VRAMState.DISABLED
vram_state = VRAMState.MPS
except:
pass
if args.cpu: if cpu_state == CPUState.MPS:
vram_state = VRAMState.CPU vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}") print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device): def get_torch_device_name(device):
if hasattr(device, 'type'): if hasattr(device, 'type'):
return "{}".format(device.type) if device.type == "cuda":
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "{} {}".format(device, torch.cuda.get_device_name(device))
else:
return "{}".format(device.type)
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try: try:
print("Using device:", get_torch_device_name(get_torch_device())) print("Device:", get_torch_device_name(get_torch_device()))
except: except:
print("Could not pick default device.") print("Could not pick default device.")
@ -199,22 +240,29 @@ def load_model_gpu(model):
model.unpatch_model() model.unpatch_model()
raise e raise e
model.model_patches_to(get_torch_device()) torch_dev = get_torch_device()
model.model_patches_to(torch_dev)
vram_set_state = vram_state
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = model.model_size()
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
current_loaded_model = model current_loaded_model = model
if vram_state == VRAMState.CPU:
if vram_set_state == VRAMState.DISABLED:
pass pass
elif vram_state == VRAMState.MPS: elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False model_accelerated = False
real_model.to(get_torch_device()) real_model.to(get_torch_device())
else: else:
if vram_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == VRAMState.LOW_VRAM: elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
model_accelerated = True model_accelerated = True
@ -223,7 +271,7 @@ def load_model_gpu(model):
def load_controlnet_gpu(control_models): def load_controlnet_gpu(control_models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == VRAMState.CPU: if vram_state == VRAMState.DISABLED:
return return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
@ -268,7 +316,8 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
global xpu_available global xpu_available
global directml_enabled global directml_enabled
if vram_state == VRAMState.CPU: global cpu_state
if cpu_state != CPUState.GPU:
return False return False
if xpu_available: if xpu_available:
return False return False
@ -340,12 +389,12 @@ def maximum_batch_area():
return int(max(area, 0)) return int(max(area, 0))
def cpu_mode(): def cpu_mode():
global vram_state global cpu_state
return vram_state == VRAMState.CPU return cpu_state == CPUState.CPU
def mps_mode(): def mps_mode():
global vram_state global cpu_state
return vram_state == VRAMState.MPS return cpu_state == CPUState.MPS
def should_use_fp16(): def should_use_fp16():
global xpu_available global xpu_available
@ -377,7 +426,10 @@ def should_use_fp16():
def soft_empty_cache(): def soft_empty_cache():
global xpu_available global xpu_available
if xpu_available: global cpu_state
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
elif xpu_available:
torch.xpu.empty_cache() torch.xpu.empty_cache()
elif torch.cuda.is_available(): elif torch.cuda.is_available():
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda

View File

@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c['transformer_options'] = transformer_options c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
del input_x del input_x
model_management.throw_exception_if_processing_interrupted() model_management.throw_exception_if_processing_interrupted()
@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond[temp[1]] = [o[0], n] uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device): def encode_adm(conds, batch_size, device, noise_augmentor=None):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
if 'adm' in x[1]: adm_out = None
adm_inputs = [] if noise_augmentor is not None:
weights = [] if 'adm' in x[1]:
noise_aug = [] adm_inputs = []
adm_in = x[1]["adm"] weights = []
for adm_c in adm_in: noise_aug = []
adm_cond = adm_c[0].image_embeds adm_in = x[1]["adm"]
weight = adm_c[1] for adm_c in adm_in:
noise_augment = adm_c[2] adm_cond = adm_c[0].image_embeds
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) weight = adm_c[1]
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) noise_augment = adm_c[2]
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
weights.append(weight) c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
noise_aug.append(noise_augment) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
adm_inputs.append(adm_out) weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1: if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0) adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this #TODO: add a way to control this
noise_augment = 0.05 noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
else: else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) if 'adm' in x[1]:
x[1] = x[1].copy() adm_out = x[1]["adm"].to(device)
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) if adm_out is not None:
x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
return conds return conds
@ -591,14 +597,17 @@ class KSampler:
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16: if self.model.get_dtype() == torch.float16:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
if hasattr(self.model, 'noise_augmentor'): #unclip if self.model.is_adm():
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) noise_augmentor = None
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) if hasattr(self.model, 'noise_augmentor'): #unclip
noise_augmentor = self.model.noise_augmentor
positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor)
negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}

View File

@ -14,8 +14,16 @@ from .t2i_adapter import adapter
from . import utils from . import utils
from . import clip_vision from . import clip_vision
from . import gligen from . import gligen
from . import diffusers_convert
from . import model_base
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
replace_prefix = {"model.diffusion_model.": "diffusion_model."}
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys())))
for x in replace:
sd[x[1]] = sd.pop(x[0])
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
k = list(sd.keys()) k = list(sd.keys())
@ -30,17 +38,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if ids.dtype == torch.float32: if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
keys_to_replace = {
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
}
for x in keys_to_replace:
if x in sd:
sd[keys_to_replace[x]] = sd.pop(x)
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24) sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
for x in load_state_dict_to: for x in load_state_dict_to:
@ -192,7 +189,7 @@ def model_lora_keys(model, key_map={}):
counter = 0 counter = 0
for b in range(12): for b in range(12):
tk = "model.diffusion_model.input_blocks.{}.1".format(b) tk = "diffusion_model.input_blocks.{}.1".format(b)
up_counter = 0 up_counter = 0
for c in LORA_UNET_MAP_ATTENTIONS: for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c) k = "{}.{}.weight".format(tk, c)
@ -203,13 +200,13 @@ def model_lora_keys(model, key_map={}):
if up_counter >= 4: if up_counter >= 4:
counter += 1 counter += 1
for c in LORA_UNET_MAP_ATTENTIONS: for c in LORA_UNET_MAP_ATTENTIONS:
k = "model.diffusion_model.middle_block.1.{}.weight".format(c) k = "diffusion_model.middle_block.1.{}.weight".format(c)
if k in sdk: if k in sdk:
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c]) lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k key_map[lora_key] = k
counter = 3 counter = 3
for b in range(12): for b in range(12):
tk = "model.diffusion_model.output_blocks.{}.1".format(b) tk = "diffusion_model.output_blocks.{}.1".format(b)
up_counter = 0 up_counter = 0
for c in LORA_UNET_MAP_ATTENTIONS: for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c) k = "{}.{}.weight".format(tk, c)
@ -233,7 +230,7 @@ def model_lora_keys(model, key_map={}):
ds_counter = 0 ds_counter = 0
counter = 0 counter = 0
for b in range(12): for b in range(12):
tk = "model.diffusion_model.input_blocks.{}.0".format(b) tk = "diffusion_model.input_blocks.{}.0".format(b)
key_in = False key_in = False
for c in LORA_UNET_MAP_RESNET: for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c) k = "{}.{}.weight".format(tk, c)
@ -252,7 +249,7 @@ def model_lora_keys(model, key_map={}):
counter = 0 counter = 0
for b in range(3): for b in range(3):
tk = "model.diffusion_model.middle_block.{}".format(b) tk = "diffusion_model.middle_block.{}".format(b)
key_in = False key_in = False
for c in LORA_UNET_MAP_RESNET: for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c) k = "{}.{}.weight".format(tk, c)
@ -266,7 +263,7 @@ def model_lora_keys(model, key_map={}):
counter = 0 counter = 0
us_counter = 0 us_counter = 0
for b in range(12): for b in range(12):
tk = "model.diffusion_model.output_blocks.{}.0".format(b) tk = "diffusion_model.output_blocks.{}.0".format(b)
key_in = False key_in = False
for c in LORA_UNET_MAP_RESNET: for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c) k = "{}.{}.weight".format(tk, c)
@ -285,15 +282,29 @@ def model_lora_keys(model, key_map={}):
return key_map return key_map
class ModelPatcher: class ModelPatcher:
def __init__(self, model): def __init__(self, model, size=0):
self.size = size
self.model = model self.model = model
self.patches = [] self.patches = []
self.backup = {} self.backup = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options":{}}
self.model_size()
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
return size
def clone(self): def clone(self):
n = ModelPatcher(self.model) n = ModelPatcher(self.model, self.size)
n.patches = self.patches[:] n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options) n.model_options = copy.deepcopy(self.model_options)
return n return n
@ -328,7 +339,7 @@ class ModelPatcher:
patch_list[i] = patch_list[i].to(device) patch_list[i] = patch_list[i].to(device)
def model_dtype(self): def model_dtype(self):
return self.model.diffusion_model.dtype return self.model.get_dtype()
def add_patches(self, patches, strength=1.0): def add_patches(self, patches, strength=1.0):
p = {} p = {}
@ -504,10 +515,16 @@ class VAE:
if config is None: if config is None:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path) self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
else: else:
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path) self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval() self.first_stage_model = self.first_stage_model.eval()
if ckpt_path is not None:
sd = utils.load_torch_file(ckpt_path)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.first_stage_model.load_state_dict(sd, strict=False)
self.scale_factor = scale_factor self.scale_factor = scale_factor
if device is None: if device is None:
device = model_management.get_torch_device() device = model_management.get_torch_device()
@ -600,7 +617,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
return torch.cat([tensor] * batched_number, dim=0) return torch.cat([tensor] * batched_number, dim=0)
class ControlNet: class ControlNet:
def __init__(self, control_model, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
self.control_model = control_model self.control_model = control_model
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
@ -609,6 +626,7 @@ class ControlNet:
device = model_management.get_torch_device() device = model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond_txt, batched_number): def get_control(self, x_noisy, t, cond_txt, batched_number):
control_prev = None control_prev = None
@ -644,6 +662,9 @@ class ControlNet:
key = 'output' key = 'output'
index = i index = i
x = control[i] x = control[i]
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength x *= self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype) x = x.to(output_dtype)
@ -674,7 +695,7 @@ class ControlNet:
self.cond_hint = None self.cond_hint = None
def copy(self): def copy(self):
c = ControlNet(self.control_model) c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
return c return c
@ -722,7 +743,7 @@ def load_controlnet(ckpt_path, model=None):
use_spatial_transformer=True, use_spatial_transformer=True,
transformer_depth=1, transformer_depth=1,
context_dim=context_dim, context_dim=context_dim,
use_checkpoint=True, use_checkpoint=False,
legacy=False, legacy=False,
use_fp16=use_fp16) use_fp16=use_fp16)
else: else:
@ -739,7 +760,7 @@ def load_controlnet(ckpt_path, model=None):
use_linear_in_transformer=True, use_linear_in_transformer=True,
transformer_depth=1, transformer_depth=1,
context_dim=context_dim, context_dim=context_dim,
use_checkpoint=True, use_checkpoint=False,
legacy=False, legacy=False,
use_fp16=use_fp16) use_fp16=use_fp16)
if pth: if pth:
@ -750,7 +771,7 @@ def load_controlnet(ckpt_path, model=None):
for x in controlnet_data: for x in controlnet_data:
c_m = "control_model." c_m = "control_model."
if x.startswith(c_m): if x.startswith(c_m):
sd_key = "model.diffusion_model.{}".format(x[len(c_m):]) sd_key = "diffusion_model.{}".format(x[len(c_m):])
if sd_key in model_sd: if sd_key in model_sd:
cd = controlnet_data[x] cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device) cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
@ -769,7 +790,11 @@ def load_controlnet(ckpt_path, model=None):
if use_fp16: if use_fp16:
control_model = control_model.half() control_model = control_model.half()
control = ControlNet(control_model) global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
return control return control
class T2IAdapter: class T2IAdapter:
@ -913,9 +938,10 @@ def load_gligen(ckpt_path):
model = model.half() model = model.half()
return model return model
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
with open(config_path, 'r') as stream: if config is None:
config = yaml.safe_load(stream) with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params'] model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config'] clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor'] scale_factor = model_config_params['scale_factor']
@ -924,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
fp16 = False fp16 = False
if "unet_config" in model_config_params: if "unet_config" in model_config_params:
if "params" in model_config_params["unet_config"]: if "params" in model_config_params["unet_config"]:
if "use_fp16" in model_config_params["unet_config"]["params"]: unet_config = model_config_params["unet_config"]["params"]
fp16 = model_config_params["unet_config"]["params"]["use_fp16"] if "use_fp16" in unet_config:
fp16 = unet_config["use_fp16"]
noise_aug_config = None
if "noise_aug_config" in model_config_params:
noise_aug_config = model_config_params["noise_aug_config"]
v_prediction = False
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
v_prediction = True
clip = None clip = None
vae = None vae = None
@ -945,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w] load_state_dict_to = [w]
model = instantiate_from_config(config["model"]) if config['model']["target"].endswith("LatentInpaintDiffusion"):
sd = utils.load_torch_file(ckpt_path) model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
if state_dict is None:
state_dict = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16: if fp16:
model = model.half() model = model.half()
@ -1024,7 +1068,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
} }
unet_config = { unet_config = {
"use_checkpoint": True, "use_checkpoint": False,
"image_size": 32, "image_size": 32,
"out_channels": 4, "out_channels": 4,
"attention_resolutions": [ "attention_resolutions": [
@ -1044,47 +1088,59 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
"legacy": False "legacy": False
} }
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2: if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2:
unet_config['use_linear_in_transformer'] = True unet_config['use_linear_in_transformer'] = True
unet_config["use_fp16"] = fp16 unet_config["use_fp16"] = fp16
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0] unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
unclip_model = False
inpaint_model = False
if noise_aug_config is not None: #SD2.x unclip model if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96 sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25 sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm' sd_config["conditioning_key"] = 'crossattn-adm'
unclip_model = True
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid" sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None sd_config["finetune_keys"] = None
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion" model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
inpaint_model = True
else: else:
sd_config["conditioning_key"] = "crossattn" sd_config["conditioning_key"] = "crossattn"
if unet_config["context_dim"] == 1024: if unet_config["context_dim"] == 768:
unet_config["num_head_channels"] = 64 #SD2.x
else:
unet_config["num_heads"] = 8 #SD1.x unet_config["num_heads"] = 8 #SD1.x
else:
unet_config["num_head_channels"] = 64 #SD2.x
unclip = 'model.diffusion_model.label_emb.0.0.weight' unclip = 'model.diffusion_model.label_emb.0.0.weight'
if unclip in sd_keys: if unclip in sd_keys:
unet_config["num_classes"] = "sequential" unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1] unet_config["adm_in_channels"] = sd[unclip].shape[1]
v_prediction = False
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = sd[k] out = sd[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
v_prediction = True
sd_config["parameterization"] = 'v' sd_config["parameterization"] = 'v'
model = instantiate_from_config(model_config) if inpaint_model:
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
elif unclip_model:
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16: if fp16:

View File

@ -82,6 +82,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
next_new_token += 1 next_new_token += 1
else: else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]]
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
if len(embedding_weights) > 0: if len(embedding_weights) > 0:

65
comfy/taesd/taesd.py Normal file
View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
"""
import torch
import torch.nn as nn
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def Encoder():
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
)
def Decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
@staticmethod
def scale_latents(x):
"""raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)

View File

@ -1,11 +1,16 @@
import torch import torch
import math import math
import struct
def load_torch_file(ckpt, safe_load=False): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
import safetensors.torch import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu") sd = safetensors.torch.load_file(ckpt, device="cpu")
else: else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load: if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else: else:
@ -19,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False):
return sd return sd
def transformers_convert(sd, prefix_from, prefix_to, number): def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
"{}.ln_final.weight": "{}.final_layer_norm.weight",
"{}.ln_final.bias": "{}.final_layer_norm.bias",
}
for k in keys_to_replace:
x = k.format(prefix_from)
if x in sd:
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
resblock_to_replace = { resblock_to_replace = {
"ln_1": "layer_norm1", "ln_1": "layer_norm1",
"ln_2": "layer_norm2", "ln_2": "layer_norm2",
@ -46,71 +63,87 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd return sd
#slow and inefficient, should be optimized def safetensors_header(safetensors_path, max_size=100*1024*1024):
with open(safetensors_path, "rb") as f:
header = f.read(8)
length_of_header = struct.unpack('<Q', header)[0]
if length_of_header > max_size:
return None
return f.read(length_of_header)
def bislerp(samples, width, height): def bislerp(samples, width, height):
shape = list(samples.shape) def slerp(b1, b2, r):
width_scale = (shape[3]) / (width ) '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
height_scale = (shape[2]) / (height )
c = b1.shape[-1]
shape[3] = width #norms
shape[2] = height b1_norms = torch.norm(b1, dim=-1, keepdim=True)
out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) b2_norms = torch.norm(b2, dim=-1, keepdim=True)
def algorithm(in1, in2, t): #normalize
dims = in1.shape b1_normalized = b1 / b1_norms
val = t b2_normalized = b2 / b2_norms
#flatten to batches #zero when norms are zero
low = in1.reshape(dims[0], -1) b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
high = in2.reshape(dims[0], -1) b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
low_weight = torch.norm(low, dim=1, keepdim=True) #slerp
low_weight[low_weight == 0] = 0.0000000001 dot = (b1_normalized*b2_normalized).sum(1)
low_norm = low/low_weight omega = torch.acos(dot)
high_weight = torch.norm(high, dim=1, keepdim=True)
high_weight[high_weight == 0] = 0.0000000001
high_norm = high/high_weight
dot_prod = (low_norm*high_norm).sum(1)
dot_prod[dot_prod > 0.9995] = 0.9995
dot_prod[dot_prod < -0.9995] = -0.9995
omega = torch.acos(dot_prod)
so = torch.sin(omega) so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm
res *= (low_weight * (1.0-val) + high_weight * val)
return res.reshape(dims)
for x_dest in range(shape[3]): #technically not mathematically correct, but more pleasing?
for y_dest in range(shape[2]): res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
y = (y_dest + 0.5) * height_scale - 0.5 res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
x = (x_dest + 0.5) * width_scale - 0.5
x1 = max(math.floor(x), 0) #edge cases for same or polar opposites
x2 = min(x1 + 1, samples.shape[3] - 1) res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
wx = x - math.floor(x) res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new):
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
n,c,h,w = samples.shape
h_new, w_new = (height, width)
#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
y1 = max(math.floor(y), 0) pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
y2 = min(y1 + 1, samples.shape[2] - 1) pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
wy = y - math.floor(y) ratios = ratios.movedim(1, -1).reshape((-1,1))
in1 = samples[:,:,y1,x1] result = slerp(pass_1, pass_2, ratios)
in2 = samples[:,:,y1,x2] result = result.reshape(n, h, w_new, c).movedim(-1, 1)
in3 = samples[:,:,y2,x1]
in4 = samples[:,:,y2,x2]
if (x1 == x2) and (y1 == y2): #linear h
out_value = in1 ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
elif (x1 == x2): coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
out_value = algorithm(in1, in3, wy) coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
elif (y1 == y2): ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
out_value = algorithm(in1, in2, wx)
else:
o1 = algorithm(in1, in2, wx)
o2 = algorithm(in3, in4, wx)
out_value = algorithm(o1, o2, wy)
out1[:,:,y_dest,x_dest] = out_value pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
return out1 pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
ratios = ratios.movedim(1, -1).reshape((-1,1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result
def common_upscale(samples, width, height, upscale_method, crop): def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center": if crop == "center":
@ -176,14 +209,14 @@ class ProgressBar:
self.current = 0 self.current = 0
self.hook = PROGRESS_BAR_HOOK self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None): def update_absolute(self, value, total=None, preview=None):
if total is not None: if total is not None:
self.total = total self.total = total
if value > self.total: if value > self.total:
value = self.total value = self.total
self.current = value self.current = value
if self.hook is not None: if self.hook is not None:
self.hook(self.current, self.total) self.hook(self.current, self.total, preview)
def update(self, value): def update(self, value):
self.update_absolute(self.current + value) self.update_absolute(self.current + value)

View File

@ -167,7 +167,7 @@ class MaskComposite:
"source": ("MASK",), "source": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"operation": (["multiply", "add", "subtract"],), "operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
} }
} }
@ -193,6 +193,12 @@ class MaskComposite:
output[top:bottom, left:right] = destination_portion + source_portion output[top:bottom, left:right] = destination_portion + source_portion
elif operation == "subtract": elif operation == "subtract":
output[top:bottom, left:right] = destination_portion - source_portion output[top:bottom, left:right] = destination_portion - source_portion
elif operation == "and":
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
elif operation == "or":
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
elif operation == "xor":
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
output = torch.clamp(output, 0.0, 1.0) output = torch.clamp(output, 0.0, 1.0)

View File

@ -102,13 +102,21 @@ def get_output_data(obj, input_data_all):
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui return output, ui
def format_value(x):
if x is None:
return None
elif isinstance(x, (int, float, bool, str)):
return x
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs: if unique_id in outputs:
return return (True, None, None)
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
@ -117,22 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
if result[0] is not True:
# Another node failed further upstream
return result
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) input_data_all = None
if server.client_id is not None: try:
server.last_node_id = unique_id input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
except comfy.model_management.InterruptProcessingException as iex:
print("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
"node_id": unique_id,
}
return (False, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
input_data_formatted = {}
if input_data_all is not None:
input_data_formatted = {}
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
output_data_formatted = {}
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
print("!!! Exception during processing !!!")
print(traceback.format_exc())
error_details = {
"node_id": unique_id,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
}
return (False, error_details, ex)
executed.add(unique_id) executed.add(unique_id)
return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item): def recursive_will_execute(prompt, outputs, current_item):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
@ -210,6 +260,48 @@ class PromptExecutor:
self.old_prompt = {} self.old_prompt = {}
self.server = server self.server = server
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
class_type = prompt[node_id]["class_type"]
# First, send back the status to the frontend depending
# on the exception type
if isinstance(ex, comfy.model_management.InterruptProcessingException):
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
}
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
else:
if self.server.client_id is not None:
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
}
self.server.send_sync("execution_error", mes, self.server.client_id)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
@ -244,42 +336,30 @@ class PromptExecutor:
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set() executed = set()
try: output_node_id = None
to_execute = [] to_execute = []
for x in list(execute_outputs):
to_execute += [(0, x)]
while len(to_execute) > 0: for node_id in list(execute_outputs):
#always execute the output that depends on the least amount of unexecuted nodes first to_execute += [(0, node_id)]
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1]
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) while len(to_execute) > 0:
except Exception as e: #always execute the output that depends on the least amount of unexecuted nodes first
if isinstance(e, comfy.model_management.InterruptProcessingException): to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
print("Processing interrupted") output_node_id = to_execute.pop(0)[-1]
else:
message = str(traceback.format_exc())
print(message)
if self.server.client_id is not None:
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
to_delete = [] # This call shouldn't raise anything if there's an error deep in
for o in self.outputs: # the actual SD code, instead it will report the node where the
if (o not in current_outputs) and (o not in executed): # error was raised
to_delete += [o] success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
if o in self.old_prompt: if success is not True:
d = self.old_prompt.pop(o) self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
del d break
for o in to_delete:
d = self.outputs.pop(o) for x in executed:
del d self.old_prompt[x] = copy.deepcopy(prompt[x])
finally: self.server.last_node_id = None
for x in executed: if self.server.client_id is not None:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
self.server.last_node_id = None
if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect() gc.collect()
@ -297,57 +377,202 @@ def validate_inputs(prompt, item, validated):
class_inputs = obj_class.INPUT_TYPES() class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required'] required_inputs = class_inputs['required']
errors = []
valid = True
for x in required_inputs: for x in required_inputs:
if x not in inputs: if x not in inputs:
return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
}
}
errors.append(error)
continue
val = inputs[x] val = inputs[x]
info = required_inputs[x] info = required_inputs[x]
type_input = info[0] type_input = info[0]
if isinstance(val, list): if isinstance(val, list):
if len(val) != 2: if len(val) != 2:
return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) error = {
"type": "bad_linked_input",
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val
}
}
errors.append(error)
continue
o_id = val[0] o_id = val[0]
o_class_type = prompt[o_id]['class_type'] o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input: if r[val[1]] != type_input:
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) received_type = r[val[1]]
r = validate_inputs(prompt, o_id, validated) details = f"{x}, {received_type} != {type_input}"
if r[0] == False: error = {
validated[o_id] = r "type": "return_type_mismatch",
return r "message": "Return type mismatch between linked nodes",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_type": received_type,
"linked_node": val
}
}
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
continue
except Exception as ex:
typ, _, tb = sys.exc_info()
valid = False
exception_type = full_type_name(typ)
reasons = [{
"type": "exception_during_inner_validation",
"message": "Exception when validating inner node",
"details": str(ex),
"extra_info": {
"input_name": x,
"input_config": info,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"linked_node": val
}
}]
validated[o_id] = (False, reasons, o_id)
continue
else: else:
if type_input == "INT": try:
val = int(val) if type_input == "INT":
inputs[x] = val val = int(val)
if type_input == "FLOAT": inputs[x] = val
val = float(val) if type_input == "FLOAT":
inputs[x] = val val = float(val)
if type_input == "STRING": inputs[x] = val
val = str(val) if type_input == "STRING":
inputs[x] = val val = str(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value",
"details": f"{x}, {val}, {ex}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
"exception_message": str(ex)
}
}
errors.append(error)
continue
if len(info) > 1: if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]: if "min" in info[1] and val < info[1]["min"]:
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]: if "max" in info[1] and val > info[1]["max"]:
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"): if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id) input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all) #ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for r in ret: for i, r in enumerate(ret):
if r != True: if r is not True:
return (False, "{}, {}".format(class_type, r), unique_id) details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else: else:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) input_config = info
list_info = ""
# Don't send back gigantic lists like if they're lots of
# scanned model filepaths
if len(type_input) > 20:
list_info = f"(list of length {len(type_input)})"
input_config = None
else:
list_info = str(type_input)
error = {
"type": "value_not_in_list",
"message": "Value not in list",
"details": f"{x}: '{val}' not in {list_info}",
"extra_info": {
"input_name": x,
"input_config": input_config,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
else:
ret = (True, [], unique_id)
ret = (True, "", unique_id)
validated[unique_id] = ret validated[unique_id] = ret
return ret return ret
def full_type_name(klass):
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__
def validate_prompt(prompt): def validate_prompt(prompt):
outputs = set() outputs = set()
for x in prompt: for x in prompt:
@ -356,7 +581,13 @@ def validate_prompt(prompt):
outputs.add(x) outputs.add(x)
if len(outputs) == 0: if len(outputs) == 0:
return (False, "Prompt has no outputs", [], []) error = {
"type": "prompt_no_outputs",
"message": "Prompt has no outputs",
"details": "",
"extra_info": {}
}
return (False, error, [], [])
good_outputs = set() good_outputs = set()
errors = [] errors = []
@ -364,34 +595,72 @@ def validate_prompt(prompt):
validated = {} validated = {}
for o in outputs: for o in outputs:
valid = False valid = False
reason = "" reasons = []
try: try:
m = validate_inputs(prompt, o, validated) m = validate_inputs(prompt, o, validated)
valid = m[0] valid = m[0]
reason = m[1] reasons = m[1]
node_id = m[2] except Exception as ex:
except Exception as e: typ, _, tb = sys.exc_info()
print(traceback.format_exc())
valid = False valid = False
reason = "Parsing error" exception_type = full_type_name(typ)
node_id = None reasons = [{
"type": "exception_during_validation",
"message": "Exception when validating node",
"details": str(ex),
"extra_info": {
"exception_type": exception_type,
"traceback": traceback.format_tb(tb)
}
}]
validated[o] = (False, reasons, o)
if valid == True: if valid is True:
good_outputs.add(o) good_outputs.add(o)
else: else:
print("Failed to validate prompt for output {} {}".format(o, reason)) print(f"Failed to validate prompt for output {o}:")
print("output will be ignored") if len(reasons) > 0:
errors += [(o, reason)] print("* (prompt):")
if node_id is not None: for reason in reasons:
if node_id not in node_errors: print(f" - {reason['message']}: {reason['details']}")
node_errors[node_id] = {"message": reason, "dependent_outputs": []} errors += [(o, reasons)]
node_errors[node_id]["dependent_outputs"].append(o) for node_id, result in validated.items():
valid = result[0]
reasons = result[1]
# If a node upstream has errors, the nodes downstream will also
# be reported as invalid, but there will be no errors attached.
# So don't return those nodes as having errors in the response.
if valid is not True and len(reasons) > 0:
if node_id not in node_errors:
class_type = prompt[node_id]['class_type']
node_errors[node_id] = {
"errors": reasons,
"dependent_outputs": [],
"class_type": class_type
}
print(f"* {class_type} {node_id}:")
for reason in reasons:
print(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o)
print("Output will be ignored")
if len(good_outputs) == 0: if len(good_outputs) == 0:
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) errors_list = []
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) for o, errors in errors:
for error in errors:
errors_list.append(f"{error['message']}: {error['details']}")
errors_list = "\n".join(errors_list)
return (True, "", list(good_outputs), node_errors) error = {
"type": "prompt_outputs_failed_validation",
"message": "Prompt outputs failed validation",
"details": errors_list,
"extra_info": {}
}
return (False, error, list(good_outputs), node_errors)
return (True, None, list(good_outputs), node_errors)
class PromptQueue: class PromptQueue:

View File

@ -1,14 +1,8 @@
import os import os
import time
supported_ckpt_extensions = set(['.ckpt', '.pth']) supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
try:
import safetensors.torch
supported_ckpt_extensions.add('.safetensors')
supported_pt_extensions.add('.safetensors')
except:
print("Could not import safetensors, safetensors support disabled.")
folder_names_and_paths = {} folder_names_and_paths = {}
@ -24,6 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision"
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
@ -38,6 +33,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
filename_list_cache = {}
if not os.path.exists(input_directory): if not os.path.exists(input_directory):
os.makedirs(input_directory) os.makedirs(input_directory)
@ -118,12 +115,18 @@ def get_folder_paths(folder_name):
return folder_names_and_paths[folder_name][0][:] return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory): def recursive_search(directory):
if not os.path.isdir(directory):
return [], {}
result = [] result = []
dirs = {directory: os.path.getmtime(directory)}
for root, subdir, file in os.walk(directory, followlinks=True): for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file: for filepath in file:
#we os.path,join directory with a blank string to generate a path separator at the end. #we os.path,join directory with a blank string to generate a path separator at the end.
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
return result for d in subdir:
path = os.path.join(root, d)
dirs[path] = os.path.getmtime(path)
return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
@ -132,20 +135,58 @@ def filter_files_extensions(files, extensions):
def get_full_path(folder_name, filename): def get_full_path(folder_name, filename):
global folder_names_and_paths global folder_names_and_paths
if folder_name not in folder_names_and_paths:
return None
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
filename = os.path.relpath(os.path.join("/", filename), "/")
for x in folders[0]: for x in folders[0]:
full_path = os.path.join(x, filename) full_path = os.path.join(x, filename)
if os.path.isfile(full_path): if os.path.isfile(full_path):
return full_path return full_path
return None
def get_filename_list(folder_name): def get_filename_list_(folder_name):
global folder_names_and_paths global folder_names_and_paths
output_list = set() output_list = set()
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
output_folders = {}
for x in folders[0]: for x in folders[0]:
output_list.update(filter_files_extensions(recursive_search(x), folders[1])) files, folders_all = recursive_search(x)
return sorted(list(output_list)) output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all}
return (sorted(list(output_list)), output_folders, time.perf_counter())
def cached_filename_list_(folder_name):
global filename_list_cache
global folder_names_and_paths
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]:
time_modified = out[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
folders = folder_names_and_paths[folder_name]
for x in folders[0]:
if os.path.isdir(x):
if x not in out[1]:
return None
return out
def get_filename_list(folder_name):
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
return list(out[0])
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename): def map_filename(filename):

95
latent_preview.py Normal file
View File

@ -0,0 +1,95 @@
import torch
from PIL import Image, ImageOps
from io import BytesIO
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import folder_paths
MAX_PREVIEW_RESOLUTION = 512
class LatentPreviewer:
def decode_latent_to_preview(self, x0):
pass
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format, quality=95)
preview_bytes = bytesIO.getvalue()
return preview_bytes
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0)[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
preview_image = Image.fromarray(x_sample)
return preview_image
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self):
self.latent_rgb_factors = torch.tensor([
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], device="cpu")
def decode_latent_to_preview(self, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
def get_previewer(device):
previewer = None
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if taesd_decoder_path:
method = LatentPreviewMethod.TAESD
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
if previewer is None:
previewer = Latent2RGBPreviewer()
return previewer

25
main.py
View File

@ -26,6 +26,7 @@ import yaml
import execution import execution
import folder_paths import folder_paths
import server import server
from server import BinaryEventTypes
from nodes import init_custom_nodes from nodes import init_custom_nodes
@ -36,19 +37,25 @@ def prompt_worker(q, server):
e.execute(item[2], item[1], item[3], item[4]) e.execute(item[2], item[1], item[3], item[4])
q.task_done(item_id, e.outputs_ui) q.task_done(item_id, e.outputs_ui)
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
def hook(value, total): def hook(value, total, preview_image_bytes):
server.send_sync("progress", { "value": value, "max": total}, server.client_id) server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image_bytes is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp(): def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
if os.path.exists(temp_dir): if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)
def load_extra_path_config(yaml_path): def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream: with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream) config = yaml.safe_load(stream)
@ -69,6 +76,7 @@ def load_extra_path_config(yaml_path):
print("Adding extra search path", x, full_path) print("Adding extra search path", x, full_path)
folder_paths.add_model_folder_path(x, full_path) folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__": if __name__ == "__main__":
cleanup_temp() cleanup_temp()
@ -89,7 +97,7 @@ if __name__ == "__main__":
server.add_routes() server.add_routes()
hijack_progress(server) hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
if args.output_directory: if args.output_directory:
output_dir = os.path.abspath(args.output_directory) output_dir = os.path.abspath(args.output_directory)
@ -103,15 +111,12 @@ if __name__ == "__main__":
if args.auto_launch: if args.auto_launch:
def startup_server(address, port): def startup_server(address, port):
import webbrowser import webbrowser
webbrowser.open("http://{}:{}".format(address, port)) webbrowser.open(f"http://{address}:{port}")
call_on_start = startup_server call_on_start = startup_server
if os.name == "nt": try:
try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
pass
else:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
print("\nStopped server")
cleanup_temp() cleanup_temp()

View File

@ -13,11 +13,10 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_convert import comfy.diffusers_load
import comfy.samplers import comfy.samplers
import comfy.sample import comfy.sample
import comfy.sd import comfy.sd
@ -29,7 +28,7 @@ import comfy.model_management
import importlib import importlib
import folder_paths import folder_paths
import latent_preview
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
@ -248,7 +247,6 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -377,7 +375,7 @@ class DiffusersLoader:
model_path = path model_path = path
break break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader: class unCLIPCheckpointLoader:
@ -426,6 +424,9 @@ class LoraLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_lora(self, model, clip, lora_name, strength_model, strength_clip): def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path("loras", lora_name) lora_path = folder_paths.get_full_path("loras", lora_name)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora) return (model_lora, clip_lora)
@ -507,6 +508,9 @@ class ControlNetApply:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def apply_controlnet(self, conditioning, control_net, image, strength): def apply_controlnet(self, conditioning, control_net, image, strength):
if strength == 0:
return (conditioning, )
c = [] c = []
control_hint = image.movedim(-1,1) control_hint = image.movedim(-1,1)
for t in conditioning: for t in conditioning:
@ -613,6 +617,9 @@ class unCLIPConditioning:
CATEGORY = "conditioning" CATEGORY = "conditioning"
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
if strength == 0:
return (conditioning, )
c = [] c = []
for t in conditioning: for t in conditioning:
o = t[1].copy() o = t[1].copy()
@ -922,6 +929,7 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,) return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"] latent_image = latent["samples"]
@ -936,9 +944,18 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = latent_preview.get_previewer(device)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps) preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
@ -961,7 +978,8 @@ class KSampler:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ), "latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}} }
}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "sample" FUNCTION = "sample"
@ -988,7 +1006,8 @@ class KSamplerAdvanced:
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ), "return_with_leftover_noise": (["disable", "enable"], ),
}} }
}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "sample" FUNCTION = "sample"

107
server.py
View File

@ -7,6 +7,7 @@ import execution
import uuid import uuid
import json import json
import glob import glob
import struct
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -22,6 +23,12 @@ except ImportError:
import mimetypes import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils
import comfy.model_management
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@web.middleware @web.middleware
@ -216,6 +223,27 @@ class PromptServer():
file = os.path.join(output_dir, filename) file = os.path.join(output_dir, filename)
if os.path.isfile(file): if os.path.isfile(file):
if 'preview' in request.rel_url.query:
with Image.open(file) as img:
preview_info = request.rel_url.query['preview'].split(';')
image_format = preview_info[0]
if image_format not in ['webp', 'jpeg']:
image_format = 'webp'
quality = 90
if preview_info[-1].isdigit():
quality = int(preview_info[-1])
buffer = BytesIO()
if image_format in ['jpeg']:
img = img.convert("RGB")
img.save(buffer, format=image_format, quality=quality)
buffer.seek(0)
return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
headers={"Content-Disposition": f"filename=\"{filename}\""})
if 'channel' not in request.rel_url.query: if 'channel' not in request.rel_url.query:
channel = 'rgba' channel = 'rgba'
else: else:
@ -257,6 +285,50 @@ class PromptServer():
return web.Response(status=404) return web.Response(status=404)
@routes.get("/view_metadata/{folder_name}")
async def view_metadata(request):
folder_name = request.match_info.get("folder_name", None)
if folder_name is None:
return web.Response(status=404)
if not "filename" in request.rel_url.query:
return web.Response(status=404)
filename = request.rel_url.query["filename"]
if not filename.endswith(".safetensors"):
return web.Response(status=404)
safetensors_path = folder_paths.get_full_path(folder_name, filename)
if safetensors_path is None:
return web.Response(status=404)
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
if out is None:
return web.Response(status=404)
dt = json.loads(out)
if not "__metadata__" in dt:
return web.Response(status=404)
return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"devices": [
{
"name": device_name,
"type": device.type,
"index": device.index,
"vram_total": vram_total,
"vram_free": vram_free,
"torch_vram_total": torch_vram_total,
"torch_vram_free": torch_vram_free,
}
]
}
return web.json_response(system_stats)
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())
@ -338,7 +410,7 @@ class PromptServer():
prompt_id = str(uuid.uuid4()) prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2] outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
return web.json_response({"prompt_id": prompt_id}) return web.json_response({"prompt_id": prompt_id, "number": number})
else: else:
print("invalid prompt:", valid[1]) print("invalid prompt:", valid[1])
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
@ -391,16 +463,37 @@ class PromptServer():
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
message = {"type": event, "data": data} if isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
if isinstance(message, str) == False: else:
message = json.dumps(message) await self.send_json(event, data, sid)
def encode_bytes(self, event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
message = bytearray(packed)
message.extend(data)
return message
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)
if sid is None: if sid is None:
for ws in self.sockets.values(): for ws in self.sockets.values():
await ws.send_str(message) await ws.send_bytes(message)
elif sid in self.sockets: elif sid in self.sockets:
await self.sockets[sid].send_str(message) await self.sockets[sid].send_bytes(message)
async def send_json(self, event, data, sid=None):
message = {"type": event, "data": data}
if sid is None:
for ws in self.sockets.values():
await ws.send_json(message)
elif sid in self.sockets:
await self.sockets[sid].send_json(message)
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(

View File

@ -21,6 +21,7 @@ const colorPalettes = {
"MODEL": "#B39DDB", // light lavender-purple "MODEL": "#B39DDB", // light lavender-purple
"STYLE_MODEL": "#C2FFAE", // light green-yellow "STYLE_MODEL": "#C2FFAE", // light green-yellow
"VAE": "#FF6E6E", // bright red "VAE": "#FF6E6E", // bright red
"TAESD": "#DCC274", // cheesecake
}, },
"litegraph_base": { "litegraph_base": {
"NODE_TITLE_COLOR": "#999", "NODE_TITLE_COLOR": "#999",

View File

@ -1,132 +1,138 @@
import { app } from "/scripts/app.js"; import {app} from "/scripts/app.js";
// Adds filtering to combo context menus // Adds filtering to combo context menus
const id = "Comfy.ContextMenuFilter"; const ext = {
app.registerExtension({ name: "Comfy.ContextMenuFilter",
name: id,
init() { init() {
const ctxMenu = LiteGraph.ContextMenu; const ctxMenu = LiteGraph.ContextMenu;
LiteGraph.ContextMenu = function (values, options) { LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options); const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input // If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) { if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input"); const filter = document.createElement("input");
Object.assign(filter.style, { filter.classList.add("comfy-context-menu-filter");
width: "calc(100% - 10px)",
border: "0",
boxSizing: "border-box",
background: "#333",
border: "1px solid #999",
margin: "0 0 5px 5px",
color: "#fff",
});
filter.placeholder = "Filter list"; filter.placeholder = "Filter list";
this.root.prepend(filter); this.root.prepend(filter);
let selectedIndex = 0; const items = Array.from(this.root.querySelectorAll(".litemenu-entry"));
let items = this.root.querySelectorAll(".litemenu-entry"); let displayedItems = [...items];
let itemCount = items.length; let itemCount = displayedItems.length;
let selectedItem;
// Apply highlighting to the selected item // We must request an animation frame for the current node of the active canvas to update.
function updateSelected() { requestAnimationFrame(() => {
if (selectedItem) { const currentNode = LGraphCanvas.active_canvas.current_node;
selectedItem.style.setProperty("background-color", ""); const clickedComboValue = currentNode.widgets
selectedItem.style.setProperty("color", ""); .filter(w => w.type === "combo" && w.options.values.length === values.length)
} .find(w => w.options.values.every((v, i) => v === values[i]))
selectedItem = items[selectedIndex]; .value;
if (selectedItem) {
selectedItem.style.setProperty("background-color", "#ccc", "important");
selectedItem.style.setProperty("color", "#000", "important");
}
}
const positionList = () => { let selectedIndex = values.findIndex(v => v === clickedComboValue);
const rect = this.root.getBoundingClientRect(); let selectedItem = displayedItems?.[selectedIndex];
// If the top is off screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
updateSelected();
// Arrow up/down to select items
filter.addEventListener("keydown", (e) => {
if (e.key === "ArrowUp") {
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
e.preventDefault();
} else if (e.key === "ArrowDown") {
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
e.preventDefault();
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
selectedItem.click();
} else if(e.key === "Escape") {
this.close();
}
});
filter.addEventListener("input", () => {
// Hide all items that dont match our filter
const term = filter.value.toLocaleLowerCase();
items = this.root.querySelectorAll(".litemenu-entry");
// When filtering recompute which items are visible for arrow up/down
// Try and maintain selection
let visibleItems = [];
for (const item of items) {
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
if (visible) {
item.style.display = "block";
if (item === selectedItem) {
selectedIndex = visibleItems.length;
}
visibleItems.push(item);
} else {
item.style.display = "none";
if (item === selectedItem) {
selectedIndex = 0;
}
}
}
items = visibleItems;
updateSelected(); updateSelected();
// If we have an event then we can try and position the list under the source // Apply highlighting to the selected item
if (options.event) { function updateSelected() {
let top = options.event.clientY - 10; selectedItem?.style.setProperty("background-color", "");
selectedItem?.style.setProperty("color", "");
const bodyRect = document.body.getBoundingClientRect(); selectedItem = displayedItems[selectedIndex];
const rootRect = this.root.getBoundingClientRect(); selectedItem?.style.setProperty("background-color", "#ccc", "important");
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { selectedItem?.style.setProperty("color", "#000", "important");
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
} }
});
requestAnimationFrame(() => { const positionList = () => {
// Focus the filter box when opening const rect = this.root.getBoundingClientRect();
filter.focus();
positionList(); // If the top is off-screen then shift the element with scaling applied
}); if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
// Arrow up/down to select items
filter.addEventListener("keydown", (event) => {
switch (event.key) {
case "ArrowUp":
event.preventDefault();
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
break;
case "ArrowRight":
event.preventDefault();
selectedIndex = itemCount - 1;
updateSelected();
break;
case "ArrowDown":
event.preventDefault();
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
break;
case "ArrowLeft":
event.preventDefault();
selectedIndex = 0;
updateSelected();
break;
case "Enter":
selectedItem?.click();
break;
case "Escape":
this.close();
break;
}
});
filter.addEventListener("input", () => {
// Hide all items that don't match our filter
const term = filter.value.toLocaleLowerCase();
// When filtering, recompute which items are visible for arrow up/down and maintain selection.
displayedItems = items.filter(item => {
const isVisible = !term || item.textContent.toLocaleLowerCase().includes(term);
item.style.display = isVisible ? "block" : "none";
return isVisible;
});
selectedIndex = 0;
if (displayedItems.includes(selectedItem)) {
selectedIndex = displayedItems.findIndex(d => d === selectedItem);
}
itemCount = displayedItems.length;
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
positionList();
});
})
} }
return ctx; return ctx;
@ -134,4 +140,6 @@ app.registerExtension({
LiteGraph.ContextMenu.prototype = ctxMenu.prototype; LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
}, },
}); }
app.registerExtension(ext);

View File

@ -3,6 +3,13 @@ import { app } from "../../scripts/app.js";
// Allows for simple dynamic prompt replacement // Allows for simple dynamic prompt replacement
// Inputs in the format {a|b} will have a random value of a or b chosen when the prompt is queued. // Inputs in the format {a|b} will have a random value of a or b chosen when the prompt is queued.
/*
* Strips C-style line and block comments from a string
*/
function stripComments(str) {
return str.replace(/\/\*[\s\S]*?\*\/|\/\/.*/g,'');
}
app.registerExtension({ app.registerExtension({
name: "Comfy.DynamicPrompts", name: "Comfy.DynamicPrompts",
nodeCreated(node) { nodeCreated(node) {
@ -15,7 +22,7 @@ app.registerExtension({
for (const widget of widgets) { for (const widget of widgets) {
// Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node // Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node
widget.serializeValue = (workflowNode, widgetIndex) => { widget.serializeValue = (workflowNode, widgetIndex) => {
let prompt = widget.value; let prompt = stripComments(widget.value);
while (prompt.replace("\\{", "").includes("{") && prompt.replace("\\}", "").includes("}")) { while (prompt.replace("\\{", "").includes("{") && prompt.replace("\\}", "").includes("}")) {
const startIndex = prompt.replace("\\{", "00").indexOf("{"); const startIndex = prompt.replace("\\{", "00").indexOf("{");
const endIndex = prompt.replace("\\}", "00").indexOf("}"); const endIndex = prompt.replace("\\}", "00").indexOf("}");

View File

@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) {
}); });
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString(); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam();
if(ComfyApp.clipspace.images) if(ComfyApp.clipspace.images)
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog {
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
// update mask // update mask
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCanvas.width = drawWidth; maskCanvas.width = drawWidth;
maskCanvas.height = drawHeight; maskCanvas.height = drawHeight;
maskCanvas.style.top = imgCanvas.offsetTop + "px"; maskCanvas.style.top = imgCanvas.offsetTop + "px";
maskCanvas.style.left = imgCanvas.offsetLeft + "px"; maskCanvas.style.left = imgCanvas.offsetLeft + "px";
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
}); });
@ -335,6 +335,7 @@ class MaskEditorDialog extends ComfyDialog {
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
alpha_url.searchParams.delete('channel'); alpha_url.searchParams.delete('channel');
alpha_url.searchParams.delete('preview');
alpha_url.searchParams.set('channel', 'a'); alpha_url.searchParams.set('channel', 'a');
touched_image.src = alpha_url; touched_image.src = alpha_url;
@ -345,6 +346,7 @@ class MaskEditorDialog extends ComfyDialog {
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
rgb_url.searchParams.delete('channel'); rgb_url.searchParams.delete('channel');
rgb_url.searchParams.delete('preview');
rgb_url.searchParams.set('channel', 'rgb'); rgb_url.searchParams.set('channel', 'rgb');
orig_image.src = rgb_url; orig_image.src = rgb_url;
this.image = orig_image; this.image = orig_image;

View File

@ -200,8 +200,23 @@ app.registerExtension({
applyToGraph() { applyToGraph() {
if (!this.outputs[0].links?.length) return; if (!this.outputs[0].links?.length) return;
function get_links(node) {
let links = [];
for (const l of node.outputs[0].links) {
const linkInfo = app.graph.links[l];
const n = node.graph.getNodeById(linkInfo.target_id);
if (n.type == "Reroute") {
links = links.concat(get_links(n));
} else {
links.push(l);
}
}
return links;
}
let links = get_links(this);
// For each output link copy our value over the original widget value // For each output link copy our value over the original widget value
for (const l of this.outputs[0].links) { for (const l of links) {
const linkInfo = app.graph.links[l]; const linkInfo = app.graph.links[l];
const node = this.graph.getNodeById(linkInfo.target_id); const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot]; const input = node.inputs[linkInfo.target_slot];

View File

@ -14,5 +14,5 @@
window.graph = app.graph; window.graph = app.graph;
</script> </script>
</head> </head>
<body></body> <body class="litegraph"></body>
</html> </html>

View File

@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action)
if (this.onShowNodePanel) { if (this.onShowNodePanel) {
this.onShowNodePanel(n); this.onShowNodePanel(n);
} }
else
{
this.showShowNodePanel(n);
}
if (this.onNodeDblClicked) { if (this.onNodeDblClicked) {
this.onNodeDblClicked(n); this.onNodeDblClicked(n);
@ -8099,11 +8095,15 @@ LGraphNode.prototype.executeAction = function(action)
bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR; bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR;
hovercolor = hovercolor || "#555"; hovercolor = hovercolor || "#555";
textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR; textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR;
var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title var pos = this.ds.convertOffsetToCanvas(this.graph_mouse);
var pos = this.mouse; var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null;
pos = this.last_click_position; if(pos) {
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); var rect = this.canvas.getBoundingClientRect();
pos[0] -= rect.left;
pos[1] -= rect.top;
}
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
ctx.fillStyle = hover ? hovercolor : bgcolor; ctx.fillStyle = hover ? hovercolor : bgcolor;
if(clicked) if(clicked)
@ -13067,6 +13067,10 @@ LGraphNode.prototype.executeAction = function(action)
has_submenu: true, has_submenu: true,
callback: LGraphCanvas.onShowMenuNodeProperties callback: LGraphCanvas.onShowMenuNodeProperties
}, },
{
content: "Properties Panel",
callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) }
},
null, null,
{ {
content: "Title", content: "Title",

View File

@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
this.socket = new WebSocket( this.socket = new WebSocket(
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}` `ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
); );
this.socket.binaryType = "arraybuffer";
this.socket.addEventListener("open", () => { this.socket.addEventListener("open", () => {
opened = true; opened = true;
@ -70,33 +71,65 @@ class ComfyApi extends EventTarget {
this.socket.addEventListener("message", (event) => { this.socket.addEventListener("message", (event) => {
try { try {
const msg = JSON.parse(event.data); if (event.data instanceof ArrayBuffer) {
switch (msg.type) { const view = new DataView(event.data);
case "status": const eventType = view.getUint32(0);
if (msg.data.sid) { const buffer = event.data.slice(4);
this.clientId = msg.data.sid; switch (eventType) {
window.name = this.clientId; case 1:
const view2 = new DataView(event.data);
const imageType = view2.getUint32(0)
let imageMime
switch (imageType) {
case 1:
default:
imageMime = "image/jpeg";
break;
case 2:
imageMime = "image/png"
} }
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); const imageBlob = new Blob([buffer.slice(4)], { type: imageMime });
break; this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob }));
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break; break;
default: default:
if (this.#registered.has(msg.type)) { throw new Error(`Unknown binary websocket message of type ${eventType}`);
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); }
} else { }
throw new Error("Unknown message type"); else {
} const msg = JSON.parse(event.data);
switch (msg.type) {
case "status":
if (msg.data.sid) {
this.clientId = msg.data.sid;
window.name = this.clientId;
}
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break;
default:
if (this.#registered.has(msg.type)) {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
} else {
throw new Error(`Unknown message type ${msg.type}`);
}
}
} }
} catch (error) { } catch (error) {
console.warn("Unhandled message:", event.data); console.warn("Unhandled message:", event.data, error);
} }
}); });
} }

View File

@ -44,6 +44,12 @@ export class ComfyApp {
*/ */
this.nodeOutputs = {}; this.nodeOutputs = {};
/**
* Stores the preview image data for each node
* @type {Record<string, Image>}
*/
this.nodePreviewImages = {};
/** /**
* If the shift key on the keyboard is pressed * If the shift key on the keyboard is pressed
* @type {boolean} * @type {boolean}
@ -51,6 +57,14 @@ export class ComfyApp {
this.shiftDown = false; this.shiftDown = false;
} }
getPreviewFormatParam() {
let preview_format = this.ui.settings.getSettingValue("Comfy.PreviewFormat");
if(preview_format)
return `&preview=${preview_format}`;
else
return "";
}
static isImageNode(node) { static isImageNode(node) {
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
} }
@ -111,10 +125,14 @@ export class ComfyApp {
if(ComfyApp.clipspace.imgs && node.imgs) { if(ComfyApp.clipspace.imgs && node.imgs) {
if(node.images && ComfyApp.clipspace.images) { if(node.images && ComfyApp.clipspace.images) {
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
} }
else else {
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; node.images = ComfyApp.clipspace.images;
}
if(app.nodeOutputs[node.id + ""])
app.nodeOutputs[node.id + ""].images = node.images;
} }
if(ComfyApp.clipspace.imgs) { if(ComfyApp.clipspace.imgs) {
@ -147,7 +165,16 @@ export class ComfyApp {
if(ComfyApp.clipspace.widgets) { if(ComfyApp.clipspace.widgets) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
if (prop && prop.type != 'button') { if (prop && prop.type != 'image') {
if(typeof prop.value == "string" && value.filename) {
prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:'');
}
else {
prop.value = value;
prop.callback(value);
}
}
else if (prop && prop.type != 'button') {
prop.value = value; prop.value = value;
prop.callback(value); prop.callback(value);
} }
@ -231,14 +258,20 @@ export class ComfyApp {
options.unshift( options.unshift(
{ {
content: "Open Image", content: "Open Image",
callback: () => window.open(img.src, "_blank"), callback: () => {
let url = new URL(img.src);
url.searchParams.delete('preview');
window.open(url, "_blank")
},
}, },
{ {
content: "Save Image", content: "Save Image",
callback: () => { callback: () => {
const a = document.createElement("a"); const a = document.createElement("a");
a.href = img.src; let url = new URL(img.src);
a.setAttribute("download", new URLSearchParams(new URL(img.src).search).get("filename")); url.searchParams.delete('preview');
a.href = url;
a.setAttribute("download", new URLSearchParams(url.search).get("filename"));
document.body.append(a); document.body.append(a);
a.click(); a.click();
requestAnimationFrame(() => a.remove()); requestAnimationFrame(() => a.remove());
@ -345,6 +378,10 @@ export class ComfyApp {
} }
node.prototype.setSizeForImage = function () { node.prototype.setSizeForImage = function () {
if (this.inputHeight) {
this.setSize(this.size);
return;
}
const minHeight = getImageTop(this) + 220; const minHeight = getImageTop(this) + 220;
if (this.size[1] < minHeight) { if (this.size[1] < minHeight) {
this.setSize([this.size[0], minHeight]); this.setSize([this.size[0], minHeight]);
@ -353,29 +390,52 @@ export class ComfyApp {
node.prototype.onDrawBackground = function (ctx) { node.prototype.onDrawBackground = function (ctx) {
if (!this.flags.collapsed) { if (!this.flags.collapsed) {
let imgURLs = []
let imagesChanged = false
const output = app.nodeOutputs[this.id + ""]; const output = app.nodeOutputs[this.id + ""];
if (output && output.images) { if (output && output.images) {
if (this.images !== output.images) { if (this.images !== output.images) {
this.images = output.images; this.images = output.images;
this.imgs = null; imagesChanged = true;
this.imageIndex = null; imgURLs = imgURLs.concat(output.images.map(params => {
return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam();
}))
}
}
const preview = app.nodePreviewImages[this.id + ""]
if (this.preview !== preview) {
this.preview = preview
imagesChanged = true;
if (preview != null) {
imgURLs.push(preview);
}
}
if (imagesChanged) {
this.imageIndex = null;
if (imgURLs.length > 0) {
Promise.all( Promise.all(
output.images.map((src) => { imgURLs.map((src) => {
return new Promise((r) => { return new Promise((r) => {
const img = new Image(); const img = new Image();
img.onload = () => r(img); img.onload = () => r(img);
img.onerror = () => r(null); img.onerror = () => r(null);
img.src = "/view?" + new URLSearchParams(src).toString(); img.src = src
}); });
}) })
).then((imgs) => { ).then((imgs) => {
if (this.images === output.images) { if ((!output || this.images === output.images) && (!preview || this.preview === preview)) {
this.imgs = imgs.filter(Boolean); this.imgs = imgs.filter(Boolean);
this.setSizeForImage?.(); this.setSizeForImage?.();
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
} }
}); });
} }
else {
this.imgs = null;
}
} }
if (this.imgs && this.imgs.length) { if (this.imgs && this.imgs.length) {
@ -771,16 +831,27 @@ export class ComfyApp {
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) { LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
const res = origDrawNodeShape.apply(this, arguments); const res = origDrawNodeShape.apply(this, arguments);
const nodeErrors = self.lastPromptError?.node_errors[node.id];
let color = null; let color = null;
let lineWidth = 1;
if (node.id === +self.runningNodeId) { if (node.id === +self.runningNodeId) {
color = "#0f0"; color = "#0f0";
} else if (self.dragOverNode && node.id === self.dragOverNode.id) { } else if (self.dragOverNode && node.id === self.dragOverNode.id) {
color = "dodgerblue"; color = "dodgerblue";
} }
else if (self.lastPromptError != null && nodeErrors?.errors) {
color = "red";
lineWidth = 2;
}
else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) {
color = "#f0f";
lineWidth = 2;
}
if (color) { if (color) {
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
ctx.lineWidth = 1; ctx.lineWidth = lineWidth;
ctx.globalAlpha = 0.8; ctx.globalAlpha = 0.8;
ctx.beginPath(); ctx.beginPath();
if (shape == LiteGraph.BOX_SHAPE) if (shape == LiteGraph.BOX_SHAPE)
@ -807,11 +878,28 @@ export class ComfyApp {
ctx.stroke(); ctx.stroke();
ctx.strokeStyle = fgcolor; ctx.strokeStyle = fgcolor;
ctx.globalAlpha = 1; ctx.globalAlpha = 1;
}
if (self.progress) { if (self.progress && node.id === +self.runningNodeId) {
ctx.fillStyle = "green"; ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
ctx.fillStyle = bgcolor; ctx.fillStyle = bgcolor;
}
// Highlight inputs that failed validation
if (nodeErrors) {
ctx.lineWidth = 2;
ctx.strokeStyle = "red";
for (const error of nodeErrors.errors) {
if (error.extra_info && error.extra_info.input_name) {
const inputIndex = node.findInputSlot(error.extra_info.input_name)
if (inputIndex !== -1) {
let pos = node.getConnectionPos(true, inputIndex);
ctx.beginPath();
ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false)
ctx.stroke();
}
}
} }
} }
@ -859,16 +947,40 @@ export class ComfyApp {
this.progress = null; this.progress = null;
this.runningNodeId = detail; this.runningNodeId = detail;
this.graph.setDirtyCanvas(true, false); this.graph.setDirtyCanvas(true, false);
delete this.nodePreviewImages[this.runningNodeId]
}); });
api.addEventListener("executed", ({ detail }) => { api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output; this.nodeOutputs[detail.node] = detail.output;
const node = this.graph.getNodeById(detail.node); const node = this.graph.getNodeById(detail.node);
if (node?.onExecuted) { if (node) {
node.onExecuted(detail.output); if (node.onExecuted)
node.onExecuted(detail.output);
} }
}); });
api.addEventListener("execution_start", ({ detail }) => {
this.runningNodeId = null;
this.lastExecutionError = null
});
api.addEventListener("execution_error", ({ detail }) => {
this.lastExecutionError = detail;
const formattedError = this.#formatExecutionError(detail);
this.ui.dialog.show(formattedError);
this.canvas.draw(true, true);
});
api.addEventListener("b_preview", ({ detail }) => {
const id = this.runningNodeId
if (id == null)
return;
const blob = detail
const blobUrl = URL.createObjectURL(blob)
this.nodePreviewImages[id] = [blobUrl]
});
api.init(); api.init();
} }
@ -975,6 +1087,11 @@ export class ComfyApp {
const app = this; const app = this;
// Load node definitions from the backend // Load node definitions from the backend
const defs = await api.getNodeDefs(); const defs = await api.getNodeDefs();
await this.registerNodesFromDefs(defs);
await this.#invokeExtensionsAsync("registerCustomNodes");
}
async registerNodesFromDefs(defs) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
// Generate list of known widgets // Generate list of known widgets
@ -1047,8 +1164,6 @@ export class ComfyApp {
LiteGraph.registerNodeType(nodeId, node); LiteGraph.registerNodeType(nodeId, node);
node.category = nodeData.category; node.category = nodeData.category;
} }
await this.#invokeExtensionsAsync("registerCustomNodes");
} }
/** /**
@ -1247,6 +1362,43 @@ export class ComfyApp {
return { workflow, output }; return { workflow, output };
} }
#formatPromptError(error) {
if (error == null) {
return "(unknown error)"
}
else if (typeof error === "string") {
return error;
}
else if (error.stack && error.message) {
return error.toString()
}
else if (error.response) {
let message = error.response.error.message;
if (error.response.error.details)
message += ": " + error.response.error.details;
for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) {
message += "\n" + nodeError.class_type + ":"
for (const errorReason of nodeError.errors) {
message += "\n - " + errorReason.message + ": " + errorReason.details
}
}
return message
}
return "(unknown error)"
}
#formatExecutionError(error) {
if (error == null) {
return "(unknown error)"
}
const traceback = error.traceback.join("")
const nodeId = error.node_id
const nodeType = error.node_type
return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}`
}
async queuePrompt(number, batchCount = 1) { async queuePrompt(number, batchCount = 1) {
this.#queueItems.push({ number, batchCount }); this.#queueItems.push({ number, batchCount });
@ -1254,8 +1406,10 @@ export class ComfyApp {
if (this.#processingQueue) { if (this.#processingQueue) {
return; return;
} }
this.#processingQueue = true; this.#processingQueue = true;
this.lastPromptError = null;
try { try {
while (this.#queueItems.length) { while (this.#queueItems.length) {
({ number, batchCount } = this.#queueItems.pop()); ({ number, batchCount } = this.#queueItems.pop());
@ -1266,7 +1420,12 @@ export class ComfyApp {
try { try {
await api.queuePrompt(number, p); await api.queuePrompt(number, p);
} catch (error) { } catch (error) {
this.ui.dialog.show(error.response.error || error.toString()); const formattedError = this.#formatPromptError(error)
this.ui.dialog.show(formattedError);
if (error.response) {
this.lastPromptError = error.response;
this.canvas.draw(true, true);
}
break; break;
} }
@ -1345,6 +1504,11 @@ export class ComfyApp {
const def = defs[node.type]; const def = defs[node.type];
// HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes,
// and additional work is needed to consider the primitive logic in the refresh logic.
if(!def)
continue;
for(const widgetNum in node.widgets) { for(const widgetNum in node.widgets) {
const widget = node.widgets[widgetNum] const widget = node.widgets[widgetNum]
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
@ -1364,6 +1528,10 @@ export class ComfyApp {
*/ */
clean() { clean() {
this.nodeOutputs = {}; this.nodeOutputs = {};
this.nodePreviewImages = {}
this.lastPromptError = null;
this.lastExecutionError = null;
this.runningNodeId = null;
} }
} }

View File

@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) {
const embeddings = await api.getEmbeddings(); const embeddings = await api.getEmbeddings();
const opts = parameters const opts = parameters
.substr(p) .substr(p)
.split("\n")[1]
.split(",") .split(",")
.reduce((p, n) => { .reduce((p, n) => {
const s = n.split(":"); const s = n.split(":");

View File

@ -462,6 +462,24 @@ export class ComfyUI {
defaultValue: true, defaultValue: true,
}); });
/**
* file format for preview
*
* format;quality
*
* ex)
* webp;50 -> webp, quality 50
* jpeg;80 -> rgb, jpeg, quality 80
*
* @type {string}
*/
const previewImage = this.settings.addSetting({
id: "Comfy.PreviewFormat",
name: "When displaying a preview in the image widget, convert it to a lightweight image. (webp, jpeg, webp;50, ...)",
type: "string",
defaultValue: "",
});
const fileInput = $el("input", { const fileInput = $el("input", {
id: "comfy-file-input", id: "comfy-file-input",
type: "file", type: "file",

View File

@ -115,12 +115,12 @@ function addMultilineWidget(node, name, opts, app) {
// See how large each text input can be // See how large each text input can be
freeSpace -= widgetHeight; freeSpace -= widgetHeight;
freeSpace /= multi.length; freeSpace /= multi.length + (!!node.imgs?.length);
if (freeSpace < MIN_SIZE) { if (freeSpace < MIN_SIZE) {
// There isnt enough space for all the widgets, increase the size of the node // There isnt enough space for all the widgets, increase the size of the node
freeSpace = MIN_SIZE; freeSpace = MIN_SIZE;
node.size[1] = y + widgetHeight + freeSpace * multi.length; node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
node.graph.setDirtyCanvas(true); node.graph.setDirtyCanvas(true);
} }
@ -303,7 +303,7 @@ export const ComfyWidgets = {
subfolder = name.substring(0, folder_separator); subfolder = name.substring(0, folder_separator);
name = name.substring(folder_separator + 1); name = name.substring(folder_separator + 1);
} }
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`;
node.setSizeForImage?.(); node.setSizeForImage?.();
} }

View File

@ -50,7 +50,7 @@ body {
padding: 30px 30px 10px 30px; padding: 30px 30px 10px 30px;
background-color: var(--comfy-menu-bg); /* Modal background */ background-color: var(--comfy-menu-bg); /* Modal background */
color: var(--error-text); color: var(--error-text);
box-shadow: 0px 0px 20px #888888; box-shadow: 0 0 20px #888888;
border-radius: 10px; border-radius: 10px;
top: 50%; top: 50%;
left: 50%; left: 50%;
@ -84,7 +84,7 @@ body {
font-size: 15px; font-size: 15px;
position: absolute; position: absolute;
top: 50%; top: 50%;
right: 0%; right: 0;
text-align: center; text-align: center;
z-index: 100; z-index: 100;
width: 170px; width: 170px;
@ -252,7 +252,7 @@ button.comfy-queue-btn {
bottom: 0 !important; bottom: 0 !important;
left: auto !important; left: auto !important;
right: 0 !important; right: 0 !important;
border-radius: 0px; border-radius: 0;
} }
.comfy-menu span.drag-handle { .comfy-menu span.drag-handle {
visibility:hidden visibility:hidden
@ -289,6 +289,11 @@ button.comfy-queue-btn {
/* Context menu */ /* Context menu */
.litegraph .dialog {
z-index: 1;
font-family: Arial, sans-serif;
}
.litegraph .litemenu-entry.has_submenu { .litegraph .litemenu-entry.has_submenu {
position: relative; position: relative;
padding-right: 20px; padding-right: 20px;
@ -325,12 +330,20 @@ button.comfy-queue-btn {
color: var(--input-text) !important; color: var(--input-text) !important;
} }
.comfy-context-menu-filter {
box-sizing: border-box;
border: 1px solid #999;
margin: 0 0 5px 5px;
width: calc(100% - 10px);
}
/* Search box */ /* Search box */
.litegraph.litesearchbox { .litegraph.litesearchbox {
z-index: 9999 !important; z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important; background-color: var(--comfy-menu-bg) !important;
overflow: hidden; overflow: hidden;
display: block;
} }
.litegraph.litesearchbox input, .litegraph.litesearchbox input,