diff --git a/execution.py b/execution.py index 5f2ce6c75..79fec295e 100644 --- a/execution.py +++ b/execution.py @@ -1,602 +1,602 @@ -import os -import sys -import copy -import json -import threading -import heapq -import traceback -import gc -import time -import itertools - -import torch -import nodes - -import comfy.model_management - -def get_input_data_batches(input_data_all): - """Given input data that can contain combinatorial input values, returns all - the possible batches that can be made by combining the different input - values together.""" - - input_to_index = {} - index_to_values = [] - - i = 0 - for input_name, value in input_data_all.items(): - if isinstance(value, dict) and "combinatorial" in value: - input_to_index[input_name] = i - index_to_values.append(value["values"]) - i += 1 - - if len(index_to_values) == 0: - # No combinatorial options. - return [input_data_all] - - batches = [] - - for combination in list(itertools.product(*index_to_values)): - batch = {} - for input_name, value in input_data_all.items(): - if isinstance(value, dict) and "combinatorial" in value: - combination_index = input_to_index[input_name] - batch[input_name] = [combination[combination_index]] # - else: - # already made into a list by get_input_data - batch[input_name] = value - batches.append(batch) - - return batches - - -def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): - """Given input data from the prompt, returns a list of input data dicts for - each combinatorial batch.""" - valid_inputs = class_def.INPUT_TYPES() - input_data_all = {} - for x in inputs: - input_data = inputs[x] - required_or_optional = ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]) - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - return None - - # This is a list of outputs for each batch of combinatorial inputs. - # Without any combinatorial inputs, it's a list of length 1. - outputs_for_all_batches = outputs[input_unique_id] - - def flatten(list_of_lists): - return list(itertools.chain.from_iterable(list_of_lists)) - - if len(outputs_for_all_batches) == 1: - # Single batch, no combinatorial stuff - input_data_all[x] = outputs_for_all_batches[0][output_index] - else: - # Make the outputs into a list for map-over-list use - # (they are themselves lists so flatten them afterwards) - input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] - input_values = flatten(input_values) - input_data_all[x] = input_values - elif is_combinatorial_input(input_data): - if required_or_optional: - input_data_all[x] = { "combinatorial": True, "values": input_data["values"] } - else: - if required_or_optional: - input_data_all[x] = [input_data] - - if "hidden" in valid_inputs: - h = valid_inputs["hidden"] - for x in h: - if h[x] == "PROMPT": - input_data_all[x] = [prompt] - if h[x] == "EXTRA_PNGINFO": - if "extra_pnginfo" in extra_data: - input_data_all[x] = [extra_data['extra_pnginfo']] - if h[x] == "UNIQUE_ID": - input_data_all[x] = [unique_id] - - input_data_all_batches = get_input_data_batches(input_data_all) - - return input_data_all_batches - -def slice_lists_into_dict(d, i): - """ - get a slice of inputs, repeat last input when list isn't long enough - d={ "seed": [ 1, 2, 3 ], "steps": [ 4, 8 ] }, i=2 -> { "seed": 3, "steps": 8 } - """ - d_new = {} - for k, v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new - -def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): - # check if node wants the lists - inputs_are_lists = False - if hasattr(obj, "INPUTS_ARE_LISTS"): - inputs_are_lists = obj.INPUTS_ARE_LISTS - - def format_dict(d): - s = [] - for k,v in d.items(): - st = f"{k}: " - if isinstance(v, list): - st += f"list[len: {len(v)}][" - i = [] - for v2 in v: - i.append(v2.__class__.__name__) - st += ",".join(i) + "]" - else: - st += str(type(v)) - s.append(st) - return "( " + ", ".join(s) + " )" - - max_len_input = max(len(x) for x in input_data_all.values()) - - results = [] - if inputs_are_lists: - if allow_interrupt: - nodes.before_node_execution() - results.append(getattr(obj, func)(**input_data_all)) - else: - for i in range(max_len_input): - if allow_interrupt: - nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i))) - return results - -def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): - all_outputs = [] - all_outputs_ui = [] - total_batches = len(input_data_all_batches) - - for batch_num, batch in enumerate(input_data_all_batches): - return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) - - uis = [] - results = [] - - for r in return_values: - if isinstance(r, dict): - if 'ui' in r: - uis.append(r['ui']) - if 'result' in r: - results.append(r['result']) - else: - results.append(r) - - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST - - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) - else: - output.append([o[i] for o in results]) - - output_ui = None - if len(uis) > 0: - output_ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - - # update the UI after each batch finishes - if server.client_id is not None: - message = { - "node": unique_id, - "output": output_ui, - "prompt_id": prompt_id, - "batch_num": batch_num, - "total_batches": total_batches - } - server.send_sync("executed", message, server.client_id) - - all_outputs.append(output) - all_outputs_ui.append(output_ui) - - return all_outputs, all_outputs_ui - -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if unique_id in outputs: - return - - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) - - input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id) - obj = class_def() - - output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) - outputs[unique_id] = output_data_from_batches - outputs_ui[unique_id] = output_ui_from_batches - executed.add(unique_id) - -def recursive_will_execute(prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - will_execute = [] - if unique_id in outputs: - return [] - - for x in inputs: - input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id) - - return will_execute + [unique_id] - -def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - is_changed_old = '' - is_changed = '' - to_delete = False - if hasattr(class_def, 'IS_CHANGED'): - if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: - is_changed_old = old_prompt[unique_id]['is_changed'] - if 'is_changed' not in prompt[unique_id]: - input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all_batches is not None: - try: - #is_changed = class_def.IS_CHANGED(**input_data_all) - for batch in input_data_all_batches: - if map_node_over_list(class_def, batch, "IS_CHANGED"): - is_changed = True - break - prompt[unique_id]['is_changed'] = is_changed - except: - to_delete = True - else: - is_changed = prompt[unique_id]['is_changed'] - - if unique_id not in outputs: - return True - - if not to_delete: - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True - - if to_delete: - d = outputs.pop(unique_id) - del d - return to_delete - -class PromptExecutor: - def __init__(self, server): - self.outputs = {} - self.outputs_ui = {} - self.old_prompt = {} - self.server = server - - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - nodes.interrupt_processing(False) - - if "client_id" in extra_data: - self.server.client_id = extra_data["client_id"] - else: - self.server.client_id = None - - execution_start_time = time.perf_counter() - if self.server.client_id is not None: - self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) - - with torch.inference_mode(): - #delete cached outputs if nodes don't exist for them - to_delete = [] - for o in self.outputs: - if o not in prompt: - to_delete += [o] - for o in to_delete: - d = self.outputs.pop(o) - del d - - for x in prompt: - recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) - - current_outputs = set(self.outputs.keys()) - for x in list(self.outputs_ui.keys()): - if x not in current_outputs: - d = self.outputs_ui.pop(x) - del d - - if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) - executed = set() - try: - to_execute = [] - for x in list(execute_outputs): - to_execute += [(0, x)] - - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - x = to_execute.pop(0)[-1] - - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) - except Exception as e: - if isinstance(e, comfy.model_management.InterruptProcessingException): - print("Processing interrupted") - else: - message = str(traceback.format_exc()) - print(message) - if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) - - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - finally: - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) - self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) - - print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) - gc.collect() - comfy.model_management.soft_empty_cache() - - -def is_combinatorial_input(val): - return isinstance(val, dict) and "__inputType__" in val - - -def get_raw_inputs(raw_val): - if isinstance(raw_val, list): - # link to another node - return [raw_val] - elif is_combinatorial_input(raw_val): - return raw_val["values"] - return [raw_val] - - -def clamp_input(val, info, class_type, obj_class, x): - if is_combinatorial_input(val): - if len(val["values"]) == 0: - return (False, "Combinatorial input has no values in its list. {}, {}".format(class_type, x)) - for i, val_choice in enumerate(val["values"]): - r = clamp_input(val_choice, info, class_type, obj_class, x) - if r[0] == False: - return r - val["values"][i] = r[1] - return (True, val) - - type_input = info[0] - - if type_input == "INT": - val = int(val) - if type_input == "FLOAT": - val = float(val) - if type_input == "STRING": - val = str(val) - - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - return (False, "Value smaller than min. {}, {}".format(class_type, x)) - if "max" in info[1] and val > info[1]["max"]: - return (False, "Value bigger than max. {}, {}".format(class_type, x)) - - return (True, val) - -def validate_inputs(prompt, item, validated): - unique_id = item - if unique_id in validated: - return validated[unique_id] - - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] - - class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] - for x in required_inputs: - if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x)) - - val = inputs[x] - - info = required_inputs[x] - type_input = info[0] - - if isinstance(val, list): - if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x)) - o_id = val[0] - o_class_type = prompt[o_id]['class_type'] - r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id, validated) - if r[0] == False: - validated[o_id] = r - return r - else: - r = clamp_input(val, info, class_type, obj_class, x) - if r[0] == False: - return r - - inputs[x] = r[1] - - if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all_batches = get_input_data(inputs, obj_class, unique_id) - #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - for batch in input_data_all_batches: - ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") - for r in ret: - if r != True: - return (False, "{}, {}".format(class_type, r)) - else: - if isinstance(type_input, list): - # Account for more than one combinatorial value - raw_vals = get_raw_inputs(val) - for raw_val in raw_vals: - if raw_val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, raw_val, type_input)) - - ret = (True, "") - validated[unique_id] = ret - return ret - -def validate_prompt(prompt): - outputs = set() - for x in prompt: - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True: - outputs.add(x) - - if len(outputs) == 0: - return (False, "Prompt has no outputs") - - good_outputs = set() - errors = [] - validated = {} - for o in outputs: - valid = False - reason = "" - try: - m = validate_inputs(prompt, o, validated) - valid = m[0] - reason = m[1] - except Exception as e: - print(traceback.format_exc()) - valid = False - reason = "Parsing error" - - if valid == True: - good_outputs.add(o) - else: - print("Failed to validate prompt for output {} {}".format(o, reason)) - print("output will be ignored") - errors += [(o, reason)] - - if len(good_outputs) == 0: - errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) - - return (True, "", list(good_outputs)) - - -class PromptQueue: - def __init__(self, server): - self.server = server - self.mutex = threading.RLock() - self.not_empty = threading.Condition(self.mutex) - self.task_counter = 0 - self.queue = [] - self.currently_running = {} - self.history = {} - server.prompt_queue = self - - def put(self, item): - with self.mutex: - heapq.heappush(self.queue, item) - self.server.queue_updated() - self.not_empty.notify() - - def get(self): - with self.not_empty: - while len(self.queue) == 0: - self.not_empty.wait() - item = heapq.heappop(self.queue) - i = self.task_counter - self.currently_running[i] = copy.deepcopy(item) - self.task_counter += 1 - self.server.queue_updated() - return (item, i) - - def task_done(self, item_id, outputs): - with self.mutex: - prompt = self.currently_running.pop(item_id) - self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } - for o in outputs: - self.history[prompt[1]]["outputs"][o] = outputs[o] - self.server.queue_updated() - - def get_current_queue(self): - with self.mutex: - out = [] - for x in self.currently_running.values(): - out += [x] - return (out, copy.deepcopy(self.queue)) - - def get_tasks_remaining(self): - with self.mutex: - return len(self.queue) + len(self.currently_running) - - def wipe_queue(self): - with self.mutex: - self.queue = [] - self.server.queue_updated() - - def delete_queue_item(self, function): - with self.mutex: - for x in range(len(self.queue)): - if function(self.queue[x]): - if len(self.queue) == 1: - self.wipe_queue() - else: - self.queue.pop(x) - heapq.heapify(self.queue) - self.server.queue_updated() - return True - return False - - def get_history(self): - with self.mutex: - return copy.deepcopy(self.history) - - def wipe_history(self): - with self.mutex: - self.history = {} - - def delete_history_item(self, id_to_delete): - with self.mutex: - self.history.pop(id_to_delete, None) +import os +import sys +import copy +import json +import threading +import heapq +import traceback +import gc +import time +import itertools + +import torch +import nodes + +import comfy.model_management + +def get_input_data_batches(input_data_all): + """Given input data that can contain combinatorial input values, returns all + the possible batches that can be made by combining the different input + values together.""" + + input_to_index = {} + index_to_values = [] + + i = 0 + for input_name, value in input_data_all.items(): + if isinstance(value, dict) and "combinatorial" in value: + input_to_index[input_name] = i + index_to_values.append(value["values"]) + i += 1 + + if len(index_to_values) == 0: + # No combinatorial options. + return [input_data_all] + + batches = [] + + for combination in list(itertools.product(*index_to_values)): + batch = {} + for input_name, value in input_data_all.items(): + if isinstance(value, dict) and "combinatorial" in value: + combination_index = input_to_index[input_name] + batch[input_name] = [combination[combination_index]] # + else: + # already made into a list by get_input_data + batch[input_name] = value + batches.append(batch) + + return batches + + +def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): + """Given input data from the prompt, returns a list of input data dicts for + each combinatorial batch.""" + valid_inputs = class_def.INPUT_TYPES() + input_data_all = {} + for x in inputs: + input_data = inputs[x] + required_or_optional = ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]) + + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id not in outputs: + return None + + # This is a list of outputs for each batch of combinatorial inputs. + # Without any combinatorial inputs, it's a list of length 1. + outputs_for_all_batches = outputs[input_unique_id] + + def flatten(list_of_lists): + return list(itertools.chain.from_iterable(list_of_lists)) + + if len(outputs_for_all_batches) == 1: + # Single batch, no combinatorial stuff + input_data_all[x] = outputs_for_all_batches[0][output_index] + else: + # Make the outputs into a list for map-over-list use + # (they are themselves lists so flatten them afterwards) + input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] + input_values = flatten(input_values) + input_data_all[x] = input_values + elif is_combinatorial_input(input_data): + if required_or_optional: + input_data_all[x] = { "combinatorial": True, "values": input_data["values"] } + else: + if required_or_optional: + input_data_all[x] = [input_data] + + if "hidden" in valid_inputs: + h = valid_inputs["hidden"] + for x in h: + if h[x] == "PROMPT": + input_data_all[x] = [prompt] + if h[x] == "EXTRA_PNGINFO": + if "extra_pnginfo" in extra_data: + input_data_all[x] = [extra_data['extra_pnginfo']] + if h[x] == "UNIQUE_ID": + input_data_all[x] = [unique_id] + + input_data_all_batches = get_input_data_batches(input_data_all) + + return input_data_all_batches + +def slice_lists_into_dict(d, i): + """ + get a slice of inputs, repeat last input when list isn't long enough + d={ "seed": [ 1, 2, 3 ], "steps": [ 4, 8 ] }, i=2 -> { "seed": 3, "steps": 8 } + """ + d_new = {} + for k, v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): + # check if node wants the lists + inputs_are_lists = False + if hasattr(obj, "INPUTS_ARE_LISTS"): + inputs_are_lists = obj.INPUTS_ARE_LISTS + + def format_dict(d): + s = [] + for k,v in d.items(): + st = f"{k}: " + if isinstance(v, list): + st += f"list[len: {len(v)}][" + i = [] + for v2 in v: + i.append(v2.__class__.__name__) + st += ",".join(i) + "]" + else: + st += str(type(v)) + s.append(st) + return "( " + ", ".join(s) + " )" + + max_len_input = max(len(x) for x in input_data_all.values()) + + results = [] + if inputs_are_lists: + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**input_data_all)) + else: + for i in range(max_len_input): + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i))) + return results + +def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): + all_outputs = [] + all_outputs_ui = [] + total_batches = len(input_data_all_batches) + + for batch_num, batch in enumerate(input_data_all_batches): + return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) + + uis = [] + results = [] + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) + else: + results.append(r) + + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + output_ui = None + if len(uis) > 0: + output_ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + + # update the UI after each batch finishes + if server.client_id is not None: + message = { + "node": unique_id, + "output": output_ui, + "prompt_id": prompt_id, + "batch_num": batch_num, + "total_batches": total_batches + } + server.send_sync("executed", message, server.client_id) + + all_outputs.append(output) + all_outputs_ui.append(output_ui) + + return all_outputs, all_outputs_ui + +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): + unique_id = current_item + inputs = prompt[unique_id]['inputs'] + class_type = prompt[unique_id]['class_type'] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if unique_id in outputs: + return + + for x in inputs: + input_data = inputs[x] + + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id not in outputs: + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + + input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) + if server.client_id is not None: + server.last_node_id = unique_id + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id) + obj = class_def() + + output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) + outputs[unique_id] = output_data_from_batches + outputs_ui[unique_id] = output_ui_from_batches + executed.add(unique_id) + +def recursive_will_execute(prompt, outputs, current_item): + unique_id = current_item + inputs = prompt[unique_id]['inputs'] + will_execute = [] + if unique_id in outputs: + return [] + + for x in inputs: + input_data = inputs[x] + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id not in outputs: + will_execute += recursive_will_execute(prompt, outputs, input_unique_id) + + return will_execute + [unique_id] + +def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): + unique_id = current_item + inputs = prompt[unique_id]['inputs'] + class_type = prompt[unique_id]['class_type'] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + + is_changed_old = '' + is_changed = '' + to_delete = False + if hasattr(class_def, 'IS_CHANGED'): + if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: + is_changed_old = old_prompt[unique_id]['is_changed'] + if 'is_changed' not in prompt[unique_id]: + input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs) + if input_data_all_batches is not None: + try: + #is_changed = class_def.IS_CHANGED(**input_data_all) + for batch in input_data_all_batches: + if map_node_over_list(class_def, batch, "IS_CHANGED"): + is_changed = True + break + prompt[unique_id]['is_changed'] = is_changed + except: + to_delete = True + else: + is_changed = prompt[unique_id]['is_changed'] + + if unique_id not in outputs: + return True + + if not to_delete: + if is_changed != is_changed_old: + to_delete = True + elif unique_id not in old_prompt: + to_delete = True + elif inputs == old_prompt[unique_id]['inputs']: + for x in inputs: + input_data = inputs[x] + + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id in outputs: + to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) + else: + to_delete = True + if to_delete: + break + else: + to_delete = True + + if to_delete: + d = outputs.pop(unique_id) + del d + return to_delete + +class PromptExecutor: + def __init__(self, server): + self.outputs = {} + self.outputs_ui = {} + self.old_prompt = {} + self.server = server + + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + nodes.interrupt_processing(False) + + if "client_id" in extra_data: + self.server.client_id = extra_data["client_id"] + else: + self.server.client_id = None + + execution_start_time = time.perf_counter() + if self.server.client_id is not None: + self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + + with torch.inference_mode(): + #delete cached outputs if nodes don't exist for them + to_delete = [] + for o in self.outputs: + if o not in prompt: + to_delete += [o] + for o in to_delete: + d = self.outputs.pop(o) + del d + + for x in prompt: + recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) + + current_outputs = set(self.outputs.keys()) + for x in list(self.outputs_ui.keys()): + if x not in current_outputs: + d = self.outputs_ui.pop(x) + del d + + if self.server.client_id is not None: + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + executed = set() + try: + to_execute = [] + for x in list(execute_outputs): + to_execute += [(0, x)] + + while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + x = to_execute.pop(0)[-1] + + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) + except Exception as e: + if isinstance(e, comfy.model_management.InterruptProcessingException): + print("Processing interrupted") + else: + message = str(traceback.format_exc()) + print(message) + if self.server.client_id is not None: + self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) + + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + finally: + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) + gc.collect() + comfy.model_management.soft_empty_cache() + + +def is_combinatorial_input(val): + return isinstance(val, dict) and "__inputType__" in val + + +def get_raw_inputs(raw_val): + if isinstance(raw_val, list): + # link to another node + return [raw_val] + elif is_combinatorial_input(raw_val): + return raw_val["values"] + return [raw_val] + + +def clamp_input(val, info, class_type, obj_class, x): + if is_combinatorial_input(val): + if len(val["values"]) == 0: + return (False, "Combinatorial input has no values in its list. {}, {}".format(class_type, x)) + for i, val_choice in enumerate(val["values"]): + r = clamp_input(val_choice, info, class_type, obj_class, x) + if r[0] == False: + return r + val["values"][i] = r[1] + return (True, val) + + type_input = info[0] + + if type_input == "INT": + val = int(val) + if type_input == "FLOAT": + val = float(val) + if type_input == "STRING": + val = str(val) + + if len(info) > 1: + if "min" in info[1] and val < info[1]["min"]: + return (False, "Value smaller than min. {}, {}".format(class_type, x)) + if "max" in info[1] and val > info[1]["max"]: + return (False, "Value bigger than max. {}, {}".format(class_type, x)) + + return (True, val) + +def validate_inputs(prompt, item, validated): + unique_id = item + if unique_id in validated: + return validated[unique_id] + + inputs = prompt[unique_id]['inputs'] + class_type = prompt[unique_id]['class_type'] + obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] + + class_inputs = obj_class.INPUT_TYPES() + required_inputs = class_inputs['required'] + for x in required_inputs: + if x not in inputs: + return (False, "Required input is missing. {}, {}".format(class_type, x)) + + val = inputs[x] + + info = required_inputs[x] + type_input = info[0] + + if isinstance(val, list): + if len(val) != 2: + return (False, "Bad Input. {}, {}".format(class_type, x)) + o_id = val[0] + o_class_type = prompt[o_id]['class_type'] + r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES + if r[val[1]] != type_input: + return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) + r = validate_inputs(prompt, o_id, validated) + if r[0] == False: + validated[o_id] = r + return r + else: + r = clamp_input(val, info, class_type, obj_class, x) + if r[0] == False: + return r + + inputs[x] = r[1] + + if hasattr(obj_class, "VALIDATE_INPUTS"): + input_data_all_batches = get_input_data(inputs, obj_class, unique_id) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + for batch in input_data_all_batches: + ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") + for r in ret: + if r != True: + return (False, "{}, {}".format(class_type, r)) + else: + if isinstance(type_input, list): + # Account for more than one combinatorial value + raw_vals = get_raw_inputs(val) + for raw_val in raw_vals: + if raw_val not in type_input: + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, raw_val, type_input)) + + ret = (True, "") + validated[unique_id] = ret + return ret + +def validate_prompt(prompt): + outputs = set() + for x in prompt: + class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] + if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True: + outputs.add(x) + + if len(outputs) == 0: + return (False, "Prompt has no outputs") + + good_outputs = set() + errors = [] + validated = {} + for o in outputs: + valid = False + reason = "" + try: + m = validate_inputs(prompt, o, validated) + valid = m[0] + reason = m[1] + except Exception as e: + print(traceback.format_exc()) + valid = False + reason = "Parsing error" + + if valid == True: + good_outputs.add(o) + else: + print("Failed to validate prompt for output {} {}".format(o, reason)) + print("output will be ignored") + errors += [(o, reason)] + + if len(good_outputs) == 0: + errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) + return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) + + return (True, "", list(good_outputs)) + + +class PromptQueue: + def __init__(self, server): + self.server = server + self.mutex = threading.RLock() + self.not_empty = threading.Condition(self.mutex) + self.task_counter = 0 + self.queue = [] + self.currently_running = {} + self.history = {} + server.prompt_queue = self + + def put(self, item): + with self.mutex: + heapq.heappush(self.queue, item) + self.server.queue_updated() + self.not_empty.notify() + + def get(self): + with self.not_empty: + while len(self.queue) == 0: + self.not_empty.wait() + item = heapq.heappop(self.queue) + i = self.task_counter + self.currently_running[i] = copy.deepcopy(item) + self.task_counter += 1 + self.server.queue_updated() + return (item, i) + + def task_done(self, item_id, outputs): + with self.mutex: + prompt = self.currently_running.pop(item_id) + self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } + for o in outputs: + self.history[prompt[1]]["outputs"][o] = outputs[o] + self.server.queue_updated() + + def get_current_queue(self): + with self.mutex: + out = [] + for x in self.currently_running.values(): + out += [x] + return (out, copy.deepcopy(self.queue)) + + def get_tasks_remaining(self): + with self.mutex: + return len(self.queue) + len(self.currently_running) + + def wipe_queue(self): + with self.mutex: + self.queue = [] + self.server.queue_updated() + + def delete_queue_item(self, function): + with self.mutex: + for x in range(len(self.queue)): + if function(self.queue[x]): + if len(self.queue) == 1: + self.wipe_queue() + else: + self.queue.pop(x) + heapq.heapify(self.queue) + self.server.queue_updated() + return True + return False + + def get_history(self): + with self.mutex: + return copy.deepcopy(self.history) + + def wipe_history(self): + with self.mutex: + self.history = {} + + def delete_history_item(self, id_to_delete): + with self.mutex: + self.history.pop(id_to_delete, None)