diff --git a/comfy/model_management.py b/comfy/model_management.py index 5c3033127..a4840f688 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -47,11 +47,10 @@ if args.deterministic: logging.info("Using deterministic algorithms for pytorch") torch.use_deterministic_algorithms(True, warn_only=True) -directml_enabled = False +directml_device = None if args.directml is not None: import torch_directml # pylint: disable=import-error - directml_enabled = True device_index = args.directml if device_index < 0: directml_device = torch_directml.device() @@ -90,10 +89,9 @@ def is_intel_xpu(): def get_torch_device(): - global directml_enabled + global directml_device global cpu_state - if directml_enabled: - global directml_device + if directml_device: return directml_device if cpu_state == CPUState.MPS: return torch.device("mps") @@ -111,7 +109,7 @@ def get_torch_device(): def get_total_memory(dev=None, torch_total_too=False): - global directml_enabled + global directml_device if dev is None: dev = get_torch_device() @@ -119,7 +117,7 @@ def get_total_memory(dev=None, torch_total_too=False): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: - if directml_enabled: + if directml_device: mem_total = 1024 * 1024 * 1024 # TODO mem_total_torch = mem_total elif is_intel_xpu(): @@ -707,7 +705,7 @@ def supports_cast(device, dtype): #TODO return True if is_device_mps(device): return False - if directml_enabled: #TODO: test this + if directml_device: #TODO: test this return False if dtype == torch.bfloat16: return True @@ -722,7 +720,7 @@ def device_supports_non_blocking(device): return False # pytorch bug? mps doesn't support non blocking if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) return False - if directml_enabled: + if directml_device: return False return True @@ -759,13 +757,13 @@ def cast_to_device(tensor, device, dtype, copy=False): def xformers_enabled(): - global directml_enabled + global directml_device global cpu_state if cpu_state != CPUState.GPU: return False if is_intel_xpu(): return False - if directml_enabled: + if directml_device: return False return XFORMERS_IS_AVAILABLE @@ -806,7 +804,7 @@ def force_upcast_attention_dtype(): return None def get_free_memory(dev=None, torch_free_too=False): - global directml_enabled + global directml_device if dev is None: dev = get_torch_device() @@ -814,7 +812,7 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: - if directml_enabled: + if directml_device: mem_free_total = 1024 * 1024 * 1024 # TODO mem_free_torch = mem_free_total elif is_intel_xpu(): @@ -864,7 +862,7 @@ def is_device_cuda(device): def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): - global directml_enabled + global directml_device if device is not None: if is_device_cpu(device): @@ -880,7 +878,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP32: return False - if directml_enabled: + if directml_device: return False if mps_mode(): @@ -943,7 +941,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP32: return False - if directml_enabled: + if directml_device: return False if cpu_mode() or mps_mode():