Handle CPU torch more gracefully

This commit is contained in:
doctorpangloss 2024-07-05 10:47:06 -07:00
parent cf2eaedc5b
commit b6b97574dc

View File

@ -1,8 +1,8 @@
from __future__ import annotations
import logging
import sys
import platform
import sys
import warnings
from enum import Enum
from threading import RLock
@ -186,6 +186,7 @@ def is_nvidia():
return True
return False
def is_amd():
global cpu_state
if cpu_state == CPUState.GPU:
@ -193,6 +194,7 @@ def is_amd():
return True
return False
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
@ -287,7 +289,6 @@ try:
except:
logging.warning("Could not pick default device.")
current_loaded_models: List["LoadedModel"] = []
@ -506,6 +507,7 @@ def load_model_gpu(model):
with model_management_lock:
return load_models_gpu([model])
def loaded_models(only_currently_used=False):
with model_management_lock:
output = []
@ -517,6 +519,7 @@ def loaded_models(only_currently_used=False):
output.append(m.model)
return output
def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock:
to_delete = []
@ -704,6 +707,7 @@ def supports_dtype(device, dtype): # TODO
return True
return False
def supports_cast(device, dtype): # TODO
if dtype == torch.float32:
return True
@ -721,6 +725,7 @@ def supports_cast(device, dtype): #TODO
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False # pytorch bug? mps doesn't support non blocking
@ -732,12 +737,14 @@ def device_supports_non_blocking(device):
return False
return True
def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last():
if args.force_channels_last:
return True
@ -805,6 +812,7 @@ def pytorch_attention_flash_attention():
return True
return False
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
@ -817,6 +825,7 @@ def force_upcast_attention_dtype():
else:
return None
def get_free_memory(dev=None, torch_free_too=False):
global directml_device
if dev is None:
@ -906,13 +915,16 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd():
return True
try:
props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
if props.major < 6:
return False
except AssertionError:
logging.warning("Torch was not compiled with cuda support")
return False
fp16_works = False
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
@ -967,9 +979,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if device is None:
device = torch.device("cuda")
try:
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
except AssertionError:
logging.warning("Torch was not compiled with CUDA support")
return False
bf16_works = torch.cuda.is_bf16_supported()