Merge branch 'comfyanonymous:master' into dpr

This commit is contained in:
kali-linex 2023-06-11 01:26:40 +02:00 committed by GitHub
commit 7352a6b41a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1759 additions and 636 deletions

View File

@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'):
else:
raise AssertionError('Unknown merge analysis result')
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
repo = pygit2.Repository(str(sys.argv[1]))
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ custom_nodes/
!custom_nodes/example_node.py.example
extra_model_paths.yaml
/.vs
.idea/

View File

@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
- Starts up very fast.
- Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
@ -37,28 +38,28 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Shortcuts
| Keybind | Explanation |
| - | - |
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Delete/Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| Double-Click LMB | Open node quick search palette |
| Keybind | Explanation |
|---------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Delete/Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| Double-Click LMB | Open node quick search palette |
Ctrl can also be replaced with Cmd instead for MacOS users
Ctrl can also be replaced with Cmd instead for macOS users
# Installing
@ -118,13 +119,26 @@ After this you should have everything installed and can proceed to running Comfy
### Others:
[Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own.
#### Apple Mac silicon
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide.
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
1. Launch ComfyUI by running `python main.py`.
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
#### DirectML (AMD Cards on Windows)
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
You don't. If you have another UI installed and working with it's own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
```source path_to_other_sd_gui/venv/bin/activate```
@ -134,7 +148,7 @@ With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"```
With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"```
And then you can use that terminal to run Comfyui without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
# Running
@ -158,6 +172,8 @@ You can use () to change emphasis of a word or phrase like: (good code:1.2) or (
You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}.
Dynamic prompts also support C-style comments, like `// comment` or `/* comment */`.
To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension):
```embedding:embedding_filename.pt```
@ -181,6 +197,12 @@ You can set this command line setting to disable the upcasting to fp32 in some c
```--dont-upcast-attention```
## How to show high-quality previews?
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
## Support and dev channel
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).

View File

@ -1,4 +1,35 @@
import argparse
import enum
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
parser = argparse.ArgumentParser()
@ -13,6 +44,14 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")

View File

@ -1,14 +1,5 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
clip = None
vae = None
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return ModelPatcher(model), clip, vae

87
comfy/diffusers_load.py Normal file
View File

@ -0,0 +1,87 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
import diffusers_convert
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)

66
comfy/model_base.py Normal file
View File

@ -0,0 +1,66 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import numpy as np
class BaseModel(torch.nn.Module):
def __init__(self, unet_config, v_prediction=False):
super().__init__()
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
if self.v_prediction:
self.parameterization = "v"
else:
self.parameterization = "eps"
if "adm_in_channels" in unet_config:
self.adm_channels = unet_config["adm_in_channels"]
else:
self.adm_channels = 0
print("v_prediction", v_prediction)
print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
if c_concat is not None:
xc = torch.cat([x] + c_concat, dim=1)
else:
xc = x
context = torch.cat(c_crossattn, 1)
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
def get_dtype(self):
return self.diffusion_model.dtype
def is_adm(self):
return self.adm_channels > 0
class SD21UNCLIP(BaseModel):
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
super().__init__(unet_config, v_prediction)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
class SDInpaint(BaseModel):
def __init__(self, unet_config, v_prediction=False):
super().__init__(unet_config, v_prediction)
self.concat_keys = ("mask", "masked_image")

View File

@ -1,23 +1,29 @@
import psutil
from enum import Enum
from comfy.cli_args import args
import torch
class VRAMState(Enum):
CPU = 0
NO_VRAM = 1
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
MPS = 5
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
cpu_state = CPUState.GPU
total_vram = 0
total_vram_available_mb = -1
accelerate_enabled = False
lowvram_available = True
xpu_available = False
directml_enabled = False
@ -31,30 +37,80 @@ if args.directml is not None:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import torch
if directml_enabled:
total_vram = 4097 #TODO
else:
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
if not args.normalvram and not args.cpu:
if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
except:
pass
try:
if torch.backends.mps.is_available():
cpu_state = CPUState.MPS
except:
pass
if args.cpu:
cpu_state = CPUState.CPU
def get_torch_device():
global xpu_available
global directml_enabled
global cpu_state
if directml_enabled:
global directml_device
return directml_device
if cpu_state == CPUState.MPS:
return torch.device("mps")
if cpu_state == CPUState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
@ -92,6 +148,7 @@ if ENABLE_PYTORCH_ATTENTION:
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
lowvram_available = True
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
@ -102,54 +159,38 @@ if args.force_fp32:
print("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
if lowvram_available:
try:
import accelerate
accelerate_enabled = True
vram_state = set_vram_to
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
total_vram_available_mb = (total_vram - 1024) // 2
total_vram_available_mb = int(max(256, total_vram_available_mb))
try:
if torch.backends.mps.is_available():
vram_state = VRAMState.MPS
except:
pass
if cpu_state != CPUState.GPU:
vram_state = VRAMState.DISABLED
if args.cpu:
vram_state = VRAMState.CPU
if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device):
if hasattr(device, 'type'):
return "{}".format(device.type)
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
if device.type == "cuda":
return "{} {}".format(device, torch.cuda.get_device_name(device))
else:
return "{}".format(device.type)
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Using device:", get_torch_device_name(get_torch_device()))
print("Device:", get_torch_device_name(get_torch_device()))
except:
print("Could not pick default device.")
@ -199,22 +240,29 @@ def load_model_gpu(model):
model.unpatch_model()
raise e
model.model_patches_to(get_torch_device())
torch_dev = get_torch_device()
model.model_patches_to(torch_dev)
vram_set_state = vram_state
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = model.model_size()
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
current_loaded_model = model
if vram_state == VRAMState.CPU:
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_state == VRAMState.MPS:
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
real_model.to(get_torch_device())
else:
if vram_state == VRAMState.NO_VRAM:
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
model_accelerated = True
@ -223,7 +271,7 @@ def load_model_gpu(model):
def load_controlnet_gpu(control_models):
global current_gpu_controlnets
global vram_state
if vram_state == VRAMState.CPU:
if vram_state == VRAMState.DISABLED:
return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
@ -268,7 +316,8 @@ def get_autocast_device(dev):
def xformers_enabled():
global xpu_available
global directml_enabled
if vram_state == VRAMState.CPU:
global cpu_state
if cpu_state != CPUState.GPU:
return False
if xpu_available:
return False
@ -340,12 +389,12 @@ def maximum_batch_area():
return int(max(area, 0))
def cpu_mode():
global vram_state
return vram_state == VRAMState.CPU
global cpu_state
return cpu_state == CPUState.CPU
def mps_mode():
global vram_state
return vram_state == VRAMState.MPS
global cpu_state
return cpu_state == CPUState.MPS
def should_use_fp16():
global xpu_available
@ -377,7 +426,10 @@ def should_use_fp16():
def soft_empty_cache():
global xpu_available
if xpu_available:
global cpu_state
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
elif xpu_available:
torch.xpu.empty_cache()
elif torch.cuda.is_available():
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda

View File

@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
model_management.throw_exception_if_processing_interrupted()
@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device):
def encode_adm(conds, batch_size, device, noise_augmentor=None):
for t in range(len(conds)):
x = conds[t]
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
adm_out = None
if noise_augmentor is not None:
if 'adm' in x[1]:
adm_inputs = []
weights = []
noise_aug = []
adm_in = x[1]["adm"]
for adm_c in adm_in:
adm_cond = adm_c[0].image_embeds
weight = adm_c[1]
noise_augment = adm_c[2]
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
weights.append(weight)
noise_aug.append(noise_augment)
adm_inputs.append(adm_out)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
if len(noise_aug) > 1:
adm_out = torch.stack(adm_inputs).sum(0)
#TODO: add a way to control this
noise_augment = 0.05
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
adm_out = torch.cat((c_adm, noise_level_emb), 1)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
if 'adm' in x[1]:
adm_out = x[1]["adm"].to(device)
if adm_out is not None:
x[1] = x[1].copy()
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
return conds
@ -591,14 +597,17 @@ class KSampler:
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16:
if self.model.get_dtype() == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
if hasattr(self.model, 'noise_augmentor'): #unclip
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
if self.model.is_adm():
noise_augmentor = None
if hasattr(self.model, 'noise_augmentor'): #unclip
noise_augmentor = self.model.noise_augmentor
positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor)
negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}

View File

@ -14,8 +14,16 @@ from .t2i_adapter import adapter
from . import utils
from . import clip_vision
from . import gligen
from . import diffusers_convert
from . import model_base
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
replace_prefix = {"model.diffusion_model.": "diffusion_model."}
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys())))
for x in replace:
sd[x[1]] = sd.pop(x[0])
m, u = model.load_state_dict(sd, strict=False)
k = list(sd.keys())
@ -30,17 +38,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
keys_to_replace = {
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
}
for x in keys_to_replace:
if x in sd:
sd[keys_to_replace[x]] = sd.pop(x)
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
for x in load_state_dict_to:
@ -192,7 +189,7 @@ def model_lora_keys(model, key_map={}):
counter = 0
for b in range(12):
tk = "model.diffusion_model.input_blocks.{}.1".format(b)
tk = "diffusion_model.input_blocks.{}.1".format(b)
up_counter = 0
for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c)
@ -203,13 +200,13 @@ def model_lora_keys(model, key_map={}):
if up_counter >= 4:
counter += 1
for c in LORA_UNET_MAP_ATTENTIONS:
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
k = "diffusion_model.middle_block.1.{}.weight".format(c)
if k in sdk:
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
key_map[lora_key] = k
counter = 3
for b in range(12):
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
tk = "diffusion_model.output_blocks.{}.1".format(b)
up_counter = 0
for c in LORA_UNET_MAP_ATTENTIONS:
k = "{}.{}.weight".format(tk, c)
@ -233,7 +230,7 @@ def model_lora_keys(model, key_map={}):
ds_counter = 0
counter = 0
for b in range(12):
tk = "model.diffusion_model.input_blocks.{}.0".format(b)
tk = "diffusion_model.input_blocks.{}.0".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
@ -252,7 +249,7 @@ def model_lora_keys(model, key_map={}):
counter = 0
for b in range(3):
tk = "model.diffusion_model.middle_block.{}".format(b)
tk = "diffusion_model.middle_block.{}".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
@ -266,7 +263,7 @@ def model_lora_keys(model, key_map={}):
counter = 0
us_counter = 0
for b in range(12):
tk = "model.diffusion_model.output_blocks.{}.0".format(b)
tk = "diffusion_model.output_blocks.{}.0".format(b)
key_in = False
for c in LORA_UNET_MAP_RESNET:
k = "{}.{}.weight".format(tk, c)
@ -285,15 +282,29 @@ def model_lora_keys(model, key_map={}):
return key_map
class ModelPatcher:
def __init__(self, model):
def __init__(self, model, size=0):
self.size = size
self.model = model
self.patches = []
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
return size
def clone(self):
n = ModelPatcher(self.model)
n = ModelPatcher(self.model, self.size)
n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options)
return n
@ -328,7 +339,7 @@ class ModelPatcher:
patch_list[i] = patch_list[i].to(device)
def model_dtype(self):
return self.model.diffusion_model.dtype
return self.model.get_dtype()
def add_patches(self, patches, strength=1.0):
p = {}
@ -504,10 +515,16 @@ class VAE:
if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
else:
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
if ckpt_path is not None:
sd = utils.load_torch_file(ckpt_path)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.first_stage_model.load_state_dict(sd, strict=False)
self.scale_factor = scale_factor
if device is None:
device = model_management.get_torch_device()
@ -600,7 +617,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
return torch.cat([tensor] * batched_number, dim=0)
class ControlNet:
def __init__(self, control_model, device=None):
def __init__(self, control_model, global_average_pooling=False, device=None):
self.control_model = control_model
self.cond_hint_original = None
self.cond_hint = None
@ -609,6 +626,7 @@ class ControlNet:
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond_txt, batched_number):
control_prev = None
@ -644,6 +662,9 @@ class ControlNet:
key = 'output'
index = i
x = control[i]
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype)
@ -674,7 +695,7 @@ class ControlNet:
self.cond_hint = None
def copy(self):
c = ControlNet(self.control_model)
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
return c
@ -722,7 +743,7 @@ def load_controlnet(ckpt_path, model=None):
use_spatial_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
use_checkpoint=False,
legacy=False,
use_fp16=use_fp16)
else:
@ -739,7 +760,7 @@ def load_controlnet(ckpt_path, model=None):
use_linear_in_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
use_checkpoint=False,
legacy=False,
use_fp16=use_fp16)
if pth:
@ -750,7 +771,7 @@ def load_controlnet(ckpt_path, model=None):
for x in controlnet_data:
c_m = "control_model."
if x.startswith(c_m):
sd_key = "model.diffusion_model.{}".format(x[len(c_m):])
sd_key = "diffusion_model.{}".format(x[len(c_m):])
if sd_key in model_sd:
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
@ -769,7 +790,11 @@ def load_controlnet(ckpt_path, model=None):
if use_fp16:
control_model = control_model.half()
control = ControlNet(control_model)
global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
return control
class T2IAdapter:
@ -913,9 +938,10 @@ def load_gligen(ckpt_path):
model = model.half()
return model
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
if config is None:
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
@ -924,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
fp16 = False
if "unet_config" in model_config_params:
if "params" in model_config_params["unet_config"]:
if "use_fp16" in model_config_params["unet_config"]["params"]:
fp16 = model_config_params["unet_config"]["params"]["use_fp16"]
unet_config = model_config_params["unet_config"]["params"]
if "use_fp16" in unet_config:
fp16 = unet_config["use_fp16"]
noise_aug_config = None
if "noise_aug_config" in model_config_params:
noise_aug_config = model_config_params["noise_aug_config"]
v_prediction = False
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
v_prediction = True
clip = None
vae = None
@ -945,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
sd = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
if state_dict is None:
state_dict = utils.load_torch_file(ckpt_path)
model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
@ -1024,7 +1068,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
}
unet_config = {
"use_checkpoint": True,
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"attention_resolutions": [
@ -1044,47 +1088,59 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
"legacy": False
}
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2:
if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2:
unet_config['use_linear_in_transformer'] = True
unet_config["use_fp16"] = fp16
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
unclip_model = False
inpaint_model = False
if noise_aug_config is not None: #SD2.x unclip model
sd_config["noise_aug_config"] = noise_aug_config
sd_config["image_size"] = 96
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
unclip_model = True
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
inpaint_model = True
else:
sd_config["conditioning_key"] = "crossattn"
if unet_config["context_dim"] == 1024:
unet_config["num_head_channels"] = 64 #SD2.x
else:
if unet_config["context_dim"] == 768:
unet_config["num_heads"] = 8 #SD1.x
else:
unet_config["num_head_channels"] = 64 #SD2.x
unclip = 'model.diffusion_model.label_emb.0.0.weight'
if unclip in sd_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = sd[unclip].shape[1]
v_prediction = False
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out = sd[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
v_prediction = True
sd_config["parameterization"] = 'v'
model = instantiate_from_config(model_config)
if inpaint_model:
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
elif unclip_model:
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
else:
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:

View File

@ -82,6 +82,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
next_new_token += 1
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]]
out_tokens += [tokens_temp]
if len(embedding_weights) > 0:

65
comfy/taesd/taesd.py Normal file
View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
"""
import torch
import torch.nn as nn
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def Encoder():
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
)
def Decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
@staticmethod
def scale_latents(x):
"""raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)

View File

@ -1,11 +1,16 @@
import torch
import math
import struct
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else:
@ -19,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False):
return sd
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
"{}.ln_final.weight": "{}.final_layer_norm.weight",
"{}.ln_final.bias": "{}.final_layer_norm.bias",
}
for k in keys_to_replace:
x = k.format(prefix_from)
if x in sd:
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",
@ -46,71 +63,87 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
#slow and inefficient, should be optimized
def safetensors_header(safetensors_path, max_size=100*1024*1024):
with open(safetensors_path, "rb") as f:
header = f.read(8)
length_of_header = struct.unpack('<Q', header)[0]
if length_of_header > max_size:
return None
return f.read(length_of_header)
def bislerp(samples, width, height):
shape = list(samples.shape)
width_scale = (shape[3]) / (width )
height_scale = (shape[2]) / (height )
def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
c = b1.shape[-1]
shape[3] = width
shape[2] = height
out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device)
#norms
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
def algorithm(in1, in2, t):
dims = in1.shape
val = t
#normalize
b1_normalized = b1 / b1_norms
b2_normalized = b2 / b2_norms
#flatten to batches
low = in1.reshape(dims[0], -1)
high = in2.reshape(dims[0], -1)
#zero when norms are zero
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
low_weight = torch.norm(low, dim=1, keepdim=True)
low_weight[low_weight == 0] = 0.0000000001
low_norm = low/low_weight
high_weight = torch.norm(high, dim=1, keepdim=True)
high_weight[high_weight == 0] = 0.0000000001
high_norm = high/high_weight
dot_prod = (low_norm*high_norm).sum(1)
dot_prod[dot_prod > 0.9995] = 0.9995
dot_prod[dot_prod < -0.9995] = -0.9995
omega = torch.acos(dot_prod)
#slerp
dot = (b1_normalized*b2_normalized).sum(1)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm
res *= (low_weight * (1.0-val) + high_weight * val)
return res.reshape(dims)
for x_dest in range(shape[3]):
for y_dest in range(shape[2]):
y = (y_dest + 0.5) * height_scale - 0.5
x = (x_dest + 0.5) * width_scale - 0.5
#technically not mathematically correct, but more pleasing?
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
x1 = max(math.floor(x), 0)
x2 = min(x1 + 1, samples.shape[3] - 1)
wx = x - math.floor(x)
#edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new):
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
n,c,h,w = samples.shape
h_new, w_new = (height, width)
#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
y1 = max(math.floor(y), 0)
y2 = min(y1 + 1, samples.shape[2] - 1)
wy = y - math.floor(y)
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
ratios = ratios.movedim(1, -1).reshape((-1,1))
in1 = samples[:,:,y1,x1]
in2 = samples[:,:,y1,x2]
in3 = samples[:,:,y2,x1]
in4 = samples[:,:,y2,x2]
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
if (x1 == x2) and (y1 == y2):
out_value = in1
elif (x1 == x2):
out_value = algorithm(in1, in3, wy)
elif (y1 == y2):
out_value = algorithm(in1, in2, wx)
else:
o1 = algorithm(in1, in2, wx)
o2 = algorithm(in3, in4, wx)
out_value = algorithm(o1, o2, wy)
#linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
out1[:,:,y_dest,x_dest] = out_value
return out1
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
ratios = ratios.movedim(1, -1).reshape((-1,1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
@ -176,14 +209,14 @@ class ProgressBar:
self.current = 0
self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None):
def update_absolute(self, value, total=None, preview=None):
if total is not None:
self.total = total
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total)
self.hook(self.current, self.total, preview)
def update(self, value):
self.update_absolute(self.current + value)

View File

@ -167,7 +167,7 @@ class MaskComposite:
"source": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"operation": (["multiply", "add", "subtract"],),
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
}
}
@ -193,6 +193,12 @@ class MaskComposite:
output[top:bottom, left:right] = destination_portion + source_portion
elif operation == "subtract":
output[top:bottom, left:right] = destination_portion - source_portion
elif operation == "and":
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
elif operation == "or":
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
elif operation == "xor":
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
output = torch.clamp(output, 0.0, 1.0)

View File

@ -102,13 +102,21 @@ def get_output_data(obj, input_data_all):
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
def format_value(x):
if x is None:
return None
elif isinstance(x, (int, float, bool, str)):
return x
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return
return (True, None, None)
for x in inputs:
input_data = inputs[x]
@ -117,22 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
if result[0] is not True:
# Another node failed further upstream
return result
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
except comfy.model_management.InterruptProcessingException as iex:
print("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
"node_id": unique_id,
}
return (False, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
input_data_formatted = {}
if input_data_all is not None:
input_data_formatted = {}
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
output_data_formatted = {}
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
print("!!! Exception during processing !!!")
print(traceback.format_exc())
error_details = {
"node_id": unique_id,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
}
return (False, error_details, ex)
executed.add(unique_id)
return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
@ -210,6 +260,48 @@ class PromptExecutor:
self.old_prompt = {}
self.server = server
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
class_type = prompt[node_id]["class_type"]
# First, send back the status to the frontend depending
# on the exception type
if isinstance(ex, comfy.model_management.InterruptProcessingException):
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
}
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
else:
if self.server.client_id is not None:
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
}
self.server.send_sync("execution_error", mes, self.server.client_id)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
@ -244,42 +336,30 @@ class PromptExecutor:
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set()
try:
to_execute = []
for x in list(execute_outputs):
to_execute += [(0, x)]
output_node_id = None
to_execute = []
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1]
for node_id in list(execute_outputs):
to_execute += [(0, node_id)]
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
except Exception as e:
if isinstance(e, comfy.model_management.InterruptProcessingException):
print("Processing interrupted")
else:
message = str(traceback.format_exc())
print(message)
if self.server.client_id is not None:
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
finally:
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
if success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect()
@ -297,57 +377,202 @@ def validate_inputs(prompt, item, validated):
class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required']
errors = []
valid = True
for x in required_inputs:
if x not in inputs:
return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id)
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
}
}
errors.append(error)
continue
val = inputs[x]
info = required_inputs[x]
type_input = info[0]
if isinstance(val, list):
if len(val) != 2:
return (False, "Bad Input. {}, {}".format(class_type, x), unique_id)
error = {
"type": "bad_linked_input",
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val
}
}
errors.append(error)
continue
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input:
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id)
r = validate_inputs(prompt, o_id, validated)
if r[0] == False:
validated[o_id] = r
return r
received_type = r[val[1]]
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_type": received_type,
"linked_node": val
}
}
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
continue
except Exception as ex:
typ, _, tb = sys.exc_info()
valid = False
exception_type = full_type_name(typ)
reasons = [{
"type": "exception_during_inner_validation",
"message": "Exception when validating inner node",
"details": str(ex),
"extra_info": {
"input_name": x,
"input_config": info,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"linked_node": val
}
}]
validated[o_id] = (False, reasons, o_id)
continue
else:
if type_input == "INT":
val = int(val)
inputs[x] = val
if type_input == "FLOAT":
val = float(val)
inputs[x] = val
if type_input == "STRING":
val = str(val)
inputs[x] = val
try:
if type_input == "INT":
val = int(val)
inputs[x] = val
if type_input == "FLOAT":
val = float(val)
inputs[x] = val
if type_input == "STRING":
val = str(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value",
"details": f"{x}, {val}, {ex}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
"exception_message": str(ex)
}
}
errors.append(error)
continue
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id)
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]:
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id)
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for r in ret:
if r != True:
return (False, "{}, {}".format(class_type, r), unique_id)
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if isinstance(type_input, list):
if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id)
input_config = info
list_info = ""
# Don't send back gigantic lists like if they're lots of
# scanned model filepaths
if len(type_input) > 20:
list_info = f"(list of length {len(type_input)})"
input_config = None
else:
list_info = str(type_input)
error = {
"type": "value_not_in_list",
"message": "Value not in list",
"details": f"{x}: '{val}' not in {list_info}",
"extra_info": {
"input_name": x,
"input_config": input_config,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
else:
ret = (True, [], unique_id)
ret = (True, "", unique_id)
validated[unique_id] = ret
return ret
def full_type_name(klass):
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__
def validate_prompt(prompt):
outputs = set()
for x in prompt:
@ -356,7 +581,13 @@ def validate_prompt(prompt):
outputs.add(x)
if len(outputs) == 0:
return (False, "Prompt has no outputs", [], [])
error = {
"type": "prompt_no_outputs",
"message": "Prompt has no outputs",
"details": "",
"extra_info": {}
}
return (False, error, [], [])
good_outputs = set()
errors = []
@ -364,34 +595,72 @@ def validate_prompt(prompt):
validated = {}
for o in outputs:
valid = False
reason = ""
reasons = []
try:
m = validate_inputs(prompt, o, validated)
valid = m[0]
reason = m[1]
node_id = m[2]
except Exception as e:
print(traceback.format_exc())
reasons = m[1]
except Exception as ex:
typ, _, tb = sys.exc_info()
valid = False
reason = "Parsing error"
node_id = None
exception_type = full_type_name(typ)
reasons = [{
"type": "exception_during_validation",
"message": "Exception when validating node",
"details": str(ex),
"extra_info": {
"exception_type": exception_type,
"traceback": traceback.format_tb(tb)
}
}]
validated[o] = (False, reasons, o)
if valid == True:
if valid is True:
good_outputs.add(o)
else:
print("Failed to validate prompt for output {} {}".format(o, reason))
print("output will be ignored")
errors += [(o, reason)]
if node_id is not None:
if node_id not in node_errors:
node_errors[node_id] = {"message": reason, "dependent_outputs": []}
node_errors[node_id]["dependent_outputs"].append(o)
print(f"Failed to validate prompt for output {o}:")
if len(reasons) > 0:
print("* (prompt):")
for reason in reasons:
print(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)]
for node_id, result in validated.items():
valid = result[0]
reasons = result[1]
# If a node upstream has errors, the nodes downstream will also
# be reported as invalid, but there will be no errors attached.
# So don't return those nodes as having errors in the response.
if valid is not True and len(reasons) > 0:
if node_id not in node_errors:
class_type = prompt[node_id]['class_type']
node_errors[node_id] = {
"errors": reasons,
"dependent_outputs": [],
"class_type": class_type
}
print(f"* {class_type} {node_id}:")
for reason in reasons:
print(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o)
print("Output will be ignored")
if len(good_outputs) == 0:
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors)
errors_list = []
for o, errors in errors:
for error in errors:
errors_list.append(f"{error['message']}: {error['details']}")
errors_list = "\n".join(errors_list)
return (True, "", list(good_outputs), node_errors)
error = {
"type": "prompt_outputs_failed_validation",
"message": "Prompt outputs failed validation",
"details": errors_list,
"extra_info": {}
}
return (False, error, list(good_outputs), node_errors)
return (True, None, list(good_outputs), node_errors)
class PromptQueue:

View File

@ -1,14 +1,8 @@
import os
import time
supported_ckpt_extensions = set(['.ckpt', '.pth'])
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
try:
import safetensors.torch
supported_ckpt_extensions.add('.safetensors')
supported_pt_extensions.add('.safetensors')
except:
print("Could not import safetensors, safetensors support disabled.")
supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
folder_names_and_paths = {}
@ -24,6 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision"
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
@ -38,6 +33,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
filename_list_cache = {}
if not os.path.exists(input_directory):
os.makedirs(input_directory)
@ -118,12 +115,18 @@ def get_folder_paths(folder_name):
return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory):
if not os.path.isdir(directory):
return [], {}
result = []
dirs = {directory: os.path.getmtime(directory)}
for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file:
#we os.path,join directory with a blank string to generate a path separator at the end.
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
return result
for d in subdir:
path = os.path.join(root, d)
dirs[path] = os.path.getmtime(path)
return result, dirs
def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
@ -132,20 +135,58 @@ def filter_files_extensions(files, extensions):
def get_full_path(folder_name, filename):
global folder_names_and_paths
if folder_name not in folder_names_and_paths:
return None
folders = folder_names_and_paths[folder_name]
filename = os.path.relpath(os.path.join("/", filename), "/")
for x in folders[0]:
full_path = os.path.join(x, filename)
if os.path.isfile(full_path):
return full_path
return None
def get_filename_list(folder_name):
def get_filename_list_(folder_name):
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
output_folders = {}
for x in folders[0]:
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
return sorted(list(output_list))
files, folders_all = recursive_search(x)
output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all}
return (sorted(list(output_list)), output_folders, time.perf_counter())
def cached_filename_list_(folder_name):
global filename_list_cache
global folder_names_and_paths
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]:
time_modified = out[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
folders = folder_names_and_paths[folder_name]
for x in folders[0]:
if os.path.isdir(x):
if x not in out[1]:
return None
return out
def get_filename_list(folder_name):
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
return list(out[0])
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename):

95
latent_preview.py Normal file
View File

@ -0,0 +1,95 @@
import torch
from PIL import Image, ImageOps
from io import BytesIO
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import folder_paths
MAX_PREVIEW_RESOLUTION = 512
class LatentPreviewer:
def decode_latent_to_preview(self, x0):
pass
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format, quality=95)
preview_bytes = bytesIO.getvalue()
return preview_bytes
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0)[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
preview_image = Image.fromarray(x_sample)
return preview_image
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self):
self.latent_rgb_factors = torch.tensor([
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], device="cpu")
def decode_latent_to_preview(self, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
def get_previewer(device):
previewer = None
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if taesd_decoder_path:
method = LatentPreviewMethod.TAESD
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
if previewer is None:
previewer = Latent2RGBPreviewer()
return previewer

25
main.py
View File

@ -26,6 +26,7 @@ import yaml
import execution
import folder_paths
import server
from server import BinaryEventTypes
from nodes import init_custom_nodes
@ -36,19 +37,25 @@ def prompt_worker(q, server):
e.execute(item[2], item[1], item[3], item[4])
q.task_done(item_id, e.outputs_ui)
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server):
def hook(value, total):
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
def hook(value, total, preview_image_bytes):
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
if preview_image_bytes is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
@ -69,6 +76,7 @@ def load_extra_path_config(yaml_path):
print("Adding extra search path", x, full_path)
folder_paths.add_model_folder_path(x, full_path)
if __name__ == "__main__":
cleanup_temp()
@ -89,7 +97,7 @@ if __name__ == "__main__":
server.add_routes()
hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
@ -103,15 +111,12 @@ if __name__ == "__main__":
if args.auto_launch:
def startup_server(address, port):
import webbrowser
webbrowser.open("http://{}:{}".format(address, port))
webbrowser.open(f"http://{address}:{port}")
call_on_start = startup_server
if os.name == "nt":
try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
pass
else:
try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
print("\nStopped server")
cleanup_temp()

View File

@ -13,11 +13,10 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_convert
import comfy.diffusers_load
import comfy.samplers
import comfy.sample
import comfy.sd
@ -29,7 +28,7 @@ import comfy.model_management
import importlib
import folder_paths
import latent_preview
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
@ -248,7 +247,6 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -377,7 +375,7 @@ class DiffusersLoader:
model_path = path
break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader:
@ -426,6 +424,9 @@ class LoraLoader:
CATEGORY = "loaders"
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path("loras", lora_name)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora)
@ -507,6 +508,9 @@ class ControlNetApply:
CATEGORY = "conditioning"
def apply_controlnet(self, conditioning, control_net, image, strength):
if strength == 0:
return (conditioning, )
c = []
control_hint = image.movedim(-1,1)
for t in conditioning:
@ -613,6 +617,9 @@ class unCLIPConditioning:
CATEGORY = "conditioning"
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
if strength == 0:
return (conditioning, )
c = []
for t in conditioning:
o = t[1].copy()
@ -922,6 +929,7 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
@ -936,9 +944,18 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = latent_preview.get_previewer(device)
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps)
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
@ -961,7 +978,8 @@ class KSampler:
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
@ -988,7 +1006,8 @@ class KSamplerAdvanced:
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ),
}}
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"

107
server.py
View File

@ -7,6 +7,7 @@ import execution
import uuid
import json
import glob
import struct
from PIL import Image
from io import BytesIO
@ -22,6 +23,12 @@ except ImportError:
import mimetypes
from comfy.cli_args import args
import comfy.utils
import comfy.model_management
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@web.middleware
@ -216,6 +223,27 @@ class PromptServer():
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
if 'preview' in request.rel_url.query:
with Image.open(file) as img:
preview_info = request.rel_url.query['preview'].split(';')
image_format = preview_info[0]
if image_format not in ['webp', 'jpeg']:
image_format = 'webp'
quality = 90
if preview_info[-1].isdigit():
quality = int(preview_info[-1])
buffer = BytesIO()
if image_format in ['jpeg']:
img = img.convert("RGB")
img.save(buffer, format=image_format, quality=quality)
buffer.seek(0)
return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
headers={"Content-Disposition": f"filename=\"{filename}\""})
if 'channel' not in request.rel_url.query:
channel = 'rgba'
else:
@ -257,6 +285,50 @@ class PromptServer():
return web.Response(status=404)
@routes.get("/view_metadata/{folder_name}")
async def view_metadata(request):
folder_name = request.match_info.get("folder_name", None)
if folder_name is None:
return web.Response(status=404)
if not "filename" in request.rel_url.query:
return web.Response(status=404)
filename = request.rel_url.query["filename"]
if not filename.endswith(".safetensors"):
return web.Response(status=404)
safetensors_path = folder_paths.get_full_path(folder_name, filename)
if safetensors_path is None:
return web.Response(status=404)
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
if out is None:
return web.Response(status=404)
dt = json.loads(out)
if not "__metadata__" in dt:
return web.Response(status=404)
return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"devices": [
{
"name": device_name,
"type": device.type,
"index": device.index,
"vram_total": vram_total,
"vram_free": vram_free,
"torch_vram_total": torch_vram_total,
"torch_vram_free": torch_vram_free,
}
]
}
return web.json_response(system_stats)
@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(self.get_queue_info())
@ -338,7 +410,7 @@ class PromptServer():
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
return web.json_response({"prompt_id": prompt_id})
return web.json_response({"prompt_id": prompt_id, "number": number})
else:
print("invalid prompt:", valid[1])
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
@ -391,16 +463,37 @@ class PromptServer():
return prompt_info
async def send(self, event, data, sid=None):
message = {"type": event, "data": data}
if isinstance(message, str) == False:
message = json.dumps(message)
if isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
else:
await self.send_json(event, data, sid)
def encode_bytes(self, event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
message = bytearray(packed)
message.extend(data)
return message
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)
if sid is None:
for ws in self.sockets.values():
await ws.send_str(message)
await ws.send_bytes(message)
elif sid in self.sockets:
await self.sockets[sid].send_str(message)
await self.sockets[sid].send_bytes(message)
async def send_json(self, event, data, sid=None):
message = {"type": event, "data": data}
if sid is None:
for ws in self.sockets.values():
await ws.send_json(message)
elif sid in self.sockets:
await self.sockets[sid].send_json(message)
def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe(

View File

@ -21,6 +21,7 @@ const colorPalettes = {
"MODEL": "#B39DDB", // light lavender-purple
"STYLE_MODEL": "#C2FFAE", // light green-yellow
"VAE": "#FF6E6E", // bright red
"TAESD": "#DCC274", // cheesecake
},
"litegraph_base": {
"NODE_TITLE_COLOR": "#999",

View File

@ -1,132 +1,138 @@
import { app } from "/scripts/app.js";
import {app} from "/scripts/app.js";
// Adds filtering to combo context menus
const id = "Comfy.ContextMenuFilter";
app.registerExtension({
name: id,
const ext = {
name: "Comfy.ContextMenuFilter",
init() {
const ctxMenu = LiteGraph.ContextMenu;
LiteGraph.ContextMenu = function (values, options) {
const ctx = ctxMenu.call(this, values, options);
// If we are a dark menu (only used for combo boxes) then add a filter input
if (options?.className === "dark" && values?.length > 10) {
const filter = document.createElement("input");
Object.assign(filter.style, {
width: "calc(100% - 10px)",
border: "0",
boxSizing: "border-box",
background: "#333",
border: "1px solid #999",
margin: "0 0 5px 5px",
color: "#fff",
});
filter.classList.add("comfy-context-menu-filter");
filter.placeholder = "Filter list";
this.root.prepend(filter);
let selectedIndex = 0;
let items = this.root.querySelectorAll(".litemenu-entry");
let itemCount = items.length;
let selectedItem;
const items = Array.from(this.root.querySelectorAll(".litemenu-entry"));
let displayedItems = [...items];
let itemCount = displayedItems.length;
// Apply highlighting to the selected item
function updateSelected() {
if (selectedItem) {
selectedItem.style.setProperty("background-color", "");
selectedItem.style.setProperty("color", "");
}
selectedItem = items[selectedIndex];
if (selectedItem) {
selectedItem.style.setProperty("background-color", "#ccc", "important");
selectedItem.style.setProperty("color", "#000", "important");
}
}
// We must request an animation frame for the current node of the active canvas to update.
requestAnimationFrame(() => {
const currentNode = LGraphCanvas.active_canvas.current_node;
const clickedComboValue = currentNode.widgets
.filter(w => w.type === "combo" && w.options.values.length === values.length)
.find(w => w.options.values.every((v, i) => v === values[i]))
.value;
const positionList = () => {
const rect = this.root.getBoundingClientRect();
// If the top is off screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
updateSelected();
// Arrow up/down to select items
filter.addEventListener("keydown", (e) => {
if (e.key === "ArrowUp") {
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
e.preventDefault();
} else if (e.key === "ArrowDown") {
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
e.preventDefault();
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
selectedItem.click();
} else if(e.key === "Escape") {
this.close();
}
});
filter.addEventListener("input", () => {
// Hide all items that dont match our filter
const term = filter.value.toLocaleLowerCase();
items = this.root.querySelectorAll(".litemenu-entry");
// When filtering recompute which items are visible for arrow up/down
// Try and maintain selection
let visibleItems = [];
for (const item of items) {
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
if (visible) {
item.style.display = "block";
if (item === selectedItem) {
selectedIndex = visibleItems.length;
}
visibleItems.push(item);
} else {
item.style.display = "none";
if (item === selectedItem) {
selectedIndex = 0;
}
}
}
items = visibleItems;
let selectedIndex = values.findIndex(v => v === clickedComboValue);
let selectedItem = displayedItems?.[selectedIndex];
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
// Apply highlighting to the selected item
function updateSelected() {
selectedItem?.style.setProperty("background-color", "");
selectedItem?.style.setProperty("color", "");
selectedItem = displayedItems[selectedIndex];
selectedItem?.style.setProperty("background-color", "#ccc", "important");
selectedItem?.style.setProperty("color", "#000", "important");
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
const positionList = () => {
const rect = this.root.getBoundingClientRect();
positionList();
});
// If the top is off-screen then shift the element with scaling applied
if (rect.top < 0) {
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
const shift = (this.root.clientHeight * scale) / 2;
this.root.style.top = -shift + "px";
}
}
// Arrow up/down to select items
filter.addEventListener("keydown", (event) => {
switch (event.key) {
case "ArrowUp":
event.preventDefault();
if (selectedIndex === 0) {
selectedIndex = itemCount - 1;
} else {
selectedIndex--;
}
updateSelected();
break;
case "ArrowRight":
event.preventDefault();
selectedIndex = itemCount - 1;
updateSelected();
break;
case "ArrowDown":
event.preventDefault();
if (selectedIndex === itemCount - 1) {
selectedIndex = 0;
} else {
selectedIndex++;
}
updateSelected();
break;
case "ArrowLeft":
event.preventDefault();
selectedIndex = 0;
updateSelected();
break;
case "Enter":
selectedItem?.click();
break;
case "Escape":
this.close();
break;
}
});
filter.addEventListener("input", () => {
// Hide all items that don't match our filter
const term = filter.value.toLocaleLowerCase();
// When filtering, recompute which items are visible for arrow up/down and maintain selection.
displayedItems = items.filter(item => {
const isVisible = !term || item.textContent.toLocaleLowerCase().includes(term);
item.style.display = isVisible ? "block" : "none";
return isVisible;
});
selectedIndex = 0;
if (displayedItems.includes(selectedItem)) {
selectedIndex = displayedItems.findIndex(d => d === selectedItem);
}
itemCount = displayedItems.length;
updateSelected();
// If we have an event then we can try and position the list under the source
if (options.event) {
let top = options.event.clientY - 10;
const bodyRect = document.body.getBoundingClientRect();
const rootRect = this.root.getBoundingClientRect();
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
top = Math.max(0, bodyRect.height - rootRect.height - 10);
}
this.root.style.top = top + "px";
positionList();
}
});
requestAnimationFrame(() => {
// Focus the filter box when opening
filter.focus();
positionList();
});
})
}
return ctx;
@ -134,4 +140,6 @@ app.registerExtension({
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
},
});
}
app.registerExtension(ext);

View File

@ -3,6 +3,13 @@ import { app } from "../../scripts/app.js";
// Allows for simple dynamic prompt replacement
// Inputs in the format {a|b} will have a random value of a or b chosen when the prompt is queued.
/*
* Strips C-style line and block comments from a string
*/
function stripComments(str) {
return str.replace(/\/\*[\s\S]*?\*\/|\/\/.*/g,'');
}
app.registerExtension({
name: "Comfy.DynamicPrompts",
nodeCreated(node) {
@ -15,7 +22,7 @@ app.registerExtension({
for (const widget of widgets) {
// Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node
widget.serializeValue = (workflowNode, widgetIndex) => {
let prompt = widget.value;
let prompt = stripComments(widget.value);
while (prompt.replace("\\{", "").includes("{") && prompt.replace("\\}", "").includes("}")) {
const startIndex = prompt.replace("\\{", "00").indexOf("{");
const endIndex = prompt.replace("\\}", "00").indexOf("}");

View File

@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) {
});
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString();
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam();
if(ComfyApp.clipspace.images)
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog {
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
// update mask
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCanvas.width = drawWidth;
maskCanvas.height = drawHeight;
maskCanvas.style.top = imgCanvas.offsetTop + "px";
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
});
@ -335,6 +335,7 @@ class MaskEditorDialog extends ComfyDialog {
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
alpha_url.searchParams.delete('channel');
alpha_url.searchParams.delete('preview');
alpha_url.searchParams.set('channel', 'a');
touched_image.src = alpha_url;
@ -345,6 +346,7 @@ class MaskEditorDialog extends ComfyDialog {
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
rgb_url.searchParams.delete('channel');
rgb_url.searchParams.delete('preview');
rgb_url.searchParams.set('channel', 'rgb');
orig_image.src = rgb_url;
this.image = orig_image;

View File

@ -200,8 +200,23 @@ app.registerExtension({
applyToGraph() {
if (!this.outputs[0].links?.length) return;
function get_links(node) {
let links = [];
for (const l of node.outputs[0].links) {
const linkInfo = app.graph.links[l];
const n = node.graph.getNodeById(linkInfo.target_id);
if (n.type == "Reroute") {
links = links.concat(get_links(n));
} else {
links.push(l);
}
}
return links;
}
let links = get_links(this);
// For each output link copy our value over the original widget value
for (const l of this.outputs[0].links) {
for (const l of links) {
const linkInfo = app.graph.links[l];
const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot];

View File

@ -14,5 +14,5 @@
window.graph = app.graph;
</script>
</head>
<body></body>
<body class="litegraph"></body>
</html>

View File

@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action)
if (this.onShowNodePanel) {
this.onShowNodePanel(n);
}
else
{
this.showShowNodePanel(n);
}
if (this.onNodeDblClicked) {
this.onNodeDblClicked(n);
@ -8099,11 +8095,15 @@ LGraphNode.prototype.executeAction = function(action)
bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR;
hovercolor = hovercolor || "#555";
textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR;
var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title
var pos = this.mouse;
var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h );
pos = this.last_click_position;
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h );
var pos = this.ds.convertOffsetToCanvas(this.graph_mouse);
var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null;
if(pos) {
var rect = this.canvas.getBoundingClientRect();
pos[0] -= rect.left;
pos[1] -= rect.top;
}
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
ctx.fillStyle = hover ? hovercolor : bgcolor;
if(clicked)
@ -13067,6 +13067,10 @@ LGraphNode.prototype.executeAction = function(action)
has_submenu: true,
callback: LGraphCanvas.onShowMenuNodeProperties
},
{
content: "Properties Panel",
callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) }
},
null,
{
content: "Title",

View File

@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
this.socket = new WebSocket(
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
);
this.socket.binaryType = "arraybuffer";
this.socket.addEventListener("open", () => {
opened = true;
@ -70,33 +71,65 @@ class ComfyApi extends EventTarget {
this.socket.addEventListener("message", (event) => {
try {
const msg = JSON.parse(event.data);
switch (msg.type) {
case "status":
if (msg.data.sid) {
this.clientId = msg.data.sid;
window.name = this.clientId;
if (event.data instanceof ArrayBuffer) {
const view = new DataView(event.data);
const eventType = view.getUint32(0);
const buffer = event.data.slice(4);
switch (eventType) {
case 1:
const view2 = new DataView(event.data);
const imageType = view2.getUint32(0)
let imageMime
switch (imageType) {
case 1:
default:
imageMime = "image/jpeg";
break;
case 2:
imageMime = "image/png"
}
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
const imageBlob = new Blob([buffer.slice(4)], { type: imageMime });
this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob }));
break;
default:
if (this.#registered.has(msg.type)) {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
} else {
throw new Error("Unknown message type");
}
throw new Error(`Unknown binary websocket message of type ${eventType}`);
}
}
else {
const msg = JSON.parse(event.data);
switch (msg.type) {
case "status":
if (msg.data.sid) {
this.clientId = msg.data.sid;
window.name = this.clientId;
}
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break;
default:
if (this.#registered.has(msg.type)) {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
} else {
throw new Error(`Unknown message type ${msg.type}`);
}
}
}
} catch (error) {
console.warn("Unhandled message:", event.data);
console.warn("Unhandled message:", event.data, error);
}
});
}

View File

@ -44,6 +44,12 @@ export class ComfyApp {
*/
this.nodeOutputs = {};
/**
* Stores the preview image data for each node
* @type {Record<string, Image>}
*/
this.nodePreviewImages = {};
/**
* If the shift key on the keyboard is pressed
* @type {boolean}
@ -51,6 +57,14 @@ export class ComfyApp {
this.shiftDown = false;
}
getPreviewFormatParam() {
let preview_format = this.ui.settings.getSettingValue("Comfy.PreviewFormat");
if(preview_format)
return `&preview=${preview_format}`;
else
return "";
}
static isImageNode(node) {
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
}
@ -111,10 +125,14 @@ export class ComfyApp {
if(ComfyApp.clipspace.imgs && node.imgs) {
if(node.images && ComfyApp.clipspace.images) {
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
}
else
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
else {
node.images = ComfyApp.clipspace.images;
}
if(app.nodeOutputs[node.id + ""])
app.nodeOutputs[node.id + ""].images = node.images;
}
if(ComfyApp.clipspace.imgs) {
@ -147,7 +165,16 @@ export class ComfyApp {
if(ComfyApp.clipspace.widgets) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
if (prop && prop.type != 'button') {
if (prop && prop.type != 'image') {
if(typeof prop.value == "string" && value.filename) {
prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:'');
}
else {
prop.value = value;
prop.callback(value);
}
}
else if (prop && prop.type != 'button') {
prop.value = value;
prop.callback(value);
}
@ -231,14 +258,20 @@ export class ComfyApp {
options.unshift(
{
content: "Open Image",
callback: () => window.open(img.src, "_blank"),
callback: () => {
let url = new URL(img.src);
url.searchParams.delete('preview');
window.open(url, "_blank")
},
},
{
content: "Save Image",
callback: () => {
const a = document.createElement("a");
a.href = img.src;
a.setAttribute("download", new URLSearchParams(new URL(img.src).search).get("filename"));
let url = new URL(img.src);
url.searchParams.delete('preview');
a.href = url;
a.setAttribute("download", new URLSearchParams(url.search).get("filename"));
document.body.append(a);
a.click();
requestAnimationFrame(() => a.remove());
@ -345,6 +378,10 @@ export class ComfyApp {
}
node.prototype.setSizeForImage = function () {
if (this.inputHeight) {
this.setSize(this.size);
return;
}
const minHeight = getImageTop(this) + 220;
if (this.size[1] < minHeight) {
this.setSize([this.size[0], minHeight]);
@ -353,29 +390,52 @@ export class ComfyApp {
node.prototype.onDrawBackground = function (ctx) {
if (!this.flags.collapsed) {
let imgURLs = []
let imagesChanged = false
const output = app.nodeOutputs[this.id + ""];
if (output && output.images) {
if (this.images !== output.images) {
this.images = output.images;
this.imgs = null;
this.imageIndex = null;
imagesChanged = true;
imgURLs = imgURLs.concat(output.images.map(params => {
return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam();
}))
}
}
const preview = app.nodePreviewImages[this.id + ""]
if (this.preview !== preview) {
this.preview = preview
imagesChanged = true;
if (preview != null) {
imgURLs.push(preview);
}
}
if (imagesChanged) {
this.imageIndex = null;
if (imgURLs.length > 0) {
Promise.all(
output.images.map((src) => {
imgURLs.map((src) => {
return new Promise((r) => {
const img = new Image();
img.onload = () => r(img);
img.onerror = () => r(null);
img.src = "/view?" + new URLSearchParams(src).toString();
img.src = src
});
})
).then((imgs) => {
if (this.images === output.images) {
if ((!output || this.images === output.images) && (!preview || this.preview === preview)) {
this.imgs = imgs.filter(Boolean);
this.setSizeForImage?.();
app.graph.setDirtyCanvas(true);
}
});
}
else {
this.imgs = null;
}
}
if (this.imgs && this.imgs.length) {
@ -771,16 +831,27 @@ export class ComfyApp {
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
const res = origDrawNodeShape.apply(this, arguments);
const nodeErrors = self.lastPromptError?.node_errors[node.id];
let color = null;
let lineWidth = 1;
if (node.id === +self.runningNodeId) {
color = "#0f0";
} else if (self.dragOverNode && node.id === self.dragOverNode.id) {
color = "dodgerblue";
}
else if (self.lastPromptError != null && nodeErrors?.errors) {
color = "red";
lineWidth = 2;
}
else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) {
color = "#f0f";
lineWidth = 2;
}
if (color) {
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
ctx.lineWidth = 1;
ctx.lineWidth = lineWidth;
ctx.globalAlpha = 0.8;
ctx.beginPath();
if (shape == LiteGraph.BOX_SHAPE)
@ -807,11 +878,28 @@ export class ComfyApp {
ctx.stroke();
ctx.strokeStyle = fgcolor;
ctx.globalAlpha = 1;
}
if (self.progress) {
ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
ctx.fillStyle = bgcolor;
if (self.progress && node.id === +self.runningNodeId) {
ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
ctx.fillStyle = bgcolor;
}
// Highlight inputs that failed validation
if (nodeErrors) {
ctx.lineWidth = 2;
ctx.strokeStyle = "red";
for (const error of nodeErrors.errors) {
if (error.extra_info && error.extra_info.input_name) {
const inputIndex = node.findInputSlot(error.extra_info.input_name)
if (inputIndex !== -1) {
let pos = node.getConnectionPos(true, inputIndex);
ctx.beginPath();
ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false)
ctx.stroke();
}
}
}
}
@ -859,16 +947,40 @@ export class ComfyApp {
this.progress = null;
this.runningNodeId = detail;
this.graph.setDirtyCanvas(true, false);
delete this.nodePreviewImages[this.runningNodeId]
});
api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output;
const node = this.graph.getNodeById(detail.node);
if (node?.onExecuted) {
node.onExecuted(detail.output);
if (node) {
if (node.onExecuted)
node.onExecuted(detail.output);
}
});
api.addEventListener("execution_start", ({ detail }) => {
this.runningNodeId = null;
this.lastExecutionError = null
});
api.addEventListener("execution_error", ({ detail }) => {
this.lastExecutionError = detail;
const formattedError = this.#formatExecutionError(detail);
this.ui.dialog.show(formattedError);
this.canvas.draw(true, true);
});
api.addEventListener("b_preview", ({ detail }) => {
const id = this.runningNodeId
if (id == null)
return;
const blob = detail
const blobUrl = URL.createObjectURL(blob)
this.nodePreviewImages[id] = [blobUrl]
});
api.init();
}
@ -975,6 +1087,11 @@ export class ComfyApp {
const app = this;
// Load node definitions from the backend
const defs = await api.getNodeDefs();
await this.registerNodesFromDefs(defs);
await this.#invokeExtensionsAsync("registerCustomNodes");
}
async registerNodesFromDefs(defs) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
// Generate list of known widgets
@ -1047,8 +1164,6 @@ export class ComfyApp {
LiteGraph.registerNodeType(nodeId, node);
node.category = nodeData.category;
}
await this.#invokeExtensionsAsync("registerCustomNodes");
}
/**
@ -1247,6 +1362,43 @@ export class ComfyApp {
return { workflow, output };
}
#formatPromptError(error) {
if (error == null) {
return "(unknown error)"
}
else if (typeof error === "string") {
return error;
}
else if (error.stack && error.message) {
return error.toString()
}
else if (error.response) {
let message = error.response.error.message;
if (error.response.error.details)
message += ": " + error.response.error.details;
for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) {
message += "\n" + nodeError.class_type + ":"
for (const errorReason of nodeError.errors) {
message += "\n - " + errorReason.message + ": " + errorReason.details
}
}
return message
}
return "(unknown error)"
}
#formatExecutionError(error) {
if (error == null) {
return "(unknown error)"
}
const traceback = error.traceback.join("")
const nodeId = error.node_id
const nodeType = error.node_type
return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}`
}
async queuePrompt(number, batchCount = 1) {
this.#queueItems.push({ number, batchCount });
@ -1254,8 +1406,10 @@ export class ComfyApp {
if (this.#processingQueue) {
return;
}
this.#processingQueue = true;
this.lastPromptError = null;
try {
while (this.#queueItems.length) {
({ number, batchCount } = this.#queueItems.pop());
@ -1266,7 +1420,12 @@ export class ComfyApp {
try {
await api.queuePrompt(number, p);
} catch (error) {
this.ui.dialog.show(error.response.error || error.toString());
const formattedError = this.#formatPromptError(error)
this.ui.dialog.show(formattedError);
if (error.response) {
this.lastPromptError = error.response;
this.canvas.draw(true, true);
}
break;
}
@ -1345,6 +1504,11 @@ export class ComfyApp {
const def = defs[node.type];
// HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes,
// and additional work is needed to consider the primitive logic in the refresh logic.
if(!def)
continue;
for(const widgetNum in node.widgets) {
const widget = node.widgets[widgetNum]
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
@ -1364,6 +1528,10 @@ export class ComfyApp {
*/
clean() {
this.nodeOutputs = {};
this.nodePreviewImages = {}
this.lastPromptError = null;
this.lastExecutionError = null;
this.runningNodeId = null;
}
}

View File

@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) {
const embeddings = await api.getEmbeddings();
const opts = parameters
.substr(p)
.split("\n")[1]
.split(",")
.reduce((p, n) => {
const s = n.split(":");

View File

@ -462,6 +462,24 @@ export class ComfyUI {
defaultValue: true,
});
/**
* file format for preview
*
* format;quality
*
* ex)
* webp;50 -> webp, quality 50
* jpeg;80 -> rgb, jpeg, quality 80
*
* @type {string}
*/
const previewImage = this.settings.addSetting({
id: "Comfy.PreviewFormat",
name: "When displaying a preview in the image widget, convert it to a lightweight image. (webp, jpeg, webp;50, ...)",
type: "string",
defaultValue: "",
});
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",

View File

@ -115,12 +115,12 @@ function addMultilineWidget(node, name, opts, app) {
// See how large each text input can be
freeSpace -= widgetHeight;
freeSpace /= multi.length;
freeSpace /= multi.length + (!!node.imgs?.length);
if (freeSpace < MIN_SIZE) {
// There isnt enough space for all the widgets, increase the size of the node
freeSpace = MIN_SIZE;
node.size[1] = y + widgetHeight + freeSpace * multi.length;
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
node.graph.setDirtyCanvas(true);
}
@ -303,7 +303,7 @@ export const ComfyWidgets = {
subfolder = name.substring(0, folder_separator);
name = name.substring(folder_separator + 1);
}
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`;
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`;
node.setSizeForImage?.();
}

View File

@ -50,7 +50,7 @@ body {
padding: 30px 30px 10px 30px;
background-color: var(--comfy-menu-bg); /* Modal background */
color: var(--error-text);
box-shadow: 0px 0px 20px #888888;
box-shadow: 0 0 20px #888888;
border-radius: 10px;
top: 50%;
left: 50%;
@ -84,7 +84,7 @@ body {
font-size: 15px;
position: absolute;
top: 50%;
right: 0%;
right: 0;
text-align: center;
z-index: 100;
width: 170px;
@ -252,7 +252,7 @@ button.comfy-queue-btn {
bottom: 0 !important;
left: auto !important;
right: 0 !important;
border-radius: 0px;
border-radius: 0;
}
.comfy-menu span.drag-handle {
visibility:hidden
@ -289,6 +289,11 @@ button.comfy-queue-btn {
/* Context menu */
.litegraph .dialog {
z-index: 1;
font-family: Arial, sans-serif;
}
.litegraph .litemenu-entry.has_submenu {
position: relative;
padding-right: 20px;
@ -325,12 +330,20 @@ button.comfy-queue-btn {
color: var(--input-text) !important;
}
.comfy-context-menu-filter {
box-sizing: border-box;
border: 1px solid #999;
margin: 0 0 5px 5px;
width: calc(100% - 10px);
}
/* Search box */
.litegraph.litesearchbox {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
overflow: hidden;
display: block;
}
.litegraph.litesearchbox input,