Allow attn_mask in attention_pytorch.

This commit is contained in:
comfyanonymous 2023-10-11 20:24:17 -04:00
parent c60864b5e4
commit 95df4b6174

View File

@ -284,7 +284,7 @@ def attention_pytorch(q, k, v, heads, mask=None):
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError