better partial unload

This commit is contained in:
strint 2025-10-23 18:09:47 +08:00
parent c312733b8c
commit dc7c77e78c
2 changed files with 46 additions and 25 deletions

View File

@ -26,6 +26,14 @@ import importlib
import platform
import weakref
import gc
import os
def get_mmap_mem_threshold_gb():
mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0"))
return mmap_mem_threshold_gb
def get_free_disk():
return psutil.disk_usage("/").free
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -524,9 +532,7 @@ class LoadedModel:
logging.debug(f"unpatch_weights: {unpatch_weights}")
logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB")
logging.debug(f"offload_device: {self.model.offload_device}")
available_memory = get_free_memory(self.model.offload_device)
logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
# reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage
# if available_memory < reserved_memory:
# logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB")
# return False
@ -537,30 +543,42 @@ class LoadedModel:
# memory_to_free = offload_memory
# logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB")
# logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB")
try:
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
logging.debug("Do partially unload")
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB")
if freed >= memory_to_free:
return False
if memory_to_free is None:
# free the full model
memory_to_free = self.model.loaded_size()
available_memory = get_free_memory(self.model.offload_device)
logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage
if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size():
partially_unload = True
else:
partially_unload = False
if partially_unload:
logging.debug("Do partially unload")
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB")
if freed < memory_to_free:
logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB")
else:
logging.debug("Do full unload")
self.model.detach(unpatch_weights)
logging.debug("Do full unload done")
except Exception as e:
logging.error(f"Error in model_unload: {e}")
available_memory = get_free_memory(self.model.offload_device)
logging.info(f"after error, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
return False
finally:
available_memory = get_free_memory(self.model.offload_device)
logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
available_memory = get_free_memory(self.model.offload_device)
logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
if partially_unload:
return False
else:
return True
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)

View File

@ -40,11 +40,11 @@ import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk
def need_mmap() -> bool:
free_cpu_mem = get_free_memory(torch.device("cpu"))
mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0"))
mmap_mem_threshold_gb = get_mmap_mem_threshold_gb()
if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024:
logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB")
return True
@ -972,6 +972,9 @@ class ModelPatcher:
if move_weight:
cast_weight = self.force_cast_weights
if need_mmap():
if get_free_disk() < module_mem:
logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB")
break
# offload to mmap
model_to_mmap(m)
else: