Update model_management.py

This commit is contained in:
patientx 2025-02-14 12:30:19 +03:00 committed by GitHub
parent 4d66aa9709
commit 99d2824d5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,6 @@ import sys
import platform
import weakref
import gc
import zluda
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -237,10 +236,23 @@ try:
except:
pass
#if ENABLE_PYTORCH_ATTENTION:
# torch.backends.cuda.enable_math_sdp(True)
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True)
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if int(torch_version[0]) >= 2 and int(torch_version[2]) >= 7: # works on 2.6 but doesn't actually seem to improve much
if arch in ["gfx1100"]: #TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
except:
pass
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
try:
if is_nvidia() and args.fast:
@ -307,6 +319,7 @@ try:
except:
logging.warning("Could not pick default device.")
current_loaded_models = []
def module_size(module):