mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Coalesce directml_enabled and directml_device into one variable
This commit is contained in:
parent
5cd4ca9906
commit
9cf4f9830f
@ -47,11 +47,10 @@ if args.deterministic:
|
||||
logging.info("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
directml_enabled = False
|
||||
directml_device = None
|
||||
if args.directml is not None:
|
||||
import torch_directml # pylint: disable=import-error
|
||||
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
if device_index < 0:
|
||||
directml_device = torch_directml.device()
|
||||
@ -90,10 +89,9 @@ def is_intel_xpu():
|
||||
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
global cpu_state
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
if directml_device:
|
||||
return directml_device
|
||||
if cpu_state == CPUState.MPS:
|
||||
return torch.device("mps")
|
||||
@ -111,7 +109,7 @@ def get_torch_device():
|
||||
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
@ -119,7 +117,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
mem_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_total_torch = mem_total
|
||||
elif is_intel_xpu():
|
||||
@ -707,7 +705,7 @@ def supports_cast(device, dtype): #TODO
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if directml_enabled: #TODO: test this
|
||||
if directml_device: #TODO: test this
|
||||
return False
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
@ -722,7 +720,7 @@ def device_supports_non_blocking(device):
|
||||
return False # pytorch bug? mps doesn't support non blocking
|
||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -759,13 +757,13 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
global cpu_state
|
||||
if cpu_state != CPUState.GPU:
|
||||
return False
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
|
||||
@ -806,7 +804,7 @@ def force_upcast_attention_dtype():
|
||||
return None
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
@ -814,7 +812,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
mem_free_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_free_torch = mem_free_total
|
||||
elif is_intel_xpu():
|
||||
@ -864,7 +862,7 @@ def is_device_cuda(device):
|
||||
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
|
||||
if device is not None:
|
||||
if is_device_cpu(device):
|
||||
@ -880,7 +878,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
return False
|
||||
|
||||
if mps_mode():
|
||||
@ -943,7 +941,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user