mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-21 23:39:35 +08:00
Compare commits
3 Commits
3bd7d8b15d
...
553f71aa9e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
553f71aa9e | ||
|
|
739191df33 | ||
|
|
c6217aa799 |
@ -719,7 +719,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return out
|
||||
|
||||
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False):
|
||||
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
@ -737,7 +737,10 @@ def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_resha
|
||||
v = v.transpose(1, 2)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
return out.transpose(1, 2).values()
|
||||
out = out.transpose(1, 2)
|
||||
if not skip_output_reshape:
|
||||
return out.values().reshape(-1, heads * (q.shape[-1]))
|
||||
return out.values()
|
||||
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
optimized_attention = attention_basic
|
||||
|
||||
@ -808,7 +808,7 @@ class NaSwinAttention(NaMMAttention):
|
||||
q=concat_win(vid_q, txt_q),
|
||||
k=concat_win(vid_k, txt_k),
|
||||
v=concat_win(vid_v, txt_v),
|
||||
heads=self.heads, skip_reshape=True,
|
||||
heads=self.heads, skip_reshape=True, skip_output_reshape=True,
|
||||
cu_seqlens_q=cache_win(
|
||||
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||
),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user