From 766ae119a86e51bf0ee0b068c26fa3acc0e1b815 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Aug 2024 14:00:56 -0400 Subject: [PATCH 1/5] CheckpointSave node name. --- comfy_extras/nodes_model_merging.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 8fb8bf799..ccf601158 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -365,3 +365,7 @@ NODE_CLASS_MAPPINGS = { "VAESave": VAESave, "ModelSave": ModelSave, } + +NODE_DISPLAY_NAME_MAPPINGS = { + "CheckpointSave": "Save Checkpoint", +} From be0726c1ed9c8ae5d02733dcc2fd5b997bd265de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Aug 2024 15:24:07 -0400 Subject: [PATCH 2/5] Remove duplication. --- comfy/model_management.py | 5 +--- comfy/model_patcher.py | 53 +++++++++++++++------------------------ 2 files changed, 21 insertions(+), 37 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a6996709b..6387c8d08 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -315,10 +315,7 @@ class LoadedModel: self.model_use_more_vram(use_more_vram) else: try: - if lowvram_model_memory > 0 and load_weights: - self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) - else: - self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) + self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6be8f7730..8dafc54a3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -336,33 +336,7 @@ class ModelPatcher: else: comfy.utils.set_attr_param(self.model, key, out_weight) - def patch_model(self, device_to=None, patch_weights=True): - for k in self.object_patches: - old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) - if k not in self.object_patches_backup: - self.object_patches_backup[k] = old - - if patch_weights: - model_sd = self.model_state_dict() - keys_sort = [] - for key in self.patches: - if key not in model_sd: - logging.warning("could not patch. key doesn't exist in model: {}".format(key)) - continue - keys_sort.append((math.prod(model_sd[key].shape), key)) - - keys_sort.sort(reverse=True) - for ks in keys_sort: - self.patch_weight_to_device(ks[1], device_to) - - if device_to is not None: - self.model.to(device_to) - self.model.device = device_to - self.model.model_loaded_weight_memory = self.model_size() - - return self.model - - def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): mem_counter = 0 patch_counter = 0 lowvram_counter = 0 @@ -430,16 +404,29 @@ class ModelPatcher: logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024))) + logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False + if full_load: + self.model.to(device_to) + mem_counter = self.model_size() + self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + for k in self.object_patches: + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old - def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): - self.patch_model(device_to, patch_weights=False) - self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) + if lowvram_model_memory == 0: + full_load = True + else: + full_load = False + + if load_weights: + self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) return self.model def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): @@ -692,14 +679,14 @@ class ModelPatcher: def partially_load(self, device_to, extra_memory=0): self.unpatch_model(unpatch_weights=False) - self.patch_model(patch_weights=False) + self.patch_model(load_weights=False) full_load = False if self.model.model_lowvram == False: return 0 if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): full_load = True current_used = self.model.model_loaded_weight_memory - self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) + self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) return self.model.model_loaded_weight_memory - current_used def current_loaded_device(self): From 6138f92084cff2ad380aa232c208d0b448887620 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Aug 2024 15:35:25 -0400 Subject: [PATCH 3/5] Use better dtype for the lowvram lora system. --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 8dafc54a3..51259b559 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -96,7 +96,7 @@ class LowVramPatch: self.key = key self.model_patcher = model_patcher def __call__(self, weight): - return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) class ModelPatcher: From 4d341b78e8a04f76a289d9bcec264e4f38997c64 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Aug 2024 16:28:55 -0400 Subject: [PATCH 4/5] Bug fixes. --- comfy/model_patcher.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 51259b559..ae33687c1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -387,15 +387,14 @@ class ModelPatcher: m = x[2] weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) - param = list(m.parameters()) - if len(param) > 0: - weight = param[0] - if weight.device == device_to: + if hasattr(m, "comfy_patched_weights"): + if m.comfy_patched_weights == True: continue self.patch_weight_to_device(weight_key, device_to=device_to) self.patch_weight_to_device(bias_key, device_to=device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + m.comfy_patched_weights = True for x in load_completely: x[2].to(device_to) @@ -622,6 +621,10 @@ class ModelPatcher: self.model.device = device_to self.model.model_loaded_weight_memory = 0 + for m in self.model.modules(): + if hasattr(m, "comfy_patched_weights"): + del m.comfy_patched_weights + keys = list(self.object_patches_backup.keys()) for k in keys: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) @@ -649,7 +652,7 @@ class ModelPatcher: weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) - if m.weight is not None and m.weight.device != device_to: + if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: for key in [weight_key, bias_key]: bk = self.backup.get(key, None) if bk is not None: @@ -669,6 +672,7 @@ class ModelPatcher: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True + m.comfy_patched_weights = False memory_freed += module_mem logging.debug("freed {}".format(n)) From 045377ea893d0703e515d87f891936784cb2f5de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Aug 2024 17:16:18 -0400 Subject: [PATCH 5/5] Add a --reserve-vram argument if you don't want comfy to use all of it. --reserve-vram 1.0 for example will make ComfyUI try to keep 1GB vram free. This can also be useful if workflows are failing because of OOM errors but in that case please report it if --reserve-vram improves your situation. --- comfy/cli_args.py | 3 +++ comfy/model_management.py | 15 +++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a895c7e10..77009c912 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -116,6 +116,9 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") +parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.") + + parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 6387c8d08..fb136bd1e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -367,6 +367,17 @@ def offloaded_memory(loaded_models, device): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 1.2 +EXTRA_RESERVED_VRAM = 200 * 1024 * 1024 +if any(platform.win32_ver()): + EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 #Windows is higher because of the shared vram issue + +if args.reserve_vram is not None: + EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 + logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024))) + +def extra_reserved_memory(): + return EXTRA_RESERVED_VRAM + def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [] for i in range(len(current_loaded_models)): @@ -436,11 +447,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu global vram_state inference_memory = minimum_inference_memory() - extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024) + extra_mem = max(inference_memory, memory_required + extra_reserved_memory()) if minimum_memory_required is None: minimum_memory_required = extra_mem else: - minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024) + minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) models = set(models)