From e00688e134d622d8292e71f7b225100c915578a1 Mon Sep 17 00:00:00 2001 From: vivienfanghua Date: Fri, 31 Oct 2025 07:40:58 +0000 Subject: [PATCH] add --- comfy/ldm/modules/attention.py | 82 ++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 1f9c56b34..43991fd40 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -45,7 +45,8 @@ try: AITER_ATTENTION_IS_AVAILABLE = True except ImportError: if model_management.aiter_attention_enabled(): - logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter") + logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.") + logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation") exit(-1) REGISTERED_ATTENTION_FUNCTIONS = {} @@ -629,42 +630,46 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out -def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, softmax_scale: Optional[float] = None, - causal: bool = False, window_size: tuple = (-1, -1), - bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False) -> torch.Tensor: - """Wrapper for aiter.flash_attn_func to handle its specific parameters""" - return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, - causal=causal, window_size=window_size, bias=bias, - alibi_slopes=alibi_slopes, deterministic=deterministic, - return_lse=False, return_attn_probs=False, - cu_seqlens_q=None, cu_seqlens_kv=None) - @wrap_attn def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + # Store original inputs for fallback + q_orig, k_orig, v_orig, mask_orig = q, k, v, mask + if skip_reshape: b, _, _, dim_head = q.shape + q, k, v = map( + lambda t: t.transpose(1, 2), + (q, k, v), + ) else: b, _, dim_head = q.shape dim_head //= heads - # reshape to (batch, seqlen, nheads, headdim) for aiter q, k, v = map( lambda t: t.view(b, -1, heads, dim_head), (q, k, v), ) + + # Convert mask to [sq, sk] format for aiter bias + bias = None if mask is not None: - # add a batch dimension if there isn't already one if mask.ndim == 2: - mask = mask.unsqueeze(0) - # add a heads dimension if there isn't already one - if mask.ndim == 3: - mask = mask.unsqueeze(1) + bias = mask + elif mask.ndim == 3: + seqlen_q = q.shape[1] + if mask.shape[-2] == 1: + # [1, 1, sk] -> expand to [sq, sk] + bias = mask.squeeze(0).expand(seqlen_q, -1) + else: + # [batch, sq, sk] -> take first batch + bias = mask[0] + elif mask.ndim == 4: + # [batch, heads, sq, sk] -> take first batch and head + bias = mask[0, 0] try: # aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format - out = aiter_flash_attn_wrapper( + out = aiter.flash_attn_func( q, k, v, @@ -672,26 +677,29 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape softmax_scale=None, causal=False, window_size=(-1, -1), - bias=mask, + bias=bias, alibi_slopes=None, deterministic=False, + return_lse=False, + return_attn_probs=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, ) + + if skip_output_reshape: + # output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) + out = out.transpose(1, 2) + else: + # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) + out = out.reshape(b, -1, heads * dim_head) + return out + except Exception as e: - logging.warning(f"Aiter Attention failed, using default SDPA: {e}") - # fallback needs (batch, nheads, seqlen, headdim) format - q_sdpa = q.transpose(1, 2) - k_sdpa = k.transpose(1, 2) - v_sdpa = v.transpose(1, 2) - out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False) - out = out.transpose(1, 2) - - if skip_output_reshape: - # output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) - out = out.transpose(1, 2) - else: - # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) - out = out.reshape(b, -1, heads * dim_head) - return out + logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}") + # Fallback to attention_pytorch with original inputs + return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig, + attn_precision=attn_precision, skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, **kwargs) optimized_attention = attention_basic @@ -1170,5 +1178,3 @@ class SpatialVideoTransformer(SpatialTransformer): x = self.proj_out(x) out = x + x_in return out - -