Disable sage attention on ace step 1.5 (#12297)

This commit is contained in:
comfyanonymous 2026-02-04 19:15:30 -08:00 committed by GitHub
parent 6125b80979
commit a50c32d63f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 1 deletions

View File

@ -183,7 +183,7 @@ class AceStepAttention(nn.Module):
else:
attn_bias = window_bias
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
attn_output = self.o_proj(attn_output)
return attn_output

View File

@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if kwargs.get("low_precision_attention", True) is False:
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape