This commit is contained in:
Julián Mulet 2026-05-02 09:08:55 -07:00 committed by GitHub
commit 352e5cc190
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 486 additions and 5 deletions

View File

@ -157,6 +157,7 @@ parser.add_argument("--force-non-blocking", action="store_true", help="Force Com
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Moves all models larger than 1 GB to a virtual (meta) device between runs, preventing swap pressure on disk. Small models like the VAE are preserved. Trade-off: models are reloaded from disk on subsequent generations.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
class PerformanceFeature(enum.Enum):

View File

@ -461,10 +461,51 @@ if cpu_state == CPUState.MPS:
logging.info(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory
AGGRESSIVE_OFFLOAD = args.aggressive_offload
if DISABLE_SMART_MEMORY:
logging.info("Disabling smart memory management")
if AGGRESSIVE_OFFLOAD:
logging.info("Aggressive offload enabled: models will be freed from RAM after use (designed for Apple Silicon)")
# ---------------------------------------------------------------------------
# Model lifecycle callbacks — on_model_destroyed
# ---------------------------------------------------------------------------
# Why not comfy.hooks? The existing hook system (comfy/hooks.py) is scoped
# to *sampling conditioning* — LoRA weight injection, transformer_options,
# and keyframe scheduling. It has no concept of model-management lifecycle
# events such as "a model's parameters were deallocated".
#
# This lightweight callback list fills that gap. It is intentionally minimal
# (append-only, no priorities, no removal) because the only current consumer
# is the execution-engine cache invalidator registered in PromptExecutor.
# If upstream adopts a formal lifecycle-event bus in the future, these
# callbacks should migrate to that system.
# ---------------------------------------------------------------------------
_on_model_destroyed_callbacks: list = []
def register_model_destroyed_callback(callback):
"""Register a listener for post-destruction lifecycle events.
After ``free_memory`` moves one or more models to the ``meta`` device
(aggressive offload), every registered callback is invoked once with a
*reason* string describing the batch (e.g. ``"batch"``).
Typical usage executed by ``PromptExecutor.__init__``::
def _invalidate(reason):
executor.caches.outputs.clear_all()
register_model_destroyed_callback(_invalidate)
Args:
callback: ``Callable[[str], None]`` receives a human-readable
reason string. Must be safe to call from within the
``free_memory`` critical section (no heavy I/O, no model loads).
"""
_on_model_destroyed_callbacks.append(callback)
def get_torch_device_name(device):
if hasattr(device, 'type'):
if device.type == "cuda":
@ -633,14 +674,21 @@ def offloaded_memory(loaded_models, device):
WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
if cpu_state == CPUState.MPS and AGGRESSIVE_OFFLOAD:
# macOS with Apple Silicon + aggressive offload: shared memory means OS
# needs more headroom. Reserve 4 GB for macOS + system services to
# prevent swap thrashing during model destruction/reload cycles.
EXTRA_RESERVED_VRAM = 4 * 1024 * 1024 * 1024
logging.info("MPS detected with --aggressive-offload: reserving 4 GB for macOS system overhead")
elif WINDOWS:
import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
def get_free_ram():
return comfy.windows.get_free_ram()
else:
if not WINDOWS:
def get_free_ram():
return psutil.virtual_memory().available
@ -663,14 +711,25 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
if device is None or shift_model.device == device:
# On Apple Silicon SHARED mode, CPU RAM == GPU VRAM (same physical memory).
# Bypass the device filter so CPU-loaded models (like CLIP) can be freed.
device_match = (device is None or shift_model.device == device)
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
device_match = True
if device_match:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
can_unload_sorted = sorted(can_unload)
# Collect models to destroy via meta device AFTER the unload loop completes,
# so we don't kill weakrefs of models still being iterated.
_meta_destroy_queue = []
for x in can_unload_sorted:
i = x[-1]
# Guard: weakref may already be dead from a previous iteration
if current_loaded_models[i].model is None:
continue
memory_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY or device is None:
@ -681,15 +740,72 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
memory_to_free = 0
# Aggressive offload for Apple Silicon: force-unload unused models
# regardless of free memory, since CPU RAM == GPU VRAM.
# Only force-unload models > 1 GB — small models like the VAE (160 MB)
# are preserved to avoid unnecessary reload from disk.
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
model_ref = current_loaded_models[i].model
if model_ref is not None and not current_loaded_models[i].currently_used:
model_size = current_loaded_models[i].model_memory()
if model_size > 1024 * 1024 * 1024: # 1 GB threshold
memory_to_free = 1e32 # Force unload
inner = getattr(model_ref, "model", None)
model_name = inner.__class__.__name__ if inner is not None else "unknown"
model_size_mb = model_size / (1024 * 1024)
logging.info(f"[aggressive-offload] Force-unloading {model_name} ({model_size_mb:.0f} MB) from shared RAM")
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
# Queue for meta device destruction after loop completes.
# Only destroy large models (>1 GB) — small models like the VAE (160 MB)
# are kept because the execution cache may reuse their patcher across
# workflow nodes (e.g. vae_loader is cached while vae_decode runs later).
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
if current_loaded_models[i].model is not None:
model_size = current_loaded_models[i].model_memory()
if model_size > 1024 * 1024 * 1024: # Only meta-destroy models > 1 GB
_meta_destroy_queue.append(i)
unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
if current_loaded_models[i].model is not None:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
# --- Phase 2: Deferred meta-device destruction -------------------------
# Move parameters of queued models to the 'meta' device. This replaces
# every nn.Parameter with a zero-storage meta tensor, releasing physical
# RAM on unified-memory systems (Apple Silicon). The operation is
# deferred until *after* the unload loop to avoid invalidating weakrefs
# that other iterations may still reference.
for i in _meta_destroy_queue:
try:
model_ref = current_loaded_models[i].model
if model_ref is None:
continue
inner_model = model_ref.model
model_name = inner_model.__class__.__name__
param_count = sum(p.numel() * p.element_size() for p in inner_model.parameters())
inner_model.to(device="meta")
logging.info(f"[aggressive-offload] Moved {model_name} params to meta device, freed {param_count / (1024**2):.0f} MB")
except Exception as e:
logging.warning(f"[aggressive-offload] Failed to move model to meta: {e}")
# --- Phase 3: Notify lifecycle listeners --------------------------------
# Fire on_model_destroyed callbacks *once* after the entire batch has been
# processed, not per-model. This lets the execution engine clear its
# output cache in a single operation (see PromptExecutor.__init__).
if _meta_destroy_queue and _on_model_destroyed_callbacks:
for cb in _on_model_destroyed_callbacks:
cb("batch")
logging.info(f"[aggressive-offload] Invalidated execution cache after destroying {len(_meta_destroy_queue)} model(s)")
for x in can_unload_sorted:
i = x[-1]
# Guard: weakref may be dead after cache invalidation (meta device move)
if current_loaded_models[i].model is None:
continue
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
@ -702,6 +818,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if len(unloaded_model) > 0:
soft_empty_cache()
if AGGRESSIVE_OFFLOAD:
gc.collect() # Force Python GC to release model tensors
soft_empty_cache() # Second pass to free MPS allocator cache
elif device is not None:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)

View File

@ -748,6 +748,18 @@ class KSAMPLER(Sampler):
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
# On Apple Silicon MPS, flush the allocator pool between steps to prevent
# progressive memory fragmentation and swap thrashing. Wrapping the callback
# here (rather than patching individual samplers) covers all sampler variants.
import comfy.model_management
if noise.device.type == "mps" and getattr(comfy.model_management, "AGGRESSIVE_OFFLOAD", False):
_inner_callback = k_callback
def _mps_flush_callback(x):
if _inner_callback is not None:
_inner_callback(x)
torch.mps.empty_cache()
k_callback = _mps_flush_callback
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples

View File

@ -191,6 +191,17 @@ class BasicCache:
for key in to_remove:
del self.subcaches[key]
def clear_all(self):
"""Drop all cached outputs unconditionally.
This is the public API for external subsystems (e.g. aggressive model
offloading) that need to invalidate every cached result for instance
after model parameters have been moved to the ``meta`` device and the
cached tensors are no longer usable.
"""
self.cache.clear()
self.subcaches.clear()
def clean_unused(self):
assert self.initialized
self._clean_cache()
@ -418,6 +429,10 @@ class NullCache:
def clean_unused(self):
pass
def clear_all(self):
"""No-op: null backend has nothing to invalidate."""
pass
def poll(self, **kwargs):
pass
@ -451,6 +466,13 @@ class LRUCache(BasicCache):
for node_id in node_ids:
self._mark_used(node_id)
def clear_all(self):
"""Drop all cached outputs and reset LRU bookkeeping."""
super().clear_all()
self.used_generation.clear()
self.children.clear()
self.min_generation = 0
def clean_unused(self):
while len(self.cache) > self.max_size and self.min_generation < self.generation:
self.min_generation += 1
@ -509,6 +531,11 @@ class RAMPressureCache(LRUCache):
super().__init__(key_class, 0, enable_providers=enable_providers)
self.timestamps = {}
def clear_all(self):
"""Drop all cached outputs and reset RAM-pressure bookkeeping."""
super().clear_all()
self.timestamps.clear()
def clean_unused(self):
self._clean_subcaches()

View File

@ -651,6 +651,17 @@ class PromptExecutor:
self.cache_type = cache_type
self.server = server
self.reset()
# Register callback so model_management can invalidate cached outputs
# after destroying a model via meta device move (aggressive offload).
# NOTE: self.caches is resolved at call time (not capture time), so this
# callback remains valid even if reset() replaces self.caches later.
import comfy.model_management as mm
if mm.AGGRESSIVE_OFFLOAD:
executor = self
def _invalidate_cache(reason):
logging.info(f"[aggressive-offload] Invalidating execution cache ({reason})")
executor.caches.outputs.clear_all()
mm.register_model_destroyed_callback(_invalidate_cache)
def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)

View File

@ -0,0 +1,311 @@
"""Tests for the aggressive-offload memory management feature.
These tests validate the Apple Silicon (MPS) memory optimisation path without
requiring a GPU or actual model weights. Every test mocks the relevant model
and cache structures so the suite can run in CI on any platform.
"""
import pytest
import types
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Fixtures & helpers
# ---------------------------------------------------------------------------
class FakeLinearModel(nn.Module):
"""Minimal nn.Module whose parameters consume measurable memory."""
def __init__(self, size_mb: float = 2.0):
"""Create a model with approximately ``size_mb`` MB of float32 params."""
super().__init__()
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
n = int(size_mb * 1024 * 1024 / 4)
self.weight = nn.Parameter(torch.zeros(n, dtype=torch.float32))
class FakeModelPatcher:
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
def __init__(self, size_mb: float = 2.0):
"""Create a patcher wrapping a ``FakeLinearModel`` of the given size."""
self.model = FakeLinearModel(size_mb)
self._loaded_size = int(size_mb * 1024 * 1024)
def loaded_size(self):
"""Return reported loaded size in bytes."""
return self._loaded_size
def is_dynamic(self):
"""Static model — never dynamically offloaded."""
return False
class FakeLoadedModel:
"""Mimics LoadedModel entries in current_loaded_models."""
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
"""Wrap a ``FakeModelPatcher`` as a loaded-model entry."""
self._model = patcher
self.currently_used = currently_used
@property
def model(self):
"""Return the underlying model patcher."""
return self._model
def model_memory(self):
"""Return memory footprint in bytes."""
return self._model.loaded_size()
def model_unload(self, _memory_to_free):
"""Simulate successful unload."""
return True
def model_load(self, _device, _keep_loaded):
"""No-op load for testing."""
pass
# ---------------------------------------------------------------------------
# 1. BasicCache.clear_all()
# ---------------------------------------------------------------------------
class TestBasicCacheClearAll:
"""Verify that BasicCache.clear_all() is a proper public API."""
def test_clear_all_empties_cache_and_subcaches(self):
"""clear_all() must remove every entry in both dicts."""
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
cache = BasicCache(CacheKeySetInputSignature)
cache.cache["key1"] = "value1"
cache.cache["key2"] = "value2"
cache.subcaches["sub1"] = "subvalue1"
cache.clear_all()
assert len(cache.cache) == 0
assert len(cache.subcaches) == 0
def test_clear_all_is_idempotent(self):
"""Calling clear_all() on an already-empty cache must not raise."""
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
cache = BasicCache(CacheKeySetInputSignature)
cache.clear_all() # should be a no-op
cache.clear_all() # still a no-op
assert len(cache.cache) == 0
def test_null_cache_clear_all_is_noop(self):
"""NullCache.clear_all() must not raise — it's the null backend."""
from comfy_execution.caching import NullCache
null = NullCache()
null.clear_all() # must not raise AttributeError
def test_lru_cache_clear_all_resets_metadata(self):
"""LRUCache.clear_all() must also reset used_generation, children, min_generation."""
from comfy_execution.caching import LRUCache, CacheKeySetInputSignature
cache = LRUCache(CacheKeySetInputSignature, max_size=10)
# Simulate some entries
cache.cache["k1"] = "v1"
cache.used_generation["k1"] = 5
cache.children["k1"] = ["child1"]
cache.min_generation = 3
cache.generation = 5
cache.clear_all()
assert len(cache.cache) == 0
assert len(cache.used_generation) == 0
assert len(cache.children) == 0
assert cache.min_generation == 0
# generation counter should NOT be reset (it's a monotonic counter)
assert cache.generation == 5
def test_ram_pressure_cache_clear_all_resets_timestamps(self):
"""RAMPressureCache.clear_all() must also reset timestamps."""
from comfy_execution.caching import RAMPressureCache, CacheKeySetInputSignature
cache = RAMPressureCache(CacheKeySetInputSignature)
cache.cache["k1"] = "v1"
cache.used_generation["k1"] = 2
cache.timestamps["k1"] = 1234567890.0
cache.clear_all()
assert len(cache.cache) == 0
assert len(cache.used_generation) == 0
assert len(cache.timestamps) == 0
# ---------------------------------------------------------------------------
# 2. Callback registration & dispatch
# ---------------------------------------------------------------------------
class TestModelDestroyedCallbacks:
"""Validate the on_model_destroyed lifecycle callback system."""
def setup_method(self):
"""Save and reset the callback list before every test."""
import comfy.model_management as mm
self._original = mm._on_model_destroyed_callbacks.copy()
mm._on_model_destroyed_callbacks.clear()
def teardown_method(self):
"""Restore the original callback list after every test."""
import comfy.model_management as mm
mm._on_model_destroyed_callbacks.clear()
mm._on_model_destroyed_callbacks.extend(self._original)
def test_register_single_callback(self):
"""A single registered callback must be stored and callable."""
import comfy.model_management as mm
invocations = []
mm.register_model_destroyed_callback(lambda reason: invocations.append(reason))
assert len(mm._on_model_destroyed_callbacks) == 1
# Simulate dispatch
for cb in mm._on_model_destroyed_callbacks:
cb("test")
assert invocations == ["test"]
def test_register_multiple_callbacks(self):
"""Multiple registrants must all fire — no silent overwrites."""
import comfy.model_management as mm
results_a, results_b = [], []
mm.register_model_destroyed_callback(lambda r: results_a.append(r))
mm.register_model_destroyed_callback(lambda r: results_b.append(r))
for cb in mm._on_model_destroyed_callbacks:
cb("batch")
assert results_a == ["batch"]
assert results_b == ["batch"]
def test_callback_receives_reason_string(self):
"""The callback signature is (reason: str) -> None."""
import comfy.model_management as mm
captured = {}
def _cb(reason):
captured["reason"] = reason
captured["type"] = type(reason).__name__
mm.register_model_destroyed_callback(_cb)
for cb in mm._on_model_destroyed_callbacks:
cb("batch")
assert captured["reason"] == "batch"
assert captured["type"] == "str"
# ---------------------------------------------------------------------------
# 3. Meta-device destruction threshold
# ---------------------------------------------------------------------------
class TestMetaDeviceThreshold:
"""Verify that only models > 1 GB are queued for meta-device destruction."""
def test_small_model_not_destroyed(self):
"""A 160 MB model (VAE-sized) must NOT be moved to meta device."""
model = FakeLinearModel(size_mb=160)
# Simulate the threshold check from free_memory
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
threshold = 1024 * 1024 * 1024 # 1 GB
assert model_size < threshold, (
f"160 MB model should be below 1 GB threshold, got {model_size / (1024**2):.0f} MB"
)
# Confirm parameters are still on a real device
assert model.weight.device.type != "meta"
def test_large_model_above_threshold(self):
"""A 2 GB model (UNET/CLIP-sized) must BE above the destruction threshold."""
# Use a meta-device tensor to avoid allocating 2 GB of real memory.
# Meta tensors report correct numel/element_size but use zero storage.
n = int(2048 * 1024 * 1024 / 4) # 2 GB in float32 params
meta_weight = torch.empty(n, dtype=torch.float32, device="meta")
model_size = meta_weight.numel() * meta_weight.element_size()
threshold = 1024 * 1024 * 1024 # 1 GB
assert model_size > threshold, (
f"2 GB model should be above 1 GB threshold, got {model_size / (1024**2):.0f} MB"
)
def test_meta_device_move_releases_storage(self):
"""Moving parameters to 'meta' must place them on the meta device."""
model = FakeLinearModel(size_mb=2)
assert model.weight.device.type != "meta"
model.to(device="meta")
assert model.weight.device.type == "meta"
# Meta tensors retain their logical shape but live on a virtual device
# with no physical backing — this is what releases RAM.
assert model.weight.nelement() > 0 # still has logical shape
assert model.weight.untyped_storage().device.type == "meta"
# ---------------------------------------------------------------------------
# 4. MPS flush conditionality
# ---------------------------------------------------------------------------
class TestMpsFlushConditionality:
"""Verify the MPS flush only activates under correct conditions."""
def test_flush_requires_aggressive_offload_flag(self):
"""The MPS flush in samplers is gated on AGGRESSIVE_OFFLOAD."""
import comfy.model_management as mm
# When False, flush should NOT be injected
original = getattr(mm, "AGGRESSIVE_OFFLOAD", False)
try:
mm.AGGRESSIVE_OFFLOAD = False
assert not (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
mm.AGGRESSIVE_OFFLOAD = True
assert (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
finally:
mm.AGGRESSIVE_OFFLOAD = original
def test_flush_requires_mps_device(self):
"""The flush must only activate on MPS devices, not CPU or CUDA."""
# Simulate CPU device — flush should not activate
cpu_device = torch.device("cpu")
assert cpu_device.type != "mps"
# Simulate MPS device string check
if torch.backends.mps.is_available():
mps_device = torch.device("mps")
assert mps_device.type == "mps"
# ---------------------------------------------------------------------------
# 5. AGGRESSIVE_OFFLOAD flag integration
# ---------------------------------------------------------------------------
class TestAggressiveOffloadFlag:
"""Verify the CLI flag is correctly exposed."""
def test_flag_exists_in_model_management(self):
"""AGGRESSIVE_OFFLOAD must be importable from model_management."""
import comfy.model_management as mm
assert hasattr(mm, "AGGRESSIVE_OFFLOAD")
assert isinstance(mm.AGGRESSIVE_OFFLOAD, bool)
def test_flag_defaults_from_cli_args(self):
"""The flag should be wired from cli_args to model_management."""
import comfy.cli_args as cli_args
import comfy.model_management as mm
assert hasattr(cli_args.args, "aggressive_offload")
assert mm.AGGRESSIVE_OFFLOAD == cli_args.args.aggressive_offload