From 5d729658630c5a80756af2a30988e9237ced5ee6 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Fri, 28 Jul 2023 22:28:18 -0700 Subject: [PATCH] Implement conditional Execution Blocking Execution blocking can be done by returning an `ExecutionBlocker` (imported from graph_utils) either in place of results or as a specific output. Any node that uses an `ExecutionBlocker` as input will be skipped. This operates on a per-entry basis when inputs are lists. If the `ExecutionBlocker` is initialized with an error message, that message will be displayed on the first node it's used on (and further downstream nodes will be silently skipped). --- comfy/graph_utils.py | 5 ++ .../conditions.py | 32 ++++++++ .../flow_control.py | 30 ++++++- .../utility_nodes.py | 81 +++++++++++++++++++ execution.py | 71 +++++++++++++--- 5 files changed, 205 insertions(+), 14 deletions(-) diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py index e436840ff..869f6154d 100644 --- a/comfy/graph_utils.py +++ b/comfy/graph_utils.py @@ -63,6 +63,11 @@ class GraphBuilder: id = self.prefix + id del self.nodes[id] +# Return this from a node and any users will be blocked with the given error message. +class ExecutionBlocker: + def __init__(self, message): + self.message = message + class Node: def __init__(self, id, class_type, inputs): self.id = id diff --git a/custom_nodes/execution-inversion-demo-comfyui/conditions.py b/custom_nodes/execution-inversion-demo-comfyui/conditions.py index 0d34b10cd..a7c71a362 100644 --- a/custom_nodes/execution-inversion-demo-comfyui/conditions.py +++ b/custom_nodes/execution-inversion-demo-comfyui/conditions.py @@ -147,11 +147,42 @@ class ToBoolNode: return (result,) +class BoolOperationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("BOOL",), + "b": ("BOOL",), + "op": (["a AND b", "a OR b", "a XOR b", "NOT a"],), + }, + } + + RETURN_TYPES = ("BOOL",) + FUNCTION = "bool_operation" + + CATEGORY = "InversionDemo Nodes" + + def bool_operation(self, a, b, op): + if op == "a AND b": + return (a and b,) + elif op == "a OR b": + return (a or b,) + elif op == "a XOR b": + return (a ^ b,) + elif op == "NOT a": + return (not a,) + + CONDITION_NODE_CLASS_MAPPINGS = { "IntConditions": IntConditions, "FloatConditions": FloatConditions, "StringConditions": StringConditions, "ToBoolNode": ToBoolNode, + "BoolOperationNode": BoolOperationNode, } CONDITION_NODE_DISPLAY_NAME_MAPPINGS = { @@ -159,4 +190,5 @@ CONDITION_NODE_DISPLAY_NAME_MAPPINGS = { "FloatConditions": "Float Condition", "StringConditions": "String Condition", "ToBoolNode": "To Bool", + "BoolOperationNode": "Bool Operation", } diff --git a/custom_nodes/execution-inversion-demo-comfyui/flow_control.py b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py index 38ddea8e1..39daa2eb3 100644 --- a/custom_nodes/execution-inversion-demo-comfyui/flow_control.py +++ b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py @@ -1,4 +1,4 @@ -from comfy.graph_utils import GraphBuilder, is_link +from comfy.graph_utils import GraphBuilder, is_link, ExecutionBlocker NUM_FLOW_SOCKETS = 5 class WhileLoopOpen: @@ -124,11 +124,39 @@ class WhileLoopClose: "expand": graph.finalize(), } +class ExecutionBlockerNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "input": ("*",), + "block": ("BOOL",), + "verbose": ("BOOL", {"default": False}), + }, + } + return inputs + + RETURN_TYPES = ("*",) + RETURN_NAMES = ("output",) + FUNCTION = "execution_blocker" + + CATEGORY = "Flow Control" + + def execution_blocker(self, input, block, verbose): + if block: + return (ExecutionBlocker("Blocked Execution" if verbose else None),) + return (input,) + FLOW_CONTROL_NODE_CLASS_MAPPINGS = { "WhileLoopOpen": WhileLoopOpen, "WhileLoopClose": WhileLoopClose, + "ExecutionBlocker": ExecutionBlockerNode, } FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = { "WhileLoopOpen": "While Loop Open", "WhileLoopClose": "While Loop Close", + "ExecutionBlocker": "Execution Blocker", } diff --git a/custom_nodes/execution-inversion-demo-comfyui/utility_nodes.py b/custom_nodes/execution-inversion-demo-comfyui/utility_nodes.py index 656b0da0a..4ee2b2ae4 100644 --- a/custom_nodes/execution-inversion-demo-comfyui/utility_nodes.py +++ b/custom_nodes/execution-inversion-demo-comfyui/utility_nodes.py @@ -1,4 +1,5 @@ from comfy.graph_utils import GraphBuilder +import torch class AccumulateNode: def __init__(self): @@ -228,6 +229,82 @@ class ForLoopClose: "expand": graph.finalize(), } +class DebugPrint: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("*",), + "label": ("STRING", {"multiline": False}), + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "debug_print" + + CATEGORY = "InversionDemo Nodes" + + def debugtype(self, value): + if isinstance(value, list): + result = "[" + for i, v in enumerate(value): + result += (self.debugtype(v) + ",") + result += "]" + elif isinstance(value, tuple): + result = "(" + for i, v in enumerate(value): + result += (self.debugtype(v) + ",") + result += ")" + elif isinstance(value, dict): + result = "{" + for k, v in value.items(): + result += ("%s: %s," % (self.debugtype(k), self.debugtype(v))) + result += "}" + elif isinstance(value, str): + result = "'%s'" % value + elif isinstance(value, bool) or isinstance(value, int) or isinstance(value, float): + result = str(value) + elif isinstance(value, torch.Tensor): + result = "Tensor[%s]" % str(value.shape) + else: + result = type(value).__name__ + return result + + def debug_print(self, value, label): + print("[%s]: %s" % (label, self.debugtype(value))) + return (value,) + +NUM_LIST_SOCKETS = 10 +class MakeListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value1": ("*",), + }, + "optional": { + "value%d" % i: ("*",) for i in range(1, NUM_LIST_SOCKETS) + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "make_list" + OUTPUT_IS_LIST = (True,) + + CATEGORY = "InversionDemo Nodes" + + def make_list(self, **kwargs): + result = [] + for i in range(NUM_LIST_SOCKETS): + if "value%d" % i in kwargs: + result.append(kwargs["value%d" % i]) + return (result,) UTILITY_NODE_CLASS_MAPPINGS = { "AccumulateNode": AccumulateNode, @@ -238,6 +315,8 @@ UTILITY_NODE_CLASS_MAPPINGS = { "ForLoopOpen": ForLoopOpen, "ForLoopClose": ForLoopClose, "IntMathOperation": IntMathOperation, + "DebugPrint": DebugPrint, + "MakeListNode": MakeListNode, } UTILITY_NODE_DISPLAY_NAME_MAPPINGS = { "AccumulateNode": "Accumulate", @@ -248,4 +327,6 @@ UTILITY_NODE_DISPLAY_NAME_MAPPINGS = { "ForLoopOpen": "For Loop Open", "ForLoopClose": "For Loop Close", "IntMathOperation": "Int Math Operation", + "DebugPrint": "Debug Print", + "MakeListNode": "Make List", } diff --git a/execution.py b/execution.py index c537dc313..9fcd26b97 100644 --- a/execution.py +++ b/execution.py @@ -14,7 +14,7 @@ import nodes import comfy.model_management import comfy.graph_utils -from comfy.graph_utils import is_link +from comfy.graph_utils import is_link, ExecutionBlocker class ExecutionResult(Enum): SUCCESS = 0 @@ -185,7 +185,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp input_data_all[x] = [unique_id] return input_data_all -def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None): # check if node wants the lists intput_is_list = False if hasattr(obj, "INPUT_IS_LIST"): @@ -204,12 +204,31 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): if intput_is_list: if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**input_data_all)) + execution_block = None + for k, v in input_data_all.items(): + for input in v: + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb is not None else v + break + + if execution_block is None: + results.append(getattr(obj, func)(**input_data_all)) + else: + results.append(execution_block) else: for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + input_dict = slice_dict(input_data_all, i) + execution_block = None + for k, v in input_dict.items(): + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb is not None else v + break + if execution_block is None: + results.append(getattr(obj, func)(**input_dict)) + else: + results.append(execution_block) return results def merge_result_data(results, obj): @@ -227,12 +246,12 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -def get_output_data(obj, input_data_all): +def get_output_data(obj, input_data_all, execution_block_cb=None): results = [] uis = [] subgraph_results = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb) has_subgraph = False for i in range(len(return_values)): r = return_values[i] @@ -243,11 +262,19 @@ def get_output_data(obj, input_data_all): # Perform an expansion, but do not append results has_subgraph = True new_graph = r['expand'] - subgraph_results.append((new_graph, r.get("result", None))) + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) elif 'result' in r: - results.append(r['result']) - subgraph_results.append((None, r['result'])) + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: + if isinstance(r, ExecutionBlocker): + r = tuple([r] * len(obj.RETURN_TYPES)) results.append(r) if has_subgraph: @@ -315,13 +342,31 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, if hasattr(obj, "check_lazy_status"): required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) - required_inputs = [x for x in required_inputs if x not in input_data_all] + required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all] if len(required_inputs) > 0: for i in required_inputs: execution_list.make_input_strong_link(unique_id, i) return (ExecutionResult.SLEEPING, None, None) - output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all) + def execution_block_cb(block): + if block.message is not None: + mes = { + "prompt_id": prompt_id, + "node_id": unique_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": "Execution Blocked: %s" % block.message, + "exception_type": "ExecutionBlocked", + "traceback": [], + "current_inputs": [], + "current_outputs": [], + } + server.send_sync("execution_error", mes, server.client_id) + return ExecutionBlocker(None) + else: + return block + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb) if len(output_ui) > 0: outputs_ui[unique_id] = output_ui if server.client_id is not None: @@ -414,7 +459,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item try: #is_changed = class_def.IS_CHANGED(**input_data_all) is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") - prompt[unique_id]['is_changed'] = is_changed + prompt[unique_id]['is_changed'] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except: to_delete = True else: @@ -719,7 +764,7 @@ def validate_inputs(prompt, item, validated): #ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for i, r in enumerate(ret): - if r is not True: + if r is not True and not isinstance(r, ExecutionBlocker): details = f"{x}" if r is not False: details += f" - {str(r)}"