mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Dynamic VRAM unloading fix (#12227)
* mp: fix full dynamic unloading This was not unloading dynamic models when requesting a full unload via the unpatch() code path. This was ok, i your workflow was all dynamic models but fails with big VRAM leaks if you need to fully unload something for a regular ModelPatcher It also fices the "unload models" button. * mm: load models outside of Aimdo Mempool In dynamic_vram mode, escape the Aimdo mempool and load into the regular mempool. Use a dummy thread to do it.
This commit is contained in:
parent
37f711d4a1
commit
de9ada6a41
@ -19,7 +19,8 @@
|
|||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||||
|
import threading
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
import platform
|
import platform
|
||||||
@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
@ -746,8 +747,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
|
||||||
def load_model_gpu(model):
|
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
||||||
return load_models_gpu([model])
|
with torch.inference_mode():
|
||||||
|
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
|
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
||||||
|
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
||||||
|
#thread local. So exploit that to escape context
|
||||||
|
if enables_dynamic_vram():
|
||||||
|
t = threading.Thread(
|
||||||
|
target=load_models_gpu_thread,
|
||||||
|
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
else:
|
||||||
|
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
||||||
|
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||||
|
|
||||||
def loaded_models(only_currently_used=False):
|
def loaded_models(only_currently_used=False):
|
||||||
output = []
|
output = []
|
||||||
@ -1717,9 +1735,6 @@ def debug_memory_summary():
|
|||||||
return torch.cuda.memory.memory_summary()
|
return torch.cuda.memory.memory_summary()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
|
||||||
import threading
|
|
||||||
|
|
||||||
class InterruptProcessingException(Exception):
|
class InterruptProcessingException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.partially_unload_ram(1e32)
|
self.partially_unload_ram(1e32)
|
||||||
self.partially_unload(None)
|
self.partially_unload(None, 1e32)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user