Disable sage attention in stable audio dit and VAE. (#14148)

This commit is contained in:
comfyanonymous 2026-05-27 17:35:03 -07:00 committed by GitHub
parent 987a937658
commit 85a403d1ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 6 deletions

View File

@ -433,11 +433,11 @@ class Attention(nn.Module):
if self.differential:
q, q_diff = q.unbind(dim=1)
k, k_diff = k.unbind(dim=1)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = out - out_diff
else:
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = self.to_out(out)

View File

@ -138,11 +138,11 @@ class Attention(nn.Module):
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
if self.differential:
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True))
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False))
del q, k, v, q_diff, k_diff
else:
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
del q, k, v
return self.to_out(out)