diff --git a/comfy/model_management.py b/comfy/model_management.py index d43276b42..11cef8eee 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -211,15 +211,9 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and dev.type == 'cpu': + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total - elif hasattr(dev, 'type') and dev.type == 'mps': - mem_total = psutil.virtual_memory().total - try: - mem_total_torch = torch.mps.recommended_max_memory() - except Exception: - mem_total_torch = mem_total else: if directml_enabled: mem_total = 1024 * 1024 * 1024 #TODO @@ -1494,18 +1488,9 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and dev.type == 'cpu': + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total - elif hasattr(dev, 'type') and dev.type == 'mps': - try: - driver_mem = torch.mps.driver_allocated_memory() - current_mem = torch.mps.current_allocated_memory() - mem_free_torch = max(driver_mem - current_mem, 0) - mem_free_total = psutil.virtual_memory().available + mem_free_torch - except Exception: - mem_free_total = psutil.virtual_memory().available - mem_free_torch = 0 else: if directml_enabled: mem_free_total = 1024 * 1024 * 1024 #TODO