mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 20:40:49 +08:00
add
This commit is contained in:
parent
3bea4efc6b
commit
897d2662be
@ -111,6 +111,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
|||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||||
|
attn_group.add_argument("--use-aiter-attention", action="store_true", help="Use aiter attention.")
|
||||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|||||||
@ -39,6 +39,15 @@ except ImportError:
|
|||||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
AITER_ATTENTION_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
import aiter
|
||||||
|
AITER_ATTENTION_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if model_management.aiter_attention_enabled():
|
||||||
|
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter")
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||||
def register_attention_function(name: str, func: Callable):
|
def register_attention_function(name: str, func: Callable):
|
||||||
# avoid replacing existing functions
|
# avoid replacing existing functions
|
||||||
@ -619,11 +628,96 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
@torch.library.custom_op("aiter_attention::aiter_flash_attn", mutates_args=())
|
||||||
|
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
|
||||||
|
causal: bool = False, window_size: tuple = (-1, -1),
|
||||||
|
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
|
deterministic: bool = False) -> torch.Tensor:
|
||||||
|
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
|
||||||
|
causal=causal, window_size=window_size, bias=bias,
|
||||||
|
alibi_slopes=alibi_slopes, deterministic=deterministic,
|
||||||
|
return_lse=False, return_attn_probs=False,
|
||||||
|
cu_seqlens_q=None, cu_seqlens_kv=None)
|
||||||
|
|
||||||
|
|
||||||
|
@aiter_flash_attn_wrapper.register_fake
|
||||||
|
def aiter_flash_attn_fake(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||||
|
window_size=(-1, -1), bias=None, alibi_slopes=None, deterministic=False):
|
||||||
|
# Output shape is the same as q
|
||||||
|
return q.new_empty(q.shape)
|
||||||
|
except AttributeError as error:
|
||||||
|
AITER_ATTN_ERROR = error
|
||||||
|
|
||||||
|
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||||
|
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
|
||||||
|
causal: bool = False, window_size: tuple = (-1, -1),
|
||||||
|
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
|
deterministic: bool = False) -> torch.Tensor:
|
||||||
|
assert False, f"Could not define aiter_flash_attn_wrapper: {AITER_ATTN_ERROR}"
|
||||||
|
|
||||||
|
@wrap_attn
|
||||||
|
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
# reshape to (batch, seqlen, nheads, headdim) for aiter
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# add a batch dimension if there isn't already one
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format
|
||||||
|
out = aiter_flash_attn_wrapper(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
bias=mask,
|
||||||
|
alibi_slopes=None,
|
||||||
|
deterministic=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Aiter Attention failed, using default SDPA: {e}")
|
||||||
|
# fallback needs (batch, nheads, seqlen, headdim) format
|
||||||
|
q_sdpa = q.transpose(1, 2)
|
||||||
|
k_sdpa = k.transpose(1, 2)
|
||||||
|
v_sdpa = v.transpose(1, 2)
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
|
||||||
|
if skip_output_reshape:
|
||||||
|
# output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
else:
|
||||||
|
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
|
||||||
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
logging.info("Using sage attention")
|
logging.info("Using sage attention")
|
||||||
optimized_attention = attention_sage
|
optimized_attention = attention_sage
|
||||||
|
elif model_management.aiter_attention_enabled():
|
||||||
|
logging.info("Using aiter attention")
|
||||||
|
optimized_attention = attention_aiter
|
||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.info("Using xformers attention")
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
@ -647,6 +741,8 @@ optimized_attention_masked = optimized_attention
|
|||||||
# register core-supported attention functions
|
# register core-supported attention functions
|
||||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("sage", attention_sage)
|
register_attention_function("sage", attention_sage)
|
||||||
|
if AITER_ATTENTION_IS_AVAILABLE:
|
||||||
|
register_attention_function("aiter", attention_aiter)
|
||||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("flash", attention_flash)
|
register_attention_function("flash", attention_flash)
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
|
|||||||
@ -1083,6 +1083,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
def aiter_attention_enabled():
|
||||||
|
return args.use_aiter_attention
|
||||||
|
|
||||||
def flash_attention_enabled():
|
def flash_attention_enabled():
|
||||||
return args.use_flash_attention
|
return args.use_flash_attention
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user