Prevent and detect some types of memory leaks.

This commit is contained in:
comfyanonymous 2024-11-01 06:54:37 -04:00
parent 975927cc79
commit bd5d8f150f

View File

@ -24,6 +24,7 @@ import torch
import sys
import platform
import weakref
import gc
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -400,6 +401,7 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []
can_unload = []
unloaded_models = []
@ -436,6 +438,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
inference_memory = minimum_inference_memory()
@ -523,6 +526,27 @@ def loaded_models(only_currently_used=False):
output.append(m.model)
return output
def cleanup_models_gc():
do_gc = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.real_model() is not None and cur.model is None:
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break
if do_gc:
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.real_model() is not None and cur.model is None:
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):