From 2ca43c3ace031d3a9c63cd1bc09f1c0394475adc Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sat, 3 Jun 2023 13:27:08 +0900 Subject: [PATCH] change recursion style execution to worklist style execution * Preventing stack overflow in deep recursion in Python. * Handling worklists of arbitrary length * Allowing future extension for loop structures * Eliminating inefficiency due to call stack cost * Improved control over the processing order of nodes --- main.py | 3 +- worklist_execution.py | 427 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 429 insertions(+), 1 deletion(-) create mode 100644 worklist_execution.py diff --git a/main.py b/main.py index 50d3b9a62..c3fc8daba 100644 --- a/main.py +++ b/main.py @@ -24,13 +24,14 @@ if __name__ == "__main__": import yaml import execution +import worklist_execution import folder_paths import server from nodes import init_custom_nodes def prompt_worker(q, server): - e = execution.PromptExecutor(server) + e = worklist_execution.PromptExecutor(server) while True: item, item_id = q.get() e.execute(item[2], item[1], item[3], item[4]) diff --git a/worklist_execution.py b/worklist_execution.py new file mode 100644 index 000000000..9b7a71143 --- /dev/null +++ b/worklist_execution.py @@ -0,0 +1,427 @@ +# This module modifies the original code for experimentation with a new execution structure intended to replace the original execution structure. +# Original: https://github.com/comfyanonymous/ComfyUI/blob/master/execution.py + +import sys +import copy +import traceback +import gc +import time + +import torch +import nodes + +import comfy.model_management +from execution import get_input_data, get_output_data, map_node_over_list, format_value, full_type_name +from queue import Queue + + +DEBUG_FLAG = True + + +def print_dbg(x): + if DEBUG_FLAG: + print(f"[DBG] {x()}") + + +# To handling virtual node +class DummyNode: + @classmethod + def INPUT_TYPES(s): + return {"required": {}} + + RETURN_TYPES = () + FUNCTION = "doit" + + def doit(s, *args, **kwargs): + if len(kwargs) == 1: + key = list(kwargs.keys())[0] + output = kwargs[key] + return (output,) + else: + pass + + +def exception_helper(unique_id, input_data_all, executed, outputs, task): + try: + task() + return None + except comfy.model_management.InterruptProcessingException as iex: + print("Processing interrupted") + + # skip formatting inputs/outputs + error_details = { + "node_id": unique_id, + } + + return executed, False, error_details, iex + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputs in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputs] + + output_data_formatted = {} + for node_id, node_outputs in outputs.items(): + output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + + error_details = { + "node_id": unique_id, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "current_inputs": input_data_formatted, + "current_outputs": output_data_formatted + } + return executed, False, error_details, ex + + +def is_incomplete_input_slots(inputs, outputs): + for x in inputs: + input_data = inputs[x] + + if isinstance(input_data, list): + input_unique_id = input_data[0] + if input_unique_id not in outputs: + return True + + return False + + +def get_class_def(prompt, unique_id): + if 'class_type' not in prompt[unique_id]: + class_def = DummyNode + else: + class_type = prompt[unique_id]['class_type'] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + + return class_def + + +def get_next_nodes_map(prompt): + next_nodes = {} + for key, value in prompt.items(): + inputs = value['inputs'] + + for input_data in inputs.values(): + if isinstance(input_data, list): + input_unique_id = input_data[0] + if input_unique_id in next_nodes: + next_nodes[input_unique_id].add(key) + else: + next_nodes[input_unique_id] = {key} + return next_nodes + + +def worklist_execute(server, prompt, outputs, extra_data, prompt_id, outputs_ui, to_execute, next_nodes): + worklist = Queue() + executed = set() + will_execute = {} + + def add_work(item): + worklist.put(item) + cnt = will_execute.get(item, 0) + will_execute[item] = cnt + 1 + + def get_work(): + item = worklist.get() + cnt = will_execute.get(item, 0) + if cnt <= 0: + del will_execute[item] + else: + will_execute[item] = cnt - 1 + + return str(item) + + def get_progress(): + total = len(executed)+len(will_execute.keys()) + return len(executed)/total + + # init seeds: the nodes that have their output not erased in the input slot are the seeds. + for unique_id in to_execute: + inputs = prompt[unique_id]['inputs'] + class_def = get_class_def(prompt, unique_id) + + if unique_id in outputs: + continue + + if is_incomplete_input_slots(inputs, outputs): + continue + + input_data_all = None + + def task(): + nonlocal input_data_all + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) + + if input_data_all is None: + return + + add_work(unique_id) # add to seed if all input is properly provided + + result = exception_helper(unique_id, input_data_all, executed, outputs, task) + if result is not None: + return result # error state + + while not worklist.empty(): + unique_id = get_work() + + inputs = prompt[unique_id]['inputs'] + class_def = get_class_def(prompt, unique_id) + + print_dbg(lambda: f"work: {unique_id} ({class_def.__name__}) / worklist: {list(worklist.queue)}") + + input_data_all = None + + def task(): + nonlocal input_data_all + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) + + if server.client_id is not None: + server.last_node_id = unique_id + server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id, "progress": get_progress()}, + server.client_id) + obj = class_def() + + output_data, output_ui = get_output_data(obj, input_data_all) + + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui + if server.client_id is not None: + server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, + server.client_id) + executed.add(unique_id) + + result = exception_helper(unique_id, input_data_all, executed, outputs, task) + if result is not None: + return result # error state + else: + if unique_id in next_nodes: + + for next_node in next_nodes[unique_id]: + if next_node in to_execute: + # If all input slots are not completed, do not add to the work. + # This prevents duplicate entries of the same work in the worklist. + # For loop support, it is important to fire only once when the input slot is completed. + if not is_incomplete_input_slots(prompt[next_node]['inputs'], outputs): + add_work(next_node) + + return executed, True, None, None + + +def worklist_will_execute(prompt, outputs, worklist): + visited = set() + + will_execute = [] + + while worklist: + unique_id = str(worklist.pop()) + + if unique_id in visited: # to avoid infinite loop and redundant processing + continue + else: + visited.add(unique_id) + + inputs = prompt[unique_id]['inputs'] + + if unique_id in outputs: + continue + + for x in inputs: + input_data = inputs[x] + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id not in outputs: + worklist.append(input_unique_id) + + will_execute.append(unique_id) + + return will_execute + + +def worklist_output_delete_if_changed(prompt, old_prompt, outputs, next_nodes): + worklist = [] + deleted = set() + + # init seeds + for unique_id, value in prompt.items(): + inputs = value['inputs'] + if 'class_type' not in value: + class_def = DummyNode + else: + class_type = value['class_type'] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + + is_changed_old = '' + is_changed = '' + to_delete = False + + if hasattr(class_def, 'IS_CHANGED'): + if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: + is_changed_old = old_prompt[unique_id]['is_changed'] + if 'is_changed' not in value: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs) + if input_data_all is not None: + try: + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + value['is_changed'] = is_changed + except: + to_delete = True + else: + is_changed = value['is_changed'] + + if unique_id not in outputs: + to_delete = True + else: + if not to_delete: + if is_changed != is_changed_old: + to_delete = True + elif unique_id not in old_prompt: + to_delete = True + elif inputs == old_prompt[unique_id]['inputs']: + for x in inputs: + input_data = inputs[x] + + if isinstance(input_data, list): + input_unique_id = input_data[0] + + if input_unique_id not in outputs: + to_delete = True + break + else: + to_delete = True + + if to_delete: + worklist.append(unique_id) + + # cascade removing + while worklist: + unique_id = worklist.pop() + + if unique_id in deleted: + continue + + if unique_id in outputs: + d = outputs.pop(unique_id) + del d + + new_works = next_nodes.get(unique_id, []) + worklist.extend(new_works) + deleted.add(unique_id) + + return outputs + + +class PromptExecutor: + def __init__(self, server): + self.outputs = {} + self.outputs_ui = {} + self.old_prompt = {} + self.server = server + + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + node_id = error["node_id"] + if "class_type" in prompt[node_id]: + class_type = prompt[node_id]["class_type"] + else: + class_type = "ComponentInput/Output" + + # First, send back the status to the frontend depending + # on the exception type + if isinstance(ex, comfy.model_management.InterruptProcessingException): + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + } + self.server.send_sync("execution_interrupted", mes, self.server.client_id) + else: + if self.server.client_id is not None: + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": error["exception_message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.server.send_sync("execution_error", mes, self.server.client_id) + + # Next, remove the subsequent outputs since they will not be executed + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + nodes.interrupt_processing(False) + + if "client_id" in extra_data: + self.server.client_id = extra_data["client_id"] + else: + self.server.client_id = None + + execution_start_time = time.perf_counter() + if self.server.client_id is not None: + self.server.send_sync("execution_start", {"prompt_id": prompt_id}, self.server.client_id) + + with torch.inference_mode(): + # delete cached outputs if nodes don't exist for them + to_delete = [] + for o in self.outputs: + if o not in prompt: + to_delete += [o] + for o in to_delete: + d = self.outputs.pop(o) + del d + + next_nodes = get_next_nodes_map(prompt) + worklist_output_delete_if_changed(prompt, self.old_prompt, self.outputs, next_nodes) + + current_outputs = set(self.outputs.keys()) + for x in list(self.outputs_ui.keys()): + if x not in current_outputs: + d = self.outputs_ui.pop(x) + del d + + if self.server.client_id is not None: + self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, + self.server.client_id) + + to_execute = worklist_will_execute(prompt, self.outputs, execute_outputs) + + # This call shouldn't raise anything if there's an error deep in + # the actual SD code, instead it will report the node where the + # error was raised + executed, success, error, ex = worklist_execute(self.server, prompt, self.outputs, extra_data, prompt_id, + self.outputs_ui, to_execute, next_nodes) + if success is not True: + self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, self.server.client_id) + + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) + gc.collect() + comfy.model_management.soft_empty_cache()