"Boolean value of Tensor with more than one value is ambiguous" fix

This commit is contained in:
patientx 2025-05-11 20:39:42 +03:00 committed by GitHub
parent 8abcc4ec4f
commit cd7eb9bd36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -261,7 +261,7 @@ def do_hijack():
def amd_flash_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
try:
if (query.shape[-1] <= 128 and
not attn_mask and
attn_mask is None and # fix flash-attention error : "Flash attention error: Boolean value of Tensor with more than one value is ambiguous"
query.dtype != torch.float32):
if scale is None:
scale = query.shape[-1] ** -0.5