ComfyUI/tests-unit/test_aggressive_offload.py

302 lines
11 KiB
Python

"""Tests for the aggressive-offload memory management feature.
These tests validate the Apple Silicon (MPS) memory optimisation path without
requiring a GPU or actual model weights. Every test mocks the relevant model
and cache structures so the suite can run in CI on any platform.
"""
import pytest
import types
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Fixtures & helpers
# ---------------------------------------------------------------------------
class FakeLinearModel(nn.Module):
"""Minimal nn.Module whose parameters consume measurable memory."""
def __init__(self, size_mb: float = 2.0):
super().__init__()
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
n = int(size_mb * 1024 * 1024 / 4)
self.weight = nn.Parameter(torch.zeros(n, dtype=torch.float32))
class FakeModelPatcher:
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
def __init__(self, size_mb: float = 2.0):
self.model = FakeLinearModel(size_mb)
self._loaded_size = int(size_mb * 1024 * 1024)
def loaded_size(self):
return self._loaded_size
def is_dynamic(self):
return False
class FakeLoadedModel:
"""Mimics LoadedModel entries in current_loaded_models."""
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
self._model = patcher
self.currently_used = currently_used
@property
def model(self):
return self._model
def model_memory(self):
return self._model.loaded_size()
def model_unload(self, _memory_to_free):
return True
def model_load(self, _device, _keep_loaded):
pass
# ---------------------------------------------------------------------------
# 1. BasicCache.clear_all()
# ---------------------------------------------------------------------------
class TestBasicCacheClearAll:
"""Verify that BasicCache.clear_all() is a proper public API."""
def test_clear_all_empties_cache_and_subcaches(self):
"""clear_all() must remove every entry in both dicts."""
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
cache = BasicCache(CacheKeySetInputSignature)
cache.cache["key1"] = "value1"
cache.cache["key2"] = "value2"
cache.subcaches["sub1"] = "subvalue1"
cache.clear_all()
assert len(cache.cache) == 0
assert len(cache.subcaches) == 0
def test_clear_all_is_idempotent(self):
"""Calling clear_all() on an already-empty cache must not raise."""
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
cache = BasicCache(CacheKeySetInputSignature)
cache.clear_all() # should be a no-op
cache.clear_all() # still a no-op
assert len(cache.cache) == 0
def test_null_cache_clear_all_is_noop(self):
"""NullCache.clear_all() must not raise — it's the null backend."""
from comfy_execution.caching import NullCache
null = NullCache()
null.clear_all() # must not raise AttributeError
def test_lru_cache_clear_all_resets_metadata(self):
"""LRUCache.clear_all() must also reset used_generation, children, min_generation."""
from comfy_execution.caching import LRUCache, CacheKeySetInputSignature
cache = LRUCache(CacheKeySetInputSignature, max_size=10)
# Simulate some entries
cache.cache["k1"] = "v1"
cache.used_generation["k1"] = 5
cache.children["k1"] = ["child1"]
cache.min_generation = 3
cache.generation = 5
cache.clear_all()
assert len(cache.cache) == 0
assert len(cache.used_generation) == 0
assert len(cache.children) == 0
assert cache.min_generation == 0
# generation counter should NOT be reset (it's a monotonic counter)
assert cache.generation == 5
def test_ram_pressure_cache_clear_all_resets_timestamps(self):
"""RAMPressureCache.clear_all() must also reset timestamps."""
from comfy_execution.caching import RAMPressureCache, CacheKeySetInputSignature
cache = RAMPressureCache(CacheKeySetInputSignature)
cache.cache["k1"] = "v1"
cache.used_generation["k1"] = 2
cache.timestamps["k1"] = 1234567890.0
cache.clear_all()
assert len(cache.cache) == 0
assert len(cache.used_generation) == 0
assert len(cache.timestamps) == 0
# ---------------------------------------------------------------------------
# 2. Callback registration & dispatch
# ---------------------------------------------------------------------------
class TestModelDestroyedCallbacks:
"""Validate the on_model_destroyed lifecycle callback system."""
def setup_method(self):
"""Reset the callback list before every test."""
import comfy.model_management as mm
self._original = mm._on_model_destroyed_callbacks.copy()
mm._on_model_destroyed_callbacks.clear()
def teardown_method(self):
"""Restore the original callback list."""
import comfy.model_management as mm
mm._on_model_destroyed_callbacks.clear()
mm._on_model_destroyed_callbacks.extend(self._original)
def test_register_single_callback(self):
import comfy.model_management as mm
invocations = []
mm.register_model_destroyed_callback(lambda reason: invocations.append(reason))
assert len(mm._on_model_destroyed_callbacks) == 1
# Simulate dispatch
for cb in mm._on_model_destroyed_callbacks:
cb("test")
assert invocations == ["test"]
def test_register_multiple_callbacks(self):
"""Multiple registrants must all fire — no silent overwrites."""
import comfy.model_management as mm
results_a, results_b = [], []
mm.register_model_destroyed_callback(lambda r: results_a.append(r))
mm.register_model_destroyed_callback(lambda r: results_b.append(r))
for cb in mm._on_model_destroyed_callbacks:
cb("batch")
assert results_a == ["batch"]
assert results_b == ["batch"]
def test_callback_receives_reason_string(self):
"""The callback signature is (reason: str) -> None."""
import comfy.model_management as mm
captured = {}
def _cb(reason):
captured["reason"] = reason
captured["type"] = type(reason).__name__
mm.register_model_destroyed_callback(_cb)
for cb in mm._on_model_destroyed_callbacks:
cb("batch")
assert captured["reason"] == "batch"
assert captured["type"] == "str"
# ---------------------------------------------------------------------------
# 3. Meta-device destruction threshold
# ---------------------------------------------------------------------------
class TestMetaDeviceThreshold:
"""Verify that only models > 1 GB are queued for meta-device destruction."""
def test_small_model_not_destroyed(self):
"""A 160 MB model (VAE-sized) must NOT be moved to meta device."""
model = FakeLinearModel(size_mb=160)
# Simulate the threshold check from free_memory
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
threshold = 1024 * 1024 * 1024 # 1 GB
assert model_size < threshold, (
f"160 MB model should be below 1 GB threshold, got {model_size / (1024**2):.0f} MB"
)
# Confirm parameters are still on a real device
assert model.weight.device.type != "meta"
def test_large_model_above_threshold(self):
"""A 2 GB model (UNET/CLIP-sized) must BE above the destruction threshold."""
# Use a meta-device tensor to avoid allocating 2 GB of real memory.
# Meta tensors report correct numel/element_size but use zero storage.
n = int(2048 * 1024 * 1024 / 4) # 2 GB in float32 params
meta_weight = torch.empty(n, dtype=torch.float32, device="meta")
model_size = meta_weight.numel() * meta_weight.element_size()
threshold = 1024 * 1024 * 1024 # 1 GB
assert model_size > threshold, (
f"2 GB model should be above 1 GB threshold, got {model_size / (1024**2):.0f} MB"
)
def test_meta_device_move_releases_storage(self):
"""Moving parameters to 'meta' must place them on the meta device."""
model = FakeLinearModel(size_mb=2)
assert model.weight.device.type != "meta"
model.to(device="meta")
assert model.weight.device.type == "meta"
# Meta tensors retain their logical shape but live on a virtual device
# with no physical backing — this is what releases RAM.
assert model.weight.nelement() > 0 # still has logical shape
assert model.weight.untyped_storage().device.type == "meta"
# ---------------------------------------------------------------------------
# 4. MPS flush conditionality
# ---------------------------------------------------------------------------
class TestMpsFlushConditionality:
"""Verify the MPS flush only activates under correct conditions."""
def test_flush_requires_aggressive_offload_flag(self):
"""The MPS flush in samplers is gated on AGGRESSIVE_OFFLOAD."""
import comfy.model_management as mm
# When False, flush should NOT be injected
original = getattr(mm, "AGGRESSIVE_OFFLOAD", False)
try:
mm.AGGRESSIVE_OFFLOAD = False
assert not (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
mm.AGGRESSIVE_OFFLOAD = True
assert (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
finally:
mm.AGGRESSIVE_OFFLOAD = original
def test_flush_requires_mps_device(self):
"""The flush condition checks device.type == 'mps'."""
# Simulate CPU device — flush should not activate
cpu_device = torch.device("cpu")
assert cpu_device.type != "mps"
# Simulate MPS device string check
if torch.backends.mps.is_available():
mps_device = torch.device("mps")
assert mps_device.type == "mps"
# ---------------------------------------------------------------------------
# 5. AGGRESSIVE_OFFLOAD flag integration
# ---------------------------------------------------------------------------
class TestAggressiveOffloadFlag:
"""Verify the CLI flag is correctly exposed."""
def test_flag_exists_in_model_management(self):
"""AGGRESSIVE_OFFLOAD must be importable from model_management."""
import comfy.model_management as mm
assert hasattr(mm, "AGGRESSIVE_OFFLOAD")
assert isinstance(mm.AGGRESSIVE_OFFLOAD, bool)
def test_flag_defaults_from_cli_args(self):
"""The flag should be wired from cli_args to model_management."""
import comfy.cli_args as cli_args
import comfy.model_management as mm
assert hasattr(cli_args.args, "aggressive_offload")
assert mm.AGGRESSIVE_OFFLOAD == cli_args.args.aggressive_offload