Enable MPS testing (passes)

This commit is contained in:
Benjamin Berman 2024-07-07 16:31:17 -07:00
parent 1902aeee0b
commit 8329bb0db3

View File

@ -50,21 +50,39 @@ def run_server(server_arguments: Configuration):
@pytest.fixture(scope="function", autouse=False)
def has_gpu() -> bool:
# ipex
# mps
has_gpu = False
try:
import intel_extension_for_pytorch as ipex
has_gpu = ipex.xpu.device_count() > 0
import torch
has_gpu = torch.backends.mps.is_available() and torch.device("mps") is not None
if has_gpu:
from comfy import model_management
from comfy.model_management import CPUState
model_management.cpu_state = CPUState.MPS
except ImportError:
pass
if not has_gpu:
# ipex
try:
import torch
has_gpu = torch.device(torch.cuda.current_device()) is not None
except:
import intel_extension_for_pytorch as ipex
has_gpu = ipex.xpu.device_count() > 0
except ImportError:
has_gpu = False
if not has_gpu:
# cuda
try:
import torch
has_gpu = torch.device(torch.cuda.current_device()) is not None
except:
has_gpu = False
if has_gpu:
from comfy import model_management
from comfy.model_management import CPUState
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
if model_management.cpu_state != CPUState.MPS:
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
yield has_gpu