mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
Only fall to pytorch attention from sage for guide mask
This commit is contained in:
parent
f2beaa5802
commit
6b97e3f4cb
@ -401,13 +401,16 @@ def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, trans
|
||||
tracked_end = guide_start + guide_mask.tracked_count
|
||||
|
||||
out = torch.empty_like(q)
|
||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention_masked(
|
||||
q[:, :guide_start, :], k, v, heads, guide_mask.noisy_mask,
|
||||
|
||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False, # sageattn mask support is unreliable
|
||||
)
|
||||
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention_masked(
|
||||
q[:, guide_start:tracked_end, :], k, v, heads, guide_mask.tracked_mask,
|
||||
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
return out
|
||||
|
||||
@ -469,7 +472,7 @@ class CrossAttention(nn.Module):
|
||||
elif isinstance(mask, GuideAttentionMask):
|
||||
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
|
||||
@ -544,10 +544,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
# sageattn's attn_mask support is unreliable, fall back to pytorch when a mask is set
|
||||
if mask is not None:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
Loading…
Reference in New Issue
Block a user