From b6b97574dc5f76ca47dc3053a6a712c19d1eb59c Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 5 Jul 2024 10:47:06 -0700 Subject: [PATCH] Handle CPU torch more gracefully --- comfy/model_management.py | 54 +++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a1f078391..9d061600d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -import sys import platform +import sys import warnings from enum import Enum from threading import RLock @@ -49,7 +49,7 @@ if args.deterministic: directml_device = None if args.directml is not None: - import torch_directml # pylint: disable=import-error + import torch_directml # pylint: disable=import-error device_index = args.directml if device_index < 0: @@ -61,7 +61,7 @@ if args.directml is not None: lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import intel_extension_for_pytorch as ipex # pylint: disable=import-error + import intel_extension_for_pytorch as ipex # pylint: disable=import-error if torch.xpu.is_available(): xpu_available = True @@ -158,8 +158,8 @@ if args.disable_xformers: XFORMERS_IS_AVAILABLE = False else: try: - import xformers # pylint: disable=import-error - import xformers.ops # pylint: disable=import-error + import xformers # pylint: disable=import-error + import xformers.ops # pylint: disable=import-error XFORMERS_IS_AVAILABLE = True try: @@ -186,6 +186,7 @@ def is_nvidia(): return True return False + def is_amd(): global cpu_state if cpu_state == CPUState.GPU: @@ -193,6 +194,7 @@ def is_amd(): return True return False + ENABLE_PYTORCH_ATTENTION = False if args.use_pytorch_cross_attention: ENABLE_PYTORCH_ATTENTION = True @@ -287,7 +289,6 @@ try: except: logging.warning("Could not pick default device.") - current_loaded_models: List["LoadedModel"] = [] @@ -506,6 +507,7 @@ def load_model_gpu(model): with model_management_lock: return load_models_gpu([model]) + def loaded_models(only_currently_used=False): with model_management_lock: output = [] @@ -517,6 +519,7 @@ def loaded_models(only_currently_used=False): output.append(m.model) return output + def cleanup_models(keep_clone_weights_loaded=False): with model_management_lock: to_delete = [] @@ -704,14 +707,15 @@ def supports_dtype(device, dtype): # TODO return True return False -def supports_cast(device, dtype): #TODO + +def supports_cast(device, dtype): # TODO if dtype == torch.float32: return True if dtype == torch.float16: return True if is_device_mps(device): return False - if directml_device: #TODO: test this + if directml_device: # TODO: test this return False if dtype == torch.bfloat16: return True @@ -721,28 +725,31 @@ def supports_cast(device, dtype): #TODO return True return False + def device_supports_non_blocking(device): if is_device_mps(device): return False # pytorch bug? mps doesn't support non blocking if is_intel_xpu(): return False - if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) + if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) return False if directml_device: return False return True + def device_should_use_non_blocking(device): if not device_supports_non_blocking(device): return False return False # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others + def force_channels_last(): if args.force_channels_last: return True - #TODO + # TODO return False @@ -805,10 +812,11 @@ def pytorch_attention_flash_attention(): return True return False + def force_upcast_attention_dtype(): upcast = args.force_upcast_attention try: - if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5 + if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5 upcast = True except: pass @@ -817,6 +825,7 @@ def force_upcast_attention_dtype(): else: return None + def get_free_memory(dev=None, torch_free_too=False): global directml_device if dev is None: @@ -906,12 +915,15 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_amd(): return True + try: + props = torch.cuda.get_device_properties("cuda") + if props.major >= 8: + return True - props = torch.cuda.get_device_properties("cuda") - if props.major >= 8: - return True - - if props.major < 6: + if props.major < 6: + return False + except AssertionError: + logging.warning("Torch was not compiled with cuda support") return False fp16_works = False @@ -967,9 +979,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if device is None: device = torch.device("cuda") - props = torch.cuda.get_device_properties(device) - if props.major >= 8: - return True + try: + props = torch.cuda.get_device_properties(device) + if props.major >= 8: + return True + except AssertionError: + logging.warning("Torch was not compiled with CUDA support") + return False bf16_works = torch.cuda.is_bf16_supported()