From 7883076f5cc9ad77c567b2013b3b79483348ae61 Mon Sep 17 00:00:00 2001 From: Jianqiao Huang Date: Thu, 11 Dec 2025 14:16:47 -0800 Subject: [PATCH] Remove the Sage Attention3 switch, but retain the attention function registration. --- comfy/cli_args.py | 1 - comfy/ldm/modules/attention.py | 20 ++++++-------------- comfy/model_management.py | 3 --- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 6becebcb5..209fc185b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,7 +112,6 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") -attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index d23a753f9..e9e93ec72 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -35,12 +35,7 @@ try: from sageattn3 import sageattn3_blackwell SAGE_ATTENTION3_IS_AVAILABLE = True except ImportError as e: - if model_management.sage_attention3_enabled(): - if e.name == "sageattn3": - logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell") - else: - raise e - exit(-1) + pass FLASH_ATTENTION_IS_AVAILABLE = False try: @@ -589,7 +584,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape skip_output_reshape=skip_output_reshape, **kwargs ) - + if skip_reshape: B, H, L, D = q.shape if H != heads: @@ -626,16 +621,16 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape skip_output_reshape=skip_output_reshape, **kwargs ) - + if not skip_reshape: q_s, k_s, v_s = map( lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), (q, k, v), ) B, H, L, D = q_s.shape - + try: - out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) except Exception as e: exception_fallback = True logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) @@ -655,7 +650,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape if skip_reshape: if not skip_output_reshape: out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) - else: + else: if skip_output_reshape: pass else: @@ -726,9 +721,6 @@ optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage -if model_management.sage_attention3_enabled(): - logging.info("Using sage attention 3") - optimized_attention = attention3_sage elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers diff --git a/comfy/model_management.py b/comfy/model_management.py index c971dd95f..aeddbaefe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1189,9 +1189,6 @@ def unpin_memory(tensor): def sage_attention_enabled(): return args.use_sage_attention -def sage_attention3_enabled(): - return args.use_sage_attention3 - def flash_attention_enabled(): return args.use_flash_attention