diff --git a/comfy/model_management.py b/comfy/model_management.py index e134440f5..dd8a2a28f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -25,7 +25,6 @@ import sys import platform import weakref import gc -import zluda class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -237,10 +236,23 @@ try: except: pass -#if ENABLE_PYTORCH_ATTENTION: -# torch.backends.cuda.enable_math_sdp(True) -# torch.backends.cuda.enable_flash_sdp(True) -# torch.backends.cuda.enable_mem_efficient_sdp(True) + +try: + if is_amd(): + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + logging.info("AMD arch: {}".format(arch)) + if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + if int(torch_version[0]) >= 2 and int(torch_version[2]) >= 7: # works on 2.6 but doesn't actually seem to improve much + if arch in ["gfx1100"]: #TODO: more arches + ENABLE_PYTORCH_ATTENTION = True +except: + pass + + +if ENABLE_PYTORCH_ATTENTION: + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) try: if is_nvidia() and args.fast: @@ -307,6 +319,7 @@ try: except: logging.warning("Could not pick default device.") + current_loaded_models = [] def module_size(module):