mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-04 05:17:50 +08:00
Disable sage attention in stable audio dit and VAE. (#14148)
This commit is contained in:
parent
987a937658
commit
85a403d1ea
@ -433,11 +433,11 @@ class Attention(nn.Module):
|
|||||||
if self.differential:
|
if self.differential:
|
||||||
q, q_diff = q.unbind(dim=1)
|
q, q_diff = q.unbind(dim=1)
|
||||||
k, k_diff = k.unbind(dim=1)
|
k, k_diff = k.unbind(dim=1)
|
||||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options)
|
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||||
out = out - out_diff
|
out = out - out_diff
|
||||||
else:
|
else:
|
||||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||||
|
|
||||||
out = self.to_out(out)
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
|||||||
@ -138,11 +138,11 @@ class Attention(nn.Module):
|
|||||||
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
|
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
|
||||||
|
|
||||||
if self.differential:
|
if self.differential:
|
||||||
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
|
||||||
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True))
|
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False))
|
||||||
del q, k, v, q_diff, k_diff
|
del q, k, v, q_diff, k_diff
|
||||||
else:
|
else:
|
||||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user