Try to fix ace text encoder slowness on some configs. (#12290)

This commit is contained in:
comfyanonymous 2026-02-04 16:37:05 -08:00 committed by GitHub
parent 26dd7eb421
commit c8fcbd66ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,6 +54,8 @@ try:
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
def scaled_dot_product_attention(q, k, v, *args, **kwargs): def scaled_dot_product_attention(q, k, v, *args, **kwargs):
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
else: else: