This commit is contained in:
vivienfanghua 2025-11-03 09:25:04 +00:00
parent e00688e134
commit 0f6f3c93fd

View File

@ -45,8 +45,8 @@ try:
AITER_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.aiter_attention_enabled():
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
logging.error("\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
logging.error("Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
@ -634,7 +634,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
# Store original inputs for fallback
q_orig, k_orig, v_orig, mask_orig = q, k, v, mask
if skip_reshape:
b, _, _, dim_head = q.shape
q, k, v = map(
@ -649,7 +649,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
(q, k, v),
)
# Convert mask to [sq, sk] format for aiter bias
bias = None
if mask is not None:
@ -685,7 +685,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
cu_seqlens_q=None,
cu_seqlens_kv=None,
)
if skip_output_reshape:
# output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
out = out.transpose(1, 2)
@ -693,12 +693,12 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
out = out.reshape(b, -1, heads * dim_head)
return out
except Exception as e:
logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}")
# Fallback to attention_pytorch with original inputs
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
attn_precision=attn_precision, skip_reshape=skip_reshape,
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
attn_precision=attn_precision, skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape, **kwargs)