mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Fix some bugs found in PR review. The N dimension at which Sage
Attention 3 takes effect is reduced to 1024 (although the improvement is not significant at this scale).
This commit is contained in:
parent
4bdb0dddb7
commit
648814b751
@ -574,6 +574,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
exception_fallback = False
|
||||||
if (q.device.type != "cuda" or
|
if (q.device.type != "cuda" or
|
||||||
q.dtype not in (torch.float16, torch.bfloat16) or
|
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||||
mask is not None):
|
mask is not None):
|
||||||
@ -599,6 +600,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
)
|
)
|
||||||
q_s, k_s, v_s = q, k, v
|
q_s, k_s, v_s = q, k, v
|
||||||
N = q.shape[2]
|
N = q.shape[2]
|
||||||
|
dim_head = D
|
||||||
else:
|
else:
|
||||||
B, N, inner_dim = q.shape
|
B, N, inner_dim = q.shape
|
||||||
if inner_dim % heads != 0:
|
if inner_dim % heads != 0:
|
||||||
@ -612,26 +614,32 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
)
|
)
|
||||||
dim_head = inner_dim // heads
|
dim_head = inner_dim // heads
|
||||||
|
|
||||||
|
if dim_head >= 256 or N <= 1024:
|
||||||
|
return attention_pytorch(
|
||||||
|
q, k, v, heads,
|
||||||
|
mask=mask,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_reshape:
|
||||||
q_s, k_s, v_s = map(
|
q_s, k_s, v_s = map(
|
||||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
B, H, L, D = q_s.shape
|
B, H, L, D = q_s.shape
|
||||||
|
|
||||||
if dim_head >= 256 or N <= 2048:
|
|
||||||
return attention_pytorch(
|
|
||||||
q, k, v, heads,
|
|
||||||
mask=mask,
|
|
||||||
attn_precision=attn_precision,
|
|
||||||
skip_reshape=False,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
exception_fallback = True
|
||||||
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||||
|
|
||||||
|
if exception_fallback:
|
||||||
|
if not skip_reshape:
|
||||||
|
del q_s, k_s, v_s
|
||||||
return attention_pytorch(
|
return attention_pytorch(
|
||||||
q, k, v, heads,
|
q, k, v, heads,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user