Handle CPU torch more gracefully

This commit is contained in:
doctorpangloss 2024-07-05 10:47:06 -07:00
parent cf2eaedc5b
commit b6b97574dc

View File

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import sys
import platform import platform
import sys
import warnings import warnings
from enum import Enum from enum import Enum
from threading import RLock from threading import RLock
@ -186,6 +186,7 @@ def is_nvidia():
return True return True
return False return False
def is_amd(): def is_amd():
global cpu_state global cpu_state
if cpu_state == CPUState.GPU: if cpu_state == CPUState.GPU:
@ -193,6 +194,7 @@ def is_amd():
return True return True
return False return False
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention: if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
@ -287,7 +289,6 @@ try:
except: except:
logging.warning("Could not pick default device.") logging.warning("Could not pick default device.")
current_loaded_models: List["LoadedModel"] = [] current_loaded_models: List["LoadedModel"] = []
@ -506,6 +507,7 @@ def load_model_gpu(model):
with model_management_lock: with model_management_lock:
return load_models_gpu([model]) return load_models_gpu([model])
def loaded_models(only_currently_used=False): def loaded_models(only_currently_used=False):
with model_management_lock: with model_management_lock:
output = [] output = []
@ -517,6 +519,7 @@ def loaded_models(only_currently_used=False):
output.append(m.model) output.append(m.model)
return output return output
def cleanup_models(keep_clone_weights_loaded=False): def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock: with model_management_lock:
to_delete = [] to_delete = []
@ -704,6 +707,7 @@ def supports_dtype(device, dtype): # TODO
return True return True
return False return False
def supports_cast(device, dtype): # TODO def supports_cast(device, dtype): # TODO
if dtype == torch.float32: if dtype == torch.float32:
return True return True
@ -721,6 +725,7 @@ def supports_cast(device, dtype): #TODO
return True return True
return False return False
def device_supports_non_blocking(device): def device_supports_non_blocking(device):
if is_device_mps(device): if is_device_mps(device):
return False # pytorch bug? mps doesn't support non blocking return False # pytorch bug? mps doesn't support non blocking
@ -732,12 +737,14 @@ def device_supports_non_blocking(device):
return False return False
return True return True
def device_should_use_non_blocking(device): def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device): if not device_supports_non_blocking(device):
return False return False
return False return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last(): def force_channels_last():
if args.force_channels_last: if args.force_channels_last:
return True return True
@ -805,6 +812,7 @@ def pytorch_attention_flash_attention():
return True return True
return False return False
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
try: try:
@ -817,6 +825,7 @@ def force_upcast_attention_dtype():
else: else:
return None return None
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global directml_device global directml_device
if dev is None: if dev is None:
@ -906,13 +915,16 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd(): if is_amd():
return True return True
try:
props = torch.cuda.get_device_properties("cuda") props = torch.cuda.get_device_properties("cuda")
if props.major >= 8: if props.major >= 8:
return True return True
if props.major < 6: if props.major < 6:
return False return False
except AssertionError:
logging.warning("Torch was not compiled with cuda support")
return False
fp16_works = False fp16_works = False
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled # FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
@ -967,9 +979,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if device is None: if device is None:
device = torch.device("cuda") device = torch.device("cuda")
try:
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if props.major >= 8: if props.major >= 8:
return True return True
except AssertionError:
logging.warning("Torch was not compiled with CUDA support")
return False
bf16_works = torch.cuda.is_bf16_supported() bf16_works = torch.cuda.is_bf16_supported()