Fix some bugs found in PR review. The N dimension at which Sage

Attention 3 takes effect is reduced to 1024 (although the improvement is
not significant at this scale).
This commit is contained in:
Jianqiao Huang 2025-12-02 08:19:40 -08:00
parent 4bdb0dddb7
commit 648814b751

View File

@ -574,6 +574,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
@wrap_attn
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if (q.device.type != "cuda" or
q.dtype not in (torch.float16, torch.bfloat16) or
mask is not None):
@ -599,6 +600,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
B, N, inner_dim = q.shape
if inner_dim % heads != 0:
@ -611,27 +613,33 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
**kwargs
)
dim_head = inner_dim // heads
if dim_head >= 256 or N <= 1024:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)
if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
B, H, L, D = q_s.shape
if dim_head >= 256 or N <= 2048:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
try:
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
except Exception as e:
exception_fallback = True
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
if exception_fallback:
if not skip_reshape:
del q_s, k_s, v_s
return attention_pytorch(
q, k, v, heads,
mask=mask,