From 99f0fa8b5045614e4716a69806b5a43155aecefe Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 9 Oct 2024 09:27:05 -0700 Subject: [PATCH] Enable sage attention autodetection --- comfy/ldm/modules/attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 51004fe6b..a25421750 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -420,7 +420,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @pytorch_style_decl -def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): +def attention_sageattn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) # pylint: disable=possibly-used-before-assignment @@ -431,7 +431,10 @@ def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_re optimized_attention = attention_basic -if model_management.xformers_enabled(): +if model_management.sage_attention_enabled(): + logging.debug("Using sage attention") + optimized_attention = attention_sageattn +elif model_management.xformers_enabled(): logging.debug("Using xformers cross attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled():