Display an error for dependency cycles

Previously, dependency cycles that were created during node expansion
would cause the application to quit (due to an uncaught exception). Now,
we'll throw a proper error to the UI. We also make an attempt to 'blame'
the most relevant node in the UI.
This commit is contained in:
Jacob Segal 2024-04-20 22:40:38 -07:00
parent 5dc13651b0
commit dd3bafb40b
4 changed files with 106 additions and 4 deletions

View File

@ -144,10 +144,27 @@ class ExecutionList(TopologicalSort):
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None
return None, None, None
available = self.get_ready_nodes()
if len(available) == 0:
raise DependencyCycleError("Dependency cycle detected")
cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation,
# we will 'blame' the first node in the cycle that is not a static node.
blamed_node = cycled_nodes[0]
for node_id in cycled_nodes:
display_node_id = self.dynprompt.get_display_node_id(node_id)
if display_node_id != node_id:
blamed_node = display_node_id
break
ex = DependencyCycleError("Dependency cycle detected")
error_details = {
"node_id": blamed_node,
"exception_message": str(ex),
"exception_type": "graph.DependencyCycleError",
"traceback": [],
"current_inputs": []
}
return None, error_details, ex
next_node = available[0]
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
@ -160,7 +177,7 @@ class ExecutionList(TopologicalSort):
next_node = node_id
break
self.staged_node_id = next_node
return self.staged_node_id
return self.staged_node_id, None, None
def unstage_node_execution(self):
assert self.staged_node_id is not None
@ -171,6 +188,25 @@ class ExecutionList(TopologicalSort):
self.pop_node(node_id)
self.staged_node_id = None
def get_nodes_in_cycle(self):
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
# the code simple (and because having a cycle in the first place is a catastrophic error)
blocked_by = { node_id: {} for node_id in self.pendingNodes }
for from_node_id in self.blocking:
for to_node_id in self.blocking[from_node_id]:
if True in self.blocking[from_node_id][to_node_id].values():
blocked_by[to_node_id][from_node_id] = True
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
while len(to_remove) > 0:
for node_id in to_remove:
for to_node_id in blocked_by:
if node_id in blocked_by[to_node_id]:
del blocked_by[to_node_id][node_id]
del blocked_by[node_id]
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
return list(blocked_by.keys())
# Return this from a node and any users will be blocked with the given error message.
class ExecutionBlocker:
def __init__(self, message):

View File

@ -473,7 +473,11 @@ class PromptExecutor:
execution_list.add_node(node_id)
while not execution_list.is_empty():
node_id = execution_list.stage_node_execution()
node_id, error, ex = execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)

View File

@ -234,6 +234,7 @@ class TestExecution:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
@ -298,6 +299,35 @@ class TestExecution:
else:
client.run(g)
def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), mask=mask.out(0))
lazy_mix2 = g.node("TestLazyMixImages", image1=lazy_mix1.out(0), image2=input2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix2.out(0))
# When the cycle exists on initial submission, it should raise a validation error
with pytest.raises(urllib.error.HTTPError):
client.run(g)
def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
generator = g.node("TestDynamicDependencyCycle", input1=input1.out(0), input2=input2.out(0))
g.node("SaveImage", images=generator.out(0))
# When the cycle is in a graph that is generated dynamically, it should raise a runtime error
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Creating the nodes in this specific order previously caused a bug

View File

@ -1,5 +1,6 @@
import torch
from .tools import VariantSupport
from comfy.graph_utils import GraphBuilder
class TestLazyMixImages:
@classmethod
@ -191,6 +192,35 @@ class TestCustomValidation3:
result = input1 * input2
return (result,)
class TestDynamicDependencyCycle:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE",),
"input2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "dynamic_dependency_cycle"
CATEGORY = "Testing/Nodes"
def dynamic_dependency_cycle(self, input1, input2):
g = GraphBuilder()
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
# Create the cyle
mix1.set_input("image2", mix2.out(0))
return {
"result": (mix2.out(0),),
"expand": g.finalize(),
}
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
@ -198,6 +228,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@ -207,4 +238,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
}