Users can now configure their workers to panic if they have out of memory exceptions, which occur due to complex failures in custom nodes

This commit is contained in:
doctorpangloss 2025-02-18 10:57:23 -08:00
parent d04288ce8d
commit 684d180446
7 changed files with 269 additions and 2 deletions

View File

@ -190,6 +190,17 @@ def _create_parser() -> EnhancedConfigArgParser:
""", """,
) )
parser.add_argument(
'--panic-when',
action='append',
help="""
List of fully qualified exception class names to panic (os.exit(1)) when a workflow raises it.
Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a
comma-separated list.""",
type=str,
default=[]
)
def is_valid_directory(path: Optional[str]) -> Optional[str]: def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory.""" """Validate if the given path is a directory."""
if path is None: if path is None:

View File

@ -122,6 +122,7 @@ class Configuration(dict):
anthropic_api_key (str): Configures the Anthropic API key for its nodes related to Claude functionality. Visit https://console.anthropic.com/settings/keys to create this key. anthropic_api_key (str): Configures the Anthropic API key for its nodes related to Claude functionality. Visit https://console.anthropic.com/settings/keys to create this key.
user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path. user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path.
log_stdout (bool): Send normal process output to stdout instead of stderr (default) log_stdout (bool): Send normal process output to stdout instead of stderr (default)
panic_when (list[str]): List of fully qualified exception class names to panic (os.exit(1)) when a workflow raises it.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -220,6 +221,7 @@ class Configuration(dict):
self.ideogram_api_key: Optional[str] = None self.ideogram_api_key: Optional[str] = None
self.anthropic_api_key: Optional[str] = None self.anthropic_api_key: Optional[str] = None
self.user_directory: Optional[str] = None self.user_directory: Optional[str] = None
self.panic_when: list[str] = []
def __getattr__(self, item): def __getattr__(self, item):
if item not in self: if item not in self:

View File

@ -41,12 +41,14 @@ def _execute_prompt(
span_context: Context = propagate.extract(span_context) span_context: Context = propagate.extract(span_context)
token = attach(span_context) token = attach(span_context)
try: try:
return __execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration) # there is never an event loop running on a thread or process pool thread here
# this also guarantees nodes will be able to successfully call await
return asyncio.run(__execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration))
finally: finally:
detach(token) detach(token)
def __execute_prompt( async def __execute_prompt(
prompt: dict, prompt: dict,
prompt_id: str, prompt_id: str,
client_id: str, client_id: str,

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import copy import copy
import heapq import heapq
import inspect import inspect
@ -20,6 +21,7 @@ from opentelemetry.trace import get_current_span, StatusCode, Status
from .main_pre import tracer from .main_pre import tracer
from .. import interruption from .. import interruption
from .. import model_management from .. import model_management
from ..cli_args import args
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
@ -28,6 +30,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
from ..component_model.files import canonicalize_path from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import context_execute_node, context_execute_prompt from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception
from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
@ -110,6 +113,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing(): def mark_missing():
missing_keys[x] = True missing_keys[x] = True
input_data_all[x] = (None,) input_data_all[x] = (None,)
@ -480,6 +484,14 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data,
logging.error("Got an OOM, unloading all loaded models.") logging.error("Got an OOM, unloading all loaded models.")
model_management.unload_all_models() model_management.unload_all_models()
if should_panic_on_exception(ex, args.panic_when):
logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
def sys_exit(*args):
sys.exit(1)
asyncio.get_event_loop().call_soon_threadsafe(sys_exit, ())
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex) return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id) executed.add(unique_id)

View File

@ -1,5 +1,7 @@
import asyncio
import concurrent import concurrent
import contextvars import contextvars
import threading
import typing import typing
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from functools import partial from functools import partial

56
comfy/execution_ext.py Normal file
View File

@ -0,0 +1,56 @@
import importlib
def import_exception_class(fqn: str):
"""
Imports an exception class from its fully qualified name.
Example: 'torch.cuda.OutOfMemoryError' -> torch.cuda.OutOfMemoryError
Args:
fqn: Fully qualified name of the exception class
Returns:
The exception class
Raises:
ValueError: If the class cannot be imported or is not a subclass of Exception
"""
try:
module_path, class_name = fqn.rsplit('.', 1)
module = importlib.import_module(module_path)
exc_class = getattr(module, class_name)
if not isinstance(exc_class, type) or not issubclass(exc_class, Exception):
raise ValueError(f"{fqn} is not an exception class")
return exc_class
except (ImportError, AttributeError) as e:
raise ValueError(f"Could not import exception class {fqn}: {str(e)}")
def should_panic_on_exception(exc: Exception, panic_classes: list[str]) -> bool:
"""
Checks if the given exception matches any of the specified panic classes.
Args:
exc: The exception to check
panic_classes: List of fully qualified exception class names
Returns:
True if the exception is an instance of one of the specified classes
"""
# Handle comma-separated lists (from config files or env vars)
expanded_classes = []
for class_spec in panic_classes:
expanded_classes.extend(name.strip() for name in class_spec.split(','))
# Import all exception classes
try:
exception_types = [import_exception_class(name)
for name in expanded_classes if name]
except ValueError as e:
print(f"Warning: {str(e)}")
return False
# Check if exception matches any of the specified types
return any(isinstance(exc, exc_type) for exc_type in exception_types)

182
tests/unit/test_panics.py Normal file
View File

@ -0,0 +1,182 @@
import asyncio
import threading
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.cmd.execution import nodes
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.tensor_types import RGBImageBatch
from comfy.nodes.package_typing import CustomNode, ExportedNodes
@pytest.mark.asyncio
async def test_event_loop_callbacks():
"""Test to understand event loop callback behavior in pytest-asyncio"""
callback_executed = False
current_thread = threading.current_thread()
current_loop = asyncio.get_running_loop()
def callback(*args):
nonlocal callback_executed
print(f"Callback executing in thread: {threading.current_thread()}")
print(f"Original thread was: {current_thread}")
callback_executed = True
print(f"Test running in thread: {current_thread}")
print(f"Test using event loop: {current_loop}")
# Try different ways of scheduling the callback
current_loop.call_soon(callback)
await asyncio.sleep(0)
print(f"After sleep(0), callback_executed: {callback_executed}")
if not callback_executed:
current_loop.call_soon_threadsafe(callback)
await asyncio.sleep(0)
print(f"After threadsafe callback, callback_executed: {callback_executed}")
if not callback_executed:
# Try running callback in event loop directly
await asyncio.get_event_loop().run_in_executor(None, callback)
print(f"After run_in_executor, callback_executed: {callback_executed}")
assert callback_executed, "Callback was never executed"
@pytest.mark.asyncio
async def test_separate_thread_callback():
"""Test callbacks scheduled from a separate thread"""
callback_executed = False
event = threading.Event()
main_loop = asyncio.get_running_loop()
def thread_func():
print(f"Thread function running in: {threading.current_thread()}")
main_loop.call_soon_threadsafe(lambda *_: event.set())
print(f"Test running in thread: {threading.current_thread()}")
print(f"Test using event loop: {main_loop}")
# Start thread that will schedule callback
thread = threading.Thread(target=thread_func)
thread.start()
# Wait for event with timeout
try:
await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(None, event.wait),
timeout=1.0
)
print("Event was set!")
except asyncio.TimeoutError:
print("Timed out waiting for event!")
assert False, "Event was never set"
thread.join()
# Custom test exception that we'll configure to panic on
class TestUnrecoverableError(Exception):
pass
class TestExceptionNode(CustomNode):
"""Node that raises a specific exception for testing"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"should_raise": ("BOOL", {"default": True}),
},
}
RETURN_TYPES = ("IMAGE",) # Make it an output node by returning IMAGE
FUNCTION = "raise_exception"
CATEGORY = "Testing/Nodes"
OUTPUT_NODE = True
def raise_exception(self, should_raise=True) -> tuple[RGBImageBatch]:
if should_raise:
raise TestUnrecoverableError("Test exception from node")
else:
# Return a dummy image if not raising
return (torch.zeros([1, 64, 64, 3]),)
# Export the node mappings
TEST_NODE_CLASS_MAPPINGS = {
"TestExceptionNode": TestExceptionNode,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestExceptionNode": "Test Exception Node",
}
def create_failing_workflow():
"""Create a workflow that uses our test node to raise an exception"""
return make_mutable({
"1": {
"class_type": "TestExceptionNode",
"inputs": {
"should_raise": True
}
}
})
@pytest.mark.asyncio
async def test_panic_on_exception():
# Set up the test nodes
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS))
# Create configuration with our test exception in panic_when
config = Configuration()
config.panic_when = [f"{__name__}.TestUnrecoverableError"]
# Mock sys.exit to prevent actual exit and verify it's called
with patch('sys.exit') as mock_exit:
try:
async with EmbeddedComfyClient(configuration=config) as client:
# Queue our failing workflow
await client.queue_prompt(create_failing_workflow())
except TestUnrecoverableError:
# We expect the exception to be raised here
pass
# Give the event loop a chance to process the exit callback
await asyncio.sleep(0)
# Verify sys.exit was called with code 1
mock_exit.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_no_panic_when_disabled():
"""Verify that the same exception doesn't trigger exit when not in panic_when"""
# Set up the test nodes
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS))
# Create configuration without the exception in panic_when
config = Configuration()
# Mock sys.exit to verify it's not called
with patch('sys.exit') as mock_exit:
try:
async with EmbeddedComfyClient(configuration=config) as client:
# Queue our failing workflow
await client.queue_prompt(create_failing_workflow())
except TestUnrecoverableError:
# We expect the exception to be raised here
pass
# Give the event loop a chance to process any callbacks
await asyncio.sleep(0.1)
# Verify sys.exit was not called
mock_exit.assert_not_called()