This commit is contained in:
doctorpangloss 2025-09-23 10:28:36 -07:00
parent dd9a781654
commit 4a3feee1a2
2 changed files with 5 additions and 7 deletions

View File

@ -27,9 +27,8 @@ try:
SAGE_ATTENTION_IS_AVAILABLE = True
except (ImportError, ModuleNotFoundError) as e:
if e.name == "sageattention":
logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.")
sageattn = torch.nn.functional.scaled_dot_product_attention
if e.name == "sageattention" and model_management.sage_attention_enabled():
logger.debug(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.")
flash_attn_func = torch.nn.functional.scaled_dot_product_attention
FLASH_ATTENTION_IS_AVAILABLE = False
@ -39,8 +38,7 @@ try:
FLASH_ATTENTION_IS_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.")
flash_attn_func = torch.nn.functional.scaled_dot_product_attention
logging.debug(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.")
REGISTERED_ATTENTION_FUNCTIONS = {}

View File

@ -87,7 +87,7 @@ class UnrecoverableError(Exception):
pass
class TestExceptionNode(CustomNode):
class ThrowsExceptionNode(CustomNode):
"""Node that raises a specific exception for testing"""
@classmethod
@ -113,7 +113,7 @@ class TestExceptionNode(CustomNode):
# Export the node mappings
TEST_NODE_CLASS_MAPPINGS = {
"TestExceptionNode": TestExceptionNode,
"TestExceptionNode": ThrowsExceptionNode,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {