This commit is contained in:
Yousef Rafat 2026-02-18 22:01:09 +02:00
parent 0a49718194
commit b5feac202c

View File

@ -6,7 +6,7 @@ from comfy.ldm.trellis2.vae import VarLenTensor
FLASH_ATTN_3_AVA = True
try:
import flash_attn_interface as flash_attn_3
import flash_attn_interface as flash_attn_3 # noqa: F401
except:
FLASH_ATTN_3_AVA = False
@ -53,8 +53,6 @@ def scaled_dot_product_attention(*args, **kwargs):
elif num_all_args == 3:
out = flash_attn_3.flash_attn_func(q, k, v)
elif optimized_attention.__name__ == 'attention_pytorch':
if 'sdpa' not in globals():
from torch.nn.functional import scaled_dot_product_attention as sdpa
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2: