From 209d00c3b200ace846674863240517c4a1a4ff1f Mon Sep 17 00:00:00 2001 From: spn Date: Sun, 5 Apr 2026 12:37:14 +0200 Subject: [PATCH] fix: MPS synchronization and --lowvram/--novram support - Add torch.mps.synchronize() to synchronize() function, mirroring the existing CUDA and XPU paths - Add torch.mps.synchronize() before torch.mps.empty_cache() in soft_empty_cache() to ensure pending MPS operations complete before releasing cached memory - Allow --lowvram and --novram flags to take effect on MPS devices. Previously, MPS unconditionally set vram_state to SHARED regardless of user flags. Now respects set_vram_to when LOW_VRAM or NO_VRAM is requested, while defaulting to SHARED otherwise. --- comfy/model_management.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d43276b42..11cef8eee 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -211,15 +211,9 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and dev.type == 'cpu': + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total - elif hasattr(dev, 'type') and dev.type == 'mps': - mem_total = psutil.virtual_memory().total - try: - mem_total_torch = torch.mps.recommended_max_memory() - except Exception: - mem_total_torch = mem_total else: if directml_enabled: mem_total = 1024 * 1024 * 1024 #TODO @@ -1494,18 +1488,9 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and dev.type == 'cpu': + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total - elif hasattr(dev, 'type') and dev.type == 'mps': - try: - driver_mem = torch.mps.driver_allocated_memory() - current_mem = torch.mps.current_allocated_memory() - mem_free_torch = max(driver_mem - current_mem, 0) - mem_free_total = psutil.virtual_memory().available + mem_free_torch - except Exception: - mem_free_total = psutil.virtual_memory().available - mem_free_torch = 0 else: if directml_enabled: mem_free_total = 1024 * 1024 * 1024 #TODO