diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 209fc185b..9e5ce893c 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,6 +112,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-aiter-attention", action="store_true", help="Use aiter attention.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..d3eca8814 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -39,6 +39,16 @@ except ImportError: logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +AITER_ATTENTION_IS_AVAILABLE = False +try: + import aiter + AITER_ATTENTION_IS_AVAILABLE = True +except ImportError: + if model_management.aiter_attention_enabled(): + logging.error("\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.") + logging.error("Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation") + exit(-1) + REGISTERED_ATTENTION_FUNCTIONS = {} def register_attention_function(name: str, func: Callable): # avoid replacing existing functions @@ -615,6 +625,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape except Exception as e: logging.warning(f"Flash Attention failed, using default SDPA: {e}") out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -622,11 +633,86 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out +@wrap_attn +def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + # Store original inputs for fallback + q_orig, k_orig, v_orig, mask_orig = q, k, v, mask + + if skip_reshape: + b, _, _, dim_head = q.shape + q, k, v = map( + lambda t: t.transpose(1, 2), + (q, k, v), + ) + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head), + (q, k, v), + ) + + + # Convert mask to [sq, sk] format for aiter bias + bias = None + if mask is not None: + if mask.ndim == 2: + bias = mask + elif mask.ndim == 3: + seqlen_q = q.shape[1] + if mask.shape[-2] == 1: + # [1, 1, sk] -> expand to [sq, sk] + bias = mask.squeeze(0).expand(seqlen_q, -1) + else: + # [batch, sq, sk] -> take first batch + bias = mask[0] + elif mask.ndim == 4: + # [batch, heads, sq, sk] -> take first batch and head + bias = mask[0, 0] + + try: + # aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format + out = aiter.flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + bias=bias, + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, + ) + + if skip_output_reshape: + # output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) + out = out.transpose(1, 2) + else: + # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) + out = out.reshape(b, -1, heads * dim_head) + return out + + except Exception as e: + logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}") + # Fallback to attention_pytorch with original inputs + return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig, + attn_precision=attn_precision, skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, **kwargs) + + optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage +elif model_management.aiter_attention_enabled(): + logging.info("Using aiter attention") + optimized_attention = attention_aiter elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers @@ -650,6 +736,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if AITER_ATTENTION_IS_AVAILABLE: + register_attention_function("aiter", attention_aiter) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): @@ -1093,5 +1181,3 @@ class SpatialVideoTransformer(SpatialTransformer): x = self.proj_out(x) out = x + x_in return out - - diff --git a/comfy/model_management.py b/comfy/model_management.py index 40717b1e4..a4bbcdea1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1189,6 +1189,9 @@ def unpin_memory(tensor): def sage_attention_enabled(): return args.use_sage_attention +def aiter_attention_enabled(): + return args.use_aiter_attention + def flash_attention_enabled(): return args.use_flash_attention