From 4bdb0dddb74c8276a420eafd85a8a010d9508617 Mon Sep 17 00:00:00 2001 From: Jianqiao Huang Date: Mon, 1 Dec 2025 09:54:19 -0800 Subject: [PATCH 1/3] Add support for sage attention 3 in comfyui, enable via new cli arg --use-sage-attiention3 --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 96 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 3 ++ 3 files changed, 100 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5f0dfaa10..a3c4a6bc6 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,6 +112,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7437e0567..2dfc55bfa 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,18 @@ except ImportError as e: raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError as e: + if model_management.sage_attention3_enabled(): + if e.name == "sageattn3": + logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell") + else: + raise e + exit(-1) + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -560,6 +572,85 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + if dim_head >= 256 or N <= 2048: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -624,6 +715,9 @@ optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage +if model_management.sage_attention3_enabled(): + logging.info("Using sage attention 3") + optimized_attention = attention3_sage elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers @@ -647,6 +741,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe..c971dd95f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1189,6 +1189,9 @@ def unpin_memory(tensor): def sage_attention_enabled(): return args.use_sage_attention +def sage_attention3_enabled(): + return args.use_sage_attention3 + def flash_attention_enabled(): return args.use_flash_attention From 648814b7516b369bf5b57a75887c4a3be43228b7 Mon Sep 17 00:00:00 2001 From: Jianqiao Huang Date: Tue, 2 Dec 2025 08:19:40 -0800 Subject: [PATCH 2/3] Fix some bugs found in PR review. The N dimension at which Sage Attention 3 takes effect is reduced to 1024 (although the improvement is not significant at this scale). --- comfy/ldm/modules/attention.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2dfc55bfa..d51e49da2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -574,6 +574,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= @wrap_attn def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False if (q.device.type != "cuda" or q.dtype not in (torch.float16, torch.bfloat16) or mask is not None): @@ -599,6 +600,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) q_s, k_s, v_s = q, k, v N = q.shape[2] + dim_head = D else: B, N, inner_dim = q.shape if inner_dim % heads != 0: @@ -611,27 +613,33 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape **kwargs ) dim_head = inner_dim // heads - + + if dim_head >= 256 or N <= 1024: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if not skip_reshape: q_s, k_s, v_s = map( lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), (q, k, v), ) B, H, L, D = q_s.shape - - if dim_head >= 256 or N <= 2048: - return attention_pytorch( - q, k, v, heads, - mask=mask, - attn_precision=attn_precision, - skip_reshape=False, - skip_output_reshape=skip_output_reshape, - **kwargs - ) try: out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) except Exception as e: + exception_fallback = True logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s return attention_pytorch( q, k, v, heads, mask=mask, From 7883076f5cc9ad77c567b2013b3b79483348ae61 Mon Sep 17 00:00:00 2001 From: Jianqiao Huang Date: Thu, 11 Dec 2025 14:16:47 -0800 Subject: [PATCH 3/3] Remove the Sage Attention3 switch, but retain the attention function registration. --- comfy/cli_args.py | 1 - comfy/ldm/modules/attention.py | 20 ++++++-------------- comfy/model_management.py | 3 --- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 6becebcb5..209fc185b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,7 +112,6 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") -attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index d23a753f9..e9e93ec72 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -35,12 +35,7 @@ try: from sageattn3 import sageattn3_blackwell SAGE_ATTENTION3_IS_AVAILABLE = True except ImportError as e: - if model_management.sage_attention3_enabled(): - if e.name == "sageattn3": - logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell") - else: - raise e - exit(-1) + pass FLASH_ATTENTION_IS_AVAILABLE = False try: @@ -589,7 +584,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape skip_output_reshape=skip_output_reshape, **kwargs ) - + if skip_reshape: B, H, L, D = q.shape if H != heads: @@ -626,16 +621,16 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape skip_output_reshape=skip_output_reshape, **kwargs ) - + if not skip_reshape: q_s, k_s, v_s = map( lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), (q, k, v), ) B, H, L, D = q_s.shape - + try: - out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) except Exception as e: exception_fallback = True logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) @@ -655,7 +650,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape if skip_reshape: if not skip_output_reshape: out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) - else: + else: if skip_output_reshape: pass else: @@ -726,9 +721,6 @@ optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage -if model_management.sage_attention3_enabled(): - logging.info("Using sage attention 3") - optimized_attention = attention3_sage elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers diff --git a/comfy/model_management.py b/comfy/model_management.py index c971dd95f..aeddbaefe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1189,9 +1189,6 @@ def unpin_memory(tensor): def sage_attention_enabled(): return args.use_sage_attention -def sage_attention3_enabled(): - return args.use_sage_attention3 - def flash_attention_enabled(): return args.use_flash_attention