Handle CPU torch more gracefully

This commit is contained in:
doctorpangloss 2024-07-05 10:47:06 -07:00
parent cf2eaedc5b
commit b6b97574dc

View File

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import sys
import platform import platform
import sys
import warnings import warnings
from enum import Enum from enum import Enum
from threading import RLock from threading import RLock
@ -49,7 +49,7 @@ if args.deterministic:
directml_device = None directml_device = None
if args.directml is not None: if args.directml is not None:
import torch_directml # pylint: disable=import-error import torch_directml # pylint: disable=import-error
device_index = args.directml device_index = args.directml
if device_index < 0: if device_index < 0:
@ -61,7 +61,7 @@ if args.directml is not None:
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try: try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error import intel_extension_for_pytorch as ipex # pylint: disable=import-error
if torch.xpu.is_available(): if torch.xpu.is_available():
xpu_available = True xpu_available = True
@ -158,8 +158,8 @@ if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
else: else:
try: try:
import xformers # pylint: disable=import-error import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error import xformers.ops # pylint: disable=import-error
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
try: try:
@ -186,6 +186,7 @@ def is_nvidia():
return True return True
return False return False
def is_amd(): def is_amd():
global cpu_state global cpu_state
if cpu_state == CPUState.GPU: if cpu_state == CPUState.GPU:
@ -193,6 +194,7 @@ def is_amd():
return True return True
return False return False
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention: if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
@ -287,7 +289,6 @@ try:
except: except:
logging.warning("Could not pick default device.") logging.warning("Could not pick default device.")
current_loaded_models: List["LoadedModel"] = [] current_loaded_models: List["LoadedModel"] = []
@ -506,6 +507,7 @@ def load_model_gpu(model):
with model_management_lock: with model_management_lock:
return load_models_gpu([model]) return load_models_gpu([model])
def loaded_models(only_currently_used=False): def loaded_models(only_currently_used=False):
with model_management_lock: with model_management_lock:
output = [] output = []
@ -517,6 +519,7 @@ def loaded_models(only_currently_used=False):
output.append(m.model) output.append(m.model)
return output return output
def cleanup_models(keep_clone_weights_loaded=False): def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock: with model_management_lock:
to_delete = [] to_delete = []
@ -704,14 +707,15 @@ def supports_dtype(device, dtype): # TODO
return True return True
return False return False
def supports_cast(device, dtype): #TODO
def supports_cast(device, dtype): # TODO
if dtype == torch.float32: if dtype == torch.float32:
return True return True
if dtype == torch.float16: if dtype == torch.float16:
return True return True
if is_device_mps(device): if is_device_mps(device):
return False return False
if directml_device: #TODO: test this if directml_device: # TODO: test this
return False return False
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return True return True
@ -721,28 +725,31 @@ def supports_cast(device, dtype): #TODO
return True return True
return False return False
def device_supports_non_blocking(device): def device_supports_non_blocking(device):
if is_device_mps(device): if is_device_mps(device):
return False # pytorch bug? mps doesn't support non blocking return False # pytorch bug? mps doesn't support non blocking
if is_intel_xpu(): if is_intel_xpu():
return False return False
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False return False
if directml_device: if directml_device:
return False return False
return True return True
def device_should_use_non_blocking(device): def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device): if not device_supports_non_blocking(device):
return False return False
return False return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last(): def force_channels_last():
if args.force_channels_last: if args.force_channels_last:
return True return True
#TODO # TODO
return False return False
@ -805,10 +812,11 @@ def pytorch_attention_flash_attention():
return True return True
return False return False
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
try: try:
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5 if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5
upcast = True upcast = True
except: except:
pass pass
@ -817,6 +825,7 @@ def force_upcast_attention_dtype():
else: else:
return None return None
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global directml_device global directml_device
if dev is None: if dev is None:
@ -906,12 +915,15 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd(): if is_amd():
return True return True
try:
props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
props = torch.cuda.get_device_properties("cuda") if props.major < 6:
if props.major >= 8: return False
return True except AssertionError:
logging.warning("Torch was not compiled with cuda support")
if props.major < 6:
return False return False
fp16_works = False fp16_works = False
@ -967,9 +979,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if device is None: if device is None:
device = torch.device("cuda") device = torch.device("cuda")
props = torch.cuda.get_device_properties(device) try:
if props.major >= 8: props = torch.cuda.get_device_properties(device)
return True if props.major >= 8:
return True
except AssertionError:
logging.warning("Torch was not compiled with CUDA support")
return False
bf16_works = torch.cuda.is_bf16_supported() bf16_works = torch.cuda.is_bf16_supported()