From bd5d8f150f2b7e8283a512510f5a08224176ada2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Nov 2024 06:54:37 -0400 Subject: [PATCH] Prevent and detect some types of memory leaks. --- comfy/model_management.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7f21559e5..adda3841d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -24,6 +24,7 @@ import torch import sys import platform import weakref +import gc class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -400,6 +401,7 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() def free_memory(memory_required, device, keep_loaded=[]): + cleanup_models_gc() unloaded_model = [] can_unload = [] unloaded_models = [] @@ -436,6 +438,7 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + cleanup_models_gc() global vram_state inference_memory = minimum_inference_memory() @@ -523,6 +526,27 @@ def loaded_models(only_currently_used=False): output.append(m.model) return output + +def cleanup_models_gc(): + do_gc = False + for i in range(len(current_loaded_models)): + cur = current_loaded_models[i] + if cur.real_model() is not None and cur.model is None: + logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) + do_gc = True + break + + if do_gc: + gc.collect() + soft_empty_cache() + + for i in range(len(current_loaded_models)): + cur = current_loaded_models[i] + if cur.real_model() is not None and cur.model is None: + logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) + + + def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)):