Handle CPU torch more gracefully

This commit is contained in:
doctorpangloss 2024-07-05 10:47:06 -07:00
parent cf2eaedc5b
commit b6b97574dc

View File

@ -1,8 +1,8 @@
from __future__ import annotations
import logging
import sys
import platform
import sys
import warnings
from enum import Enum
from threading import RLock
@ -49,7 +49,7 @@ if args.deterministic:
directml_device = None
if args.directml is not None:
import torch_directml # pylint: disable=import-error
import torch_directml # pylint: disable=import-error
device_index = args.directml
if device_index < 0:
@ -61,7 +61,7 @@ if args.directml is not None:
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
if torch.xpu.is_available():
xpu_available = True
@ -158,8 +158,8 @@ if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False
else:
try:
import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
XFORMERS_IS_AVAILABLE = True
try:
@ -186,6 +186,7 @@ def is_nvidia():
return True
return False
def is_amd():
global cpu_state
if cpu_state == CPUState.GPU:
@ -193,6 +194,7 @@ def is_amd():
return True
return False
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
@ -287,7 +289,6 @@ try:
except:
logging.warning("Could not pick default device.")
current_loaded_models: List["LoadedModel"] = []
@ -506,6 +507,7 @@ def load_model_gpu(model):
with model_management_lock:
return load_models_gpu([model])
def loaded_models(only_currently_used=False):
with model_management_lock:
output = []
@ -517,6 +519,7 @@ def loaded_models(only_currently_used=False):
output.append(m.model)
return output
def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock:
to_delete = []
@ -704,14 +707,15 @@ def supports_dtype(device, dtype): # TODO
return True
return False
def supports_cast(device, dtype): #TODO
def supports_cast(device, dtype): # TODO
if dtype == torch.float32:
return True
if dtype == torch.float16:
return True
if is_device_mps(device):
return False
if directml_device: #TODO: test this
if directml_device: # TODO: test this
return False
if dtype == torch.bfloat16:
return True
@ -721,28 +725,31 @@ def supports_cast(device, dtype): #TODO
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False # pytorch bug? mps doesn't support non blocking
if is_intel_xpu():
return False
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_device:
return False
return True
def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last():
if args.force_channels_last:
return True
#TODO
# TODO
return False
@ -805,10 +812,11 @@ def pytorch_attention_flash_attention():
return True
return False
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5
upcast = True
except:
pass
@ -817,6 +825,7 @@ def force_upcast_attention_dtype():
else:
return None
def get_free_memory(dev=None, torch_free_too=False):
global directml_device
if dev is None:
@ -906,12 +915,15 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd():
return True
try:
props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
if props.major < 6:
if props.major < 6:
return False
except AssertionError:
logging.warning("Torch was not compiled with cuda support")
return False
fp16_works = False
@ -967,9 +979,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if device is None:
device = torch.device("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
try:
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
except AssertionError:
logging.warning("Torch was not compiled with CUDA support")
return False
bf16_works = torch.cuda.is_bf16_supported()