mm: allow unload of the current model

In some workflows, its possible for a model to be used twice but with
different requirements for the inference VRAM.

Currently, once a model is loaded at a certain level of offload, it will
be preserved at that level of offload if it is used again. This will OOM
if there is a major change in the size of the inference VRAM. This happens
in your classic latent upscaling workflow where the same model is used twice
to generate and upscale.

This is very noticable for WAN in particlar.

Fix by two-passing the model VRAM unload process, firstly trying with
the existing list on idle models and then try again adding the actual
models that are about to be loaded. This will implement the partial
offload you need of your hot-in-VRAM model to make space for the bigger
inference.

Improve info messages regarding any unloads done.
This commit is contained in:
Rattus 2025-11-08 16:03:39 +10:00
parent a1a70362ca
commit ca73329fcd

View File

@ -577,14 +577,14 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[]):
def free_memory(memory_required, device, keep_loaded=[], loaded_models=current_loaded_models):
cleanup_models_gc()
unloaded_model = []
can_unload = []
unloaded_models = []
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
for i in range(len(loaded_models) -1, -1, -1):
shift_model = loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
@ -598,12 +598,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
logging.info(f"Unloading {loaded_models[i].model.model.__class__.__name__}")
if loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
unloaded_models.append(loaded_models.pop(i))
if len(unloaded_model) > 0:
soft_empty_cache()
@ -634,6 +634,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models = models_temp
models_to_load = []
models_to_reload = []
for x in models:
loaded_model = LoadedModel(x)
@ -645,17 +646,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
loaded.currently_used = True
models_to_load.append(loaded)
models_to_reload.append(loaded)
else:
if hasattr(x, "model"):
logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
models_to_load += models_to_reload
for loaded_model in models_to_load:
to_unload = []
for i in range(len(current_loaded_models)):
if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if not current_loaded_models[i] in models_to_reload:
models_to_reload.append(current_loaded_models[i])
for i in to_unload:
model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False)
@ -665,16 +670,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
def free_memory_required(vram, device, models_to_reload):
if get_free_memory(device) < vram:
models_unloaded = free_memory(vram, device)
if len(models_unloaded):
logging.info("{} idle models unloaded.".format(len(models_unloaded)))
models_unloaded = free_memory(vram, device, loaded_models=models_to_reload)
if len(models_unloaded):
logging.info("{} active models unloaded for increased offloading.".format(len(models_unloaded)))
for unloaded_model in models_unloaded:
if unloaded_model in models_to_load:
unloaded_model.currently_used = True
for device in total_memory_required:
if device != torch.device("cpu"):
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
models_l = free_memory(minimum_memory_required, device)
logging.info("{} models unloaded.".format(len(models_l)))
free_memory_required(total_memory_required[device] * 1.1 + extra_mem, device, models_to_reload)
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory_required(minimum_memory_required, device, models_to_reload)
for loaded_model in models_to_load:
model = loaded_model.model