mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
Add support for flash attention 3
This commit is contained in:
parent
7d6103325e
commit
603a721405
@ -32,7 +32,7 @@ except ImportError as e:
|
|||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn_interface import flash_attn_func
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if model_management.flash_attention_enabled():
|
if model_management.flash_attention_enabled():
|
||||||
@ -565,7 +565,7 @@ try:
|
|||||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||||
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
|
||||||
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
|
return flash_attn_func(q, k, v, causal=causal)
|
||||||
|
|
||||||
|
|
||||||
@flash_attn_wrapper.register_fake
|
@flash_attn_wrapper.register_fake
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user