From ca73329fcd344aa2f00e1431c7047c1ca9cb75ed Mon Sep 17 00:00:00 2001 From: Rattus Date: Sat, 8 Nov 2025 16:03:39 +1000 Subject: [PATCH] mm: allow unload of the current model In some workflows, its possible for a model to be used twice but with different requirements for the inference VRAM. Currently, once a model is loaded at a certain level of offload, it will be preserved at that level of offload if it is used again. This will OOM if there is a major change in the size of the inference VRAM. This happens in your classic latent upscaling workflow where the same model is used twice to generate and upscale. This is very noticable for WAN in particlar. Fix by two-passing the model VRAM unload process, firstly trying with the existing list on idle models and then try again adding the actual models that are about to be loaded. This will implement the partial offload you need of your hot-in-VRAM model to make space for the bigger inference. Improve info messages regarding any unloads done. --- comfy/model_management.py | 43 ++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7012df858..44ef3dd4a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -577,14 +577,14 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], loaded_models=current_loaded_models): cleanup_models_gc() unloaded_model = [] can_unload = [] unloaded_models = [] - for i in range(len(current_loaded_models) -1, -1, -1): - shift_model = current_loaded_models[i] + for i in range(len(loaded_models) -1, -1, -1): + shift_model = loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) @@ -598,12 +598,12 @@ def free_memory(memory_required, device, keep_loaded=[]): if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + logging.info(f"Unloading {loaded_models[i].model.model.__class__.__name__}") + if loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + unloaded_models.append(loaded_models.pop(i)) if len(unloaded_model) > 0: soft_empty_cache() @@ -634,6 +634,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models = models_temp models_to_load = [] + models_to_reload = [] for x in models: loaded_model = LoadedModel(x) @@ -645,17 +646,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] loaded.currently_used = True - models_to_load.append(loaded) + models_to_reload.append(loaded) else: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) + models_to_load += models_to_reload + for loaded_model in models_to_load: to_unload = [] for i in range(len(current_loaded_models)): if loaded_model.model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + if not current_loaded_models[i] in models_to_reload: + models_to_reload.append(current_loaded_models[i]) for i in to_unload: model_to_unload = current_loaded_models.pop(i) model_to_unload.model.detach(unpatch_all=False) @@ -665,16 +670,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu for loaded_model in models_to_load: total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - for device in total_memory_required: - if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + def free_memory_required(vram, device, models_to_reload): + if get_free_memory(device) < vram: + models_unloaded = free_memory(vram, device) + if len(models_unloaded): + logging.info("{} idle models unloaded.".format(len(models_unloaded))) + + models_unloaded = free_memory(vram, device, loaded_models=models_to_reload) + if len(models_unloaded): + logging.info("{} active models unloaded for increased offloading.".format(len(models_unloaded))) + for unloaded_model in models_unloaded: + if unloaded_model in models_to_load: + unloaded_model.currently_used = True for device in total_memory_required: if device != torch.device("cpu"): - free_mem = get_free_memory(device) - if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) - logging.info("{} models unloaded.".format(len(models_l))) + free_memory_required(total_memory_required[device] * 1.1 + extra_mem, device, models_to_reload) + + for device in total_memory_required: + if device != torch.device("cpu"): + free_memory_required(minimum_memory_required, device, models_to_reload) for loaded_model in models_to_load: model = loaded_model.model