mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 04:20:49 +08:00
add
This commit is contained in:
parent
43de98b30e
commit
e00688e134
@ -45,7 +45,8 @@ try:
|
|||||||
AITER_ATTENTION_IS_AVAILABLE = True
|
AITER_ATTENTION_IS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if model_management.aiter_attention_enabled():
|
if model_management.aiter_attention_enabled():
|
||||||
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter")
|
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
|
||||||
|
logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||||
@ -629,42 +630,46 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
||||||
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
|
|
||||||
causal: bool = False, window_size: tuple = (-1, -1),
|
|
||||||
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
deterministic: bool = False) -> torch.Tensor:
|
|
||||||
"""Wrapper for aiter.flash_attn_func to handle its specific parameters"""
|
|
||||||
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
|
|
||||||
causal=causal, window_size=window_size, bias=bias,
|
|
||||||
alibi_slopes=alibi_slopes, deterministic=deterministic,
|
|
||||||
return_lse=False, return_attn_probs=False,
|
|
||||||
cu_seqlens_q=None, cu_seqlens_kv=None)
|
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
# Store original inputs for fallback
|
||||||
|
q_orig, k_orig, v_orig, mask_orig = q, k, v, mask
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
# reshape to (batch, seqlen, nheads, headdim) for aiter
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.view(b, -1, heads, dim_head),
|
lambda t: t.view(b, -1, heads, dim_head),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Convert mask to [sq, sk] format for aiter bias
|
||||||
|
bias = None
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# add a batch dimension if there isn't already one
|
|
||||||
if mask.ndim == 2:
|
if mask.ndim == 2:
|
||||||
mask = mask.unsqueeze(0)
|
bias = mask
|
||||||
# add a heads dimension if there isn't already one
|
elif mask.ndim == 3:
|
||||||
if mask.ndim == 3:
|
seqlen_q = q.shape[1]
|
||||||
mask = mask.unsqueeze(1)
|
if mask.shape[-2] == 1:
|
||||||
|
# [1, 1, sk] -> expand to [sq, sk]
|
||||||
|
bias = mask.squeeze(0).expand(seqlen_q, -1)
|
||||||
|
else:
|
||||||
|
# [batch, sq, sk] -> take first batch
|
||||||
|
bias = mask[0]
|
||||||
|
elif mask.ndim == 4:
|
||||||
|
# [batch, heads, sq, sk] -> take first batch and head
|
||||||
|
bias = mask[0, 0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format
|
# aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format
|
||||||
out = aiter_flash_attn_wrapper(
|
out = aiter.flash_attn_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@ -672,26 +677,29 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=False,
|
causal=False,
|
||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
bias=mask,
|
bias=bias,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
|
return_lse=False,
|
||||||
|
return_attn_probs=False,
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
cu_seqlens_kv=None,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Aiter Attention failed, using default SDPA: {e}")
|
|
||||||
# fallback needs (batch, nheads, seqlen, headdim) format
|
|
||||||
q_sdpa = q.transpose(1, 2)
|
|
||||||
k_sdpa = k.transpose(1, 2)
|
|
||||||
v_sdpa = v.transpose(1, 2)
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
|
|
||||||
if skip_output_reshape:
|
if skip_output_reshape:
|
||||||
# output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
|
# output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
|
||||||
out = out.transpose(1, 2)
|
out = out.transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
|
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
|
||||||
out = out.reshape(b, -1, heads * dim_head)
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}")
|
||||||
|
# Fallback to attention_pytorch with original inputs
|
||||||
|
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
|
||||||
|
attn_precision=attn_precision, skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
@ -1170,5 +1178,3 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
out = x + x_in
|
out = x + x_in
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user