Coalesce directml_enabled and directml_device into one variable

This commit is contained in:
Max Tretikov 2024-06-14 13:16:05 -06:00
parent 5cd4ca9906
commit 9cf4f9830f

View File

@ -47,11 +47,10 @@ if args.deterministic:
logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False
directml_device = None
if args.directml is not None:
import torch_directml # pylint: disable=import-error
directml_enabled = True
device_index = args.directml
if device_index < 0:
directml_device = torch_directml.device()
@ -90,10 +89,9 @@ def is_intel_xpu():
def get_torch_device():
global directml_enabled
global directml_device
global cpu_state
if directml_enabled:
global directml_device
if directml_device:
return directml_device
if cpu_state == CPUState.MPS:
return torch.device("mps")
@ -111,7 +109,7 @@ def get_torch_device():
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
global directml_device
if dev is None:
dev = get_torch_device()
@ -119,7 +117,7 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
if directml_device:
mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total
elif is_intel_xpu():
@ -707,7 +705,7 @@ def supports_cast(device, dtype): #TODO
return True
if is_device_mps(device):
return False
if directml_enabled: #TODO: test this
if directml_device: #TODO: test this
return False
if dtype == torch.bfloat16:
return True
@ -722,7 +720,7 @@ def device_supports_non_blocking(device):
return False # pytorch bug? mps doesn't support non blocking
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled:
if directml_device:
return False
return True
@ -759,13 +757,13 @@ def cast_to_device(tensor, device, dtype, copy=False):
def xformers_enabled():
global directml_enabled
global directml_device
global cpu_state
if cpu_state != CPUState.GPU:
return False
if is_intel_xpu():
return False
if directml_enabled:
if directml_device:
return False
return XFORMERS_IS_AVAILABLE
@ -806,7 +804,7 @@ def force_upcast_attention_dtype():
return None
def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled
global directml_device
if dev is None:
dev = get_torch_device()
@ -814,7 +812,7 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
if directml_enabled:
if directml_device:
mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total
elif is_intel_xpu():
@ -864,7 +862,7 @@ def is_device_cuda(device):
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled
global directml_device
if device is not None:
if is_device_cpu(device):
@ -880,7 +878,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32:
return False
if directml_enabled:
if directml_device:
return False
if mps_mode():
@ -943,7 +941,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32:
return False
if directml_enabled:
if directml_device:
return False
if cpu_mode() or mps_mode():