mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
output reshape
This commit is contained in:
parent
3b418dab2c
commit
c6217aa799
@ -719,7 +719,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False):
|
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
if not skip_reshape:
|
if not skip_reshape:
|
||||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||||
total_tokens, embed_dim = q.shape
|
total_tokens, embed_dim = q.shape
|
||||||
@ -737,7 +737,10 @@ def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_resha
|
|||||||
v = v.transpose(1, 2)
|
v = v.transpose(1, 2)
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
return out.transpose(1, 2).values()
|
out = out.transpose(1, 2)
|
||||||
|
if not skip_output_reshape:
|
||||||
|
return out.values().reshape(-1, heads * (q.shape[-1]))
|
||||||
|
return out.values()
|
||||||
|
|
||||||
optimized_var_attention = var_attention_pytorch
|
optimized_var_attention = var_attention_pytorch
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user