mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Remove the Sage Attention3 switch, but retain the attention function
registration.
This commit is contained in:
parent
483ba1e98b
commit
7883076f5c
@ -112,7 +112,6 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
|||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||||
attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.")
|
|
||||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|||||||
@ -35,12 +35,7 @@ try:
|
|||||||
from sageattn3 import sageattn3_blackwell
|
from sageattn3 import sageattn3_blackwell
|
||||||
SAGE_ATTENTION3_IS_AVAILABLE = True
|
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
if model_management.sage_attention3_enabled():
|
pass
|
||||||
if e.name == "sageattn3":
|
|
||||||
logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell")
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
exit(-1)
|
|
||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
@ -589,7 +584,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
skip_output_reshape=skip_output_reshape,
|
skip_output_reshape=skip_output_reshape,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
B, H, L, D = q.shape
|
B, H, L, D = q.shape
|
||||||
if H != heads:
|
if H != heads:
|
||||||
@ -626,16 +621,16 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
skip_output_reshape=skip_output_reshape,
|
skip_output_reshape=skip_output_reshape,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if not skip_reshape:
|
if not skip_reshape:
|
||||||
q_s, k_s, v_s = map(
|
q_s, k_s, v_s = map(
|
||||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
B, H, L, D = q_s.shape
|
B, H, L, D = q_s.shape
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception_fallback = True
|
exception_fallback = True
|
||||||
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||||
@ -655,7 +650,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||||
else:
|
else:
|
||||||
if skip_output_reshape:
|
if skip_output_reshape:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -726,9 +721,6 @@ optimized_attention = attention_basic
|
|||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
logging.info("Using sage attention")
|
logging.info("Using sage attention")
|
||||||
optimized_attention = attention_sage
|
optimized_attention = attention_sage
|
||||||
if model_management.sage_attention3_enabled():
|
|
||||||
logging.info("Using sage attention 3")
|
|
||||||
optimized_attention = attention3_sage
|
|
||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.info("Using xformers attention")
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
|
|||||||
@ -1189,9 +1189,6 @@ def unpin_memory(tensor):
|
|||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
def sage_attention3_enabled():
|
|
||||||
return args.use_sage_attention3
|
|
||||||
|
|
||||||
def flash_attention_enabled():
|
def flash_attention_enabled():
|
||||||
return args.use_flash_attention
|
return args.use_flash_attention
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user