diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index c7ef93ce1..319942e7c 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -31,7 +31,7 @@ jobs: echo 'import site' >> ./python311._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio aiohttp==3.8.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python311._pth diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 46fbf0a69..251483131 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -6,8 +6,6 @@ import torch as th import torch.nn as nn from ..ldm.modules.diffusionmodules.util import ( - conv_nd, - linear, zero_module, timestep_embedding, ) @@ -15,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import ( from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.util import exists - +import comfy.ops class ControlledUnetModel(UNetModel): #implemented in the ldm unet @@ -55,6 +53,8 @@ class ControlNet(nn.Module): use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + device=None, + operations=comfy.ops, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" @@ -117,9 +117,9 @@ class ControlNet(nn.Module): time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), + operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), - linear(time_embed_dim, time_embed_dim), + operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), ) if self.num_classes is not None: @@ -132,9 +132,9 @@ class ControlNet(nn.Module): assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( - linear(adm_in_channels, time_embed_dim), + operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), - linear(time_embed_dim, time_embed_dim), + operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), ) ) else: @@ -143,28 +143,28 @@ class ControlNet(nn.Module): self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) + operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) ] ) - self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)]) self.input_hint_block = TimestepEmbedSequential( - conv_nd(dims, hint_channels, 16, 3, padding=1), + operations.conv_nd(dims, hint_channels, 16, 3, padding=1), nn.SiLU(), - conv_nd(dims, 16, 16, 3, padding=1), + operations.conv_nd(dims, 16, 16, 3, padding=1), nn.SiLU(), - conv_nd(dims, 16, 32, 3, padding=1, stride=2), + operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2), nn.SiLU(), - conv_nd(dims, 32, 32, 3, padding=1), + operations.conv_nd(dims, 32, 32, 3, padding=1), nn.SiLU(), - conv_nd(dims, 32, 96, 3, padding=1, stride=2), + operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2), nn.SiLU(), - conv_nd(dims, 96, 96, 3, padding=1), + operations.conv_nd(dims, 96, 96, 3, padding=1), nn.SiLU(), - conv_nd(dims, 96, 256, 3, padding=1, stride=2), + operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2), nn.SiLU(), - zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) + zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1)) ) self._feature_size = model_channels @@ -182,6 +182,7 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + operations=operations ) ] ch = mult * model_channels @@ -204,11 +205,11 @@ class ControlNet(nn.Module): SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + use_checkpoint=use_checkpoint, operations=operations ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) - self.zero_convs.append(self.make_zero_conv(ch)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: @@ -224,16 +225,17 @@ class ControlNet(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, + operations=operations ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch + ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations ) ) ) ch = out_ch input_block_chans.append(ch) - self.zero_convs.append(self.make_zero_conv(ch)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) ds *= 2 self._feature_size += ch @@ -253,11 +255,12 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + operations=operations ), SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + use_checkpoint=use_checkpoint, operations=operations ), ResBlock( ch, @@ -266,16 +269,17 @@ class ControlNet(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + operations=operations ), ) - self.middle_block_out = self.make_zero_conv(ch) + self.middle_block_out = self.make_zero_conv(ch, operations=operations) self._feature_size += ch - def make_zero_conv(self, channels): - return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + def make_zero_conv(self, channels, operations=None): + return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) def forward(self, x, hint, timesteps, context, y=None, **kwargs): - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) emb = self.time_embed(t_emb) guided_hint = self.input_hint_block(hint, emb, context) @@ -283,9 +287,6 @@ class ControlNet(nn.Module): outs = [] hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index ec7d34a55..b4f22f319 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") +parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") + class LatentPreviewMethod(enum.Enum): NoPreviews = "none" Auto = "auto" @@ -82,6 +84,9 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn' vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") +parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") + + parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 8d04faf71..a887e51b5 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -25,6 +25,7 @@ class ClipVisionModel(): def encode_image(self, image): img = torch.clip((255. * image), 0, 255).round().int() + img = list(map(lambda a: a, img)) inputs = self.processor(images=img, return_tensors="pt") outputs = self.model(**inputs) return outputs @@ -49,18 +50,22 @@ def convert_to_transformers(sd, prefix): if "{}proj".format(prefix) in sd_k: sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) - sd = transformers_convert(sd, prefix, "vision_model.", 32) + sd = transformers_convert(sd, prefix, "vision_model.", 48) return sd def load_clipvision_from_sd(sd, prefix="", convert_keys=False): if convert_keys: sd = convert_to_transformers(sd, prefix) - if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: + if "vision_model.encoder.layers.47.layer_norm1.weight" in sd: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") + elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") clip = ClipVisionModel(json_config) m, u = clip.load_sd(sd) + if len(m) > 0: + print("missing clip vision:", m) u = set(u) keys = list(sd.keys()) for k in keys: @@ -71,4 +76,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): def load(ckpt_path): sd = load_torch_file(ckpt_path) - return load_clipvision_from_sd(sd) + if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd: + return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True) + else: + return load_clipvision_from_sd(sd) diff --git a/comfy/clip_vision_config_g.json b/comfy/clip_vision_config_g.json new file mode 100644 index 000000000..708e7e21a --- /dev/null +++ b/comfy/clip_vision_config_g.json @@ -0,0 +1,18 @@ +{ + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "gelu", + "hidden_size": 1664, + "image_size": 224, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 8192, + "layer_norm_eps": 1e-05, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 48, + "patch_size": 14, + "projection_dim": 1280, + "torch_dtype": "float32" +} diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 88a30d98d..7c9d7d07e 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -409,6 +409,7 @@ class PromptExecutor: d = self.outputs_ui.pop(x) del d + comfy.model_management.cleanup_models() if self.server.client_id is not None: self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, self.server.client_id) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index da5edab30..761b6bee7 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -4,6 +4,7 @@ import glob import struct import sys import shutil +from urllib.parse import quote from PIL import Image, ImageOps from io import BytesIO @@ -80,6 +81,8 @@ class PromptServer(): mimetypes.init() mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' + + self.supports = ["custom_nodes_from_web"] self.prompt_queue = None self.loop = loop self.messages = asyncio.Queue() @@ -140,9 +143,17 @@ class PromptServer(): @routes.get("/extensions") async def get_extensions(request): - files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) - return web.json_response( - list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) + files = glob.glob(os.path.join( + self.web_root, 'extensions/**/*.js'), recursive=True) + + extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) + + for name, dir in nodes.EXTENSION_WEB_DIRS.items(): + files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True) + extensions.extend(list(map(lambda f: "/extensions/" + quote( + name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) + + return web.json_response(extensions) def get_dir_by_type(dir_type=None): type_dir = "" @@ -638,6 +649,12 @@ class PromptServer(): def add_routes(self): self.app.add_routes(self.routes) + + for name, dir in nodes.EXTENSION_WEB_DIRS.items(): + self.app.add_routes([ + web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True), + ]) + self.app.add_routes([ web.static('/', self.web_root, follow_symlinks=True), ]) diff --git a/comfy/gligen.py b/comfy/gligen.py index 90558785b..8d182839e 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -244,30 +244,15 @@ class Gligen(nn.Module): self.position_net = position_net self.key_dim = key_dim self.max_objs = 30 - self.lowvram = False + self.current_device = torch.device("cpu") def _set_position(self, boxes, masks, positive_embeddings): - if self.lowvram == True: - self.position_net.to(boxes.device) - objs = self.position_net(boxes, masks, positive_embeddings) - - if self.lowvram == True: - self.position_net.cpu() - def func_lowvram(x, extra_options): - key = extra_options["transformer_index"] - module = self.module_list[key] - module.to(x.device) - r = module(x, objs) - module.cpu() - return r - return func_lowvram - else: - def func(x, extra_options): - key = extra_options["transformer_index"] - module = self.module_list[key] - return module(x, objs) - return func + def func(x, extra_options): + key = extra_options["transformer_index"] + module = self.module_list[key] + return module(x, objs) + return func def set_position(self, latent_image_shape, position_params, device): batch, c, h, w = latent_image_shape @@ -312,14 +297,6 @@ class Gligen(nn.Module): masks.to(device), conds.to(device)) - def set_lowvram(self, value=True): - self.lowvram = value - - def cleanup(self): - self.lowvram = False - - def get_models(self): - return [self] def load_gligen(sd): sd_k = sd.keys() diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index beaa623f3..eb088d92b 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -649,7 +649,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl s_in = x.new_ones([x.shape[0]]) denoised_1, denoised_2 = None, None - h_1, h_2 = None, None + h, h_1, h_2 = None, None, None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 573cea6ac..973619bf2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -10,13 +10,14 @@ from .diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management -import comfy.ops if model_management.xformers_enabled(): import xformers import xformers.ops from comfy.cli_args import args +import comfy.ops + # CrossAttn precision handling if args.dont_upcast_attention: print("disabling upcasting of attention") @@ -52,9 +53,9 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out, dtype=None, device=None): + def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops): super().__init__() - self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) + self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -62,19 +63,19 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device), + operations.Linear(dim, inner_dim, dtype=dtype, device=device), nn.GELU() - ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device) + ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device) + operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): @@ -148,7 +149,7 @@ class SpatialSelfAttention(nn.Module): class CrossAttentionBirchSan(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -156,12 +157,12 @@ class CrossAttentionBirchSan(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential( - comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -245,7 +246,7 @@ class CrossAttentionBirchSan(nn.Module): class CrossAttentionDoggettx(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -253,12 +254,12 @@ class CrossAttentionDoggettx(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential( - comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -343,7 +344,7 @@ class CrossAttentionDoggettx(nn.Module): return self.to_out(r2) class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -351,12 +352,12 @@ class CrossAttention(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential( - comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -399,7 +400,7 @@ class CrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops): super().__init__() print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " f"{heads} heads.") @@ -409,11 +410,11 @@ class MemoryEfficientCrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): @@ -450,7 +451,7 @@ class MemoryEfficientCrossAttention(nn.Module): return self.to_out(out) class CrossAttentionPytorch(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -458,11 +459,11 @@ class CrossAttentionPytorch(nn.Module): self.heads = heads self.dim_head = dim_head - self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): @@ -508,14 +509,14 @@ else: class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False, dtype=None, device=None): + disable_self_attn=False, dtype=None, device=None, operations=comfy.ops): super().__init__() self.disable_self_attn = disable_self_attn self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device) + context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none + heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) @@ -648,7 +649,7 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None, device=None): + use_checkpoint=True, dtype=None, device=None, operations=comfy.ops): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth @@ -656,26 +657,26 @@ class SpatialTransformer(nn.Module): inner_dim = n_heads * d_head self.norm = Normalize(in_channels, dtype=dtype, device=device) if not use_linear: - self.proj_in = nn.Conv2d(in_channels, + self.proj_in = operations.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) else: - self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) + self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device) + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations) for d in range(depth)] ) if not use_linear: - self.proj_out = nn.Conv2d(inner_dim,in_channels, + self.proj_out = operations.Conv2d(inner_dim,in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) else: - self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) + self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.use_linear = use_linear def forward(self, x, context=None, transformer_options={}): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 90c153465..11cec0eda 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -8,8 +8,6 @@ import torch.nn.functional as F from .util import ( checkpoint, - conv_nd, - linear, avg_pool_nd, zero_module, normalization, @@ -17,7 +15,7 @@ from .util import ( ) from ..attention import SpatialTransformer from comfy.ldm.util import exists - +import comfy.ops class TimestepBlock(nn.Module): """ @@ -72,14 +70,14 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device) + self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device) def forward(self, x, output_shape=None): assert x.shape[1] == self.channels @@ -108,7 +106,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -116,7 +114,7 @@ class Downsample(nn.Module): self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( + self.op = operations.conv_nd( dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device ) else: @@ -158,6 +156,7 @@ class ResBlock(TimestepBlock): down=False, dtype=None, device=None, + operations=comfy.ops ): super().__init__() self.channels = channels @@ -171,7 +170,7 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), + operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), ) self.updown = up or down @@ -187,7 +186,7 @@ class ResBlock(TimestepBlock): self.emb_layers = nn.Sequential( nn.SiLU(), - linear( + operations.Linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device ), @@ -197,18 +196,18 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) + operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( + self.skip_connection = operations.conv_nd( dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) + self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) def forward(self, x, emb): """ @@ -317,6 +316,7 @@ class UNetModel(nn.Module): adm_in_channels=None, transformer_depth_middle=None, device=None, + operations=comfy.ops, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" @@ -379,9 +379,9 @@ class UNetModel(nn.Module): time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), + operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), - linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), + operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), ) if self.num_classes is not None: @@ -394,9 +394,9 @@ class UNetModel(nn.Module): assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( - linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), + operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), nn.SiLU(), - linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), + operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), ) ) else: @@ -405,7 +405,7 @@ class UNetModel(nn.Module): self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) + operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) ] ) @@ -426,6 +426,7 @@ class UNetModel(nn.Module): use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, device=device, + operations=operations, ) ] ch = mult * model_channels @@ -447,7 +448,7 @@ class UNetModel(nn.Module): layers.append(SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -468,10 +469,11 @@ class UNetModel(nn.Module): down=True, dtype=self.dtype, device=device, + operations=operations ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations ) ) ) @@ -498,11 +500,12 @@ class UNetModel(nn.Module): use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, device=device, + operations=operations ), SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ), ResBlock( ch, @@ -513,6 +516,7 @@ class UNetModel(nn.Module): use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, device=device, + operations=operations ), ) self._feature_size += ch @@ -532,6 +536,7 @@ class UNetModel(nn.Module): use_scale_shift_norm=use_scale_shift_norm, dtype=self.dtype, device=device, + operations=operations ) ] ch = model_channels * mult @@ -554,7 +559,7 @@ class UNetModel(nn.Module): SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) ) if level and i == self.num_res_blocks[level]: @@ -571,9 +576,10 @@ class UNetModel(nn.Module): up=True, dtype=self.dtype, device=device, + operations=operations ) if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device) + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) @@ -582,12 +588,12 @@ class UNetModel(nn.Module): self.out = nn.Sequential( nn.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), + zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( nn.GroupNorm(32, ch, dtype=self.dtype, device=device), - conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), + operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) diff --git a/comfy/model_base.py b/comfy/model_base.py index ad661ec7d..979e2c65e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -148,13 +148,20 @@ class SDInpaint(BaseModel): super().__init__(model_config, model_type, device=device) self.concat_keys = ("mask", "masked_image") +def sdxl_pooled(args, noise_augmentor): + if "unclip_conditioning" in args: + return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] + else: + return args["pooled_output"] + class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.embedder = Timestep(256) + self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280}) def encode_adm(self, **kwargs): - clip_pooled = kwargs["pooled_output"] + clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor) width = kwargs.get("width", 768) height = kwargs.get("height", 768) crop_w = kwargs.get("crop_w", 0) @@ -178,9 +185,10 @@ class SDXL(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.embedder = Timestep(256) + self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280}) def encode_adm(self, **kwargs): - clip_pooled = kwargs["pooled_output"] + clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor) width = kwargs.get("width", 768) height = kwargs.get("height", 768) crop_w = kwargs.get("crop_w", 0) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 49ee9ea70..0edc4f180 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -121,9 +121,20 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): return model_config_from_unet_config(unet_config) -def model_config_from_diffusers_unet(state_dict, use_fp16): +def unet_config_from_diffusers_unet(state_dict, use_fp16): match = {} - match["context_dim"] = state_dict["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] + attention_resolutions = [] + + attn_res = 1 + for i in range(5): + k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i) + if k in state_dict: + match["context_dim"] = state_dict[k].shape[1] + attention_resolutions.append(attn_res) + attn_res *= 2 + + match["attention_resolutions"] = attention_resolutions + match["model_channels"] = state_dict["conv_in.weight"].shape[0] match["in_channels"] = state_dict["conv_in.weight"].shape[1] match["adm_in_channels"] = None @@ -135,22 +146,22 @@ def model_config_from_diffusers_unet(state_dict, use_fp16): SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048} + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280} + 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, @@ -160,9 +171,20 @@ def model_config_from_diffusers_unet(state_dict, use_fp16): SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768} + 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl] + SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} + + SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} + + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet] for unet_config in supported_models: matches = True @@ -171,5 +193,11 @@ def model_config_from_diffusers_unet(state_dict, use_fp16): matches = False break if matches: - return model_config_from_unet_config(unet_config) + return unet_config + return None + +def model_config_from_diffusers_unet(state_dict, use_fp16): + unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16) + if unet_config is not None: + return model_config_from_unet_config(unet_config) return None diff --git a/comfy/model_management.py b/comfy/model_management.py index e634f10d2..9b1b7f55d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -2,6 +2,7 @@ import psutil from enum import Enum from comfy.cli_args import args import torch +import sys class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -87,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False): mem_total = 1024 * 1024 * 1024 #TODO mem_total_torch = mem_total elif xpu_available: + stats = torch.xpu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] mem_total = torch.xpu.get_device_properties(dev).total_memory - mem_total_torch = mem_total + mem_total_torch = mem_reserved else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -201,8 +204,13 @@ if cpu_state == CPUState.MPS: print(f"Set vram state to: {vram_state.name}") +DISABLE_SMART_MEMORY = args.disable_smart_memory + +if DISABLE_SMART_MEMORY: + print("Disabling smart memory management") def get_torch_device_name(device): + global xpu_available if hasattr(device, 'type'): if device.type == "cuda": try: @@ -212,6 +220,8 @@ def get_torch_device_name(device): return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) else: return "{}".format(device.type) + elif xpu_available: + return "{} {}".format(device, torch.xpu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -221,132 +231,168 @@ except: print("Could not pick default device.") -current_loaded_model = None -current_gpu_controlnets = [] +current_loaded_models = [] -model_accelerated = False +class LoadedModel: + def __init__(self, model): + self.model = model + self.model_accelerated = False + self.device = model.load_device + def model_memory(self): + return self.model.model_size() -def unload_model(): - global current_loaded_model - global model_accelerated - global current_gpu_controlnets - global vram_state + def model_memory_required(self, device): + if device == self.model.current_device: + return 0 + else: + return self.model_memory() - if current_loaded_model is not None: - if model_accelerated: - accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) - model_accelerated = False + def model_load(self, lowvram_model_memory=0): + global xpu_available + patch_model_to = None + if lowvram_model_memory == 0: + patch_model_to = self.device - current_loaded_model.unpatch_model() - current_loaded_model.model.to(current_loaded_model.offload_device) - current_loaded_model.model_patches_to(current_loaded_model.offload_device) - current_loaded_model = None - if vram_state != VRAMState.HIGH_VRAM: - soft_empty_cache() + self.model.model_patches_to(self.device) + self.model.model_patches_to(self.model.model_dtype()) - if vram_state != VRAMState.HIGH_VRAM: - if len(current_gpu_controlnets) > 0: - for n in current_gpu_controlnets: - n.cpu() - current_gpu_controlnets = [] + try: + self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU + except Exception as e: + self.model.unpatch_model(self.model.offload_device) + self.model_unload() + raise e + + if lowvram_model_memory > 0: + print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) + device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) + accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) + self.model_accelerated = True + + if xpu_available and not args.disable_ipex_optimize: + self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) + + return self.real_model + + def model_unload(self): + if self.model_accelerated: + accelerate.hooks.remove_hook_from_submodules(self.real_model) + self.model_accelerated = False + + self.model.unpatch_model(self.model.offload_device) + self.model.model_patches_to(self.model.offload_device) + + def __eq__(self, other): + return self.model is other.model def minimum_inference_memory(): - return (768 * 1024 * 1024) + return (1024 * 1024 * 1024) + +def unload_model_clones(model): + to_unload = [] + for i in range(len(current_loaded_models)): + if model.is_clone(current_loaded_models[i].model): + to_unload = [i] + to_unload + + for i in to_unload: + print("unload clone", i) + current_loaded_models.pop(i).model_unload() + +def free_memory(memory_required, device, keep_loaded=[]): + unloaded_model = False + for i in range(len(current_loaded_models) -1, -1, -1): + if DISABLE_SMART_MEMORY: + current_free_mem = 0 + else: + current_free_mem = get_free_memory(device) + if current_free_mem > memory_required: + break + shift_model = current_loaded_models[i] + if shift_model.device == device: + if shift_model not in keep_loaded: + current_loaded_models.pop(i).model_unload() + unloaded_model = True + + if unloaded_model: + soft_empty_cache() + + +def load_models_gpu(models, memory_required=0): + global vram_state + + inference_memory = minimum_inference_memory() + extra_mem = max(inference_memory, memory_required) + + models_to_load = [] + models_already_loaded = [] + for x in models: + loaded_model = LoadedModel(x) + + if loaded_model in current_loaded_models: + index = current_loaded_models.index(loaded_model) + current_loaded_models.insert(0, current_loaded_models.pop(index)) + models_already_loaded.append(loaded_model) + else: + models_to_load.append(loaded_model) + + if len(models_to_load) == 0: + devs = set(map(lambda a: a.device, models_already_loaded)) + for d in devs: + if d != torch.device("cpu"): + free_memory(extra_mem, d, models_already_loaded) + return + + print("loading new") + + total_memory_required = {} + for loaded_model in models_to_load: + unload_model_clones(loaded_model.model) + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + + for device in total_memory_required: + if device != torch.device("cpu"): + free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + + for loaded_model in models_to_load: + model = loaded_model.model + torch_dev = model.load_device + if is_device_cpu(torch_dev): + vram_set_state = VRAMState.DISABLED + else: + vram_set_state = vram_state + lowvram_model_memory = 0 + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = loaded_model.model_memory_required(torch_dev) + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + else: + lowvram_model_memory = 0 + + if vram_set_state == VRAMState.NO_VRAM: + lowvram_model_memory = 256 * 1024 * 1024 + + cur_loaded_model = loaded_model.model_load(lowvram_model_memory) + current_loaded_models.insert(0, loaded_model) + return + def load_model_gpu(model): - global current_loaded_model - global vram_state - global model_accelerated + return load_models_gpu([model]) - if model is current_loaded_model: - return - unload_model() +def cleanup_models(): + to_delete = [] + for i in range(len(current_loaded_models)): + print(sys.getrefcount(current_loaded_models[i].model)) + if sys.getrefcount(current_loaded_models[i].model) <= 2: + to_delete = [i] + to_delete - torch_dev = model.load_device - model.model_patches_to(torch_dev) - model.model_patches_to(model.model_dtype()) - current_loaded_model = model - - if is_device_cpu(torch_dev): - vram_set_state = VRAMState.DISABLED - else: - vram_set_state = vram_state - - if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): - model_size = model.model_size() - current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) - if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary - vram_set_state = VRAMState.LOW_VRAM - - real_model = model.model - patch_model_to = None - if vram_set_state == VRAMState.DISABLED: - pass - elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: - model_accelerated = False - patch_model_to = torch_dev - - try: - real_model = model.patch_model(device_to=patch_model_to) - except Exception as e: - model.unpatch_model() - unload_model() - raise e - - if patch_model_to is not None: - real_model.to(torch_dev) - - if vram_set_state == VRAMState.NO_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) - model_accelerated = True - elif vram_set_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) - model_accelerated = True - - return current_loaded_model - -def load_controlnet_gpu(control_models): - global current_gpu_controlnets - global vram_state - if vram_state == VRAMState.DISABLED: - return - - if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - for m in control_models: - if hasattr(m, 'set_lowvram'): - m.set_lowvram(True) - #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after - return - - models = [] - for m in control_models: - models += m.get_models() - - for m in current_gpu_controlnets: - if m not in models: - m.cpu() - - device = get_torch_device() - current_gpu_controlnets = [] - for m in models: - current_gpu_controlnets.append(m.to(device)) - - -def load_if_low_vram(model): - global vram_state - if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - return model.to(get_torch_device()) - return model - -def unload_if_low_vram(model): - global vram_state - if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - return model.cpu() - return model + for i in to_delete: + x = current_loaded_models.pop(i) + x.model_unload() + del x def unet_offload_device(): if vram_state == VRAMState.HIGH_VRAM: @@ -354,6 +400,28 @@ def unet_offload_device(): else: return torch.device("cpu") +def unet_inital_load_device(parameters, dtype): + torch_dev = get_torch_device() + if vram_state == VRAMState.HIGH_VRAM: + return torch_dev + + cpu_dev = torch.device("cpu") + if DISABLE_SMART_MEMORY: + return cpu_dev + + dtype_size = 4 + if dtype == torch.float16 or dtype == torch.bfloat16: + dtype_size = 2 + + model_size = dtype_size * parameters + + mem_dev = get_free_memory(torch_dev) + mem_cpu = get_free_memory(cpu_dev) + if mem_dev > mem_cpu and model_size < mem_dev: + return torch_dev + else: + return cpu_dev + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -441,8 +509,12 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_torch = mem_free_total elif xpu_available: - mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) - mem_free_torch = mem_free_total + stats = torch.xpu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_allocated = stats['allocated_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_torch = mem_reserved - mem_active + mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -456,6 +528,13 @@ def get_free_memory(dev=None, torch_free_too=False): else: return mem_free_total +def batch_area_memory(area): + if xformers_enabled() or pytorch_attention_flash_attention(): + #TODO: these formulas are copied from maximum_batch_area below + return (area / 20) * (1024 * 1024) + else: + return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) + def maximum_batch_area(): global vram_state if vram_state == VRAMState.NO_VRAM: @@ -507,9 +586,12 @@ def should_use_fp16(device=None, model_params=0): if directml_enabled: return False - if cpu_mode() or mps_mode() or xpu_available: + if cpu_mode() or mps_mode(): return False #TODO ? + if xpu_available: + return True + if torch.cuda.is_bf16_supported(): return True diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index e90f6d163..150497662 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -1,6 +1,7 @@ from __future__ import annotations import importlib +import os import pkgutil import time import types @@ -14,6 +15,7 @@ except: custom_nodes = None from comfy.nodes.package_typing import ExportedNodes from functools import reduce +from pkg_resources import resource_filename _comfy_nodes = ExportedNodes() @@ -21,10 +23,19 @@ _comfy_nodes = ExportedNodes() def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None) + web_directory = getattr(module, "WEB_DIRECTORY", None) if node_class_mappings: exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings) if node_display_names: exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names) + if web_directory: + # load the extension resources path + abs_web_directory = os.path.abspath(resource_filename(module.__name__, web_directory)) + if not os.path.isdir(abs_web_directory): + abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory)) + if not os.path.isdir(abs_web_directory): + raise ImportError(path=abs_web_directory) + exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False) -> ExportedNodes: diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index bf83c0a7e..6cb8d212f 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -22,22 +22,27 @@ class CustomNode(Protocol): class ExportedNodes: NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict) NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict) + EXTENSION_WEB_DIRS: Dict[str, str] = field(default_factory=dict) def update(self, exported_nodes: ExportedNodes) -> ExportedNodes: self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS) self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS) + self.EXTENSION_WEB_DIRS.update(exported_nodes.EXTENSION_WEB_DIRS) return self def __len__(self): return len(self.NODE_CLASS_MAPPINGS) - def __sub__(self, other): + def __sub__(self, other: ExportedNodes): exported_nodes = ExportedNodes().update(self) for self_key in exported_nodes.NODE_CLASS_MAPPINGS: if self_key in other.NODE_CLASS_MAPPINGS: exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key) if self_key in other.NODE_DISPLAY_NAME_MAPPINGS: exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key) + for self_key in exported_nodes.EXTENSION_WEB_DIRS: + if self_key in other.EXTENSION_WEB_DIRS: + exported_nodes.EXTENSION_WEB_DIRS.pop(self_key) return exported_nodes def __add__(self, other): diff --git a/comfy/ops.py b/comfy/ops.py index 2e72030bd..678c2c6d0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,6 +21,11 @@ class Conv2d(torch.nn.Conv2d): def reset_parameters(self): return None +def conv_nd(dims, *args, **kwargs): + if dims == 2: + return Conv2d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") @contextmanager def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way diff --git a/comfy/sample.py b/comfy/sample.py index 48530f132..d7292024e 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -51,19 +51,24 @@ def get_models_from_cond(cond, model_type): models += [c[1][model_type]] return models -def load_additional_models(positive, negative, dtype): +def get_additional_models(positive, negative): """loads additional models in positive and negative conditioning""" - control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") + control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) + + control_models = [] + for m in control_nets: + control_models += m.get_models() + gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") - gligen = [x[1].to(dtype) for x in gligen] - models = control_nets + gligen - comfy.model_management.load_controlnet_gpu(models) + gligen = [x[1] for x in gligen] + models = control_models + gligen return models def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: - m.cleanup() + if hasattr(m, 'cleanup'): + m.cleanup() def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): device = comfy.model_management.get_torch_device() @@ -72,7 +77,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative noise_mask = prepare_mask(noise_mask, noise.shape, device) real_model = None - comfy.model_management.load_model_gpu(model) + models = get_additional_models(positive, negative) + comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3])) real_model = model.model noise = noise.to(device) @@ -81,7 +87,6 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative positive_copy = broadcast_cond(positive, noise.shape[0], device) negative_copy = broadcast_cond(negative, noise.shape[0], device) - models = load_additional_models(positive, negative, model.model_dtype()) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) diff --git a/comfy/samplers.py b/comfy/samplers.py index 28cd46667..134336de6 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -88,9 +88,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con gligen_type = gligen[0] gligen_model = gligen[1] if gligen_type == "position": - gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) + gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) else: - gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) + gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) patches['middle_patch'] = [gligen_patch] @@ -478,7 +478,7 @@ def pre_run_control(model, conds): timestep_end = None percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) if 'control' in x[1]: - x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function) + x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] diff --git a/comfy/sd.py b/comfy/sd.py index bff9ee141..b0482c782 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -243,8 +243,15 @@ def set_attr(obj, attr, value): setattr(obj, attrs[-1], torch.nn.Parameter(value)) del prev +def get_attr(obj, attr): + attrs = attr.split(".") + for name in attrs: + obj = getattr(obj, name) + return obj + + class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0): + def __init__(self, model, load_device, offload_device, size=0, current_device=None): self.size = size self.model = model self.patches = {} @@ -253,6 +260,10 @@ class ModelPatcher: self.model_size() self.load_device = load_device self.offload_device = offload_device + if current_device is None: + self.current_device = self.offload_device + else: + self.current_device = current_device def model_size(self): if self.size > 0: @@ -267,7 +278,7 @@ class ModelPatcher: return size def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -276,6 +287,11 @@ class ModelPatcher: n.model_keys = self.model_keys return n + def is_clone(self, other): + if hasattr(other, 'model') and self.model is other.model: + return True + return False + def set_model_sampler_cfg_function(self, sampler_cfg_function): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way @@ -390,6 +406,11 @@ class ModelPatcher: out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) set_attr(self.model, key, out_weight) del temp_weight + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + return self.model def calculate_weight(self, patches, weight, key): @@ -482,7 +503,7 @@ class ModelPatcher: return weight - def unpatch_model(self): + def unpatch_model(self, device_to=None): keys = list(self.backup.keys()) for k in keys: @@ -490,6 +511,11 @@ class ModelPatcher: self.backup = {} + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + + def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = model_lora_keys_unet(model.model) key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) @@ -555,7 +581,7 @@ class CLIP: else: self.cond_stage_model.reset_clip_layer() - model_management.load_model_gpu(self.patcher) + self.load_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) if return_pooled: return cond, pooled @@ -571,11 +597,9 @@ class CLIP: def get_sd(self): return self.cond_stage_model.state_dict() - def patch_model(self): - self.patcher.patch_model() - - def unpatch_model(self): - self.patcher.unpatch_model() + def load_model(self): + model_management.load_model_gpu(self.patcher) + return self.patcher def get_key_patches(self): return self.patcher.get_key_patches() @@ -630,11 +654,12 @@ class VAE: return samples def decode(self, samples_in): - model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) try: + memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 + model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) - batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64)) + batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") @@ -650,19 +675,19 @@ class VAE: return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): - model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) self.first_stage_model = self.first_stage_model.to(self.offload_device) return output.movedim(1,-1) def encode(self, pixel_samples): - model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: + memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) - batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): @@ -677,7 +702,6 @@ class VAE: return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): - model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) @@ -757,6 +781,7 @@ class ControlNet(ControlBase): def __init__(self, control_model, global_average_pooling=False, device=None): super().__init__(device) self.control_model = control_model + self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond, batched_number): @@ -780,19 +805,14 @@ class ControlNet(ControlBase): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - if self.control_model.dtype == torch.float16: - precision_scope = torch.autocast - else: - precision_scope = contextlib.nullcontext - with precision_scope(model_management.get_autocast_device(self.device)): - self.control_model = model_management.load_if_low_vram(self.control_model) - context = torch.cat(cond['c_crossattn'], 1) - y = cond.get('c_adm', None) - control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y) - self.control_model = model_management.unload_if_low_vram(self.control_model) + context = torch.cat(cond['c_crossattn'], 1) + y = cond.get('c_adm', None) + if y is not None: + y = y.to(self.control_model.dtype) + control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + out = {'middle':[], 'output': []} - autocast_enabled = torch.is_autocast_enabled() for i in range(len(control)): if i == (len(control) - 1): @@ -806,7 +826,7 @@ class ControlNet(ControlBase): x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) x *= self.strength - if x.dtype != output_dtype and not autocast_enabled: + if x.dtype != output_dtype: x = x.to(output_dtype) if control_prev is not None and key in control_prev: @@ -825,17 +845,133 @@ class ControlNet(ControlBase): def get_models(self): out = super().get_models() - out.append(self.control_model) + out.append(self.control_model_wrapped) return out +class ControlLoraOps: + class Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.up = None + self.down = None + self.bias = None + + def forward(self, input): + if self.up is not None: + return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + else: + return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) + + class Conv2d(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = False + self.output_padding = 0 + self.groups = groups + self.padding_mode = padding_mode + + self.weight = None + self.bias = None + self.up = None + self.down = None + + + def forward(self, input): + if self.up is not None: + return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + else: + return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) + + def conv_nd(self, dims, *args, **kwargs): + if dims == 2: + return self.Conv2d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +class ControlLora(ControlNet): + def __init__(self, control_weights, global_average_pooling=False, device=None): + ControlBase.__init__(self, device) + self.control_weights = control_weights + self.global_average_pooling = global_average_pooling + + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + controlnet_config = model.model_config.unet_config.copy() + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] + controlnet_config["operations"] = ControlLoraOps() + self.control_model = cldm.ControlNet(**controlnet_config) + if model_management.should_use_fp16(): + self.control_model.half() + self.control_model.to(model_management.get_torch_device()) + diffusion_model = model.diffusion_model + sd = diffusion_model.state_dict() + cm = self.control_model.state_dict() + + for k in sd: + weight = sd[k] + if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. + key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. + op = get_attr(diffusion_model, '.'.join(key_split[:-1])) + weight = op._hf_hook.weights_map[key_split[-1]] + + try: + set_attr(self.control_model, k, weight) + except: + pass + + for k in self.control_weights: + if k not in {"lora_controlnet"}: + set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device())) + + def copy(self): + c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) + self.copy_to(c) + return c + + def cleanup(self): + del self.control_model + self.control_model = None + super().cleanup() + + def get_models(self): + out = ControlBase.get_models(self) + return out def load_controlnet(ckpt_path, model=None): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) + if "lora_controlnet" in controlnet_data: + return ControlLora(controlnet_data) controlnet_config = None if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format use_fp16 = model_management.should_use_fp16() - controlnet_config = model_detection.model_config_from_diffusers_unet(controlnet_data, use_fp16).unet_config + controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) diffusers_keys = utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -874,6 +1010,9 @@ def load_controlnet(ckpt_path, model=None): if k in controlnet_data: new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + leftover_keys = controlnet_data.keys() + if len(leftover_keys) > 0: + print("leftover keys:", leftover_keys) controlnet_data = new_sd pth_key = 'control_model.zero_convs.0.0.weight' @@ -901,8 +1040,8 @@ def load_controlnet(ckpt_path, model=None): if pth: if 'difference' in controlnet_data: if model is not None: - m = model.patch_model() - model_sd = m.state_dict() + model_management.load_models_gpu([model]) + model_sd = model.model_state_dict() for x in controlnet_data: c_m = "control_model." if x.startswith(c_m): @@ -910,7 +1049,6 @@ def load_controlnet(ckpt_path, model=None): if sd_key in model_sd: cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) - model.unpatch_model() else: print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") @@ -970,11 +1108,10 @@ class T2IAdapter(ControlBase): output_dtype = x_noisy.dtype out = {'input':[]} - autocast_enabled = torch.is_autocast_enabled() for i in range(len(self.control_input)): key = 'input' x = self.control_input[i] * self.strength - if x.dtype != output_dtype and not autocast_enabled: + if x.dtype != output_dtype: x = x.to(output_dtype) if control_prev is not None and key in control_prev: @@ -1001,7 +1138,6 @@ class T2IAdapter(ControlBase): self.copy_to(c) return c - def load_t2i_adapter(t2i_data): keys = t2i_data.keys() if 'adapter' in keys: @@ -1087,7 +1223,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return model + return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): #TODO: this function is a mess and should be removed eventually @@ -1199,8 +1335,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clipvision: clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) + dtype = torch.float32 + if fp16: + dtype = torch.float16 + + inital_load_device = model_management.unet_inital_load_device(parameters, dtype) offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device) + model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model.load_model_weights(sd, "model.diffusion_model.") if output_vae: @@ -1221,7 +1362,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) + model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + if inital_load_device != torch.device("cpu"): + print("loaded straight to GPU") + model_management.load_model_gpu(model_patcher) + + return (model_patcher, clip, vae, clipvision) def load_unet(unet_path): #load unet in diffusers format @@ -1249,14 +1395,6 @@ def load_unet(unet_path): #load unet in diffusers format return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): - try: - model.patch_model() - clip.patch_model() - sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) - utils.save_torch_file(sd, output_path, metadata=metadata) - model.unpatch_model() - clip.unpatch_model() - except Exception as e: - model.unpatch_model() - clip.unpatch_model() - raise e + model_management.load_models_gpu([model, clip.load_model()]) + sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) + utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy_extras/nodes/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py index 8194a89d7..acdc0d96b 100644 --- a/comfy_extras/nodes/nodes_mask.py +++ b/comfy_extras/nodes/nodes_mask.py @@ -1,3 +1,5 @@ +import numpy as np +from scipy.ndimage import grey_dilation import torch from comfy.nodes.common import MAX_RESOLUTION @@ -277,6 +279,35 @@ class FeatherMask: output[-y, :] *= feather_rate return (output,) + +class GrowMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "expand": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "tapered_corners": ("BOOLEAN", {"default": True}), + }, + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "expand_mask" + + def expand_mask(self, mask, expand, tapered_corners): + c = 0 if tapered_corners else 1 + kernel = np.array([[c, 1, c], + [1, 1, 1], + [c, 1, c]]) + output = mask.numpy().copy() + while expand > 0: + output = grey_dilation(output, footprint=kernel) + expand -= 1 + output = torch.from_numpy(output) + return (output,) @@ -290,6 +321,7 @@ NODE_CLASS_MAPPINGS = { "CropMask": CropMask, "MaskComposite": MaskComposite, "FeatherMask": FeatherMask, + "GrowMask": GrowMask, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes/nodes_post_processing.py b/comfy_extras/nodes/nodes_post_processing.py index a138b292e..51bdb24fa 100644 --- a/comfy_extras/nodes/nodes_post_processing.py +++ b/comfy_extras/nodes/nodes_post_processing.py @@ -2,6 +2,7 @@ import numpy as np import torch import torch.nn.functional as F from PIL import Image +import math import comfy.utils @@ -209,9 +210,36 @@ class Sharpen: return (result,) +class ImageScaleToTotalPixels: + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + crop_methods = ["disabled", "center"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, image, upscale_method, megapixels): + samples = image.movedim(-1,1) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1,-1) + return (s,) + NODE_CLASS_MAPPINGS = { "ImageBlend": Blend, "ImageBlur": Blur, "ImageQuantize": Quantize, "ImageSharpen": Sharpen, + "ImageScaleToTotalPixels": ImageScaleToTotalPixels, } diff --git a/web/index.html b/web/index.html index 71067d993..41bc246c0 100644 --- a/web/index.html +++ b/web/index.html @@ -6,6 +6,7 @@ +