mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Handle CPU torch more gracefully
This commit is contained in:
parent
cf2eaedc5b
commit
b6b97574dc
@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import platform
|
||||
import sys
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from threading import RLock
|
||||
@ -49,7 +49,7 @@ if args.deterministic:
|
||||
|
||||
directml_device = None
|
||||
if args.directml is not None:
|
||||
import torch_directml # pylint: disable=import-error
|
||||
import torch_directml # pylint: disable=import-error
|
||||
|
||||
device_index = args.directml
|
||||
if device_index < 0:
|
||||
@ -61,7 +61,7 @@ if args.directml is not None:
|
||||
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
||||
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
@ -158,8 +158,8 @@ if args.disable_xformers:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
else:
|
||||
try:
|
||||
import xformers # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
import xformers # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
try:
|
||||
@ -186,6 +186,7 @@ def is_nvidia():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_amd():
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.GPU:
|
||||
@ -193,6 +194,7 @@ def is_amd():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@ -287,7 +289,6 @@ try:
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
|
||||
|
||||
current_loaded_models: List["LoadedModel"] = []
|
||||
|
||||
|
||||
@ -506,6 +507,7 @@ def load_model_gpu(model):
|
||||
with model_management_lock:
|
||||
return load_models_gpu([model])
|
||||
|
||||
|
||||
def loaded_models(only_currently_used=False):
|
||||
with model_management_lock:
|
||||
output = []
|
||||
@ -517,6 +519,7 @@ def loaded_models(only_currently_used=False):
|
||||
output.append(m.model)
|
||||
return output
|
||||
|
||||
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
with model_management_lock:
|
||||
to_delete = []
|
||||
@ -704,14 +707,15 @@ def supports_dtype(device, dtype): # TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
def supports_cast(device, dtype): #TODO
|
||||
|
||||
def supports_cast(device, dtype): # TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if directml_device: #TODO: test this
|
||||
if directml_device: # TODO: test this
|
||||
return False
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
@ -721,28 +725,31 @@ def supports_cast(device, dtype): #TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False # pytorch bug? mps doesn't support non blocking
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_device:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def device_should_use_non_blocking(device):
|
||||
if not device_supports_non_blocking(device):
|
||||
return False
|
||||
return False
|
||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||
|
||||
|
||||
def force_channels_last():
|
||||
if args.force_channels_last:
|
||||
return True
|
||||
|
||||
#TODO
|
||||
# TODO
|
||||
return False
|
||||
|
||||
|
||||
@ -805,10 +812,11 @@ def pytorch_attention_flash_attention():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
||||
if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5
|
||||
upcast = True
|
||||
except:
|
||||
pass
|
||||
@ -817,6 +825,7 @@ def force_upcast_attention_dtype():
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global directml_device
|
||||
if dev is None:
|
||||
@ -906,12 +915,15 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
if is_amd():
|
||||
return True
|
||||
try:
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
if props.major < 6:
|
||||
if props.major < 6:
|
||||
return False
|
||||
except AssertionError:
|
||||
logging.warning("Torch was not compiled with cuda support")
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
@ -967,9 +979,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if device is None:
|
||||
device = torch.device("cuda")
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
except AssertionError:
|
||||
logging.warning("Torch was not compiled with CUDA support")
|
||||
return False
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user