From 897d2662be3d593c78ee88a3eae24601cfef5499 Mon Sep 17 00:00:00 2001 From: vivienfanghua Date: Tue, 28 Oct 2025 19:13:17 +0800 Subject: [PATCH 1/5] add --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 96 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 3 ++ 3 files changed, 100 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482..d4219cdf3 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -111,6 +111,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-aiter-attention", action="store_true", help="Use aiter attention.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7437e0567..76bab22b0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -39,6 +39,15 @@ except ImportError: logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +AITER_ATTENTION_IS_AVAILABLE = False +try: + import aiter + AITER_ATTENTION_IS_AVAILABLE = True +except ImportError: + if model_management.aiter_attention_enabled(): + logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter") + exit(-1) + REGISTERED_ATTENTION_FUNCTIONS = {} def register_attention_function(name: str, func: Callable): # avoid replacing existing functions @@ -619,11 +628,96 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out +try: + @torch.library.custom_op("aiter_attention::aiter_flash_attn", mutates_args=()) + def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, softmax_scale: Optional[float] = None, + causal: bool = False, window_size: tuple = (-1, -1), + bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False) -> torch.Tensor: + return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, + causal=causal, window_size=window_size, bias=bias, + alibi_slopes=alibi_slopes, deterministic=deterministic, + return_lse=False, return_attn_probs=False, + cu_seqlens_q=None, cu_seqlens_kv=None) + + + @aiter_flash_attn_wrapper.register_fake + def aiter_flash_attn_fake(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), bias=None, alibi_slopes=None, deterministic=False): + # Output shape is the same as q + return q.new_empty(q.shape) +except AttributeError as error: + AITER_ATTN_ERROR = error + + def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, softmax_scale: Optional[float] = None, + causal: bool = False, window_size: tuple = (-1, -1), + bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False) -> torch.Tensor: + assert False, f"Could not define aiter_flash_attn_wrapper: {AITER_ATTN_ERROR}" + +@wrap_attn +def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + # reshape to (batch, seqlen, nheads, headdim) for aiter + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head), + (q, k, v), + ) + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + try: + # aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format + out = aiter_flash_attn_wrapper( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + bias=mask, + alibi_slopes=None, + deterministic=False, + ) + except Exception as e: + logging.warning(f"Aiter Attention failed, using default SDPA: {e}") + # fallback needs (batch, nheads, seqlen, headdim) format + q_sdpa = q.transpose(1, 2) + k_sdpa = k.transpose(1, 2) + v_sdpa = v.transpose(1, 2) + out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = out.transpose(1, 2) + + if skip_output_reshape: + # output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) + out = out.transpose(1, 2) + else: + # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) + out = out.reshape(b, -1, heads * dim_head) + return out + + optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage +elif model_management.aiter_attention_enabled(): + logging.info("Using aiter attention") + optimized_attention = attention_aiter elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers @@ -647,6 +741,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if AITER_ATTENTION_IS_AVAILABLE: + register_attention_function("aiter", attention_aiter) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/model_management.py b/comfy/model_management.py index afe78f36e..350bee6c6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1083,6 +1083,9 @@ def cast_to_device(tensor, device, dtype, copy=False): def sage_attention_enabled(): return args.use_sage_attention +def aiter_attention_enabled(): + return args.use_aiter_attention + def flash_attention_enabled(): return args.use_flash_attention From 1c5a503715b3cf1d438f1032d7338bfd203885db Mon Sep 17 00:00:00 2001 From: vivienfanghua Date: Tue, 28 Oct 2025 19:31:23 +0800 Subject: [PATCH 2/5] add --- comfy/ldm/modules/attention.py | 39 ++++++++++------------------------ 1 file changed, 11 insertions(+), 28 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 76bab22b0..f9459c2dc 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -628,34 +628,17 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out -try: - @torch.library.custom_op("aiter_attention::aiter_flash_attn", mutates_args=()) - def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, softmax_scale: Optional[float] = None, - causal: bool = False, window_size: tuple = (-1, -1), - bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False) -> torch.Tensor: - return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, - causal=causal, window_size=window_size, bias=bias, - alibi_slopes=alibi_slopes, deterministic=deterministic, - return_lse=False, return_attn_probs=False, - cu_seqlens_q=None, cu_seqlens_kv=None) - - - @aiter_flash_attn_wrapper.register_fake - def aiter_flash_attn_fake(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, - window_size=(-1, -1), bias=None, alibi_slopes=None, deterministic=False): - # Output shape is the same as q - return q.new_empty(q.shape) -except AttributeError as error: - AITER_ATTN_ERROR = error - - def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, softmax_scale: Optional[float] = None, - causal: bool = False, window_size: tuple = (-1, -1), - bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False) -> torch.Tensor: - assert False, f"Could not define aiter_flash_attn_wrapper: {AITER_ATTN_ERROR}" +def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, softmax_scale: Optional[float] = None, + causal: bool = False, window_size: tuple = (-1, -1), + bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False) -> torch.Tensor: + """Wrapper for aiter.flash_attn_func to handle its specific parameters""" + return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, + causal=causal, window_size=window_size, bias=bias, + alibi_slopes=alibi_slopes, deterministic=deterministic, + return_lse=False, return_attn_probs=False, + cu_seqlens_q=None, cu_seqlens_kv=None) @wrap_attn def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): From 43de98b30e294c6067272124996a88412419ae73 Mon Sep 17 00:00:00 2001 From: vivienfanghua Date: Wed, 29 Oct 2025 10:46:17 +0800 Subject: [PATCH 3/5] add --- comfy/ldm/modules/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f9459c2dc..1f9c56b34 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -621,6 +621,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape except Exception as e: logging.warning(f"Flash Attention failed, using default SDPA: {e}") out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) From e00688e134d622d8292e71f7b225100c915578a1 Mon Sep 17 00:00:00 2001 From: vivienfanghua Date: Fri, 31 Oct 2025 07:40:58 +0000 Subject: [PATCH 4/5] add --- comfy/ldm/modules/attention.py | 82 ++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 1f9c56b34..43991fd40 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -45,7 +45,8 @@ try: AITER_ATTENTION_IS_AVAILABLE = True except ImportError: if model_management.aiter_attention_enabled(): - logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter") + logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.") + logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation") exit(-1) REGISTERED_ATTENTION_FUNCTIONS = {} @@ -629,42 +630,46 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out -def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, softmax_scale: Optional[float] = None, - causal: bool = False, window_size: tuple = (-1, -1), - bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False) -> torch.Tensor: - """Wrapper for aiter.flash_attn_func to handle its specific parameters""" - return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, - causal=causal, window_size=window_size, bias=bias, - alibi_slopes=alibi_slopes, deterministic=deterministic, - return_lse=False, return_attn_probs=False, - cu_seqlens_q=None, cu_seqlens_kv=None) - @wrap_attn def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + # Store original inputs for fallback + q_orig, k_orig, v_orig, mask_orig = q, k, v, mask + if skip_reshape: b, _, _, dim_head = q.shape + q, k, v = map( + lambda t: t.transpose(1, 2), + (q, k, v), + ) else: b, _, dim_head = q.shape dim_head //= heads - # reshape to (batch, seqlen, nheads, headdim) for aiter q, k, v = map( lambda t: t.view(b, -1, heads, dim_head), (q, k, v), ) + + # Convert mask to [sq, sk] format for aiter bias + bias = None if mask is not None: - # add a batch dimension if there isn't already one if mask.ndim == 2: - mask = mask.unsqueeze(0) - # add a heads dimension if there isn't already one - if mask.ndim == 3: - mask = mask.unsqueeze(1) + bias = mask + elif mask.ndim == 3: + seqlen_q = q.shape[1] + if mask.shape[-2] == 1: + # [1, 1, sk] -> expand to [sq, sk] + bias = mask.squeeze(0).expand(seqlen_q, -1) + else: + # [batch, sq, sk] -> take first batch + bias = mask[0] + elif mask.ndim == 4: + # [batch, heads, sq, sk] -> take first batch and head + bias = mask[0, 0] try: # aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format - out = aiter_flash_attn_wrapper( + out = aiter.flash_attn_func( q, k, v, @@ -672,26 +677,29 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape softmax_scale=None, causal=False, window_size=(-1, -1), - bias=mask, + bias=bias, alibi_slopes=None, deterministic=False, + return_lse=False, + return_attn_probs=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, ) + + if skip_output_reshape: + # output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) + out = out.transpose(1, 2) + else: + # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) + out = out.reshape(b, -1, heads * dim_head) + return out + except Exception as e: - logging.warning(f"Aiter Attention failed, using default SDPA: {e}") - # fallback needs (batch, nheads, seqlen, headdim) format - q_sdpa = q.transpose(1, 2) - k_sdpa = k.transpose(1, 2) - v_sdpa = v.transpose(1, 2) - out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False) - out = out.transpose(1, 2) - - if skip_output_reshape: - # output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) - out = out.transpose(1, 2) - else: - # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) - out = out.reshape(b, -1, heads * dim_head) - return out + logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}") + # Fallback to attention_pytorch with original inputs + return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig, + attn_precision=attn_precision, skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, **kwargs) optimized_attention = attention_basic @@ -1170,5 +1178,3 @@ class SpatialVideoTransformer(SpatialTransformer): x = self.proj_out(x) out = x + x_in return out - - From 0f6f3c93fda308e1c10badccf9498007946682ce Mon Sep 17 00:00:00 2001 From: vivienfanghua Date: Mon, 3 Nov 2025 09:25:04 +0000 Subject: [PATCH 5/5] format --- comfy/ldm/modules/attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 43991fd40..c18266672 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -45,8 +45,8 @@ try: AITER_ATTENTION_IS_AVAILABLE = True except ImportError: if model_management.aiter_attention_enabled(): - logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.") - logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation") + logging.error("\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.") + logging.error("Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation") exit(-1) REGISTERED_ATTENTION_FUNCTIONS = {} @@ -634,7 +634,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): # Store original inputs for fallback q_orig, k_orig, v_orig, mask_orig = q, k, v, mask - + if skip_reshape: b, _, _, dim_head = q.shape q, k, v = map( @@ -649,7 +649,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape (q, k, v), ) - + # Convert mask to [sq, sk] format for aiter bias bias = None if mask is not None: @@ -685,7 +685,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape cu_seqlens_q=None, cu_seqlens_kv=None, ) - + if skip_output_reshape: # output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) out = out.transpose(1, 2) @@ -693,12 +693,12 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) out = out.reshape(b, -1, heads * dim_head) return out - + except Exception as e: logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}") # Fallback to attention_pytorch with original inputs - return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig, - attn_precision=attn_precision, skip_reshape=skip_reshape, + return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig, + attn_precision=attn_precision, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)