From 4bdb0dddb74c8276a420eafd85a8a010d9508617 Mon Sep 17 00:00:00 2001 From: Jianqiao Huang Date: Mon, 1 Dec 2025 09:54:19 -0800 Subject: [PATCH] Add support for sage attention 3 in comfyui, enable via new cli arg --use-sage-attiention3 --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 96 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 3 ++ 3 files changed, 100 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5f0dfaa10..a3c4a6bc6 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,6 +112,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7437e0567..2dfc55bfa 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,18 @@ except ImportError as e: raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError as e: + if model_management.sage_attention3_enabled(): + if e.name == "sageattn3": + logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell") + else: + raise e + exit(-1) + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -560,6 +572,85 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + if dim_head >= 256 or N <= 2048: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -624,6 +715,9 @@ optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage +if model_management.sage_attention3_enabled(): + logging.info("Using sage attention 3") + optimized_attention = attention3_sage elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers @@ -647,6 +741,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe..c971dd95f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1189,6 +1189,9 @@ def unpin_memory(tensor): def sage_attention_enabled(): return args.use_sage_attention +def sage_attention3_enabled(): + return args.use_sage_attention3 + def flash_attention_enabled(): return args.use_flash_attention