diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 087d8f662..44f18cbcf 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -573,7 +573,7 @@ class PromptExecutor: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) - model_management.cleanup_models(keep_clone_weights_loaded=True) + model_management.cleanup_models_gc() self.add_message("execution_cached", {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) diff --git a/comfy/model_management.py b/comfy/model_management.py index 24aa423de..75a65ca76 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,9 +22,10 @@ import logging import platform import sys import warnings +import weakref from enum import Enum from threading import RLock -from typing import Literal, List, Sequence, Final +from typing import List, Sequence, Final import psutil import torch @@ -338,11 +339,27 @@ def module_size(module): class LoadedModel: def __init__(self, model: ModelManageable): - self.model = model + self._set_model(model) self.device = model.load_device - self.weights_loaded = False self.real_model = None self.currently_used = True + self.model_finalizer = None + self._patcher_finalizer = None + + def _set_model(self, model): + self._model = weakref.ref(model) + if model.parent is not None: + self._parent_model = weakref.ref(model.parent) + self._patcher_finalizer = weakref.finalize(model, self._switch_parent) + + def _switch_parent(self): + model = self._parent_model() + if model is not None: + self._set_model(model) + + @property + def model(self): + return self._model() def model_memory(self): return self.model.model_size() @@ -357,32 +374,23 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0, force_patch_weights=False): - patch_model_to = self.device - self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) - load_weights = not self.weights_loaded + # if self.model.loaded_size() > 0: + use_more_vram = lowvram_model_memory + if use_more_vram == 0: + use_more_vram = 1e32 + self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model - if self.model.loaded_size() > 0: - use_more_vram = lowvram_model_memory - if use_more_vram == 0: - use_more_vram = 1e32 - self.model_use_more_vram(use_more_vram) - else: - try: - self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights) - except Exception as e: - self.model.unpatch_model(self.model.offload_device) - self.model_unload() - raise e - - if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None: + if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: with torch.no_grad(): - self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) + real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) - self.weights_loaded = True - return self.real_model + self.real_model = weakref.ref(real_model) + self.model_finalizer = weakref.finalize(real_model, cleanup_models) + return real_model def should_reload_model(self, force_patch_weights=False): if force_patch_weights and self.model.lowvram_patch_counter() > 0: @@ -395,14 +403,14 @@ class LoadedModel: freed = self.model.partially_unload(self.model.offload_device, memory_to_free) if freed >= memory_to_free: return False - self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) - self.model.model_patches_to(self.model.offload_device) - self.weights_loaded = self.weights_loaded and not unpatch_weights + self.model.detach(unpatch_weights) + self.model_finalizer.detach() + self.model_finalizer = None self.real_model = None return True - def model_use_more_vram(self, extra_memory): - return self.model.partially_load(self.device, extra_memory) + def model_use_more_vram(self, extra_memory, force_patch_weights=False): + return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) def __eq__(self, other): return self.model is other.model @@ -413,6 +421,10 @@ class LoadedModel: else: return f"" + def __del__(self): + if self._patcher_finalizer is not None: + self._patcher_finalizer.detach() + def use_more_memory(extra_memory, loaded_models, device): for m in loaded_models: @@ -449,43 +461,6 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]: - with model_management_lock: - return _unload_model_clones(model, unload_weights_only, force_unload) - - -def _unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]: - to_unload = [] - for i in range(len(current_loaded_models)): - if model.is_clone(current_loaded_models[i].model): - to_unload = [i] + to_unload - - if len(to_unload) == 0: - return True - - same_weights = 0 - for i in to_unload: - if model.clone_has_same_weights(current_loaded_models[i].model): - same_weights += 1 - - if same_weights == len(to_unload): - unload_weight = False - else: - unload_weight = True - - if not force_unload: - if unload_weights_only and unload_weight == False: - return None - else: - unload_weight = True - - for i in to_unload: - logging.debug("unload clone {} {}".format(i, unload_weight)) - current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) - - return unload_weight - - @tracer.start_as_current_span("Free Memory") def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]: span = get_current_span() @@ -496,7 +471,8 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]: return unloaded_models -def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]: +def _free_memory(memory_required, device, keep_loaded=[]): + cleanup_models_gc() unloaded_model = [] can_unload = [] unloaded_models = [] @@ -546,6 +522,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None: + cleanup_models_gc() global vram_state inference_memory = minimum_inference_memory() @@ -558,12 +535,9 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0 models = set(models) models_to_load = [] - models_already_loaded = [] models_freed = [] for x in models: loaded_model = LoadedModel(x) - loaded = None - try: loaded_model_index = current_loaded_models.index(loaded_model) except: @@ -571,46 +545,34 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0 if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] - if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic - current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) - loaded = None - else: - loaded.currently_used = True - models_already_loaded.append(loaded) - - if loaded is None: + loaded.currently_used = True + models_to_load.append(loaded) + else: models_to_load.append(loaded_model) - if len(models_to_load) == 0: - devs = set(map(lambda a: a.device, models_already_loaded)) - for d in devs: - if d != torch.device("cpu"): - free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded) - free_mem = get_free_memory(d) - if free_mem < minimum_memory_required: - models_to_load = free_memory(minimum_memory_required, d) - models_freed += models_to_load - else: - use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) - if len(models_to_load) == 0: - return + for loaded_model in models_to_load: + to_unload = [] + for i in range(len(current_loaded_models)): + if loaded_model.model.is_clone(current_loaded_models[i].model): + to_unload = [i] + to_unload + for i in to_unload: + current_loaded_models.pop(i).model.detach(unpatch_all=False) total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) # unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - for loaded_model in models_already_loaded: - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - - for loaded_model in models_to_load: - weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded - if weights_unloaded is not None: - loaded_model.weights_loaded = not weights_unloaded - for device in total_memory_required: if device != torch.device("cpu"): - models_freed += free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded) + models_freed += free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + + for device in total_memory_required: + if device != torch.device("cpu"): + free_mem = get_free_memory(device) + if free_mem < minimum_memory_required: + models_l = free_memory(minimum_memory_required, device) + models_freed += models_l + logging.debug("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: model = loaded_model.model @@ -633,13 +595,6 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0 cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) - devs = set(map(lambda a: a.device, models_already_loaded)) - for d in devs: - if d != torch.device("cpu"): - free_mem = get_free_memory(d) - if free_mem > minimum_memory_required: - use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) - span = get_current_span() span.set_attribute("models_to_load", list(map(str, models_to_load))) span.set_attribute("models_freed", list(map(str, models_freed))) @@ -662,23 +617,34 @@ def loaded_models(only_currently_used=False): return output -def cleanup_models(keep_clone_weights_loaded=False): - with model_management_lock: - to_delete = [] - for i in range(len(current_loaded_models)): - # TODO: very fragile function needs improvement - num_refs = sys.getrefcount(current_loaded_models[i].model) - if num_refs <= 2: - if not keep_clone_weights_loaded: - to_delete = [i] + to_delete - # TODO: find a less fragile way to do this. - elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: # references from .real_model + the .model - to_delete = [i] + to_delete +def cleanup_models_gc(): + do_gc = False + for i in range(len(current_loaded_models)): + cur = current_loaded_models[i] + if cur.real_model() is not None and cur.model is None: + logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) + do_gc = True + break - for i in to_delete: - x = current_loaded_models.pop(i) - x.model_unload() - del x + if do_gc: + gc.collect() + soft_empty_cache() + + for i in range(len(current_loaded_models)): + cur = current_loaded_models[i] + if cur.real_model() is not None and cur.model is None: + logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) + + +def cleanup_models(): + to_delete = [] + for i in range(len(current_loaded_models)): + if current_loaded_models[i].real_model() is None: + to_delete = [i] + to_delete + + for i in to_delete: + x = current_loaded_models.pop(i) + del x def dtype_size(dtype): @@ -747,7 +713,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor pass if fp8_dtype is not None: - if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive + if supports_fp8_compute(device): # if fp8 compute is supported the casting is most likely not expensive return fp8_dtype free_model_memory = maximum_vram_for_weights(device) @@ -960,6 +926,7 @@ def force_channels_last(): # TODO return False + def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): if device is None or weight.device == device: if not copy: @@ -971,12 +938,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): r.copy_(weight, non_blocking=non_blocking) return r + def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) - FLASH_ATTENTION_ENABLED = False if not args.disable_flash_attn: try: diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index af804631e..5291bbfa2 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable +from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any import torch import torch.nn @@ -71,11 +71,11 @@ class ModelManageable(Protocol): def lowvram_patch_counter(self) -> int: return 0 - def partially_load(self, device_to: torch.device, extra_memory=0) -> int: + def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False): self.patch_model(device_to=device_to) return self.model_size() - def partially_unload(self, device_to: torch.device, extra_memory=0) -> int: + def partially_unload(self, device_to: torch.device, memory_to_free: int = 0): self.unpatch_model(device_to) return self.model_size() @@ -113,6 +113,16 @@ class ModelManageable(Protocol): def __del__(self): del self.model + @property + def parent(self) -> ModelManageableT | None: + return None + + def detach(self, unpatch_all: bool = True): + self.model_patches_to(self.offload_device) + if unpatch_all: + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + return self.model + @dataclasses.dataclass class MemoryMeasurements: @@ -120,6 +130,7 @@ class MemoryMeasurements: model_loaded_weight_memory: int = 0 lowvram_patch_counter: int = 0 model_lowvram: bool = False + current_weight_patches_uuid: Any = None _device: torch.device | None = None @property diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a58a9e3d3..d45a360e6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -146,6 +146,7 @@ class ModelPatcher(ModelManageable): self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update + self._parent: ModelManageable | None = None self.patches_uuid: uuid.UUID = uuid.uuid4() self.ckpt_name = ckpt_name self._memory_measurements = MemoryMeasurements(self.model) @@ -166,6 +167,18 @@ class ModelPatcher(ModelManageable): def model_device(self, value: torch.device): self._memory_measurements.device = value + @property + def current_weight_patches_uuid(self) -> Optional[uuid.UUID]: + return self._memory_measurements.current_weight_patches_uuid + + @current_weight_patches_uuid.setter + def current_weight_patches_uuid(self, value): + self._memory_measurements.current_weight_patches_uuid = value + + @property + def parent(self) -> Optional["ModelPatcher"]: + return self._parent + def lowvram_patch_counter(self): return self._memory_measurements.lowvram_patch_counter @@ -191,6 +204,7 @@ class ModelPatcher(ModelManageable): n._model_options = copy.deepcopy(self.model_options) n.backup = self.backup n.object_patches_backup = self.object_patches_backup + n._parent = self return n def is_clone(self, other): @@ -484,6 +498,7 @@ class ModelPatcher(ModelManageable): self.model_device = device_to self._memory_measurements.model_loaded_weight_memory = mem_counter + self._memory_measurements.current_weight_patches_uuid = self.patches_uuid def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): for k in self.object_patches: @@ -518,6 +533,7 @@ class ModelPatcher(ModelManageable): else: utils.set_attr_param(self.model, k, bk.weight) + self._memory_measurements.current_weight_patches_uuid = None self.backup.clear() if device_to is not None: @@ -585,18 +601,35 @@ class ModelPatcher(ModelManageable): self._memory_measurements.model_loaded_weight_memory -= memory_freed return memory_freed - def partially_load(self, device_to, extra_memory=0): - self.unpatch_model(unpatch_weights=False) + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + unpatch_weights = self._memory_measurements.current_weight_patches_uuid is not None and (self._memory_measurements.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) + # TODO: force_patch_weights should not unload + reload full model + used = self._memory_measurements.model_loaded_weight_memory + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + if unpatch_weights: + extra_memory += (used - self._memory_measurements.model_loaded_weight_memory) + self.patch_model(load_weights=False) full_load = False - if not self._memory_measurements.model_lowvram: + if not self._memory_measurements.model_lowvram and self._memory_measurements.model_loaded_weight_memory > 0: return 0 if self._memory_measurements.model_loaded_weight_memory + extra_memory > self.model_size(): full_load = True current_used = self._memory_measurements.model_loaded_weight_memory - self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) + try: + self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) + except Exception as e: + self.detach() + raise e + return self._memory_measurements.model_loaded_weight_memory - current_used + def detach(self, unpatch_all=True): + self.model_patches_to(self.offload_device) + if unpatch_all: + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + return self.model + def current_loaded_device(self): return self.model_device @@ -618,3 +651,6 @@ class ModelPatcher(ModelManageable): def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") return lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) + + def __del__(self): + self.detach(unpatch_all=False) diff --git a/tests/inference/workflows/mochi-text-to-video-0.json b/tests/inference/workflows/mochi-text-to-video-0.json index 56463127e..b75355096 100644 --- a/tests/inference/workflows/mochi-text-to-video-0.json +++ b/tests/inference/workflows/mochi-text-to-video-0.json @@ -2,7 +2,7 @@ "3": { "inputs": { "seed": 309794859719915, - "steps": 30, + "steps": 1, "cfg": 4.5, "sampler_name": "euler", "scheduler": "simple",