diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 52273181e..4c814a544 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -719,7 +719,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out -def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False): +def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): if not skip_reshape: # assumes 2D q, k,v [total_tokens, embed_dim] total_tokens, embed_dim = q.shape @@ -737,7 +737,10 @@ def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_resha v = v.transpose(1, 2) out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - return out.transpose(1, 2).values() + out = out.transpose(1, 2) + if not skip_output_reshape: + return out.values().reshape(-1, heads * (q.shape[-1])) + return out.values() optimized_var_attention = var_attention_pytorch optimized_attention = attention_basic