diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 767a7216b..dde50a73f 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -31,7 +31,7 @@ jobs: echo 'import site' >> ./python311._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio aiohttp==3.8.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python311._pth diff --git a/README.md b/README.md index 56ee873e0..5e32a74f3 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) -- Latent previews with [TAESD](https://github.com/madebyollin/taesd) +- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) +- Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. - Works fully offline: will never download anything. - [Config file](extra_model_paths.yaml.example) to set the search paths for models. @@ -69,7 +70,7 @@ There is a portable standalone build for Windows that should work for running on ### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z) -Just download, extract and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints +Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints #### How do I share models between another UI and ComfyUI? diff --git a/comfy/cli_args.py b/comfy/cli_args.py index f1306ef7f..38718b66b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -53,7 +53,8 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) attn_group = parser.add_mutually_exclusive_group() -attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") +attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") +attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/clip_config_bigg.json b/comfy/clip_config_bigg.json index 16bafe448..32d82ff39 100644 --- a/comfy/clip_config_bigg.json +++ b/comfy/clip_config_bigg.json @@ -17,7 +17,7 @@ "num_attention_heads": 20, "num_hidden_layers": 32, "pad_token_id": 1, - "projection_dim": 512, + "projection_dim": 1280, "torch_dtype": "float32", "vocab_size": 49408 } diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 1eab54d4b..9688cbd52 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys())) code2idx = {"q": 0, "k": 1, "v": 2} -def convert_text_enc_state_dict_v20(text_enc_dict): +def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): new_state_dict = {} capture_qkv_weight = {} capture_qkv_bias = {} for k, v in text_enc_dict.items(): + if not k.startswith(prefix): + continue if ( k.endswith(".self_attn.q_proj.weight") or k.endswith(".self_attn.k_proj.weight") diff --git a/comfy/model_base.py b/comfy/model_base.py index 923c4348b..60997246c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import numpy as np +from . import utils class BaseModel(torch.nn.Module): def __init__(self, model_config, v_prediction=False): @@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format + self.model_config = model_config self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.diffusion_model = UNetModel(**unet_config) self.v_prediction = v_prediction @@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) + def state_dict_for_saving(self, clip_state_dict, vae_state_dict): + clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) + unet_state_dict = self.diffusion_model.state_dict() + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) + vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) + if self.get_dtype() == torch.float16: + clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) + vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + class SD21UNCLIP(BaseModel): def __init__(self, model_config, noise_aug_config, v_prediction=True): @@ -144,10 +156,10 @@ class SDXLRefiner(BaseModel): print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score) out = [] - out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([height]))) - out.append(self.embedder(torch.Tensor([crop_w]))) + out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([crop_h]))) + out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([aesthetic_score]))) flat = torch.flatten(torch.cat(out))[None, ] return torch.cat((clip_pooled.to(flat.device), flat), dim=1) @@ -168,11 +180,11 @@ class SDXL(BaseModel): print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height) out = [] - out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([height]))) - out.append(self.embedder(torch.Tensor([crop_w]))) + out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([crop_h]))) - out.append(self.embedder(torch.Tensor([target_width]))) + out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([target_height]))) + out.append(self.embedder(torch.Tensor([target_width]))) flat = torch.flatten(torch.cat(out))[None, ] return torch.cat((clip_pooled.to(flat.device), flat), dim=1) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 48137c78f..edad48b1c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -16,13 +16,11 @@ def count_blocks(state_dict_keys, prefix_string): def detect_unet_config(state_dict, key_prefix, use_fp16): state_dict_keys = list(state_dict.keys()) - num_res_blocks = 2 unet_config = { "use_checkpoint": False, "image_size": 32, "out_channels": 4, - "num_res_blocks": num_res_blocks, "use_spatial_transformer": True, "legacy": False } diff --git a/comfy/model_management.py b/comfy/model_management.py index d64dce187..4f3f28571 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -139,7 +139,23 @@ else: except: XFORMERS_IS_AVAILABLE = False +def is_nvidia(): + global cpu_state + if cpu_state == CPUState.GPU: + if torch.version.cuda: + return True + ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention + +if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + try: + if is_nvidia(): + torch_version = torch.version.__version__ + if int(torch_version[0]) >= 2: + ENABLE_PYTORCH_ATTENTION = True + except: + pass + if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) @@ -200,6 +216,11 @@ current_gpu_controlnets = [] model_accelerated = False +def unet_offload_device(): + if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: + return get_torch_device() + else: + return torch.device("cpu") def unload_model(): global current_loaded_model @@ -212,10 +233,9 @@ def unload_model(): accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) model_accelerated = False - #never unload models from GPU on high vram - if vram_state != VRAMState.HIGH_VRAM: - current_loaded_model.model.cpu() - current_loaded_model.model_patches_to("cpu") + + current_loaded_model.model.to(unet_offload_device()) + current_loaded_model.model_patches_to(unet_offload_device()) current_loaded_model.unpatch_model() current_loaded_model = None @@ -347,7 +367,7 @@ def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: #TODO: more reliable way of checking for flash attention? - if torch.version.cuda: #pytorch flash attention only works on Nvidia + if is_nvidia(): #pytorch flash attention only works on Nvidia return True return False @@ -438,7 +458,7 @@ def soft_empty_cache(): elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): - if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda + if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/sd.py b/comfy/sd.py index dbfbdbe38..8eac1f8ed 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -89,8 +89,7 @@ LORA_UNET_MAP_RESNET = { "skip_connection": "resnets_{}_conv_shortcut" } -def load_lora(path, to_load): - lora = utils.load_torch_file(path, safe_load=True) +def load_lora(lora, to_load): patch_dict = {} loaded_keys = set() for x in to_load: @@ -223,13 +222,28 @@ def model_lora_keys(model, key_map={}): counter += 1 counter = 0 text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" - for b in range(24): + clip_l_present = False + for b in range(32): for c in LORA_CLIP_MAP: k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k + k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + key_map[lora_key] = k + clip_l_present = True + + k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + if clip_l_present: + lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + else: + lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner + key_map[lora_key] = k + #Locon stuff ds_counter = 0 @@ -486,10 +500,10 @@ class ModelPatcher: self.backup = {} -def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip): +def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = model_lora_keys(model.model) key_map = model_lora_keys(clip.cond_stage_model, key_map) - loaded = load_lora(lora_path, key_map) + loaded = load_lora(lora, key_map) new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) new_clip = clip.clone() @@ -545,11 +559,11 @@ class CLIP: if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: - self.patcher.patch_model() + self.patch_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) - self.patcher.unpatch_model() + self.unpatch_model() except Exception as e: - self.patcher.unpatch_model() + self.unpatch_model() raise e cond_out = cond @@ -564,6 +578,15 @@ class CLIP: def load_sd(self, sd): return self.cond_stage_model.load_sd(sd) + def get_sd(self): + return self.cond_stage_model.state_dict() + + def patch_model(self): + self.patcher.patch_model() + + def unpatch_model(self): + self.patcher.unpatch_model() + class VAE: def __init__(self, ckpt_path=None, device=None, config=None): if config is None: @@ -665,6 +688,10 @@ class VAE: self.first_stage_model = self.first_stage_model.cpu() return samples + def get_sd(self): + return self.first_stage_model.state_dict() + + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] #print(current_batch_size, target_batch_size) @@ -1114,6 +1141,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) model = model_config.get_model(sd) + model = model.to(model_management.unet_offload_device()) model.load_model_weights(sd, "model.diffusion_model.") if output_vae: @@ -1135,3 +1163,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("left over keys:", left_over) return (ModelPatcher(model), clip, vae, clipvision) + +def save_checkpoint(output_path, model, clip, vae, metadata=None): + try: + model.patch_model() + clip.patch_model() + sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) + utils.save_torch_file(sd, output_path, metadata=metadata) + model.unpatch_model() + clip.unpatch_model() + except Exception as e: + model.unpatch_model() + clip.unpatch_model() + raise e diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0ee314ad5..02a998e5b 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -95,7 +95,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): out_tokens += [tokens_temp] if len(embedding_weights) > 0: - new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1]) + new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device) new_embedding.weight[:token_dict_size] = current_embeds.weight[:] n = token_dict_size for x in embedding_weights: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 51da9456e..6b17b089f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -9,6 +9,8 @@ from . import sdxl_clip from . import supported_models_base from . import latent_formats +from . import diffusers_convert + class SD15(supported_models_base.BASE): unet_config = { "context_dim": 768, @@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE): state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + replace_prefix[""] = "cond_stage_model.model." + state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) + return state_dict + def clip_target(self): return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) @@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE): state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + replace_prefix["clip_g"] = "conditioner.embedders.0.model" + state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + return state_dict_g + def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) @@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE): state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {} + keys_to_replace = {} + state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + for k in state_dict: + if k.startswith("clip_l"): + state_dict_g[k] = state_dict[k] + + replace_prefix["clip_g"] = "conditioner.embedders.1.model" + replace_prefix["clip_l"] = "conditioner.embedders.0" + state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + return state_dict_g + def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3312a99d5..0b0235ca4 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -64,3 +64,15 @@ class BASE: def process_clip_state_dict(self, state_dict): return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "cond_stage_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model.diffusion_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + + def process_vae_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "first_stage_model."} + return state_dict_prefix_replace(state_dict, replace_prefix) + diff --git a/comfy/utils.py b/comfy/utils.py index 7a7f1fa12..b64349054 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -2,10 +2,10 @@ import torch import math import struct import comfy.checkpoint_pickle +import safetensors.torch def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): - import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: if safe_load: @@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False): sd = pl_sd return sd +def save_torch_file(sd, ckpt, metadata=None): + if metadata is not None: + safetensors.torch.save_file(sd, ckpt, metadata=metadata) + else: + safetensors.torch.save_file(sd, ckpt) + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", @@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def convert_sd_to(state_dict, dtype): + keys = list(state_dict.keys()) + for k in keys: + state_dict[k] = state_dict[k].to(dtype) + return state_dict + def safetensors_header(safetensors_path, max_size=100*1024*1024): with open(safetensors_path, "rb") as f: header = f.read(8) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 52b73f702..72eeffb39 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -1,4 +1,8 @@ - +import comfy.sd +import comfy.utils +import folder_paths +import json +import os class ModelMergeSimple: @classmethod @@ -10,7 +14,7 @@ class ModelMergeSimple: RETURN_TYPES = ("MODEL",) FUNCTION = "merge" - CATEGORY = "_for_testing/model_merging" + CATEGORY = "advanced/model_merging" def merge(self, model1, model2, ratio): m = model1.clone() @@ -31,7 +35,7 @@ class ModelMergeBlocks: RETURN_TYPES = ("MODEL",) FUNCTION = "merge" - CATEGORY = "_for_testing/model_merging" + CATEGORY = "advanced/model_merging" def merge(self, model1, model2, **kwargs): m = model1.clone() @@ -49,7 +53,43 @@ class ModelMergeBlocks: m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) return (m, ) +class CheckpointSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip": ("CLIP",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"prompt": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) + return {} + + NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, - "ModelMergeBlocks": ModelMergeBlocks + "ModelMergeBlocks": ModelMergeBlocks, + "CheckpointSave": CheckpointSave, } diff --git a/execution.py b/execution.py index f93de8465..a40b1dd36 100644 --- a/execution.py +++ b/execution.py @@ -110,7 +110,7 @@ def format_value(x): else: return str(x) -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -125,7 +125,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) if result[0] is not True: # Another node failed further upstream return result @@ -136,7 +136,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute if server.client_id is not None: server.last_node_id = unique_id server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - obj = class_def() + + obj = object_storage.get((unique_id, class_type), None) + if obj is None: + obj = class_def() + object_storage[(unique_id, class_type)] = obj output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data @@ -256,6 +260,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): self.outputs = {} + self.object_storage = {} self.outputs_ui = {} self.old_prompt = {} self.server = server @@ -322,6 +327,17 @@ class PromptExecutor: for o in to_delete: d = self.outputs.pop(o) del d + to_delete = [] + for o in self.object_storage: + if o[0] not in prompt: + to_delete += [o] + else: + p = prompt[o[0]] + if o[1] != p['class_type']: + to_delete += [o] + for o in to_delete: + d = self.object_storage.pop(o) + del d for x in prompt: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) @@ -349,7 +365,7 @@ class PromptExecutor: # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) if success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) break diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index fa5418a68..e72f81f28 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -8,7 +8,9 @@ a111: checkpoints: models/Stable-diffusion configs: models/Stable-diffusion vae: models/VAE - loras: models/Lora + loras: | + models/Lora + models/LyCORIS upscale_models: | models/ESRGAN models/SwinIR @@ -21,5 +23,3 @@ a111: # checkpoints: models/checkpoints # gligen: models/gligen # custom_nodes: path/custom_nodes - - diff --git a/nodes.py b/nodes.py index 7280d7880..a9f2e962e 100644 --- a/nodes.py +++ b/nodes.py @@ -148,6 +148,25 @@ class ConditioningSetMask: c.append(n) return (c, ) +class ConditioningZeroOut: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", )}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "zero_out" + + CATEGORY = "advanced/conditioning" + + def zero_out(self, conditioning): + c = [] + for t in conditioning: + d = t[1].copy() + if "pooled_output" in d: + d["pooled_output"] = torch.zeros_like(d["pooled_output"]) + n = [torch.zeros_like(t[0]), d] + c.append(n) + return (c, ) + class VAEDecode: @classmethod def INPUT_TYPES(s): @@ -286,8 +305,7 @@ class SaveLatent: output["latent_tensor"] = samples["samples"] output["latent_format_version_0"] = torch.tensor([]) - safetensors.torch.save_file(output, file, metadata=metadata) - + comfy.utils.save_torch_file(output, file, metadata=metadata) return {} @@ -416,6 +434,9 @@ class CLIPSetLastLayer: return (clip,) class LoraLoader: + def __init__(self): + self.loaded_lora = None + @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -434,7 +455,18 @@ class LoraLoader: return (model, clip) lora_path = folder_paths.get_full_path("loras", lora_name) - model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + del self.loaded_lora + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return (model_lora, clip_lora) class VAELoader: @@ -1351,6 +1383,8 @@ NODE_CLASS_MAPPINGS = { "LoadLatent": LoadLatent, "SaveLatent": SaveLatent, + + "ConditioningZeroOut": ConditioningZeroOut, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c5a209eec..61c277bf6 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -144,6 +144,7 @@ "\n", "\n", "# ESRGAN upscale model\n", + "#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", "\n", diff --git a/web/scripts/app.js b/web/scripts/app.js index d8c9645fc..360e72830 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1484,7 +1484,7 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); - } else if (file.name?.endsWith(".latent")) { + } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) { const info = await getLatentMetadata(file); if (info.workflow) { this.loadGraphData(JSON.parse(info.workflow)); diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 977b5ac2f..c5293dfa3 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -55,11 +55,12 @@ export function getLatentMetadata(file) { const dataView = new DataView(safetensorsData.buffer); let header_size = dataView.getUint32(0, true); let offset = 8; - let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size))); r(header.__metadata__); }; - reader.readAsArrayBuffer(file); + var slice = file.slice(0, 1024 * 1024 * 4); + reader.readAsArrayBuffer(slice); }); } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 99e9123ae..12fda1273 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -545,7 +545,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png,.latent", + accept: ".json,image/png,.latent,.safetensors", style: {display: "none"}, parent: document.body, onchange: () => {