diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 51004fe6b..a25421750 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -420,7 +420,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @pytorch_style_decl -def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): +def attention_sageattn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) # pylint: disable=possibly-used-before-assignment @@ -431,7 +431,10 @@ def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_re optimized_attention = attention_basic -if model_management.xformers_enabled(): +if model_management.sage_attention_enabled(): + logging.debug("Using sage attention") + optimized_attention = attention_sageattn +elif model_management.xformers_enabled(): logging.debug("Using xformers cross attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled():