mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
mm: allow unload of the current model
In some workflows, its possible for a model to be used twice but with different requirements for the inference VRAM. Currently, once a model is loaded at a certain level of offload, it will be preserved at that level of offload if it is used again. This will OOM if there is a major change in the size of the inference VRAM. This happens in your classic latent upscaling workflow where the same model is used twice to generate and upscale. This is very noticable for WAN in particlar. Fix by two-passing the model VRAM unload process, firstly trying with the existing list on idle models and then try again adding the actual models that are about to be loaded. This will implement the partial offload you need of your hot-in-VRAM model to make space for the bigger inference. Improve info messages regarding any unloads done.
This commit is contained in:
parent
a1a70362ca
commit
ca73329fcd
@ -577,14 +577,14 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
def free_memory(memory_required, device, keep_loaded=[], loaded_models=current_loaded_models):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
for i in range(len(loaded_models) -1, -1, -1):
|
||||
shift_model = loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
@ -598,12 +598,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.info(f"Unloading {loaded_models[i].model.model.__class__.__name__}")
|
||||
if loaded_models[i].model_unload(memory_to_free):
|
||||
unloaded_model.append(i)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
unloaded_models.append(loaded_models.pop(i))
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
@ -634,6 +634,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
models = models_temp
|
||||
|
||||
models_to_load = []
|
||||
models_to_reload = []
|
||||
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
@ -645,17 +646,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
loaded.currently_used = True
|
||||
models_to_load.append(loaded)
|
||||
models_to_reload.append(loaded)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
models_to_load += models_to_reload
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if loaded_model.model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
if not current_loaded_models[i] in models_to_reload:
|
||||
models_to_reload.append(current_loaded_models[i])
|
||||
for i in to_unload:
|
||||
model_to_unload = current_loaded_models.pop(i)
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
@ -665,16 +670,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
for loaded_model in models_to_load:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
def free_memory_required(vram, device, models_to_reload):
|
||||
if get_free_memory(device) < vram:
|
||||
models_unloaded = free_memory(vram, device)
|
||||
if len(models_unloaded):
|
||||
logging.info("{} idle models unloaded.".format(len(models_unloaded)))
|
||||
|
||||
models_unloaded = free_memory(vram, device, loaded_models=models_to_reload)
|
||||
if len(models_unloaded):
|
||||
logging.info("{} active models unloaded for increased offloading.".format(len(models_unloaded)))
|
||||
for unloaded_model in models_unloaded:
|
||||
if unloaded_model in models_to_load:
|
||||
unloaded_model.currently_used = True
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
free_memory_required(total_memory_required[device] * 1.1 + extra_mem, device, models_to_reload)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory_required(minimum_memory_required, device, models_to_reload)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user