diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 43991fd40..c18266672 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -45,8 +45,8 @@ try: AITER_ATTENTION_IS_AVAILABLE = True except ImportError: if model_management.aiter_attention_enabled(): - logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.") - logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation") + logging.error("\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.") + logging.error("Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation") exit(-1) REGISTERED_ATTENTION_FUNCTIONS = {} @@ -634,7 +634,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): # Store original inputs for fallback q_orig, k_orig, v_orig, mask_orig = q, k, v, mask - + if skip_reshape: b, _, _, dim_head = q.shape q, k, v = map( @@ -649,7 +649,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape (q, k, v), ) - + # Convert mask to [sq, sk] format for aiter bias bias = None if mask is not None: @@ -685,7 +685,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape cu_seqlens_q=None, cu_seqlens_kv=None, ) - + if skip_output_reshape: # output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) out = out.transpose(1, 2) @@ -693,12 +693,12 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) out = out.reshape(b, -1, heads * dim_head) return out - + except Exception as e: logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}") # Fallback to attention_pytorch with original inputs - return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig, - attn_precision=attn_precision, skip_reshape=skip_reshape, + return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig, + attn_precision=attn_precision, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)