mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
Implement conditional Execution Blocking
Execution blocking can be done by returning an `ExecutionBlocker` (imported from graph_utils) either in place of results or as a specific output. Any node that uses an `ExecutionBlocker` as input will be skipped. This operates on a per-entry basis when inputs are lists. If the `ExecutionBlocker` is initialized with an error message, that message will be displayed on the first node it's used on (and further downstream nodes will be silently skipped).
This commit is contained in:
parent
b09620f89c
commit
5d72965863
@ -63,6 +63,11 @@ class GraphBuilder:
|
||||
id = self.prefix + id
|
||||
del self.nodes[id]
|
||||
|
||||
# Return this from a node and any users will be blocked with the given error message.
|
||||
class ExecutionBlocker:
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
class Node:
|
||||
def __init__(self, id, class_type, inputs):
|
||||
self.id = id
|
||||
|
||||
@ -147,11 +147,42 @@ class ToBoolNode:
|
||||
|
||||
return (result,)
|
||||
|
||||
class BoolOperationNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("BOOL",),
|
||||
"b": ("BOOL",),
|
||||
"op": (["a AND b", "a OR b", "a XOR b", "NOT a"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOL",)
|
||||
FUNCTION = "bool_operation"
|
||||
|
||||
CATEGORY = "InversionDemo Nodes"
|
||||
|
||||
def bool_operation(self, a, b, op):
|
||||
if op == "a AND b":
|
||||
return (a and b,)
|
||||
elif op == "a OR b":
|
||||
return (a or b,)
|
||||
elif op == "a XOR b":
|
||||
return (a ^ b,)
|
||||
elif op == "NOT a":
|
||||
return (not a,)
|
||||
|
||||
|
||||
CONDITION_NODE_CLASS_MAPPINGS = {
|
||||
"IntConditions": IntConditions,
|
||||
"FloatConditions": FloatConditions,
|
||||
"StringConditions": StringConditions,
|
||||
"ToBoolNode": ToBoolNode,
|
||||
"BoolOperationNode": BoolOperationNode,
|
||||
}
|
||||
|
||||
CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -159,4 +190,5 @@ CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FloatConditions": "Float Condition",
|
||||
"StringConditions": "String Condition",
|
||||
"ToBoolNode": "To Bool",
|
||||
"BoolOperationNode": "Bool Operation",
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from comfy.graph_utils import GraphBuilder, is_link
|
||||
from comfy.graph_utils import GraphBuilder, is_link, ExecutionBlocker
|
||||
|
||||
NUM_FLOW_SOCKETS = 5
|
||||
class WhileLoopOpen:
|
||||
@ -124,11 +124,39 @@ class WhileLoopClose:
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
class ExecutionBlockerNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"input": ("*",),
|
||||
"block": ("BOOL",),
|
||||
"verbose": ("BOOL", {"default": False}),
|
||||
},
|
||||
}
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
RETURN_NAMES = ("output",)
|
||||
FUNCTION = "execution_blocker"
|
||||
|
||||
CATEGORY = "Flow Control"
|
||||
|
||||
def execution_blocker(self, input, block, verbose):
|
||||
if block:
|
||||
return (ExecutionBlocker("Blocked Execution" if verbose else None),)
|
||||
return (input,)
|
||||
|
||||
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
|
||||
"WhileLoopOpen": WhileLoopOpen,
|
||||
"WhileLoopClose": WhileLoopClose,
|
||||
"ExecutionBlocker": ExecutionBlockerNode,
|
||||
}
|
||||
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WhileLoopOpen": "While Loop Open",
|
||||
"WhileLoopClose": "While Loop Close",
|
||||
"ExecutionBlocker": "Execution Blocker",
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from comfy.graph_utils import GraphBuilder
|
||||
import torch
|
||||
|
||||
class AccumulateNode:
|
||||
def __init__(self):
|
||||
@ -228,6 +229,82 @@ class ForLoopClose:
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
class DebugPrint:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("*",),
|
||||
"label": ("STRING", {"multiline": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "debug_print"
|
||||
|
||||
CATEGORY = "InversionDemo Nodes"
|
||||
|
||||
def debugtype(self, value):
|
||||
if isinstance(value, list):
|
||||
result = "["
|
||||
for i, v in enumerate(value):
|
||||
result += (self.debugtype(v) + ",")
|
||||
result += "]"
|
||||
elif isinstance(value, tuple):
|
||||
result = "("
|
||||
for i, v in enumerate(value):
|
||||
result += (self.debugtype(v) + ",")
|
||||
result += ")"
|
||||
elif isinstance(value, dict):
|
||||
result = "{"
|
||||
for k, v in value.items():
|
||||
result += ("%s: %s," % (self.debugtype(k), self.debugtype(v)))
|
||||
result += "}"
|
||||
elif isinstance(value, str):
|
||||
result = "'%s'" % value
|
||||
elif isinstance(value, bool) or isinstance(value, int) or isinstance(value, float):
|
||||
result = str(value)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
result = "Tensor[%s]" % str(value.shape)
|
||||
else:
|
||||
result = type(value).__name__
|
||||
return result
|
||||
|
||||
def debug_print(self, value, label):
|
||||
print("[%s]: %s" % (label, self.debugtype(value)))
|
||||
return (value,)
|
||||
|
||||
NUM_LIST_SOCKETS = 10
|
||||
class MakeListNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value1": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
"value%d" % i: ("*",) for i in range(1, NUM_LIST_SOCKETS)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "make_list"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
|
||||
CATEGORY = "InversionDemo Nodes"
|
||||
|
||||
def make_list(self, **kwargs):
|
||||
result = []
|
||||
for i in range(NUM_LIST_SOCKETS):
|
||||
if "value%d" % i in kwargs:
|
||||
result.append(kwargs["value%d" % i])
|
||||
return (result,)
|
||||
|
||||
UTILITY_NODE_CLASS_MAPPINGS = {
|
||||
"AccumulateNode": AccumulateNode,
|
||||
@ -238,6 +315,8 @@ UTILITY_NODE_CLASS_MAPPINGS = {
|
||||
"ForLoopOpen": ForLoopOpen,
|
||||
"ForLoopClose": ForLoopClose,
|
||||
"IntMathOperation": IntMathOperation,
|
||||
"DebugPrint": DebugPrint,
|
||||
"MakeListNode": MakeListNode,
|
||||
}
|
||||
UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"AccumulateNode": "Accumulate",
|
||||
@ -248,4 +327,6 @@ UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ForLoopOpen": "For Loop Open",
|
||||
"ForLoopClose": "For Loop Close",
|
||||
"IntMathOperation": "Int Math Operation",
|
||||
"DebugPrint": "Debug Print",
|
||||
"MakeListNode": "Make List",
|
||||
}
|
||||
|
||||
71
execution.py
71
execution.py
@ -14,7 +14,7 @@ import nodes
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.graph_utils
|
||||
from comfy.graph_utils import is_link
|
||||
from comfy.graph_utils import is_link, ExecutionBlocker
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@ -185,7 +185,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None):
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
@ -204,12 +204,31 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
if intput_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
execution_block = None
|
||||
for k, v in input_data_all.items():
|
||||
for input in v:
|
||||
if isinstance(v, ExecutionBlocker):
|
||||
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
|
||||
break
|
||||
|
||||
if execution_block is None:
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
results.append(execution_block)
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
execution_block = None
|
||||
for k, v in input_dict.items():
|
||||
if isinstance(v, ExecutionBlocker):
|
||||
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
|
||||
break
|
||||
if execution_block is None:
|
||||
results.append(getattr(obj, func)(**input_dict))
|
||||
else:
|
||||
results.append(execution_block)
|
||||
return results
|
||||
|
||||
def merge_result_data(results, obj):
|
||||
@ -227,12 +246,12 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
@ -243,11 +262,19 @@ def get_output_data(obj, input_data_all):
|
||||
# Perform an expansion, but do not append results
|
||||
has_subgraph = True
|
||||
new_graph = r['expand']
|
||||
subgraph_results.append((new_graph, r.get("result", None)))
|
||||
result = r.get("result", None)
|
||||
if isinstance(result, ExecutionBlocker):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif 'result' in r:
|
||||
results.append(r['result'])
|
||||
subgraph_results.append((None, r['result']))
|
||||
result = r.get("result", None)
|
||||
if isinstance(result, ExecutionBlocker):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
results.append(r)
|
||||
|
||||
if has_subgraph:
|
||||
@ -315,13 +342,31 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if x not in input_data_all]
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all]
|
||||
if len(required_inputs) > 0:
|
||||
for i in required_inputs:
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.SLEEPING, None, None)
|
||||
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all)
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": unique_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": "Execution Blocked: %s" % block.message,
|
||||
"exception_type": "ExecutionBlocked",
|
||||
"traceback": [],
|
||||
"current_inputs": [],
|
||||
"current_outputs": [],
|
||||
}
|
||||
server.send_sync("execution_error", mes, server.client_id)
|
||||
return ExecutionBlocker(None)
|
||||
else:
|
||||
return block
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb)
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
@ -414,7 +459,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
try:
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
prompt[unique_id]['is_changed'] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except:
|
||||
to_delete = True
|
||||
else:
|
||||
@ -719,7 +764,7 @@ def validate_inputs(prompt, item, validated):
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user