diff --git a/comfy/model_management.py b/comfy/model_management.py index 7012df858..44ef3dd4a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -577,14 +577,14 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], loaded_models=current_loaded_models): cleanup_models_gc() unloaded_model = [] can_unload = [] unloaded_models = [] - for i in range(len(current_loaded_models) -1, -1, -1): - shift_model = current_loaded_models[i] + for i in range(len(loaded_models) -1, -1, -1): + shift_model = loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) @@ -598,12 +598,12 @@ def free_memory(memory_required, device, keep_loaded=[]): if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + logging.info(f"Unloading {loaded_models[i].model.model.__class__.__name__}") + if loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + unloaded_models.append(loaded_models.pop(i)) if len(unloaded_model) > 0: soft_empty_cache() @@ -634,6 +634,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models = models_temp models_to_load = [] + models_to_reload = [] for x in models: loaded_model = LoadedModel(x) @@ -645,17 +646,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] loaded.currently_used = True - models_to_load.append(loaded) + models_to_reload.append(loaded) else: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) + models_to_load += models_to_reload + for loaded_model in models_to_load: to_unload = [] for i in range(len(current_loaded_models)): if loaded_model.model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + if not current_loaded_models[i] in models_to_reload: + models_to_reload.append(current_loaded_models[i]) for i in to_unload: model_to_unload = current_loaded_models.pop(i) model_to_unload.model.detach(unpatch_all=False) @@ -665,16 +670,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu for loaded_model in models_to_load: total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - for device in total_memory_required: - if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + def free_memory_required(vram, device, models_to_reload): + if get_free_memory(device) < vram: + models_unloaded = free_memory(vram, device) + if len(models_unloaded): + logging.info("{} idle models unloaded.".format(len(models_unloaded))) + + models_unloaded = free_memory(vram, device, loaded_models=models_to_reload) + if len(models_unloaded): + logging.info("{} active models unloaded for increased offloading.".format(len(models_unloaded))) + for unloaded_model in models_unloaded: + if unloaded_model in models_to_load: + unloaded_model.currently_used = True for device in total_memory_required: if device != torch.device("cpu"): - free_mem = get_free_memory(device) - if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) - logging.info("{} models unloaded.".format(len(models_l))) + free_memory_required(total_memory_required[device] * 1.1 + extra_mem, device, models_to_reload) + + for device in total_memory_required: + if device != torch.device("cpu"): + free_memory_required(minimum_memory_required, device, models_to_reload) for loaded_model in models_to_load: model = loaded_model.model