mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
302 lines
11 KiB
Python
302 lines
11 KiB
Python
"""Tests for the aggressive-offload memory management feature.
|
|
|
|
These tests validate the Apple Silicon (MPS) memory optimisation path without
|
|
requiring a GPU or actual model weights. Every test mocks the relevant model
|
|
and cache structures so the suite can run in CI on any platform.
|
|
"""
|
|
|
|
import pytest
|
|
import types
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures & helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class FakeLinearModel(nn.Module):
|
|
"""Minimal nn.Module whose parameters consume measurable memory."""
|
|
|
|
def __init__(self, size_mb: float = 2.0):
|
|
super().__init__()
|
|
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
|
|
n = int(size_mb * 1024 * 1024 / 4)
|
|
self.weight = nn.Parameter(torch.zeros(n, dtype=torch.float32))
|
|
|
|
|
|
class FakeModelPatcher:
|
|
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
|
|
|
|
def __init__(self, size_mb: float = 2.0):
|
|
self.model = FakeLinearModel(size_mb)
|
|
self._loaded_size = int(size_mb * 1024 * 1024)
|
|
|
|
def loaded_size(self):
|
|
return self._loaded_size
|
|
|
|
def is_dynamic(self):
|
|
return False
|
|
|
|
|
|
class FakeLoadedModel:
|
|
"""Mimics LoadedModel entries in current_loaded_models."""
|
|
|
|
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
|
|
self._model = patcher
|
|
self.currently_used = currently_used
|
|
|
|
@property
|
|
def model(self):
|
|
return self._model
|
|
|
|
def model_memory(self):
|
|
return self._model.loaded_size()
|
|
|
|
def model_unload(self, _memory_to_free):
|
|
return True
|
|
|
|
def model_load(self, _device, _keep_loaded):
|
|
pass
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 1. BasicCache.clear_all()
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestBasicCacheClearAll:
|
|
"""Verify that BasicCache.clear_all() is a proper public API."""
|
|
|
|
def test_clear_all_empties_cache_and_subcaches(self):
|
|
"""clear_all() must remove every entry in both dicts."""
|
|
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
|
|
|
|
cache = BasicCache(CacheKeySetInputSignature)
|
|
cache.cache["key1"] = "value1"
|
|
cache.cache["key2"] = "value2"
|
|
cache.subcaches["sub1"] = "subvalue1"
|
|
|
|
cache.clear_all()
|
|
|
|
assert len(cache.cache) == 0
|
|
assert len(cache.subcaches) == 0
|
|
|
|
def test_clear_all_is_idempotent(self):
|
|
"""Calling clear_all() on an already-empty cache must not raise."""
|
|
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
|
|
|
|
cache = BasicCache(CacheKeySetInputSignature)
|
|
cache.clear_all() # should be a no-op
|
|
cache.clear_all() # still a no-op
|
|
|
|
assert len(cache.cache) == 0
|
|
|
|
def test_null_cache_clear_all_is_noop(self):
|
|
"""NullCache.clear_all() must not raise — it's the null backend."""
|
|
from comfy_execution.caching import NullCache
|
|
|
|
null = NullCache()
|
|
null.clear_all() # must not raise AttributeError
|
|
|
|
def test_lru_cache_clear_all_resets_metadata(self):
|
|
"""LRUCache.clear_all() must also reset used_generation, children, min_generation."""
|
|
from comfy_execution.caching import LRUCache, CacheKeySetInputSignature
|
|
|
|
cache = LRUCache(CacheKeySetInputSignature, max_size=10)
|
|
# Simulate some entries
|
|
cache.cache["k1"] = "v1"
|
|
cache.used_generation["k1"] = 5
|
|
cache.children["k1"] = ["child1"]
|
|
cache.min_generation = 3
|
|
cache.generation = 5
|
|
|
|
cache.clear_all()
|
|
|
|
assert len(cache.cache) == 0
|
|
assert len(cache.used_generation) == 0
|
|
assert len(cache.children) == 0
|
|
assert cache.min_generation == 0
|
|
# generation counter should NOT be reset (it's a monotonic counter)
|
|
assert cache.generation == 5
|
|
|
|
def test_ram_pressure_cache_clear_all_resets_timestamps(self):
|
|
"""RAMPressureCache.clear_all() must also reset timestamps."""
|
|
from comfy_execution.caching import RAMPressureCache, CacheKeySetInputSignature
|
|
|
|
cache = RAMPressureCache(CacheKeySetInputSignature)
|
|
cache.cache["k1"] = "v1"
|
|
cache.used_generation["k1"] = 2
|
|
cache.timestamps["k1"] = 1234567890.0
|
|
|
|
cache.clear_all()
|
|
|
|
assert len(cache.cache) == 0
|
|
assert len(cache.used_generation) == 0
|
|
assert len(cache.timestamps) == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 2. Callback registration & dispatch
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestModelDestroyedCallbacks:
|
|
"""Validate the on_model_destroyed lifecycle callback system."""
|
|
|
|
def setup_method(self):
|
|
"""Reset the callback list before every test."""
|
|
import comfy.model_management as mm
|
|
self._original = mm._on_model_destroyed_callbacks.copy()
|
|
mm._on_model_destroyed_callbacks.clear()
|
|
|
|
def teardown_method(self):
|
|
"""Restore the original callback list."""
|
|
import comfy.model_management as mm
|
|
mm._on_model_destroyed_callbacks.clear()
|
|
mm._on_model_destroyed_callbacks.extend(self._original)
|
|
|
|
def test_register_single_callback(self):
|
|
import comfy.model_management as mm
|
|
|
|
invocations = []
|
|
mm.register_model_destroyed_callback(lambda reason: invocations.append(reason))
|
|
|
|
assert len(mm._on_model_destroyed_callbacks) == 1
|
|
|
|
# Simulate dispatch
|
|
for cb in mm._on_model_destroyed_callbacks:
|
|
cb("test")
|
|
assert invocations == ["test"]
|
|
|
|
def test_register_multiple_callbacks(self):
|
|
"""Multiple registrants must all fire — no silent overwrites."""
|
|
import comfy.model_management as mm
|
|
|
|
results_a, results_b = [], []
|
|
mm.register_model_destroyed_callback(lambda r: results_a.append(r))
|
|
mm.register_model_destroyed_callback(lambda r: results_b.append(r))
|
|
|
|
for cb in mm._on_model_destroyed_callbacks:
|
|
cb("batch")
|
|
|
|
assert results_a == ["batch"]
|
|
assert results_b == ["batch"]
|
|
|
|
def test_callback_receives_reason_string(self):
|
|
"""The callback signature is (reason: str) -> None."""
|
|
import comfy.model_management as mm
|
|
|
|
captured = {}
|
|
def _cb(reason):
|
|
captured["reason"] = reason
|
|
captured["type"] = type(reason).__name__
|
|
|
|
mm.register_model_destroyed_callback(_cb)
|
|
for cb in mm._on_model_destroyed_callbacks:
|
|
cb("batch")
|
|
|
|
assert captured["reason"] == "batch"
|
|
assert captured["type"] == "str"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 3. Meta-device destruction threshold
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestMetaDeviceThreshold:
|
|
"""Verify that only models > 1 GB are queued for meta-device destruction."""
|
|
|
|
def test_small_model_not_destroyed(self):
|
|
"""A 160 MB model (VAE-sized) must NOT be moved to meta device."""
|
|
model = FakeLinearModel(size_mb=160)
|
|
|
|
# Simulate the threshold check from free_memory
|
|
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
|
|
threshold = 1024 * 1024 * 1024 # 1 GB
|
|
|
|
assert model_size < threshold, (
|
|
f"160 MB model should be below 1 GB threshold, got {model_size / (1024**2):.0f} MB"
|
|
)
|
|
# Confirm parameters are still on a real device
|
|
assert model.weight.device.type != "meta"
|
|
|
|
def test_large_model_above_threshold(self):
|
|
"""A 2 GB model (UNET/CLIP-sized) must BE above the destruction threshold."""
|
|
# Use a meta-device tensor to avoid allocating 2 GB of real memory.
|
|
# Meta tensors report correct numel/element_size but use zero storage.
|
|
n = int(2048 * 1024 * 1024 / 4) # 2 GB in float32 params
|
|
meta_weight = torch.empty(n, dtype=torch.float32, device="meta")
|
|
|
|
model_size = meta_weight.numel() * meta_weight.element_size()
|
|
threshold = 1024 * 1024 * 1024 # 1 GB
|
|
|
|
assert model_size > threshold, (
|
|
f"2 GB model should be above 1 GB threshold, got {model_size / (1024**2):.0f} MB"
|
|
)
|
|
|
|
def test_meta_device_move_releases_storage(self):
|
|
"""Moving parameters to 'meta' must place them on the meta device."""
|
|
model = FakeLinearModel(size_mb=2)
|
|
assert model.weight.device.type != "meta"
|
|
|
|
model.to(device="meta")
|
|
|
|
assert model.weight.device.type == "meta"
|
|
# Meta tensors retain their logical shape but live on a virtual device
|
|
# with no physical backing — this is what releases RAM.
|
|
assert model.weight.nelement() > 0 # still has logical shape
|
|
assert model.weight.untyped_storage().device.type == "meta"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 4. MPS flush conditionality
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestMpsFlushConditionality:
|
|
"""Verify the MPS flush only activates under correct conditions."""
|
|
|
|
def test_flush_requires_aggressive_offload_flag(self):
|
|
"""The MPS flush in samplers is gated on AGGRESSIVE_OFFLOAD."""
|
|
import comfy.model_management as mm
|
|
|
|
# When False, flush should NOT be injected
|
|
original = getattr(mm, "AGGRESSIVE_OFFLOAD", False)
|
|
try:
|
|
mm.AGGRESSIVE_OFFLOAD = False
|
|
assert not (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
|
|
|
|
mm.AGGRESSIVE_OFFLOAD = True
|
|
assert (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
|
|
finally:
|
|
mm.AGGRESSIVE_OFFLOAD = original
|
|
|
|
def test_flush_requires_mps_device(self):
|
|
"""The flush condition checks device.type == 'mps'."""
|
|
# Simulate CPU device — flush should not activate
|
|
cpu_device = torch.device("cpu")
|
|
assert cpu_device.type != "mps"
|
|
|
|
# Simulate MPS device string check
|
|
if torch.backends.mps.is_available():
|
|
mps_device = torch.device("mps")
|
|
assert mps_device.type == "mps"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 5. AGGRESSIVE_OFFLOAD flag integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestAggressiveOffloadFlag:
|
|
"""Verify the CLI flag is correctly exposed."""
|
|
|
|
def test_flag_exists_in_model_management(self):
|
|
"""AGGRESSIVE_OFFLOAD must be importable from model_management."""
|
|
import comfy.model_management as mm
|
|
assert hasattr(mm, "AGGRESSIVE_OFFLOAD")
|
|
assert isinstance(mm.AGGRESSIVE_OFFLOAD, bool)
|
|
|
|
def test_flag_defaults_from_cli_args(self):
|
|
"""The flag should be wired from cli_args to model_management."""
|
|
import comfy.cli_args as cli_args
|
|
import comfy.model_management as mm
|
|
assert hasattr(cli_args.args, "aggressive_offload")
|
|
assert mm.AGGRESSIVE_OFFLOAD == cli_args.args.aggressive_offload
|