merge upstream

This commit is contained in:
Benjamin Berman 2023-10-17 14:47:59 -07:00
commit d21655b5a2
36 changed files with 714 additions and 647 deletions

View File

@ -41,13 +41,13 @@ jobs:
- shell: bash - shell: bash
run: | run: |
echo "@echo off echo "@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\ ..\python_embeded\python.exe .\update.py ..\ComfyUI\\
echo echo -
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
echo You should not be running this anyways unless you really have to echo You should not be running this anyways unless you really have to
echo echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead. echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo echo -
pause pause
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat pause" > update_comfyui_and_python_dependencies.bat

View File

@ -46,6 +46,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Ctrl + S | Save workflow | | Ctrl + S | Save workflow |
| Ctrl + O | Load workflow | | Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes | | Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes | | Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | | Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes | | Delete/Backspace | Delete selected nodes |
@ -93,7 +94,7 @@ Ctrl can also be replaced with Cmd instead for macOS users
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases). There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z) ### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu121_or_cpu.7z)
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
@ -127,6 +128,8 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
source ./venv/bin/activate source ./venv/bin/activate
``` ```
Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier.
3. Then, run the following command to install `comfyui` into your current environment. This will correctly select the version of pytorch that matches the GPU on your machine (NVIDIA or CPU on Windows, NVIDIA AMD or CPU on Linux): 3. Then, run the following command to install `comfyui` into your current environment. This will correctly select the version of pytorch that matches the GPU on your machine (NVIDIA or CPU on Windows, NVIDIA AMD or CPU on Linux):
```shell ```shell
pip install git+https://github.com/hiddenswitch/ComfyUI.git pip install git+https://github.com/hiddenswitch/ComfyUI.git

View File

@ -34,8 +34,7 @@ class ControlNet(nn.Module):
dims=2, dims=2,
num_classes=None, num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, dtype=torch.float32,
use_bf16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = dtype
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample

View File

@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
@ -52,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpvae_group = parser.add_mutually_exclusive_group() fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")

View File

@ -4,6 +4,8 @@ import asyncio
import copy import copy
import datetime import datetime
import heapq import heapq
import json
import logging
import threading import threading
import time import time
import traceback import traceback
@ -16,7 +18,7 @@ import torch
from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package import import_all_nodes_in_workspace
nodes = import_all_nodes_in_workspace() nodes = import_all_nodes_in_workspace()
from .. import model_management from .. import model_management # type: ignore
""" """
A queued item A queued item
@ -209,7 +211,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id},
server.client_id) server.client_id)
except model_management.InterruptProcessingException as iex: except model_management.InterruptProcessingException as iex:
print("Processing interrupted") logging.info("Processing interrupted")
# skip formatting inputs/outputs # skip formatting inputs/outputs
error_details = { error_details = {
@ -230,8 +232,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for node_id, node_outputs in outputs.items(): for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
print("!!! Exception during processing !!!") logging.error("!!! Exception during processing !!!")
print(traceback.format_exc()) logging.error(traceback.format_exc())
error_details = { error_details = {
"node_id": unique_id, "node_id": unique_id,
@ -443,7 +445,7 @@ class PromptExecutor:
def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]: def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], typing.Any]:
# todo: this should check if LoadImage / LoadImageMask paths exist # todo: this should check if LoadImage / LoadImageMask paths exist
# todo: or, nodes should provide a way to validate their values # todo: or, nodes should provide a way to validate their values
unique_id = item unique_id = item
@ -511,8 +513,8 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
errors.append(error) errors.append(error)
continue continue
try: try:
r = validate_inputs(prompt, o_id, validated) r2 = validate_inputs(prompt, o_id, validated)
if r[0] is False: if r2[0] is False:
# `r` will be set in `validated[o_id]` already # `r` will be set in `validated[o_id]` already
valid = False valid = False
continue continue
@ -593,11 +595,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
input_data_all = get_input_data(inputs, obj_class, unique_id) input_data_all = get_input_data(inputs, obj_class, unique_id)
# ret = obj_class.VALIDATE_INPUTS(**input_data_all) # ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret): for i, r3 in enumerate(ret):
if r is not True: if r3 is not True:
details = f"{x}" details = f"{x}"
if r is not False: if r3 is not False:
details += f" - {str(r)}" details += f" - {str(r3)}"
error = { error = {
"type": "custom_validation_failed", "type": "custom_validation_failed",
@ -698,11 +700,11 @@ def validate_prompt(prompt: dict) -> typing.Tuple[bool, dict | typing.List[dict]
if valid is True: if valid is True:
good_outputs.add(o) good_outputs.add(o)
else: else:
print(f"Failed to validate prompt for output {o}:") logging.error(f"Failed to validate prompt for output {o}:")
if len(reasons) > 0: if len(reasons) > 0:
print("* (prompt):") logging.error("* (prompt):")
for reason in reasons: for reason in reasons:
print(f" - {reason['message']}: {reason['details']}") logging.error(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)] errors += [(o, reasons)]
for node_id, result in validated.items(): for node_id, result in validated.items():
valid = result[0] valid = result[0]
@ -718,16 +720,16 @@ def validate_prompt(prompt: dict) -> typing.Tuple[bool, dict | typing.List[dict]
"dependent_outputs": [], "dependent_outputs": [],
"class_type": class_type "class_type": class_type
} }
print(f"* {class_type} {node_id}:") logging.error(f"* {class_type} {node_id}:")
for reason in reasons: for reason in reasons:
print(f" - {reason['message']}: {reason['details']}") logging.error(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o) node_errors[node_id]["dependent_outputs"].append(o)
print("Output will be ignored") logging.error("Output will be ignored")
if len(good_outputs) == 0: if len(good_outputs) == 0:
errors_list = [] errors_list = []
for o, errors in errors: for o, _errors in errors:
for error in errors: for error in _errors:
errors_list.append(f"{error['message']}: {error['details']}") errors_list.append(f"{error['message']}: {error['details']}")
errors_list = "\n".join(errors_list) errors_list = "\n".join(errors_list)

View File

@ -33,6 +33,8 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input") input_directory = os.path.join(base_path, "input")
@ -50,6 +52,10 @@ def set_temp_directory(temp_dir):
global temp_directory global temp_directory
temp_directory = temp_dir temp_directory = temp_dir
def set_input_directory(input_dir):
global input_directory
input_directory = input_dir
def get_output_directory(): def get_output_directory():
global output_directory global output_directory
return output_directory return output_directory
@ -144,7 +150,7 @@ def recursive_search(directory, excluded_dir_names=None):
return result, dirs return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))

View File

@ -55,7 +55,12 @@ def get_previewer(device, latent_format):
# TODO previewer methods # TODO previewer methods
taesd_decoder_path = None taesd_decoder_path = None
if latent_format.taesd_decoder_name is not None: if latent_format.taesd_decoder_name is not None:
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) taesd_decoder_path = next(
(fn for fn in folder_paths.get_filename_list("vae_approx")
if fn.startswith(latent_format.taesd_decoder_name)),
""
)
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
if method == LatentPreviewMethod.Auto: if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB method = LatentPreviewMethod.Latent2RGB

View File

@ -178,6 +178,16 @@ def main():
print(f"Setting output directory to: {output_dir}") print(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir) folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
print(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci: if args.quick_test_for_ci:
exit(0) exit(0)

View File

@ -6,6 +6,7 @@ import struct
import sys import sys
import shutil import shutil
from urllib.parse import quote from urllib.parse import quote
from pkg_resources import resource_filename
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -105,7 +106,6 @@ class PromptServer():
self.sockets = dict() self.sockets = dict()
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../web") web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../web")
if not os.path.exists(web_root_path): if not os.path.exists(web_root_path):
from pkg_resources import resource_filename
web_root_path = resource_filename('comfy', 'web/') web_root_path = resource_filename('comfy', 'web/')
self.web_root = web_root_path self.web_root = web_root_path
routes = web.RouteTableDef() routes = web.RouteTableDef()

View File

@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
controlnet_config = None controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = comfy.model_management.should_use_fp16() unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
return net return net
if controlnet_config is None: if controlnet_config is None:
use_fp16 = comfy.model_management.should_use_fp16() unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected) print(missing, unexpected)
if use_fp16: control_model = control_model.to(unet_dtype)
control_model = control_model.half()
global_average_pooling = False global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0] filename = os.path.splitext(ckpt_path)[0]
@ -456,7 +455,7 @@ def load_t2i_adapter(t2i_data):
for i in range(4): for i in range(4):
for j in range(2): for j in range(2):
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j) prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2) prefix_replace["adapter.body.{}.".format(i)] = "body.{}.".format(i * 2)
prefix_replace["adapter."] = "" prefix_replace["adapter."] = ""
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace) t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
keys = t2i_data.keys() keys = t2i_data.keys()

View File

@ -1,13 +0,0 @@
from asyncio import AbstractEventLoop
from ..cmd.execution import PromptQueue
from ..cmd.server import PromptServer
class Comfy:
loop: AbstractEventLoop
server: PromptServer
queue: PromptQueue
def __init__(self):
pass

View File

@ -20,7 +20,7 @@ class SD15(LatentFormat):
[-0.2829, 0.1762, 0.2721], [-0.2829, 0.1762, 0.2721],
[-0.2120, -0.2616, -0.7177] [-0.2120, -0.2616, -0.7177]
] ]
self.taesd_decoder_name = "taesd_decoder.pth" self.taesd_decoder_name = "taesd_decoder"
class SDXL(LatentFormat): class SDXL(LatentFormat):
def __init__(self): def __init__(self):
@ -32,4 +32,4 @@ class SDXL(LatentFormat):
[ 0.0568, 0.1687, -0.0755], [ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076] [-0.3112, -0.2359, -0.2076]
] ]
self.taesd_decoder_name = "taesdxl_decoder.pth" self.taesd_decoder_name = "taesdxl_decoder"

View File

@ -93,253 +93,222 @@ def zero_module(module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None):
h = heads
scale = (q.shape[-1] // heads) ** -0.5
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
class SpatialSelfAttention(nn.Module): # force cast to fp32 to avoid overflowing
def __init__(self, in_channels): if _ATTN_PRECISION =="fp32":
super().__init__() with torch.autocast(enabled=False, device_type = 'cuda'):
self.in_channels = in_channels q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
self.norm = Normalize(in_channels) del q, k
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): if exists(mask):
h_ = x mask = rearrange(mask, 'b ... -> b (...)')
h_ = self.norm(h_) max_neg_value = -torch.finfo(sim.dtype).max
q = self.q(h_) mask = repeat(mask, 'b j -> (b h) () j', h=h)
k = self.k(h_) sim.masked_fill_(~mask, max_neg_value)
v = self.v(h_)
# compute attention # attention, what we cannot get enough of
b,c,h,w = q.shape sim = sim.softmax(dim=-1)
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5)) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
w_ = torch.nn.functional.softmax(w_, dim=2) out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return out
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttentionBirchSan(nn.Module): def attention_sub_quad(query, key, value, heads, mask=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): scale = (query.shape[-1] // heads) ** -0.5
super().__init__() query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
inner_dim = dim_head * heads key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
context_dim = default(context_dim, query_dim) del key
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
self.scale = dim_head ** -0.5 dtype = query.dtype
self.heads = heads upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None): kv_chunk_size_min = None
h = self.heads
query = self.to_q(x) #not sure at all about the math here
context = default(context, x) #TODO: tweak this
key = self.to_k(context) if mem_free_total > 8192 * 1024 * 1024 * 1.3:
if value is not None: query_chunk_size_x = 1024 * 4
value = self.to_v(value) elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
else: query_chunk_size_x = 1024 * 2
value = self.to_v(context) else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
del context, x if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) hidden_states = efficient_dot_product_attention(
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1) query,
del key key_t,
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False,
upcast_attention=upcast_attention,
)
dtype = query.dtype hidden_states = hidden_states.to(dtype)
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD def attention_split(q, k, v, heads, mask=None):
scale = (q.shape[-1] // heads) ** -0.5
h = heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
kv_chunk_size_min = None r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
#not sure at all about the math here mem_free_total = model_management.get_free_memory(q.device)
#TODO: tweak this
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 4
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 2
else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: gb = 1024 ** 3
# the big matmul fits into our memory limit; do everything in 1 chunk, tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
# i.e. send it down the unchunked fast-path modifier = 3 if q.element_size() == 2 else 2.5
query_chunk_size = q_tokens mem_required = tensor_size * modifier
kv_chunk_size = k_tokens steps = 1
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
hidden_states = efficient_dot_product_attention(
query,
key_t,
value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=self.training,
upcast_attention=upcast_attention,
)
hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
hidden_states = out_proj(hidden_states)
hidden_states = dropout(hidden_states)
return hidden_states
class CrossAttentionDoggettx(nn.Module): if mem_required > mem_free_total:
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
super().__init__() # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
inner_dim = dim_head * heads # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5 if steps > 64:
self.heads = heads max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) first_op_done = False
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) cleared_cache = False
while True:
self.to_out = nn.Sequential( try:
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
nn.Dropout(dropout) for i in range(0, q.shape[1], slice_size):
) end = i + slice_size
if _ATTN_PRECISION =="fp32":
def forward(self, x, context=None, value=None, mask=None): with torch.autocast(enabled=False, device_type = 'cuda'):
h = self.heads s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
if value is not None:
v_in = self.to_v(value)
del value
else:
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
raise e
print("out of memory error, increasing steps and trying again", steps)
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
raise e raise e
print("out of memory error, increasing steps and trying again", steps)
else:
raise e
del q, k, v del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1 del r1
return r2
return self.to_out(r2) def attention_xformers(q, k, v, heads, mask=None):
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], heads, -1)
.permute(0, 2, 1, 3)
.reshape(b * heads, t.shape[1], -1)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, heads, out.shape[1], -1)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], -1)
)
return out
def attention_pytorch(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled():
print("Using xformers cross attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
optimized_attention = attention_split
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
@ -347,62 +316,6 @@ class CrossAttention(nn.Module):
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
@ -411,7 +324,6 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
@ -423,85 +335,12 @@ class MemoryEfficientCrossAttention(nn.Module):
else: else:
v = self.to_v(context) v = self.to_v(context)
b, _, _ = q.shape if mask is None:
q, k, v = map( out = optimized_attention(q, k, v, self.heads)
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
else: else:
v = self.to_v(context) out = optimized_attention_masked(q, k, v, self.heads, mask)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = (
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
)
return self.to_out(out) return self.to_out(out)
if model_management.xformers_enabled():
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,

View File

@ -6,7 +6,6 @@ import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention
from .... import model_management from .... import model_management
from .... import ops from .... import ops
@ -194,6 +193,52 @@ def slice_attention(q, k, v):
return r1 return r1
def normal_attention(q, k, v):
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
return h_
def xformers_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
def pytorch_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
@ -221,6 +266,16 @@ class AttnBlock(nn.Module):
stride=1, stride=1,
padding=0) padding=0)
if model_management.xformers_enabled_vae():
print("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
print("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
print("Using split attention in VAE")
self.optimized_attention = normal_attention
def forward(self, x): def forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
@ -228,161 +283,15 @@ class AttnBlock(nn.Module):
k = self.k(h_) k = self.k(h_)
v = self.v(h_) v = self.v(h_)
# compute attention h_ = self.optimized_attention(q, k, v)
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x+h_
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
class MemoryEfficientAttnBlockPytorch(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' return AttnBlock(in_channels)
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "vanilla-pytorch":
return MemoryEfficientAttnBlockPytorch(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
raise NotImplementedError()
class Model(nn.Module): class Model(nn.Module):

View File

@ -296,8 +296,7 @@ class UNetModel(nn.Module):
dims=2, dims=2,
num_classes=None, num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, dtype=th.float32,
use_bf16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = dtype
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample

View File

@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1 count += 1
return count return count
def detect_unet_config(state_dict, key_prefix, use_fp16): def detect_unet_config(state_dict, key_prefix, dtype):
state_dict_keys = list(state_dict.keys()) state_dict_keys = list(state_dict.keys())
unet_config = { unet_config = {
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
else: else:
unet_config["adm_in_channels"] = None unet_config["adm_in_channels"] = None
unet_config["use_fp16"] = use_fp16 unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
print("no match", unet_config) print("no match", unet_config)
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
model_config = model_config_from_unet_config(unet_config) model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config) return comfy.supported_models_base.BASE(unet_config)
else: else:
return model_config return model_config
def unet_config_from_diffusers_unet(state_dict, use_fp16): def unet_config_from_diffusers_unet(state_dict, dtype):
match = {} match = {}
attention_resolutions = [] attention_resolutions = []
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
return unet_config return unet_config
return None return None
def model_config_from_diffusers_unet(state_dict, use_fp16): def model_config_from_diffusers_unet(state_dict, dtype):
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16) unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
if unet_config is not None: if unet_config is not None:
return model_config_from_unet_config(unet_config) return model_config_from_unet_config(unet_config)
return None return None

View File

@ -154,14 +154,18 @@ def is_nvidia():
return True return True
return False return False
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
VAE_DTYPE = torch.float32 VAE_DTYPE = torch.float32
try: try:
if is_nvidia(): if is_nvidia():
torch_version = torch.version.__version__ torch_version = torch.version.__version__
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
VAE_DTYPE = torch.bfloat16 VAE_DTYPE = torch.bfloat16
@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
XFORMERS_IS_AVAILABLE = False
if args.lowvram: if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
@ -354,6 +357,8 @@ def load_models_gpu(models, memory_required=0):
current_loaded_models.insert(0, current_loaded_models.pop(index)) current_loaded_models.insert(0, current_loaded_models.pop(index))
models_already_loaded.append(loaded_model) models_already_loaded.append(loaded_model)
else: else:
if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
if len(models_to_load) == 0: if len(models_to_load) == 0:
@ -363,7 +368,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem, d, models_already_loaded)
return return
print("loading new") print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
@ -405,7 +410,6 @@ def load_model_gpu(model):
def cleanup_models(): def cleanup_models():
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
print(sys.getrefcount(current_loaded_models[i].model))
if sys.getrefcount(current_loaded_models[i].model) <= 2: if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete to_delete = [i] + to_delete
@ -444,6 +448,13 @@ def unet_inital_load_device(parameters, dtype):
else: else:
return cpu_dev return cpu_dev
def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
@ -656,7 +667,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return False return False
#FP16 is just broken on these cards #FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"] nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
for x in nvidia_16_series: for x in nvidia_16_series:
if x in props.name: if x in props.name:
return False return False

View File

@ -107,6 +107,10 @@ class ModelPatcher:
for k in patch_list: for k in patch_list:
if hasattr(patch_list[k], "to"): if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device) patch_list[k] = patch_list[k].to(device)
if "unet_wrapper_function" in self.model_options:
wrap_func = self.model_options["unet_wrapper_function"]
if hasattr(wrap_func, "to"):
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
def model_dtype(self): def model_dtype(self):
if hasattr(self.model, "get_dtype"): if hasattr(self.model, "get_dtype"):

View File

@ -1185,7 +1185,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
callback = latent_preview.prepare_callback(model, steps) callback = latent_preview.prepare_callback(model, steps)
disable_pbar = False disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
@ -1645,7 +1645,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"KSampler": "KSampler", "KSampler": "KSampler",
"KSamplerAdvanced": "KSampler (Advanced)", "KSamplerAdvanced": "KSampler (Advanced)",
# Loaders # Loaders
"CheckpointLoader": "Load Checkpoint (With Config)", "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint", "CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"LoraLoader": "Load LoRA", "LoraLoader": "Load LoRA",

View File

@ -29,7 +29,7 @@ def prepare_mask(noise_mask, shape, device):
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round() noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device) noise_mask = noise_mask.to(device)
return noise_mask return noise_mask

View File

@ -327,7 +327,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "params" in model_config_params["unet_config"]: if "params" in model_config_params["unet_config"]:
unet_config = model_config_params["unet_config"]["params"] unet_config = model_config_params["unet_config"]["params"]
if "use_fp16" in unet_config: if "use_fp16" in unet_config:
fp16 = unet_config["use_fp16"] fp16 = unet_config.pop("use_fp16")
if fp16:
unet_config["dtype"] = torch.float16
noise_aug_config = None noise_aug_config = None
if "noise_aug_config" in model_config_params: if "noise_aug_config" in model_config_params:
@ -405,12 +407,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
fp16 = model_management.should_use_fp16(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
pass pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16) model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -418,12 +420,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision: if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
dtype = torch.float32
if fp16:
dtype = torch.float16
if output_model: if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet(unet_path): #load unet in diffusers format def load_unet(unet_path): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path) sd = comfy.utils.load_torch_file(unet_path)
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
fp16 = model_management.should_use_fp16(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", fp16) model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
new_sd = sd new_sd = sd
else: #diffusers else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None: if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path) print("ERROR UNSUPPORTED UNET", unet_path)
return None return None

View File

@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.utils
def conv(n_in, n_out, **kwargs): def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
@ -50,9 +52,9 @@ class TAESD(nn.Module):
self.encoder = Encoder() self.encoder = Encoder()
self.decoder = Decoder() self.decoder = Decoder()
if encoder_path is not None: if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None: if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod @staticmethod
def scale_latents(x): def scale_latents(x):

View File

@ -410,6 +410,10 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
PROGRESS_BAR_ENABLED = enabled
PROGRESS_BAR_HOOK = None PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function): def set_progress_bar_global_hook(function):

View File

@ -158,7 +158,7 @@ class SplitImageWithAlpha:
def split_image_with_alpha(self, image: torch.Tensor): def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image] out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), torch.stack(out_alphas)) result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result return result
@ -180,7 +180,7 @@ class JoinImageWithAlpha:
batch_size = min(len(image), len(alpha)) batch_size = min(len(image), len(alpha))
out_images = [] out_images = []
alpha = resize_mask(alpha, image.shape[1:]) alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size): for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))

View File

@ -3,6 +3,7 @@ import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sampling as k_diffusion_sampling
import latent_preview import latent_preview
import torch import torch
import comfy.utils
class BasicScheduler: class BasicScheduler:
@ -15,7 +16,7 @@ class BasicScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -35,7 +36,7 @@ class KarrasScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -53,7 +54,7 @@ class ExponentialScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -72,7 +73,7 @@ class PolyexponentialScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -91,7 +92,7 @@ class VPScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -108,7 +109,7 @@ class SplitSigmas:
} }
} }
RETURN_TYPES = ("SIGMAS","SIGMAS") RETURN_TYPES = ("SIGMAS","SIGMAS")
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -125,7 +126,7 @@ class KSamplerSelect:
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
@ -144,7 +145,7 @@ class SamplerDPMPP_2M_SDE:
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
@ -168,7 +169,7 @@ class SamplerDPMPP_SDE:
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
@ -201,7 +202,7 @@ class SamplerCustom:
FUNCTION = "sample" FUNCTION = "sample"
CATEGORY = "_for_testing/custom_sampling" CATEGORY = "sampling/custom_sampling"
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
latent = latent_image latent = latent_image
@ -219,7 +220,7 @@ class SamplerCustom:
x0_output = {} x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
disable_pbar = False disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy() out = latent.copy()

View File

@ -241,8 +241,8 @@ class MaskComposite:
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
visible_width, visible_height = (right - left, bottom - top,) visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:visible_height, :visible_width] source_portion = source[:, :visible_height, :visible_width]
destination_portion = destination[top:bottom, left:right] destination_portion = destination[:, top:bottom, left:right]
if operation == "multiply": if operation == "multiply":
output[:, top:bottom, left:right] = destination_portion * source_portion output[:, top:bottom, left:right] = destination_portion * source_portion
@ -283,10 +283,10 @@ class FeatherMask:
def feather(self, mask, left, top, right, bottom): def feather(self, mask, left, top, right, bottom):
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
left = min(left, output.shape[1]) left = min(left, output.shape[-1])
right = min(right, output.shape[1]) right = min(right, output.shape[-1])
top = min(top, output.shape[0]) top = min(top, output.shape[-2])
bottom = min(bottom, output.shape[0]) bottom = min(bottom, output.shape[-2])
for x in range(left): for x in range(left):
feather_rate = (x + 1.0) / left feather_rate = (x + 1.0) / left

View File

@ -1,5 +1,6 @@
from comfy import sd from comfy import sd
from comfy import model_base from comfy import model_base
import comfy.model_management
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
import json import json
@ -177,6 +178,95 @@ class CheckpointSave:
sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {} return {}
class CLIPSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()])
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
for x in k:
current_clip_sd[x] = clip_sd.pop(x)
if len(current_clip_sd) == 0:
continue
p = prefix[:-1]
replace_prefix = {}
filename_prefix_ = filename_prefix
if len(p) > 0:
filename_prefix_ = "{}_{}".format(filename_prefix_, p)
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
class VAESave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple, "ModelMergeSimple": ModelMergeSimple,
@ -185,4 +275,6 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeAdd": ModelAdd, "ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave, "CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple, "CLIPMergeSimple": CLIPMergeSimple,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
} }

View File

@ -1,5 +1,6 @@
#Rename this to extra_model_paths.yaml and ComfyUI will load it #Rename this to extra_model_paths.yaml and ComfyUI will load it
#config for a1111 ui #config for a1111 ui
#all you have to do is change the base_path to where yours is installed #all you have to do is change the base_path to where yours is installed
a111: a111:
@ -19,6 +20,21 @@ a111:
hypernetworks: models/hypernetworks hypernetworks: models/hypernetworks
controlnet: models/ControlNet controlnet: models/ControlNet
#config for comfyui
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
#comfyui:
# base_path: path/to/comfyui/
# checkpoints: models/checkpoints/
# clip: models/clip/
# clip_vision: models/clip_vision/
# configs: models/configs/
# controlnet: models/controlnet/
# embeddings: models/embeddings/
# loras: models/loras/
# upscale_models: models/upscale_models/
# vae: models/vae/
#other_ui: #other_ui:
# base_path: path/to/ui # base_path: path/to/ui
# checkpoints: models/checkpoints # checkpoints: models/checkpoints

4
mypy.ini Normal file
View File

@ -0,0 +1,4 @@
[mypy]
files = comfy/, comfy_extras/
ignore_missing_imports = True
strict_optional = True

View File

@ -26,4 +26,5 @@ Pillow
scipy scipy
tqdm tqdm
protobuf==3.20.3 protobuf==3.20.3
psutil psutil
mypy>=1.6.0

View File

@ -4,6 +4,7 @@ import os.path
import platform import platform
import subprocess import subprocess
import sys import sys
from typing import List
from pip._internal.index.collector import LinkCollector from pip._internal.index.collector import LinkCollector
from pip._internal.index.package_finder import PackageFinder from pip._internal.index.package_finder import PackageFinder
@ -106,7 +107,7 @@ def _is_linux_arm64():
return os_name == 'Linux' and architecture == 'aarch64' return os_name == 'Linux' and architecture == 'aarch64'
def dependencies() -> [str]: def dependencies() -> List[str]:
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines() _dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
# todo: also add all plugin dependencies # todo: also add all plugin dependencies
_alternative_indices = [amd_torch_index, nvidia_torch_index, cpu_torch_index_nightlies] _alternative_indices = [amd_torch_index, nvidia_torch_index, cpu_torch_index_nightlies]
@ -137,7 +138,7 @@ def dependencies() -> [str]:
except: except:
try: try:
# pip 22 # pip 22
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls)), finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls)), # type: ignore
SelectionPreferences(allow_yanked=False, prefer_binary=False, SelectionPreferences(allow_yanked=False, prefer_binary=False,
allow_all_prereleases=True) allow_all_prereleases=True)
, use_deprecated_html5lib=False) , use_deprecated_html5lib=False)

View File

@ -5,6 +5,61 @@ function setNodeMode(node, mode) {
node.graph.change(); node.graph.change();
} }
function addNodesToGroup(group, nodes=[]) {
var x1, y1, x2, y2;
var nx1, ny1, nx2, ny2;
var node;
x1 = y1 = x2 = y2 = -1;
nx1 = ny1 = nx2 = ny2 = -1;
for (var n of [group._nodes, nodes]) {
for (var i in n) {
node = n[i]
nx1 = node.pos[0]
ny1 = node.pos[1]
nx2 = node.pos[0] + node.size[0]
ny2 = node.pos[1] + node.size[1]
if (node.type != "Reroute") {
ny1 -= LiteGraph.NODE_TITLE_HEIGHT;
}
if (node.flags?.collapsed) {
ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT;
if (node?._collapsed_width) {
nx2 = nx1 + Math.round(node._collapsed_width);
}
}
if (x1 == -1 || nx1 < x1) {
x1 = nx1;
}
if (y1 == -1 || ny1 < y1) {
y1 = ny1;
}
if (x2 == -1 || nx2 > x2) {
x2 = nx2;
}
if (y2 == -1 || ny2 > y2) {
y2 = ny2;
}
}
}
var padding = 10;
y1 = y1 - Math.round(group.font_size * 1.4);
group.pos = [x1 - padding, y1 - padding];
group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2];
}
app.registerExtension({ app.registerExtension({
name: "Comfy.GroupOptions", name: "Comfy.GroupOptions",
setup() { setup() {
@ -14,6 +69,17 @@ app.registerExtension({
const options = orig.apply(this, arguments); const options = orig.apply(this, arguments);
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]); const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
if (!group) { if (!group) {
options.push({
content: "Add Group For Selected Nodes",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
var group = new LiteGraph.LGraphGroup();
addNodesToGroup(group, this.selected_nodes)
app.canvas.graph.add(group);
this.graph.change();
}
});
return options; return options;
} }
@ -21,6 +87,15 @@ app.registerExtension({
group.recomputeInsideNodes(); group.recomputeInsideNodes();
const nodesInGroup = group._nodes; const nodesInGroup = group._nodes;
options.push({
content: "Add Selected Nodes To Group",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
addNodesToGroup(group, this.selected_nodes)
this.graph.change();
}
});
// No nodes in group, return default options // No nodes in group, return default options
if (nodesInGroup.length === 0) { if (nodesInGroup.length === 0) {
return options; return options;
@ -38,6 +113,23 @@ app.registerExtension({
} }
} }
options.push({
content: "Fit Group To Nodes",
callback: () => {
addNodesToGroup(group)
this.graph.change();
}
});
options.push({
content: "Select Nodes",
callback: () => {
this.selectNodes(nodesInGroup);
this.graph.change();
this.canvas.focus();
}
});
// Modes // Modes
// 0: Always // 0: Always
// 1: On Event // 1: On Event

View File

@ -200,6 +200,10 @@ app.registerExtension({
for (const input of this.inputs) { for (const input of this.inputs) {
if (input.widget && !input.widget[GET_CONFIG]) { if (input.widget && !input.widget[GET_CONFIG]) {
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name); input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
const w = this.widgets.find((w) => w.name === input.widget.name);
if (w) {
hideWidget(this, w);
}
} }
} }
} }

View File

@ -3796,7 +3796,7 @@
out = out || new Float32Array(4); out = out || new Float32Array(4);
out[0] = this.pos[0] - 4; out[0] = this.pos[0] - 4;
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
out[2] = this.size[0] + 4; out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT; out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
if (this.onBounding) { if (this.onBounding) {

View File

@ -450,6 +450,47 @@ export class ComfyApp {
} }
} }
function calculateGrid(w, h, n) {
let columns, rows, cellsize;
if (w > h) {
cellsize = h;
columns = Math.ceil(w / cellsize);
rows = Math.ceil(n / columns);
} else {
cellsize = w;
rows = Math.ceil(h / cellsize);
columns = Math.ceil(n / rows);
}
while (columns * rows < n) {
cellsize++;
if (w >= h) {
columns = Math.ceil(w / cellsize);
rows = Math.ceil(n / columns);
} else {
rows = Math.ceil(h / cellsize);
columns = Math.ceil(n / rows);
}
}
const cell_size = Math.min(w/columns, h/rows);
return {cell_size, columns, rows};
}
function is_all_same_aspect_ratio(imgs) {
// assume: imgs.length >= 2
let ratio = imgs[0].naturalWidth/imgs[0].naturalHeight;
for(let i=1; i<imgs.length; i++) {
let this_ratio = imgs[i].naturalWidth/imgs[i].naturalHeight;
if(ratio != this_ratio)
return false;
}
return true;
}
if (this.imgs && this.imgs.length) { if (this.imgs && this.imgs.length) {
const canvas = graph.list_of_graphcanvas[0]; const canvas = graph.list_of_graphcanvas[0];
const mouse = canvas.graph_mouse; const mouse = canvas.graph_mouse;
@ -460,44 +501,60 @@ export class ComfyApp {
this.pointerDown = null; this.pointerDown = null;
} }
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
let imageIndex = this.imageIndex; let imageIndex = this.imageIndex;
const numImages = this.imgs.length; const numImages = this.imgs.length;
if (numImages === 1 && !imageIndex) { if (numImages === 1 && !imageIndex) {
this.imageIndex = imageIndex = 0; this.imageIndex = imageIndex = 0;
} }
const shiftY = getImageTop(this); const top = getImageTop(this);
var shiftY = top;
let dw = this.size[0]; let dw = this.size[0];
let dh = this.size[1]; let dh = this.size[1];
dh -= shiftY; dh -= shiftY;
if (imageIndex == null) { if (imageIndex == null) {
let best = 0; var cellWidth, cellHeight, shiftX, cell_padding, cols;
let cellWidth;
let cellHeight;
let cols = 0;
let shiftX = 0;
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1); const compact_mode = is_all_same_aspect_ratio(this.imgs);
const imageW = w * scale; if(!compact_mode) {
const imageH = h * scale; // use rectangle cell style and border line
const area = imageW * imageH * numImages; cell_padding = 2;
const { cell_size, columns, rows } = calculateGrid(dw, dh, numImages);
cols = columns;
if (area > best) { cellWidth = cell_size;
best = area; cellHeight = cell_size;
cellWidth = imageW; shiftX = (dw-cell_size*cols)/2;
cellHeight = imageH; shiftY = (dh-cell_size*rows)/2 + top;
cols = c; }
shiftX = c * ((cW - imageW) / 2); else {
cell_padding = 0;
let best = 0;
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
// compact style
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
shiftX = c * ((cW - imageW) / 2);
}
} }
} }
@ -542,7 +599,14 @@ export class ComfyApp {
let imgWidth = ratio * img.width; let imgWidth = ratio * img.width;
let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2;
ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight); ctx.drawImage(img, imgX+cell_padding, imgY+cell_padding, imgWidth-cell_padding*2, imgHeight-cell_padding*2);
if(!compact_mode) {
// rectangle cell and border line style
ctx.strokeStyle = "#8F8F8F";
ctx.lineWidth = 1;
ctx.strokeRect(x+cell_padding, y+cell_padding, cellWidth-cell_padding*2, cellHeight-cell_padding*2);
}
ctx.filter = "none"; ctx.filter = "none";
} }
@ -552,6 +616,9 @@ export class ComfyApp {
} }
} else { } else {
// Draw individual // Draw individual
let w = this.imgs[imageIndex].naturalWidth;
let h = this.imgs[imageIndex].naturalHeight;
const scaleX = dw / w; const scaleX = dw / w;
const scaleY = dh / h; const scaleY = dh / h;
const scale = Math.min(scaleX, scaleY, 1); const scale = Math.min(scaleX, scaleY, 1);
@ -594,14 +661,14 @@ export class ComfyApp {
}; };
if (numImages > 1) { if (numImages > 1) {
if (drawButton(x + w - 35, y + h - 35, 30, `${this.imageIndex + 1}/${numImages}`)) { if (drawButton(dw - 40, dh + top - 40, 30, `${this.imageIndex + 1}/${numImages}`)) {
let i = this.imageIndex + 1 >= numImages ? 0 : this.imageIndex + 1; let i = this.imageIndex + 1 >= numImages ? 0 : this.imageIndex + 1;
if (!this.pointerDown || !this.pointerDown.index === i) { if (!this.pointerDown || !this.pointerDown.index === i) {
this.pointerDown = { index: i, pos: [...mouse] }; this.pointerDown = { index: i, pos: [...mouse] };
} }
} }
if (drawButton(x + w - 35, y + 5, 30, `x`)) { if (drawButton(dw - 40, top + 10, 30, `x`)) {
if (!this.pointerDown || !this.pointerDown.index === null) { if (!this.pointerDown || !this.pointerDown.index === null) {
this.pointerDown = { index: null, pos: [...mouse] }; this.pointerDown = { index: null, pos: [...mouse] };
} }
@ -861,6 +928,16 @@ export class ComfyApp {
block_default = true; block_default = true;
} }
// Alt + C collapse/uncollapse
if (e.key === 'c' && e.altKey) {
if (this.selected_nodes) {
for (var i in this.selected_nodes) {
this.selected_nodes[i].collapse()
}
}
block_default = true;
}
// Ctrl+C Copy // Ctrl+C Copy
if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) {
// Trigger onCopy // Trigger onCopy
@ -1525,7 +1602,7 @@ export class ComfyApp {
all_inputs = all_inputs.concat(Object.keys(parent.inputs)) all_inputs = all_inputs.concat(Object.keys(parent.inputs))
for (let parent_input in all_inputs) { for (let parent_input in all_inputs) {
parent_input = all_inputs[parent_input]; parent_input = all_inputs[parent_input];
if (parent.inputs[parent_input].type === node.inputs[i].type) { if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
link = parent.getInputLink(parent_input); link = parent.getInputLink(parent_input);
if (link) { if (link) {
parent = parent.getInputNode(parent_input); parent = parent.getInputNode(parent_input);

View File

@ -809,7 +809,8 @@ export class ComfyUI {
if ( if (
this.lastQueueSize != 0 && this.lastQueueSize != 0 &&
status.exec_info.queue_remaining == 0 && status.exec_info.queue_remaining == 0 &&
document.getElementById("autoQueueCheckbox").checked document.getElementById("autoQueueCheckbox").checked &&
! app.lastExecutionError
) { ) {
app.queuePrompt(0, this.batchCount); app.queuePrompt(0, this.batchCount);
} }