Enable sage attention autodetection

This commit is contained in:
doctorpangloss 2024-10-09 09:27:05 -07:00
parent 388dad67d5
commit 99f0fa8b50

View File

@ -420,7 +420,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@pytorch_style_decl
def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sageattn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) # pylint: disable=possibly-used-before-assignment
@ -431,7 +431,10 @@ def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_re
optimized_attention = attention_basic
if model_management.xformers_enabled():
if model_management.sage_attention_enabled():
logging.debug("Using sage attention")
optimized_attention = attention_sageattn
elif model_management.xformers_enabled():
logging.debug("Using xformers cross attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():