diff --git a/comfy/model_management.py b/comfy/model_management.py index bcf1399c4..88d242541 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -460,7 +460,10 @@ if cpu_state != CPUState.GPU: vram_state = VRAMState.DISABLED if cpu_state == CPUState.MPS: - vram_state = VRAMState.SHARED + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to + else: + vram_state = VRAMState.SHARED logging.info(f"Set vram state to: {vram_state.name}") @@ -1771,7 +1774,9 @@ def lora_compute_dtype(device): def synchronize(): if cpu_mode(): return - if is_intel_xpu(): + if mps_mode(): + torch.mps.synchronize() + elif is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() @@ -1781,6 +1786,7 @@ def soft_empty_cache(force=False): return global cpu_state if cpu_state == CPUState.MPS: + torch.mps.synchronize() torch.mps.empty_cache() elif is_intel_xpu(): torch.xpu.empty_cache() diff --git a/execution.py b/execution.py index 5e02dffb2..00e80ba43 100644 --- a/execution.py +++ b/execution.py @@ -780,6 +780,11 @@ class PromptExecutor: if self.cache_type == CacheType.RAM_PRESSURE: comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) comfy.memory_management.extra_ram_release(ram_headroom) + elif comfy.model_management.mps_mode(): + mem_free_total, mem_free_torch = comfy.model_management.get_free_memory( + comfy.model_management.get_torch_device(), torch_free_too=True) + if mem_free_torch < mem_free_total * 0.25: + comfy.model_management.soft_empty_cache() else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed