mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
merge upstream
This commit is contained in:
commit
d21655b5a2
@ -41,13 +41,13 @@ jobs:
|
||||
- shell: bash
|
||||
run: |
|
||||
echo "@echo off
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||
echo
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\\
|
||||
echo -
|
||||
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
|
||||
echo You should not be running this anyways unless you really have to
|
||||
echo
|
||||
echo -
|
||||
echo If you just want to update normally, close this and run update_comfyui.bat instead.
|
||||
echo
|
||||
echo -
|
||||
pause
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > update_comfyui_and_python_dependencies.bat
|
||||
|
||||
@ -46,6 +46,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
@ -93,7 +94,7 @@ Ctrl can also be replaced with Cmd instead for macOS users
|
||||
|
||||
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z)
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu121_or_cpu.7z)
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
|
||||
@ -127,6 +128,8 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
|
||||
source ./venv/bin/activate
|
||||
```
|
||||
|
||||
Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier.
|
||||
|
||||
3. Then, run the following command to install `comfyui` into your current environment. This will correctly select the version of pytorch that matches the GPU on your machine (NVIDIA or CPU on Windows, NVIDIA AMD or CPU on Linux):
|
||||
```shell
|
||||
pip install git+https://github.com/hiddenswitch/ComfyUI.git
|
||||
|
||||
@ -34,8 +34,7 @@ class ControlNet(nn.Module):
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
use_bf16=False,
|
||||
dtype=torch.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
||||
self.dtype = dtype
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
@ -52,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||
|
||||
@ -4,6 +4,8 @@ import asyncio
|
||||
import copy
|
||||
import datetime
|
||||
import heapq
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@ -16,7 +18,7 @@ import torch
|
||||
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
nodes = import_all_nodes_in_workspace()
|
||||
from .. import model_management
|
||||
from .. import model_management # type: ignore
|
||||
|
||||
"""
|
||||
A queued item
|
||||
@ -209,7 +211,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id},
|
||||
server.client_id)
|
||||
except model_management.InterruptProcessingException as iex:
|
||||
print("Processing interrupted")
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details = {
|
||||
@ -230,8 +232,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
logging.error("!!! Exception during processing !!!")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
@ -443,7 +445,7 @@ class PromptExecutor:
|
||||
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
|
||||
def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], typing.Any]:
|
||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||
# todo: or, nodes should provide a way to validate their values
|
||||
unique_id = item
|
||||
@ -511,8 +513,8 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
r2 = validate_inputs(prompt, o_id, validated)
|
||||
if r2[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
continue
|
||||
@ -593,11 +595,11 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
# ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
for i, r3 in enumerate(ret):
|
||||
if r3 is not True:
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
if r3 is not False:
|
||||
details += f" - {str(r3)}"
|
||||
|
||||
error = {
|
||||
"type": "custom_validation_failed",
|
||||
@ -698,11 +700,11 @@ def validate_prompt(prompt: dict) -> typing.Tuple[bool, dict | typing.List[dict]
|
||||
if valid is True:
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
print(f"Failed to validate prompt for output {o}:")
|
||||
logging.error(f"Failed to validate prompt for output {o}:")
|
||||
if len(reasons) > 0:
|
||||
print("* (prompt):")
|
||||
logging.error("* (prompt):")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
logging.error(f" - {reason['message']}: {reason['details']}")
|
||||
errors += [(o, reasons)]
|
||||
for node_id, result in validated.items():
|
||||
valid = result[0]
|
||||
@ -718,16 +720,16 @@ def validate_prompt(prompt: dict) -> typing.Tuple[bool, dict | typing.List[dict]
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type
|
||||
}
|
||||
print(f"* {class_type} {node_id}:")
|
||||
logging.error(f"* {class_type} {node_id}:")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
logging.error(f" - {reason['message']}: {reason['details']}")
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
print("Output will be ignored")
|
||||
logging.error("Output will be ignored")
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = []
|
||||
for o, errors in errors:
|
||||
for error in errors:
|
||||
for o, _errors in errors:
|
||||
for error in _errors:
|
||||
errors_list.append(f"{error['message']}: {error['details']}")
|
||||
errors_list = "\n".join(errors_list)
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
|
||||
|
||||
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
@ -50,6 +52,10 @@ def set_temp_directory(temp_dir):
|
||||
global temp_directory
|
||||
temp_directory = temp_dir
|
||||
|
||||
def set_input_directory(input_dir):
|
||||
global input_directory
|
||||
input_directory = input_dir
|
||||
|
||||
def get_output_directory():
|
||||
global output_directory
|
||||
return output_directory
|
||||
@ -144,7 +150,7 @@ def recursive_search(directory, excluded_dir_names=None):
|
||||
return result, dirs
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
||||
|
||||
|
||||
|
||||
|
||||
@ -55,7 +55,12 @@ def get_previewer(device, latent_format):
|
||||
# TODO previewer methods
|
||||
taesd_decoder_path = None
|
||||
if latent_format.taesd_decoder_name is not None:
|
||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
||||
taesd_decoder_path = next(
|
||||
(fn for fn in folder_paths.get_filename_list("vae_approx")
|
||||
if fn.startswith(latent_format.taesd_decoder_name)),
|
||||
""
|
||||
)
|
||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
|
||||
|
||||
if method == LatentPreviewMethod.Auto:
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
|
||||
@ -178,6 +178,16 @@ def main():
|
||||
print(f"Setting output directory to: {output_dir}")
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
|
||||
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
print(f"Setting input directory to: {input_dir}")
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import struct
|
||||
import sys
|
||||
import shutil
|
||||
from urllib.parse import quote
|
||||
from pkg_resources import resource_filename
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
@ -105,7 +106,6 @@ class PromptServer():
|
||||
self.sockets = dict()
|
||||
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../web")
|
||||
if not os.path.exists(web_root_path):
|
||||
from pkg_resources import resource_filename
|
||||
web_root_path = resource_filename('comfy', 'web/')
|
||||
self.web_root = web_root_path
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
|
||||
|
||||
controlnet_config = None
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
use_fp16 = comfy.model_management.should_use_fp16()
|
||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
use_fp16 = comfy.model_management.should_use_fp16()
|
||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
print(missing, unexpected)
|
||||
|
||||
if use_fp16:
|
||||
control_model = control_model.half()
|
||||
control_model = control_model.to(unet_dtype)
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
@ -456,7 +455,7 @@ def load_t2i_adapter(t2i_data):
|
||||
for i in range(4):
|
||||
for j in range(2):
|
||||
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
|
||||
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
|
||||
prefix_replace["adapter.body.{}.".format(i)] = "body.{}.".format(i * 2)
|
||||
prefix_replace["adapter."] = ""
|
||||
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
|
||||
keys = t2i_data.keys()
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
from ..cmd.execution import PromptQueue
|
||||
from ..cmd.server import PromptServer
|
||||
|
||||
|
||||
class Comfy:
|
||||
loop: AbstractEventLoop
|
||||
server: PromptServer
|
||||
queue: PromptQueue
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -20,7 +20,7 @@ class SD15(LatentFormat):
|
||||
[-0.2829, 0.1762, 0.2721],
|
||||
[-0.2120, -0.2616, -0.7177]
|
||||
]
|
||||
self.taesd_decoder_name = "taesd_decoder.pth"
|
||||
self.taesd_decoder_name = "taesd_decoder"
|
||||
|
||||
class SDXL(LatentFormat):
|
||||
def __init__(self):
|
||||
@ -32,4 +32,4 @@ class SDXL(LatentFormat):
|
||||
[ 0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder.pth"
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
@ -93,253 +93,222 @@ def zero_module(module):
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None):
|
||||
h = heads
|
||||
scale = (q.shape[-1] // heads) ** -0.5
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
del q, k
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttentionBirchSan(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
scale = (query.shape[-1] // heads) ** -0.5
|
||||
query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
|
||||
del key
|
||||
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
dtype = query.dtype
|
||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
else:
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
kv_chunk_size_min = None
|
||||
|
||||
query = self.to_q(x)
|
||||
context = default(context, x)
|
||||
key = self.to_k(context)
|
||||
if value is not None:
|
||||
value = self.to_v(value)
|
||||
else:
|
||||
value = self.to_v(context)
|
||||
#not sure at all about the math here
|
||||
#TODO: tweak this
|
||||
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 4
|
||||
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 2
|
||||
else:
|
||||
query_chunk_size_x = 1024
|
||||
kv_chunk_size_min_x = None
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||
if kv_chunk_size_x < 1024:
|
||||
kv_chunk_size_x = None
|
||||
|
||||
del context, x
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
else:
|
||||
query_chunk_size = query_chunk_size_x
|
||||
kv_chunk_size = kv_chunk_size_x
|
||||
kv_chunk_size_min = kv_chunk_size_min_x
|
||||
|
||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1)
|
||||
del key
|
||||
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key_t,
|
||||
value,
|
||||
query_chunk_size=query_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=False,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
dtype = query.dtype
|
||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
else:
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||
def attention_split(q, k, v, heads, mask=None):
|
||||
scale = (q.shape[-1] // heads) ** -0.5
|
||||
h = heads
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
kv_chunk_size_min = None
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
#not sure at all about the math here
|
||||
#TODO: tweak this
|
||||
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 4
|
||||
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 2
|
||||
else:
|
||||
query_chunk_size_x = 1024
|
||||
kv_chunk_size_min_x = None
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||
if kv_chunk_size_x < 1024:
|
||||
kv_chunk_size_x = None
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
else:
|
||||
query_chunk_size = query_chunk_size_x
|
||||
kv_chunk_size = kv_chunk_size_x
|
||||
kv_chunk_size_min = kv_chunk_size_min_x
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key_t,
|
||||
value,
|
||||
query_chunk_size=query_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=self.training,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2)
|
||||
|
||||
out_proj, dropout = self.to_out
|
||||
hidden_states = out_proj(hidden_states)
|
||||
hidden_states = dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
class CrossAttentionDoggettx(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
if value is not None:
|
||||
v_in = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
if first_op_done == False:
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
cleared_cache = True
|
||||
print("out of memory error, emptying cache and trying again")
|
||||
continue
|
||||
steps *= 2
|
||||
if steps > 64:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
if first_op_done == False:
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
cleared_cache = True
|
||||
print("out of memory error, emptying cache and trying again")
|
||||
continue
|
||||
steps *= 2
|
||||
if steps > 64:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
else:
|
||||
raise e
|
||||
|
||||
del q, k, v
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
return r2
|
||||
|
||||
return self.to_out(r2)
|
||||
def attention_xformers(q, k, v, heads, mask=None):
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], heads, -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, t.shape[1], -1)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, out.shape[1], -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], -1)
|
||||
)
|
||||
return out
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
optimized_attention_masked = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch cross attention")
|
||||
optimized_attention = attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
print("Using split optimization for cross attention")
|
||||
optimized_attention = attention_split
|
||||
else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||
@ -347,62 +316,6 @@ class CrossAttention(nn.Module):
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
@ -411,7 +324,6 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
@ -423,85 +335,12 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossAttentionPytorch(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
|
||||
)
|
||||
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
||||
return self.to_out(out)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
CrossAttention = MemoryEfficientCrossAttention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch cross attention")
|
||||
CrossAttention = CrossAttentionPytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
print("Using split optimization for cross attention")
|
||||
CrossAttention = CrossAttentionDoggettx
|
||||
else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
CrossAttention = CrossAttentionBirchSan
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
|
||||
@ -6,7 +6,6 @@ import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
|
||||
from ..attention import MemoryEfficientCrossAttention
|
||||
from .... import model_management
|
||||
from .... import ops
|
||||
|
||||
@ -194,6 +193,52 @@ def slice_attention(q, k, v):
|
||||
|
||||
return r1
|
||||
|
||||
def normal_attention(q, k, v):
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
return h_
|
||||
|
||||
def xformers_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
def pytorch_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
@ -221,6 +266,16 @@ class AttnBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
print("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
print("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@ -228,161 +283,15 @@ class AttnBlock(nn.Module):
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
h_ = self.optimized_attention(q, k, v)
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
class MemoryEfficientAttnBlockPytorch(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
def forward(self, x, context=None, mask=None):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
out = super().forward(x, context=context, mask=mask)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
||||
return x + out
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-pytorch"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-pytorch":
|
||||
return MemoryEfficientAttnBlockPytorch(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return AttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
@ -296,8 +296,7 @@ class UNetModel(nn.Module):
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
use_bf16=False,
|
||||
dtype=th.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
||||
self.dtype = dtype
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
unet_config = {
|
||||
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
unet_config["use_fp16"] = use_fp16
|
||||
unet_config["dtype"] = dtype
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
|
||||
print("no match", unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
match = {}
|
||||
attention_resolutions = []
|
||||
|
||||
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
||||
|
||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||
|
||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
return unet_config
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
||||
@ -154,14 +154,18 @@ def is_nvidia():
|
||||
return True
|
||||
return False
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch.cuda.is_bf16_supported():
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
if args.lowvram:
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
@ -354,6 +357,8 @@ def load_models_gpu(models, memory_required=0):
|
||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||
models_already_loaded.append(loaded_model)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
print(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
@ -363,7 +368,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
free_memory(extra_mem, d, models_already_loaded)
|
||||
return
|
||||
|
||||
print("loading new")
|
||||
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
@ -405,7 +410,6 @@ def load_model_gpu(model):
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
print(sys.getrefcount(current_loaded_models[i].model))
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
@ -444,6 +448,13 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
@ -656,7 +667,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
||||
return False
|
||||
|
||||
#FP16 is just broken on these cards
|
||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
|
||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
|
||||
for x in nvidia_16_series:
|
||||
if x in props.name:
|
||||
return False
|
||||
|
||||
@ -107,6 +107,10 @@ class ModelPatcher:
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
if "unet_wrapper_function" in self.model_options:
|
||||
wrap_func = self.model_options["unet_wrapper_function"]
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
|
||||
@ -1185,7 +1185,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
callback = latent_preview.prepare_callback(model, steps)
|
||||
disable_pbar = False
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
@ -1645,7 +1645,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"KSampler": "KSampler",
|
||||
"KSamplerAdvanced": "KSampler (Advanced)",
|
||||
# Loaders
|
||||
"CheckpointLoader": "Load Checkpoint (With Config)",
|
||||
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
"VAELoader": "Load VAE",
|
||||
"LoraLoader": "Load LoRA",
|
||||
|
||||
@ -29,7 +29,7 @@ def prepare_mask(noise_mask, shape, device):
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = utils.repeat_to_batch_size(noise_mask, shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
return noise_mask
|
||||
|
||||
|
||||
20
comfy/sd.py
20
comfy/sd.py
@ -327,7 +327,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
if "params" in model_config_params["unet_config"]:
|
||||
unet_config = model_config_params["unet_config"]["params"]
|
||||
if "use_fp16" in unet_config:
|
||||
fp16 = unet_config["use_fp16"]
|
||||
fp16 = unet_config.pop("use_fp16")
|
||||
if fp16:
|
||||
unet_config["dtype"] = torch.float16
|
||||
|
||||
noise_aug_config = None
|
||||
if "noise_aug_config" in model_config_params:
|
||||
@ -405,12 +407,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
clip_target = None
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
|
||||
@ -418,12 +420,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_clipvision:
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
|
||||
dtype = torch.float32
|
||||
if fp16:
|
||||
dtype = torch.float16
|
||||
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", fp16)
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
new_sd = sd
|
||||
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||
if model_config is None:
|
||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||
return None
|
||||
|
||||
@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.utils
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
@ -50,9 +52,9 @@ class TAESD(nn.Module):
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
if encoder_path is not None:
|
||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
||||
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
if decoder_path is not None:
|
||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
||||
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
|
||||
@ -410,6 +410,10 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
PROGRESS_BAR_ENABLED = enabled
|
||||
|
||||
PROGRESS_BAR_HOOK = None
|
||||
def set_progress_bar_global_hook(function):
|
||||
|
||||
@ -158,7 +158,7 @@ class SplitImageWithAlpha:
|
||||
def split_image_with_alpha(self, image: torch.Tensor):
|
||||
out_images = [i[:,:,:3] for i in image]
|
||||
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
||||
result = (torch.stack(out_images), torch.stack(out_alphas))
|
||||
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||
return result
|
||||
|
||||
|
||||
@ -180,7 +180,7 @@ class JoinImageWithAlpha:
|
||||
batch_size = min(len(image), len(alpha))
|
||||
out_images = []
|
||||
|
||||
alpha = resize_mask(alpha, image.shape[1:])
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
import latent_preview
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class BasicScheduler:
|
||||
@ -15,7 +16,7 @@ class BasicScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -35,7 +36,7 @@ class KarrasScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -53,7 +54,7 @@ class ExponentialScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -72,7 +73,7 @@ class PolyexponentialScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -91,7 +92,7 @@ class VPScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -108,7 +109,7 @@ class SplitSigmas:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS","SIGMAS")
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -125,7 +126,7 @@ class KSamplerSelect:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
@ -144,7 +145,7 @@ class SamplerDPMPP_2M_SDE:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
@ -168,7 +169,7 @@ class SamplerDPMPP_SDE:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
@ -201,7 +202,7 @@ class SamplerCustom:
|
||||
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "_for_testing/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
|
||||
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
|
||||
latent = latent_image
|
||||
@ -219,7 +220,7 @@ class SamplerCustom:
|
||||
x0_output = {}
|
||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||
|
||||
disable_pbar = False
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||
|
||||
out = latent.copy()
|
||||
|
||||
@ -241,8 +241,8 @@ class MaskComposite:
|
||||
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
|
||||
visible_width, visible_height = (right - left, bottom - top,)
|
||||
|
||||
source_portion = source[:visible_height, :visible_width]
|
||||
destination_portion = destination[top:bottom, left:right]
|
||||
source_portion = source[:, :visible_height, :visible_width]
|
||||
destination_portion = destination[:, top:bottom, left:right]
|
||||
|
||||
if operation == "multiply":
|
||||
output[:, top:bottom, left:right] = destination_portion * source_portion
|
||||
@ -283,10 +283,10 @@ class FeatherMask:
|
||||
def feather(self, mask, left, top, right, bottom):
|
||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||
|
||||
left = min(left, output.shape[1])
|
||||
right = min(right, output.shape[1])
|
||||
top = min(top, output.shape[0])
|
||||
bottom = min(bottom, output.shape[0])
|
||||
left = min(left, output.shape[-1])
|
||||
right = min(right, output.shape[-1])
|
||||
top = min(top, output.shape[-2])
|
||||
bottom = min(bottom, output.shape[-2])
|
||||
|
||||
for x in range(left):
|
||||
feather_rate = (x + 1.0) / left
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from comfy import sd
|
||||
from comfy import model_base
|
||||
import comfy.model_management
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
import json
|
||||
@ -177,6 +178,95 @@ class CheckpointSave:
|
||||
sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||
return {}
|
||||
|
||||
class CLIPSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip": ("CLIP",),
|
||||
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comfy.model_management.load_models_gpu([clip.load_model()])
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", ""]:
|
||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||
current_clip_sd = {}
|
||||
for x in k:
|
||||
current_clip_sd[x] = clip_sd.pop(x)
|
||||
if len(current_clip_sd) == 0:
|
||||
continue
|
||||
|
||||
p = prefix[:-1]
|
||||
replace_prefix = {}
|
||||
filename_prefix_ = filename_prefix
|
||||
if len(p) > 0:
|
||||
filename_prefix_ = "{}_{}".format(filename_prefix_, p)
|
||||
replace_prefix[prefix] = ""
|
||||
replace_prefix["transformer."] = ""
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
|
||||
|
||||
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
class VAESave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae": ("VAE",),
|
||||
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSimple": ModelMergeSimple,
|
||||
@ -185,4 +275,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeAdd": ModelAdd,
|
||||
"CheckpointSave": CheckpointSave,
|
||||
"CLIPMergeSimple": CLIPMergeSimple,
|
||||
"CLIPSave": CLIPSave,
|
||||
"VAESave": VAESave,
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#Rename this to extra_model_paths.yaml and ComfyUI will load it
|
||||
|
||||
|
||||
#config for a1111 ui
|
||||
#all you have to do is change the base_path to where yours is installed
|
||||
a111:
|
||||
@ -19,6 +20,21 @@ a111:
|
||||
hypernetworks: models/hypernetworks
|
||||
controlnet: models/ControlNet
|
||||
|
||||
#config for comfyui
|
||||
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
|
||||
|
||||
#comfyui:
|
||||
# base_path: path/to/comfyui/
|
||||
# checkpoints: models/checkpoints/
|
||||
# clip: models/clip/
|
||||
# clip_vision: models/clip_vision/
|
||||
# configs: models/configs/
|
||||
# controlnet: models/controlnet/
|
||||
# embeddings: models/embeddings/
|
||||
# loras: models/loras/
|
||||
# upscale_models: models/upscale_models/
|
||||
# vae: models/vae/
|
||||
|
||||
#other_ui:
|
||||
# base_path: path/to/ui
|
||||
# checkpoints: models/checkpoints
|
||||
|
||||
4
mypy.ini
Normal file
4
mypy.ini
Normal file
@ -0,0 +1,4 @@
|
||||
[mypy]
|
||||
files = comfy/, comfy_extras/
|
||||
ignore_missing_imports = True
|
||||
strict_optional = True
|
||||
@ -26,4 +26,5 @@ Pillow
|
||||
scipy
|
||||
tqdm
|
||||
protobuf==3.20.3
|
||||
psutil
|
||||
psutil
|
||||
mypy>=1.6.0
|
||||
5
setup.py
5
setup.py
@ -4,6 +4,7 @@ import os.path
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from pip._internal.index.collector import LinkCollector
|
||||
from pip._internal.index.package_finder import PackageFinder
|
||||
@ -106,7 +107,7 @@ def _is_linux_arm64():
|
||||
return os_name == 'Linux' and architecture == 'aarch64'
|
||||
|
||||
|
||||
def dependencies() -> [str]:
|
||||
def dependencies() -> List[str]:
|
||||
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
|
||||
# todo: also add all plugin dependencies
|
||||
_alternative_indices = [amd_torch_index, nvidia_torch_index, cpu_torch_index_nightlies]
|
||||
@ -137,7 +138,7 @@ def dependencies() -> [str]:
|
||||
except:
|
||||
try:
|
||||
# pip 22
|
||||
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls)),
|
||||
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls)), # type: ignore
|
||||
SelectionPreferences(allow_yanked=False, prefer_binary=False,
|
||||
allow_all_prereleases=True)
|
||||
, use_deprecated_html5lib=False)
|
||||
|
||||
@ -5,6 +5,61 @@ function setNodeMode(node, mode) {
|
||||
node.graph.change();
|
||||
}
|
||||
|
||||
function addNodesToGroup(group, nodes=[]) {
|
||||
var x1, y1, x2, y2;
|
||||
var nx1, ny1, nx2, ny2;
|
||||
var node;
|
||||
|
||||
x1 = y1 = x2 = y2 = -1;
|
||||
nx1 = ny1 = nx2 = ny2 = -1;
|
||||
|
||||
for (var n of [group._nodes, nodes]) {
|
||||
for (var i in n) {
|
||||
node = n[i]
|
||||
|
||||
nx1 = node.pos[0]
|
||||
ny1 = node.pos[1]
|
||||
nx2 = node.pos[0] + node.size[0]
|
||||
ny2 = node.pos[1] + node.size[1]
|
||||
|
||||
if (node.type != "Reroute") {
|
||||
ny1 -= LiteGraph.NODE_TITLE_HEIGHT;
|
||||
}
|
||||
|
||||
if (node.flags?.collapsed) {
|
||||
ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT;
|
||||
|
||||
if (node?._collapsed_width) {
|
||||
nx2 = nx1 + Math.round(node._collapsed_width);
|
||||
}
|
||||
}
|
||||
|
||||
if (x1 == -1 || nx1 < x1) {
|
||||
x1 = nx1;
|
||||
}
|
||||
|
||||
if (y1 == -1 || ny1 < y1) {
|
||||
y1 = ny1;
|
||||
}
|
||||
|
||||
if (x2 == -1 || nx2 > x2) {
|
||||
x2 = nx2;
|
||||
}
|
||||
|
||||
if (y2 == -1 || ny2 > y2) {
|
||||
y2 = ny2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var padding = 10;
|
||||
|
||||
y1 = y1 - Math.round(group.font_size * 1.4);
|
||||
|
||||
group.pos = [x1 - padding, y1 - padding];
|
||||
group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2];
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.GroupOptions",
|
||||
setup() {
|
||||
@ -14,6 +69,17 @@ app.registerExtension({
|
||||
const options = orig.apply(this, arguments);
|
||||
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
|
||||
if (!group) {
|
||||
options.push({
|
||||
content: "Add Group For Selected Nodes",
|
||||
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||
callback: () => {
|
||||
var group = new LiteGraph.LGraphGroup();
|
||||
addNodesToGroup(group, this.selected_nodes)
|
||||
app.canvas.graph.add(group);
|
||||
this.graph.change();
|
||||
}
|
||||
});
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
@ -21,6 +87,15 @@ app.registerExtension({
|
||||
group.recomputeInsideNodes();
|
||||
const nodesInGroup = group._nodes;
|
||||
|
||||
options.push({
|
||||
content: "Add Selected Nodes To Group",
|
||||
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||
callback: () => {
|
||||
addNodesToGroup(group, this.selected_nodes)
|
||||
this.graph.change();
|
||||
}
|
||||
});
|
||||
|
||||
// No nodes in group, return default options
|
||||
if (nodesInGroup.length === 0) {
|
||||
return options;
|
||||
@ -38,6 +113,23 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
options.push({
|
||||
content: "Fit Group To Nodes",
|
||||
callback: () => {
|
||||
addNodesToGroup(group)
|
||||
this.graph.change();
|
||||
}
|
||||
});
|
||||
|
||||
options.push({
|
||||
content: "Select Nodes",
|
||||
callback: () => {
|
||||
this.selectNodes(nodesInGroup);
|
||||
this.graph.change();
|
||||
this.canvas.focus();
|
||||
}
|
||||
});
|
||||
|
||||
// Modes
|
||||
// 0: Always
|
||||
// 1: On Event
|
||||
|
||||
@ -200,6 +200,10 @@ app.registerExtension({
|
||||
for (const input of this.inputs) {
|
||||
if (input.widget && !input.widget[GET_CONFIG]) {
|
||||
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
|
||||
const w = this.widgets.find((w) => w.name === input.widget.name);
|
||||
if (w) {
|
||||
hideWidget(this, w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3796,7 +3796,7 @@
|
||||
out = out || new Float32Array(4);
|
||||
out[0] = this.pos[0] - 4;
|
||||
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
|
||||
out[2] = this.size[0] + 4;
|
||||
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
|
||||
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
|
||||
|
||||
if (this.onBounding) {
|
||||
|
||||
@ -450,6 +450,47 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
|
||||
function calculateGrid(w, h, n) {
|
||||
let columns, rows, cellsize;
|
||||
|
||||
if (w > h) {
|
||||
cellsize = h;
|
||||
columns = Math.ceil(w / cellsize);
|
||||
rows = Math.ceil(n / columns);
|
||||
} else {
|
||||
cellsize = w;
|
||||
rows = Math.ceil(h / cellsize);
|
||||
columns = Math.ceil(n / rows);
|
||||
}
|
||||
|
||||
while (columns * rows < n) {
|
||||
cellsize++;
|
||||
if (w >= h) {
|
||||
columns = Math.ceil(w / cellsize);
|
||||
rows = Math.ceil(n / columns);
|
||||
} else {
|
||||
rows = Math.ceil(h / cellsize);
|
||||
columns = Math.ceil(n / rows);
|
||||
}
|
||||
}
|
||||
|
||||
const cell_size = Math.min(w/columns, h/rows);
|
||||
return {cell_size, columns, rows};
|
||||
}
|
||||
|
||||
function is_all_same_aspect_ratio(imgs) {
|
||||
// assume: imgs.length >= 2
|
||||
let ratio = imgs[0].naturalWidth/imgs[0].naturalHeight;
|
||||
|
||||
for(let i=1; i<imgs.length; i++) {
|
||||
let this_ratio = imgs[i].naturalWidth/imgs[i].naturalHeight;
|
||||
if(ratio != this_ratio)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this.imgs && this.imgs.length) {
|
||||
const canvas = graph.list_of_graphcanvas[0];
|
||||
const mouse = canvas.graph_mouse;
|
||||
@ -460,44 +501,60 @@ export class ComfyApp {
|
||||
this.pointerDown = null;
|
||||
}
|
||||
|
||||
let w = this.imgs[0].naturalWidth;
|
||||
let h = this.imgs[0].naturalHeight;
|
||||
let imageIndex = this.imageIndex;
|
||||
const numImages = this.imgs.length;
|
||||
if (numImages === 1 && !imageIndex) {
|
||||
this.imageIndex = imageIndex = 0;
|
||||
}
|
||||
|
||||
const shiftY = getImageTop(this);
|
||||
const top = getImageTop(this);
|
||||
var shiftY = top;
|
||||
|
||||
let dw = this.size[0];
|
||||
let dh = this.size[1];
|
||||
dh -= shiftY;
|
||||
|
||||
if (imageIndex == null) {
|
||||
let best = 0;
|
||||
let cellWidth;
|
||||
let cellHeight;
|
||||
let cols = 0;
|
||||
let shiftX = 0;
|
||||
for (let c = 1; c <= numImages; c++) {
|
||||
const rows = Math.ceil(numImages / c);
|
||||
const cW = dw / c;
|
||||
const cH = dh / rows;
|
||||
const scaleX = cW / w;
|
||||
const scaleY = cH / h;
|
||||
var cellWidth, cellHeight, shiftX, cell_padding, cols;
|
||||
|
||||
const scale = Math.min(scaleX, scaleY, 1);
|
||||
const imageW = w * scale;
|
||||
const imageH = h * scale;
|
||||
const area = imageW * imageH * numImages;
|
||||
const compact_mode = is_all_same_aspect_ratio(this.imgs);
|
||||
if(!compact_mode) {
|
||||
// use rectangle cell style and border line
|
||||
cell_padding = 2;
|
||||
const { cell_size, columns, rows } = calculateGrid(dw, dh, numImages);
|
||||
cols = columns;
|
||||
|
||||
if (area > best) {
|
||||
best = area;
|
||||
cellWidth = imageW;
|
||||
cellHeight = imageH;
|
||||
cols = c;
|
||||
shiftX = c * ((cW - imageW) / 2);
|
||||
cellWidth = cell_size;
|
||||
cellHeight = cell_size;
|
||||
shiftX = (dw-cell_size*cols)/2;
|
||||
shiftY = (dh-cell_size*rows)/2 + top;
|
||||
}
|
||||
else {
|
||||
cell_padding = 0;
|
||||
let best = 0;
|
||||
let w = this.imgs[0].naturalWidth;
|
||||
let h = this.imgs[0].naturalHeight;
|
||||
|
||||
// compact style
|
||||
for (let c = 1; c <= numImages; c++) {
|
||||
const rows = Math.ceil(numImages / c);
|
||||
const cW = dw / c;
|
||||
const cH = dh / rows;
|
||||
const scaleX = cW / w;
|
||||
const scaleY = cH / h;
|
||||
|
||||
const scale = Math.min(scaleX, scaleY, 1);
|
||||
const imageW = w * scale;
|
||||
const imageH = h * scale;
|
||||
const area = imageW * imageH * numImages;
|
||||
|
||||
if (area > best) {
|
||||
best = area;
|
||||
cellWidth = imageW;
|
||||
cellHeight = imageH;
|
||||
cols = c;
|
||||
shiftX = c * ((cW - imageW) / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -542,7 +599,14 @@ export class ComfyApp {
|
||||
let imgWidth = ratio * img.width;
|
||||
let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2;
|
||||
|
||||
ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight);
|
||||
ctx.drawImage(img, imgX+cell_padding, imgY+cell_padding, imgWidth-cell_padding*2, imgHeight-cell_padding*2);
|
||||
if(!compact_mode) {
|
||||
// rectangle cell and border line style
|
||||
ctx.strokeStyle = "#8F8F8F";
|
||||
ctx.lineWidth = 1;
|
||||
ctx.strokeRect(x+cell_padding, y+cell_padding, cellWidth-cell_padding*2, cellHeight-cell_padding*2);
|
||||
}
|
||||
|
||||
ctx.filter = "none";
|
||||
}
|
||||
|
||||
@ -552,6 +616,9 @@ export class ComfyApp {
|
||||
}
|
||||
} else {
|
||||
// Draw individual
|
||||
let w = this.imgs[imageIndex].naturalWidth;
|
||||
let h = this.imgs[imageIndex].naturalHeight;
|
||||
|
||||
const scaleX = dw / w;
|
||||
const scaleY = dh / h;
|
||||
const scale = Math.min(scaleX, scaleY, 1);
|
||||
@ -594,14 +661,14 @@ export class ComfyApp {
|
||||
};
|
||||
|
||||
if (numImages > 1) {
|
||||
if (drawButton(x + w - 35, y + h - 35, 30, `${this.imageIndex + 1}/${numImages}`)) {
|
||||
if (drawButton(dw - 40, dh + top - 40, 30, `${this.imageIndex + 1}/${numImages}`)) {
|
||||
let i = this.imageIndex + 1 >= numImages ? 0 : this.imageIndex + 1;
|
||||
if (!this.pointerDown || !this.pointerDown.index === i) {
|
||||
this.pointerDown = { index: i, pos: [...mouse] };
|
||||
}
|
||||
}
|
||||
|
||||
if (drawButton(x + w - 35, y + 5, 30, `x`)) {
|
||||
if (drawButton(dw - 40, top + 10, 30, `x`)) {
|
||||
if (!this.pointerDown || !this.pointerDown.index === null) {
|
||||
this.pointerDown = { index: null, pos: [...mouse] };
|
||||
}
|
||||
@ -861,6 +928,16 @@ export class ComfyApp {
|
||||
block_default = true;
|
||||
}
|
||||
|
||||
// Alt + C collapse/uncollapse
|
||||
if (e.key === 'c' && e.altKey) {
|
||||
if (this.selected_nodes) {
|
||||
for (var i in this.selected_nodes) {
|
||||
this.selected_nodes[i].collapse()
|
||||
}
|
||||
}
|
||||
block_default = true;
|
||||
}
|
||||
|
||||
// Ctrl+C Copy
|
||||
if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) {
|
||||
// Trigger onCopy
|
||||
@ -1525,7 +1602,7 @@ export class ComfyApp {
|
||||
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
||||
for (let parent_input in all_inputs) {
|
||||
parent_input = all_inputs[parent_input];
|
||||
if (parent.inputs[parent_input].type === node.inputs[i].type) {
|
||||
if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
|
||||
link = parent.getInputLink(parent_input);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(parent_input);
|
||||
|
||||
@ -809,7 +809,8 @@ export class ComfyUI {
|
||||
if (
|
||||
this.lastQueueSize != 0 &&
|
||||
status.exec_info.queue_remaining == 0 &&
|
||||
document.getElementById("autoQueueCheckbox").checked
|
||||
document.getElementById("autoQueueCheckbox").checked &&
|
||||
! app.lastExecutionError
|
||||
) {
|
||||
app.queuePrompt(0, this.batchCount);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user