mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-18 18:30:19 +08:00
Handle CPU torch more gracefully
This commit is contained in:
parent
cf2eaedc5b
commit
b6b97574dc
@ -1,8 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import platform
|
import platform
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
@ -49,7 +49,7 @@ if args.deterministic:
|
|||||||
|
|
||||||
directml_device = None
|
directml_device = None
|
||||||
if args.directml is not None:
|
if args.directml is not None:
|
||||||
import torch_directml # pylint: disable=import-error
|
import torch_directml # pylint: disable=import-error
|
||||||
|
|
||||||
device_index = args.directml
|
device_index = args.directml
|
||||||
if device_index < 0:
|
if device_index < 0:
|
||||||
@ -61,7 +61,7 @@ if args.directml is not None:
|
|||||||
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
xpu_available = True
|
xpu_available = True
|
||||||
@ -158,8 +158,8 @@ if args.disable_xformers:
|
|||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import xformers # pylint: disable=import-error
|
import xformers # pylint: disable=import-error
|
||||||
import xformers.ops # pylint: disable=import-error
|
import xformers.ops # pylint: disable=import-error
|
||||||
|
|
||||||
XFORMERS_IS_AVAILABLE = True
|
XFORMERS_IS_AVAILABLE = True
|
||||||
try:
|
try:
|
||||||
@ -186,6 +186,7 @@ def is_nvidia():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_amd():
|
def is_amd():
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.GPU:
|
if cpu_state == CPUState.GPU:
|
||||||
@ -193,6 +194,7 @@ def is_amd():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = False
|
ENABLE_PYTORCH_ATTENTION = False
|
||||||
if args.use_pytorch_cross_attention:
|
if args.use_pytorch_cross_attention:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@ -287,7 +289,6 @@ try:
|
|||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
|
||||||
|
|
||||||
current_loaded_models: List["LoadedModel"] = []
|
current_loaded_models: List["LoadedModel"] = []
|
||||||
|
|
||||||
|
|
||||||
@ -506,6 +507,7 @@ def load_model_gpu(model):
|
|||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
return load_models_gpu([model])
|
return load_models_gpu([model])
|
||||||
|
|
||||||
|
|
||||||
def loaded_models(only_currently_used=False):
|
def loaded_models(only_currently_used=False):
|
||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
output = []
|
output = []
|
||||||
@ -517,6 +519,7 @@ def loaded_models(only_currently_used=False):
|
|||||||
output.append(m.model)
|
output.append(m.model)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def cleanup_models(keep_clone_weights_loaded=False):
|
def cleanup_models(keep_clone_weights_loaded=False):
|
||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
to_delete = []
|
to_delete = []
|
||||||
@ -704,14 +707,15 @@ def supports_dtype(device, dtype): # TODO
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_cast(device, dtype): #TODO
|
|
||||||
|
def supports_cast(device, dtype): # TODO
|
||||||
if dtype == torch.float32:
|
if dtype == torch.float32:
|
||||||
return True
|
return True
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16:
|
||||||
return True
|
return True
|
||||||
if is_device_mps(device):
|
if is_device_mps(device):
|
||||||
return False
|
return False
|
||||||
if directml_device: #TODO: test this
|
if directml_device: # TODO: test this
|
||||||
return False
|
return False
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return True
|
return True
|
||||||
@ -721,28 +725,31 @@ def supports_cast(device, dtype): #TODO
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def device_supports_non_blocking(device):
|
def device_supports_non_blocking(device):
|
||||||
if is_device_mps(device):
|
if is_device_mps(device):
|
||||||
return False # pytorch bug? mps doesn't support non blocking
|
return False # pytorch bug? mps doesn't support non blocking
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return False
|
return False
|
||||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||||
return False
|
return False
|
||||||
if directml_device:
|
if directml_device:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def device_should_use_non_blocking(device):
|
def device_should_use_non_blocking(device):
|
||||||
if not device_supports_non_blocking(device):
|
if not device_supports_non_blocking(device):
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||||
|
|
||||||
|
|
||||||
def force_channels_last():
|
def force_channels_last():
|
||||||
if args.force_channels_last:
|
if args.force_channels_last:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
#TODO
|
# TODO
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -805,10 +812,11 @@ def pytorch_attention_flash_attention():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -817,6 +825,7 @@ def force_upcast_attention_dtype():
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
global directml_device
|
global directml_device
|
||||||
if dev is None:
|
if dev is None:
|
||||||
@ -906,12 +915,15 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
if is_amd():
|
if is_amd():
|
||||||
return True
|
return True
|
||||||
|
try:
|
||||||
|
props = torch.cuda.get_device_properties("cuda")
|
||||||
|
if props.major >= 8:
|
||||||
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
if props.major < 6:
|
||||||
if props.major >= 8:
|
return False
|
||||||
return True
|
except AssertionError:
|
||||||
|
logging.warning("Torch was not compiled with cuda support")
|
||||||
if props.major < 6:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
fp16_works = False
|
fp16_works = False
|
||||||
@ -967,9 +979,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
try:
|
||||||
if props.major >= 8:
|
props = torch.cuda.get_device_properties(device)
|
||||||
return True
|
if props.major >= 8:
|
||||||
|
return True
|
||||||
|
except AssertionError:
|
||||||
|
logging.warning("Torch was not compiled with CUDA support")
|
||||||
|
return False
|
||||||
|
|
||||||
bf16_works = torch.cuda.is_bf16_supported()
|
bf16_works = torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user