diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..b1943cd1a 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -158,6 +158,7 @@ parser.add_argument("--force-non-blocking", action="store_true", help="Force Com parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") +parser.add_argument("--aggressive-offload", action="store_true", help="Aggressively free models from RAM after use. Designed for Apple Silicon where CPU RAM and GPU VRAM are the same physical memory. Frees ~18GB during sampling by unloading text encoders after encoding. Trade-off: ~10s reload penalty per subsequent generation.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") class PerformanceFeature(enum.Enum): diff --git a/comfy/model_management.py b/comfy/model_management.py index 0eebf1ded..66ff3f81e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -465,10 +465,51 @@ if cpu_state == CPUState.MPS: logging.info(f"Set vram state to: {vram_state.name}") DISABLE_SMART_MEMORY = args.disable_smart_memory +AGGRESSIVE_OFFLOAD = args.aggressive_offload if DISABLE_SMART_MEMORY: logging.info("Disabling smart memory management") +if AGGRESSIVE_OFFLOAD: + logging.info("Aggressive offload enabled: models will be freed from RAM after use (designed for Apple Silicon)") + +# --------------------------------------------------------------------------- +# Model lifecycle callbacks — on_model_destroyed +# --------------------------------------------------------------------------- +# Why not comfy.hooks? The existing hook system (comfy/hooks.py) is scoped +# to *sampling conditioning* — LoRA weight injection, transformer_options, +# and keyframe scheduling. It has no concept of model-management lifecycle +# events such as "a model's parameters were deallocated". +# +# This lightweight callback list fills that gap. It is intentionally minimal +# (append-only, no priorities, no removal) because the only current consumer +# is the execution-engine cache invalidator registered in PromptExecutor. +# If upstream adopts a formal lifecycle-event bus in the future, these +# callbacks should migrate to that system. +# --------------------------------------------------------------------------- +_on_model_destroyed_callbacks: list = [] + + +def register_model_destroyed_callback(callback): + """Register a listener for post-destruction lifecycle events. + + After ``free_memory`` moves one or more models to the ``meta`` device + (aggressive offload), every registered callback is invoked once with a + *reason* string describing the batch (e.g. ``"batch"``). + + Typical usage — executed by ``PromptExecutor.__init__``:: + + def _invalidate(reason): + executor.caches.outputs.clear_all() + register_model_destroyed_callback(_invalidate) + + Args: + callback: ``Callable[[str], None]`` — receives a human-readable + reason string. Must be safe to call from within the + ``free_memory`` critical section (no heavy I/O, no model loads). + """ + _on_model_destroyed_callbacks.append(callback) + def get_torch_device_name(device): if hasattr(device, 'type'): if device.type == "cuda": @@ -640,14 +681,20 @@ def offloaded_memory(loaded_models, device): WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 -if WINDOWS: +if cpu_state == CPUState.MPS: + # macOS with Apple Silicon: shared memory means OS needs more headroom. + # Reserve 4 GB for macOS + system services to prevent swap thrashing. + EXTRA_RESERVED_VRAM = 4 * 1024 * 1024 * 1024 + logging.info("MPS detected: reserving 4 GB for macOS system overhead") +elif WINDOWS: import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 def get_free_ram(): return comfy.windows.get_free_ram() -else: + +if not WINDOWS: def get_free_ram(): return psutil.virtual_memory().available @@ -669,14 +716,25 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] - if device is None or shift_model.device == device: + # On Apple Silicon SHARED mode, CPU RAM == GPU VRAM (same physical memory). + # Bypass the device filter so CPU-loaded models (like CLIP) can be freed. + device_match = (device is None or shift_model.device == device) + if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED: + device_match = True + if device_match: if shift_model not in keep_loaded and not shift_model.is_dead(): can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False can_unload_sorted = sorted(can_unload) + # Collect models to destroy via meta device AFTER the unload loop completes, + # so we don't kill weakrefs of models still being iterated. + _meta_destroy_queue = [] for x in can_unload_sorted: i = x[-1] + # Guard: weakref may already be dead from a previous iteration + if current_loaded_models[i].model is None: + continue memory_to_free = 1e32 pins_to_free = 1e32 if not DISABLE_SMART_MEMORY or device is None: @@ -687,15 +745,66 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() memory_to_free = 0 + + # Aggressive offload for Apple Silicon: force-unload unused models + # regardless of free memory, since CPU RAM == GPU VRAM. + if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED: + if not current_loaded_models[i].currently_used: + memory_to_free = 1e32 # Force unload + model_name = current_loaded_models[i].model.model.__class__.__name__ + model_size_mb = current_loaded_models[i].model_memory() / (1024 * 1024) + logging.info(f"[aggressive-offload] Force-unloading {model_name} ({model_size_mb:.0f} MB) from shared RAM") + if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") + # Queue for meta device destruction after loop completes. + # Only destroy large models (>1 GB) — small models like the VAE (160 MB) + # are kept because the execution cache may reuse their patcher across + # workflow nodes (e.g. vae_loader is cached while vae_decode runs later). + if AGGRESSIVE_OFFLOAD and vram_state == VRAMState.SHARED: + if current_loaded_models[i].model is not None: + model_size = current_loaded_models[i].model_memory() + if model_size > 1024 * 1024 * 1024: # Only meta-destroy models > 1 GB + _meta_destroy_queue.append(i) unloaded_model.append(i) if pins_to_free > 0: - logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") - current_loaded_models[i].model.partially_unload_ram(pins_to_free) + if current_loaded_models[i].model is not None: + logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") + current_loaded_models[i].model.partially_unload_ram(pins_to_free) + + # --- Phase 2: Deferred meta-device destruction ------------------------- + # Move parameters of queued models to the 'meta' device. This replaces + # every nn.Parameter with a zero-storage meta tensor, releasing physical + # RAM on unified-memory systems (Apple Silicon). The operation is + # deferred until *after* the unload loop to avoid invalidating weakrefs + # that other iterations may still reference. + for i in _meta_destroy_queue: + try: + model_ref = current_loaded_models[i].model + if model_ref is None: + continue + inner_model = model_ref.model + model_name = inner_model.__class__.__name__ + param_count = sum(p.numel() * p.element_size() for p in inner_model.parameters()) + inner_model.to(device="meta") + logging.info(f"[aggressive-offload] Moved {model_name} params to meta device, freed {param_count / (1024**2):.0f} MB") + except Exception as e: + logging.warning(f"[aggressive-offload] Failed to move model to meta: {e}") + + # --- Phase 3: Notify lifecycle listeners -------------------------------- + # Fire on_model_destroyed callbacks *once* after the entire batch has been + # processed, not per-model. This lets the execution engine clear its + # output cache in a single operation (see PromptExecutor.__init__). + if _meta_destroy_queue and _on_model_destroyed_callbacks: + for cb in _on_model_destroyed_callbacks: + cb("batch") + logging.info(f"[aggressive-offload] Invalidated execution cache after destroying {len(_meta_destroy_queue)} model(s)") for x in can_unload_sorted: i = x[-1] + # Guard: weakref may be dead after cache invalidation (meta device move) + if current_loaded_models[i].model is None: + continue ram_to_free = ram_required - psutil.virtual_memory().available if ram_to_free <= 0 and i not in unloaded_model: continue @@ -708,6 +817,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins if len(unloaded_model) > 0: soft_empty_cache() + if AGGRESSIVE_OFFLOAD: + gc.collect() # Force Python GC to release model tensors + soft_empty_cache() # Second pass to free MPS allocator cache elif device is not None: if vram_state != VRAMState.HIGH_VRAM: mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..c60d3fb5d 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -748,6 +748,18 @@ class KSAMPLER(Sampler): if callback is not None: k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + # On Apple Silicon MPS, flush the allocator pool between steps to prevent + # progressive memory fragmentation and swap thrashing. Wrapping the callback + # here (rather than patching individual samplers) covers all sampler variants. + import comfy.model_management + if noise.device.type == "mps" and getattr(comfy.model_management, "AGGRESSIVE_OFFLOAD", False): + _inner_callback = k_callback + def _mps_flush_callback(x): + if _inner_callback is not None: + _inner_callback(x) + torch.mps.empty_cache() + k_callback = _mps_flush_callback + samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return samples diff --git a/execution.py b/execution.py index 5e02dffb2..074d561f9 100644 --- a/execution.py +++ b/execution.py @@ -651,6 +651,17 @@ class PromptExecutor: self.cache_type = cache_type self.server = server self.reset() + # Register callback so model_management can invalidate cached outputs + # after destroying a model via meta device move (aggressive offload). + # NOTE: self.caches is resolved at call time (not capture time), so this + # callback remains valid even if reset() replaces self.caches later. + import comfy.model_management as mm + if mm.AGGRESSIVE_OFFLOAD: + executor = self + def _invalidate_cache(reason): + logging.info(f"[aggressive-offload] Invalidating execution cache ({reason})") + executor.caches.outputs.clear_all() + mm.register_model_destroyed_callback(_invalidate_cache) def reset(self): self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)