From dd3bafb40b37e377ffa630edee238c19c03a6d44 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 20 Apr 2024 22:40:38 -0700 Subject: [PATCH] Display an error for dependency cycles Previously, dependency cycles that were created during node expansion would cause the application to quit (due to an uncaught exception). Now, we'll throw a proper error to the UI. We also make an attempt to 'blame' the most relevant node in the UI. --- comfy/graph.py | 42 +++++++++++++++++-- execution.py | 6 ++- tests/inference/test_execution.py | 30 +++++++++++++ .../testing-pack/specific_tests.py | 32 ++++++++++++++ 4 files changed, 106 insertions(+), 4 deletions(-) diff --git a/comfy/graph.py b/comfy/graph.py index bf4be4ae1..b20c7bf38 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -144,10 +144,27 @@ class ExecutionList(TopologicalSort): def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): - return None + return None, None, None available = self.get_ready_nodes() if len(available) == 0: - raise DependencyCycleError("Dependency cycle detected") + cycled_nodes = self.get_nodes_in_cycle() + # Because cycles composed entirely of static nodes are caught during initial validation, + # we will 'blame' the first node in the cycle that is not a static node. + blamed_node = cycled_nodes[0] + for node_id in cycled_nodes: + display_node_id = self.dynprompt.get_display_node_id(node_id) + if display_node_id != node_id: + blamed_node = display_node_id + break + ex = DependencyCycleError("Dependency cycle detected") + error_details = { + "node_id": blamed_node, + "exception_message": str(ex), + "exception_type": "graph.DependencyCycleError", + "traceback": [], + "current_inputs": [] + } + return None, error_details, ex next_node = available[0] # If an output node is available, do that first. # Technically this has no effect on the overall length of execution, but it feels better as a user @@ -160,7 +177,7 @@ class ExecutionList(TopologicalSort): next_node = node_id break self.staged_node_id = next_node - return self.staged_node_id + return self.staged_node_id, None, None def unstage_node_execution(self): assert self.staged_node_id is not None @@ -171,6 +188,25 @@ class ExecutionList(TopologicalSort): self.pop_node(node_id) self.staged_node_id = None + def get_nodes_in_cycle(self): + # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle. + # We're skipping some of the performance optimizations from the original TopologicalSort to keep + # the code simple (and because having a cycle in the first place is a catastrophic error) + blocked_by = { node_id: {} for node_id in self.pendingNodes } + for from_node_id in self.blocking: + for to_node_id in self.blocking[from_node_id]: + if True in self.blocking[from_node_id][to_node_id].values(): + blocked_by[to_node_id][from_node_id] = True + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + while len(to_remove) > 0: + for node_id in to_remove: + for to_node_id in blocked_by: + if node_id in blocked_by[to_node_id]: + del blocked_by[to_node_id][node_id] + del blocked_by[node_id] + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + return list(blocked_by.keys()) + # Return this from a node and any users will be blocked with the given error message. class ExecutionBlocker: def __init__(self, message): diff --git a/execution.py b/execution.py index ecd0850af..ee4637a38 100644 --- a/execution.py +++ b/execution.py @@ -473,7 +473,11 @@ class PromptExecutor: execution_list.add_node(node_id) while not execution_list.is_empty(): - node_id = execution_list.stage_node_execution() + node_id, error, ex = execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 6a4fa3dd1..0ae70b8ca 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -234,6 +234,7 @@ class TestExecution: try: client.run(g) + assert False, "Should have raised an error" except Exception as e: assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" @@ -298,6 +299,35 @@ class TestExecution: else: client.run(g) + def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), mask=mask.out(0)) + lazy_mix2 = g.node("TestLazyMixImages", image1=lazy_mix1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix2.out(0)) + + # When the cycle exists on initial submission, it should raise a validation error + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + generator = g.node("TestDynamicDependencyCycle", input1=input1.out(0), input2=input2.out(0)) + g.node("SaveImage", images=generator.out(0)) + + # When the cycle is in a graph that is generated dynamically, it should raise a runtime error + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 8c103c18a..56b8f70b2 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -1,5 +1,6 @@ import torch from .tools import VariantSupport +from comfy.graph_utils import GraphBuilder class TestLazyMixImages: @classmethod @@ -191,6 +192,35 @@ class TestCustomValidation3: result = input1 * input2 return (result,) +class TestDynamicDependencyCycle: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE",), + "input2": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "dynamic_dependency_cycle" + + CATEGORY = "Testing/Nodes" + + def dynamic_dependency_cycle(self, input1, input2): + g = GraphBuilder() + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0)) + mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0)) + + # Create the cyle + mix1.set_input("image2", mix2.out(0)) + + return { + "result": (mix2.out(0),), + "expand": g.finalize(), + } + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -198,6 +228,7 @@ TEST_NODE_CLASS_MAPPINGS = { "TestCustomValidation1": TestCustomValidation1, "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, + "TestDynamicDependencyCycle": TestDynamicDependencyCycle, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -207,4 +238,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestCustomValidation1": "Custom Validation 1", "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3", + "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", }