From 63e5fd17907fa2100725d6d46053084c14065232 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Wed, 4 Oct 2023 19:45:15 -0300 Subject: [PATCH 01/34] Option to input directory --- comfy/cli_args.py | 1 + folder_paths.py | 4 ++++ main.py | 5 +++++ 3 files changed, 10 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index ffae81c49..35d44164f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") +parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") diff --git a/folder_paths.py b/folder_paths.py index 4a10c68e7..898513b0e 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -46,6 +46,10 @@ def set_temp_directory(temp_dir): global temp_directory temp_directory = temp_dir +def set_input_directory(input_dir): + global input_directory + input_directory = input_dir + def get_output_directory(): global output_directory return output_directory diff --git a/main.py b/main.py index 7c5eaee0a..875ea1aa9 100644 --- a/main.py +++ b/main.py @@ -175,6 +175,11 @@ if __name__ == "__main__": print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) + if args.input_directory: + input_dir = os.path.abspath(args.input_directory) + print(f"Setting input directory to: {input_dir}") + folder_paths.set_input_directory(input_dir) + if args.quick_test_for_ci: exit(0) From 7bb9f6b7e87d533cbbbc7aceafafb744a62dbfa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 9 Oct 2023 01:42:15 -0400 Subject: [PATCH 02/34] Add a VAESave node. --- comfy_extras/nodes_model_merging.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 3d42d7806..1e3fc9359 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -1,6 +1,7 @@ import comfy.sd import comfy.utils import comfy.model_base +import comfy.model_management import folder_paths import json @@ -178,6 +179,39 @@ class CheckpointSave: comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) return {} +class VAESave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) + return {} NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, @@ -186,4 +220,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, + "VAESave": VAESave, } From 4308862ce0434718a6f3acdc7164f3a1436220e7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 9 Oct 2023 01:51:01 -0400 Subject: [PATCH 03/34] Add a note to README about pytorch 3.12 not being supported. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 925caa732..6bef25cee 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,8 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae +Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier. + ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: From 9eb621c95a7c867d118e90348057474cbc96c20c Mon Sep 17 00:00:00 2001 From: Yukimasa Funaoka Date: Tue, 10 Oct 2023 13:21:44 +0900 Subject: [PATCH 04/34] Supports TAESD models in safetensors format --- comfy/latent_formats.py | 4 ++-- comfy/taesd/taesd.py | 12 ++++++++++-- latent_preview.py | 7 ++++++- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index fadc0eec7..c209087e0 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -20,7 +20,7 @@ class SD15(LatentFormat): [-0.2829, 0.1762, 0.2721], [-0.2120, -0.2616, -0.7177] ] - self.taesd_decoder_name = "taesd_decoder.pth" + self.taesd_decoder_name = "taesd_decoder" class SDXL(LatentFormat): def __init__(self): @@ -32,4 +32,4 @@ class SDXL(LatentFormat): [ 0.0568, 0.1687, -0.0755], [-0.3112, -0.2359, -0.2076] ] - self.taesd_decoder_name = "taesdxl_decoder.pth" + self.taesd_decoder_name = "taesdxl_decoder" diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 1549345ae..92f74c11a 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -50,9 +50,17 @@ class TAESD(nn.Module): self.encoder = Encoder() self.decoder = Decoder() if encoder_path is not None: - self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) + if encoder_path.lower().endswith(".safetensors"): + import safetensors.torch + self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu")) + else: + self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) if decoder_path is not None: - self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) + if decoder_path.lower().endswith(".safetensors"): + import safetensors.torch + self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu")) + else: + self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) @staticmethod def scale_latents(x): diff --git a/latent_preview.py b/latent_preview.py index 740e08607..e1553c85c 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -56,7 +56,12 @@ def get_previewer(device, latent_format): # TODO previewer methods taesd_decoder_path = None if latent_format.taesd_decoder_name is not None: - taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) + taesd_decoder_path = next( + (fn for fn in folder_paths.get_filename_list("vae_approx") + if fn.startswith(latent_format.taesd_decoder_name)), + "" + ) + taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB From 877553843f914d9b18d470c285f0d745a4a85c06 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 10 Oct 2023 01:24:49 -0400 Subject: [PATCH 05/34] Add a CLIPSave node to save CLIP model weights. --- comfy_extras/nodes_model_merging.py | 57 +++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 1e3fc9359..dad1dd637 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -179,6 +179,62 @@ class CheckpointSave: comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) return {} +class CLIPSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip": ("CLIP",), + "filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + comfy.model_management.load_models_gpu([clip.load_model()]) + clip_sd = clip.get_sd() + + for prefix in ["clip_l.", "clip_g.", ""]: + k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys())) + current_clip_sd = {} + for x in k: + current_clip_sd[x] = clip_sd.pop(x) + if len(current_clip_sd) == 0: + continue + + p = prefix[:-1] + replace_prefix = {} + filename_prefix_ = filename_prefix + if len(p) > 0: + filename_prefix_ = "{}_{}".format(filename_prefix_, p) + replace_prefix[prefix] = "" + replace_prefix["transformer."] = "" + + full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix) + + comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata) + return {} + class VAESave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -220,5 +276,6 @@ NODE_CLASS_MAPPINGS = { "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, + "CLIPSave": CLIPSave, "VAESave": VAESave, } From be903eb2e2921f03a3a03dc9d6b0c6437ae201f5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 10 Oct 2023 01:25:47 -0400 Subject: [PATCH 06/34] Add default CheckpointSave, CLIPSave and VAESave paths to model paths. --- main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/main.py b/main.py index 875ea1aa9..1100a07f4 100644 --- a/main.py +++ b/main.py @@ -175,6 +175,11 @@ if __name__ == "__main__": print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) + #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes + folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) + folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) + folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) + if args.input_directory: input_dir = os.path.abspath(args.input_directory) print(f"Setting input directory to: {input_dir}") From 5e885bd9c822a3cc7d50d75a40958265f497cd03 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 10 Oct 2023 21:46:53 -0400 Subject: [PATCH 07/34] Cleanup. --- comfy/taesd/taesd.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 92f74c11a..8df1f1609 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion import torch import torch.nn as nn +import comfy.utils + def conv(n_in, n_out, **kwargs): return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) @@ -50,17 +52,9 @@ class TAESD(nn.Module): self.encoder = Encoder() self.decoder = Decoder() if encoder_path is not None: - if encoder_path.lower().endswith(".safetensors"): - import safetensors.torch - self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu")) - else: - self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) + self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) if decoder_path is not None: - if decoder_path.lower().endswith(".safetensors"): - import safetensors.torch - self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu")) - else: - self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) + self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) @staticmethod def scale_latents(x): From 8cc75c64ff7188ce72cd4ba595119586e425c09f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 11 Oct 2023 01:34:38 -0400 Subject: [PATCH 08/34] Let unet wrapper functions have .to attributes. --- comfy/model_patcher.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ba505221e..50b725b86 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -107,6 +107,10 @@ class ModelPatcher: for k in patch_list: if hasattr(patch_list[k], "to"): patch_list[k] = patch_list[k].to(device) + if "unet_wrapper_function" in self.model_options: + wrap_func = self.model_options["unet_wrapper_function"] + if hasattr(wrap_func, "to"): + self.model_options["unet_wrapper_function"] = wrap_func.to(device) def model_dtype(self): if hasattr(self.model, "get_dtype"): From 1a4bd9e9a6fc2b364ebb547dbd80736548cf9f5c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 11 Oct 2023 15:47:53 -0400 Subject: [PATCH 09/34] Refactor the attention functions. There's no reason for the whole CrossAttention object to be repeated when only the operation in the middle changes. --- comfy/ldm/modules/attention.py | 538 +++++++------------- comfy/ldm/modules/diffusionmodules/model.py | 13 - 2 files changed, 186 insertions(+), 365 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index fcae6b66a..3230cfaf5 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -94,253 +94,220 @@ def zero_module(module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) +def attention_basic(q, k, v, heads, mask=None): + h = heads + scale = (q.shape[-1] // heads) ** -0.5 + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + del q, k - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) - # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) - h_ = self.proj_out(h_) - - return x+h_ + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return out -class CrossAttentionBirchSan(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) +def attention_sub_quad(query, key, value, heads, mask=None): + scale = (query.shape[-1] // heads) ** -0.5 + query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) + key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1) + del key + value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) - self.scale = dim_head ** -0.5 - self.heads = heads + dtype = query.dtype + upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 + if upcast_attention: + bytes_per_token = torch.finfo(torch.float32).bits//8 + else: + bytes_per_token = torch.finfo(query.dtype).bits//8 + batch_x_heads, q_tokens, _ = query.shape + _, _, k_tokens = key_t.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) - self.to_out = nn.Sequential( - operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), - nn.Dropout(dropout) - ) + chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD - def forward(self, x, context=None, value=None, mask=None): - h = self.heads + kv_chunk_size_min = None - query = self.to_q(x) - context = default(context, x) - key = self.to_k(context) - if value is not None: - value = self.to_v(value) - else: - value = self.to_v(context) + #not sure at all about the math here + #TODO: tweak this + if mem_free_total > 8192 * 1024 * 1024 * 1.3: + query_chunk_size_x = 1024 * 4 + elif mem_free_total > 4096 * 1024 * 1024 * 1.3: + query_chunk_size_x = 1024 * 2 + else: + query_chunk_size_x = 1024 + kv_chunk_size_min_x = None + kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 + if kv_chunk_size_x < 1024: + kv_chunk_size_x = None - del context, x + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + else: + query_chunk_size = query_chunk_size_x + kv_chunk_size = kv_chunk_size_x + kv_chunk_size_min = kv_chunk_size_min_x - query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) - key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1) - del key - value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) + hidden_states = efficient_dot_product_attention( + query, + key_t, + value, + query_chunk_size=query_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min=kv_chunk_size_min, + use_checkpoint=False, + upcast_attention=upcast_attention, + ) - dtype = query.dtype - upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 - if upcast_attention: - bytes_per_token = torch.finfo(torch.float32).bits//8 - else: - bytes_per_token = torch.finfo(query.dtype).bits//8 - batch_x_heads, q_tokens, _ = query.shape - _, _, k_tokens = key_t.shape - qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + hidden_states = hidden_states.to(dtype) - mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) + hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) + return hidden_states - chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD +def attention_split(q, k, v, heads, mask=None): + scale = (q.shape[-1] // heads) ** -0.5 + h = heads + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - kv_chunk_size_min = None + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - #not sure at all about the math here - #TODO: tweak this - if mem_free_total > 8192 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 4 - elif mem_free_total > 4096 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 2 - else: - query_chunk_size_x = 1024 - kv_chunk_size_min_x = None - kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 - if kv_chunk_size_x < 1024: - kv_chunk_size_x = None + mem_free_total = model_management.get_free_memory(q.device) - if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: - # the big matmul fits into our memory limit; do everything in 1 chunk, - # i.e. send it down the unchunked fast-path - query_chunk_size = q_tokens - kv_chunk_size = k_tokens - else: - query_chunk_size = query_chunk_size_x - kv_chunk_size = kv_chunk_size_x - kv_chunk_size_min = kv_chunk_size_min_x - - hidden_states = efficient_dot_product_attention( - query, - key_t, - value, - query_chunk_size=query_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min=kv_chunk_size_min, - use_checkpoint=self.training, - upcast_attention=upcast_attention, - ) - - hidden_states = hidden_states.to(dtype) - - hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2) - - out_proj, dropout = self.to_out - hidden_states = out_proj(hidden_states) - hidden_states = dropout(hidden_states) - - return hidden_states + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 -class CrossAttentionDoggettx(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - self.scale = dim_head ** -0.5 - self.heads = heads + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') - self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - - self.to_out = nn.Sequential( - operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), - nn.Dropout(dropout) - ) - - def forward(self, x, context=None, value=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) - if value is not None: - v_in = self.to_v(value) - del value - else: - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = model_management.get_free_memory(q.device) - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') - - # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) - first_op_done = False - cleared_cache = False - while True: - try: - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale - else: - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - first_op_done = True - - s2 = s1.softmax(dim=-1).to(v.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - break - except model_management.OOM_EXCEPTION as e: - if first_op_done == False: - model_management.soft_empty_cache(True) - if cleared_cache == False: - cleared_cache = True - print("out of memory error, emptying cache and trying again") - continue - steps *= 2 - if steps > 64: - raise e - print("out of memory error, increasing steps and trying again", steps) + # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) + first_op_done = False + cleared_cache = False + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + first_op_done = True + + s2 = s1.softmax(dim=-1).to(v.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + break + except model_management.OOM_EXCEPTION as e: + if first_op_done == False: + model_management.soft_empty_cache(True) + if cleared_cache == False: + cleared_cache = True + print("out of memory error, emptying cache and trying again") + continue + steps *= 2 + if steps > 64: raise e + print("out of memory error, increasing steps and trying again", steps) + else: + raise e - del q, k, v + del q, k, v - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + return r2 - return self.to_out(r2) +def attention_xformers(q, k, v, heads, mask=None): + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], heads, -1) + .permute(0, 2, 1, 3) + .reshape(b * heads, t.shape[1], -1) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, heads, out.shape[1], -1) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], -1) + ) + return out + +def attention_pytorch(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + + if exists(mask): + raise NotImplementedError + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out + +optimized_attention = attention_basic + +if model_management.xformers_enabled(): + print("Using xformers cross attention") + optimized_attention = attention_xformers +elif model_management.pytorch_attention_enabled(): + print("Using pytorch cross attention") + optimized_attention = attention_pytorch +else: + if args.use_split_cross_attention: + print("Using split optimization for cross attention") + optimized_attention = attention_split + else: + print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + optimized_attention = attention_sub_quad class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): @@ -348,62 +315,6 @@ class CrossAttention(nn.Module): inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 - self.heads = heads - - self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - - self.to_out = nn.Sequential( - operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), - nn.Dropout(dropout) - ) - - def forward(self, x, context=None, value=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - if value is not None: - v = self.to_v(value) - del value - else: - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - else: - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - del q, k - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - sim = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', sim, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - -class MemoryEfficientCrossAttention(nn.Module): - # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - self.heads = heads self.dim_head = dim_head @@ -412,7 +323,6 @@ class MemoryEfficientCrossAttention(nn.Module): self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) @@ -424,85 +334,9 @@ class MemoryEfficientCrossAttention(nn.Module): else: v = self.to_v(context) - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) + out = optimized_attention(q, k, v, self.heads, mask) return self.to_out(out) -class CrossAttentionPytorch(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.heads = heads - self.dim_head = dim_head - - self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - - self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None - - def forward(self, x, context=None, value=None, mask=None): - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - if value is not None: - v = self.to_v(value) - del value - else: - v = self.to_v(context) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2), - (q, k, v), - ) - - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - - if exists(mask): - raise NotImplementedError - out = ( - out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head) - ) - - return self.to_out(out) - -if model_management.xformers_enabled(): - print("Using xformers cross attention") - CrossAttention = MemoryEfficientCrossAttention -elif model_management.pytorch_attention_enabled(): - print("Using pytorch cross attention") - CrossAttention = CrossAttentionPytorch -else: - if args.use_split_cross_attention: - print("Using split optimization for cross attention") - CrossAttention = CrossAttentionDoggettx - else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") - CrossAttention = CrossAttentionBirchSan - class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5f38640c3..e6cf954ff 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -6,7 +6,6 @@ import numpy as np from einops import rearrange from typing import Optional, Any -from ..attention import MemoryEfficientCrossAttention from comfy import model_management import comfy.ops @@ -352,15 +351,6 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): out = self.proj_out(out) return x+out -class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None): - b, c, h, w = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') - out = super().forward(x, context=context, mask=mask) - out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) - return x + out - - def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' if model_management.xformers_enabled_vae() and attn_type == "vanilla": @@ -376,9 +366,6 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): return MemoryEfficientAttnBlock(in_channels) elif attn_type == "vanilla-pytorch": return MemoryEfficientAttnBlockPytorch(in_channels) - elif type == "memory-efficient-cross-attn": - attn_kwargs["query_dim"] = in_channels - return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) elif attn_type == "none": return nn.Identity(in_channels) else: From ac7d8cfa875e08623993da8109cc73a68df42379 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 11 Oct 2023 20:24:17 -0400 Subject: [PATCH 10/34] Allow attn_mask in attention_pytorch. --- comfy/ldm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3230cfaf5..ac0d9c8c5 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -284,7 +284,7 @@ def attention_pytorch(q, k, v, heads, mask=None): (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if exists(mask): raise NotImplementedError From 20d3852aa1a23db24b288531ef07c26b49308629 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 11 Oct 2023 20:35:50 -0400 Subject: [PATCH 11/34] Pull some small changes from the other repo. --- comfy/model_management.py | 5 +++-- comfy/utils.py | 4 ++++ comfy_extras/nodes_custom_sampler.py | 3 ++- execution.py | 19 ++++++++++--------- folder_paths.py | 2 ++ nodes.py | 2 +- 6 files changed, 22 insertions(+), 13 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8b8963726..3b43b21ac 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -354,6 +354,8 @@ def load_models_gpu(models, memory_required=0): current_loaded_models.insert(0, current_loaded_models.pop(index)) models_already_loaded.append(loaded_model) else: + if hasattr(x, "model"): + print(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) if len(models_to_load) == 0: @@ -363,7 +365,7 @@ def load_models_gpu(models, memory_required=0): free_memory(extra_mem, d, models_already_loaded) return - print("loading new") + print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") total_memory_required = {} for loaded_model in models_to_load: @@ -405,7 +407,6 @@ def load_model_gpu(model): def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - print(sys.getrefcount(current_loaded_models[i].model)) if sys.getrefcount(current_loaded_models[i].model) <= 2: to_delete = [i] + to_delete diff --git a/comfy/utils.py b/comfy/utils.py index 7843b58cc..df016ef9e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -408,6 +408,10 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am output[b:b+1] = out/out_div return output +PROGRESS_BAR_ENABLED = True +def set_progress_bar_enabled(enabled): + global PROGRESS_BAR_ENABLED + PROGRESS_BAR_ENABLED = enabled PROGRESS_BAR_HOOK = None def set_progress_bar_global_hook(function): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 9391c7147..b52ad8fbd 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -3,6 +3,7 @@ import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling import latent_preview import torch +import comfy.utils class BasicScheduler: @@ -219,7 +220,7 @@ class SamplerCustom: x0_output = {} callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) - disable_pbar = False + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() diff --git a/execution.py b/execution.py index 5f5d6c738..918c2bc5c 100644 --- a/execution.py +++ b/execution.py @@ -2,6 +2,7 @@ import os import sys import copy import json +import logging import threading import heapq import traceback @@ -156,7 +157,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) except comfy.model_management.InterruptProcessingException as iex: - print("Processing interrupted") + logging.info("Processing interrupted") # skip formatting inputs/outputs error_details = { @@ -177,8 +178,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute for node_id, node_outputs in outputs.items(): output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] - print("!!! Exception during processing !!!") - print(traceback.format_exc()) + logging.error("!!! Exception during processing !!!") + logging.error(traceback.format_exc()) error_details = { "node_id": unique_id, @@ -636,11 +637,11 @@ def validate_prompt(prompt): if valid is True: good_outputs.add(o) else: - print(f"Failed to validate prompt for output {o}:") + logging.error(f"Failed to validate prompt for output {o}:") if len(reasons) > 0: - print("* (prompt):") + logging.error("* (prompt):") for reason in reasons: - print(f" - {reason['message']}: {reason['details']}") + logging.error(f" - {reason['message']}: {reason['details']}") errors += [(o, reasons)] for node_id, result in validated.items(): valid = result[0] @@ -656,11 +657,11 @@ def validate_prompt(prompt): "dependent_outputs": [], "class_type": class_type } - print(f"* {class_type} {node_id}:") + logging.error(f"* {class_type} {node_id}:") for reason in reasons: - print(f" - {reason['message']}: {reason['details']}") + logging.error(f" - {reason['message']}: {reason['details']}") node_errors[node_id]["dependent_outputs"].append(o) - print("Output will be ignored") + logging.error("Output will be ignored") if len(good_outputs) == 0: errors_list = [] diff --git a/folder_paths.py b/folder_paths.py index 898513b0e..5d121b443 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -29,6 +29,8 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) +folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) + output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") diff --git a/nodes.py b/nodes.py index 16bf07cca..208cbc84f 100644 --- a/nodes.py +++ b/nodes.py @@ -1202,7 +1202,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = latent["noise_mask"] callback = latent_preview.prepare_callback(model, steps) - disable_pbar = False + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) From 88733c997fd807a572d4a214d2c15fc5dd17b3c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 11 Oct 2023 21:29:03 -0400 Subject: [PATCH 12/34] pytorch_attention_enabled can now return True when xformers is enabled. --- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/model_management.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index e6cf954ff..6576df4b6 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -355,7 +355,7 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' if model_management.xformers_enabled_vae() and attn_type == "vanilla": attn_type = "vanilla-xformers" - if model_management.pytorch_attention_enabled() and attn_type == "vanilla": + elif model_management.pytorch_attention_enabled() and attn_type == "vanilla": attn_type = "vanilla-pytorch" print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": diff --git a/comfy/model_management.py b/comfy/model_management.py index 3b43b21ac..3c390d9ca 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -154,14 +154,18 @@ def is_nvidia(): return True return False -ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +ENABLE_PYTORCH_ATTENTION = False +if args.use_pytorch_cross_attention: + ENABLE_PYTORCH_ATTENTION = True + XFORMERS_IS_AVAILABLE = False + VAE_DTYPE = torch.float32 try: if is_nvidia(): torch_version = torch.version.__version__ if int(torch_version[0]) >= 2: - if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported(): VAE_DTYPE = torch.bfloat16 @@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) - XFORMERS_IS_AVAILABLE = False if args.lowvram: set_vram_to = VRAMState.LOW_VRAM From 41d2c5660dab2404c7db78ff14c7faecd6c57e1f Mon Sep 17 00:00:00 2001 From: Chris Date: Thu, 12 Oct 2023 14:26:53 +1100 Subject: [PATCH 13/34] add query --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 7698d0f11..3cf3585d2 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1592,7 +1592,7 @@ export class ComfyApp { all_inputs = all_inputs.concat(Object.keys(parent.inputs)) for (let parent_input in all_inputs) { parent_input = all_inputs[parent_input]; - if (parent.inputs[parent_input].type === node.inputs[i].type) { + if (parent.inputs[parent_input]?.type === node.inputs[i].type) { link = parent.getInputLink(parent_input); if (link) { parent = parent.getInputNode(parent_input); From 851a4bdb803871d62ceca4249036da34449891ed Mon Sep 17 00:00:00 2001 From: Nick Teeple <64399276+nickteeple@users.noreply.github.com> Date: Thu, 12 Oct 2023 21:26:27 +0800 Subject: [PATCH 14/34] Update extra_model_paths.yaml.example with comfy specific example --- extra_model_paths.yaml.example | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 36078fffc..0870811f3 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -1,5 +1,20 @@ #Rename this to extra_model_paths.yaml and ComfyUI will load it +#config for comfyui +#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc. +#in this example, we make a central folder called c:/machine-learning/ and replicate the ComfyUI models folder structure inside +comfyui: + base_path: c:/machine-learning/ + checkpoints: models/checkpoints/ + clip: models/clip/ + clip_vision: models/clip_vision/ + configs: models/configs/ + controlnet: models/controlnet/ + embeddings: models/embeddings/ + loras: models/loras/ + upscale_models: models/upscale_models/ + vae: models/vae/ + #config for a1111 ui #all you have to do is change the base_path to where yours is installed a111: From fee3b0c0700aedaafacd72ef90d49c7be8c1a003 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 12 Oct 2023 20:54:43 -0400 Subject: [PATCH 15/34] Move and comment out. --- extra_model_paths.yaml.example | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 0870811f3..846d04dbe 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -1,19 +1,5 @@ #Rename this to extra_model_paths.yaml and ComfyUI will load it -#config for comfyui -#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc. -#in this example, we make a central folder called c:/machine-learning/ and replicate the ComfyUI models folder structure inside -comfyui: - base_path: c:/machine-learning/ - checkpoints: models/checkpoints/ - clip: models/clip/ - clip_vision: models/clip_vision/ - configs: models/configs/ - controlnet: models/controlnet/ - embeddings: models/embeddings/ - loras: models/loras/ - upscale_models: models/upscale_models/ - vae: models/vae/ #config for a1111 ui #all you have to do is change the base_path to where yours is installed @@ -34,6 +20,21 @@ a111: hypernetworks: models/hypernetworks controlnet: models/ControlNet +#config for comfyui +#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc. + +#comfyui: +# base_path: path/to/comfyui/ +# checkpoints: models/checkpoints/ +# clip: models/clip/ +# clip_vision: models/clip_vision/ +# configs: models/configs/ +# controlnet: models/controlnet/ +# embeddings: models/embeddings/ +# loras: models/loras/ +# upscale_models: models/upscale_models/ +# vae: models/vae/ + #other_ui: # base_path: path/to/ui # checkpoints: models/checkpoints From 87097a11c32ac7393093055c4ed1432553995560 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Fri, 13 Oct 2023 12:26:54 -0300 Subject: [PATCH 16/34] Fix FeatherMask --- comfy_extras/nodes_mask.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9b0b289c1..3cc4cfd76 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -282,10 +282,10 @@ class FeatherMask: def feather(self, mask, left, top, right, bottom): output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() - left = min(left, output.shape[1]) - right = min(right, output.shape[1]) - top = min(top, output.shape[0]) - bottom = min(bottom, output.shape[0]) + left = min(left, output.shape[-1]) + right = min(right, output.shape[-1]) + top = min(top, output.shape[-2]) + bottom = min(bottom, output.shape[-2]) for x in range(left): feather_rate = (x + 1.0) / left From b5fa3d28d7a9242a73c40b3b37e526267cbd67fd Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Fri, 13 Oct 2023 13:40:53 -0300 Subject: [PATCH 17/34] Fix MaskComposite --- comfy_extras/nodes_mask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 3cc4cfd76..d8c65c2b6 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -240,8 +240,8 @@ class MaskComposite: right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) visible_width, visible_height = (right - left, bottom - top,) - source_portion = source[:visible_height, :visible_width] - destination_portion = destination[top:bottom, left:right] + source_portion = source[:, :visible_height, :visible_width] + destination_portion = destination[:, top:bottom, left:right] if operation == "multiply": output[:, top:bottom, left:right] = destination_portion * source_portion From 9a55dadb4c9c82d772d934125112b8b0613e4f28 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 13 Oct 2023 14:35:21 -0400 Subject: [PATCH 18/34] Refactor code so model can be a dtype other than fp32 or fp16. --- comfy/cldm/cldm.py | 6 ++-- comfy/controlnet.py | 11 +++---- .../modules/diffusionmodules/openaimodel.py | 6 ++-- comfy/model_detection.py | 32 +++++++++---------- comfy/model_management.py | 5 +++ comfy/sd.py | 20 ++++++------ 6 files changed, 39 insertions(+), 41 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 251483131..f982d648c 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -34,8 +34,7 @@ class ControlNet(nn.Module): dims=2, num_classes=None, use_checkpoint=False, - use_fp16=False, - use_bf16=False, + dtype=torch.float32, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, @@ -108,8 +107,7 @@ class ControlNet(nn.Module): self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.dtype = dtype self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ea219c7e5..73a40acfa 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None): controlnet_config = None if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - use_fp16 = comfy.model_management.should_use_fp16() - controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) + unet_dtype = comfy.model_management.unet_dtype() + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None): return net if controlnet_config is None: - use_fp16 = comfy.model_management.should_use_fp16() - controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config + unet_dtype = comfy.model_management.unet_dtype() + controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) @@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None): missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) print(missing, unexpected) - if use_fp16: - control_model = control_model.half() + control_model = control_model.to(unet_dtype) global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index b42637c82..bf58a4045 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -296,8 +296,7 @@ class UNetModel(nn.Module): dims=2, num_classes=None, use_checkpoint=False, - use_fp16=False, - use_bf16=False, + dtype=th.float32, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, @@ -370,8 +369,7 @@ class UNetModel(nn.Module): self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.dtype = dtype self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 787c78575..0ff2e7fb5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count -def detect_unet_config(state_dict, key_prefix, use_fp16): +def detect_unet_config(state_dict, key_prefix, dtype): state_dict_keys = list(state_dict.keys()) unet_config = { @@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16): else: unet_config["adm_in_channels"] = None - unet_config["use_fp16"] = use_fp16 + unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] @@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config): print("no match", unet_config) return None -def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) +def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False): + unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype) model_config = model_config_from_unet_config(unet_config) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) else: return model_config -def unet_config_from_diffusers_unet(state_dict, use_fp16): +def unet_config_from_diffusers_unet(state_dict, dtype): match = {} attention_resolutions = [] @@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16): match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, + 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} @@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16): return unet_config return None -def model_config_from_diffusers_unet(state_dict, use_fp16): - unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16) +def model_config_from_diffusers_unet(state_dict, dtype): + unet_config = unet_config_from_diffusers_unet(state_dict, dtype) if unet_config is not None: return model_config_from_unet_config(unet_config) return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 3c390d9ca..1161c2447 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -448,6 +448,11 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev +def unet_dtype(device=None, model_params=0): + if should_use_fp16(device=device, model_params=model_params): + return torch.float16 + return torch.float32 + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() diff --git a/comfy/sd.py b/comfy/sd.py index cfd6fb3cb..fd8b94df8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -327,7 +327,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if "params" in model_config_params["unet_config"]: unet_config = model_config_params["unet_config"]["params"] if "use_fp16" in unet_config: - fp16 = unet_config["use_fp16"] + fp16 = unet_config.pop("use_fp16") + if fp16: + unet_config["dtype"] = torch.float16 noise_aug_config = None if "noise_aug_config" in model_config_params: @@ -405,12 +407,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip_target = None parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") - fp16 = model_management.should_use_fp16(model_params=parameters) + unet_dtype = model_management.unet_dtype(model_params=parameters) class WeightsLoader(torch.nn.Module): pass - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16) + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -418,12 +420,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clipvision: clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) - dtype = torch.float32 - if fp16: - dtype = torch.float16 - if output_model: - inital_load_device = model_management.unet_inital_load_device(parameters, dtype) + inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) offload_device = model_management.unet_offload_device() model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model.load_model_weights(sd, "model.diffusion_model.") @@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet(unet_path): #load unet in diffusers format sd = comfy.utils.load_torch_file(unet_path) parameters = comfy.utils.calculate_parameters(sd) - fp16 = model_management.should_use_fp16(model_params=parameters) + unet_dtype = model_management.unet_dtype(model_params=parameters) if "input_blocks.0.0.weight" in sd: #ldm - model_config = model_detection.model_config_from_unet(sd, "", fp16) + model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) new_sd = sd else: #diffusers - model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) + model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) if model_config is None: print("ERROR UNSUPPORTED UNET", unet_path) return None From fd4c5f07e7b4d034c7f9fa96d6853f4aeec7b2de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 13 Oct 2023 14:51:10 -0400 Subject: [PATCH 19/34] Add a --bf16-unet to test running the unet in bf16. --- comfy/cli_args.py | 2 ++ comfy/model_management.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 35d44164f..d86557646 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -53,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") +parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") + fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 1161c2447..c24c7b27e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -449,6 +449,8 @@ def unet_inital_load_device(parameters, dtype): return cpu_dev def unet_dtype(device=None, model_params=0): + if args.bf16_unet: + return torch.bfloat16 if should_use_fp16(device=device, model_params=model_params): return torch.float16 return torch.float32 From 25f0f4e9c8fdbad43311b8e5161b3e6efa9d58d9 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Sat, 14 Oct 2023 11:54:33 -0300 Subject: [PATCH 20/34] Shortcut Alt + C to collapse/uncollapse selected nodes --- README.md | 1 + web/scripts/app.js | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/README.md b/README.md index 6bef25cee..d622c9072 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git | Ctrl + S | Save workflow | | Ctrl + O | Load workflow | | Ctrl + A | Select all nodes | +| Alt + C | Collapse/uncollapse selected nodes | | Ctrl + M | Mute/unmute selected nodes | | Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | | Delete/Backspace | Delete selected nodes | diff --git a/web/scripts/app.js b/web/scripts/app.js index 3cf3585d2..1a07d69bc 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -928,6 +928,16 @@ export class ComfyApp { block_default = true; } + // Alt + C collapse/uncollapse + if (e.key === 'c' && e.altKey) { + if (this.selected_nodes) { + for (var i in this.selected_nodes) { + this.selected_nodes[i].collapse() + } + } + block_default = true; + } + // Ctrl+C Copy if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { // Trigger onCopy From 2e6270e328d4b43f9d76172f21f6d7636bf4a452 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Sat, 14 Oct 2023 11:56:44 -0300 Subject: [PATCH 21/34] Stop auto queue on error --- web/scripts/ui.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 1e7920167..c3b3fbda1 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -809,7 +809,8 @@ export class ComfyUI { if ( this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && - document.getElementById("autoQueueCheckbox").checked + document.getElementById("autoQueueCheckbox").checked && + ! app.lastExecutionError ) { app.queuePrompt(0, this.batchCount); } From 8d04978298d8bb2bcb969608a994fe4ef39be3ee Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Sat, 14 Oct 2023 11:59:35 -0300 Subject: [PATCH 22/34] Allow all extensions if extension list is empty --- folder_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index 5d121b443..4a38deec0 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -146,7 +146,7 @@ def recursive_search(directory, excluded_dir_names=None): return result, dirs def filter_files_extensions(files, extensions): - return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) + return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files))) From a7b65b9505b0504fefc7b57a5e80a243bd6630cf Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Sat, 14 Oct 2023 12:11:49 -0300 Subject: [PATCH 23/34] Group menu option select nodes --- web/extensions/core/groupOptions.js | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/web/extensions/core/groupOptions.js b/web/extensions/core/groupOptions.js index 1d935e90a..d523737dc 100644 --- a/web/extensions/core/groupOptions.js +++ b/web/extensions/core/groupOptions.js @@ -38,6 +38,15 @@ app.registerExtension({ } } + options.push({ + content: "Select Nodes", + callback: () => { + this.selectNodes(nodesInGroup); + this.graph.change(); + this.canvas.focus(); + } + }); + // Modes // 0: Always // 1: On Event From 7e09e889e3d52eaa0e51144e6257e3baa2f9349e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 15 Oct 2023 02:22:22 -0400 Subject: [PATCH 24/34] Make clear that the old CheckpointLoader is deprecated. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 208cbc84f..d0e9ffcc5 100644 --- a/nodes.py +++ b/nodes.py @@ -1660,7 +1660,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "KSampler": "KSampler", "KSamplerAdvanced": "KSampler (Advanced)", # Loaders - "CheckpointLoader": "Load Checkpoint (With Config)", + "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA", From bb064c97965da8af32e08d7ac00768f62c33cdab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 16 Oct 2023 02:31:24 -0400 Subject: [PATCH 25/34] Add a separate optimized_attention_masked function. --- comfy/ldm/modules/attention.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ac0d9c8c5..9cd14a537 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None): ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - - if exists(mask): - raise NotImplementedError out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) ) return out + optimized_attention = attention_basic +optimized_attention_masked = attention_basic if model_management.xformers_enabled(): print("Using xformers cross attention") @@ -309,6 +308,9 @@ else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad +if model_management.pytorch_attention_enabled(): + optimized_attention_masked = attention_pytorch + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() @@ -334,7 +336,10 @@ class CrossAttention(nn.Module): else: v = self.to_v(context) - out = optimized_attention(q, k, v, self.heads, mask) + if mask is None: + out = optimized_attention(q, k, v, self.heads) + else: + out = optimized_attention_masked(q, k, v, self.heads, mask) return self.to_out(out) From 7d5d0fd577fd6c4ad5de8d07a71aa7599c457b70 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Mon, 16 Oct 2023 15:12:40 -0300 Subject: [PATCH 26/34] Group options - Add Group For Selected Nodes - Add Selected Nodes To Group - Fit Group To Nodes --- web/extensions/core/groupOptions.js | 71 +++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/web/extensions/core/groupOptions.js b/web/extensions/core/groupOptions.js index d523737dc..3ec0e3fd9 100644 --- a/web/extensions/core/groupOptions.js +++ b/web/extensions/core/groupOptions.js @@ -5,6 +5,49 @@ function setNodeMode(node, mode) { node.graph.change(); } +function addNodesToGroup(group, nodes=[]) { + var x1, y1, x2, y2; + var nx1, ny1, nx2, ny2; + var node; + + x1 = y1 = x2 = y2 = -1; + nx1 = ny1 = nx2 = ny2 = -1; + + for (var n of [group._nodes, nodes]) { + for (var i in n) { + node = n[i] + + nx1 = node.pos[0] + ny1 = node.pos[1] + nx2 = node.pos[0] + node.size[0] + ny2 = node.pos[1] + node.size[1] + + if (x1 == -1 || nx1 < x1) { + x1 = nx1; + } + + if (y1 == -1 || ny1 < y1) { + y1 = ny1; + } + + if (x2 == -1 || nx2 > x2) { + x2 = nx2; + } + + if (y2 == -1 || ny2 > y2) { + y2 = ny2; + } + } + } + + var padding = 10; + + y1 = y1 - Math.round(group.font_size * 2.7); + + group.pos = [x1 - padding, y1 - padding]; + group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2]; +} + app.registerExtension({ name: "Comfy.GroupOptions", setup() { @@ -14,6 +57,17 @@ app.registerExtension({ const options = orig.apply(this, arguments); const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]); if (!group) { + options.push({ + content: "Add Group For Selected Nodes", + disabled: !Object.keys(app.canvas.selected_nodes || {}).length, + callback: () => { + var group = new LiteGraph.LGraphGroup(); + addNodesToGroup(group, this.selected_nodes) + app.canvas.graph.add(group); + this.graph.change(); + } + }); + return options; } @@ -38,6 +92,23 @@ app.registerExtension({ } } + options.push({ + content: "Add Selected Nodes To Group", + disabled: !Object.keys(app.canvas.selected_nodes || {}).length, + callback: () => { + addNodesToGroup(group, this.selected_nodes) + this.graph.change(); + } + }); + + options.push({ + content: "Fit Group To Nodes", + callback: () => { + addNodesToGroup(group) + this.graph.change(); + } + }); + options.push({ content: "Select Nodes", callback: () => { From e8c02219ee41424d076682841dd84a1ee9093190 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Mon, 16 Oct 2023 15:26:36 -0300 Subject: [PATCH 27/34] Fix add selected nodes to empty group --- web/extensions/core/groupOptions.js | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/web/extensions/core/groupOptions.js b/web/extensions/core/groupOptions.js index 3ec0e3fd9..d3d52bf29 100644 --- a/web/extensions/core/groupOptions.js +++ b/web/extensions/core/groupOptions.js @@ -75,6 +75,15 @@ app.registerExtension({ group.recomputeInsideNodes(); const nodesInGroup = group._nodes; + options.push({ + content: "Add Selected Nodes To Group", + disabled: !Object.keys(app.canvas.selected_nodes || {}).length, + callback: () => { + addNodesToGroup(group, this.selected_nodes) + this.graph.change(); + } + }); + // No nodes in group, return default options if (nodesInGroup.length === 0) { return options; @@ -92,15 +101,6 @@ app.registerExtension({ } } - options.push({ - content: "Add Selected Nodes To Group", - disabled: !Object.keys(app.canvas.selected_nodes || {}).length, - callback: () => { - addNodesToGroup(group, this.selected_nodes) - this.graph.change(); - } - }); - options.push({ content: "Fit Group To Nodes", callback: () => { From 682c84ccf3fb601d0bd3ae8a506cc7e3d42d2199 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Mon, 16 Oct 2023 16:00:01 -0300 Subject: [PATCH 28/34] Fix fit group to nodes with reroute and collapsed nodes --- web/extensions/core/groupOptions.js | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/groupOptions.js b/web/extensions/core/groupOptions.js index d3d52bf29..5dd21e730 100644 --- a/web/extensions/core/groupOptions.js +++ b/web/extensions/core/groupOptions.js @@ -22,6 +22,18 @@ function addNodesToGroup(group, nodes=[]) { nx2 = node.pos[0] + node.size[0] ny2 = node.pos[1] + node.size[1] + if (node.type != "Reroute") { + ny1 -= LiteGraph.NODE_TITLE_HEIGHT; + } + + if (node.flags?.collapsed) { + ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT; + + if (node?._collapsed_width) { + nx2 = nx1 + Math.round(node._collapsed_width); + } + } + if (x1 == -1 || nx1 < x1) { x1 = nx1; } @@ -42,7 +54,7 @@ function addNodesToGroup(group, nodes=[]) { var padding = 10; - y1 = y1 - Math.round(group.font_size * 2.7); + y1 = y1 - Math.round(group.font_size * 1.4); group.pos = [x1 - padding, y1 - padding]; group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2]; From 5a608aa37c3113db286cc7d3c06147c175249325 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Mon, 16 Oct 2023 17:29:23 -0300 Subject: [PATCH 29/34] Fix node getBounding for collapsed nodes --- web/lib/litegraph.core.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index f81c83a8a..e906590f5 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -3796,7 +3796,7 @@ out = out || new Float32Array(4); out[0] = this.pos[0] - 4; out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; - out[2] = this.size[0] + 4; + out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4; out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT; if (this.onBounding) { From c8013f73e57b27f8e89d96a13d806d4340bd48a0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 16 Oct 2023 16:46:41 -0400 Subject: [PATCH 30/34] Add some Quadro cards to the list of cards with broken fp16. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c24c7b27e..64ed19727 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -667,7 +667,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return False #FP16 is just broken on these cards - nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"] + nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"] for x in nvidia_16_series: if x in props.name: return False From 23680a9155727b67eb542445417ec362a53f2148 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 17 Oct 2023 03:19:29 -0400 Subject: [PATCH 31/34] Refactor the attention stuff in the VAE. --- comfy/ldm/modules/diffusionmodules/model.py | 194 ++++++-------------- 1 file changed, 58 insertions(+), 136 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 6576df4b6..852d367c9 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -193,6 +193,52 @@ def slice_attention(q, k, v): return r1 +def normal_attention(q, k, v): + # compute attention + b,c,h,w = q.shape + + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + v = v.reshape(b,c,h*w) + + r1 = slice_attention(q, k, v) + h_ = r1.reshape(b,c,h,w) + del r1 + return h_ + +def xformers_attention(q, k, v): + # compute attention + B, C, H, W = q.shape + q, k, v = map( + lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), + (q, k, v), + ) + + try: + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + out = out.transpose(1, 2).reshape(B, C, H, W) + except NotImplementedError as e: + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + return out + +def pytorch_attention(q, k, v): + # compute attention + B, C, H, W = q.shape + q, k, v = map( + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), + (q, k, v), + ) + + try: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + except model_management.OOM_EXCEPTION as e: + print("scaled_dot_product_attention OOMed: switched to slice attention") + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + return out + + class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() @@ -220,6 +266,16 @@ class AttnBlock(nn.Module): stride=1, padding=0) + if model_management.xformers_enabled_vae(): + print("Using xformers attention in VAE") + self.optimized_attention = xformers_attention + elif model_management.pytorch_attention_enabled(): + print("Using pytorch attention in VAE") + self.optimized_attention = pytorch_attention + else: + print("Using split attention in VAE") + self.optimized_attention = normal_attention + def forward(self, x): h_ = x h_ = self.norm(h_) @@ -227,149 +283,15 @@ class AttnBlock(nn.Module): k = self.k(h_) v = self.v(h_) - # compute attention - b,c,h,w = q.shape + h_ = self.optimized_attention(q, k, v) - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - v = v.reshape(b,c,h*w) - - r1 = slice_attention(q, k, v) - h_ = r1.reshape(b,c,h,w) - del r1 h_ = self.proj_out(h_) return x+h_ -class MemoryEfficientAttnBlock(nn.Module): - """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation - """ - # - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.attention_op: Optional[Any] = None - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map( - lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), - (q, k, v), - ) - - try: - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = out.transpose(1, 2).reshape(B, C, H, W) - except NotImplementedError as e: - out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) - - out = self.proj_out(out) - return x+out - -class MemoryEfficientAttnBlockPytorch(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.attention_op: Optional[Any] = None - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map( - lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), - (q, k, v), - ) - - try: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = out.transpose(2, 3).reshape(B, C, H, W) - except model_management.OOM_EXCEPTION as e: - print("scaled_dot_product_attention OOMed: switched to slice attention") - out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) - - out = self.proj_out(out) - return x+out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' - if model_management.xformers_enabled_vae() and attn_type == "vanilla": - attn_type = "vanilla-xformers" - elif model_management.pytorch_attention_enabled() and attn_type == "vanilla": - attn_type = "vanilla-pytorch" - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - assert attn_kwargs is None - return AttnBlock(in_channels) - elif attn_type == "vanilla-xformers": - print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") - return MemoryEfficientAttnBlock(in_channels) - elif attn_type == "vanilla-pytorch": - return MemoryEfficientAttnBlockPytorch(in_channels) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - raise NotImplementedError() + return AttnBlock(in_channels) class Model(nn.Module): From 92f0318630f4567227dd8b4b6e53328ec5526195 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 17 Oct 2023 11:39:15 -0400 Subject: [PATCH 32/34] Try to fix notebook. --- notebooks/comfyui_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 4fdccaace..ec83265b4 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -47,7 +47,7 @@ " !git pull\n", "\n", "!echo -= Install dependencies =-\n", - "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" + "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" ] }, { From f8caa24bcc6abf226aeb486beeffd853e711c770 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 17 Oct 2023 12:08:03 -0400 Subject: [PATCH 33/34] Support hypernetwork with mish activation function and layer norm. --- comfy_extras/nodes_hypernetwork.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index d16c49aeb..f692945a8 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength): "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, "softsign": torch.nn.Softsign, + "mish": torch.nn.Mish, } if activation_func not in valid_activation: @@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength): linears = list(map(lambda a: a[:-len(".weight")], linears)) layers = [] - for i in range(len(linears)): + i = 0 + while i < len(linears): lin_name = linears[i] last_layer = (i == (len(linears) - 1)) penultimate_layer = (i == (len(linears) - 2)) @@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength): if (not last_layer) or (activate_output): layers.append(valid_activation[activation_func]()) if is_layer_norm: - layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) + i += 1 + ln_name = linears[i] + ln_weight = attn_weights['{}.weight'.format(ln_name)] + ln_bias = attn_weights['{}.bias'.format(ln_name)] + ln = torch.nn.LayerNorm(ln_weight.shape[0]) + ln.load_state_dict({"weight": ln_weight, "bias": ln_bias}) + layers.append(ln) if use_dropout: if (not last_layer) and (not penultimate_layer or last_layer_dropout): layers.append(torch.nn.Dropout(p=0.3)) + i += 1 output.append(torch.nn.Sequential(*layers)) out[dim] = torch.nn.ModuleList(output) From d44a2de49f797af6b1adc709e709d6b0b4d89093 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 17 Oct 2023 14:51:51 -0400 Subject: [PATCH 34/34] Make VAE code closer to sgm. --- comfy/diffusers_load.py | 3 +- comfy/ldm/models/autoencoder.py | 348 ++++++++++---------- comfy/ldm/modules/diffusionmodules/model.py | 31 +- comfy/sd.py | 39 ++- comfy/utils.py | 11 +- nodes.py | 3 +- 6 files changed, 224 insertions(+), 211 deletions(-) diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index a52e0102b..c0b420e79 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire vae = None if output_vae: - vae = comfy.sd.VAE(ckpt_path=vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) return (unet, clip, vae) diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 1fb7ed879..d2f1d74a9 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -2,67 +2,66 @@ import torch # import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union -from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.util import instantiate_from_config from comfy.ldm.modules.ema import LitEma -# class AutoencoderKL(pl.LightningModule): -class AutoencoderKL(torch.nn.Module): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): +class DiagonalGaussianRegularizer(torch.nn.Module): + def __init__(self, sample: bool = True): super().__init__() - self.learn_logvar = learn_logvar - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log + + +class AbstractAutoencoder(torch.nn.Module): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + **kwargs, + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None if monitor is not None: self.monitor = monitor - self.use_ema = ema_decay is not None if self.use_ema: - self.ema_decay = ema_decay - assert 0. < ema_decay < 1. self.model_ema = LitEma(self, decay=ema_decay) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + def get_input(self, batch) -> Any: + raise NotImplementedError() - def init_from_ckpt(self, path, ignore_keys=list()): - if path.lower().endswith(".safetensors"): - import safetensors.torch - sd = safetensors.torch.load_file(path, device="cpu") - else: - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) @contextmanager def ema_scope(self, context=None): @@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module): self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: - print(f"{context}: Switched to EMA weights") + logpy.info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: - print(f"{context}: Restored training weights") + logpy.info(f"{context}: Restored training weights") - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior + def configure_optimizers(self) -> Any: + raise NotImplementedError() - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - return x - def training_step(self, batch, batch_idx, optimizer_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return aeloss + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + regularizer_config: Dict, + **kwargs, + ): + super().__init__(*args, **kwargs) - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, postfix=""): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) - if self.learn_logvar: - print(f"{self.__class__.__name__}: Learning logvar") - ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] + self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) + self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) + self.regularization: AbstractRegularizer = instantiate_from_config( + regularizer_config + ) def get_last_layer(self): - return self.decoder.conv_out.weight + return self.decoder.get_last_layer() - @torch.no_grad() - def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - if log_ema or self.use_ema: - with self.ema_scope(): - xrec_ema, posterior_ema = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec_ema.shape[1] > 3 - xrec_ema = self.to_rgb(xrec_ema) - log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) - log["reconstructions_ema"] = xrec_ema - log["inputs"] = x - return log + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) return x + def forward( + self, x: torch.Tensor, **additional_decode_kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface - super().__init__() - def encode(self, x, *args, **kwargs): - return x +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + super().__init__( + encoder_config={ + "target": "comfy.ldm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "comfy.ldm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim - def decode(self, x, *args, **kwargs): - return x + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x + def encode( + self, x: torch.Tensor, return_reg_log: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) - def forward(self, x, *args, **kwargs): - return x + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class AutoencoderKL(AutoencodingEngineLegacy): + def __init__(self, **kwargs): + if "lossconfig" in kwargs: + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer" + ) + }, + **kwargs, + ) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 852d367c9..f23417fd2 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -541,7 +541,10 @@ class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): + conv_out_op=comfy.ops.Conv2d, + resnet_op=ResnetBlock, + attn_op=AttnBlock, + **ignorekwargs): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch @@ -570,12 +573,12 @@ class Decoder(nn.Module): # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, + self.mid.block_1 = resnet_op(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, + self.mid.attn_1 = attn_op(block_in) + self.mid.block_2 = resnet_op(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) @@ -587,13 +590,13 @@ class Decoder(nn.Module): attn = nn.ModuleList() block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, + block.append(resnet_op(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) + attn.append(attn_op(block_in)) up = nn.Module() up.block = block up.attn = attn @@ -604,13 +607,13 @@ class Decoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = conv_out_op(block_in, out_ch, kernel_size=3, stride=1, padding=1) - def forward(self, z): + def forward(self, z, **kwargs): #assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape @@ -621,16 +624,16 @@ class Decoder(nn.Module): h = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb) + h = self.up[i_level].block[i_block](h, temb, **kwargs) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + h = self.up[i_level].attn[i_block](h, **kwargs) if i_level != 0: h = self.up[i_level].upsample(h) @@ -640,7 +643,7 @@ class Decoder(nn.Module): h = self.norm_out(h) h = nonlinearity(h) - h = self.conv_out(h) + h = self.conv_out(h, **kwargs) if self.tanh_out: h = torch.tanh(h) return h diff --git a/comfy/sd.py b/comfy/sd.py index fd8b94df8..48ee5721b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -4,7 +4,7 @@ import math from comfy import model_management from .ldm.util import instantiate_from_config -from .ldm.models.autoencoder import AutoencoderKL +from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine import yaml import comfy.utils @@ -140,21 +140,24 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, ckpt_path=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None): + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - if ckpt_path is not None: - sd = comfy.utils.load_torch_file(ckpt_path) - if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - print("Missing VAE keys", m) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False) + if len(m) > 0: + print("Missing VAE keys", m) + + if len(u) > 0: + print("Leftover VAE keys", u) if device is None: device = model_management.vae_device() @@ -183,7 +186,7 @@ class VAE: steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float() + encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) @@ -229,7 +232,7 @@ class VAE: samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float() + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") @@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model.load_model_weights(state_dict, "model.diffusion_model.") if output_vae: - w = WeightsLoader() - vae = VAE(config=vae_config) - w.first_stage_model = vae.first_stage_model - load_model_weights(w, state_dict) + vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True) + vae = VAE(sd=vae_sd, config=vae_config) if output_clip: w = WeightsLoader() @@ -427,10 +428,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae = VAE() - w = WeightsLoader() - w.first_stage_model = vae.first_stage_model - load_model_weights(w, sd) + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae = VAE(sd=vae_sd) if output_clip: w = WeightsLoader() diff --git a/comfy/utils.py b/comfy/utils.py index df016ef9e..a1807aa1d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace): state_dict[keys_to_replace[x]] = state_dict.pop(x) return state_dict -def state_dict_prefix_replace(state_dict, replace_prefix): +def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): + if filter_keys: + out = {} + else: + out = state_dict for rp in replace_prefix: replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) for x in replace: - state_dict[x[1]] = state_dict.pop(x[0]) - return state_dict + w = state_dict.pop(x[0]) + out[x[1]] = w + return out def transformers_convert(sd, prefix_from, prefix_to, number): diff --git a/nodes.py b/nodes.py index d0e9ffcc5..0dbc2be32 100644 --- a/nodes.py +++ b/nodes.py @@ -584,7 +584,8 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): vae_path = folder_paths.get_full_path("vae", vae_name) - vae = comfy.sd.VAE(ckpt_path=vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) return (vae,) class ControlNetLoader: