diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 424d97d2b..1add78f38 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -267,10 +267,10 @@ class ExecutionList(TopologicalSort): try: self.staged_node_id = self.ux_friendly_pick_node(available) except Exception as ex: - # Picking a node is a scheduling heuristic that inspects node - # definitions; a malformed custom node must not crash the prompt - # worker thread silently. Blame an available node and surface the - # error to the frontend like any other execution error. + # Backstop: the ordering heuristics in ux_friendly_pick_node are + # defensive, but should anything else there fail, surface it as an + # execution error instead of letting it kill the prompt worker + # thread. Blame an available node (best effort). blamed_node = self.dynprompt.get_display_node_id(available[0]) exception_type = type(ex).__qualname__ if type(ex).__module__ != "builtins": @@ -290,29 +290,28 @@ class ExecutionList(TopologicalSort): # Technically this has no effect on the overall length of execution, but it feels better as a user # for a PreviewImage to display a result as soon as it can # Some other heuristics could probably be used here to improve the UX further. + # These node-ordering heuristics only affect *order*, never correctness. + # A malformed node (e.g. a FUNCTION typo, or a node whose schema-derived + # attributes raise) must not crash scheduling: failing a heuristic just + # means "not prioritized". The node then proceeds to normal execution, + # where the real error is raised and reported against the correct node. def is_output(node_id): class_type = self.dynprompt.get_node(node_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: - return True - return False + try: + return hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True + except Exception: + return False # If an available node is async, do that first. # This will execute the asynchronous function earlier, reducing the overall time. def is_async(node_id): class_type = self.dynprompt.get_node(node_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - # A malformed node (e.g. FUNCTION pointing at a method that does not - # exist because of a typo) must not crash scheduling here. Treat it as - # non-async so it proceeds to normal execution, where the missing - # method raises an error that is caught and reported to the frontend. - function_name = getattr(class_def, "FUNCTION", None) - if function_name is None: + try: + return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) + except Exception: return False - func = getattr(class_def, function_name, None) - if func is None: - return False - return inspect.iscoroutinefunction(func) for node_id in node_list: if is_output(node_id) or is_async(node_id): diff --git a/tests-unit/execution_test/scheduler_malformed_node_test.py b/tests-unit/execution_test/scheduler_malformed_node_test.py index 1397f5488..dd0c11b55 100644 --- a/tests-unit/execution_test/scheduler_malformed_node_test.py +++ b/tests-unit/execution_test/scheduler_malformed_node_test.py @@ -26,6 +26,26 @@ class _MalformedV1Node: return (None,) +class _RaisingDescriptor: + def __get__(self, obj, owner): + raise RuntimeError("schema error") + + +class _SchemaRaisesNode: + """A node whose schema-derived attribute access raises, as a broken V3 node would.""" + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "run" + OUTPUT_NODE = _RaisingDescriptor() + CATEGORY = "Test" + + def run(self): + return (None,) + + class _FakeOutputCache: def all_node_ids(self): return set() @@ -51,6 +71,15 @@ def test_malformed_function_does_not_crash_scheduler(): assert node_id == "1" +def test_schema_attribute_error_does_not_crash_scheduler(): + """A node whose attribute access raises during heuristics still schedules.""" + execution_list = _make_execution_list("SchemaRaisesNode", _SchemaRaisesNode) + node_id, error, ex = asyncio.run(execution_list.stage_node_execution()) + assert ex is None + assert error is None + assert node_id == "1" + + def test_pick_node_failure_is_reported_not_raised(): """An unexpected scheduling error is returned as an error, not raised.""" execution_list = _make_execution_list("MalformedV1Node", _MalformedV1Node)