Raise exception for bad get_node calls.

This commit is contained in:
Jacob Segal 2024-04-21 16:10:01 -07:00
parent 2dda3f2827
commit 06f3ce9200
2 changed files with 8 additions and 2 deletions

View File

@ -8,6 +8,9 @@ class DependencyCycleError(Exception):
class NodeInputError(Exception): class NodeInputError(Exception):
pass pass
class NodeNotFoundError(Exception):
pass
class DynamicPrompt: class DynamicPrompt:
def __init__(self, original_prompt): def __init__(self, original_prompt):
# The original prompt provided by the user # The original prompt provided by the user
@ -22,7 +25,10 @@ class DynamicPrompt:
return self.ephemeral_prompt[node_id] return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt: if node_id in self.original_prompt:
return self.original_prompt[node_id] return self.original_prompt[node_id]
return None raise NodeNotFoundError(f"Node {node_id} not found")
def has_node(self, node_id):
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
self.ephemeral_prompt[node_id] = node_info self.ephemeral_prompt[node_id] = node_info

View File

@ -349,7 +349,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
else: else:
# Check for conflicts # Check for conflicts
for node_id in new_graph.keys(): for node_id in new_graph.keys():
if dynprompt.get_node(node_id) is not None: if dynprompt.has_node(node_id):
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
for node_id, node_info in new_graph.items(): for node_id, node_info in new_graph.items():
new_node_ids.append(node_id) new_node_ids.append(node_id)