mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 23:00:20 +08:00
Users can now configure their workers to panic if they have out of memory exceptions, which occur due to complex failures in custom nodes
This commit is contained in:
parent
d04288ce8d
commit
684d180446
@ -190,6 +190,17 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--panic-when',
|
||||||
|
action='append',
|
||||||
|
help="""
|
||||||
|
List of fully qualified exception class names to panic (os.exit(1)) when a workflow raises it.
|
||||||
|
Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a
|
||||||
|
comma-separated list.""",
|
||||||
|
type=str,
|
||||||
|
default=[]
|
||||||
|
)
|
||||||
|
|
||||||
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
||||||
"""Validate if the given path is a directory."""
|
"""Validate if the given path is a directory."""
|
||||||
if path is None:
|
if path is None:
|
||||||
|
|||||||
@ -122,6 +122,7 @@ class Configuration(dict):
|
|||||||
anthropic_api_key (str): Configures the Anthropic API key for its nodes related to Claude functionality. Visit https://console.anthropic.com/settings/keys to create this key.
|
anthropic_api_key (str): Configures the Anthropic API key for its nodes related to Claude functionality. Visit https://console.anthropic.com/settings/keys to create this key.
|
||||||
user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path.
|
user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path.
|
||||||
log_stdout (bool): Send normal process output to stdout instead of stderr (default)
|
log_stdout (bool): Send normal process output to stdout instead of stderr (default)
|
||||||
|
panic_when (list[str]): List of fully qualified exception class names to panic (os.exit(1)) when a workflow raises it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -220,6 +221,7 @@ class Configuration(dict):
|
|||||||
self.ideogram_api_key: Optional[str] = None
|
self.ideogram_api_key: Optional[str] = None
|
||||||
self.anthropic_api_key: Optional[str] = None
|
self.anthropic_api_key: Optional[str] = None
|
||||||
self.user_directory: Optional[str] = None
|
self.user_directory: Optional[str] = None
|
||||||
|
self.panic_when: list[str] = []
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
if item not in self:
|
if item not in self:
|
||||||
|
|||||||
@ -41,12 +41,14 @@ def _execute_prompt(
|
|||||||
span_context: Context = propagate.extract(span_context)
|
span_context: Context = propagate.extract(span_context)
|
||||||
token = attach(span_context)
|
token = attach(span_context)
|
||||||
try:
|
try:
|
||||||
return __execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration)
|
# there is never an event loop running on a thread or process pool thread here
|
||||||
|
# this also guarantees nodes will be able to successfully call await
|
||||||
|
return asyncio.run(__execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration))
|
||||||
finally:
|
finally:
|
||||||
detach(token)
|
detach(token)
|
||||||
|
|
||||||
|
|
||||||
def __execute_prompt(
|
async def __execute_prompt(
|
||||||
prompt: dict,
|
prompt: dict,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
client_id: str,
|
client_id: str,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import heapq
|
import heapq
|
||||||
import inspect
|
import inspect
|
||||||
@ -20,6 +21,7 @@ from opentelemetry.trace import get_current_span, StatusCode, Status
|
|||||||
from .main_pre import tracer
|
from .main_pre import tracer
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
|
from ..cli_args import args
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
||||||
@ -28,6 +30,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
|
|||||||
from ..component_model.files import canonicalize_path
|
from ..component_model.files import canonicalize_path
|
||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||||
from ..execution_context import context_execute_node, context_execute_prompt
|
from ..execution_context import context_execute_node, context_execute_prompt
|
||||||
|
from ..execution_ext import should_panic_on_exception
|
||||||
from ..nodes.package import import_all_nodes_in_workspace
|
from ..nodes.package import import_all_nodes_in_workspace
|
||||||
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||||
|
|
||||||
@ -110,6 +113,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
|
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@ -480,6 +484,14 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data,
|
|||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
model_management.unload_all_models()
|
model_management.unload_all_models()
|
||||||
|
|
||||||
|
if should_panic_on_exception(ex, args.panic_when):
|
||||||
|
logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
||||||
|
|
||||||
|
def sys_exit(*args):
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
asyncio.get_event_loop().call_soon_threadsafe(sys_exit, ())
|
||||||
|
|
||||||
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
|
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
|
||||||
|
|
||||||
executed.add(unique_id)
|
executed.add(unique_id)
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
|
import asyncio
|
||||||
import concurrent
|
import concurrent
|
||||||
import contextvars
|
import contextvars
|
||||||
|
import threading
|
||||||
import typing
|
import typing
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|||||||
56
comfy/execution_ext.py
Normal file
56
comfy/execution_ext.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
def import_exception_class(fqn: str):
|
||||||
|
"""
|
||||||
|
Imports an exception class from its fully qualified name.
|
||||||
|
Example: 'torch.cuda.OutOfMemoryError' -> torch.cuda.OutOfMemoryError
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fqn: Fully qualified name of the exception class
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The exception class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the class cannot be imported or is not a subclass of Exception
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
module_path, class_name = fqn.rsplit('.', 1)
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
exc_class = getattr(module, class_name)
|
||||||
|
|
||||||
|
if not isinstance(exc_class, type) or not issubclass(exc_class, Exception):
|
||||||
|
raise ValueError(f"{fqn} is not an exception class")
|
||||||
|
|
||||||
|
return exc_class
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise ValueError(f"Could not import exception class {fqn}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def should_panic_on_exception(exc: Exception, panic_classes: list[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the given exception matches any of the specified panic classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc: The exception to check
|
||||||
|
panic_classes: List of fully qualified exception class names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the exception is an instance of one of the specified classes
|
||||||
|
"""
|
||||||
|
# Handle comma-separated lists (from config files or env vars)
|
||||||
|
expanded_classes = []
|
||||||
|
for class_spec in panic_classes:
|
||||||
|
expanded_classes.extend(name.strip() for name in class_spec.split(','))
|
||||||
|
|
||||||
|
# Import all exception classes
|
||||||
|
try:
|
||||||
|
exception_types = [import_exception_class(name)
|
||||||
|
for name in expanded_classes if name]
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Warning: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if exception matches any of the specified types
|
||||||
|
return any(isinstance(exc, exc_type) for exc_type in exception_types)
|
||||||
182
tests/unit/test_panics.py
Normal file
182
tests/unit/test_panics.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args_types import Configuration
|
||||||
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
|
from comfy.cmd.execution import nodes
|
||||||
|
from comfy.component_model.make_mutable import make_mutable
|
||||||
|
from comfy.component_model.tensor_types import RGBImageBatch
|
||||||
|
from comfy.nodes.package_typing import CustomNode, ExportedNodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_loop_callbacks():
|
||||||
|
"""Test to understand event loop callback behavior in pytest-asyncio"""
|
||||||
|
callback_executed = False
|
||||||
|
current_thread = threading.current_thread()
|
||||||
|
current_loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def callback(*args):
|
||||||
|
nonlocal callback_executed
|
||||||
|
print(f"Callback executing in thread: {threading.current_thread()}")
|
||||||
|
print(f"Original thread was: {current_thread}")
|
||||||
|
callback_executed = True
|
||||||
|
|
||||||
|
print(f"Test running in thread: {current_thread}")
|
||||||
|
print(f"Test using event loop: {current_loop}")
|
||||||
|
|
||||||
|
# Try different ways of scheduling the callback
|
||||||
|
current_loop.call_soon(callback)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
print(f"After sleep(0), callback_executed: {callback_executed}")
|
||||||
|
|
||||||
|
if not callback_executed:
|
||||||
|
current_loop.call_soon_threadsafe(callback)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
print(f"After threadsafe callback, callback_executed: {callback_executed}")
|
||||||
|
|
||||||
|
if not callback_executed:
|
||||||
|
# Try running callback in event loop directly
|
||||||
|
await asyncio.get_event_loop().run_in_executor(None, callback)
|
||||||
|
print(f"After run_in_executor, callback_executed: {callback_executed}")
|
||||||
|
|
||||||
|
assert callback_executed, "Callback was never executed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_separate_thread_callback():
|
||||||
|
"""Test callbacks scheduled from a separate thread"""
|
||||||
|
callback_executed = False
|
||||||
|
event = threading.Event()
|
||||||
|
main_loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def thread_func():
|
||||||
|
print(f"Thread function running in: {threading.current_thread()}")
|
||||||
|
main_loop.call_soon_threadsafe(lambda *_: event.set())
|
||||||
|
|
||||||
|
print(f"Test running in thread: {threading.current_thread()}")
|
||||||
|
print(f"Test using event loop: {main_loop}")
|
||||||
|
|
||||||
|
# Start thread that will schedule callback
|
||||||
|
thread = threading.Thread(target=thread_func)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Wait for event with timeout
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.get_event_loop().run_in_executor(None, event.wait),
|
||||||
|
timeout=1.0
|
||||||
|
)
|
||||||
|
print("Event was set!")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("Timed out waiting for event!")
|
||||||
|
assert False, "Event was never set"
|
||||||
|
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
|
||||||
|
# Custom test exception that we'll configure to panic on
|
||||||
|
class TestUnrecoverableError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestExceptionNode(CustomNode):
|
||||||
|
"""Node that raises a specific exception for testing"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"should_raise": ("BOOL", {"default": True}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",) # Make it an output node by returning IMAGE
|
||||||
|
FUNCTION = "raise_exception"
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def raise_exception(self, should_raise=True) -> tuple[RGBImageBatch]:
|
||||||
|
if should_raise:
|
||||||
|
raise TestUnrecoverableError("Test exception from node")
|
||||||
|
else:
|
||||||
|
# Return a dummy image if not raising
|
||||||
|
return (torch.zeros([1, 64, 64, 3]),)
|
||||||
|
|
||||||
|
|
||||||
|
# Export the node mappings
|
||||||
|
TEST_NODE_CLASS_MAPPINGS = {
|
||||||
|
"TestExceptionNode": TestExceptionNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"TestExceptionNode": "Test Exception Node",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_failing_workflow():
|
||||||
|
"""Create a workflow that uses our test node to raise an exception"""
|
||||||
|
return make_mutable({
|
||||||
|
"1": {
|
||||||
|
"class_type": "TestExceptionNode",
|
||||||
|
"inputs": {
|
||||||
|
"should_raise": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_panic_on_exception():
|
||||||
|
# Set up the test nodes
|
||||||
|
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS))
|
||||||
|
|
||||||
|
# Create configuration with our test exception in panic_when
|
||||||
|
config = Configuration()
|
||||||
|
config.panic_when = [f"{__name__}.TestUnrecoverableError"]
|
||||||
|
|
||||||
|
# Mock sys.exit to prevent actual exit and verify it's called
|
||||||
|
with patch('sys.exit') as mock_exit:
|
||||||
|
try:
|
||||||
|
async with EmbeddedComfyClient(configuration=config) as client:
|
||||||
|
# Queue our failing workflow
|
||||||
|
await client.queue_prompt(create_failing_workflow())
|
||||||
|
except TestUnrecoverableError:
|
||||||
|
# We expect the exception to be raised here
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Give the event loop a chance to process the exit callback
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# Verify sys.exit was called with code 1
|
||||||
|
mock_exit.assert_called_once_with(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_panic_when_disabled():
|
||||||
|
"""Verify that the same exception doesn't trigger exit when not in panic_when"""
|
||||||
|
# Set up the test nodes
|
||||||
|
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS))
|
||||||
|
|
||||||
|
# Create configuration without the exception in panic_when
|
||||||
|
config = Configuration()
|
||||||
|
|
||||||
|
# Mock sys.exit to verify it's not called
|
||||||
|
with patch('sys.exit') as mock_exit:
|
||||||
|
try:
|
||||||
|
async with EmbeddedComfyClient(configuration=config) as client:
|
||||||
|
# Queue our failing workflow
|
||||||
|
await client.queue_prompt(create_failing_workflow())
|
||||||
|
except TestUnrecoverableError:
|
||||||
|
# We expect the exception to be raised here
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Give the event loop a chance to process any callbacks
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Verify sys.exit was not called
|
||||||
|
mock_exit.assert_not_called()
|
||||||
Loading…
Reference in New Issue
Block a user