mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
format
This commit is contained in:
parent
e00688e134
commit
0f6f3c93fd
@ -45,8 +45,8 @@ try:
|
|||||||
AITER_ATTENTION_IS_AVAILABLE = True
|
AITER_ATTENTION_IS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if model_management.aiter_attention_enabled():
|
if model_management.aiter_attention_enabled():
|
||||||
logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
|
logging.error("\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
|
||||||
logging.error(f"Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
|
logging.error("Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||||
@ -634,7 +634,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
# Store original inputs for fallback
|
# Store original inputs for fallback
|
||||||
q_orig, k_orig, v_orig, mask_orig = q, k, v, mask
|
q_orig, k_orig, v_orig, mask_orig = q, k, v, mask
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
@ -649,7 +649,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Convert mask to [sq, sk] format for aiter bias
|
# Convert mask to [sq, sk] format for aiter bias
|
||||||
bias = None
|
bias = None
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -685,7 +685,7 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
cu_seqlens_q=None,
|
cu_seqlens_q=None,
|
||||||
cu_seqlens_kv=None,
|
cu_seqlens_kv=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if skip_output_reshape:
|
if skip_output_reshape:
|
||||||
# output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
|
# output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
|
||||||
out = out.transpose(1, 2)
|
out = out.transpose(1, 2)
|
||||||
@ -693,12 +693,12 @@ def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
|
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
|
||||||
out = out.reshape(b, -1, heads * dim_head)
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}")
|
logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}")
|
||||||
# Fallback to attention_pytorch with original inputs
|
# Fallback to attention_pytorch with original inputs
|
||||||
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
|
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
|
||||||
attn_precision=attn_precision, skip_reshape=skip_reshape,
|
attn_precision=attn_precision, skip_reshape=skip_reshape,
|
||||||
skip_output_reshape=skip_output_reshape, **kwargs)
|
skip_output_reshape=skip_output_reshape, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user