mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-26 00:47:24 +08:00
Merge 0bd2a353bf into 3e3ed8cc2a
This commit is contained in:
commit
352e5cc190
@ -157,6 +157,7 @@ parser.add_argument("--force-non-blocking", action="store_true", help="Force Com
|
|||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
|
parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Moves all models larger than 1 GB to a virtual (meta) device between runs, preventing swap pressure on disk. Small models like the VAE are preserved. Trade-off: models are reloaded from disk on subsequent generations.")
|
||||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||||
|
|
||||||
class PerformanceFeature(enum.Enum):
|
class PerformanceFeature(enum.Enum):
|
||||||
|
|||||||
@ -461,10 +461,51 @@ if cpu_state == CPUState.MPS:
|
|||||||
logging.info(f"Set vram state to: {vram_state.name}")
|
logging.info(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
||||||
|
AGGRESSIVE_OFFLOAD = args.aggressive_offload
|
||||||
|
|
||||||
if DISABLE_SMART_MEMORY:
|
if DISABLE_SMART_MEMORY:
|
||||||
logging.info("Disabling smart memory management")
|
logging.info("Disabling smart memory management")
|
||||||
|
|
||||||
|
if AGGRESSIVE_OFFLOAD:
|
||||||
|
logging.info("Aggressive offload enabled: models will be freed from RAM after use (designed for Apple Silicon)")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Model lifecycle callbacks — on_model_destroyed
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Why not comfy.hooks? The existing hook system (comfy/hooks.py) is scoped
|
||||||
|
# to *sampling conditioning* — LoRA weight injection, transformer_options,
|
||||||
|
# and keyframe scheduling. It has no concept of model-management lifecycle
|
||||||
|
# events such as "a model's parameters were deallocated".
|
||||||
|
#
|
||||||
|
# This lightweight callback list fills that gap. It is intentionally minimal
|
||||||
|
# (append-only, no priorities, no removal) because the only current consumer
|
||||||
|
# is the execution-engine cache invalidator registered in PromptExecutor.
|
||||||
|
# If upstream adopts a formal lifecycle-event bus in the future, these
|
||||||
|
# callbacks should migrate to that system.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
_on_model_destroyed_callbacks: list = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_model_destroyed_callback(callback):
|
||||||
|
"""Register a listener for post-destruction lifecycle events.
|
||||||
|
|
||||||
|
After ``free_memory`` moves one or more models to the ``meta`` device
|
||||||
|
(aggressive offload), every registered callback is invoked once with a
|
||||||
|
*reason* string describing the batch (e.g. ``"batch"``).
|
||||||
|
|
||||||
|
Typical usage — executed by ``PromptExecutor.__init__``::
|
||||||
|
|
||||||
|
def _invalidate(reason):
|
||||||
|
executor.caches.outputs.clear_all()
|
||||||
|
register_model_destroyed_callback(_invalidate)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: ``Callable[[str], None]`` — receives a human-readable
|
||||||
|
reason string. Must be safe to call from within the
|
||||||
|
``free_memory`` critical section (no heavy I/O, no model loads).
|
||||||
|
"""
|
||||||
|
_on_model_destroyed_callbacks.append(callback)
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
@ -633,14 +674,21 @@ def offloaded_memory(loaded_models, device):
|
|||||||
WINDOWS = any(platform.win32_ver())
|
WINDOWS = any(platform.win32_ver())
|
||||||
|
|
||||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if WINDOWS:
|
if cpu_state == CPUState.MPS and AGGRESSIVE_OFFLOAD:
|
||||||
|
# macOS with Apple Silicon + aggressive offload: shared memory means OS
|
||||||
|
# needs more headroom. Reserve 4 GB for macOS + system services to
|
||||||
|
# prevent swap thrashing during model destruction/reload cycles.
|
||||||
|
EXTRA_RESERVED_VRAM = 4 * 1024 * 1024 * 1024
|
||||||
|
logging.info("MPS detected with --aggressive-offload: reserving 4 GB for macOS system overhead")
|
||||||
|
elif WINDOWS:
|
||||||
import comfy.windows
|
import comfy.windows
|
||||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||||
def get_free_ram():
|
def get_free_ram():
|
||||||
return comfy.windows.get_free_ram()
|
return comfy.windows.get_free_ram()
|
||||||
else:
|
|
||||||
|
if not WINDOWS:
|
||||||
def get_free_ram():
|
def get_free_ram():
|
||||||
return psutil.virtual_memory().available
|
return psutil.virtual_memory().available
|
||||||
|
|
||||||
@ -663,14 +711,25 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
|
|
||||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if device is None or shift_model.device == device:
|
# On Apple Silicon SHARED mode, CPU RAM == GPU VRAM (same physical memory).
|
||||||
|
# Bypass the device filter so CPU-loaded models (like CLIP) can be freed.
|
||||||
|
device_match = (device is None or shift_model.device == device)
|
||||||
|
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
|
||||||
|
device_match = True
|
||||||
|
if device_match:
|
||||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
can_unload_sorted = sorted(can_unload)
|
can_unload_sorted = sorted(can_unload)
|
||||||
|
# Collect models to destroy via meta device AFTER the unload loop completes,
|
||||||
|
# so we don't kill weakrefs of models still being iterated.
|
||||||
|
_meta_destroy_queue = []
|
||||||
for x in can_unload_sorted:
|
for x in can_unload_sorted:
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
|
# Guard: weakref may already be dead from a previous iteration
|
||||||
|
if current_loaded_models[i].model is None:
|
||||||
|
continue
|
||||||
memory_to_free = 1e32
|
memory_to_free = 1e32
|
||||||
pins_to_free = 1e32
|
pins_to_free = 1e32
|
||||||
if not DISABLE_SMART_MEMORY or device is None:
|
if not DISABLE_SMART_MEMORY or device is None:
|
||||||
@ -681,15 +740,72 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
#as that works on-demand.
|
#as that works on-demand.
|
||||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||||
memory_to_free = 0
|
memory_to_free = 0
|
||||||
|
|
||||||
|
# Aggressive offload for Apple Silicon: force-unload unused models
|
||||||
|
# regardless of free memory, since CPU RAM == GPU VRAM.
|
||||||
|
# Only force-unload models > 1 GB — small models like the VAE (160 MB)
|
||||||
|
# are preserved to avoid unnecessary reload from disk.
|
||||||
|
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
|
||||||
|
model_ref = current_loaded_models[i].model
|
||||||
|
if model_ref is not None and not current_loaded_models[i].currently_used:
|
||||||
|
model_size = current_loaded_models[i].model_memory()
|
||||||
|
if model_size > 1024 * 1024 * 1024: # 1 GB threshold
|
||||||
|
memory_to_free = 1e32 # Force unload
|
||||||
|
inner = getattr(model_ref, "model", None)
|
||||||
|
model_name = inner.__class__.__name__ if inner is not None else "unknown"
|
||||||
|
model_size_mb = model_size / (1024 * 1024)
|
||||||
|
logging.info(f"[aggressive-offload] Force-unloading {model_name} ({model_size_mb:.0f} MB) from shared RAM")
|
||||||
|
|
||||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
# Queue for meta device destruction after loop completes.
|
||||||
|
# Only destroy large models (>1 GB) — small models like the VAE (160 MB)
|
||||||
|
# are kept because the execution cache may reuse their patcher across
|
||||||
|
# workflow nodes (e.g. vae_loader is cached while vae_decode runs later).
|
||||||
|
if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED:
|
||||||
|
if current_loaded_models[i].model is not None:
|
||||||
|
model_size = current_loaded_models[i].model_memory()
|
||||||
|
if model_size > 1024 * 1024 * 1024: # Only meta-destroy models > 1 GB
|
||||||
|
_meta_destroy_queue.append(i)
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
if pins_to_free > 0:
|
if pins_to_free > 0:
|
||||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
if current_loaded_models[i].model is not None:
|
||||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||||
|
|
||||||
|
# --- Phase 2: Deferred meta-device destruction -------------------------
|
||||||
|
# Move parameters of queued models to the 'meta' device. This replaces
|
||||||
|
# every nn.Parameter with a zero-storage meta tensor, releasing physical
|
||||||
|
# RAM on unified-memory systems (Apple Silicon). The operation is
|
||||||
|
# deferred until *after* the unload loop to avoid invalidating weakrefs
|
||||||
|
# that other iterations may still reference.
|
||||||
|
for i in _meta_destroy_queue:
|
||||||
|
try:
|
||||||
|
model_ref = current_loaded_models[i].model
|
||||||
|
if model_ref is None:
|
||||||
|
continue
|
||||||
|
inner_model = model_ref.model
|
||||||
|
model_name = inner_model.__class__.__name__
|
||||||
|
param_count = sum(p.numel() * p.element_size() for p in inner_model.parameters())
|
||||||
|
inner_model.to(device="meta")
|
||||||
|
logging.info(f"[aggressive-offload] Moved {model_name} params to meta device, freed {param_count / (1024**2):.0f} MB")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"[aggressive-offload] Failed to move model to meta: {e}")
|
||||||
|
|
||||||
|
# --- Phase 3: Notify lifecycle listeners --------------------------------
|
||||||
|
# Fire on_model_destroyed callbacks *once* after the entire batch has been
|
||||||
|
# processed, not per-model. This lets the execution engine clear its
|
||||||
|
# output cache in a single operation (see PromptExecutor.__init__).
|
||||||
|
if _meta_destroy_queue and _on_model_destroyed_callbacks:
|
||||||
|
for cb in _on_model_destroyed_callbacks:
|
||||||
|
cb("batch")
|
||||||
|
logging.info(f"[aggressive-offload] Invalidated execution cache after destroying {len(_meta_destroy_queue)} model(s)")
|
||||||
|
|
||||||
for x in can_unload_sorted:
|
for x in can_unload_sorted:
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
|
# Guard: weakref may be dead after cache invalidation (meta device move)
|
||||||
|
if current_loaded_models[i].model is None:
|
||||||
|
continue
|
||||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||||
if ram_to_free <= 0 and i not in unloaded_model:
|
if ram_to_free <= 0 and i not in unloaded_model:
|
||||||
continue
|
continue
|
||||||
@ -702,6 +818,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
|
|
||||||
if len(unloaded_model) > 0:
|
if len(unloaded_model) > 0:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
if AGGRESSIVE_OFFLOAD:
|
||||||
|
gc.collect() # Force Python GC to release model tensors
|
||||||
|
soft_empty_cache() # Second pass to free MPS allocator cache
|
||||||
elif device is not None:
|
elif device is not None:
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||||
|
|||||||
@ -748,6 +748,18 @@ class KSAMPLER(Sampler):
|
|||||||
if callback is not None:
|
if callback is not None:
|
||||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||||
|
|
||||||
|
# On Apple Silicon MPS, flush the allocator pool between steps to prevent
|
||||||
|
# progressive memory fragmentation and swap thrashing. Wrapping the callback
|
||||||
|
# here (rather than patching individual samplers) covers all sampler variants.
|
||||||
|
import comfy.model_management
|
||||||
|
if noise.device.type == "mps" and getattr(comfy.model_management, "AGGRESSIVE_OFFLOAD", False):
|
||||||
|
_inner_callback = k_callback
|
||||||
|
def _mps_flush_callback(x):
|
||||||
|
if _inner_callback is not None:
|
||||||
|
_inner_callback(x)
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
k_callback = _mps_flush_callback
|
||||||
|
|
||||||
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
||||||
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
|
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
|
||||||
return samples
|
return samples
|
||||||
|
|||||||
@ -191,6 +191,17 @@ class BasicCache:
|
|||||||
for key in to_remove:
|
for key in to_remove:
|
||||||
del self.subcaches[key]
|
del self.subcaches[key]
|
||||||
|
|
||||||
|
def clear_all(self):
|
||||||
|
"""Drop all cached outputs unconditionally.
|
||||||
|
|
||||||
|
This is the public API for external subsystems (e.g. aggressive model
|
||||||
|
offloading) that need to invalidate every cached result — for instance
|
||||||
|
after model parameters have been moved to the ``meta`` device and the
|
||||||
|
cached tensors are no longer usable.
|
||||||
|
"""
|
||||||
|
self.cache.clear()
|
||||||
|
self.subcaches.clear()
|
||||||
|
|
||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
assert self.initialized
|
assert self.initialized
|
||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
@ -418,6 +429,10 @@ class NullCache:
|
|||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def clear_all(self):
|
||||||
|
"""No-op: null backend has nothing to invalidate."""
|
||||||
|
pass
|
||||||
|
|
||||||
def poll(self, **kwargs):
|
def poll(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -451,6 +466,13 @@ class LRUCache(BasicCache):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
|
|
||||||
|
def clear_all(self):
|
||||||
|
"""Drop all cached outputs and reset LRU bookkeeping."""
|
||||||
|
super().clear_all()
|
||||||
|
self.used_generation.clear()
|
||||||
|
self.children.clear()
|
||||||
|
self.min_generation = 0
|
||||||
|
|
||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
||||||
self.min_generation += 1
|
self.min_generation += 1
|
||||||
@ -509,6 +531,11 @@ class RAMPressureCache(LRUCache):
|
|||||||
super().__init__(key_class, 0, enable_providers=enable_providers)
|
super().__init__(key_class, 0, enable_providers=enable_providers)
|
||||||
self.timestamps = {}
|
self.timestamps = {}
|
||||||
|
|
||||||
|
def clear_all(self):
|
||||||
|
"""Drop all cached outputs and reset RAM-pressure bookkeeping."""
|
||||||
|
super().clear_all()
|
||||||
|
self.timestamps.clear()
|
||||||
|
|
||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
|||||||
11
execution.py
11
execution.py
@ -651,6 +651,17 @@ class PromptExecutor:
|
|||||||
self.cache_type = cache_type
|
self.cache_type = cache_type
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
# Register callback so model_management can invalidate cached outputs
|
||||||
|
# after destroying a model via meta device move (aggressive offload).
|
||||||
|
# NOTE: self.caches is resolved at call time (not capture time), so this
|
||||||
|
# callback remains valid even if reset() replaces self.caches later.
|
||||||
|
import comfy.model_management as mm
|
||||||
|
if mm.AGGRESSIVE_OFFLOAD:
|
||||||
|
executor = self
|
||||||
|
def _invalidate_cache(reason):
|
||||||
|
logging.info(f"[aggressive-offload] Invalidating execution cache ({reason})")
|
||||||
|
executor.caches.outputs.clear_all()
|
||||||
|
mm.register_model_destroyed_callback(_invalidate_cache)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||||
|
|||||||
311
tests-unit/test_aggressive_offload.py
Normal file
311
tests-unit/test_aggressive_offload.py
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
"""Tests for the aggressive-offload memory management feature.
|
||||||
|
|
||||||
|
These tests validate the Apple Silicon (MPS) memory optimisation path without
|
||||||
|
requiring a GPU or actual model weights. Every test mocks the relevant model
|
||||||
|
and cache structures so the suite can run in CI on any platform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import types
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures & helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class FakeLinearModel(nn.Module):
|
||||||
|
"""Minimal nn.Module whose parameters consume measurable memory."""
|
||||||
|
|
||||||
|
def __init__(self, size_mb: float = 2.0):
|
||||||
|
"""Create a model with approximately ``size_mb`` MB of float32 params."""
|
||||||
|
super().__init__()
|
||||||
|
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
|
||||||
|
n = int(size_mb * 1024 * 1024 / 4)
|
||||||
|
self.weight = nn.Parameter(torch.zeros(n, dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
|
class FakeModelPatcher:
|
||||||
|
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
|
||||||
|
|
||||||
|
def __init__(self, size_mb: float = 2.0):
|
||||||
|
"""Create a patcher wrapping a ``FakeLinearModel`` of the given size."""
|
||||||
|
self.model = FakeLinearModel(size_mb)
|
||||||
|
self._loaded_size = int(size_mb * 1024 * 1024)
|
||||||
|
|
||||||
|
def loaded_size(self):
|
||||||
|
"""Return reported loaded size in bytes."""
|
||||||
|
return self._loaded_size
|
||||||
|
|
||||||
|
def is_dynamic(self):
|
||||||
|
"""Static model — never dynamically offloaded."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLoadedModel:
|
||||||
|
"""Mimics LoadedModel entries in current_loaded_models."""
|
||||||
|
|
||||||
|
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
|
||||||
|
"""Wrap a ``FakeModelPatcher`` as a loaded-model entry."""
|
||||||
|
self._model = patcher
|
||||||
|
self.currently_used = currently_used
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
"""Return the underlying model patcher."""
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
def model_memory(self):
|
||||||
|
"""Return memory footprint in bytes."""
|
||||||
|
return self._model.loaded_size()
|
||||||
|
|
||||||
|
def model_unload(self, _memory_to_free):
|
||||||
|
"""Simulate successful unload."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def model_load(self, _device, _keep_loaded):
|
||||||
|
"""No-op load for testing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. BasicCache.clear_all()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBasicCacheClearAll:
|
||||||
|
"""Verify that BasicCache.clear_all() is a proper public API."""
|
||||||
|
|
||||||
|
def test_clear_all_empties_cache_and_subcaches(self):
|
||||||
|
"""clear_all() must remove every entry in both dicts."""
|
||||||
|
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
|
||||||
|
|
||||||
|
cache = BasicCache(CacheKeySetInputSignature)
|
||||||
|
cache.cache["key1"] = "value1"
|
||||||
|
cache.cache["key2"] = "value2"
|
||||||
|
cache.subcaches["sub1"] = "subvalue1"
|
||||||
|
|
||||||
|
cache.clear_all()
|
||||||
|
|
||||||
|
assert len(cache.cache) == 0
|
||||||
|
assert len(cache.subcaches) == 0
|
||||||
|
|
||||||
|
def test_clear_all_is_idempotent(self):
|
||||||
|
"""Calling clear_all() on an already-empty cache must not raise."""
|
||||||
|
from comfy_execution.caching import BasicCache, CacheKeySetInputSignature
|
||||||
|
|
||||||
|
cache = BasicCache(CacheKeySetInputSignature)
|
||||||
|
cache.clear_all() # should be a no-op
|
||||||
|
cache.clear_all() # still a no-op
|
||||||
|
|
||||||
|
assert len(cache.cache) == 0
|
||||||
|
|
||||||
|
def test_null_cache_clear_all_is_noop(self):
|
||||||
|
"""NullCache.clear_all() must not raise — it's the null backend."""
|
||||||
|
from comfy_execution.caching import NullCache
|
||||||
|
|
||||||
|
null = NullCache()
|
||||||
|
null.clear_all() # must not raise AttributeError
|
||||||
|
|
||||||
|
def test_lru_cache_clear_all_resets_metadata(self):
|
||||||
|
"""LRUCache.clear_all() must also reset used_generation, children, min_generation."""
|
||||||
|
from comfy_execution.caching import LRUCache, CacheKeySetInputSignature
|
||||||
|
|
||||||
|
cache = LRUCache(CacheKeySetInputSignature, max_size=10)
|
||||||
|
# Simulate some entries
|
||||||
|
cache.cache["k1"] = "v1"
|
||||||
|
cache.used_generation["k1"] = 5
|
||||||
|
cache.children["k1"] = ["child1"]
|
||||||
|
cache.min_generation = 3
|
||||||
|
cache.generation = 5
|
||||||
|
|
||||||
|
cache.clear_all()
|
||||||
|
|
||||||
|
assert len(cache.cache) == 0
|
||||||
|
assert len(cache.used_generation) == 0
|
||||||
|
assert len(cache.children) == 0
|
||||||
|
assert cache.min_generation == 0
|
||||||
|
# generation counter should NOT be reset (it's a monotonic counter)
|
||||||
|
assert cache.generation == 5
|
||||||
|
|
||||||
|
def test_ram_pressure_cache_clear_all_resets_timestamps(self):
|
||||||
|
"""RAMPressureCache.clear_all() must also reset timestamps."""
|
||||||
|
from comfy_execution.caching import RAMPressureCache, CacheKeySetInputSignature
|
||||||
|
|
||||||
|
cache = RAMPressureCache(CacheKeySetInputSignature)
|
||||||
|
cache.cache["k1"] = "v1"
|
||||||
|
cache.used_generation["k1"] = 2
|
||||||
|
cache.timestamps["k1"] = 1234567890.0
|
||||||
|
|
||||||
|
cache.clear_all()
|
||||||
|
|
||||||
|
assert len(cache.cache) == 0
|
||||||
|
assert len(cache.used_generation) == 0
|
||||||
|
assert len(cache.timestamps) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 2. Callback registration & dispatch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestModelDestroyedCallbacks:
|
||||||
|
"""Validate the on_model_destroyed lifecycle callback system."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Save and reset the callback list before every test."""
|
||||||
|
import comfy.model_management as mm
|
||||||
|
self._original = mm._on_model_destroyed_callbacks.copy()
|
||||||
|
mm._on_model_destroyed_callbacks.clear()
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Restore the original callback list after every test."""
|
||||||
|
import comfy.model_management as mm
|
||||||
|
mm._on_model_destroyed_callbacks.clear()
|
||||||
|
mm._on_model_destroyed_callbacks.extend(self._original)
|
||||||
|
|
||||||
|
def test_register_single_callback(self):
|
||||||
|
"""A single registered callback must be stored and callable."""
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
|
invocations = []
|
||||||
|
mm.register_model_destroyed_callback(lambda reason: invocations.append(reason))
|
||||||
|
|
||||||
|
assert len(mm._on_model_destroyed_callbacks) == 1
|
||||||
|
|
||||||
|
# Simulate dispatch
|
||||||
|
for cb in mm._on_model_destroyed_callbacks:
|
||||||
|
cb("test")
|
||||||
|
assert invocations == ["test"]
|
||||||
|
|
||||||
|
def test_register_multiple_callbacks(self):
|
||||||
|
"""Multiple registrants must all fire — no silent overwrites."""
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
|
results_a, results_b = [], []
|
||||||
|
mm.register_model_destroyed_callback(lambda r: results_a.append(r))
|
||||||
|
mm.register_model_destroyed_callback(lambda r: results_b.append(r))
|
||||||
|
|
||||||
|
for cb in mm._on_model_destroyed_callbacks:
|
||||||
|
cb("batch")
|
||||||
|
|
||||||
|
assert results_a == ["batch"]
|
||||||
|
assert results_b == ["batch"]
|
||||||
|
|
||||||
|
def test_callback_receives_reason_string(self):
|
||||||
|
"""The callback signature is (reason: str) -> None."""
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
def _cb(reason):
|
||||||
|
captured["reason"] = reason
|
||||||
|
captured["type"] = type(reason).__name__
|
||||||
|
|
||||||
|
mm.register_model_destroyed_callback(_cb)
|
||||||
|
for cb in mm._on_model_destroyed_callbacks:
|
||||||
|
cb("batch")
|
||||||
|
|
||||||
|
assert captured["reason"] == "batch"
|
||||||
|
assert captured["type"] == "str"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 3. Meta-device destruction threshold
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMetaDeviceThreshold:
|
||||||
|
"""Verify that only models > 1 GB are queued for meta-device destruction."""
|
||||||
|
|
||||||
|
def test_small_model_not_destroyed(self):
|
||||||
|
"""A 160 MB model (VAE-sized) must NOT be moved to meta device."""
|
||||||
|
model = FakeLinearModel(size_mb=160)
|
||||||
|
|
||||||
|
# Simulate the threshold check from free_memory
|
||||||
|
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
|
||||||
|
threshold = 1024 * 1024 * 1024 # 1 GB
|
||||||
|
|
||||||
|
assert model_size < threshold, (
|
||||||
|
f"160 MB model should be below 1 GB threshold, got {model_size / (1024**2):.0f} MB"
|
||||||
|
)
|
||||||
|
# Confirm parameters are still on a real device
|
||||||
|
assert model.weight.device.type != "meta"
|
||||||
|
|
||||||
|
def test_large_model_above_threshold(self):
|
||||||
|
"""A 2 GB model (UNET/CLIP-sized) must BE above the destruction threshold."""
|
||||||
|
# Use a meta-device tensor to avoid allocating 2 GB of real memory.
|
||||||
|
# Meta tensors report correct numel/element_size but use zero storage.
|
||||||
|
n = int(2048 * 1024 * 1024 / 4) # 2 GB in float32 params
|
||||||
|
meta_weight = torch.empty(n, dtype=torch.float32, device="meta")
|
||||||
|
|
||||||
|
model_size = meta_weight.numel() * meta_weight.element_size()
|
||||||
|
threshold = 1024 * 1024 * 1024 # 1 GB
|
||||||
|
|
||||||
|
assert model_size > threshold, (
|
||||||
|
f"2 GB model should be above 1 GB threshold, got {model_size / (1024**2):.0f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_meta_device_move_releases_storage(self):
|
||||||
|
"""Moving parameters to 'meta' must place them on the meta device."""
|
||||||
|
model = FakeLinearModel(size_mb=2)
|
||||||
|
assert model.weight.device.type != "meta"
|
||||||
|
|
||||||
|
model.to(device="meta")
|
||||||
|
|
||||||
|
assert model.weight.device.type == "meta"
|
||||||
|
# Meta tensors retain their logical shape but live on a virtual device
|
||||||
|
# with no physical backing — this is what releases RAM.
|
||||||
|
assert model.weight.nelement() > 0 # still has logical shape
|
||||||
|
assert model.weight.untyped_storage().device.type == "meta"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 4. MPS flush conditionality
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestMpsFlushConditionality:
|
||||||
|
"""Verify the MPS flush only activates under correct conditions."""
|
||||||
|
|
||||||
|
def test_flush_requires_aggressive_offload_flag(self):
|
||||||
|
"""The MPS flush in samplers is gated on AGGRESSIVE_OFFLOAD."""
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
|
# When False, flush should NOT be injected
|
||||||
|
original = getattr(mm, "AGGRESSIVE_OFFLOAD", False)
|
||||||
|
try:
|
||||||
|
mm.AGGRESSIVE_OFFLOAD = False
|
||||||
|
assert not (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
|
||||||
|
|
||||||
|
mm.AGGRESSIVE_OFFLOAD = True
|
||||||
|
assert (True and getattr(mm, "AGGRESSIVE_OFFLOAD", False))
|
||||||
|
finally:
|
||||||
|
mm.AGGRESSIVE_OFFLOAD = original
|
||||||
|
|
||||||
|
def test_flush_requires_mps_device(self):
|
||||||
|
"""The flush must only activate on MPS devices, not CPU or CUDA."""
|
||||||
|
# Simulate CPU device — flush should not activate
|
||||||
|
cpu_device = torch.device("cpu")
|
||||||
|
assert cpu_device.type != "mps"
|
||||||
|
|
||||||
|
# Simulate MPS device string check
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
mps_device = torch.device("mps")
|
||||||
|
assert mps_device.type == "mps"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 5. AGGRESSIVE_OFFLOAD flag integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAggressiveOffloadFlag:
|
||||||
|
"""Verify the CLI flag is correctly exposed."""
|
||||||
|
|
||||||
|
def test_flag_exists_in_model_management(self):
|
||||||
|
"""AGGRESSIVE_OFFLOAD must be importable from model_management."""
|
||||||
|
import comfy.model_management as mm
|
||||||
|
assert hasattr(mm, "AGGRESSIVE_OFFLOAD")
|
||||||
|
assert isinstance(mm.AGGRESSIVE_OFFLOAD, bool)
|
||||||
|
|
||||||
|
def test_flag_defaults_from_cli_args(self):
|
||||||
|
"""The flag should be wired from cli_args to model_management."""
|
||||||
|
import comfy.cli_args as cli_args
|
||||||
|
import comfy.model_management as mm
|
||||||
|
assert hasattr(cli_args.args, "aggressive_offload")
|
||||||
|
assert mm.AGGRESSIVE_OFFLOAD == cli_args.args.aggressive_offload
|
||||||
Loading…
Reference in New Issue
Block a user