Implement conditional Execution Blocking

Execution blocking can be done by returning an `ExecutionBlocker`
(imported from graph_utils) either in place of results or as a specific
output. Any node that uses an `ExecutionBlocker` as input will be
skipped. This operates on a per-entry basis when inputs are lists.

If the `ExecutionBlocker` is initialized with an error message, that
message will be displayed on the first node it's used on (and further
downstream nodes will be silently skipped).
This commit is contained in:
Jacob Segal 2023-07-28 22:28:18 -07:00
parent b09620f89c
commit 5d72965863
5 changed files with 205 additions and 14 deletions

View File

@ -63,6 +63,11 @@ class GraphBuilder:
id = self.prefix + id
del self.nodes[id]
# Return this from a node and any users will be blocked with the given error message.
class ExecutionBlocker:
def __init__(self, message):
self.message = message
class Node:
def __init__(self, id, class_type, inputs):
self.id = id

View File

@ -147,11 +147,42 @@ class ToBoolNode:
return (result,)
class BoolOperationNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("BOOL",),
"b": ("BOOL",),
"op": (["a AND b", "a OR b", "a XOR b", "NOT a"],),
},
}
RETURN_TYPES = ("BOOL",)
FUNCTION = "bool_operation"
CATEGORY = "InversionDemo Nodes"
def bool_operation(self, a, b, op):
if op == "a AND b":
return (a and b,)
elif op == "a OR b":
return (a or b,)
elif op == "a XOR b":
return (a ^ b,)
elif op == "NOT a":
return (not a,)
CONDITION_NODE_CLASS_MAPPINGS = {
"IntConditions": IntConditions,
"FloatConditions": FloatConditions,
"StringConditions": StringConditions,
"ToBoolNode": ToBoolNode,
"BoolOperationNode": BoolOperationNode,
}
CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
@ -159,4 +190,5 @@ CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
"FloatConditions": "Float Condition",
"StringConditions": "String Condition",
"ToBoolNode": "To Bool",
"BoolOperationNode": "Bool Operation",
}

View File

@ -1,4 +1,4 @@
from comfy.graph_utils import GraphBuilder, is_link
from comfy.graph_utils import GraphBuilder, is_link, ExecutionBlocker
NUM_FLOW_SOCKETS = 5
class WhileLoopOpen:
@ -124,11 +124,39 @@ class WhileLoopClose:
"expand": graph.finalize(),
}
class ExecutionBlockerNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
inputs = {
"required": {
"input": ("*",),
"block": ("BOOL",),
"verbose": ("BOOL", {"default": False}),
},
}
return inputs
RETURN_TYPES = ("*",)
RETURN_NAMES = ("output",)
FUNCTION = "execution_blocker"
CATEGORY = "Flow Control"
def execution_blocker(self, input, block, verbose):
if block:
return (ExecutionBlocker("Blocked Execution" if verbose else None),)
return (input,)
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
"WhileLoopOpen": WhileLoopOpen,
"WhileLoopClose": WhileLoopClose,
"ExecutionBlocker": ExecutionBlockerNode,
}
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
"WhileLoopOpen": "While Loop Open",
"WhileLoopClose": "While Loop Close",
"ExecutionBlocker": "Execution Blocker",
}

View File

@ -1,4 +1,5 @@
from comfy.graph_utils import GraphBuilder
import torch
class AccumulateNode:
def __init__(self):
@ -228,6 +229,82 @@ class ForLoopClose:
"expand": graph.finalize(),
}
class DebugPrint:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("*",),
"label": ("STRING", {"multiline": False}),
},
}
RETURN_TYPES = ("*",)
FUNCTION = "debug_print"
CATEGORY = "InversionDemo Nodes"
def debugtype(self, value):
if isinstance(value, list):
result = "["
for i, v in enumerate(value):
result += (self.debugtype(v) + ",")
result += "]"
elif isinstance(value, tuple):
result = "("
for i, v in enumerate(value):
result += (self.debugtype(v) + ",")
result += ")"
elif isinstance(value, dict):
result = "{"
for k, v in value.items():
result += ("%s: %s," % (self.debugtype(k), self.debugtype(v)))
result += "}"
elif isinstance(value, str):
result = "'%s'" % value
elif isinstance(value, bool) or isinstance(value, int) or isinstance(value, float):
result = str(value)
elif isinstance(value, torch.Tensor):
result = "Tensor[%s]" % str(value.shape)
else:
result = type(value).__name__
return result
def debug_print(self, value, label):
print("[%s]: %s" % (label, self.debugtype(value)))
return (value,)
NUM_LIST_SOCKETS = 10
class MakeListNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value1": ("*",),
},
"optional": {
"value%d" % i: ("*",) for i in range(1, NUM_LIST_SOCKETS)
},
}
RETURN_TYPES = ("*",)
FUNCTION = "make_list"
OUTPUT_IS_LIST = (True,)
CATEGORY = "InversionDemo Nodes"
def make_list(self, **kwargs):
result = []
for i in range(NUM_LIST_SOCKETS):
if "value%d" % i in kwargs:
result.append(kwargs["value%d" % i])
return (result,)
UTILITY_NODE_CLASS_MAPPINGS = {
"AccumulateNode": AccumulateNode,
@ -238,6 +315,8 @@ UTILITY_NODE_CLASS_MAPPINGS = {
"ForLoopOpen": ForLoopOpen,
"ForLoopClose": ForLoopClose,
"IntMathOperation": IntMathOperation,
"DebugPrint": DebugPrint,
"MakeListNode": MakeListNode,
}
UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
"AccumulateNode": "Accumulate",
@ -248,4 +327,6 @@ UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
"ForLoopOpen": "For Loop Open",
"ForLoopClose": "For Loop Close",
"IntMathOperation": "Int Math Operation",
"DebugPrint": "Debug Print",
"MakeListNode": "Make List",
}

View File

@ -14,7 +14,7 @@ import nodes
import comfy.model_management
import comfy.graph_utils
from comfy.graph_utils import is_link
from comfy.graph_utils import is_link, ExecutionBlocker
class ExecutionResult(Enum):
SUCCESS = 0
@ -185,7 +185,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
input_data_all[x] = [unique_id]
return input_data_all
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
@ -204,12 +204,31 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
if intput_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
execution_block = None
for k, v in input_data_all.items():
for input in v:
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
break
if execution_block is None:
results.append(getattr(obj, func)(**input_data_all))
else:
results.append(execution_block)
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
input_dict = slice_dict(input_data_all, i)
execution_block = None
for k, v in input_dict.items():
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
break
if execution_block is None:
results.append(getattr(obj, func)(**input_dict))
else:
results.append(execution_block)
return results
def merge_result_data(results, obj):
@ -227,12 +246,12 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all):
def get_output_data(obj, input_data_all, execution_block_cb=None):
results = []
uis = []
subgraph_results = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
@ -243,11 +262,19 @@ def get_output_data(obj, input_data_all):
# Perform an expansion, but do not append results
has_subgraph = True
new_graph = r['expand']
subgraph_results.append((new_graph, r.get("result", None)))
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif 'result' in r:
results.append(r['result'])
subgraph_results.append((None, r['result']))
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r)
if has_subgraph:
@ -315,13 +342,31 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
if hasattr(obj, "check_lazy_status"):
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if x not in input_data_all]
required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all]
if len(required_inputs) > 0:
for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.SLEEPING, None, None)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all)
def execution_block_cb(block):
if block.message is not None:
mes = {
"prompt_id": prompt_id,
"node_id": unique_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": "Execution Blocked: %s" % block.message,
"exception_type": "ExecutionBlocked",
"traceback": [],
"current_inputs": [],
"current_outputs": [],
}
server.send_sync("execution_error", mes, server.client_id)
return ExecutionBlocker(None)
else:
return block
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb)
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
@ -414,7 +459,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
prompt[unique_id]['is_changed'] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except:
to_delete = True
else:
@ -719,7 +764,7 @@ def validate_inputs(prompt, item, validated):
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret):
if r is not True:
if r is not True and not isinstance(r, ExecutionBlocker):
details = f"{x}"
if r is not False:
details += f" - {str(r)}"