diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index e6aa50842..e8e401fd7 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -6,7 +6,7 @@ from comfy.ldm.trellis2.vae import VarLenTensor FLASH_ATTN_3_AVA = True try: - import flash_attn_interface as flash_attn_3 + import flash_attn_interface as flash_attn_3 # noqa: F401 except: FLASH_ATTN_3_AVA = False @@ -53,8 +53,6 @@ def scaled_dot_product_attention(*args, **kwargs): elif num_all_args == 3: out = flash_attn_3.flash_attn_func(q, k, v) elif optimized_attention.__name__ == 'attention_pytorch': - if 'sdpa' not in globals(): - from torch.nn.functional import scaled_dot_product_attention as sdpa if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: