Compare commits

...

3 Commits

Author SHA1 Message Date
Kacper Michajłow
3134211aa6
Merge 02d15cc85f into da2bfb5b0a 2025-12-13 11:33:19 -06:00
Kacper Michajłow
02d15cc85f
Enable pytorch attention by default on AMD gfx1200 2025-10-21 12:49:21 +02:00
Kacper Michajłow
9519e2d49d
Revert "Disable pytorch attention in VAE for AMD."
It causes crashes even without pytorch attention for big sizes, and for
resonable sizes it is significantly faster.

This reverts commit 1cd6cd6080.
2025-10-21 12:48:34 +02:00
2 changed files with 3 additions and 8 deletions

View File

@ -335,7 +335,7 @@ def vae_attention():
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE") logging.info("Using xformers attention in VAE")
return xformers_attention return xformers_attention
elif model_management.pytorch_attention_enabled_vae(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE") logging.info("Using pytorch attention in VAE")
return pytorch_attention return pytorch_attention
else: else:

View File

@ -354,8 +354,8 @@ try:
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0): if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]): if any((a in arch) for a in ["gfx1200", "gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
SUPPORT_FP8_OPS = True SUPPORT_FP8_OPS = True
@ -1221,11 +1221,6 @@ def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_enabled_vae():
if is_amd():
return False # enabling pytorch attention on AMD currently causes crash when doing high res
return pytorch_attention_enabled()
def pytorch_attention_flash_attention(): def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION: