mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge branch 'comfyanonymous:master' into dpr
This commit is contained in:
commit
7352a6b41a
@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'):
|
||||
else:
|
||||
raise AssertionError('Unknown merge analysis result')
|
||||
|
||||
|
||||
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
||||
repo = pygit2.Repository(str(sys.argv[1]))
|
||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||
try:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ custom_nodes/
|
||||
!custom_nodes/example_node.py.example
|
||||
extra_model_paths.yaml
|
||||
/.vs
|
||||
.idea/
|
||||
72
README.md
72
README.md
@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||
@ -37,28 +38,28 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|
||||
## Shortcuts
|
||||
|
||||
| Keybind | Explanation |
|
||||
| - | - |
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Delete/Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| Keybind | Explanation |
|
||||
|---------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
| Ctrl + Delete/Backspace | Delete the current graph |
|
||||
| Space | Move the canvas around when held and moving the cursor |
|
||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
||||
| Ctrl + D | Load default graph |
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for MacOS users
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
|
||||
# Installing
|
||||
|
||||
@ -118,13 +119,26 @@ After this you should have everything installed and can proceed to running Comfy
|
||||
|
||||
### Others:
|
||||
|
||||
[Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
|
||||
#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476)
|
||||
|
||||
Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own.
|
||||
#### Apple Mac silicon
|
||||
|
||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||
|
||||
1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide.
|
||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
||||
1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies).
|
||||
1. Launch ComfyUI by running `python main.py`.
|
||||
|
||||
> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux).
|
||||
|
||||
#### DirectML (AMD Cards on Windows)
|
||||
|
||||
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
|
||||
|
||||
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
|
||||
|
||||
You don't. If you have another UI installed and working with it's own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
|
||||
You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
|
||||
|
||||
```source path_to_other_sd_gui/venv/bin/activate```
|
||||
|
||||
@ -134,7 +148,7 @@ With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"```
|
||||
|
||||
With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"```
|
||||
|
||||
And then you can use that terminal to run Comfyui without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
|
||||
And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
|
||||
|
||||
# Running
|
||||
|
||||
@ -158,6 +172,8 @@ You can use () to change emphasis of a word or phrase like: (good code:1.2) or (
|
||||
|
||||
You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}.
|
||||
|
||||
Dynamic prompts also support C-style comments, like `// comment` or `/* comment */`.
|
||||
|
||||
To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension):
|
||||
|
||||
```embedding:embedding_filename.pt```
|
||||
@ -181,6 +197,12 @@ You can set this command line setting to disable the upcasting to fp32 in some c
|
||||
|
||||
```--dont-upcast-attention```
|
||||
|
||||
## How to show high-quality previews?
|
||||
|
||||
Use ```--preview-method auto``` to enable previews.
|
||||
|
||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
|
||||
|
||||
@ -1,4 +1,35 @@
|
||||
import argparse
|
||||
import enum
|
||||
|
||||
|
||||
class EnumAction(argparse.Action):
|
||||
"""
|
||||
Argparse action for handling Enums
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
# Pop off the type value
|
||||
enum_type = kwargs.pop("type", None)
|
||||
|
||||
# Ensure an Enum subclass is provided
|
||||
if enum_type is None:
|
||||
raise ValueError("type must be assigned an Enum when using EnumAction")
|
||||
if not issubclass(enum_type, enum.Enum):
|
||||
raise TypeError("type must be an Enum when using EnumAction")
|
||||
|
||||
# Generate choices from the Enum
|
||||
choices = tuple(e.value for e in enum_type)
|
||||
kwargs.setdefault("choices", choices)
|
||||
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
|
||||
|
||||
super(EnumAction, self).__init__(**kwargs)
|
||||
|
||||
self._enum = enum_type
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
# Convert value back into an Enum
|
||||
value = self._enum(values)
|
||||
setattr(namespace, self.dest, value)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@ -13,6 +44,14 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
|
||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
Auto = "auto"
|
||||
Latent2RGB = "latent2rgb"
|
||||
TAESD = "taesd"
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
|
||||
@ -1,14 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
||||
import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||
|
||||
@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||
|
||||
# magic
|
||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||
if 'prediction_type' in diffusers_scheduler_conf:
|
||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||
|
||||
if v2:
|
||||
if v_pred:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
vae_config['scale_factor'] = scale_factor
|
||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
else:
|
||||
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
load_state_dict_to = []
|
||||
if output_vae:
|
||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
if output_clip:
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
model = instantiate_from_config(config["model"])
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
return ModelPatcher(model), clip, vae
|
||||
|
||||
87
comfy/diffusers_load.py
Normal file
87
comfy/diffusers_load.py
Normal file
@ -0,0 +1,87 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
||||
import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import diffusers_convert
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||
|
||||
# magic
|
||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||
if 'prediction_type' in diffusers_scheduler_conf:
|
||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||
|
||||
if v2:
|
||||
if v_pred:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
vae_config['scale_factor'] = scale_factor
|
||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
else:
|
||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
|
||||
return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)
|
||||
66
comfy/model_base.py
Normal file
66
comfy/model_base.py
Normal file
@ -0,0 +1,66 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import numpy as np
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, unet_config, v_prediction=False):
|
||||
super().__init__()
|
||||
|
||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.diffusion_model = UNetModel(**unet_config)
|
||||
self.v_prediction = v_prediction
|
||||
if self.v_prediction:
|
||||
self.parameterization = "v"
|
||||
else:
|
||||
self.parameterization = "eps"
|
||||
if "adm_in_channels" in unet_config:
|
||||
self.adm_channels = unet_config["adm_in_channels"]
|
||||
else:
|
||||
self.adm_channels = 0
|
||||
print("v_prediction", v_prediction)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
|
||||
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
else:
|
||||
xc = x
|
||||
context = torch.cat(c_crossattn, 1)
|
||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
def is_adm(self):
|
||||
return self.adm_channels > 0
|
||||
|
||||
class SD21UNCLIP(BaseModel):
|
||||
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
|
||||
super().__init__(unet_config, v_prediction)
|
||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||
|
||||
class SDInpaint(BaseModel):
|
||||
def __init__(self, unet_config, v_prediction=False):
|
||||
super().__init__(unet_config, v_prediction)
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
@ -1,23 +1,29 @@
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args
|
||||
import torch
|
||||
|
||||
class VRAMState(Enum):
|
||||
CPU = 0
|
||||
NO_VRAM = 1
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
LOW_VRAM = 2
|
||||
NORMAL_VRAM = 3
|
||||
HIGH_VRAM = 4
|
||||
MPS = 5
|
||||
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
|
||||
|
||||
class CPUState(Enum):
|
||||
GPU = 0
|
||||
CPU = 1
|
||||
MPS = 2
|
||||
|
||||
# Determine VRAM State
|
||||
vram_state = VRAMState.NORMAL_VRAM
|
||||
set_vram_to = VRAMState.NORMAL_VRAM
|
||||
cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
total_vram_available_mb = -1
|
||||
|
||||
accelerate_enabled = False
|
||||
lowvram_available = True
|
||||
xpu_available = False
|
||||
|
||||
directml_enabled = False
|
||||
@ -31,30 +37,80 @@ if args.directml is not None:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import torch
|
||||
if directml_enabled:
|
||||
total_vram = 4097 #TODO
|
||||
else:
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
||||
except:
|
||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
if not args.normalvram and not args.cpu:
|
||||
if total_vram <= 4096:
|
||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
cpu_state = CPUState.MPS
|
||||
except:
|
||||
pass
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if cpu_state == CPUState.MPS:
|
||||
return torch.device("mps")
|
||||
if cpu_state == CPUState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_total_torch = mem_total
|
||||
elif xpu_available:
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_cuda
|
||||
|
||||
if torch_total_too:
|
||||
return (mem_total, mem_total_torch)
|
||||
else:
|
||||
return mem_total
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
if not args.normalvram and not args.cpu:
|
||||
if lowvram_available and total_vram <= 4096:
|
||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
|
||||
try:
|
||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||
except:
|
||||
@ -92,6 +148,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
||||
|
||||
if args.lowvram:
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
lowvram_available = True
|
||||
elif args.novram:
|
||||
set_vram_to = VRAMState.NO_VRAM
|
||||
elif args.highvram:
|
||||
@ -102,54 +159,38 @@ if args.force_fp32:
|
||||
print("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
|
||||
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
if lowvram_available:
|
||||
try:
|
||||
import accelerate
|
||||
accelerate_enabled = True
|
||||
vram_state = set_vram_to
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
|
||||
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
|
||||
lowvram_available = False
|
||||
|
||||
total_vram_available_mb = (total_vram - 1024) // 2
|
||||
total_vram_available_mb = int(max(256, total_vram_available_mb))
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
vram_state = VRAMState.MPS
|
||||
except:
|
||||
pass
|
||||
if cpu_state != CPUState.GPU:
|
||||
vram_state = VRAMState.DISABLED
|
||||
|
||||
if args.cpu:
|
||||
vram_state = VRAMState.CPU
|
||||
if cpu_state == CPUState.MPS:
|
||||
vram_state = VRAMState.SHARED
|
||||
|
||||
print(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_torch_device_name(device):
|
||||
if hasattr(device, 'type'):
|
||||
return "{}".format(device.type)
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
if device.type == "cuda":
|
||||
return "{} {}".format(device, torch.cuda.get_device_name(device))
|
||||
else:
|
||||
return "{}".format(device.type)
|
||||
else:
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
try:
|
||||
print("Using device:", get_torch_device_name(get_torch_device()))
|
||||
print("Device:", get_torch_device_name(get_torch_device()))
|
||||
except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
@ -199,22 +240,29 @@ def load_model_gpu(model):
|
||||
model.unpatch_model()
|
||||
raise e
|
||||
|
||||
model.model_patches_to(get_torch_device())
|
||||
torch_dev = get_torch_device()
|
||||
model.model_patches_to(torch_dev)
|
||||
|
||||
vram_set_state = vram_state
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = model.model_size()
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
|
||||
current_loaded_model = model
|
||||
if vram_state == VRAMState.CPU:
|
||||
|
||||
if vram_set_state == VRAMState.DISABLED:
|
||||
pass
|
||||
elif vram_state == VRAMState.MPS:
|
||||
mps_device = torch.device("mps")
|
||||
real_model.to(mps_device)
|
||||
pass
|
||||
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||
model_accelerated = False
|
||||
real_model.to(get_torch_device())
|
||||
else:
|
||||
if vram_state == VRAMState.NO_VRAM:
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
elif vram_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||
model_accelerated = True
|
||||
@ -223,7 +271,7 @@ def load_model_gpu(model):
|
||||
def load_controlnet_gpu(control_models):
|
||||
global current_gpu_controlnets
|
||||
global vram_state
|
||||
if vram_state == VRAMState.CPU:
|
||||
if vram_state == VRAMState.DISABLED:
|
||||
return
|
||||
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
@ -268,7 +316,8 @@ def get_autocast_device(dev):
|
||||
def xformers_enabled():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if vram_state == VRAMState.CPU:
|
||||
global cpu_state
|
||||
if cpu_state != CPUState.GPU:
|
||||
return False
|
||||
if xpu_available:
|
||||
return False
|
||||
@ -340,12 +389,12 @@ def maximum_batch_area():
|
||||
return int(max(area, 0))
|
||||
|
||||
def cpu_mode():
|
||||
global vram_state
|
||||
return vram_state == VRAMState.CPU
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.CPU
|
||||
|
||||
def mps_mode():
|
||||
global vram_state
|
||||
return vram_state == VRAMState.MPS
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def should_use_fp16():
|
||||
global xpu_available
|
||||
@ -377,7 +426,10 @@ def should_use_fp16():
|
||||
|
||||
def soft_empty_cache():
|
||||
global xpu_available
|
||||
if xpu_available:
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif xpu_available:
|
||||
torch.xpu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
||||
|
||||
@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
|
||||
|
||||
def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
def encode_adm(conds, batch_size, device, noise_augmentor=None):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
if 'adm' in x[1]:
|
||||
adm_inputs = []
|
||||
weights = []
|
||||
noise_aug = []
|
||||
adm_in = x[1]["adm"]
|
||||
for adm_c in adm_in:
|
||||
adm_cond = adm_c[0].image_embeds
|
||||
weight = adm_c[1]
|
||||
noise_augment = adm_c[2]
|
||||
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
||||
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||
weights.append(weight)
|
||||
noise_aug.append(noise_augment)
|
||||
adm_inputs.append(adm_out)
|
||||
adm_out = None
|
||||
if noise_augmentor is not None:
|
||||
if 'adm' in x[1]:
|
||||
adm_inputs = []
|
||||
weights = []
|
||||
noise_aug = []
|
||||
adm_in = x[1]["adm"]
|
||||
for adm_c in adm_in:
|
||||
adm_cond = adm_c[0].image_embeds
|
||||
weight = adm_c[1]
|
||||
noise_augment = adm_c[2]
|
||||
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
|
||||
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||
weights.append(weight)
|
||||
noise_aug.append(noise_augment)
|
||||
adm_inputs.append(adm_out)
|
||||
|
||||
if len(noise_aug) > 1:
|
||||
adm_out = torch.stack(adm_inputs).sum(0)
|
||||
#TODO: add a way to control this
|
||||
noise_augment = 0.05
|
||||
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
|
||||
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
||||
if len(noise_aug) > 1:
|
||||
adm_out = torch.stack(adm_inputs).sum(0)
|
||||
#TODO: add a way to control this
|
||||
noise_augment = 0.05
|
||||
noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
|
||||
c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
|
||||
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
||||
else:
|
||||
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
||||
else:
|
||||
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
||||
x[1] = x[1].copy()
|
||||
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
|
||||
if 'adm' in x[1]:
|
||||
adm_out = x[1]["adm"].to(device)
|
||||
if adm_out is not None:
|
||||
x[1] = x[1].copy()
|
||||
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
|
||||
|
||||
return conds
|
||||
|
||||
@ -591,14 +597,17 @@ class KSampler:
|
||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
||||
if self.model.get_dtype() == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
if hasattr(self.model, 'noise_augmentor'): #unclip
|
||||
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
|
||||
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
|
||||
if self.model.is_adm():
|
||||
noise_augmentor = None
|
||||
if hasattr(self.model, 'noise_augmentor'): #unclip
|
||||
noise_augmentor = self.model.noise_augmentor
|
||||
positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor)
|
||||
negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||
|
||||
|
||||
142
comfy/sd.py
142
comfy/sd.py
@ -14,8 +14,16 @@ from .t2i_adapter import adapter
|
||||
from . import utils
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
from . import diffusers_convert
|
||||
from . import model_base
|
||||
|
||||
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||
replace_prefix = {"model.diffusion_model.": "diffusion_model."}
|
||||
for rp in replace_prefix:
|
||||
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys())))
|
||||
for x in replace:
|
||||
sd[x[1]] = sd.pop(x[0])
|
||||
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
k = list(sd.keys())
|
||||
@ -30,17 +38,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||
if ids.dtype == torch.float32:
|
||||
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
keys_to_replace = {
|
||||
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
|
||||
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
|
||||
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
|
||||
}
|
||||
|
||||
for x in keys_to_replace:
|
||||
if x in sd:
|
||||
sd[keys_to_replace[x]] = sd.pop(x)
|
||||
|
||||
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
|
||||
|
||||
for x in load_state_dict_to:
|
||||
@ -192,7 +189,7 @@ def model_lora_keys(model, key_map={}):
|
||||
|
||||
counter = 0
|
||||
for b in range(12):
|
||||
tk = "model.diffusion_model.input_blocks.{}.1".format(b)
|
||||
tk = "diffusion_model.input_blocks.{}.1".format(b)
|
||||
up_counter = 0
|
||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
@ -203,13 +200,13 @@ def model_lora_keys(model, key_map={}):
|
||||
if up_counter >= 4:
|
||||
counter += 1
|
||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
||||
k = "diffusion_model.middle_block.1.{}.weight".format(c)
|
||||
if k in sdk:
|
||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
||||
key_map[lora_key] = k
|
||||
counter = 3
|
||||
for b in range(12):
|
||||
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
||||
tk = "diffusion_model.output_blocks.{}.1".format(b)
|
||||
up_counter = 0
|
||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
@ -233,7 +230,7 @@ def model_lora_keys(model, key_map={}):
|
||||
ds_counter = 0
|
||||
counter = 0
|
||||
for b in range(12):
|
||||
tk = "model.diffusion_model.input_blocks.{}.0".format(b)
|
||||
tk = "diffusion_model.input_blocks.{}.0".format(b)
|
||||
key_in = False
|
||||
for c in LORA_UNET_MAP_RESNET:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
@ -252,7 +249,7 @@ def model_lora_keys(model, key_map={}):
|
||||
|
||||
counter = 0
|
||||
for b in range(3):
|
||||
tk = "model.diffusion_model.middle_block.{}".format(b)
|
||||
tk = "diffusion_model.middle_block.{}".format(b)
|
||||
key_in = False
|
||||
for c in LORA_UNET_MAP_RESNET:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
@ -266,7 +263,7 @@ def model_lora_keys(model, key_map={}):
|
||||
counter = 0
|
||||
us_counter = 0
|
||||
for b in range(12):
|
||||
tk = "model.diffusion_model.output_blocks.{}.0".format(b)
|
||||
tk = "diffusion_model.output_blocks.{}.0".format(b)
|
||||
key_in = False
|
||||
for c in LORA_UNET_MAP_RESNET:
|
||||
k = "{}.{}.weight".format(tk, c)
|
||||
@ -285,15 +282,29 @@ def model_lora_keys(model, key_map={}):
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, size=0):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = []
|
||||
self.backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
model_sd = self.model.state_dict()
|
||||
size = 0
|
||||
for k in model_sd:
|
||||
t = model_sd[k]
|
||||
size += t.nelement() * t.element_size()
|
||||
self.size = size
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model)
|
||||
n = ModelPatcher(self.model, self.size)
|
||||
n.patches = self.patches[:]
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
return n
|
||||
@ -328,7 +339,7 @@ class ModelPatcher:
|
||||
patch_list[i] = patch_list[i].to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
return self.model.diffusion_model.dtype
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength=1.0):
|
||||
p = {}
|
||||
@ -504,10 +515,16 @@ class VAE:
|
||||
if config is None:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
if ckpt_path is not None:
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
if device is None:
|
||||
device = model_management.get_torch_device()
|
||||
@ -600,7 +617,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
class ControlNet:
|
||||
def __init__(self, control_model, device=None):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
self.control_model = control_model
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
@ -609,6 +626,7 @@ class ControlNet:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
||||
control_prev = None
|
||||
@ -644,6 +662,9 @@ class ControlNet:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control[i]
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype and not autocast_enabled:
|
||||
x = x.to(output_dtype)
|
||||
@ -674,7 +695,7 @@ class ControlNet:
|
||||
self.cond_hint = None
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model)
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
return c
|
||||
@ -722,7 +743,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
use_spatial_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
use_checkpoint=False,
|
||||
legacy=False,
|
||||
use_fp16=use_fp16)
|
||||
else:
|
||||
@ -739,7 +760,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
use_linear_in_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
use_checkpoint=False,
|
||||
legacy=False,
|
||||
use_fp16=use_fp16)
|
||||
if pth:
|
||||
@ -750,7 +771,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
for x in controlnet_data:
|
||||
c_m = "control_model."
|
||||
if x.startswith(c_m):
|
||||
sd_key = "model.diffusion_model.{}".format(x[len(c_m):])
|
||||
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
||||
if sd_key in model_sd:
|
||||
cd = controlnet_data[x]
|
||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||
@ -769,7 +790,11 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if use_fp16:
|
||||
control_model = control_model.half()
|
||||
|
||||
control = ControlNet(control_model)
|
||||
global_average_pooling = False
|
||||
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||
return control
|
||||
|
||||
class T2IAdapter:
|
||||
@ -913,9 +938,10 @@ def load_gligen(ckpt_path):
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||
if config is None:
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
@ -924,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
||||
fp16 = False
|
||||
if "unet_config" in model_config_params:
|
||||
if "params" in model_config_params["unet_config"]:
|
||||
if "use_fp16" in model_config_params["unet_config"]["params"]:
|
||||
fp16 = model_config_params["unet_config"]["params"]["use_fp16"]
|
||||
unet_config = model_config_params["unet_config"]["params"]
|
||||
if "use_fp16" in unet_config:
|
||||
fp16 = unet_config["use_fp16"]
|
||||
|
||||
noise_aug_config = None
|
||||
if "noise_aug_config" in model_config_params:
|
||||
noise_aug_config = model_config_params["noise_aug_config"]
|
||||
|
||||
v_prediction = False
|
||||
|
||||
if "parameterization" in model_config_params:
|
||||
if model_config_params["parameterization"] == "v":
|
||||
v_prediction = True
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
@ -945,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
model = instantiate_from_config(config["model"])
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
|
||||
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
|
||||
else:
|
||||
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = utils.load_torch_file(ckpt_path)
|
||||
model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
@ -1024,7 +1068,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
}
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": True,
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
"out_channels": 4,
|
||||
"attention_resolutions": [
|
||||
@ -1044,47 +1088,59 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
"legacy": False
|
||||
}
|
||||
|
||||
if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2:
|
||||
if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2:
|
||||
unet_config['use_linear_in_transformer'] = True
|
||||
|
||||
unet_config["use_fp16"] = fp16
|
||||
unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0]
|
||||
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
|
||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||
|
||||
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||
|
||||
unclip_model = False
|
||||
inpaint_model = False
|
||||
if noise_aug_config is not None: #SD2.x unclip model
|
||||
sd_config["noise_aug_config"] = noise_aug_config
|
||||
sd_config["image_size"] = 96
|
||||
sd_config["embedding_dropout"] = 0.25
|
||||
sd_config["conditioning_key"] = 'crossattn-adm'
|
||||
unclip_model = True
|
||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
||||
elif unet_config["in_channels"] > 4: #inpainting model
|
||||
sd_config["conditioning_key"] = "hybrid"
|
||||
sd_config["finetune_keys"] = None
|
||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
inpaint_model = True
|
||||
else:
|
||||
sd_config["conditioning_key"] = "crossattn"
|
||||
|
||||
if unet_config["context_dim"] == 1024:
|
||||
unet_config["num_head_channels"] = 64 #SD2.x
|
||||
else:
|
||||
if unet_config["context_dim"] == 768:
|
||||
unet_config["num_heads"] = 8 #SD1.x
|
||||
else:
|
||||
unet_config["num_head_channels"] = 64 #SD2.x
|
||||
|
||||
unclip = 'model.diffusion_model.label_emb.0.0.weight'
|
||||
if unclip in sd_keys:
|
||||
unet_config["num_classes"] = "sequential"
|
||||
unet_config["adm_in_channels"] = sd[unclip].shape[1]
|
||||
|
||||
v_prediction = False
|
||||
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
||||
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
||||
out = sd[k]
|
||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
v_prediction = True
|
||||
sd_config["parameterization"] = 'v'
|
||||
|
||||
model = instantiate_from_config(model_config)
|
||||
if inpaint_model:
|
||||
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
|
||||
elif unclip_model:
|
||||
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
|
||||
else:
|
||||
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
|
||||
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
|
||||
if fp16:
|
||||
|
||||
@ -82,6 +82,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
next_new_token += 1
|
||||
else:
|
||||
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
||||
while len(tokens_temp) < len(x):
|
||||
tokens_temp += [self.empty_tokens[0][-1]]
|
||||
out_tokens += [tokens_temp]
|
||||
|
||||
if len(embedding_weights) > 0:
|
||||
|
||||
65
comfy/taesd/taesd.py
Normal file
65
comfy/taesd/taesd.py
Normal file
@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tiny AutoEncoder for Stable Diffusion
|
||||
(DNN for encoding / decoding SD's latent space)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.fuse = nn.ReLU()
|
||||
def forward(self, x):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
def Encoder():
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 4),
|
||||
)
|
||||
|
||||
def Decoder():
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(4, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
if encoder_path is not None:
|
||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
||||
if decoder_path is not None:
|
||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
"""raw latents -> [0, 1]"""
|
||||
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
143
comfy/utils.py
143
comfy/utils.py
@ -1,11 +1,16 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
safe_load = False
|
||||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
@ -19,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False):
|
||||
return sd
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
|
||||
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
|
||||
"{}.ln_final.weight": "{}.final_layer_norm.weight",
|
||||
"{}.ln_final.bias": "{}.final_layer_norm.bias",
|
||||
}
|
||||
|
||||
for k in keys_to_replace:
|
||||
x = k.format(prefix_from)
|
||||
if x in sd:
|
||||
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
||||
|
||||
resblock_to_replace = {
|
||||
"ln_1": "layer_norm1",
|
||||
"ln_2": "layer_norm2",
|
||||
@ -46,71 +63,87 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
#slow and inefficient, should be optimized
|
||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
with open(safetensors_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
length_of_header = struct.unpack('<Q', header)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
def bislerp(samples, width, height):
|
||||
shape = list(samples.shape)
|
||||
width_scale = (shape[3]) / (width )
|
||||
height_scale = (shape[2]) / (height )
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
|
||||
c = b1.shape[-1]
|
||||
|
||||
shape[3] = width
|
||||
shape[2] = height
|
||||
out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device)
|
||||
#norms
|
||||
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||
|
||||
def algorithm(in1, in2, t):
|
||||
dims = in1.shape
|
||||
val = t
|
||||
#normalize
|
||||
b1_normalized = b1 / b1_norms
|
||||
b2_normalized = b2 / b2_norms
|
||||
|
||||
#flatten to batches
|
||||
low = in1.reshape(dims[0], -1)
|
||||
high = in2.reshape(dims[0], -1)
|
||||
#zero when norms are zero
|
||||
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||
|
||||
low_weight = torch.norm(low, dim=1, keepdim=True)
|
||||
low_weight[low_weight == 0] = 0.0000000001
|
||||
low_norm = low/low_weight
|
||||
high_weight = torch.norm(high, dim=1, keepdim=True)
|
||||
high_weight[high_weight == 0] = 0.0000000001
|
||||
high_norm = high/high_weight
|
||||
|
||||
dot_prod = (low_norm*high_norm).sum(1)
|
||||
dot_prod[dot_prod > 0.9995] = 0.9995
|
||||
dot_prod[dot_prod < -0.9995] = -0.9995
|
||||
omega = torch.acos(dot_prod)
|
||||
#slerp
|
||||
dot = (b1_normalized*b2_normalized).sum(1)
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm
|
||||
res *= (low_weight * (1.0-val) + high_weight * val)
|
||||
return res.reshape(dims)
|
||||
|
||||
for x_dest in range(shape[3]):
|
||||
for y_dest in range(shape[2]):
|
||||
y = (y_dest + 0.5) * height_scale - 0.5
|
||||
x = (x_dest + 0.5) * width_scale - 0.5
|
||||
#technically not mathematically correct, but more pleasing?
|
||||
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||
|
||||
x1 = max(math.floor(x), 0)
|
||||
x2 = min(x1 + 1, samples.shape[3] - 1)
|
||||
wx = x - math.floor(x)
|
||||
#edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
def generate_bilinear_data(length_old, length_new):
|
||||
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
coords_2 = coords_2.to(torch.int64)
|
||||
return ratios, coords_1, coords_2
|
||||
|
||||
n,c,h,w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
#linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
coords_2 = coords_2.expand((n, c, h, -1))
|
||||
ratios = ratios.expand((n, 1, h, -1))
|
||||
|
||||
y1 = max(math.floor(y), 0)
|
||||
y2 = min(y1 + 1, samples.shape[2] - 1)
|
||||
wy = y - math.floor(y)
|
||||
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
|
||||
in1 = samples[:,:,y1,x1]
|
||||
in2 = samples[:,:,y1,x2]
|
||||
in3 = samples[:,:,y2,x1]
|
||||
in4 = samples[:,:,y2,x2]
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||
|
||||
if (x1 == x2) and (y1 == y2):
|
||||
out_value = in1
|
||||
elif (x1 == x2):
|
||||
out_value = algorithm(in1, in3, wy)
|
||||
elif (y1 == y2):
|
||||
out_value = algorithm(in1, in2, wx)
|
||||
else:
|
||||
o1 = algorithm(in1, in2, wx)
|
||||
o2 = algorithm(in3, in4, wx)
|
||||
out_value = algorithm(o1, o2, wy)
|
||||
#linear h
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
|
||||
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||
|
||||
out1[:,:,y_dest,x_dest] = out_value
|
||||
return out1
|
||||
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||
return result
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
@ -176,14 +209,14 @@ class ProgressBar:
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
|
||||
def update_absolute(self, value, total=None):
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total)
|
||||
self.hook(self.current, self.total, preview)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
@ -167,7 +167,7 @@ class MaskComposite:
|
||||
"source": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"operation": (["multiply", "add", "subtract"],),
|
||||
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,6 +193,12 @@ class MaskComposite:
|
||||
output[top:bottom, left:right] = destination_portion + source_portion
|
||||
elif operation == "subtract":
|
||||
output[top:bottom, left:right] = destination_portion - source_portion
|
||||
elif operation == "and":
|
||||
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
elif operation == "or":
|
||||
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
elif operation == "xor":
|
||||
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
|
||||
output = torch.clamp(output, 0.0, 1.0)
|
||||
|
||||
|
||||
445
execution.py
445
execution.py
@ -102,13 +102,21 @@ def get_output_data(obj, input_data_all):
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
|
||||
def format_value(x):
|
||||
if x is None:
|
||||
return None
|
||||
elif isinstance(x, (int, float, bool, str)):
|
||||
return x
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return
|
||||
return (True, None, None)
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
@ -117,22 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
input_data_all = None
|
||||
try:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
print("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
}
|
||||
|
||||
return (False, error_details, iex)
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
input_data_formatted = {}
|
||||
if input_data_all is not None:
|
||||
input_data_formatted = {}
|
||||
for name, inputs in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
|
||||
output_data_formatted = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted,
|
||||
"current_outputs": output_data_formatted
|
||||
}
|
||||
return (False, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
return (True, None, None)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
@ -210,6 +260,48 @@ class PromptExecutor:
|
||||
self.old_prompt = {}
|
||||
self.server = server
|
||||
|
||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||
node_id = error["node_id"]
|
||||
class_type = prompt[node_id]["class_type"]
|
||||
|
||||
# First, send back the status to the frontend depending
|
||||
# on the exception type
|
||||
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
}
|
||||
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
|
||||
else:
|
||||
if self.server.client_id is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": error["exception_message"],
|
||||
"exception_type": error["exception_type"],
|
||||
"traceback": error["traceback"],
|
||||
"current_inputs": error["current_inputs"],
|
||||
"current_outputs": error["current_outputs"],
|
||||
}
|
||||
self.server.send_sync("execution_error", mes, self.server.client_id)
|
||||
|
||||
# Next, remove the subsequent outputs since they will not be executed
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
@ -244,42 +336,30 @@ class PromptExecutor:
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
executed = set()
|
||||
try:
|
||||
to_execute = []
|
||||
for x in list(execute_outputs):
|
||||
to_execute += [(0, x)]
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
x = to_execute.pop(0)[-1]
|
||||
for node_id in list(execute_outputs):
|
||||
to_execute += [(0, node_id)]
|
||||
|
||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
except Exception as e:
|
||||
if isinstance(e, comfy.model_management.InterruptProcessingException):
|
||||
print("Processing interrupted")
|
||||
else:
|
||||
message = str(traceback.format_exc())
|
||||
print(message)
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
output_node_id = to_execute.pop(0)[-1]
|
||||
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
finally:
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
if success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||
|
||||
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||
gc.collect()
|
||||
@ -297,57 +377,202 @@ def validate_inputs(prompt, item, validated):
|
||||
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
required_inputs = class_inputs['required']
|
||||
|
||||
errors = []
|
||||
valid = True
|
||||
|
||||
for x in required_inputs:
|
||||
if x not in inputs:
|
||||
return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id)
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
val = inputs[x]
|
||||
info = required_inputs[x]
|
||||
type_input = info[0]
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
return (False, "Bad Input. {}, {}".format(class_type, x), unique_id)
|
||||
error = {
|
||||
"type": "bad_linked_input",
|
||||
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
o_id = val[0]
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[val[1]] != type_input:
|
||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id)
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] == False:
|
||||
validated[o_id] = r
|
||||
return r
|
||||
received_type = r[val[1]]
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
"message": "Return type mismatch between linked nodes",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_type": received_type,
|
||||
"linked_node": val
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
continue
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
valid = False
|
||||
exception_type = full_type_name(typ)
|
||||
reasons = [{
|
||||
"type": "exception_during_inner_validation",
|
||||
"message": "Exception when validating inner node",
|
||||
"details": str(ex),
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"linked_node": val
|
||||
}
|
||||
}]
|
||||
validated[o_id] = (False, reasons, o_id)
|
||||
continue
|
||||
else:
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
if type_input == "FLOAT":
|
||||
val = float(val)
|
||||
inputs[x] = val
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
try:
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
if type_input == "FLOAT":
|
||||
val = float(val)
|
||||
inputs[x] = val
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
except Exception as ex:
|
||||
error = {
|
||||
"type": "invalid_input_type",
|
||||
"message": f"Failed to convert an input value to a {type_input} value",
|
||||
"details": f"{x}, {val}, {ex}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
"exception_message": str(ex)
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id)
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id)
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for r in ret:
|
||||
if r != True:
|
||||
return (False, "{}, {}".format(class_type, r), unique_id)
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
|
||||
error = {
|
||||
"type": "custom_validation_failed",
|
||||
"message": "Custom validation failed for node",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
else:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id)
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
# Don't send back gigantic lists like if they're lots of
|
||||
# scanned model filepaths
|
||||
if len(type_input) > 20:
|
||||
list_info = f"(list of length {len(type_input)})"
|
||||
input_config = None
|
||||
else:
|
||||
list_info = str(type_input)
|
||||
|
||||
error = {
|
||||
"type": "value_not_in_list",
|
||||
"message": "Value not in list",
|
||||
"details": f"{x}: '{val}' not in {list_info}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": input_config,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(errors) > 0 or valid is not True:
|
||||
ret = (False, errors, unique_id)
|
||||
else:
|
||||
ret = (True, [], unique_id)
|
||||
|
||||
ret = (True, "", unique_id)
|
||||
validated[unique_id] = ret
|
||||
return ret
|
||||
|
||||
def full_type_name(klass):
|
||||
module = klass.__module__
|
||||
if module == 'builtins':
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
@ -356,7 +581,13 @@ def validate_prompt(prompt):
|
||||
outputs.add(x)
|
||||
|
||||
if len(outputs) == 0:
|
||||
return (False, "Prompt has no outputs", [], [])
|
||||
error = {
|
||||
"type": "prompt_no_outputs",
|
||||
"message": "Prompt has no outputs",
|
||||
"details": "",
|
||||
"extra_info": {}
|
||||
}
|
||||
return (False, error, [], [])
|
||||
|
||||
good_outputs = set()
|
||||
errors = []
|
||||
@ -364,34 +595,72 @@ def validate_prompt(prompt):
|
||||
validated = {}
|
||||
for o in outputs:
|
||||
valid = False
|
||||
reason = ""
|
||||
reasons = []
|
||||
try:
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
valid = m[0]
|
||||
reason = m[1]
|
||||
node_id = m[2]
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
reasons = m[1]
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
valid = False
|
||||
reason = "Parsing error"
|
||||
node_id = None
|
||||
exception_type = full_type_name(typ)
|
||||
reasons = [{
|
||||
"type": "exception_during_validation",
|
||||
"message": "Exception when validating node",
|
||||
"details": str(ex),
|
||||
"extra_info": {
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb)
|
||||
}
|
||||
}]
|
||||
validated[o] = (False, reasons, o)
|
||||
|
||||
if valid == True:
|
||||
if valid is True:
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
print("Failed to validate prompt for output {} {}".format(o, reason))
|
||||
print("output will be ignored")
|
||||
errors += [(o, reason)]
|
||||
if node_id is not None:
|
||||
if node_id not in node_errors:
|
||||
node_errors[node_id] = {"message": reason, "dependent_outputs": []}
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
print(f"Failed to validate prompt for output {o}:")
|
||||
if len(reasons) > 0:
|
||||
print("* (prompt):")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
errors += [(o, reasons)]
|
||||
for node_id, result in validated.items():
|
||||
valid = result[0]
|
||||
reasons = result[1]
|
||||
# If a node upstream has errors, the nodes downstream will also
|
||||
# be reported as invalid, but there will be no errors attached.
|
||||
# So don't return those nodes as having errors in the response.
|
||||
if valid is not True and len(reasons) > 0:
|
||||
if node_id not in node_errors:
|
||||
class_type = prompt[node_id]['class_type']
|
||||
node_errors[node_id] = {
|
||||
"errors": reasons,
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type
|
||||
}
|
||||
print(f"* {class_type} {node_id}:")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
print("Output will be ignored")
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors)
|
||||
errors_list = []
|
||||
for o, errors in errors:
|
||||
for error in errors:
|
||||
errors_list.append(f"{error['message']}: {error['details']}")
|
||||
errors_list = "\n".join(errors_list)
|
||||
|
||||
return (True, "", list(good_outputs), node_errors)
|
||||
error = {
|
||||
"type": "prompt_outputs_failed_validation",
|
||||
"message": "Prompt outputs failed validation",
|
||||
"details": errors_list,
|
||||
"extra_info": {}
|
||||
}
|
||||
|
||||
return (False, error, list(good_outputs), node_errors)
|
||||
|
||||
return (True, None, list(good_outputs), node_errors)
|
||||
|
||||
|
||||
class PromptQueue:
|
||||
|
||||
@ -1,14 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
supported_ckpt_extensions = set(['.ckpt', '.pth'])
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
|
||||
try:
|
||||
import safetensors.torch
|
||||
supported_ckpt_extensions.add('.safetensors')
|
||||
supported_pt_extensions.add('.safetensors')
|
||||
except:
|
||||
print("Could not import safetensors, safetensors support disabled.")
|
||||
|
||||
supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
|
||||
|
||||
folder_names_and_paths = {}
|
||||
|
||||
@ -24,6 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision"
|
||||
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
|
||||
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
|
||||
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
|
||||
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
|
||||
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
|
||||
@ -38,6 +33,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou
|
||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
|
||||
filename_list_cache = {}
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
os.makedirs(input_directory)
|
||||
|
||||
@ -118,12 +115,18 @@ def get_folder_paths(folder_name):
|
||||
return folder_names_and_paths[folder_name][0][:]
|
||||
|
||||
def recursive_search(directory):
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}
|
||||
result = []
|
||||
dirs = {directory: os.path.getmtime(directory)}
|
||||
for root, subdir, file in os.walk(directory, followlinks=True):
|
||||
for filepath in file:
|
||||
#we os.path,join directory with a blank string to generate a path separator at the end.
|
||||
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
|
||||
return result
|
||||
for d in subdir:
|
||||
path = os.path.join(root, d)
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
return result, dirs
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||
@ -132,20 +135,58 @@ def filter_files_extensions(files, extensions):
|
||||
|
||||
def get_full_path(folder_name, filename):
|
||||
global folder_names_and_paths
|
||||
if folder_name not in folder_names_and_paths:
|
||||
return None
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
filename = os.path.relpath(os.path.join("/", filename), "/")
|
||||
for x in folders[0]:
|
||||
full_path = os.path.join(x, filename)
|
||||
if os.path.isfile(full_path):
|
||||
return full_path
|
||||
|
||||
return None
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
def get_filename_list_(folder_name):
|
||||
global folder_names_and_paths
|
||||
output_list = set()
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
output_folders = {}
|
||||
for x in folders[0]:
|
||||
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
||||
return sorted(list(output_list))
|
||||
files, folders_all = recursive_search(x)
|
||||
output_list.update(filter_files_extensions(files, folders[1]))
|
||||
output_folders = {**output_folders, **folders_all}
|
||||
|
||||
return (sorted(list(output_list)), output_folders, time.perf_counter())
|
||||
|
||||
def cached_filename_list_(folder_name):
|
||||
global filename_list_cache
|
||||
global folder_names_and_paths
|
||||
if folder_name not in filename_list_cache:
|
||||
return None
|
||||
out = filename_list_cache[folder_name]
|
||||
if time.perf_counter() < (out[2] + 0.5):
|
||||
return out
|
||||
for x in out[1]:
|
||||
time_modified = out[1][x]
|
||||
folder = x
|
||||
if os.path.getmtime(folder) != time_modified:
|
||||
return None
|
||||
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
for x in folders[0]:
|
||||
if os.path.isdir(x):
|
||||
if x not in out[1]:
|
||||
return None
|
||||
|
||||
return out
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
out = cached_filename_list_(folder_name)
|
||||
if out is None:
|
||||
out = get_filename_list_(folder_name)
|
||||
global filename_list_cache
|
||||
filename_list_cache[folder_name] = out
|
||||
return list(out[0])
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def map_filename(filename):
|
||||
|
||||
95
latent_preview.py
Normal file
95
latent_preview.py
Normal file
@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
from comfy.cli_args import args, LatentPreviewMethod
|
||||
from comfy.taesd.taesd import TAESD
|
||||
import folder_paths
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = 512
|
||||
|
||||
class LatentPreviewer:
|
||||
def decode_latent_to_preview(self, x0):
|
||||
pass
|
||||
|
||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||
preview_image = self.decode_latent_to_preview(x0)
|
||||
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
|
||||
|
||||
preview_type = 1
|
||||
if preview_format == "JPEG":
|
||||
preview_type = 1
|
||||
elif preview_format == "PNG":
|
||||
preview_type = 2
|
||||
|
||||
bytesIO = BytesIO()
|
||||
header = struct.pack(">I", preview_type)
|
||||
bytesIO.write(header)
|
||||
preview_image.save(bytesIO, format=preview_format, quality=95)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
return preview_bytes
|
||||
|
||||
class TAESDPreviewerImpl(LatentPreviewer):
|
||||
def __init__(self, taesd):
|
||||
self.taesd = taesd
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
x_sample = self.taesd.decoder(x0)[0].detach()
|
||||
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
preview_image = Image.fromarray(x_sample)
|
||||
return preview_image
|
||||
|
||||
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = torch.tensor([
|
||||
# R G B
|
||||
[0.298, 0.207, 0.208], # L1
|
||||
[0.187, 0.286, 0.173], # L2
|
||||
[-0.158, 0.189, 0.264], # L3
|
||||
[-0.184, -0.271, -0.473], # L4
|
||||
], device="cpu")
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
||||
|
||||
latents_ubyte = (((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
def get_previewer(device):
|
||||
previewer = None
|
||||
method = args.preview_method
|
||||
if method != LatentPreviewMethod.NoPreviews:
|
||||
# TODO previewer methods
|
||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
|
||||
|
||||
if method == LatentPreviewMethod.Auto:
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
if taesd_decoder_path:
|
||||
method = LatentPreviewMethod.TAESD
|
||||
|
||||
if method == LatentPreviewMethod.TAESD:
|
||||
if taesd_decoder_path:
|
||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
||||
previewer = TAESDPreviewerImpl(taesd)
|
||||
else:
|
||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
|
||||
|
||||
if previewer is None:
|
||||
previewer = Latent2RGBPreviewer()
|
||||
return previewer
|
||||
|
||||
|
||||
25
main.py
25
main.py
@ -26,6 +26,7 @@ import yaml
|
||||
import execution
|
||||
import folder_paths
|
||||
import server
|
||||
from server import BinaryEventTypes
|
||||
from nodes import init_custom_nodes
|
||||
|
||||
|
||||
@ -36,19 +37,25 @@ def prompt_worker(q, server):
|
||||
e.execute(item[2], item[1], item[3], item[4])
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
|
||||
|
||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
||||
|
||||
def hijack_progress(server):
|
||||
def hook(value, total):
|
||||
server.send_sync("progress", { "value": value, "max": total}, server.client_id)
|
||||
def hook(value, total, preview_image_bytes):
|
||||
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
|
||||
if preview_image_bytes is not None:
|
||||
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
|
||||
def cleanup_temp():
|
||||
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def load_extra_path_config(yaml_path):
|
||||
with open(yaml_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
@ -69,6 +76,7 @@ def load_extra_path_config(yaml_path):
|
||||
print("Adding extra search path", x, full_path)
|
||||
folder_paths.add_model_folder_path(x, full_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cleanup_temp()
|
||||
|
||||
@ -89,7 +97,7 @@ if __name__ == "__main__":
|
||||
server.add_routes()
|
||||
hijack_progress(server)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
|
||||
|
||||
if args.output_directory:
|
||||
output_dir = os.path.abspath(args.output_directory)
|
||||
@ -103,15 +111,12 @@ if __name__ == "__main__":
|
||||
if args.auto_launch:
|
||||
def startup_server(address, port):
|
||||
import webbrowser
|
||||
webbrowser.open("http://{}:{}".format(address, port))
|
||||
webbrowser.open(f"http://{address}:{port}")
|
||||
call_on_start = startup_server
|
||||
|
||||
if os.name == "nt":
|
||||
try:
|
||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopped server")
|
||||
|
||||
cleanup_temp()
|
||||
|
||||
35
nodes.py
35
nodes.py
@ -13,11 +13,10 @@ from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
|
||||
import comfy.diffusers_convert
|
||||
import comfy.diffusers_load
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
import comfy.sd
|
||||
@ -29,7 +28,7 @@ import comfy.model_management
|
||||
import importlib
|
||||
|
||||
import folder_paths
|
||||
|
||||
import latent_preview
|
||||
|
||||
def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
@ -248,7 +247,6 @@ class VAEEncodeForInpaint:
|
||||
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
@ -377,7 +375,7 @@ class DiffusersLoader:
|
||||
model_path = path
|
||||
break
|
||||
|
||||
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
|
||||
class unCLIPCheckpointLoader:
|
||||
@ -426,6 +424,9 @@ class LoraLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (model, clip)
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
@ -507,6 +508,9 @@ class ControlNetApply:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
c = []
|
||||
control_hint = image.movedim(-1,1)
|
||||
for t in conditioning:
|
||||
@ -613,6 +617,9 @@ class unCLIPConditioning:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
o = t[1].copy()
|
||||
@ -922,6 +929,7 @@ class SetLatentNoiseMask:
|
||||
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
return (s,)
|
||||
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
latent_image = latent["samples"]
|
||||
@ -936,9 +944,18 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
preview_format = "JPEG"
|
||||
if preview_format not in ["JPEG", "PNG"]:
|
||||
preview_format = "JPEG"
|
||||
|
||||
previewer = latent_preview.get_previewer(device)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
def callback(step, x0, x, total_steps):
|
||||
pbar.update_absolute(step + 1, total_steps)
|
||||
preview_bytes = None
|
||||
if previewer:
|
||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
@ -961,7 +978,8 @@ class KSampler:
|
||||
"negative": ("CONDITIONING", ),
|
||||
"latent_image": ("LATENT", ),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "sample"
|
||||
@ -988,7 +1006,8 @@ class KSamplerAdvanced:
|
||||
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
||||
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
||||
"return_with_leftover_noise": (["disable", "enable"], ),
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "sample"
|
||||
|
||||
107
server.py
107
server.py
@ -7,6 +7,7 @@ import execution
|
||||
import uuid
|
||||
import json
|
||||
import glob
|
||||
import struct
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
@ -22,6 +23,12 @@ except ImportError:
|
||||
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
|
||||
|
||||
@web.middleware
|
||||
@ -216,6 +223,27 @@ class PromptServer():
|
||||
file = os.path.join(output_dir, filename)
|
||||
|
||||
if os.path.isfile(file):
|
||||
if 'preview' in request.rel_url.query:
|
||||
with Image.open(file) as img:
|
||||
preview_info = request.rel_url.query['preview'].split(';')
|
||||
|
||||
image_format = preview_info[0]
|
||||
if image_format not in ['webp', 'jpeg']:
|
||||
image_format = 'webp'
|
||||
|
||||
quality = 90
|
||||
if preview_info[-1].isdigit():
|
||||
quality = int(preview_info[-1])
|
||||
|
||||
buffer = BytesIO()
|
||||
if image_format in ['jpeg']:
|
||||
img = img.convert("RGB")
|
||||
img.save(buffer, format=image_format, quality=quality)
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
if 'channel' not in request.rel_url.query:
|
||||
channel = 'rgba'
|
||||
else:
|
||||
@ -257,6 +285,50 @@ class PromptServer():
|
||||
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.get("/view_metadata/{folder_name}")
|
||||
async def view_metadata(request):
|
||||
folder_name = request.match_info.get("folder_name", None)
|
||||
if folder_name is None:
|
||||
return web.Response(status=404)
|
||||
if not "filename" in request.rel_url.query:
|
||||
return web.Response(status=404)
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
if not filename.endswith(".safetensors"):
|
||||
return web.Response(status=404)
|
||||
|
||||
safetensors_path = folder_paths.get_full_path(folder_name, filename)
|
||||
if safetensors_path is None:
|
||||
return web.Response(status=404)
|
||||
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
|
||||
if out is None:
|
||||
return web.Response(status=404)
|
||||
dt = json.loads(out)
|
||||
if not "__metadata__" in dt:
|
||||
return web.Response(status=404)
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
@routes.get("/system_stats")
|
||||
async def get_queue(request):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
system_stats = {
|
||||
"devices": [
|
||||
{
|
||||
"name": device_name,
|
||||
"type": device.type,
|
||||
"index": device.index,
|
||||
"vram_total": vram_total,
|
||||
"vram_free": vram_free,
|
||||
"torch_vram_total": torch_vram_total,
|
||||
"torch_vram_free": torch_vram_free,
|
||||
}
|
||||
]
|
||||
}
|
||||
return web.json_response(system_stats)
|
||||
|
||||
@routes.get("/prompt")
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
@ -338,7 +410,7 @@ class PromptServer():
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
return web.json_response({"prompt_id": prompt_id})
|
||||
return web.json_response({"prompt_id": prompt_id, "number": number})
|
||||
else:
|
||||
print("invalid prompt:", valid[1])
|
||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||
@ -391,16 +463,37 @@ class PromptServer():
|
||||
return prompt_info
|
||||
|
||||
async def send(self, event, data, sid=None):
|
||||
message = {"type": event, "data": data}
|
||||
|
||||
if isinstance(message, str) == False:
|
||||
message = json.dumps(message)
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
await self.send_bytes(event, data, sid)
|
||||
else:
|
||||
await self.send_json(event, data, sid)
|
||||
|
||||
def encode_bytes(self, event, data):
|
||||
if not isinstance(event, int):
|
||||
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||||
|
||||
packed = struct.pack(">I", event)
|
||||
message = bytearray(packed)
|
||||
message.extend(data)
|
||||
return message
|
||||
|
||||
async def send_bytes(self, event, data, sid=None):
|
||||
message = self.encode_bytes(event, data)
|
||||
|
||||
if sid is None:
|
||||
for ws in self.sockets.values():
|
||||
await ws.send_str(message)
|
||||
await ws.send_bytes(message)
|
||||
elif sid in self.sockets:
|
||||
await self.sockets[sid].send_str(message)
|
||||
await self.sockets[sid].send_bytes(message)
|
||||
|
||||
async def send_json(self, event, data, sid=None):
|
||||
message = {"type": event, "data": data}
|
||||
|
||||
if sid is None:
|
||||
for ws in self.sockets.values():
|
||||
await ws.send_json(message)
|
||||
elif sid in self.sockets:
|
||||
await self.sockets[sid].send_json(message)
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
self.loop.call_soon_threadsafe(
|
||||
|
||||
@ -21,6 +21,7 @@ const colorPalettes = {
|
||||
"MODEL": "#B39DDB", // light lavender-purple
|
||||
"STYLE_MODEL": "#C2FFAE", // light green-yellow
|
||||
"VAE": "#FF6E6E", // bright red
|
||||
"TAESD": "#DCC274", // cheesecake
|
||||
},
|
||||
"litegraph_base": {
|
||||
"NODE_TITLE_COLOR": "#999",
|
||||
|
||||
@ -1,132 +1,138 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
import {app} from "/scripts/app.js";
|
||||
|
||||
// Adds filtering to combo context menus
|
||||
|
||||
const id = "Comfy.ContextMenuFilter";
|
||||
app.registerExtension({
|
||||
name: id,
|
||||
const ext = {
|
||||
name: "Comfy.ContextMenuFilter",
|
||||
init() {
|
||||
const ctxMenu = LiteGraph.ContextMenu;
|
||||
|
||||
LiteGraph.ContextMenu = function (values, options) {
|
||||
const ctx = ctxMenu.call(this, values, options);
|
||||
|
||||
// If we are a dark menu (only used for combo boxes) then add a filter input
|
||||
if (options?.className === "dark" && values?.length > 10) {
|
||||
const filter = document.createElement("input");
|
||||
Object.assign(filter.style, {
|
||||
width: "calc(100% - 10px)",
|
||||
border: "0",
|
||||
boxSizing: "border-box",
|
||||
background: "#333",
|
||||
border: "1px solid #999",
|
||||
margin: "0 0 5px 5px",
|
||||
color: "#fff",
|
||||
});
|
||||
filter.classList.add("comfy-context-menu-filter");
|
||||
filter.placeholder = "Filter list";
|
||||
this.root.prepend(filter);
|
||||
|
||||
let selectedIndex = 0;
|
||||
let items = this.root.querySelectorAll(".litemenu-entry");
|
||||
let itemCount = items.length;
|
||||
let selectedItem;
|
||||
const items = Array.from(this.root.querySelectorAll(".litemenu-entry"));
|
||||
let displayedItems = [...items];
|
||||
let itemCount = displayedItems.length;
|
||||
|
||||
// Apply highlighting to the selected item
|
||||
function updateSelected() {
|
||||
if (selectedItem) {
|
||||
selectedItem.style.setProperty("background-color", "");
|
||||
selectedItem.style.setProperty("color", "");
|
||||
}
|
||||
selectedItem = items[selectedIndex];
|
||||
if (selectedItem) {
|
||||
selectedItem.style.setProperty("background-color", "#ccc", "important");
|
||||
selectedItem.style.setProperty("color", "#000", "important");
|
||||
}
|
||||
}
|
||||
// We must request an animation frame for the current node of the active canvas to update.
|
||||
requestAnimationFrame(() => {
|
||||
const currentNode = LGraphCanvas.active_canvas.current_node;
|
||||
const clickedComboValue = currentNode.widgets
|
||||
.filter(w => w.type === "combo" && w.options.values.length === values.length)
|
||||
.find(w => w.options.values.every((v, i) => v === values[i]))
|
||||
.value;
|
||||
|
||||
const positionList = () => {
|
||||
const rect = this.root.getBoundingClientRect();
|
||||
|
||||
// If the top is off screen then shift the element with scaling applied
|
||||
if (rect.top < 0) {
|
||||
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
|
||||
const shift = (this.root.clientHeight * scale) / 2;
|
||||
this.root.style.top = -shift + "px";
|
||||
}
|
||||
}
|
||||
|
||||
updateSelected();
|
||||
|
||||
// Arrow up/down to select items
|
||||
filter.addEventListener("keydown", (e) => {
|
||||
if (e.key === "ArrowUp") {
|
||||
if (selectedIndex === 0) {
|
||||
selectedIndex = itemCount - 1;
|
||||
} else {
|
||||
selectedIndex--;
|
||||
}
|
||||
updateSelected();
|
||||
e.preventDefault();
|
||||
} else if (e.key === "ArrowDown") {
|
||||
if (selectedIndex === itemCount - 1) {
|
||||
selectedIndex = 0;
|
||||
} else {
|
||||
selectedIndex++;
|
||||
}
|
||||
updateSelected();
|
||||
e.preventDefault();
|
||||
} else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) {
|
||||
selectedItem.click();
|
||||
} else if(e.key === "Escape") {
|
||||
this.close();
|
||||
}
|
||||
});
|
||||
|
||||
filter.addEventListener("input", () => {
|
||||
// Hide all items that dont match our filter
|
||||
const term = filter.value.toLocaleLowerCase();
|
||||
items = this.root.querySelectorAll(".litemenu-entry");
|
||||
// When filtering recompute which items are visible for arrow up/down
|
||||
// Try and maintain selection
|
||||
let visibleItems = [];
|
||||
for (const item of items) {
|
||||
const visible = !term || item.textContent.toLocaleLowerCase().includes(term);
|
||||
if (visible) {
|
||||
item.style.display = "block";
|
||||
if (item === selectedItem) {
|
||||
selectedIndex = visibleItems.length;
|
||||
}
|
||||
visibleItems.push(item);
|
||||
} else {
|
||||
item.style.display = "none";
|
||||
if (item === selectedItem) {
|
||||
selectedIndex = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
items = visibleItems;
|
||||
let selectedIndex = values.findIndex(v => v === clickedComboValue);
|
||||
let selectedItem = displayedItems?.[selectedIndex];
|
||||
updateSelected();
|
||||
|
||||
// If we have an event then we can try and position the list under the source
|
||||
if (options.event) {
|
||||
let top = options.event.clientY - 10;
|
||||
|
||||
const bodyRect = document.body.getBoundingClientRect();
|
||||
const rootRect = this.root.getBoundingClientRect();
|
||||
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
|
||||
top = Math.max(0, bodyRect.height - rootRect.height - 10);
|
||||
}
|
||||
|
||||
this.root.style.top = top + "px";
|
||||
positionList();
|
||||
// Apply highlighting to the selected item
|
||||
function updateSelected() {
|
||||
selectedItem?.style.setProperty("background-color", "");
|
||||
selectedItem?.style.setProperty("color", "");
|
||||
selectedItem = displayedItems[selectedIndex];
|
||||
selectedItem?.style.setProperty("background-color", "#ccc", "important");
|
||||
selectedItem?.style.setProperty("color", "#000", "important");
|
||||
}
|
||||
});
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
// Focus the filter box when opening
|
||||
filter.focus();
|
||||
const positionList = () => {
|
||||
const rect = this.root.getBoundingClientRect();
|
||||
|
||||
positionList();
|
||||
});
|
||||
// If the top is off-screen then shift the element with scaling applied
|
||||
if (rect.top < 0) {
|
||||
const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight;
|
||||
const shift = (this.root.clientHeight * scale) / 2;
|
||||
this.root.style.top = -shift + "px";
|
||||
}
|
||||
}
|
||||
|
||||
// Arrow up/down to select items
|
||||
filter.addEventListener("keydown", (event) => {
|
||||
switch (event.key) {
|
||||
case "ArrowUp":
|
||||
event.preventDefault();
|
||||
if (selectedIndex === 0) {
|
||||
selectedIndex = itemCount - 1;
|
||||
} else {
|
||||
selectedIndex--;
|
||||
}
|
||||
updateSelected();
|
||||
break;
|
||||
case "ArrowRight":
|
||||
event.preventDefault();
|
||||
selectedIndex = itemCount - 1;
|
||||
updateSelected();
|
||||
break;
|
||||
case "ArrowDown":
|
||||
event.preventDefault();
|
||||
if (selectedIndex === itemCount - 1) {
|
||||
selectedIndex = 0;
|
||||
} else {
|
||||
selectedIndex++;
|
||||
}
|
||||
updateSelected();
|
||||
break;
|
||||
case "ArrowLeft":
|
||||
event.preventDefault();
|
||||
selectedIndex = 0;
|
||||
updateSelected();
|
||||
break;
|
||||
case "Enter":
|
||||
selectedItem?.click();
|
||||
break;
|
||||
case "Escape":
|
||||
this.close();
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
filter.addEventListener("input", () => {
|
||||
// Hide all items that don't match our filter
|
||||
const term = filter.value.toLocaleLowerCase();
|
||||
// When filtering, recompute which items are visible for arrow up/down and maintain selection.
|
||||
displayedItems = items.filter(item => {
|
||||
const isVisible = !term || item.textContent.toLocaleLowerCase().includes(term);
|
||||
item.style.display = isVisible ? "block" : "none";
|
||||
return isVisible;
|
||||
});
|
||||
|
||||
selectedIndex = 0;
|
||||
if (displayedItems.includes(selectedItem)) {
|
||||
selectedIndex = displayedItems.findIndex(d => d === selectedItem);
|
||||
}
|
||||
itemCount = displayedItems.length;
|
||||
|
||||
updateSelected();
|
||||
|
||||
// If we have an event then we can try and position the list under the source
|
||||
if (options.event) {
|
||||
let top = options.event.clientY - 10;
|
||||
|
||||
const bodyRect = document.body.getBoundingClientRect();
|
||||
const rootRect = this.root.getBoundingClientRect();
|
||||
if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) {
|
||||
top = Math.max(0, bodyRect.height - rootRect.height - 10);
|
||||
}
|
||||
|
||||
this.root.style.top = top + "px";
|
||||
positionList();
|
||||
}
|
||||
});
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
// Focus the filter box when opening
|
||||
filter.focus();
|
||||
|
||||
positionList();
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
return ctx;
|
||||
@ -134,4 +140,6 @@ app.registerExtension({
|
||||
|
||||
LiteGraph.ContextMenu.prototype = ctxMenu.prototype;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
app.registerExtension(ext);
|
||||
|
||||
@ -3,6 +3,13 @@ import { app } from "../../scripts/app.js";
|
||||
// Allows for simple dynamic prompt replacement
|
||||
// Inputs in the format {a|b} will have a random value of a or b chosen when the prompt is queued.
|
||||
|
||||
/*
|
||||
* Strips C-style line and block comments from a string
|
||||
*/
|
||||
function stripComments(str) {
|
||||
return str.replace(/\/\*[\s\S]*?\*\/|\/\/.*/g,'');
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.DynamicPrompts",
|
||||
nodeCreated(node) {
|
||||
@ -15,7 +22,7 @@ app.registerExtension({
|
||||
for (const widget of widgets) {
|
||||
// Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node
|
||||
widget.serializeValue = (workflowNode, widgetIndex) => {
|
||||
let prompt = widget.value;
|
||||
let prompt = stripComments(widget.value);
|
||||
while (prompt.replace("\\{", "").includes("{") && prompt.replace("\\}", "").includes("}")) {
|
||||
const startIndex = prompt.replace("\\{", "00").indexOf("{");
|
||||
const endIndex = prompt.replace("\\}", "00").indexOf("}");
|
||||
|
||||
@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) {
|
||||
});
|
||||
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString();
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam();
|
||||
|
||||
if(ComfyApp.clipspace.images)
|
||||
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
|
||||
@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
||||
|
||||
// update mask
|
||||
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
maskCanvas.width = drawWidth;
|
||||
maskCanvas.height = drawHeight;
|
||||
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
||||
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
||||
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||
});
|
||||
|
||||
@ -335,6 +335,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
|
||||
alpha_url.searchParams.delete('channel');
|
||||
alpha_url.searchParams.delete('preview');
|
||||
alpha_url.searchParams.set('channel', 'a');
|
||||
touched_image.src = alpha_url;
|
||||
|
||||
@ -345,6 +346,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||
rgb_url.searchParams.delete('channel');
|
||||
rgb_url.searchParams.delete('preview');
|
||||
rgb_url.searchParams.set('channel', 'rgb');
|
||||
orig_image.src = rgb_url;
|
||||
this.image = orig_image;
|
||||
|
||||
@ -200,8 +200,23 @@ app.registerExtension({
|
||||
applyToGraph() {
|
||||
if (!this.outputs[0].links?.length) return;
|
||||
|
||||
function get_links(node) {
|
||||
let links = [];
|
||||
for (const l of node.outputs[0].links) {
|
||||
const linkInfo = app.graph.links[l];
|
||||
const n = node.graph.getNodeById(linkInfo.target_id);
|
||||
if (n.type == "Reroute") {
|
||||
links = links.concat(get_links(n));
|
||||
} else {
|
||||
links.push(l);
|
||||
}
|
||||
}
|
||||
return links;
|
||||
}
|
||||
|
||||
let links = get_links(this);
|
||||
// For each output link copy our value over the original widget value
|
||||
for (const l of this.outputs[0].links) {
|
||||
for (const l of links) {
|
||||
const linkInfo = app.graph.links[l];
|
||||
const node = this.graph.getNodeById(linkInfo.target_id);
|
||||
const input = node.inputs[linkInfo.target_slot];
|
||||
|
||||
@ -14,5 +14,5 @@
|
||||
window.graph = app.graph;
|
||||
</script>
|
||||
</head>
|
||||
<body></body>
|
||||
<body class="litegraph"></body>
|
||||
</html>
|
||||
|
||||
@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
if (this.onShowNodePanel) {
|
||||
this.onShowNodePanel(n);
|
||||
}
|
||||
else
|
||||
{
|
||||
this.showShowNodePanel(n);
|
||||
}
|
||||
|
||||
if (this.onNodeDblClicked) {
|
||||
this.onNodeDblClicked(n);
|
||||
@ -8099,11 +8095,15 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR;
|
||||
hovercolor = hovercolor || "#555";
|
||||
textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR;
|
||||
var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title
|
||||
var pos = this.mouse;
|
||||
var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h );
|
||||
pos = this.last_click_position;
|
||||
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h );
|
||||
var pos = this.ds.convertOffsetToCanvas(this.graph_mouse);
|
||||
var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
|
||||
pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null;
|
||||
if(pos) {
|
||||
var rect = this.canvas.getBoundingClientRect();
|
||||
pos[0] -= rect.left;
|
||||
pos[1] -= rect.top;
|
||||
}
|
||||
var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h );
|
||||
|
||||
ctx.fillStyle = hover ? hovercolor : bgcolor;
|
||||
if(clicked)
|
||||
@ -13067,6 +13067,10 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
has_submenu: true,
|
||||
callback: LGraphCanvas.onShowMenuNodeProperties
|
||||
},
|
||||
{
|
||||
content: "Properties Panel",
|
||||
callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) }
|
||||
},
|
||||
null,
|
||||
{
|
||||
content: "Title",
|
||||
|
||||
@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
|
||||
this.socket = new WebSocket(
|
||||
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
|
||||
);
|
||||
this.socket.binaryType = "arraybuffer";
|
||||
|
||||
this.socket.addEventListener("open", () => {
|
||||
opened = true;
|
||||
@ -70,33 +71,65 @@ class ComfyApi extends EventTarget {
|
||||
|
||||
this.socket.addEventListener("message", (event) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
switch (msg.type) {
|
||||
case "status":
|
||||
if (msg.data.sid) {
|
||||
this.clientId = msg.data.sid;
|
||||
window.name = this.clientId;
|
||||
if (event.data instanceof ArrayBuffer) {
|
||||
const view = new DataView(event.data);
|
||||
const eventType = view.getUint32(0);
|
||||
const buffer = event.data.slice(4);
|
||||
switch (eventType) {
|
||||
case 1:
|
||||
const view2 = new DataView(event.data);
|
||||
const imageType = view2.getUint32(0)
|
||||
let imageMime
|
||||
switch (imageType) {
|
||||
case 1:
|
||||
default:
|
||||
imageMime = "image/jpeg";
|
||||
break;
|
||||
case 2:
|
||||
imageMime = "image/png"
|
||||
}
|
||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||
break;
|
||||
case "progress":
|
||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||
break;
|
||||
case "executing":
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||
break;
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
const imageBlob = new Blob([buffer.slice(4)], { type: imageMime });
|
||||
this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob }));
|
||||
break;
|
||||
default:
|
||||
if (this.#registered.has(msg.type)) {
|
||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||
} else {
|
||||
throw new Error("Unknown message type");
|
||||
}
|
||||
throw new Error(`Unknown binary websocket message of type ${eventType}`);
|
||||
}
|
||||
}
|
||||
else {
|
||||
const msg = JSON.parse(event.data);
|
||||
switch (msg.type) {
|
||||
case "status":
|
||||
if (msg.data.sid) {
|
||||
this.clientId = msg.data.sid;
|
||||
window.name = this.clientId;
|
||||
}
|
||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||
break;
|
||||
case "progress":
|
||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||
break;
|
||||
case "executing":
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||
break;
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_start":
|
||||
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_error":
|
||||
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||
break;
|
||||
default:
|
||||
if (this.#registered.has(msg.type)) {
|
||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||
} else {
|
||||
throw new Error(`Unknown message type ${msg.type}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn("Unhandled message:", event.data);
|
||||
console.warn("Unhandled message:", event.data, error);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -44,6 +44,12 @@ export class ComfyApp {
|
||||
*/
|
||||
this.nodeOutputs = {};
|
||||
|
||||
/**
|
||||
* Stores the preview image data for each node
|
||||
* @type {Record<string, Image>}
|
||||
*/
|
||||
this.nodePreviewImages = {};
|
||||
|
||||
/**
|
||||
* If the shift key on the keyboard is pressed
|
||||
* @type {boolean}
|
||||
@ -51,6 +57,14 @@ export class ComfyApp {
|
||||
this.shiftDown = false;
|
||||
}
|
||||
|
||||
getPreviewFormatParam() {
|
||||
let preview_format = this.ui.settings.getSettingValue("Comfy.PreviewFormat");
|
||||
if(preview_format)
|
||||
return `&preview=${preview_format}`;
|
||||
else
|
||||
return "";
|
||||
}
|
||||
|
||||
static isImageNode(node) {
|
||||
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
|
||||
}
|
||||
@ -111,10 +125,14 @@ export class ComfyApp {
|
||||
if(ComfyApp.clipspace.imgs && node.imgs) {
|
||||
if(node.images && ComfyApp.clipspace.images) {
|
||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
|
||||
node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
|
||||
}
|
||||
else
|
||||
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
|
||||
else {
|
||||
node.images = ComfyApp.clipspace.images;
|
||||
}
|
||||
|
||||
if(app.nodeOutputs[node.id + ""])
|
||||
app.nodeOutputs[node.id + ""].images = node.images;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
@ -147,7 +165,16 @@ export class ComfyApp {
|
||||
if(ComfyApp.clipspace.widgets) {
|
||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
||||
if (prop && prop.type != 'button') {
|
||||
if (prop && prop.type != 'image') {
|
||||
if(typeof prop.value == "string" && value.filename) {
|
||||
prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:'');
|
||||
}
|
||||
else {
|
||||
prop.value = value;
|
||||
prop.callback(value);
|
||||
}
|
||||
}
|
||||
else if (prop && prop.type != 'button') {
|
||||
prop.value = value;
|
||||
prop.callback(value);
|
||||
}
|
||||
@ -231,14 +258,20 @@ export class ComfyApp {
|
||||
options.unshift(
|
||||
{
|
||||
content: "Open Image",
|
||||
callback: () => window.open(img.src, "_blank"),
|
||||
callback: () => {
|
||||
let url = new URL(img.src);
|
||||
url.searchParams.delete('preview');
|
||||
window.open(url, "_blank")
|
||||
},
|
||||
},
|
||||
{
|
||||
content: "Save Image",
|
||||
callback: () => {
|
||||
const a = document.createElement("a");
|
||||
a.href = img.src;
|
||||
a.setAttribute("download", new URLSearchParams(new URL(img.src).search).get("filename"));
|
||||
let url = new URL(img.src);
|
||||
url.searchParams.delete('preview');
|
||||
a.href = url;
|
||||
a.setAttribute("download", new URLSearchParams(url.search).get("filename"));
|
||||
document.body.append(a);
|
||||
a.click();
|
||||
requestAnimationFrame(() => a.remove());
|
||||
@ -345,6 +378,10 @@ export class ComfyApp {
|
||||
}
|
||||
|
||||
node.prototype.setSizeForImage = function () {
|
||||
if (this.inputHeight) {
|
||||
this.setSize(this.size);
|
||||
return;
|
||||
}
|
||||
const minHeight = getImageTop(this) + 220;
|
||||
if (this.size[1] < minHeight) {
|
||||
this.setSize([this.size[0], minHeight]);
|
||||
@ -353,29 +390,52 @@ export class ComfyApp {
|
||||
|
||||
node.prototype.onDrawBackground = function (ctx) {
|
||||
if (!this.flags.collapsed) {
|
||||
let imgURLs = []
|
||||
let imagesChanged = false
|
||||
|
||||
const output = app.nodeOutputs[this.id + ""];
|
||||
if (output && output.images) {
|
||||
if (this.images !== output.images) {
|
||||
this.images = output.images;
|
||||
this.imgs = null;
|
||||
this.imageIndex = null;
|
||||
imagesChanged = true;
|
||||
imgURLs = imgURLs.concat(output.images.map(params => {
|
||||
return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam();
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
const preview = app.nodePreviewImages[this.id + ""]
|
||||
if (this.preview !== preview) {
|
||||
this.preview = preview
|
||||
imagesChanged = true;
|
||||
if (preview != null) {
|
||||
imgURLs.push(preview);
|
||||
}
|
||||
}
|
||||
|
||||
if (imagesChanged) {
|
||||
this.imageIndex = null;
|
||||
if (imgURLs.length > 0) {
|
||||
Promise.all(
|
||||
output.images.map((src) => {
|
||||
imgURLs.map((src) => {
|
||||
return new Promise((r) => {
|
||||
const img = new Image();
|
||||
img.onload = () => r(img);
|
||||
img.onerror = () => r(null);
|
||||
img.src = "/view?" + new URLSearchParams(src).toString();
|
||||
img.src = src
|
||||
});
|
||||
})
|
||||
).then((imgs) => {
|
||||
if (this.images === output.images) {
|
||||
if ((!output || this.images === output.images) && (!preview || this.preview === preview)) {
|
||||
this.imgs = imgs.filter(Boolean);
|
||||
this.setSizeForImage?.();
|
||||
app.graph.setDirtyCanvas(true);
|
||||
}
|
||||
});
|
||||
}
|
||||
else {
|
||||
this.imgs = null;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.imgs && this.imgs.length) {
|
||||
@ -771,16 +831,27 @@ export class ComfyApp {
|
||||
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
|
||||
const res = origDrawNodeShape.apply(this, arguments);
|
||||
|
||||
const nodeErrors = self.lastPromptError?.node_errors[node.id];
|
||||
|
||||
let color = null;
|
||||
let lineWidth = 1;
|
||||
if (node.id === +self.runningNodeId) {
|
||||
color = "#0f0";
|
||||
} else if (self.dragOverNode && node.id === self.dragOverNode.id) {
|
||||
color = "dodgerblue";
|
||||
}
|
||||
else if (self.lastPromptError != null && nodeErrors?.errors) {
|
||||
color = "red";
|
||||
lineWidth = 2;
|
||||
}
|
||||
else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) {
|
||||
color = "#f0f";
|
||||
lineWidth = 2;
|
||||
}
|
||||
|
||||
if (color) {
|
||||
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
|
||||
ctx.lineWidth = 1;
|
||||
ctx.lineWidth = lineWidth;
|
||||
ctx.globalAlpha = 0.8;
|
||||
ctx.beginPath();
|
||||
if (shape == LiteGraph.BOX_SHAPE)
|
||||
@ -807,11 +878,28 @@ export class ComfyApp {
|
||||
ctx.stroke();
|
||||
ctx.strokeStyle = fgcolor;
|
||||
ctx.globalAlpha = 1;
|
||||
}
|
||||
|
||||
if (self.progress) {
|
||||
ctx.fillStyle = "green";
|
||||
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
||||
ctx.fillStyle = bgcolor;
|
||||
if (self.progress && node.id === +self.runningNodeId) {
|
||||
ctx.fillStyle = "green";
|
||||
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
||||
ctx.fillStyle = bgcolor;
|
||||
}
|
||||
|
||||
// Highlight inputs that failed validation
|
||||
if (nodeErrors) {
|
||||
ctx.lineWidth = 2;
|
||||
ctx.strokeStyle = "red";
|
||||
for (const error of nodeErrors.errors) {
|
||||
if (error.extra_info && error.extra_info.input_name) {
|
||||
const inputIndex = node.findInputSlot(error.extra_info.input_name)
|
||||
if (inputIndex !== -1) {
|
||||
let pos = node.getConnectionPos(true, inputIndex);
|
||||
ctx.beginPath();
|
||||
ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false)
|
||||
ctx.stroke();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -859,16 +947,40 @@ export class ComfyApp {
|
||||
this.progress = null;
|
||||
this.runningNodeId = detail;
|
||||
this.graph.setDirtyCanvas(true, false);
|
||||
delete this.nodePreviewImages[this.runningNodeId]
|
||||
});
|
||||
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
if (node?.onExecuted) {
|
||||
node.onExecuted(detail.output);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
node.onExecuted(detail.output);
|
||||
}
|
||||
});
|
||||
|
||||
api.addEventListener("execution_start", ({ detail }) => {
|
||||
this.runningNodeId = null;
|
||||
this.lastExecutionError = null
|
||||
});
|
||||
|
||||
api.addEventListener("execution_error", ({ detail }) => {
|
||||
this.lastExecutionError = detail;
|
||||
const formattedError = this.#formatExecutionError(detail);
|
||||
this.ui.dialog.show(formattedError);
|
||||
this.canvas.draw(true, true);
|
||||
});
|
||||
|
||||
api.addEventListener("b_preview", ({ detail }) => {
|
||||
const id = this.runningNodeId
|
||||
if (id == null)
|
||||
return;
|
||||
|
||||
const blob = detail
|
||||
const blobUrl = URL.createObjectURL(blob)
|
||||
this.nodePreviewImages[id] = [blobUrl]
|
||||
});
|
||||
|
||||
api.init();
|
||||
}
|
||||
|
||||
@ -975,6 +1087,11 @@ export class ComfyApp {
|
||||
const app = this;
|
||||
// Load node definitions from the backend
|
||||
const defs = await api.getNodeDefs();
|
||||
await this.registerNodesFromDefs(defs);
|
||||
await this.#invokeExtensionsAsync("registerCustomNodes");
|
||||
}
|
||||
|
||||
async registerNodesFromDefs(defs) {
|
||||
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
|
||||
|
||||
// Generate list of known widgets
|
||||
@ -1047,8 +1164,6 @@ export class ComfyApp {
|
||||
LiteGraph.registerNodeType(nodeId, node);
|
||||
node.category = nodeData.category;
|
||||
}
|
||||
|
||||
await this.#invokeExtensionsAsync("registerCustomNodes");
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1247,6 +1362,43 @@ export class ComfyApp {
|
||||
return { workflow, output };
|
||||
}
|
||||
|
||||
#formatPromptError(error) {
|
||||
if (error == null) {
|
||||
return "(unknown error)"
|
||||
}
|
||||
else if (typeof error === "string") {
|
||||
return error;
|
||||
}
|
||||
else if (error.stack && error.message) {
|
||||
return error.toString()
|
||||
}
|
||||
else if (error.response) {
|
||||
let message = error.response.error.message;
|
||||
if (error.response.error.details)
|
||||
message += ": " + error.response.error.details;
|
||||
for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) {
|
||||
message += "\n" + nodeError.class_type + ":"
|
||||
for (const errorReason of nodeError.errors) {
|
||||
message += "\n - " + errorReason.message + ": " + errorReason.details
|
||||
}
|
||||
}
|
||||
return message
|
||||
}
|
||||
return "(unknown error)"
|
||||
}
|
||||
|
||||
#formatExecutionError(error) {
|
||||
if (error == null) {
|
||||
return "(unknown error)"
|
||||
}
|
||||
|
||||
const traceback = error.traceback.join("")
|
||||
const nodeId = error.node_id
|
||||
const nodeType = error.node_type
|
||||
|
||||
return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}`
|
||||
}
|
||||
|
||||
async queuePrompt(number, batchCount = 1) {
|
||||
this.#queueItems.push({ number, batchCount });
|
||||
|
||||
@ -1254,8 +1406,10 @@ export class ComfyApp {
|
||||
if (this.#processingQueue) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
this.#processingQueue = true;
|
||||
this.lastPromptError = null;
|
||||
|
||||
try {
|
||||
while (this.#queueItems.length) {
|
||||
({ number, batchCount } = this.#queueItems.pop());
|
||||
@ -1266,7 +1420,12 @@ export class ComfyApp {
|
||||
try {
|
||||
await api.queuePrompt(number, p);
|
||||
} catch (error) {
|
||||
this.ui.dialog.show(error.response.error || error.toString());
|
||||
const formattedError = this.#formatPromptError(error)
|
||||
this.ui.dialog.show(formattedError);
|
||||
if (error.response) {
|
||||
this.lastPromptError = error.response;
|
||||
this.canvas.draw(true, true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1345,6 +1504,11 @@ export class ComfyApp {
|
||||
|
||||
const def = defs[node.type];
|
||||
|
||||
// HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes,
|
||||
// and additional work is needed to consider the primitive logic in the refresh logic.
|
||||
if(!def)
|
||||
continue;
|
||||
|
||||
for(const widgetNum in node.widgets) {
|
||||
const widget = node.widgets[widgetNum]
|
||||
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
|
||||
@ -1364,6 +1528,10 @@ export class ComfyApp {
|
||||
*/
|
||||
clean() {
|
||||
this.nodeOutputs = {};
|
||||
this.nodePreviewImages = {}
|
||||
this.lastPromptError = null;
|
||||
this.lastExecutionError = null;
|
||||
this.runningNodeId = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) {
|
||||
const embeddings = await api.getEmbeddings();
|
||||
const opts = parameters
|
||||
.substr(p)
|
||||
.split("\n")[1]
|
||||
.split(",")
|
||||
.reduce((p, n) => {
|
||||
const s = n.split(":");
|
||||
|
||||
@ -462,6 +462,24 @@ export class ComfyUI {
|
||||
defaultValue: true,
|
||||
});
|
||||
|
||||
/**
|
||||
* file format for preview
|
||||
*
|
||||
* format;quality
|
||||
*
|
||||
* ex)
|
||||
* webp;50 -> webp, quality 50
|
||||
* jpeg;80 -> rgb, jpeg, quality 80
|
||||
*
|
||||
* @type {string}
|
||||
*/
|
||||
const previewImage = this.settings.addSetting({
|
||||
id: "Comfy.PreviewFormat",
|
||||
name: "When displaying a preview in the image widget, convert it to a lightweight image. (webp, jpeg, webp;50, ...)",
|
||||
type: "string",
|
||||
defaultValue: "",
|
||||
});
|
||||
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
|
||||
@ -115,12 +115,12 @@ function addMultilineWidget(node, name, opts, app) {
|
||||
|
||||
// See how large each text input can be
|
||||
freeSpace -= widgetHeight;
|
||||
freeSpace /= multi.length;
|
||||
freeSpace /= multi.length + (!!node.imgs?.length);
|
||||
|
||||
if (freeSpace < MIN_SIZE) {
|
||||
// There isnt enough space for all the widgets, increase the size of the node
|
||||
freeSpace = MIN_SIZE;
|
||||
node.size[1] = y + widgetHeight + freeSpace * multi.length;
|
||||
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
|
||||
node.graph.setDirtyCanvas(true);
|
||||
}
|
||||
|
||||
@ -303,7 +303,7 @@ export const ComfyWidgets = {
|
||||
subfolder = name.substring(0, folder_separator);
|
||||
name = name.substring(folder_separator + 1);
|
||||
}
|
||||
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`;
|
||||
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`;
|
||||
node.setSizeForImage?.();
|
||||
}
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ body {
|
||||
padding: 30px 30px 10px 30px;
|
||||
background-color: var(--comfy-menu-bg); /* Modal background */
|
||||
color: var(--error-text);
|
||||
box-shadow: 0px 0px 20px #888888;
|
||||
box-shadow: 0 0 20px #888888;
|
||||
border-radius: 10px;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
@ -84,7 +84,7 @@ body {
|
||||
font-size: 15px;
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
right: 0%;
|
||||
right: 0;
|
||||
text-align: center;
|
||||
z-index: 100;
|
||||
width: 170px;
|
||||
@ -252,7 +252,7 @@ button.comfy-queue-btn {
|
||||
bottom: 0 !important;
|
||||
left: auto !important;
|
||||
right: 0 !important;
|
||||
border-radius: 0px;
|
||||
border-radius: 0;
|
||||
}
|
||||
.comfy-menu span.drag-handle {
|
||||
visibility:hidden
|
||||
@ -289,6 +289,11 @@ button.comfy-queue-btn {
|
||||
|
||||
/* Context menu */
|
||||
|
||||
.litegraph .dialog {
|
||||
z-index: 1;
|
||||
font-family: Arial, sans-serif;
|
||||
}
|
||||
|
||||
.litegraph .litemenu-entry.has_submenu {
|
||||
position: relative;
|
||||
padding-right: 20px;
|
||||
@ -325,12 +330,20 @@ button.comfy-queue-btn {
|
||||
color: var(--input-text) !important;
|
||||
}
|
||||
|
||||
.comfy-context-menu-filter {
|
||||
box-sizing: border-box;
|
||||
border: 1px solid #999;
|
||||
margin: 0 0 5px 5px;
|
||||
width: calc(100% - 10px);
|
||||
}
|
||||
|
||||
/* Search box */
|
||||
|
||||
.litegraph.litesearchbox {
|
||||
z-index: 9999 !important;
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
overflow: hidden;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.litegraph.litesearchbox input,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user