Compare commits

...

3 Commits

Author SHA1 Message Date
Kacper Michajłow
3134211aa6
Merge 02d15cc85f into da2bfb5b0a 2025-12-13 11:33:19 -06:00
Kacper Michajłow
02d15cc85f
Enable pytorch attention by default on AMD gfx1200 2025-10-21 12:49:21 +02:00
Kacper Michajłow
9519e2d49d
Revert "Disable pytorch attention in VAE for AMD."
It causes crashes even without pytorch attention for big sizes, and for
resonable sizes it is significantly faster.

This reverts commit 1cd6cd6080.
2025-10-21 12:48:34 +02:00
2 changed files with 3 additions and 8 deletions

View File

@ -335,7 +335,7 @@ def vae_attention():
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled_vae():
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
return pytorch_attention
else:

View File

@ -354,8 +354,8 @@ try:
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
SUPPORT_FP8_OPS = True
@ -1221,11 +1221,6 @@ def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_enabled_vae():
if is_amd():
return False # enabling pytorch attention on AMD currently causes crash when doing high res
return pytorch_attention_enabled()
def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION: