This commit is contained in:
Yousef Rafat 2025-12-26 23:14:33 +02:00
parent 3039c7ba14
commit fadc7839cc

View File

@ -446,7 +446,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
if v.ndim == 3: v = v.unsqueeze(0) if v.ndim == 3: v = v.unsqueeze(0)
dim_head = q.shape[-1] dim_head = q.shape[-1]
target_output_shape = (q.shape[1], -1) target_output_shape = (q.shape[1], -1)
b = 1 b = 1
elif skip_reshape: elif skip_reshape:
# b h k d -> b k h d # b h k d -> b k h d