diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml
index c7ef93ce1..319942e7c 100644
--- a/.github/workflows/windows_release_nightly_pytorch.yml
+++ b/.github/workflows/windows_release_nightly_pytorch.yml
@@ -31,7 +31,7 @@ jobs:
echo 'import site' >> ./python311._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
- python -m pip wheel torch torchvision torchaudio aiohttp==3.8.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
+ python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python311._pth
diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py
index 46fbf0a69..251483131 100644
--- a/comfy/cldm/cldm.py
+++ b/comfy/cldm/cldm.py
@@ -6,8 +6,6 @@ import torch as th
import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import (
- conv_nd,
- linear,
zero_module,
timestep_embedding,
)
@@ -15,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists
-
+import comfy.ops
class ControlledUnetModel(UNetModel):
#implemented in the ldm unet
@@ -55,6 +53,8 @@ class ControlNet(nn.Module):
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
+ device=None,
+ operations=comfy.ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@@ -117,9 +117,9 @@ class ControlNet(nn.Module):
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
)
if self.num_classes is not None:
@@ -132,9 +132,9 @@ class ControlNet(nn.Module):
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
- linear(adm_in_channels, time_embed_dim),
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
)
)
else:
@@ -143,28 +143,28 @@ class ControlNet(nn.Module):
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
]
)
- self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
self.input_hint_block = TimestepEmbedSequential(
- conv_nd(dims, hint_channels, 16, 3, padding=1),
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(),
- conv_nd(dims, 16, 16, 3, padding=1),
+ operations.conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
- conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
- conv_nd(dims, 32, 32, 3, padding=1),
+ operations.conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
- conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
- conv_nd(dims, 96, 96, 3, padding=1),
+ operations.conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
- conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
- zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
+ zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
)
self._feature_size = model_channels
@@ -182,6 +182,7 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
+ operations=operations
)
]
ch = mult * model_channels
@@ -204,11 +205,11 @@ class ControlNet(nn.Module):
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint
+ use_checkpoint=use_checkpoint, operations=operations
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
- self.zero_convs.append(self.make_zero_conv(ch))
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
@@ -224,16 +225,17 @@ class ControlNet(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
+ operations=operations
)
if resblock_updown
else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
+ ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
)
)
)
ch = out_ch
input_block_chans.append(ch)
- self.zero_convs.append(self.make_zero_conv(ch))
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
ds *= 2
self._feature_size += ch
@@ -253,11 +255,12 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
+ operations=operations
),
SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint
+ use_checkpoint=use_checkpoint, operations=operations
),
ResBlock(
ch,
@@ -266,16 +269,17 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
+ operations=operations
),
)
- self.middle_block_out = self.make_zero_conv(ch)
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self._feature_size += ch
- def make_zero_conv(self, channels):
- return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+ def make_zero_conv(self, channels, operations=None):
+ return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
@@ -283,9 +287,6 @@ class ControlNet(nn.Module):
outs = []
hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
-
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index ec7d34a55..b4f22f319 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
+parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
+
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"
@@ -82,6 +84,9 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn'
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
+parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
+
+
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index 8d04faf71..a887e51b5 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -25,6 +25,7 @@ class ClipVisionModel():
def encode_image(self, image):
img = torch.clip((255. * image), 0, 255).round().int()
+ img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt")
outputs = self.model(**inputs)
return outputs
@@ -49,18 +50,22 @@ def convert_to_transformers(sd, prefix):
if "{}proj".format(prefix) in sd_k:
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
- sd = transformers_convert(sd, prefix, "vision_model.", 32)
+ sd = transformers_convert(sd, prefix, "vision_model.", 48)
return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if convert_keys:
sd = convert_to_transformers(sd, prefix)
- if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
+ if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
+ elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
+ if len(m) > 0:
+ print("missing clip vision:", m)
u = set(u)
keys = list(sd.keys())
for k in keys:
@@ -71,4 +76,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
- return load_clipvision_from_sd(sd)
+ if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
+ return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
+ else:
+ return load_clipvision_from_sd(sd)
diff --git a/comfy/clip_vision_config_g.json b/comfy/clip_vision_config_g.json
new file mode 100644
index 000000000..708e7e21a
--- /dev/null
+++ b/comfy/clip_vision_config_g.json
@@ -0,0 +1,18 @@
+{
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "gelu",
+ "hidden_size": 1664,
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 8192,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 48,
+ "patch_size": 14,
+ "projection_dim": 1280,
+ "torch_dtype": "float32"
+}
diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py
index 88a30d98d..7c9d7d07e 100644
--- a/comfy/cmd/execution.py
+++ b/comfy/cmd/execution.py
@@ -409,6 +409,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x)
del d
+ comfy.model_management.cleanup_models()
if self.server.client_id is not None:
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id},
self.server.client_id)
diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py
index da5edab30..761b6bee7 100644
--- a/comfy/cmd/server.py
+++ b/comfy/cmd/server.py
@@ -4,6 +4,7 @@ import glob
import struct
import sys
import shutil
+from urllib.parse import quote
from PIL import Image, ImageOps
from io import BytesIO
@@ -80,6 +81,8 @@ class PromptServer():
mimetypes.init()
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
+
+ self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
@@ -140,9 +143,17 @@ class PromptServer():
@routes.get("/extensions")
async def get_extensions(request):
- files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
- return web.json_response(
- list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
+ files = glob.glob(os.path.join(
+ self.web_root, 'extensions/**/*.js'), recursive=True)
+
+ extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
+
+ for name, dir in nodes.EXTENSION_WEB_DIRS.items():
+ files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True)
+ extensions.extend(list(map(lambda f: "/extensions/" + quote(
+ name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
+
+ return web.json_response(extensions)
def get_dir_by_type(dir_type=None):
type_dir = ""
@@ -638,6 +649,12 @@ class PromptServer():
def add_routes(self):
self.app.add_routes(self.routes)
+
+ for name, dir in nodes.EXTENSION_WEB_DIRS.items():
+ self.app.add_routes([
+ web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True),
+ ])
+
self.app.add_routes([
web.static('/', self.web_root, follow_symlinks=True),
])
diff --git a/comfy/gligen.py b/comfy/gligen.py
index 90558785b..8d182839e 100644
--- a/comfy/gligen.py
+++ b/comfy/gligen.py
@@ -244,30 +244,15 @@ class Gligen(nn.Module):
self.position_net = position_net
self.key_dim = key_dim
self.max_objs = 30
- self.lowvram = False
+ self.current_device = torch.device("cpu")
def _set_position(self, boxes, masks, positive_embeddings):
- if self.lowvram == True:
- self.position_net.to(boxes.device)
-
objs = self.position_net(boxes, masks, positive_embeddings)
-
- if self.lowvram == True:
- self.position_net.cpu()
- def func_lowvram(x, extra_options):
- key = extra_options["transformer_index"]
- module = self.module_list[key]
- module.to(x.device)
- r = module(x, objs)
- module.cpu()
- return r
- return func_lowvram
- else:
- def func(x, extra_options):
- key = extra_options["transformer_index"]
- module = self.module_list[key]
- return module(x, objs)
- return func
+ def func(x, extra_options):
+ key = extra_options["transformer_index"]
+ module = self.module_list[key]
+ return module(x, objs)
+ return func
def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape
@@ -312,14 +297,6 @@ class Gligen(nn.Module):
masks.to(device),
conds.to(device))
- def set_lowvram(self, value=True):
- self.lowvram = value
-
- def cleanup(self):
- self.lowvram = False
-
- def get_models(self):
- return [self]
def load_gligen(sd):
sd_k = sd.keys()
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index beaa623f3..eb088d92b 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -649,7 +649,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None
- h_1, h_2 = None, None
+ h, h_1, h_2 = None, None, None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 573cea6ac..973619bf2 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -10,13 +10,14 @@ from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
-import comfy.ops
if model_management.xformers_enabled():
import xformers
import xformers.ops
from comfy.cli_args import args
+import comfy.ops
+
# CrossAttn precision handling
if args.dont_upcast_attention:
print("disabling upcasting of attention")
@@ -52,9 +53,9 @@ def init_(tensor):
# feedforward
class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out, dtype=None, device=None):
+ def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
super().__init__()
- self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@@ -62,19 +63,19 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
- comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device),
+ operations.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU()
- ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device)
+ ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
- comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device)
+ operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
@@ -148,7 +149,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -156,12 +157,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
- comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
+ operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
@@ -245,7 +246,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -253,12 +254,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
- comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
+ operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
@@ -343,7 +344,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2)
class CrossAttention(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -351,12 +352,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
- self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
- comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
+ operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
@@ -399,7 +400,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
@@ -409,11 +410,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
- self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
+ self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@@ -450,7 +451,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out)
class CrossAttentionPytorch(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -458,11 +459,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads
self.dim_head = dim_head
- self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
- self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
+ self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@@ -508,14 +509,14 @@ else:
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
- disable_self_attn=False, dtype=None, device=None):
+ disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
- context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device)
+ context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
- heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none
+ heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
@@ -648,7 +649,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
- use_checkpoint=True, dtype=None, device=None):
+ use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
@@ -656,26 +657,26 @@ class SpatialTransformer(nn.Module):
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device)
if not use_linear:
- self.proj_in = nn.Conv2d(in_channels,
+ self.proj_in = operations.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0, dtype=dtype, device=device)
else:
- self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
+ self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device)
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
for d in range(depth)]
)
if not use_linear:
- self.proj_out = nn.Conv2d(inner_dim,in_channels,
+ self.proj_out = operations.Conv2d(inner_dim,in_channels,
kernel_size=1,
stride=1,
padding=0, dtype=dtype, device=device)
else:
- self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
+ self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}):
diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py
index 90c153465..11cec0eda 100644
--- a/comfy/ldm/modules/diffusionmodules/openaimodel.py
+++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py
@@ -8,8 +8,6 @@ import torch.nn.functional as F
from .util import (
checkpoint,
- conv_nd,
- linear,
avg_pool_nd,
zero_module,
normalization,
@@ -17,7 +15,7 @@ from .util import (
)
from ..attention import SpatialTransformer
from comfy.ldm.util import exists
-
+import comfy.ops
class TimestepBlock(nn.Module):
"""
@@ -72,14 +70,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
+ self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels
@@ -108,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -116,7 +114,7 @@ class Downsample(nn.Module):
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
- self.op = conv_nd(
+ self.op = operations.conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
)
else:
@@ -158,6 +156,7 @@ class ResBlock(TimestepBlock):
down=False,
dtype=None,
device=None,
+ operations=comfy.ops
):
super().__init__()
self.channels = channels
@@ -171,7 +170,7 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
- conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
+ operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
)
self.updown = up or down
@@ -187,7 +186,7 @@ class ResBlock(TimestepBlock):
self.emb_layers = nn.Sequential(
nn.SiLU(),
- linear(
+ operations.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
@@ -197,18 +196,18 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
+ operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
- self.skip_connection = conv_nd(
+ self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
)
else:
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
+ self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb):
"""
@@ -317,6 +316,7 @@ class UNetModel(nn.Module):
adm_in_channels=None,
transformer_depth_middle=None,
device=None,
+ operations=comfy.ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@@ -379,9 +379,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
- linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
)
if self.num_classes is not None:
@@ -394,9 +394,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
- linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(),
- linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
)
)
else:
@@ -405,7 +405,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
]
)
@@ -426,6 +426,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
+ operations=operations,
)
]
ch = mult * model_channels
@@ -447,7 +448,7 @@ class UNetModel(nn.Module):
layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -468,10 +469,11 @@ class UNetModel(nn.Module):
down=True,
dtype=self.dtype,
device=device,
+ operations=operations
)
if resblock_updown
else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
)
)
)
@@ -498,11 +500,12 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
+ operations=operations
),
SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
),
ResBlock(
ch,
@@ -513,6 +516,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
+ operations=operations
),
)
self._feature_size += ch
@@ -532,6 +536,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
+ operations=operations
)
]
ch = model_channels * mult
@@ -554,7 +559,7 @@ class UNetModel(nn.Module):
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
)
if level and i == self.num_res_blocks[level]:
@@ -571,9 +576,10 @@ class UNetModel(nn.Module):
up=True,
dtype=self.dtype,
device=device,
+ operations=operations
)
if resblock_updown
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device)
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
@@ -582,12 +588,12 @@ class UNetModel(nn.Module):
self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(),
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
+ zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
- conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
+ operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
diff --git a/comfy/model_base.py b/comfy/model_base.py
index ad661ec7d..979e2c65e 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -148,13 +148,20 @@ class SDInpaint(BaseModel):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("mask", "masked_image")
+def sdxl_pooled(args, noise_augmentor):
+ if "unclip_conditioning" in args:
+ return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
+ else:
+ return args["pooled_output"]
+
class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
+ self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs):
- clip_pooled = kwargs["pooled_output"]
+ clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
@@ -178,9 +185,10 @@ class SDXL(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
+ self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs):
- clip_pooled = kwargs["pooled_output"]
+ clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 49ee9ea70..0edc4f180 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -121,9 +121,20 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
return model_config_from_unet_config(unet_config)
-def model_config_from_diffusers_unet(state_dict, use_fp16):
+def unet_config_from_diffusers_unet(state_dict, use_fp16):
match = {}
- match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1]
+ attention_resolutions = []
+
+ attn_res = 1
+ for i in range(5):
+ k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i)
+ if k in state_dict:
+ match["context_dim"] = state_dict[k].shape[1]
+ attention_resolutions.append(attn_res)
+ attn_res *= 2
+
+ match["attention_resolutions"] = attention_resolutions
+
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["adm_in_channels"] = None
@@ -135,22 +146,22 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
- 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048}
+ 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
- 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280}
+ 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
- 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
+ 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
- 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
+ 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
@@ -160,9 +171,20 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
- 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768}
+ 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
- supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl]
+ SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
+ 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
+ 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
+
+ SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
+ 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
+ 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
+
+
+ supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet]
for unet_config in supported_models:
matches = True
@@ -171,5 +193,11 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
matches = False
break
if matches:
- return model_config_from_unet_config(unet_config)
+ return unet_config
+ return None
+
+def model_config_from_diffusers_unet(state_dict, use_fp16):
+ unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
+ if unet_config is not None:
+ return model_config_from_unet_config(unet_config)
return None
diff --git a/comfy/model_management.py b/comfy/model_management.py
index e634f10d2..9b1b7f55d 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -2,6 +2,7 @@ import psutil
from enum import Enum
from comfy.cli_args import args
import torch
+import sys
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -87,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
+ stats = torch.xpu.memory_stats(dev)
+ mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory
- mem_total_torch = mem_total
+ mem_total_torch = mem_reserved
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -201,8 +204,13 @@ if cpu_state == CPUState.MPS:
print(f"Set vram state to: {vram_state.name}")
+DISABLE_SMART_MEMORY = args.disable_smart_memory
+
+if DISABLE_SMART_MEMORY:
+ print("Disabling smart memory management")
def get_torch_device_name(device):
+ global xpu_available
if hasattr(device, 'type'):
if device.type == "cuda":
try:
@@ -212,6 +220,8 @@ def get_torch_device_name(device):
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
else:
return "{}".format(device.type)
+ elif xpu_available:
+ return "{} {}".format(device, torch.xpu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@@ -221,132 +231,168 @@ except:
print("Could not pick default device.")
-current_loaded_model = None
-current_gpu_controlnets = []
+current_loaded_models = []
-model_accelerated = False
+class LoadedModel:
+ def __init__(self, model):
+ self.model = model
+ self.model_accelerated = False
+ self.device = model.load_device
+ def model_memory(self):
+ return self.model.model_size()
-def unload_model():
- global current_loaded_model
- global model_accelerated
- global current_gpu_controlnets
- global vram_state
+ def model_memory_required(self, device):
+ if device == self.model.current_device:
+ return 0
+ else:
+ return self.model_memory()
- if current_loaded_model is not None:
- if model_accelerated:
- accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
- model_accelerated = False
+ def model_load(self, lowvram_model_memory=0):
+ global xpu_available
+ patch_model_to = None
+ if lowvram_model_memory == 0:
+ patch_model_to = self.device
- current_loaded_model.unpatch_model()
- current_loaded_model.model.to(current_loaded_model.offload_device)
- current_loaded_model.model_patches_to(current_loaded_model.offload_device)
- current_loaded_model = None
- if vram_state != VRAMState.HIGH_VRAM:
- soft_empty_cache()
+ self.model.model_patches_to(self.device)
+ self.model.model_patches_to(self.model.model_dtype())
- if vram_state != VRAMState.HIGH_VRAM:
- if len(current_gpu_controlnets) > 0:
- for n in current_gpu_controlnets:
- n.cpu()
- current_gpu_controlnets = []
+ try:
+ self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
+ except Exception as e:
+ self.model.unpatch_model(self.model.offload_device)
+ self.model_unload()
+ raise e
+
+ if lowvram_model_memory > 0:
+ print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
+ device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
+ accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
+ self.model_accelerated = True
+
+ if xpu_available and not args.disable_ipex_optimize:
+ self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
+
+ return self.real_model
+
+ def model_unload(self):
+ if self.model_accelerated:
+ accelerate.hooks.remove_hook_from_submodules(self.real_model)
+ self.model_accelerated = False
+
+ self.model.unpatch_model(self.model.offload_device)
+ self.model.model_patches_to(self.model.offload_device)
+
+ def __eq__(self, other):
+ return self.model is other.model
def minimum_inference_memory():
- return (768 * 1024 * 1024)
+ return (1024 * 1024 * 1024)
+
+def unload_model_clones(model):
+ to_unload = []
+ for i in range(len(current_loaded_models)):
+ if model.is_clone(current_loaded_models[i].model):
+ to_unload = [i] + to_unload
+
+ for i in to_unload:
+ print("unload clone", i)
+ current_loaded_models.pop(i).model_unload()
+
+def free_memory(memory_required, device, keep_loaded=[]):
+ unloaded_model = False
+ for i in range(len(current_loaded_models) -1, -1, -1):
+ if DISABLE_SMART_MEMORY:
+ current_free_mem = 0
+ else:
+ current_free_mem = get_free_memory(device)
+ if current_free_mem > memory_required:
+ break
+ shift_model = current_loaded_models[i]
+ if shift_model.device == device:
+ if shift_model not in keep_loaded:
+ current_loaded_models.pop(i).model_unload()
+ unloaded_model = True
+
+ if unloaded_model:
+ soft_empty_cache()
+
+
+def load_models_gpu(models, memory_required=0):
+ global vram_state
+
+ inference_memory = minimum_inference_memory()
+ extra_mem = max(inference_memory, memory_required)
+
+ models_to_load = []
+ models_already_loaded = []
+ for x in models:
+ loaded_model = LoadedModel(x)
+
+ if loaded_model in current_loaded_models:
+ index = current_loaded_models.index(loaded_model)
+ current_loaded_models.insert(0, current_loaded_models.pop(index))
+ models_already_loaded.append(loaded_model)
+ else:
+ models_to_load.append(loaded_model)
+
+ if len(models_to_load) == 0:
+ devs = set(map(lambda a: a.device, models_already_loaded))
+ for d in devs:
+ if d != torch.device("cpu"):
+ free_memory(extra_mem, d, models_already_loaded)
+ return
+
+ print("loading new")
+
+ total_memory_required = {}
+ for loaded_model in models_to_load:
+ unload_model_clones(loaded_model.model)
+ total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
+
+ for device in total_memory_required:
+ if device != torch.device("cpu"):
+ free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
+
+ for loaded_model in models_to_load:
+ model = loaded_model.model
+ torch_dev = model.load_device
+ if is_device_cpu(torch_dev):
+ vram_set_state = VRAMState.DISABLED
+ else:
+ vram_set_state = vram_state
+ lowvram_model_memory = 0
+ if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
+ model_size = loaded_model.model_memory_required(torch_dev)
+ current_free_mem = get_free_memory(torch_dev)
+ lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
+ if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
+ vram_set_state = VRAMState.LOW_VRAM
+ else:
+ lowvram_model_memory = 0
+
+ if vram_set_state == VRAMState.NO_VRAM:
+ lowvram_model_memory = 256 * 1024 * 1024
+
+ cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
+ current_loaded_models.insert(0, loaded_model)
+ return
+
def load_model_gpu(model):
- global current_loaded_model
- global vram_state
- global model_accelerated
+ return load_models_gpu([model])
- if model is current_loaded_model:
- return
- unload_model()
+def cleanup_models():
+ to_delete = []
+ for i in range(len(current_loaded_models)):
+ print(sys.getrefcount(current_loaded_models[i].model))
+ if sys.getrefcount(current_loaded_models[i].model) <= 2:
+ to_delete = [i] + to_delete
- torch_dev = model.load_device
- model.model_patches_to(torch_dev)
- model.model_patches_to(model.model_dtype())
- current_loaded_model = model
-
- if is_device_cpu(torch_dev):
- vram_set_state = VRAMState.DISABLED
- else:
- vram_set_state = vram_state
-
- if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
- model_size = model.model_size()
- current_free_mem = get_free_memory(torch_dev)
- lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
- if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
- vram_set_state = VRAMState.LOW_VRAM
-
- real_model = model.model
- patch_model_to = None
- if vram_set_state == VRAMState.DISABLED:
- pass
- elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
- model_accelerated = False
- patch_model_to = torch_dev
-
- try:
- real_model = model.patch_model(device_to=patch_model_to)
- except Exception as e:
- model.unpatch_model()
- unload_model()
- raise e
-
- if patch_model_to is not None:
- real_model.to(torch_dev)
-
- if vram_set_state == VRAMState.NO_VRAM:
- device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
- accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
- model_accelerated = True
- elif vram_set_state == VRAMState.LOW_VRAM:
- device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
- accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
- model_accelerated = True
-
- return current_loaded_model
-
-def load_controlnet_gpu(control_models):
- global current_gpu_controlnets
- global vram_state
- if vram_state == VRAMState.DISABLED:
- return
-
- if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
- for m in control_models:
- if hasattr(m, 'set_lowvram'):
- m.set_lowvram(True)
- #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
- return
-
- models = []
- for m in control_models:
- models += m.get_models()
-
- for m in current_gpu_controlnets:
- if m not in models:
- m.cpu()
-
- device = get_torch_device()
- current_gpu_controlnets = []
- for m in models:
- current_gpu_controlnets.append(m.to(device))
-
-
-def load_if_low_vram(model):
- global vram_state
- if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
- return model.to(get_torch_device())
- return model
-
-def unload_if_low_vram(model):
- global vram_state
- if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
- return model.cpu()
- return model
+ for i in to_delete:
+ x = current_loaded_models.pop(i)
+ x.model_unload()
+ del x
def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM:
@@ -354,6 +400,28 @@ def unet_offload_device():
else:
return torch.device("cpu")
+def unet_inital_load_device(parameters, dtype):
+ torch_dev = get_torch_device()
+ if vram_state == VRAMState.HIGH_VRAM:
+ return torch_dev
+
+ cpu_dev = torch.device("cpu")
+ if DISABLE_SMART_MEMORY:
+ return cpu_dev
+
+ dtype_size = 4
+ if dtype == torch.float16 or dtype == torch.bfloat16:
+ dtype_size = 2
+
+ model_size = dtype_size * parameters
+
+ mem_dev = get_free_memory(torch_dev)
+ mem_cpu = get_free_memory(cpu_dev)
+ if mem_dev > mem_cpu and model_size < mem_dev:
+ return torch_dev
+ else:
+ return cpu_dev
+
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
@@ -441,8 +509,12 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
elif xpu_available:
- mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
- mem_free_torch = mem_free_total
+ stats = torch.xpu.memory_stats(dev)
+ mem_active = stats['active_bytes.all.current']
+ mem_allocated = stats['allocated_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@@ -456,6 +528,13 @@ def get_free_memory(dev=None, torch_free_too=False):
else:
return mem_free_total
+def batch_area_memory(area):
+ if xformers_enabled() or pytorch_attention_flash_attention():
+ #TODO: these formulas are copied from maximum_batch_area below
+ return (area / 20) * (1024 * 1024)
+ else:
+ return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
+
def maximum_batch_area():
global vram_state
if vram_state == VRAMState.NO_VRAM:
@@ -507,9 +586,12 @@ def should_use_fp16(device=None, model_params=0):
if directml_enabled:
return False
- if cpu_mode() or mps_mode() or xpu_available:
+ if cpu_mode() or mps_mode():
return False #TODO ?
+ if xpu_available:
+ return True
+
if torch.cuda.is_bf16_supported():
return True
diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py
index e90f6d163..150497662 100644
--- a/comfy/nodes/package.py
+++ b/comfy/nodes/package.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import importlib
+import os
import pkgutil
import time
import types
@@ -14,6 +15,7 @@ except:
custom_nodes = None
from comfy.nodes.package_typing import ExportedNodes
from functools import reduce
+from pkg_resources import resource_filename
_comfy_nodes = ExportedNodes()
@@ -21,10 +23,19 @@ _comfy_nodes = ExportedNodes()
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
+ web_directory = getattr(module, "WEB_DIRECTORY", None)
if node_class_mappings:
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
if node_display_names:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
+ if web_directory:
+ # load the extension resources path
+ abs_web_directory = os.path.abspath(resource_filename(module.__name__, web_directory))
+ if not os.path.isdir(abs_web_directory):
+ abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory))
+ if not os.path.isdir(abs_web_directory):
+ raise ImportError(path=abs_web_directory)
+ exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory
def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False) -> ExportedNodes:
diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py
index bf83c0a7e..6cb8d212f 100644
--- a/comfy/nodes/package_typing.py
+++ b/comfy/nodes/package_typing.py
@@ -22,22 +22,27 @@ class CustomNode(Protocol):
class ExportedNodes:
NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict)
+ EXTENSION_WEB_DIRS: Dict[str, str] = field(default_factory=dict)
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
+ self.EXTENSION_WEB_DIRS.update(exported_nodes.EXTENSION_WEB_DIRS)
return self
def __len__(self):
return len(self.NODE_CLASS_MAPPINGS)
- def __sub__(self, other):
+ def __sub__(self, other: ExportedNodes):
exported_nodes = ExportedNodes().update(self)
for self_key in exported_nodes.NODE_CLASS_MAPPINGS:
if self_key in other.NODE_CLASS_MAPPINGS:
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
+ for self_key in exported_nodes.EXTENSION_WEB_DIRS:
+ if self_key in other.EXTENSION_WEB_DIRS:
+ exported_nodes.EXTENSION_WEB_DIRS.pop(self_key)
return exported_nodes
def __add__(self, other):
diff --git a/comfy/ops.py b/comfy/ops.py
index 2e72030bd..678c2c6d0 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -21,6 +21,11 @@ class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
+def conv_nd(dims, *args, **kwargs):
+ if dims == 2:
+ return Conv2d(*args, **kwargs)
+ else:
+ raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way
diff --git a/comfy/sample.py b/comfy/sample.py
index 48530f132..d7292024e 100644
--- a/comfy/sample.py
+++ b/comfy/sample.py
@@ -51,19 +51,24 @@ def get_models_from_cond(cond, model_type):
models += [c[1][model_type]]
return models
-def load_additional_models(positive, negative, dtype):
+def get_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning"""
- control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
+ control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
+
+ control_models = []
+ for m in control_nets:
+ control_models += m.get_models()
+
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
- gligen = [x[1].to(dtype) for x in gligen]
- models = control_nets + gligen
- comfy.model_management.load_controlnet_gpu(models)
+ gligen = [x[1] for x in gligen]
+ models = control_models + gligen
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
- m.cleanup()
+ if hasattr(m, 'cleanup'):
+ m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
device = comfy.model_management.get_torch_device()
@@ -72,7 +77,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None
- comfy.model_management.load_model_gpu(model)
+ models = get_additional_models(positive, negative)
+ comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]))
real_model = model.model
noise = noise.to(device)
@@ -81,7 +87,6 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device)
- models = load_additional_models(positive, negative, model.model_dtype())
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
diff --git a/comfy/samplers.py b/comfy/samplers.py
index 28cd46667..134336de6 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -88,9 +88,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
- gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
+ gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
- gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
+ gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
@@ -478,7 +478,7 @@ def pre_run_control(model, conds):
timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]:
- x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function)
+ x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
diff --git a/comfy/sd.py b/comfy/sd.py
index bff9ee141..b0482c782 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -243,8 +243,15 @@ def set_attr(obj, attr, value):
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
+def get_attr(obj, attr):
+ attrs = attr.split(".")
+ for name in attrs:
+ obj = getattr(obj, name)
+ return obj
+
+
class ModelPatcher:
- def __init__(self, model, load_device, offload_device, size=0):
+ def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size
self.model = model
self.patches = {}
@@ -253,6 +260,10 @@ class ModelPatcher:
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
+ if current_device is None:
+ self.current_device = self.offload_device
+ else:
+ self.current_device = current_device
def model_size(self):
if self.size > 0:
@@ -267,7 +278,7 @@ class ModelPatcher:
return size
def clone(self):
- n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
+ n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -276,6 +287,11 @@ class ModelPatcher:
n.model_keys = self.model_keys
return n
+ def is_clone(self, other):
+ if hasattr(other, 'model') and self.model is other.model:
+ return True
+ return False
+
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
@@ -390,6 +406,11 @@ class ModelPatcher:
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight)
del temp_weight
+
+ if device_to is not None:
+ self.model.to(device_to)
+ self.current_device = device_to
+
return self.model
def calculate_weight(self, patches, weight, key):
@@ -482,7 +503,7 @@ class ModelPatcher:
return weight
- def unpatch_model(self):
+ def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
for k in keys:
@@ -490,6 +511,11 @@ class ModelPatcher:
self.backup = {}
+ if device_to is not None:
+ self.model.to(device_to)
+ self.current_device = device_to
+
+
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = model_lora_keys_unet(model.model)
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
@@ -555,7 +581,7 @@ class CLIP:
else:
self.cond_stage_model.reset_clip_layer()
- model_management.load_model_gpu(self.patcher)
+ self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
if return_pooled:
return cond, pooled
@@ -571,11 +597,9 @@ class CLIP:
def get_sd(self):
return self.cond_stage_model.state_dict()
- def patch_model(self):
- self.patcher.patch_model()
-
- def unpatch_model(self):
- self.patcher.unpatch_model()
+ def load_model(self):
+ model_management.load_model_gpu(self.patcher)
+ return self.patcher
def get_key_patches(self):
return self.patcher.get_key_patches()
@@ -630,11 +654,12 @@ class VAE:
return samples
def decode(self, samples_in):
- model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
try:
+ memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
+ model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
- batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64))
+ batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
@@ -650,19 +675,19 @@ class VAE:
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
- model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1)
def encode(self, pixel_samples):
- model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
+ memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
+ model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
- batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
+ batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number):
@@ -677,7 +702,6 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
- model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
@@ -757,6 +781,7 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device)
self.control_model = control_model
+ self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number):
@@ -780,19 +805,14 @@ class ControlNet(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
- if self.control_model.dtype == torch.float16:
- precision_scope = torch.autocast
- else:
- precision_scope = contextlib.nullcontext
- with precision_scope(model_management.get_autocast_device(self.device)):
- self.control_model = model_management.load_if_low_vram(self.control_model)
- context = torch.cat(cond['c_crossattn'], 1)
- y = cond.get('c_adm', None)
- control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
- self.control_model = model_management.unload_if_low_vram(self.control_model)
+ context = torch.cat(cond['c_crossattn'], 1)
+ y = cond.get('c_adm', None)
+ if y is not None:
+ y = y.to(self.control_model.dtype)
+ control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
+
out = {'middle':[], 'output': []}
- autocast_enabled = torch.is_autocast_enabled()
for i in range(len(control)):
if i == (len(control) - 1):
@@ -806,7 +826,7 @@ class ControlNet(ControlBase):
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
- if x.dtype != output_dtype and not autocast_enabled:
+ if x.dtype != output_dtype:
x = x.to(output_dtype)
if control_prev is not None and key in control_prev:
@@ -825,17 +845,133 @@ class ControlNet(ControlBase):
def get_models(self):
out = super().get_models()
- out.append(self.control_model)
+ out.append(self.control_model_wrapped)
return out
+class ControlLoraOps:
+ class Linear(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
+ device=None, dtype=None) -> None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = None
+ self.up = None
+ self.down = None
+ self.bias = None
+
+ def forward(self, input):
+ if self.up is not None:
+ return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
+ else:
+ return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
+
+ class Conv2d(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ padding_mode='zeros',
+ device=None,
+ dtype=None
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.transposed = False
+ self.output_padding = 0
+ self.groups = groups
+ self.padding_mode = padding_mode
+
+ self.weight = None
+ self.bias = None
+ self.up = None
+ self.down = None
+
+
+ def forward(self, input):
+ if self.up is not None:
+ return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
+ else:
+ return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+ def conv_nd(self, dims, *args, **kwargs):
+ if dims == 2:
+ return self.Conv2d(*args, **kwargs)
+ else:
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class ControlLora(ControlNet):
+ def __init__(self, control_weights, global_average_pooling=False, device=None):
+ ControlBase.__init__(self, device)
+ self.control_weights = control_weights
+ self.global_average_pooling = global_average_pooling
+
+ def pre_run(self, model, percent_to_timestep_function):
+ super().pre_run(model, percent_to_timestep_function)
+ controlnet_config = model.model_config.unet_config.copy()
+ controlnet_config.pop("out_channels")
+ controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
+ controlnet_config["operations"] = ControlLoraOps()
+ self.control_model = cldm.ControlNet(**controlnet_config)
+ if model_management.should_use_fp16():
+ self.control_model.half()
+ self.control_model.to(model_management.get_torch_device())
+ diffusion_model = model.diffusion_model
+ sd = diffusion_model.state_dict()
+ cm = self.control_model.state_dict()
+
+ for k in sd:
+ weight = sd[k]
+ if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
+ key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
+ op = get_attr(diffusion_model, '.'.join(key_split[:-1]))
+ weight = op._hf_hook.weights_map[key_split[-1]]
+
+ try:
+ set_attr(self.control_model, k, weight)
+ except:
+ pass
+
+ for k in self.control_weights:
+ if k not in {"lora_controlnet"}:
+ set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device()))
+
+ def copy(self):
+ c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
+ self.copy_to(c)
+ return c
+
+ def cleanup(self):
+ del self.control_model
+ self.control_model = None
+ super().cleanup()
+
+ def get_models(self):
+ out = ControlBase.get_models(self)
+ return out
def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
+ if "lora_controlnet" in controlnet_data:
+ return ControlLora(controlnet_data)
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = model_management.should_use_fp16()
- controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config
+ controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@@ -874,6 +1010,9 @@ def load_controlnet(ckpt_path, model=None):
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
+ leftover_keys = controlnet_data.keys()
+ if len(leftover_keys) > 0:
+ print("leftover keys:", leftover_keys)
controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight'
@@ -901,8 +1040,8 @@ def load_controlnet(ckpt_path, model=None):
if pth:
if 'difference' in controlnet_data:
if model is not None:
- m = model.patch_model()
- model_sd = m.state_dict()
+ model_management.load_models_gpu([model])
+ model_sd = model.model_state_dict()
for x in controlnet_data:
c_m = "control_model."
if x.startswith(c_m):
@@ -910,7 +1049,6 @@ def load_controlnet(ckpt_path, model=None):
if sd_key in model_sd:
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
- model.unpatch_model()
else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
@@ -970,11 +1108,10 @@ class T2IAdapter(ControlBase):
output_dtype = x_noisy.dtype
out = {'input':[]}
- autocast_enabled = torch.is_autocast_enabled()
for i in range(len(self.control_input)):
key = 'input'
x = self.control_input[i] * self.strength
- if x.dtype != output_dtype and not autocast_enabled:
+ if x.dtype != output_dtype:
x = x.to(output_dtype)
if control_prev is not None and key in control_prev:
@@ -1001,7 +1138,6 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
-
def load_t2i_adapter(t2i_data):
keys = t2i_data.keys()
if 'adapter' in keys:
@@ -1087,7 +1223,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
- return model
+ return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually
@@ -1199,8 +1335,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
+ dtype = torch.float32
+ if fp16:
+ dtype = torch.float16
+
+ inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
offload_device = model_management.unet_offload_device()
- model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
+ model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
@@ -1221,7 +1362,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0:
print("left over keys:", left_over)
- return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
+ model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
+ if inital_load_device != torch.device("cpu"):
+ print("loaded straight to GPU")
+ model_management.load_model_gpu(model_patcher)
+
+ return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format
@@ -1249,14 +1395,6 @@ def load_unet(unet_path): #load unet in diffusers format
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None):
- try:
- model.patch_model()
- clip.patch_model()
- sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
- utils.save_torch_file(sd, output_path, metadata=metadata)
- model.unpatch_model()
- clip.unpatch_model()
- except Exception as e:
- model.unpatch_model()
- clip.unpatch_model()
- raise e
+ model_management.load_models_gpu([model, clip.load_model()])
+ sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
+ utils.save_torch_file(sd, output_path, metadata=metadata)
diff --git a/comfy_extras/nodes/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py
index 8194a89d7..acdc0d96b 100644
--- a/comfy_extras/nodes/nodes_mask.py
+++ b/comfy_extras/nodes/nodes_mask.py
@@ -1,3 +1,5 @@
+import numpy as np
+from scipy.ndimage import grey_dilation
import torch
from comfy.nodes.common import MAX_RESOLUTION
@@ -277,6 +279,35 @@ class FeatherMask:
output[-y, :] *= feather_rate
return (output,)
+
+class GrowMask:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "mask": ("MASK",),
+ "expand": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
+ "tapered_corners": ("BOOLEAN", {"default": True}),
+ },
+ }
+
+ CATEGORY = "mask"
+
+ RETURN_TYPES = ("MASK",)
+
+ FUNCTION = "expand_mask"
+
+ def expand_mask(self, mask, expand, tapered_corners):
+ c = 0 if tapered_corners else 1
+ kernel = np.array([[c, 1, c],
+ [1, 1, 1],
+ [c, 1, c]])
+ output = mask.numpy().copy()
+ while expand > 0:
+ output = grey_dilation(output, footprint=kernel)
+ expand -= 1
+ output = torch.from_numpy(output)
+ return (output,)
@@ -290,6 +321,7 @@ NODE_CLASS_MAPPINGS = {
"CropMask": CropMask,
"MaskComposite": MaskComposite,
"FeatherMask": FeatherMask,
+ "GrowMask": GrowMask,
}
NODE_DISPLAY_NAME_MAPPINGS = {
diff --git a/comfy_extras/nodes/nodes_post_processing.py b/comfy_extras/nodes/nodes_post_processing.py
index a138b292e..51bdb24fa 100644
--- a/comfy_extras/nodes/nodes_post_processing.py
+++ b/comfy_extras/nodes/nodes_post_processing.py
@@ -2,6 +2,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
+import math
import comfy.utils
@@ -209,9 +210,36 @@ class Sharpen:
return (result,)
+class ImageScaleToTotalPixels:
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
+ crop_methods = ["disabled", "center"]
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
+ "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "upscale"
+
+ CATEGORY = "image/upscaling"
+
+ def upscale(self, image, upscale_method, megapixels):
+ samples = image.movedim(-1,1)
+ total = int(megapixels * 1024 * 1024)
+
+ scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
+ width = round(samples.shape[3] * scale_by)
+ height = round(samples.shape[2] * scale_by)
+
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
+ s = s.movedim(1,-1)
+ return (s,)
+
NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend,
"ImageBlur": Blur,
"ImageQuantize": Quantize,
"ImageSharpen": Sharpen,
+ "ImageScaleToTotalPixels": ImageScaleToTotalPixels,
}
diff --git a/web/index.html b/web/index.html
index 71067d993..41bc246c0 100644
--- a/web/index.html
+++ b/web/index.html
@@ -6,6 +6,7 @@
+