merge with upstream

This commit is contained in:
Benjamin Berman 2023-08-21 12:00:59 -07:00
commit 30ed9fc143
25 changed files with 713 additions and 332 deletions

View File

@ -31,7 +31,7 @@ jobs:
echo 'import site' >> ./python311._pth echo 'import site' >> ./python311._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python311._pth sed -i '1i../ComfyUI' ./python311._pth

View File

@ -6,8 +6,6 @@ import torch as th
import torch.nn as nn import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import ( from ..ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module, zero_module,
timestep_embedding, timestep_embedding,
) )
@ -15,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists from ..ldm.util import exists
import comfy.ops
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):
#implemented in the ldm unet #implemented in the ldm unet
@ -55,6 +53,8 @@ class ControlNet(nn.Module):
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
device=None,
operations=comfy.ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@ -117,9 +117,9 @@ class ControlNet(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -132,9 +132,9 @@ class ControlNet(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim), operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
) )
else: else:
@ -143,28 +143,28 @@ class ControlNet(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1) operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
] ]
) )
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
self.input_hint_block = TimestepEmbedSequential( self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1), operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1), operations.conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2), operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1), operations.conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2), operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1), operations.conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2), operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
) )
self._feature_size = model_channels self._feature_size = model_channels
@ -182,6 +182,7 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -204,11 +205,11 @@ class ControlNet(nn.Module):
SpatialTransformer( SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint, operations=operations
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self._feature_size += ch self._feature_size += ch
input_block_chans.append(ch) input_block_chans.append(ch)
if level != len(channel_mult) - 1: if level != len(channel_mult) - 1:
@ -224,16 +225,17 @@ class ControlNet(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
operations=operations
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
) )
) )
) )
ch = out_ch ch = out_ch
input_block_chans.append(ch) input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
ds *= 2 ds *= 2
self._feature_size += ch self._feature_size += ch
@ -253,11 +255,12 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
), ),
SpatialTransformer( # always uses a self-attn SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint, operations=operations
), ),
ResBlock( ResBlock(
ch, ch,
@ -266,16 +269,17 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
), ),
) )
self.middle_block_out = self.make_zero_conv(ch) self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self._feature_size += ch self._feature_size += ch
def make_zero_conv(self, channels): def make_zero_conv(self, channels, operations=None):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, y=None, **kwargs): def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context) guided_hint = self.input_hint_block(hint, emb, context)
@ -283,9 +287,6 @@ class ControlNet(nn.Module):
outs = [] outs = []
hs = [] hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)

View File

@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
class LatentPreviewMethod(enum.Enum): class LatentPreviewMethod(enum.Enum):
NoPreviews = "none" NoPreviews = "none"
Auto = "auto" Auto = "auto"
@ -82,6 +84,9 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn'
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")

View File

@ -25,6 +25,7 @@ class ClipVisionModel():
def encode_image(self, image): def encode_image(self, image):
img = torch.clip((255. * image), 0, 255).round().int() img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt") inputs = self.processor(images=img, return_tensors="pt")
outputs = self.model(**inputs) outputs = self.model(**inputs)
return outputs return outputs
@ -49,18 +50,22 @@ def convert_to_transformers(sd, prefix):
if "{}proj".format(prefix) in sd_k: if "{}proj".format(prefix) in sd_k:
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, prefix, "vision_model.", 32) sd = transformers_convert(sd, prefix, "vision_model.", 48)
return sd return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False): def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if convert_keys: if convert_keys:
sd = convert_to_transformers(sd, prefix) sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else: else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
clip = ClipVisionModel(json_config) clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd) m, u = clip.load_sd(sd)
if len(m) > 0:
print("missing clip vision:", m)
u = set(u) u = set(u)
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
@ -71,4 +76,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
def load(ckpt_path): def load(ckpt_path):
sd = load_torch_file(ckpt_path) sd = load_torch_file(ckpt_path)
return load_clipvision_from_sd(sd) if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
else:
return load_clipvision_from_sd(sd)

View File

@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1664,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 8192,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 48,
"patch_size": 14,
"projection_dim": 1280,
"torch_dtype": "float32"
}

View File

@ -409,6 +409,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x) d = self.outputs_ui.pop(x)
del d del d
comfy.model_management.cleanup_models()
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id},
self.server.client_id) self.server.client_id)

View File

@ -4,6 +4,7 @@ import glob
import struct import struct
import sys import sys
import shutil import shutil
from urllib.parse import quote
from PIL import Image, ImageOps from PIL import Image, ImageOps
from io import BytesIO from io import BytesIO
@ -80,6 +81,8 @@ class PromptServer():
mimetypes.init() mimetypes.init()
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None self.prompt_queue = None
self.loop = loop self.loop = loop
self.messages = asyncio.Queue() self.messages = asyncio.Queue()
@ -140,9 +143,17 @@ class PromptServer():
@routes.get("/extensions") @routes.get("/extensions")
async def get_extensions(request): async def get_extensions(request):
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) files = glob.glob(os.path.join(
return web.json_response( self.web_root, 'extensions/**/*.js'), recursive=True)
list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True)
extensions.extend(list(map(lambda f: "/extensions/" + quote(
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
return web.json_response(extensions)
def get_dir_by_type(dir_type=None): def get_dir_by_type(dir_type=None):
type_dir = "" type_dir = ""
@ -638,6 +649,12 @@ class PromptServer():
def add_routes(self): def add_routes(self):
self.app.add_routes(self.routes) self.app.add_routes(self.routes)
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([
web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True),
])
self.app.add_routes([ self.app.add_routes([
web.static('/', self.web_root, follow_symlinks=True), web.static('/', self.web_root, follow_symlinks=True),
]) ])

View File

@ -244,30 +244,15 @@ class Gligen(nn.Module):
self.position_net = position_net self.position_net = position_net
self.key_dim = key_dim self.key_dim = key_dim
self.max_objs = 30 self.max_objs = 30
self.lowvram = False self.current_device = torch.device("cpu")
def _set_position(self, boxes, masks, positive_embeddings): def _set_position(self, boxes, masks, positive_embeddings):
if self.lowvram == True:
self.position_net.to(boxes.device)
objs = self.position_net(boxes, masks, positive_embeddings) objs = self.position_net(boxes, masks, positive_embeddings)
def func(x, extra_options):
if self.lowvram == True: key = extra_options["transformer_index"]
self.position_net.cpu() module = self.module_list[key]
def func_lowvram(x, extra_options): return module(x, objs)
key = extra_options["transformer_index"] return func
module = self.module_list[key]
module.to(x.device)
r = module(x, objs)
module.cpu()
return r
return func_lowvram
else:
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device): def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape batch, c, h, w = latent_image_shape
@ -312,14 +297,6 @@ class Gligen(nn.Module):
masks.to(device), masks.to(device),
conds.to(device)) conds.to(device))
def set_lowvram(self, value=True):
self.lowvram = value
def cleanup(self):
self.lowvram = False
def get_models(self):
return [self]
def load_gligen(sd): def load_gligen(sd):
sd_k = sd.keys() sd_k = sd.keys()

View File

@ -649,7 +649,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None denoised_1, denoised_2 = None, None
h_1, h_2 = None, None h, h_1, h_2 = None, None, None
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)

View File

@ -10,13 +10,14 @@ from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
import comfy.ops
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops
# CrossAttn precision handling # CrossAttn precision handling
if args.dont_upcast_attention: if args.dont_upcast_attention:
print("disabling upcasting of attention") print("disabling upcasting of attention")
@ -52,9 +53,9 @@ def init_(tensor):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -62,19 +63,19 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device), operations.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU() nn.GELU()
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device) ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in,
nn.Dropout(dropout), nn.Dropout(dropout),
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device) operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
) )
def forward(self, x): def forward(self, x):
@ -148,7 +149,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module): class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -156,12 +157,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -245,7 +246,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module): class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -253,12 +254,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -343,7 +344,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2) return self.to_out(r2)
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -351,12 +352,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -399,7 +400,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.") f"{heads} heads.")
@ -409,11 +410,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -450,7 +451,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
class CrossAttentionPytorch(nn.Module): class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -458,11 +459,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -508,14 +509,14 @@ else:
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None): disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
@ -648,7 +649,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None): use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
@ -656,26 +657,26 @@ class SpatialTransformer(nn.Module):
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device) self.norm = Normalize(in_channels, dtype=dtype, device=device)
if not use_linear: if not use_linear:
self.proj_in = nn.Conv2d(in_channels, self.proj_in = operations.Conv2d(in_channels,
inner_dim, inner_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype, device=device) padding=0, dtype=dtype, device=device)
else: else:
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device) disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
self.proj_out = nn.Conv2d(inner_dim,in_channels, self.proj_out = operations.Conv2d(inner_dim,in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype, device=device) padding=0, dtype=dtype, device=device)
else: else:
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):

View File

@ -8,8 +8,6 @@ import torch.nn.functional as F
from .util import ( from .util import (
checkpoint, checkpoint,
conv_nd,
linear,
avg_pool_nd, avg_pool_nd,
zero_module, zero_module,
normalization, normalization,
@ -17,7 +15,7 @@ from .util import (
) )
from ..attention import SpatialTransformer from ..attention import SpatialTransformer
from comfy.ldm.util import exists from comfy.ldm.util import exists
import comfy.ops
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):
""" """
@ -72,14 +70,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device) self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
def forward(self, x, output_shape=None): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
@ -108,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -116,7 +114,7 @@ class Downsample(nn.Module):
self.dims = dims self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = operations.conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
) )
else: else:
@ -158,6 +156,7 @@ class ResBlock(TimestepBlock):
down=False, down=False,
dtype=None, dtype=None,
device=None, device=None,
operations=comfy.ops
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -171,7 +170,7 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
) )
self.updown = up or down self.updown = up or down
@ -187,7 +186,7 @@ class ResBlock(TimestepBlock):
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
linear( operations.Linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
), ),
@ -197,18 +196,18 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
), ),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -317,6 +316,7 @@ class UNetModel(nn.Module):
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
device=None, device=None,
operations=comfy.ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@ -379,9 +379,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -394,9 +394,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
) )
else: else:
@ -405,7 +405,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
] ]
) )
@ -426,6 +426,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -447,7 +448,7 @@ class UNetModel(nn.Module):
layers.append(SpatialTransformer( layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -468,10 +469,11 @@ class UNetModel(nn.Module):
down=True, down=True,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
) )
) )
) )
@ -498,11 +500,12 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
), ),
SpatialTransformer( # always uses a self-attn SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
), ),
ResBlock( ResBlock(
ch, ch,
@ -513,6 +516,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
), ),
) )
self._feature_size += ch self._feature_size += ch
@ -532,6 +536,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
] ]
ch = model_channels * mult ch = model_channels * mult
@ -554,7 +559,7 @@ class UNetModel(nn.Module):
SpatialTransformer( SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
) )
if level and i == self.num_res_blocks[level]: if level and i == self.num_res_blocks[level]:
@ -571,9 +576,10 @@ class UNetModel(nn.Module):
up=True, up=True,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device) else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -582,12 +588,12 @@ class UNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )

View File

@ -148,13 +148,20 @@ class SDInpaint(BaseModel):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.concat_keys = ("mask", "masked_image") self.concat_keys = ("mask", "masked_image")
def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
else:
return args["pooled_output"]
class SDXLRefiner(BaseModel): class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"] clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768) width = kwargs.get("width", 768)
height = kwargs.get("height", 768) height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0) crop_w = kwargs.get("crop_w", 0)
@ -178,9 +185,10 @@ class SDXL(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"] clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768) width = kwargs.get("width", 768)
height = kwargs.get("height", 768) height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0) crop_w = kwargs.get("crop_w", 0)

View File

@ -121,9 +121,20 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_fp16):
return model_config_from_unet_config(unet_config) return model_config_from_unet_config(unet_config)
def model_config_from_diffusers_unet(state_dict, use_fp16): def unet_config_from_diffusers_unet(state_dict, use_fp16):
match = {} match = {}
match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] attention_resolutions = []
attn_res = 1
for i in range(5):
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i)
if k in state_dict:
match["context_dim"] = state_dict[k].shape[1]
attention_resolutions.append(attn_res)
attn_res *= 2
match["attention_resolutions"] = attention_resolutions
match["model_channels"] = state_dict["conv_in.weight"].shape[0] match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1] match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["adm_in_channels"] = None match["adm_in_channels"] = None
@ -135,22 +146,22 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
@ -160,9 +171,20 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True
@ -171,5 +193,11 @@ def model_config_from_diffusers_unet(state_dict, use_fp16):
matches = False matches = False
break break
if matches: if matches:
return model_config_from_unet_config(unet_config) return unet_config
return None
def model_config_from_diffusers_unet(state_dict, use_fp16):
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None return None

View File

@ -2,6 +2,7 @@ import psutil
from enum import Enum from enum import Enum
from comfy.cli_args import args from comfy.cli_args import args
import torch import torch
import sys
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
@ -87,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 #TODO mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif xpu_available: elif xpu_available:
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total mem_total_torch = mem_reserved
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -201,8 +204,13 @@ if cpu_state == CPUState.MPS:
print(f"Set vram state to: {vram_state.name}") print(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY:
print("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
global xpu_available
if hasattr(device, 'type'): if hasattr(device, 'type'):
if device.type == "cuda": if device.type == "cuda":
try: try:
@ -212,6 +220,8 @@ def get_torch_device_name(device):
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
else: else:
return "{}".format(device.type) return "{}".format(device.type)
elif xpu_available:
return "{} {}".format(device, torch.xpu.get_device_name(device))
else: else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -221,132 +231,168 @@ except:
print("Could not pick default device.") print("Could not pick default device.")
current_loaded_model = None current_loaded_models = []
current_gpu_controlnets = []
model_accelerated = False class LoadedModel:
def __init__(self, model):
self.model = model
self.model_accelerated = False
self.device = model.load_device
def model_memory(self):
return self.model.model_size()
def unload_model(): def model_memory_required(self, device):
global current_loaded_model if device == self.model.current_device:
global model_accelerated return 0
global current_gpu_controlnets else:
global vram_state return self.model_memory()
if current_loaded_model is not None: def model_load(self, lowvram_model_memory=0):
if model_accelerated: global xpu_available
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) patch_model_to = None
model_accelerated = False if lowvram_model_memory == 0:
patch_model_to = self.device
current_loaded_model.unpatch_model() self.model.model_patches_to(self.device)
current_loaded_model.model.to(current_loaded_model.offload_device) self.model.model_patches_to(self.model.model_dtype())
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model = None
if vram_state != VRAMState.HIGH_VRAM:
soft_empty_cache()
if vram_state != VRAMState.HIGH_VRAM: try:
if len(current_gpu_controlnets) > 0: self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
for n in current_gpu_controlnets: except Exception as e:
n.cpu() self.model.unpatch_model(self.model.offload_device)
current_gpu_controlnets = [] self.model_unload()
raise e
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True
if xpu_available and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
return self.real_model
def model_unload(self):
if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model)
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other):
return self.model is other.model
def minimum_inference_memory(): def minimum_inference_memory():
return (768 * 1024 * 1024) return (1024 * 1024 * 1024)
def unload_model_clones(model):
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
for i in to_unload:
print("unload clone", i)
current_loaded_models.pop(i).model_unload()
def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False
for i in range(len(current_loaded_models) -1, -1, -1):
if DISABLE_SMART_MEMORY:
current_free_mem = 0
else:
current_free_mem = get_free_memory(device)
if current_free_mem > memory_required:
break
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
current_loaded_models.pop(i).model_unload()
unloaded_model = True
if unloaded_model:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
models_to_load = []
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
if loaded_model in current_loaded_models:
index = current_loaded_models.index(loaded_model)
current_loaded_models.insert(0, current_loaded_models.pop(index))
models_already_loaded.append(loaded_model)
else:
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
return
print("loading new")
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model)
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
return
def load_model_gpu(model): def load_model_gpu(model):
global current_loaded_model return load_models_gpu([model])
global vram_state
global model_accelerated
if model is current_loaded_model: def cleanup_models():
return to_delete = []
unload_model() for i in range(len(current_loaded_models)):
print(sys.getrefcount(current_loaded_models[i].model))
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
torch_dev = model.load_device for i in to_delete:
model.model_patches_to(torch_dev) x = current_loaded_models.pop(i)
model.model_patches_to(model.model_dtype()) x.model_unload()
current_loaded_model = model del x
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = model.model_size()
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
real_model = model.model
patch_model_to = None
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
patch_model_to = torch_dev
try:
real_model = model.patch_model(device_to=patch_model_to)
except Exception as e:
model.unpatch_model()
unload_model()
raise e
if patch_model_to is not None:
real_model.to(torch_dev)
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu(control_models):
global current_gpu_controlnets
global vram_state
if vram_state == VRAMState.DISABLED:
return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
for m in control_models:
if hasattr(m, 'set_lowvram'):
m.set_lowvram(True)
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return
models = []
for m in control_models:
models += m.get_models()
for m in current_gpu_controlnets:
if m not in models:
m.cpu()
device = get_torch_device()
current_gpu_controlnets = []
for m in models:
current_gpu_controlnets.append(m.to(device))
def load_if_low_vram(model):
global vram_state
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.to(get_torch_device())
return model
def unload_if_low_vram(model):
global vram_state
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cpu()
return model
def unet_offload_device(): def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM: if vram_state == VRAMState.HIGH_VRAM:
@ -354,6 +400,28 @@ def unet_offload_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM:
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY:
return cpu_dev
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
model_size = dtype_size * parameters
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev
else:
return cpu_dev
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
@ -441,8 +509,12 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
elif xpu_available: elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) stats = torch.xpu.memory_stats(dev)
mem_free_torch = mem_free_total mem_active = stats['active_bytes.all.current']
mem_allocated = stats['allocated_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
@ -456,6 +528,13 @@ def get_free_memory(dev=None, torch_free_too=False):
else: else:
return mem_free_total return mem_free_total
def batch_area_memory(area):
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: these formulas are copied from maximum_batch_area below
return (area / 20) * (1024 * 1024)
else:
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def maximum_batch_area(): def maximum_batch_area():
global vram_state global vram_state
if vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.NO_VRAM:
@ -507,9 +586,12 @@ def should_use_fp16(device=None, model_params=0):
if directml_enabled: if directml_enabled:
return False return False
if cpu_mode() or mps_mode() or xpu_available: if cpu_mode() or mps_mode():
return False #TODO ? return False #TODO ?
if xpu_available:
return True
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
return True return True

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import importlib import importlib
import os
import pkgutil import pkgutil
import time import time
import types import types
@ -14,6 +15,7 @@ except:
custom_nodes = None custom_nodes = None
from comfy.nodes.package_typing import ExportedNodes from comfy.nodes.package_typing import ExportedNodes
from functools import reduce from functools import reduce
from pkg_resources import resource_filename
_comfy_nodes = ExportedNodes() _comfy_nodes = ExportedNodes()
@ -21,10 +23,19 @@ _comfy_nodes = ExportedNodes()
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None) node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
web_directory = getattr(module, "WEB_DIRECTORY", None)
if node_class_mappings: if node_class_mappings:
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings) exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
if node_display_names: if node_display_names:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names) exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
if web_directory:
# load the extension resources path
abs_web_directory = os.path.abspath(resource_filename(module.__name__, web_directory))
if not os.path.isdir(abs_web_directory):
abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory))
if not os.path.isdir(abs_web_directory):
raise ImportError(path=abs_web_directory)
exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory
def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False) -> ExportedNodes: def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False) -> ExportedNodes:

View File

@ -22,22 +22,27 @@ class CustomNode(Protocol):
class ExportedNodes: class ExportedNodes:
NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict) NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict) NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict)
EXTENSION_WEB_DIRS: Dict[str, str] = field(default_factory=dict)
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes: def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS) self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS) self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
self.EXTENSION_WEB_DIRS.update(exported_nodes.EXTENSION_WEB_DIRS)
return self return self
def __len__(self): def __len__(self):
return len(self.NODE_CLASS_MAPPINGS) return len(self.NODE_CLASS_MAPPINGS)
def __sub__(self, other): def __sub__(self, other: ExportedNodes):
exported_nodes = ExportedNodes().update(self) exported_nodes = ExportedNodes().update(self)
for self_key in exported_nodes.NODE_CLASS_MAPPINGS: for self_key in exported_nodes.NODE_CLASS_MAPPINGS:
if self_key in other.NODE_CLASS_MAPPINGS: if self_key in other.NODE_CLASS_MAPPINGS:
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key) exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS: if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key) exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
for self_key in exported_nodes.EXTENSION_WEB_DIRS:
if self_key in other.EXTENSION_WEB_DIRS:
exported_nodes.EXTENSION_WEB_DIRS.pop(self_key)
return exported_nodes return exported_nodes
def __add__(self, other): def __add__(self, other):

View File

@ -21,6 +21,11 @@ class Conv2d(torch.nn.Conv2d):
def reset_parameters(self): def reset_parameters(self):
return None return None
def conv_nd(dims, *args, **kwargs):
if dims == 2:
return Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager @contextmanager
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way

View File

@ -51,19 +51,24 @@ def get_models_from_cond(cond, model_type):
models += [c[1][model_type]] models += [c[1][model_type]]
return models return models
def load_additional_models(positive, negative, dtype): def get_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning""" """loads additional models in positive and negative conditioning"""
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
control_models = []
for m in control_nets:
control_models += m.get_models()
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1].to(dtype) for x in gligen] gligen = [x[1] for x in gligen]
models = control_nets + gligen models = control_models + gligen
comfy.model_management.load_controlnet_gpu(models)
return models return models
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" """cleanup additional models that were loaded"""
for m in models: for m in models:
m.cleanup() if hasattr(m, 'cleanup'):
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
@ -72,7 +77,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
noise_mask = prepare_mask(noise_mask, noise.shape, device) noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None real_model = None
comfy.model_management.load_model_gpu(model) models = get_additional_models(positive, negative)
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]))
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)
@ -81,7 +87,6 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
positive_copy = broadcast_cond(positive, noise.shape[0], device) positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device) negative_copy = broadcast_cond(negative, noise.shape[0], device)
models = load_additional_models(positive, negative, model.model_dtype())
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)

View File

@ -88,9 +88,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
gligen_type = gligen[0] gligen_type = gligen[0]
gligen_model = gligen[1] gligen_model = gligen[1]
if gligen_type == "position": if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else: else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch] patches['middle_patch'] = [gligen_patch]
@ -478,7 +478,7 @@ def pre_run_control(model, conds):
timestep_end = None timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]: if 'control' in x[1]:
x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function) x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []

View File

@ -243,8 +243,15 @@ def set_attr(obj, attr, value):
setattr(obj, attrs[-1], torch.nn.Parameter(value)) setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev del prev
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0): def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size self.size = size
self.model = model self.model = model
self.patches = {} self.patches = {}
@ -253,6 +260,10 @@ class ModelPatcher:
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
@ -267,7 +278,7 @@ class ModelPatcher:
return size return size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -276,6 +287,11 @@ class ModelPatcher:
n.model_keys = self.model_keys n.model_keys = self.model_keys
return n return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
@ -390,6 +406,11 @@ class ModelPatcher:
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight) set_attr(self.model, key, out_weight)
del temp_weight del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
@ -482,7 +503,7 @@ class ModelPatcher:
return weight return weight
def unpatch_model(self): def unpatch_model(self, device_to=None):
keys = list(self.backup.keys()) keys = list(self.backup.keys())
for k in keys: for k in keys:
@ -490,6 +511,11 @@ class ModelPatcher:
self.backup = {} self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = model_lora_keys_unet(model.model) key_map = model_lora_keys_unet(model.model)
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
@ -555,7 +581,7 @@ class CLIP:
else: else:
self.cond_stage_model.reset_clip_layer() self.cond_stage_model.reset_clip_layer()
model_management.load_model_gpu(self.patcher) self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
if return_pooled: if return_pooled:
return cond, pooled return cond, pooled
@ -571,11 +597,9 @@ class CLIP:
def get_sd(self): def get_sd(self):
return self.cond_stage_model.state_dict() return self.cond_stage_model.state_dict()
def patch_model(self): def load_model(self):
self.patcher.patch_model() model_management.load_model_gpu(self.patcher)
return self.patcher
def unpatch_model(self):
self.patcher.unpatch_model()
def get_key_patches(self): def get_key_patches(self):
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
@ -630,11 +654,12 @@ class VAE:
return samples return samples
def decode(self, samples_in): def decode(self, samples_in):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
try: try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64)) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
@ -650,19 +675,19 @@ class VAE:
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device) self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1) return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
@ -677,7 +702,6 @@ class VAE:
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
@ -757,6 +781,7 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
@ -780,19 +805,14 @@ class ControlNet(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_model.dtype == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
with precision_scope(model_management.get_autocast_device(self.device)): context = torch.cat(cond['c_crossattn'], 1)
self.control_model = model_management.load_if_low_vram(self.control_model) y = cond.get('c_adm', None)
context = torch.cat(cond['c_crossattn'], 1) if y is not None:
y = cond.get('c_adm', None) y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y) control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = {'middle':[], 'output': []} out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled()
for i in range(len(control)): for i in range(len(control)):
if i == (len(control) - 1): if i == (len(control) - 1):
@ -806,7 +826,7 @@ class ControlNet(ControlBase):
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength x *= self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype:
x = x.to(output_dtype) x = x.to(output_dtype)
if control_prev is not None and key in control_prev: if control_prev is not None and key in control_prev:
@ -825,17 +845,133 @@ class ControlNet(ControlBase):
def get_models(self): def get_models(self):
out = super().get_models() out = super().get_models()
out.append(self.control_model) out.append(self.control_model_wrapped)
return out return out
class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.up = None
self.down = None
self.bias = None
def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
device=None,
dtype=None
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = False
self.output_padding = 0
self.groups = groups
self.padding_mode = padding_mode
self.weight = None
self.bias = None
self.up = None
self.down = None
def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = cldm.ControlNet(**controlnet_config)
if model_management.should_use_fp16():
self.control_model.half()
self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = sd[k]
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = get_attr(diffusion_model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
try:
set_attr(self.control_model, k, weight)
except:
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
self.copy_to(c)
return c
def cleanup(self):
del self.control_model
self.control_model = None
super().cleanup()
def get_models(self):
out = ControlBase.get_models(self)
return out
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
controlnet_config = None controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = model_management.should_use_fp16() use_fp16 = model_management.should_use_fp16()
controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
diffusers_keys = utils.unet_to_diffusers(controlnet_config) diffusers_keys = utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@ -874,6 +1010,9 @@ def load_controlnet(ckpt_path, model=None):
if k in controlnet_data: if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k) new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0:
print("leftover keys:", leftover_keys)
controlnet_data = new_sd controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
@ -901,8 +1040,8 @@ def load_controlnet(ckpt_path, model=None):
if pth: if pth:
if 'difference' in controlnet_data: if 'difference' in controlnet_data:
if model is not None: if model is not None:
m = model.patch_model() model_management.load_models_gpu([model])
model_sd = m.state_dict() model_sd = model.model_state_dict()
for x in controlnet_data: for x in controlnet_data:
c_m = "control_model." c_m = "control_model."
if x.startswith(c_m): if x.startswith(c_m):
@ -910,7 +1049,6 @@ def load_controlnet(ckpt_path, model=None):
if sd_key in model_sd: if sd_key in model_sd:
cd = controlnet_data[x] cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device) cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
model.unpatch_model()
else: else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
@ -970,11 +1108,10 @@ class T2IAdapter(ControlBase):
output_dtype = x_noisy.dtype output_dtype = x_noisy.dtype
out = {'input':[]} out = {'input':[]}
autocast_enabled = torch.is_autocast_enabled()
for i in range(len(self.control_input)): for i in range(len(self.control_input)):
key = 'input' key = 'input'
x = self.control_input[i] * self.strength x = self.control_input[i] * self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype:
x = x.to(output_dtype) x = x.to(output_dtype)
if control_prev is not None and key in control_prev: if control_prev is not None and key in control_prev:
@ -1001,7 +1138,6 @@ class T2IAdapter(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
keys = t2i_data.keys() keys = t2i_data.keys()
if 'adapter' in keys: if 'adapter' in keys:
@ -1087,7 +1223,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data) model = gligen.load_gligen(data)
if model_management.should_use_fp16(): if model_management.should_use_fp16():
model = model.half() model = model.half()
return model return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually #TODO: this function is a mess and should be removed eventually
@ -1199,8 +1335,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision: if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
dtype = torch.float32
if fp16:
dtype = torch.float16
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device) model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
@ -1221,7 +1362,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) print("left over keys:", left_over)
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format def load_unet(unet_path): #load unet in diffusers format
@ -1249,14 +1395,6 @@ def load_unet(unet_path): #load unet in diffusers format
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
try: model_management.load_models_gpu([model, clip.load_model()])
model.patch_model() sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
clip.patch_model() utils.save_torch_file(sd, output_path, metadata=metadata)
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
utils.save_torch_file(sd, output_path, metadata=metadata)
model.unpatch_model()
clip.unpatch_model()
except Exception as e:
model.unpatch_model()
clip.unpatch_model()
raise e

View File

@ -1,3 +1,5 @@
import numpy as np
from scipy.ndimage import grey_dilation
import torch import torch
from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.common import MAX_RESOLUTION
@ -277,6 +279,35 @@ class FeatherMask:
output[-y, :] *= feather_rate output[-y, :] *= feather_rate
return (output,) return (output,)
class GrowMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
"expand": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"tapered_corners": ("BOOLEAN", {"default": True}),
},
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "expand_mask"
def expand_mask(self, mask, expand, tapered_corners):
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c],
[1, 1, 1],
[c, 1, c]])
output = mask.numpy().copy()
while expand > 0:
output = grey_dilation(output, footprint=kernel)
expand -= 1
output = torch.from_numpy(output)
return (output,)
@ -290,6 +321,7 @@ NODE_CLASS_MAPPINGS = {
"CropMask": CropMask, "CropMask": CropMask,
"MaskComposite": MaskComposite, "MaskComposite": MaskComposite,
"FeatherMask": FeatherMask, "FeatherMask": FeatherMask,
"GrowMask": GrowMask,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {

View File

@ -2,6 +2,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
import math
import comfy.utils import comfy.utils
@ -209,9 +210,36 @@ class Sharpen:
return (result,) return (result,)
class ImageScaleToTotalPixels:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
crop_methods = ["disabled", "center"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
"megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, megapixels):
samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1,-1)
return (s,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend, "ImageBlend": Blend,
"ImageBlur": Blur, "ImageBlur": Blur,
"ImageQuantize": Quantize, "ImageQuantize": Quantize,
"ImageSharpen": Sharpen, "ImageSharpen": Sharpen,
"ImageScaleToTotalPixels": ImageScaleToTotalPixels,
} }

View File

@ -6,6 +6,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no"> <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<link rel="stylesheet" type="text/css" href="./lib/litegraph.css" /> <link rel="stylesheet" type="text/css" href="./lib/litegraph.css" />
<link rel="stylesheet" type="text/css" href="./style.css" /> <link rel="stylesheet" type="text/css" href="./style.css" />
<link rel="stylesheet" type="text/css" href="./user.css" />
<script type="text/javascript" src="./lib/litegraph.core.js"></script> <script type="text/javascript" src="./lib/litegraph.core.js"></script>
<script type="text/javascript" src="./lib/litegraph.extensions.js" defer></script> <script type="text/javascript" src="./lib/litegraph.extensions.js" defer></script>
<script type="module"> <script type="module">

View File

@ -1026,18 +1026,21 @@ export class ComfyApp {
} }
/** /**
* Loads all extensions from the API into the window * Loads all extensions from the API into the window in parallel
*/ */
async #loadExtensions() { async #loadExtensions() {
const extensions = await api.getExtensions(); const extensions = await api.getExtensions();
this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions }); this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions });
for (const ext of extensions) {
try { const extensionPromises = extensions.map(async ext => {
await import(api.apiURL(ext)); try {
} catch (error) { await import(api.apiURL(ext));
console.error("Error loading extension", ext, error); } catch (error) {
} console.error("Error loading extension", ext, error);
} }
});
await Promise.all(extensionPromises);
} }
/** /**

1
web/user.css Normal file
View File

@ -0,0 +1 @@
/* Put custom styles here */