This commit is contained in:
doctorpangloss 2025-09-23 10:28:36 -07:00
parent dd9a781654
commit 4a3feee1a2
2 changed files with 5 additions and 7 deletions

View File

@ -27,9 +27,8 @@ try:
SAGE_ATTENTION_IS_AVAILABLE = True SAGE_ATTENTION_IS_AVAILABLE = True
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
if e.name == "sageattention": if e.name == "sageattention" and model_management.sage_attention_enabled():
logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.") logger.debug(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.")
sageattn = torch.nn.functional.scaled_dot_product_attention
flash_attn_func = torch.nn.functional.scaled_dot_product_attention flash_attn_func = torch.nn.functional.scaled_dot_product_attention
FLASH_ATTENTION_IS_AVAILABLE = False FLASH_ATTENTION_IS_AVAILABLE = False
@ -39,8 +38,7 @@ try:
FLASH_ATTENTION_IS_AVAILABLE = True FLASH_ATTENTION_IS_AVAILABLE = True
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
if model_management.flash_attention_enabled(): if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.") logging.debug(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.")
flash_attn_func = torch.nn.functional.scaled_dot_product_attention
REGISTERED_ATTENTION_FUNCTIONS = {} REGISTERED_ATTENTION_FUNCTIONS = {}

View File

@ -87,7 +87,7 @@ class UnrecoverableError(Exception):
pass pass
class TestExceptionNode(CustomNode): class ThrowsExceptionNode(CustomNode):
"""Node that raises a specific exception for testing""" """Node that raises a specific exception for testing"""
@classmethod @classmethod
@ -113,7 +113,7 @@ class TestExceptionNode(CustomNode):
# Export the node mappings # Export the node mappings
TEST_NODE_CLASS_MAPPINGS = { TEST_NODE_CLASS_MAPPINGS = {
"TestExceptionNode": TestExceptionNode, "TestExceptionNode": ThrowsExceptionNode,
} }
TEST_NODE_DISPLAY_NAME_MAPPINGS = { TEST_NODE_DISPLAY_NAME_MAPPINGS = {