Only fall to pytorch attention from sage for guide mask

This commit is contained in:
kijai 2026-05-06 21:31:49 +03:00
parent f2beaa5802
commit 6b97e3f4cb
2 changed files with 8 additions and 9 deletions

View File

@ -401,13 +401,16 @@ def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, trans
tracked_end = guide_start + guide_mask.tracked_count
out = torch.empty_like(q)
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention_masked(
q[:, :guide_start, :], k, v, heads, guide_mask.noisy_mask,
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention_masked(
q[:, guide_start:tracked_end, :], k, v, heads, guide_mask.tracked_mask,
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False,
)
return out
@ -469,7 +472,7 @@ class CrossAttention(nn.Module):
elif isinstance(mask, GuideAttentionMask):
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:

View File

@ -544,10 +544,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if kwargs.get("low_precision_attention", True) is False:
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
# sageattn's attn_mask support is unreliable, fall back to pytorch when a mask is set
if mask is not None:
return attention_pytorch(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape