This commit is contained in:
vivienfanghua 2025-10-31 07:40:58 +00:00
parent 43de98b30e
commit e00688e134

View File

@ -45,7 +45,8 @@ try:
AITER_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.aiter_attention_enabled():
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter")
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
@ -629,42 +630,46 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return out
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
causal: bool = False, window_size: tuple = (-1, -1),
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False) -> torch.Tensor:
"""Wrapper for aiter.flash_attn_func to handle its specific parameters"""
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
causal=causal, window_size=window_size, bias=bias,
alibi_slopes=alibi_slopes, deterministic=deterministic,
return_lse=False, return_attn_probs=False,
cu_seqlens_q=None, cu_seqlens_kv=None)
@wrap_attn
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
# Store original inputs for fallback
q_orig, k_orig, v_orig, mask_orig = q, k, v, mask
if skip_reshape:
b, _, _, dim_head = q.shape
q, k, v = map(
lambda t: t.transpose(1, 2),
(q, k, v),
)
else:
b, _, dim_head = q.shape
dim_head //= heads
# reshape to (batch, seqlen, nheads, headdim) for aiter
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
# Convert mask to [sq, sk] format for aiter bias
bias = None
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
bias = mask
elif mask.ndim == 3:
seqlen_q = q.shape[1]
if mask.shape[-2] == 1:
# [1, 1, sk] -> expand to [sq, sk]
bias = mask.squeeze(0).expand(seqlen_q, -1)
else:
# [batch, sq, sk] -> take first batch
bias = mask[0]
elif mask.ndim == 4:
# [batch, heads, sq, sk] -> take first batch and head
bias = mask[0, 0]
try:
# aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format
out = aiter_flash_attn_wrapper(
out = aiter.flash_attn_func(
q,
k,
v,
@ -672,26 +677,29 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
softmax_scale=None,
causal=False,
window_size=(-1, -1),
bias=mask,
bias=bias,
alibi_slopes=None,
deterministic=False,
return_lse=False,
return_attn_probs=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
)
if skip_output_reshape:
# output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
out = out.transpose(1, 2)
else:
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
out = out.reshape(b, -1, heads * dim_head)
return out
except Exception as e:
logging.warning(f"Aiter Attention failed, using default SDPA: {e}")
# fallback needs (batch, nheads, seqlen, headdim) format
q_sdpa = q.transpose(1, 2)
k_sdpa = k.transpose(1, 2)
v_sdpa = v.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = out.transpose(1, 2)
if skip_output_reshape:
# output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
out = out.transpose(1, 2)
else:
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
out = out.reshape(b, -1, heads * dim_head)
return out
logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}")
# Fallback to attention_pytorch with original inputs
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
attn_precision=attn_precision, skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape, **kwargs)
optimized_attention = attention_basic
@ -1170,5 +1178,3 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x)
out = x + x_in
return out