From 85a403d1ea868a9188d4b32de2e6f0be9174cfed Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 27 May 2026 17:35:03 -0700 Subject: [PATCH] Disable sage attention in stable audio dit and VAE. (#14148) --- comfy/ldm/audio/dit.py | 6 +++--- comfy/ldm/audio/vae_sa3.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index a6258b755..c28be5b49 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -433,11 +433,11 @@ class Attention(nn.Module): if self.differential: q, q_diff = q.unbind(dim=1) k, k_diff = k.unbind(dim=1) - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) - out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options) + out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) + out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) out = out - out_diff else: - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) out = self.to_out(out) diff --git a/comfy/ldm/audio/vae_sa3.py b/comfy/ldm/audio/vae_sa3.py index 276846444..8be36d6ee 100644 --- a/comfy/ldm/audio/vae_sa3.py +++ b/comfy/ldm/audio/vae_sa3.py @@ -138,11 +138,11 @@ class Attention(nn.Module): k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype) if self.differential: - out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) - - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True)) + out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False) + - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)) del q, k, v, q_diff, k_diff else: - out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False) del q, k, v return self.to_out(out)