diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2dfc55bfa..d51e49da2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -574,6 +574,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= @wrap_attn def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False if (q.device.type != "cuda" or q.dtype not in (torch.float16, torch.bfloat16) or mask is not None): @@ -599,6 +600,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) q_s, k_s, v_s = q, k, v N = q.shape[2] + dim_head = D else: B, N, inner_dim = q.shape if inner_dim % heads != 0: @@ -611,27 +613,33 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape **kwargs ) dim_head = inner_dim // heads - + + if dim_head >= 256 or N <= 1024: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if not skip_reshape: q_s, k_s, v_s = map( lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), (q, k, v), ) B, H, L, D = q_s.shape - - if dim_head >= 256 or N <= 2048: - return attention_pytorch( - q, k, v, heads, - mask=mask, - attn_precision=attn_precision, - skip_reshape=False, - skip_output_reshape=skip_output_reshape, - **kwargs - ) try: out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) except Exception as e: + exception_fallback = True logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s return attention_pytorch( q, k, v, heads, mask=mask,