diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7437e0567..8a53bc752 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -32,7 +32,7 @@ except ImportError as e: FLASH_ATTENTION_IS_AVAILABLE = False try: - from flash_attn import flash_attn_func + from flash_attn_interface import flash_attn_func FLASH_ATTENTION_IS_AVAILABLE = True except ImportError: if model_management.flash_attention_enabled(): @@ -565,7 +565,7 @@ try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: - return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) + return flash_attn_func(q, k, v, causal=causal) @flash_attn_wrapper.register_fake