Use optimized attention in Dino2AttentionBlock.

This commit is contained in:
Talmaj Marinc 2026-06-08 12:46:58 +02:00
parent 15f4dc401a
commit 8cbdd8f72e

View File

@ -53,8 +53,7 @@ class Dino2AttentionBlock(torch.nn.Module):
if rope is not None and pos is not None:
q = rope(q, pos)
k = rope(k, pos)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out = out.transpose(1, 2).reshape(B, N, C)
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
return self.output(out)