From 33b967d30ab80c6ad58fba5e34069f0f8c673ef8 Mon Sep 17 00:00:00 2001 From: spn Date: Sat, 4 Apr 2026 21:14:58 +0200 Subject: [PATCH 1/2] fix: MPS memory reporting and cache synchronization - Split MPS from CPU in get_total_memory() and get_free_memory() to use torch.mps APIs (recommended_max_memory, driver_allocated_memory, current_allocated_memory) instead of relying solely on psutil - Add torch.mps.synchronize() to synchronize() and soft_empty_cache() - Add inter-node MPS cache flush when torch free memory is low - Move MPS SHARED vram assignment before --lowvram/--novram checks so users can override it for large models --- comfy/model_management.py | 29 +++++++++++++++++++++++++---- execution.py | 5 +++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0eebf1ded..d43276b42 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -211,9 +211,15 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and dev.type == 'cpu': mem_total = psutil.virtual_memory().total mem_total_torch = mem_total + elif hasattr(dev, 'type') and dev.type == 'mps': + mem_total = psutil.virtual_memory().total + try: + mem_total_torch = torch.mps.recommended_max_memory() + except Exception: + mem_total_torch = mem_total else: if directml_enabled: mem_total = 1024 * 1024 * 1024 #TODO @@ -460,7 +466,10 @@ if cpu_state != CPUState.GPU: vram_state = VRAMState.DISABLED if cpu_state == CPUState.MPS: - vram_state = VRAMState.SHARED + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to + else: + vram_state = VRAMState.SHARED logging.info(f"Set vram state to: {vram_state.name}") @@ -1485,9 +1494,18 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and dev.type == 'cpu': mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total + elif hasattr(dev, 'type') and dev.type == 'mps': + try: + driver_mem = torch.mps.driver_allocated_memory() + current_mem = torch.mps.current_allocated_memory() + mem_free_torch = max(driver_mem - current_mem, 0) + mem_free_total = psutil.virtual_memory().available + mem_free_torch + except Exception: + mem_free_total = psutil.virtual_memory().available + mem_free_torch = 0 else: if directml_enabled: mem_free_total = 1024 * 1024 * 1024 #TODO @@ -1756,7 +1774,9 @@ def lora_compute_dtype(device): def synchronize(): if cpu_mode(): return - if is_intel_xpu(): + if mps_mode(): + torch.mps.synchronize() + elif is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() @@ -1766,6 +1786,7 @@ def soft_empty_cache(force=False): return global cpu_state if cpu_state == CPUState.MPS: + torch.mps.synchronize() torch.mps.empty_cache() elif is_intel_xpu(): torch.xpu.empty_cache() diff --git a/execution.py b/execution.py index 5e02dffb2..00e80ba43 100644 --- a/execution.py +++ b/execution.py @@ -780,6 +780,11 @@ class PromptExecutor: if self.cache_type == CacheType.RAM_PRESSURE: comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) comfy.memory_management.extra_ram_release(ram_headroom) + elif comfy.model_management.mps_mode(): + mem_free_total, mem_free_torch = comfy.model_management.get_free_memory( + comfy.model_management.get_torch_device(), torch_free_too=True) + if mem_free_torch < mem_free_total * 0.25: + comfy.model_management.soft_empty_cache() else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed From 209d00c3b200ace846674863240517c4a1a4ff1f Mon Sep 17 00:00:00 2001 From: spn Date: Sun, 5 Apr 2026 12:37:14 +0200 Subject: [PATCH 2/2] fix: MPS synchronization and --lowvram/--novram support - Add torch.mps.synchronize() to synchronize() function, mirroring the existing CUDA and XPU paths - Add torch.mps.synchronize() before torch.mps.empty_cache() in soft_empty_cache() to ensure pending MPS operations complete before releasing cached memory - Allow --lowvram and --novram flags to take effect on MPS devices. Previously, MPS unconditionally set vram_state to SHARED regardless of user flags. Now respects set_vram_to when LOW_VRAM or NO_VRAM is requested, while defaulting to SHARED otherwise. --- comfy/model_management.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d43276b42..11cef8eee 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -211,15 +211,9 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and dev.type == 'cpu': + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total - elif hasattr(dev, 'type') and dev.type == 'mps': - mem_total = psutil.virtual_memory().total - try: - mem_total_torch = torch.mps.recommended_max_memory() - except Exception: - mem_total_torch = mem_total else: if directml_enabled: mem_total = 1024 * 1024 * 1024 #TODO @@ -1494,18 +1488,9 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and dev.type == 'cpu': + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total - elif hasattr(dev, 'type') and dev.type == 'mps': - try: - driver_mem = torch.mps.driver_allocated_memory() - current_mem = torch.mps.current_allocated_memory() - mem_free_torch = max(driver_mem - current_mem, 0) - mem_free_total = psutil.virtual_memory().available + mem_free_torch - except Exception: - mem_free_total = psutil.virtual_memory().available - mem_free_torch = 0 else: if directml_enabled: mem_free_total = 1024 * 1024 * 1024 #TODO