mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-18 14:32:49 +08:00
.
This commit is contained in:
parent
0a49718194
commit
b5feac202c
@ -6,7 +6,7 @@ from comfy.ldm.trellis2.vae import VarLenTensor
|
|||||||
|
|
||||||
FLASH_ATTN_3_AVA = True
|
FLASH_ATTN_3_AVA = True
|
||||||
try:
|
try:
|
||||||
import flash_attn_interface as flash_attn_3
|
import flash_attn_interface as flash_attn_3 # noqa: F401
|
||||||
except:
|
except:
|
||||||
FLASH_ATTN_3_AVA = False
|
FLASH_ATTN_3_AVA = False
|
||||||
|
|
||||||
@ -53,8 +53,6 @@ def scaled_dot_product_attention(*args, **kwargs):
|
|||||||
elif num_all_args == 3:
|
elif num_all_args == 3:
|
||||||
out = flash_attn_3.flash_attn_func(q, k, v)
|
out = flash_attn_3.flash_attn_func(q, k, v)
|
||||||
elif optimized_attention.__name__ == 'attention_pytorch':
|
elif optimized_attention.__name__ == 'attention_pytorch':
|
||||||
if 'sdpa' not in globals():
|
|
||||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
|
||||||
if num_all_args == 1:
|
if num_all_args == 1:
|
||||||
q, k, v = qkv.unbind(dim=2)
|
q, k, v = qkv.unbind(dim=2)
|
||||||
elif num_all_args == 2:
|
elif num_all_args == 2:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user