From 684d180446ef5297aaee0f985abb9b5006f8cde5 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 18 Feb 2025 10:57:23 -0800 Subject: [PATCH] Users can now configure their workers to panic if they have out of memory exceptions, which occur due to complex failures in custom nodes --- comfy/cli_args.py | 11 ++ comfy/cli_args_types.py | 2 + comfy/client/embedded_comfy_client.py | 6 +- comfy/cmd/execution.py | 12 ++ comfy/distributed/executors.py | 2 + comfy/execution_ext.py | 56 ++++++++ tests/unit/test_panics.py | 182 ++++++++++++++++++++++++++ 7 files changed, 269 insertions(+), 2 deletions(-) create mode 100644 comfy/execution_ext.py create mode 100644 tests/unit/test_panics.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 53e747752..eb5dbc545 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -190,6 +190,17 @@ def _create_parser() -> EnhancedConfigArgParser: """, ) + parser.add_argument( + '--panic-when', + action='append', + help=""" + List of fully qualified exception class names to panic (os.exit(1)) when a workflow raises it. + Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a + comma-separated list.""", + type=str, + default=[] + ) + def is_valid_directory(path: Optional[str]) -> Optional[str]: """Validate if the given path is a directory.""" if path is None: diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index a7a87cb01..d6e0e6859 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -122,6 +122,7 @@ class Configuration(dict): anthropic_api_key (str): Configures the Anthropic API key for its nodes related to Claude functionality. Visit https://console.anthropic.com/settings/keys to create this key. user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path. log_stdout (bool): Send normal process output to stdout instead of stderr (default) + panic_when (list[str]): List of fully qualified exception class names to panic (os.exit(1)) when a workflow raises it. """ def __init__(self, **kwargs): @@ -220,6 +221,7 @@ class Configuration(dict): self.ideogram_api_key: Optional[str] = None self.anthropic_api_key: Optional[str] = None self.user_directory: Optional[str] = None + self.panic_when: list[str] = [] def __getattr__(self, item): if item not in self: diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 0f83e6bcc..4824e76a6 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -41,12 +41,14 @@ def _execute_prompt( span_context: Context = propagate.extract(span_context) token = attach(span_context) try: - return __execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration) + # there is never an event loop running on a thread or process pool thread here + # this also guarantees nodes will be able to successfully call await + return asyncio.run(__execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration)) finally: detach(token) -def __execute_prompt( +async def __execute_prompt( prompt: dict, prompt_id: str, client_id: str, diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index cbccf1e00..4e9059763 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import copy import heapq import inspect @@ -20,6 +21,7 @@ from opentelemetry.trace import get_current_span, StatusCode, Status from .main_pre import tracer from .. import interruption from .. import model_management +from ..cli_args import args from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ @@ -28,6 +30,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio from ..component_model.files import canonicalize_path from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import context_execute_node, context_execute_prompt +from ..execution_ext import should_panic_on_exception from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode @@ -110,6 +113,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e for x in inputs: input_data = inputs[x] input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) + def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -480,6 +484,14 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, logging.error("Got an OOM, unloading all loaded models.") model_management.unload_all_models() + if should_panic_on_exception(ex, args.panic_when): + logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit") + + def sys_exit(*args): + sys.exit(1) + + asyncio.get_event_loop().call_soon_threadsafe(sys_exit, ()) + return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex) executed.add(unique_id) diff --git a/comfy/distributed/executors.py b/comfy/distributed/executors.py index 626489731..97f1a4798 100644 --- a/comfy/distributed/executors.py +++ b/comfy/distributed/executors.py @@ -1,5 +1,7 @@ +import asyncio import concurrent import contextvars +import threading import typing from concurrent.futures import Future, ThreadPoolExecutor from functools import partial diff --git a/comfy/execution_ext.py b/comfy/execution_ext.py new file mode 100644 index 000000000..24012337e --- /dev/null +++ b/comfy/execution_ext.py @@ -0,0 +1,56 @@ +import importlib + + +def import_exception_class(fqn: str): + """ + Imports an exception class from its fully qualified name. + Example: 'torch.cuda.OutOfMemoryError' -> torch.cuda.OutOfMemoryError + + Args: + fqn: Fully qualified name of the exception class + + Returns: + The exception class + + Raises: + ValueError: If the class cannot be imported or is not a subclass of Exception + """ + try: + module_path, class_name = fqn.rsplit('.', 1) + module = importlib.import_module(module_path) + exc_class = getattr(module, class_name) + + if not isinstance(exc_class, type) or not issubclass(exc_class, Exception): + raise ValueError(f"{fqn} is not an exception class") + + return exc_class + except (ImportError, AttributeError) as e: + raise ValueError(f"Could not import exception class {fqn}: {str(e)}") + + +def should_panic_on_exception(exc: Exception, panic_classes: list[str]) -> bool: + """ + Checks if the given exception matches any of the specified panic classes. + + Args: + exc: The exception to check + panic_classes: List of fully qualified exception class names + + Returns: + True if the exception is an instance of one of the specified classes + """ + # Handle comma-separated lists (from config files or env vars) + expanded_classes = [] + for class_spec in panic_classes: + expanded_classes.extend(name.strip() for name in class_spec.split(',')) + + # Import all exception classes + try: + exception_types = [import_exception_class(name) + for name in expanded_classes if name] + except ValueError as e: + print(f"Warning: {str(e)}") + return False + + # Check if exception matches any of the specified types + return any(isinstance(exc, exc_type) for exc_type in exception_types) diff --git a/tests/unit/test_panics.py b/tests/unit/test_panics.py new file mode 100644 index 000000000..3966db7ec --- /dev/null +++ b/tests/unit/test_panics.py @@ -0,0 +1,182 @@ +import asyncio +import threading +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args_types import Configuration +from comfy.client.embedded_comfy_client import EmbeddedComfyClient +from comfy.cmd.execution import nodes +from comfy.component_model.make_mutable import make_mutable +from comfy.component_model.tensor_types import RGBImageBatch +from comfy.nodes.package_typing import CustomNode, ExportedNodes + + +@pytest.mark.asyncio +async def test_event_loop_callbacks(): + """Test to understand event loop callback behavior in pytest-asyncio""" + callback_executed = False + current_thread = threading.current_thread() + current_loop = asyncio.get_running_loop() + + def callback(*args): + nonlocal callback_executed + print(f"Callback executing in thread: {threading.current_thread()}") + print(f"Original thread was: {current_thread}") + callback_executed = True + + print(f"Test running in thread: {current_thread}") + print(f"Test using event loop: {current_loop}") + + # Try different ways of scheduling the callback + current_loop.call_soon(callback) + await asyncio.sleep(0) + print(f"After sleep(0), callback_executed: {callback_executed}") + + if not callback_executed: + current_loop.call_soon_threadsafe(callback) + await asyncio.sleep(0) + print(f"After threadsafe callback, callback_executed: {callback_executed}") + + if not callback_executed: + # Try running callback in event loop directly + await asyncio.get_event_loop().run_in_executor(None, callback) + print(f"After run_in_executor, callback_executed: {callback_executed}") + + assert callback_executed, "Callback was never executed" + + +@pytest.mark.asyncio +async def test_separate_thread_callback(): + """Test callbacks scheduled from a separate thread""" + callback_executed = False + event = threading.Event() + main_loop = asyncio.get_running_loop() + + def thread_func(): + print(f"Thread function running in: {threading.current_thread()}") + main_loop.call_soon_threadsafe(lambda *_: event.set()) + + print(f"Test running in thread: {threading.current_thread()}") + print(f"Test using event loop: {main_loop}") + + # Start thread that will schedule callback + thread = threading.Thread(target=thread_func) + thread.start() + + # Wait for event with timeout + try: + await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, event.wait), + timeout=1.0 + ) + print("Event was set!") + except asyncio.TimeoutError: + print("Timed out waiting for event!") + assert False, "Event was never set" + + thread.join() + + +# Custom test exception that we'll configure to panic on +class TestUnrecoverableError(Exception): + pass + + +class TestExceptionNode(CustomNode): + """Node that raises a specific exception for testing""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "should_raise": ("BOOL", {"default": True}), + }, + } + + RETURN_TYPES = ("IMAGE",) # Make it an output node by returning IMAGE + FUNCTION = "raise_exception" + CATEGORY = "Testing/Nodes" + OUTPUT_NODE = True + + def raise_exception(self, should_raise=True) -> tuple[RGBImageBatch]: + if should_raise: + raise TestUnrecoverableError("Test exception from node") + else: + # Return a dummy image if not raising + return (torch.zeros([1, 64, 64, 3]),) + + +# Export the node mappings +TEST_NODE_CLASS_MAPPINGS = { + "TestExceptionNode": TestExceptionNode, +} + +TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestExceptionNode": "Test Exception Node", +} + + +def create_failing_workflow(): + """Create a workflow that uses our test node to raise an exception""" + return make_mutable({ + "1": { + "class_type": "TestExceptionNode", + "inputs": { + "should_raise": True + } + } + }) + + +@pytest.mark.asyncio +async def test_panic_on_exception(): + # Set up the test nodes + nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)) + + # Create configuration with our test exception in panic_when + config = Configuration() + config.panic_when = [f"{__name__}.TestUnrecoverableError"] + + # Mock sys.exit to prevent actual exit and verify it's called + with patch('sys.exit') as mock_exit: + try: + async with EmbeddedComfyClient(configuration=config) as client: + # Queue our failing workflow + await client.queue_prompt(create_failing_workflow()) + except TestUnrecoverableError: + # We expect the exception to be raised here + pass + + # Give the event loop a chance to process the exit callback + await asyncio.sleep(0) + + # Verify sys.exit was called with code 1 + mock_exit.assert_called_once_with(1) + + +@pytest.mark.asyncio +async def test_no_panic_when_disabled(): + """Verify that the same exception doesn't trigger exit when not in panic_when""" + # Set up the test nodes + nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)) + + # Create configuration without the exception in panic_when + config = Configuration() + + # Mock sys.exit to verify it's not called + with patch('sys.exit') as mock_exit: + try: + async with EmbeddedComfyClient(configuration=config) as client: + # Queue our failing workflow + await client.queue_prompt(create_failing_workflow()) + except TestUnrecoverableError: + # We expect the exception to be raised here + pass + + # Give the event loop a chance to process any callbacks + await asyncio.sleep(0.1) + + # Verify sys.exit was not called + mock_exit.assert_not_called()