diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 479ee8a53..1add78f38 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -3,6 +3,7 @@ from typing import Type, Literal import nodes import asyncio import inspect +import traceback from comfy_execution.graph_utils import is_link, ExecutionBlocker from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions @@ -263,7 +264,25 @@ class ExecutionList(TopologicalSort): } return None, error_details, ex - self.staged_node_id = self.ux_friendly_pick_node(available) + try: + self.staged_node_id = self.ux_friendly_pick_node(available) + except Exception as ex: + # Backstop: the ordering heuristics in ux_friendly_pick_node are + # defensive, but should anything else there fail, surface it as an + # execution error instead of letting it kill the prompt worker + # thread. Blame an available node (best effort). + blamed_node = self.dynprompt.get_display_node_id(available[0]) + exception_type = type(ex).__qualname__ + if type(ex).__module__ != "builtins": + exception_type = type(ex).__module__ + "." + exception_type + error_details = { + "node_id": blamed_node, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(ex.__traceback__), + "current_inputs": [] + } + return None, error_details, ex return self.staged_node_id, None, None def ux_friendly_pick_node(self, node_list): @@ -271,19 +290,28 @@ class ExecutionList(TopologicalSort): # Technically this has no effect on the overall length of execution, but it feels better as a user # for a PreviewImage to display a result as soon as it can # Some other heuristics could probably be used here to improve the UX further. + # These node-ordering heuristics only affect *order*, never correctness. + # A malformed node (e.g. a FUNCTION typo, or a node whose schema-derived + # attributes raise) must not crash scheduling: failing a heuristic just + # means "not prioritized". The node then proceeds to normal execution, + # where the real error is raised and reported against the correct node. def is_output(node_id): class_type = self.dynprompt.get_node(node_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: - return True - return False + try: + return hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True + except Exception: + return False # If an available node is async, do that first. # This will execute the asynchronous function earlier, reducing the overall time. def is_async(node_id): class_type = self.dynprompt.get_node(node_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) + try: + return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) + except Exception: + return False for node_id in node_list: if is_output(node_id) or is_async(node_id): diff --git a/tests-unit/execution_test/scheduler_malformed_node_test.py b/tests-unit/execution_test/scheduler_malformed_node_test.py new file mode 100644 index 000000000..dd0c11b55 --- /dev/null +++ b/tests-unit/execution_test/scheduler_malformed_node_test.py @@ -0,0 +1,97 @@ +"""Regression tests for scheduler resilience to malformed nodes. + +A node whose FUNCTION points at a method that does not exist (e.g. a typo in a +custom node) used to raise inside the scheduling heuristic, escaping the prompt +worker's error handling and silently killing the worker thread. Scheduling must +instead either proceed (so the error surfaces through normal execution) or report +the failure as an execution error. +""" +import asyncio + +import nodes +from comfy_execution.graph import DynamicPrompt, ExecutionList + + +class _MalformedV1Node: + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "invert" # the actual method below is misspelled + OUTPUT_NODE = True + CATEGORY = "Test" + + def invvert(self): + return (None,) + + +class _RaisingDescriptor: + def __get__(self, obj, owner): + raise RuntimeError("schema error") + + +class _SchemaRaisesNode: + """A node whose schema-derived attribute access raises, as a broken V3 node would.""" + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "run" + OUTPUT_NODE = _RaisingDescriptor() + CATEGORY = "Test" + + def run(self): + return (None,) + + +class _FakeOutputCache: + def all_node_ids(self): + return set() + + async def get(self, node_id): + return None + + +def _make_execution_list(class_type, class_def): + nodes.NODE_CLASS_MAPPINGS[class_type] = class_def + prompt = {"1": {"class_type": class_type, "inputs": {}}} + execution_list = ExecutionList(DynamicPrompt(prompt), _FakeOutputCache()) + execution_list.add_node("1") + return execution_list + + +def test_malformed_function_does_not_crash_scheduler(): + """A FUNCTION-typo node schedules without raising; the error surfaces later.""" + execution_list = _make_execution_list("MalformedV1Node", _MalformedV1Node) + node_id, error, ex = asyncio.run(execution_list.stage_node_execution()) + assert ex is None + assert error is None + assert node_id == "1" + + +def test_schema_attribute_error_does_not_crash_scheduler(): + """A node whose attribute access raises during heuristics still schedules.""" + execution_list = _make_execution_list("SchemaRaisesNode", _SchemaRaisesNode) + node_id, error, ex = asyncio.run(execution_list.stage_node_execution()) + assert ex is None + assert error is None + assert node_id == "1" + + +def test_pick_node_failure_is_reported_not_raised(): + """An unexpected scheduling error is returned as an error, not raised.""" + execution_list = _make_execution_list("MalformedV1Node", _MalformedV1Node) + + def raise_on_pick(_available): + raise RuntimeError("boom") + + execution_list.ux_friendly_pick_node = raise_on_pick + node_id, error, ex = asyncio.run(execution_list.stage_node_execution()) + assert node_id is None + assert isinstance(ex, RuntimeError) + assert error["node_id"] == "1" + assert error["exception_type"] == "RuntimeError" + assert error["exception_message"] == "boom" + assert error["traceback"]