This commit is contained in:
vivienfanghua 2025-10-29 10:46:17 +08:00
parent 1c5a503715
commit 43de98b30e

View File

@ -621,6 +621,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)