Coalesce directml_enabled and directml_device into one variable

This commit is contained in:
Max Tretikov 2024-06-14 13:16:05 -06:00
parent 5cd4ca9906
commit 9cf4f9830f

View File

@ -47,11 +47,10 @@ if args.deterministic:
logging.info("Using deterministic algorithms for pytorch") logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True) torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False directml_device = None
if args.directml is not None: if args.directml is not None:
import torch_directml # pylint: disable=import-error import torch_directml # pylint: disable=import-error
directml_enabled = True
device_index = args.directml device_index = args.directml
if device_index < 0: if device_index < 0:
directml_device = torch_directml.device() directml_device = torch_directml.device()
@ -90,10 +89,9 @@ def is_intel_xpu():
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_device
global cpu_state global cpu_state
if directml_enabled: if directml_device:
global directml_device
return directml_device return directml_device
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
return torch.device("mps") return torch.device("mps")
@ -111,7 +109,7 @@ def get_torch_device():
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled global directml_device
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -119,7 +117,7 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = psutil.virtual_memory().total mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total mem_total_torch = mem_total
else: else:
if directml_enabled: if directml_device:
mem_total = 1024 * 1024 * 1024 # TODO mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif is_intel_xpu(): elif is_intel_xpu():
@ -707,7 +705,7 @@ def supports_cast(device, dtype): #TODO
return True return True
if is_device_mps(device): if is_device_mps(device):
return False return False
if directml_enabled: #TODO: test this if directml_device: #TODO: test this
return False return False
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return True return True
@ -722,7 +720,7 @@ def device_supports_non_blocking(device):
return False # pytorch bug? mps doesn't support non blocking return False # pytorch bug? mps doesn't support non blocking
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False return False
if directml_enabled: if directml_device:
return False return False
return True return True
@ -759,13 +757,13 @@ def cast_to_device(tensor, device, dtype, copy=False):
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_device
global cpu_state global cpu_state
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
return False return False
if is_intel_xpu(): if is_intel_xpu():
return False return False
if directml_enabled: if directml_device:
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -806,7 +804,7 @@ def force_upcast_attention_dtype():
return None return None
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled global directml_device
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -814,7 +812,7 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
if directml_enabled: if directml_device:
mem_free_total = 1024 * 1024 * 1024 # TODO mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
elif is_intel_xpu(): elif is_intel_xpu():
@ -864,7 +862,7 @@ def is_device_cuda(device):
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled global directml_device
if device is not None: if device is not None:
if is_device_cpu(device): if is_device_cpu(device):
@ -880,7 +878,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled: if directml_device:
return False return False
if mps_mode(): if mps_mode():
@ -943,7 +941,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled: if directml_device:
return False return False
if cpu_mode() or mps_mode(): if cpu_mode() or mps_mode():