From 8329bb0db3cc41104701695ce530f571ea7b9a0a Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Sun, 7 Jul 2024 16:31:17 -0700 Subject: [PATCH] Enable MPS testing (passes) --- tests/conftest.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6727711a9..1c8c973ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,21 +50,39 @@ def run_server(server_arguments: Configuration): @pytest.fixture(scope="function", autouse=False) def has_gpu() -> bool: - # ipex + # mps + has_gpu = False try: - import intel_extension_for_pytorch as ipex - has_gpu = ipex.xpu.device_count() > 0 + import torch + has_gpu = torch.backends.mps.is_available() and torch.device("mps") is not None + if has_gpu: + from comfy import model_management + from comfy.model_management import CPUState + model_management.cpu_state = CPUState.MPS except ImportError: + pass + + if not has_gpu: + # ipex try: - import torch - has_gpu = torch.device(torch.cuda.current_device()) is not None - except: + import intel_extension_for_pytorch as ipex + has_gpu = ipex.xpu.device_count() > 0 + except ImportError: has_gpu = False + if not has_gpu: + # cuda + try: + import torch + has_gpu = torch.device(torch.cuda.current_device()) is not None + except: + has_gpu = False + if has_gpu: from comfy import model_management from comfy.model_management import CPUState - model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU + if model_management.cpu_state != CPUState.MPS: + model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU yield has_gpu