mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
format
This commit is contained in:
parent
e00688e134
commit
0f6f3c93fd
@ -45,8 +45,8 @@ try:
|
||||
AITER_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.aiter_attention_enabled():
|
||||
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
|
||||
logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
|
||||
logging.error("\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
|
||||
logging.error("Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
|
||||
exit(-1)
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
@ -634,7 +634,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
# Store original inputs for fallback
|
||||
q_orig, k_orig, v_orig, mask_orig = q, k, v, mask
|
||||
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
q, k, v = map(
|
||||
@ -649,7 +649,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Convert mask to [sq, sk] format for aiter bias
|
||||
bias = None
|
||||
if mask is not None:
|
||||
@ -685,7 +685,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_kv=None,
|
||||
)
|
||||
|
||||
|
||||
if skip_output_reshape:
|
||||
# output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
|
||||
out = out.transpose(1, 2)
|
||||
@ -693,12 +693,12 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}")
|
||||
# Fallback to attention_pytorch with original inputs
|
||||
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
|
||||
attn_precision=attn_precision, skip_reshape=skip_reshape,
|
||||
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
|
||||
attn_precision=attn_precision, skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user