From 4a3feee1a25704f76cedff878785d0b2878bb550 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 23 Sep 2025 10:28:36 -0700 Subject: [PATCH] Tweaks --- comfy/ldm/modules/attention.py | 8 +++----- tests/unit/test_panics.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b4b4419ed..9837cc8a8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -27,9 +27,8 @@ try: SAGE_ATTENTION_IS_AVAILABLE = True except (ImportError, ModuleNotFoundError) as e: - if e.name == "sageattention": - logger.error(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.") - sageattn = torch.nn.functional.scaled_dot_product_attention + if e.name == "sageattention" and model_management.sage_attention_enabled(): + logger.debug(f"To use the `--use-sage-attention` feature, the `sageattention` package must be installed first.") flash_attn_func = torch.nn.functional.scaled_dot_product_attention FLASH_ATTENTION_IS_AVAILABLE = False @@ -39,8 +38,7 @@ try: FLASH_ATTENTION_IS_AVAILABLE = True except (ImportError, ModuleNotFoundError): if model_management.flash_attention_enabled(): - logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.") - flash_attn_func = torch.nn.functional.scaled_dot_product_attention + logging.debug(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.") REGISTERED_ATTENTION_FUNCTIONS = {} diff --git a/tests/unit/test_panics.py b/tests/unit/test_panics.py index 251ed9d97..c27439ffe 100644 --- a/tests/unit/test_panics.py +++ b/tests/unit/test_panics.py @@ -87,7 +87,7 @@ class UnrecoverableError(Exception): pass -class TestExceptionNode(CustomNode): +class ThrowsExceptionNode(CustomNode): """Node that raises a specific exception for testing""" @classmethod @@ -113,7 +113,7 @@ class TestExceptionNode(CustomNode): # Export the node mappings TEST_NODE_CLASS_MAPPINGS = { - "TestExceptionNode": TestExceptionNode, + "TestExceptionNode": ThrowsExceptionNode, } TEST_NODE_DISPLAY_NAME_MAPPINGS = {