mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 16:02:32 +08:00
Handle CPU torch more gracefully
This commit is contained in:
parent
cf2eaedc5b
commit
b6b97574dc
@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import platform
|
||||
import sys
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from threading import RLock
|
||||
@ -186,6 +186,7 @@ def is_nvidia():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_amd():
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.GPU:
|
||||
@ -193,6 +194,7 @@ def is_amd():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@ -287,7 +289,6 @@ try:
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
|
||||
|
||||
current_loaded_models: List["LoadedModel"] = []
|
||||
|
||||
|
||||
@ -506,6 +507,7 @@ def load_model_gpu(model):
|
||||
with model_management_lock:
|
||||
return load_models_gpu([model])
|
||||
|
||||
|
||||
def loaded_models(only_currently_used=False):
|
||||
with model_management_lock:
|
||||
output = []
|
||||
@ -517,6 +519,7 @@ def loaded_models(only_currently_used=False):
|
||||
output.append(m.model)
|
||||
return output
|
||||
|
||||
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
with model_management_lock:
|
||||
to_delete = []
|
||||
@ -704,6 +707,7 @@ def supports_dtype(device, dtype): # TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def supports_cast(device, dtype): # TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
@ -721,6 +725,7 @@ def supports_cast(device, dtype): #TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False # pytorch bug? mps doesn't support non blocking
|
||||
@ -732,12 +737,14 @@ def device_supports_non_blocking(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def device_should_use_non_blocking(device):
|
||||
if not device_supports_non_blocking(device):
|
||||
return False
|
||||
return False
|
||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||
|
||||
|
||||
def force_channels_last():
|
||||
if args.force_channels_last:
|
||||
return True
|
||||
@ -805,6 +812,7 @@ def pytorch_attention_flash_attention():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
@ -817,6 +825,7 @@ def force_upcast_attention_dtype():
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global directml_device
|
||||
if dev is None:
|
||||
@ -906,13 +915,16 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
if is_amd():
|
||||
return True
|
||||
|
||||
try:
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
if props.major < 6:
|
||||
return False
|
||||
except AssertionError:
|
||||
logging.warning("Torch was not compiled with cuda support")
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
@ -967,9 +979,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if device is None:
|
||||
device = torch.device("cuda")
|
||||
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
except AssertionError:
|
||||
logging.warning("Torch was not compiled with CUDA support")
|
||||
return False
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user