mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`, that variable will be a dictionary of the socket type of all incoming connections. If that argument exists, normal socket type validation will not occur. This removes the last hurdle for enabling variant types entirely from custom nodes, so I've removed that command-line option. I've added appropriate unit tests for these changes.
174 lines
5.8 KiB
Python
174 lines
5.8 KiB
Python
from comfy.graph_utils import GraphBuilder, is_link
|
|
from comfy.graph import ExecutionBlocker
|
|
from .tools import VariantSupport
|
|
|
|
NUM_FLOW_SOCKETS = 5
|
|
@VariantSupport()
|
|
class TestWhileLoopOpen:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
inputs = {
|
|
"required": {
|
|
"condition": ("BOOLEAN", {"default": True}),
|
|
},
|
|
"optional": {
|
|
},
|
|
}
|
|
for i in range(NUM_FLOW_SOCKETS):
|
|
inputs["optional"]["initial_value%d" % i] = ("*",)
|
|
return inputs
|
|
|
|
RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS)
|
|
RETURN_NAMES = tuple(["FLOW_CONTROL"] + ["value%d" % i for i in range(NUM_FLOW_SOCKETS)])
|
|
FUNCTION = "while_loop_open"
|
|
|
|
CATEGORY = "Testing/Flow"
|
|
|
|
def while_loop_open(self, condition, **kwargs):
|
|
values = []
|
|
for i in range(NUM_FLOW_SOCKETS):
|
|
values.append(kwargs.get("initial_value%d" % i, None))
|
|
return tuple(["stub"] + values)
|
|
|
|
@VariantSupport()
|
|
class TestWhileLoopClose:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
inputs = {
|
|
"required": {
|
|
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
|
|
"condition": ("BOOLEAN", {"forceInput": True}),
|
|
},
|
|
"optional": {
|
|
},
|
|
"hidden": {
|
|
"dynprompt": "DYNPROMPT",
|
|
"unique_id": "UNIQUE_ID",
|
|
}
|
|
}
|
|
for i in range(NUM_FLOW_SOCKETS):
|
|
inputs["optional"]["initial_value%d" % i] = ("*",)
|
|
return inputs
|
|
|
|
RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS)
|
|
RETURN_NAMES = tuple(["value%d" % i for i in range(NUM_FLOW_SOCKETS)])
|
|
FUNCTION = "while_loop_close"
|
|
|
|
CATEGORY = "Testing/Flow"
|
|
|
|
def explore_dependencies(self, node_id, dynprompt, upstream):
|
|
node_info = dynprompt.get_node(node_id)
|
|
if "inputs" not in node_info:
|
|
return
|
|
for k, v in node_info["inputs"].items():
|
|
if is_link(v):
|
|
parent_id = v[0]
|
|
if parent_id not in upstream:
|
|
upstream[parent_id] = []
|
|
self.explore_dependencies(parent_id, dynprompt, upstream)
|
|
upstream[parent_id].append(node_id)
|
|
|
|
def collect_contained(self, node_id, upstream, contained):
|
|
if node_id not in upstream:
|
|
return
|
|
for child_id in upstream[node_id]:
|
|
if child_id not in contained:
|
|
contained[child_id] = True
|
|
self.collect_contained(child_id, upstream, contained)
|
|
|
|
|
|
def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs):
|
|
assert dynprompt is not None
|
|
if not condition:
|
|
# We're done with the loop
|
|
values = []
|
|
for i in range(NUM_FLOW_SOCKETS):
|
|
values.append(kwargs.get("initial_value%d" % i, None))
|
|
return tuple(values)
|
|
|
|
# We want to loop
|
|
upstream = {}
|
|
# Get the list of all nodes between the open and close nodes
|
|
self.explore_dependencies(unique_id, dynprompt, upstream)
|
|
|
|
contained = {}
|
|
open_node = flow_control[0]
|
|
self.collect_contained(open_node, upstream, contained)
|
|
contained[unique_id] = True
|
|
contained[open_node] = True
|
|
|
|
# We'll use the default prefix, but to avoid having node names grow exponentially in size,
|
|
# we'll use "Recurse" for the name of the recursively-generated copy of this node.
|
|
graph = GraphBuilder()
|
|
for node_id in contained:
|
|
original_node = dynprompt.get_node(node_id)
|
|
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
|
|
node.set_override_display_id(node_id)
|
|
for node_id in contained:
|
|
original_node = dynprompt.get_node(node_id)
|
|
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
|
|
assert node is not None
|
|
for k, v in original_node["inputs"].items():
|
|
if is_link(v) and v[0] in contained:
|
|
parent = graph.lookup_node(v[0])
|
|
assert parent is not None
|
|
node.set_input(k, parent.out(v[1]))
|
|
else:
|
|
node.set_input(k, v)
|
|
new_open = graph.lookup_node(open_node)
|
|
assert new_open is not None
|
|
for i in range(NUM_FLOW_SOCKETS):
|
|
key = "initial_value%d" % i
|
|
new_open.set_input(key, kwargs.get(key, None))
|
|
my_clone = graph.lookup_node("Recurse")
|
|
assert my_clone is not None
|
|
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
|
|
return {
|
|
"result": tuple(result),
|
|
"expand": graph.finalize(),
|
|
}
|
|
|
|
@VariantSupport()
|
|
class TestExecutionBlockerNode:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
inputs = {
|
|
"required": {
|
|
"input": ("*",),
|
|
"block": ("BOOLEAN",),
|
|
"verbose": ("BOOLEAN", {"default": False}),
|
|
},
|
|
}
|
|
return inputs
|
|
|
|
RETURN_TYPES = ("*",)
|
|
RETURN_NAMES = ("output",)
|
|
FUNCTION = "execution_blocker"
|
|
|
|
CATEGORY = "Testing/Flow"
|
|
|
|
def execution_blocker(self, input, block, verbose):
|
|
if block:
|
|
return (ExecutionBlocker("Blocked Execution" if verbose else None),)
|
|
return (input,)
|
|
|
|
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
|
|
"TestWhileLoopOpen": TestWhileLoopOpen,
|
|
"TestWhileLoopClose": TestWhileLoopClose,
|
|
"TestExecutionBlocker": TestExecutionBlockerNode,
|
|
}
|
|
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"TestWhileLoopOpen": "While Loop Open",
|
|
"TestWhileLoopClose": "While Loop Close",
|
|
"TestExecutionBlocker": "Execution Blocker",
|
|
}
|