diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 209fc185b..f83afa258 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -136,6 +136,7 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn' vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") +parser.add_argument("--offload-reserve-ram-gb", type=float, default=None, help="Set the amount of ram in GB you want to reserve for other use. When the limit is reached, model on vram will be offloaded to mmap to save ram.") parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 5105111c6..c1ebb1282 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,10 +31,20 @@ import os from functools import lru_cache @lru_cache(maxsize=1) -def get_mmap_mem_threshold_gb(): - mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) - logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}") - return mmap_mem_threshold_gb +def get_offload_reserve_ram_gb(): + offload_reserve_ram_gb = 0 + try: + val = getattr(args, 'offload-reserve-ram-gb', None) + except Exception: + val = None + + if val is not None: + try: + offload_reserve_ram_gb = int(val) + except Exception: + logging.warning(f"Invalid args.offload-reserve-ram-gb value: {val}, defaulting to 0") + offload_reserve_ram_gb= 0 + return offload_reserve_ram_gb def get_free_disk(): return psutil.disk_usage("/").free @@ -613,7 +623,7 @@ def free_memory(memory_required, device, keep_loaded=[]): can_unload = [] unloaded_models = [] - for i in range(len(current_loaded_models) -1, -1, -1): + for i in range(len(current_loaded_models) -1, -1): shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d3c69f614..5d6330321 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,15 +40,19 @@ import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP -from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk +from comfy.model_management import get_free_memory, get_offload_reserve_ram_gb, get_free_disk from comfy.quant_ops import QuantizedTensor -def need_mmap() -> bool: +def enable_offload_to_mmap() -> bool: + if comfy.utils.DISABLE_MMAP: + return False + free_cpu_mem = get_free_memory(torch.device("cpu")) - mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() - if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: - logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") + offload_reserve_ram_gb = get_offload_reserve_ram_gb() + if free_cpu_mem <= offload_reserve_ram_gb * 1024 * 1024 * 1024: + logging.debug(f"Enabling offload to mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {offload_reserve_ram_gb} GB") return True + return False def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: @@ -917,7 +921,7 @@ class ModelPatcher: if device_to is not None: - if need_mmap(): + if enable_offload_to_mmap(): # offload to mmap model_to_mmap(self.model) else: @@ -982,7 +986,7 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - if need_mmap(): + if enable_offload_to_mmap(): if get_free_disk() < module_mem: logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") break