From b66253b930b86e197238c7567c6920fdb8563e61 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 22 Jul 2023 22:38:17 -0700 Subject: [PATCH] Improve recognition of node linkage Honestly, I'm still a little concerned here. There's nothing stopping a custom node from having a data type of ["str",int]. I've improved recognition to at least prevent the detection of other types, but we may still want a more systemic fix (e.g. wrapping literals within a class when using them as inputs to nodes in subgraphs). --- comfy/graph_utils.py | 17 ++++++++++++++--- .../flow_control.py | 6 +++--- execution.py | 13 +++++++------ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py index d1a4e7187..9103a9622 100644 --- a/comfy/graph_utils.py +++ b/comfy/graph_utils.py @@ -1,6 +1,17 @@ import json import random +def is_link(obj): + if not isinstance(obj, list): + return False + if len(obj) != 2: + return False + if not isinstance(obj[0], str): + return False + if not isinstance(obj[1], int) and not isinstance(obj[1], float): + return False + return True + # The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end class GraphBuilder: def __init__(self, prefix = True): @@ -40,7 +51,7 @@ class GraphBuilder: to_remove = [] for node in self.nodes.values(): for key, value in node.inputs.items(): - if isinstance(value, list) and value[0] == node_id and value[1] == index: + if is_link(value) and value[0] == node_id and value[1] == index: if new_value is None: to_remove.append((node, key)) else: @@ -85,7 +96,7 @@ def add_graph_prefix(graph, outputs, prefix): new_node_id = prefix + node_id new_node = { "class_type": node_info["class_type"], "inputs": {} } for input_name, input_value in node_info.get("inputs", {}).items(): - if isinstance(input_value, list): + if is_link(input_value): new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] else: new_node["inputs"][input_name] = input_value @@ -95,7 +106,7 @@ def add_graph_prefix(graph, outputs, prefix): new_outputs = [] for n in range(len(outputs)): output = outputs[n] - if isinstance(output, list): # This is a node link + if is_link(output): new_outputs.append([prefix + output[0], output[1]]) else: new_outputs.append(output) diff --git a/custom_nodes/execution-inversion-demo-comfyui/flow_control.py b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py index 7ec8d12db..626a82f39 100644 --- a/custom_nodes/execution-inversion-demo-comfyui/flow_control.py +++ b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py @@ -1,4 +1,4 @@ -from comfy.graph_utils import GraphBuilder +from comfy.graph_utils import GraphBuilder, is_link NUM_FLOW_SOCKETS = 5 class WhileLoopOpen: @@ -63,7 +63,7 @@ class WhileLoopClose: if "inputs" not in node_info: return for k, v in node_info["inputs"].items(): - if isinstance(v, list) and len(v) == 2: + if is_link(v): parent_id = v[0] if parent_id not in upstream: upstream[parent_id] = [] @@ -107,7 +107,7 @@ class WhileLoopClose: original_node = dynprompt.get_node(node_id) node = graph.lookup_node(node_id) for k, v in original_node["inputs"].items(): - if isinstance(v, list) and len(v) == 2 and v[0] in contained: + if is_link(v) and v[0] in contained: parent = graph.lookup_node(v[0]) node.set_input(k, parent.out(v[1])) else: diff --git a/execution.py b/execution.py index 9a391d390..5c31b1a7d 100644 --- a/execution.py +++ b/execution.py @@ -14,6 +14,7 @@ import nodes import comfy.model_management import comfy.graph_utils +from comfy.graph_utils import is_link class ExecutionResult(Enum): SUCCESS = 0 @@ -63,7 +64,7 @@ class ExecutionList: if to_input not in inputs: raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input)) value = inputs[to_input] - if not isinstance(value, list): + if not is_link(value): raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input)) from_node_id, from_socket = value self.add_strong_link(from_node_id, from_socket, to_node_id) @@ -88,7 +89,7 @@ class ExecutionList: inputs = self.dynprompt.get_node(unique_id)["inputs"] for input_name in inputs: value = inputs[input_name] - if isinstance(value, list): + if is_link(value): from_node_id, from_socket = value input_type, input_category, input_info = self.get_input_info(unique_id, input_name) if "lazy" not in input_info or not input_info["lazy"]: @@ -160,7 +161,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp for x in inputs: input_data = inputs[x] input_type, input_category, input_info = get_input_info(class_def, x) - if isinstance(input_data, list) and not input_info.get("rawLink", False): + if is_link(input_data) and not input_info.get("rawLink", False): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: @@ -288,7 +289,7 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, else: resolved_output = [] for r in result: - if isinstance(r, list) and len(r) == 2: + if is_link(r): source_node, source_output = r[0], r[1] node_output = outputs[source_node][source_output] for o in node_output: @@ -348,7 +349,7 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, for node_id in new_output_ids: execution_list.add_node(node_id) for i in range(len(node_outputs)): - if isinstance(node_outputs[i], list) and len(node_outputs[i]) == 2: + if is_link(node_outputs[i]): from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] execution_list.add_strong_link(from_node_id, from_socket, unique_id) cached_outputs.append((True, node_outputs)) @@ -430,7 +431,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): + if is_link(input_data): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id in outputs: