diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 76bab22b0..f9459c2dc 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -628,34 +628,17 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out -try: - @torch.library.custom_op("aiter_attention::aiter_flash_attn", mutates_args=()) - def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, softmax_scale: Optional[float] = None, - causal: bool = False, window_size: tuple = (-1, -1), - bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False) -> torch.Tensor: - return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, - causal=causal, window_size=window_size, bias=bias, - alibi_slopes=alibi_slopes, deterministic=deterministic, - return_lse=False, return_attn_probs=False, - cu_seqlens_q=None, cu_seqlens_kv=None) - - - @aiter_flash_attn_wrapper.register_fake - def aiter_flash_attn_fake(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, - window_size=(-1, -1), bias=None, alibi_slopes=None, deterministic=False): - # Output shape is the same as q - return q.new_empty(q.shape) -except AttributeError as error: - AITER_ATTN_ERROR = error - - def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, softmax_scale: Optional[float] = None, - causal: bool = False, window_size: tuple = (-1, -1), - bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, - deterministic: bool = False) -> torch.Tensor: - assert False, f"Could not define aiter_flash_attn_wrapper: {AITER_ATTN_ERROR}" +def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, softmax_scale: Optional[float] = None, + causal: bool = False, window_size: tuple = (-1, -1), + bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False) -> torch.Tensor: + """Wrapper for aiter.flash_attn_func to handle its specific parameters""" + return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, + causal=causal, window_size=window_size, bias=bias, + alibi_slopes=alibi_slopes, deterministic=deterministic, + return_lse=False, return_attn_probs=False, + cu_seqlens_q=None, cu_seqlens_kv=None) @wrap_attn def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):