This commit is contained in:
strint 2025-12-15 18:47:35 +08:00
parent 5495b55ab2
commit fa674cc60d
3 changed files with 27 additions and 12 deletions

View File

@ -136,6 +136,7 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn'
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--offload-reserve-ram-gb", type=float, default=None, help="Set the amount of ram in GB you want to reserve for other use. When the limit is reached, model on vram will be offloaded to mmap to save ram.")
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")

View File

@ -31,10 +31,20 @@ import os
from functools import lru_cache
@lru_cache(maxsize=1)
def get_mmap_mem_threshold_gb():
mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0"))
logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}")
return mmap_mem_threshold_gb
def get_offload_reserve_ram_gb():
offload_reserve_ram_gb = 0
try:
val = getattr(args, 'offload-reserve-ram-gb', None)
except Exception:
val = None
if val is not None:
try:
offload_reserve_ram_gb = int(val)
except Exception:
logging.warning(f"Invalid args.offload-reserve-ram-gb value: {val}, defaulting to 0")
offload_reserve_ram_gb= 0
return offload_reserve_ram_gb
def get_free_disk():
return psutil.disk_usage("/").free
@ -613,7 +623,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
can_unload = []
unloaded_models = []
for i in range(len(current_loaded_models) -1, -1, -1):
for i in range(len(current_loaded_models) -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():

View File

@ -40,15 +40,19 @@ import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk
from comfy.model_management import get_free_memory, get_offload_reserve_ram_gb, get_free_disk
from comfy.quant_ops import QuantizedTensor
def need_mmap() -> bool:
def enable_offload_to_mmap() -> bool:
if comfy.utils.DISABLE_MMAP:
return False
free_cpu_mem = get_free_memory(torch.device("cpu"))
mmap_mem_threshold_gb = get_mmap_mem_threshold_gb()
if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024:
logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB")
offload_reserve_ram_gb = get_offload_reserve_ram_gb()
if free_cpu_mem <= offload_reserve_ram_gb * 1024 * 1024 * 1024:
logging.debug(f"Enabling offload to mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {offload_reserve_ram_gb} GB")
return True
return False
def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
@ -917,7 +921,7 @@ class ModelPatcher:
if device_to is not None:
if need_mmap():
if enable_offload_to_mmap():
# offload to mmap
model_to_mmap(self.model)
else:
@ -982,7 +986,7 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
if need_mmap():
if enable_offload_to_mmap():
if get_free_disk() < module_mem:
logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB")
break