Enable MPS testing (passes)

This commit is contained in:
Benjamin Berman 2024-07-07 16:31:17 -07:00
parent 1902aeee0b
commit 8329bb0db3

View File

@ -50,21 +50,39 @@ def run_server(server_arguments: Configuration):
@pytest.fixture(scope="function", autouse=False) @pytest.fixture(scope="function", autouse=False)
def has_gpu() -> bool: def has_gpu() -> bool:
# ipex # mps
has_gpu = False
try: try:
import intel_extension_for_pytorch as ipex import torch
has_gpu = ipex.xpu.device_count() > 0 has_gpu = torch.backends.mps.is_available() and torch.device("mps") is not None
if has_gpu:
from comfy import model_management
from comfy.model_management import CPUState
model_management.cpu_state = CPUState.MPS
except ImportError: except ImportError:
pass
if not has_gpu:
# ipex
try: try:
import torch import intel_extension_for_pytorch as ipex
has_gpu = torch.device(torch.cuda.current_device()) is not None has_gpu = ipex.xpu.device_count() > 0
except: except ImportError:
has_gpu = False has_gpu = False
if not has_gpu:
# cuda
try:
import torch
has_gpu = torch.device(torch.cuda.current_device()) is not None
except:
has_gpu = False
if has_gpu: if has_gpu:
from comfy import model_management from comfy import model_management
from comfy.model_management import CPUState from comfy.model_management import CPUState
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU if model_management.cpu_state != CPUState.MPS:
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
yield has_gpu yield has_gpu