mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 03:22:33 +08:00
Disable sage attention on ace step 1.5 (#12297)
This commit is contained in:
parent
6125b80979
commit
a50c32d63f
@ -183,7 +183,7 @@ class AceStepAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
attn_bias = window_bias
|
attn_bias = window_bias
|
||||||
|
|
||||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
|
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|||||||
@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
if kwargs.get("low_precision_attention", True) is False:
|
||||||
|
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||||
|
|
||||||
exception_fallback = False
|
exception_fallback = False
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user