mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 04:20:49 +08:00
add
This commit is contained in:
parent
897d2662be
commit
1c5a503715
@ -628,34 +628,17 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("aiter_attention::aiter_flash_attn", mutates_args=())
|
||||
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
|
||||
causal: bool = False, window_size: tuple = (-1, -1),
|
||||
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = False) -> torch.Tensor:
|
||||
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
|
||||
causal=causal, window_size=window_size, bias=bias,
|
||||
alibi_slopes=alibi_slopes, deterministic=deterministic,
|
||||
return_lse=False, return_attn_probs=False,
|
||||
cu_seqlens_q=None, cu_seqlens_kv=None)
|
||||
|
||||
|
||||
@aiter_flash_attn_wrapper.register_fake
|
||||
def aiter_flash_attn_fake(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
window_size=(-1, -1), bias=None, alibi_slopes=None, deterministic=False):
|
||||
# Output shape is the same as q
|
||||
return q.new_empty(q.shape)
|
||||
except AttributeError as error:
|
||||
AITER_ATTN_ERROR = error
|
||||
|
||||
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
|
||||
causal: bool = False, window_size: tuple = (-1, -1),
|
||||
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = False) -> torch.Tensor:
|
||||
assert False, f"Could not define aiter_flash_attn_wrapper: {AITER_ATTN_ERROR}"
|
||||
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
|
||||
causal: bool = False, window_size: tuple = (-1, -1),
|
||||
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = False) -> torch.Tensor:
|
||||
"""Wrapper for aiter.flash_attn_func to handle its specific parameters"""
|
||||
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
|
||||
causal=causal, window_size=window_size, bias=bias,
|
||||
alibi_slopes=alibi_slopes, deterministic=deterministic,
|
||||
return_lse=False, return_attn_probs=False,
|
||||
cu_seqlens_q=None, cu_seqlens_kv=None)
|
||||
|
||||
@wrap_attn
|
||||
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user