From 6b97e3f4cbd003a3fc17e796f257079ae409f768 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 6 May 2026 21:31:49 +0300 Subject: [PATCH] Only fall to pytorch attention from sage for guide mask --- comfy/ldm/lightricks/model.py | 13 ++++++++----- comfy/ldm/modules/attention.py | 4 ---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index be078081c..7ccb2c42a 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -401,13 +401,16 @@ def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, trans tracked_end = guide_start + guide_mask.tracked_count out = torch.empty_like(q) - out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention_masked( - q[:, :guide_start, :], k, v, heads, guide_mask.noisy_mask, + + out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask, attn_precision=attn_precision, transformer_options=transformer_options, + low_precision_attention=False, # sageattn mask support is unreliable ) - out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention_masked( - q[:, guide_start:tracked_end, :], k, v, heads, guide_mask.tracked_mask, + out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask, attn_precision=attn_precision, transformer_options=transformer_options, + low_precision_attention=False, ) return out @@ -469,7 +472,7 @@ class CrossAttention(nn.Module): elif isinstance(mask, GuideAttentionMask): out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options) # Apply per-head gating if enabled if self.to_gate_logits is not None: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index e9df36435..a68cb8439 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -544,10 +544,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if kwargs.get("low_precision_attention", True) is False: return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) - # sageattn's attn_mask support is unreliable, fall back to pytorch when a mask is set - if mask is not None: - return attention_pytorch(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) - exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape