xpu-workflows-opt

This commit is contained in:
plusbang 2025-10-21 10:38:11 +08:00
parent 1c10b33f9b
commit 3dab05de8e
3 changed files with 12 additions and 1 deletions

View File

@ -493,7 +493,16 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
# out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if mask is None:
k = k.contiguous()
v = v.contiguous()
import sdpa_kernels
out = sdpa_kernels.sdp_xmx(q, k, v, 1 / math.sqrt(dim_head), 0)
# print("here:", out.isnan().count_nonzero(), out.isinf().count_nonzero())
# breakpoint()
else:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long