Improve recognition of node linkage

Honestly, I'm still a little concerned here. There's nothing stopping a
custom node from having a data type of ["str",int]. I've improved
recognition to at least prevent the detection of other types, but we
may still want a more systemic fix (e.g. wrapping literals within a
class when using them as inputs to nodes in subgraphs).
This commit is contained in:
Jacob Segal 2023-07-22 22:38:17 -07:00
parent 2520ade224
commit b66253b930
3 changed files with 24 additions and 12 deletions

View File

@ -1,6 +1,17 @@
import json
import random
def is_link(obj):
if not isinstance(obj, list):
return False
if len(obj) != 2:
return False
if not isinstance(obj[0], str):
return False
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
return False
return True
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder:
def __init__(self, prefix = True):
@ -40,7 +51,7 @@ class GraphBuilder:
to_remove = []
for node in self.nodes.values():
for key, value in node.inputs.items():
if isinstance(value, list) and value[0] == node_id and value[1] == index:
if is_link(value) and value[0] == node_id and value[1] == index:
if new_value is None:
to_remove.append((node, key))
else:
@ -85,7 +96,7 @@ def add_graph_prefix(graph, outputs, prefix):
new_node_id = prefix + node_id
new_node = { "class_type": node_info["class_type"], "inputs": {} }
for input_name, input_value in node_info.get("inputs", {}).items():
if isinstance(input_value, list):
if is_link(input_value):
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
else:
new_node["inputs"][input_name] = input_value
@ -95,7 +106,7 @@ def add_graph_prefix(graph, outputs, prefix):
new_outputs = []
for n in range(len(outputs)):
output = outputs[n]
if isinstance(output, list): # This is a node link
if is_link(output):
new_outputs.append([prefix + output[0], output[1]])
else:
new_outputs.append(output)

View File

@ -1,4 +1,4 @@
from comfy.graph_utils import GraphBuilder
from comfy.graph_utils import GraphBuilder, is_link
NUM_FLOW_SOCKETS = 5
class WhileLoopOpen:
@ -63,7 +63,7 @@ class WhileLoopClose:
if "inputs" not in node_info:
return
for k, v in node_info["inputs"].items():
if isinstance(v, list) and len(v) == 2:
if is_link(v):
parent_id = v[0]
if parent_id not in upstream:
upstream[parent_id] = []
@ -107,7 +107,7 @@ class WhileLoopClose:
original_node = dynprompt.get_node(node_id)
node = graph.lookup_node(node_id)
for k, v in original_node["inputs"].items():
if isinstance(v, list) and len(v) == 2 and v[0] in contained:
if is_link(v) and v[0] in contained:
parent = graph.lookup_node(v[0])
node.set_input(k, parent.out(v[1]))
else:

View File

@ -14,6 +14,7 @@ import nodes
import comfy.model_management
import comfy.graph_utils
from comfy.graph_utils import is_link
class ExecutionResult(Enum):
SUCCESS = 0
@ -63,7 +64,7 @@ class ExecutionList:
if to_input not in inputs:
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
value = inputs[to_input]
if not isinstance(value, list):
if not is_link(value):
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
from_node_id, from_socket = value
self.add_strong_link(from_node_id, from_socket, to_node_id)
@ -88,7 +89,7 @@ class ExecutionList:
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if isinstance(value, list):
if is_link(value):
from_node_id, from_socket = value
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
if "lazy" not in input_info or not input_info["lazy"]:
@ -160,7 +161,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x)
if isinstance(input_data, list) and not input_info.get("rawLink", False):
if is_link(input_data) and not input_info.get("rawLink", False):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
@ -288,7 +289,7 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
else:
resolved_output = []
for r in result:
if isinstance(r, list) and len(r) == 2:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = outputs[source_node][source_output]
for o in node_output:
@ -348,7 +349,7 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
for node_id in new_output_ids:
execution_list.add_node(node_id)
for i in range(len(node_outputs)):
if isinstance(node_outputs[i], list) and len(node_outputs[i]) == 2:
if is_link(node_outputs[i]):
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
execution_list.add_strong_link(from_node_id, from_socket, unique_id)
cached_outputs.append((True, node_outputs))
@ -430,7 +431,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
if is_link(input_data):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs: