diff --git a/comfy/model_management.py b/comfy/model_management.py index 0eebf1ded..d43276b42 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -211,9 +211,15 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and dev.type == 'cpu': mem_total = psutil.virtual_memory().total mem_total_torch = mem_total + elif hasattr(dev, 'type') and dev.type == 'mps': + mem_total = psutil.virtual_memory().total + try: + mem_total_torch = torch.mps.recommended_max_memory() + except Exception: + mem_total_torch = mem_total else: if directml_enabled: mem_total = 1024 * 1024 * 1024 #TODO @@ -460,7 +466,10 @@ if cpu_state != CPUState.GPU: vram_state = VRAMState.DISABLED if cpu_state == CPUState.MPS: - vram_state = VRAMState.SHARED + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to + else: + vram_state = VRAMState.SHARED logging.info(f"Set vram state to: {vram_state.name}") @@ -1485,9 +1494,18 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and dev.type == 'cpu': mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total + elif hasattr(dev, 'type') and dev.type == 'mps': + try: + driver_mem = torch.mps.driver_allocated_memory() + current_mem = torch.mps.current_allocated_memory() + mem_free_torch = max(driver_mem - current_mem, 0) + mem_free_total = psutil.virtual_memory().available + mem_free_torch + except Exception: + mem_free_total = psutil.virtual_memory().available + mem_free_torch = 0 else: if directml_enabled: mem_free_total = 1024 * 1024 * 1024 #TODO @@ -1756,7 +1774,9 @@ def lora_compute_dtype(device): def synchronize(): if cpu_mode(): return - if is_intel_xpu(): + if mps_mode(): + torch.mps.synchronize() + elif is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() @@ -1766,6 +1786,7 @@ def soft_empty_cache(force=False): return global cpu_state if cpu_state == CPUState.MPS: + torch.mps.synchronize() torch.mps.empty_cache() elif is_intel_xpu(): torch.xpu.empty_cache() diff --git a/execution.py b/execution.py index 5e02dffb2..00e80ba43 100644 --- a/execution.py +++ b/execution.py @@ -780,6 +780,11 @@ class PromptExecutor: if self.cache_type == CacheType.RAM_PRESSURE: comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) comfy.memory_management.extra_ram_release(ram_headroom) + elif comfy.model_management.mps_mode(): + mem_free_total, mem_free_torch = comfy.model_management.get_free_memory( + comfy.model_management.get_torch_device(), torch_free_too=True) + if mem_free_torch < mem_free_total * 0.25: + comfy.model_management.soft_empty_cache() else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed