from __future__ import annotations import asyncio import copy import datetime import heapq import threading import traceback import typing from dataclasses import dataclass from typing import Tuple import sys import torch from comfy.nodes.package import import_all_nodes_in_workspace nodes = import_all_nodes_in_workspace() import comfy.model_management """ A queued item """ QueueTuple = Tuple[float, int | str, dict, dict, list] def get_queue_priority(t: QueueTuple): return t[0] def get_prompt_id(t: QueueTuple): return t[1] def get_prompt(t: QueueTuple): return t[2] def get_extra_data(t: QueueTuple): return t[3] def get_good_outputs(t: QueueTuple): return t[4] class HistoryEntry(typing.TypedDict): prompt: QueueTuple outputs: dict timestamp: float @dataclass class QueueItem: """ An item awaiting processing in the queue """ queue_tuple: QueueTuple completed: asyncio.Future | None def __lt__(self, other: QueueItem): return get_queue_priority(self.queue_tuple) < get_queue_priority(other.queue_tuple) def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: input_data = inputs[x] if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: return None obj = outputs[input_unique_id][output_index] input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ( "optional" in valid_inputs and x in valid_inputs["optional"]): input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] return input_data_all def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # check if node wants the lists input_is_list = False if hasattr(obj, "INPUT_IS_LIST"): input_is_list = obj.INPUT_IS_LIST if len(input_data_all) == 0: max_len_input = 0 else: max_len_input = max([len(x) for x in input_data_all.values()]) # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): d_new = dict() for k, v in d.items(): d_new[k] = v[i if len(v) > i else -1] return d_new results = [] if input_is_list: if allow_interrupt: comfy.model_management.throw_exception_if_processing_interrupted() results.append(getattr(obj, func)(**input_data_all)) elif max_len_input == 0: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)()) else: for i in range(max_len_input): if allow_interrupt: comfy.model_management.throw_exception_if_processing_interrupted() results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) return results def get_output_data(obj, input_data_all): results = [] uis = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) for r in return_values: if isinstance(r, dict): if 'ui' in r: uis.append(r['ui']) if 'result' in r: results.append(r['result']) else: results.append(r) output = [] if len(results) > 0: # check which outputs need concatenating output_is_list = [False] * len(results[0]) if hasattr(obj, "OUTPUT_IS_LIST"): output_is_list = obj.OUTPUT_IS_LIST # merge node execution results for i, is_list in zip(range(len(results[0])), output_is_list): if is_list: output.append([x for o in results for x in o[i]]) else: output.append([o[i] for o in results]) ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui def format_value(x): if x is None: return None elif isinstance(x, (int, float, bool, str)): return x else: return str(x) def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: return (True, None, None) for x in inputs: input_data = inputs[x] if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) if result[0] is not True: # Another node failed further upstream return result input_data_all = None try: input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id}, server.client_id) obj = object_storage.get((unique_id, class_type), None) if obj is None: obj = class_def() object_storage[(unique_id, class_type)] = obj output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: outputs_ui[unique_id] = output_ui if server.client_id is not None: server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id) except comfy.model_management.InterruptProcessingException as iex: print("Processing interrupted") # skip formatting inputs/outputs error_details = { "node_id": unique_id, } return (False, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) input_data_formatted = {} if input_data_all is not None: input_data_formatted = {} for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] output_data_formatted = {} for node_id, node_outputs in outputs.items(): output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] print("!!! Exception during processing !!!") print(traceback.format_exc()) error_details = { "node_id": unique_id, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted, "current_outputs": output_data_formatted } return (False, error_details, ex) executed.add(unique_id) return (True, None, None) def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] will_execute = [] if unique_id in outputs: return [] for x in inputs: input_data = inputs[x] if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: will_execute += recursive_will_execute(prompt, outputs, input_unique_id) return will_execute + [unique_id] def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] is_changed_old = '' is_changed = '' to_delete = False if hasattr(class_def, 'IS_CHANGED'): if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: is_changed_old = old_prompt[unique_id]['is_changed'] if 'is_changed' not in prompt[unique_id]: input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: # is_changed = class_def.IS_CHANGED(**input_data_all) is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True else: is_changed = prompt[unique_id]['is_changed'] if unique_id not in outputs: return True if not to_delete: if is_changed != is_changed_old: to_delete = True elif unique_id not in old_prompt: to_delete = True elif inputs == old_prompt[unique_id]['inputs']: for x in inputs: input_data = inputs[x] if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id in outputs: to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) else: to_delete = True if to_delete: break else: to_delete = True if to_delete: d = outputs.pop(unique_id) del d return to_delete class PromptExecutor: def __init__(self, server): self.outputs = {} self.object_storage = {} self.outputs_ui = {} self.old_prompt = {} self.server = server def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] class_type = prompt[node_id]["class_type"] # First, send back the status to the frontend depending # on the exception type if isinstance(ex, comfy.model_management.InterruptProcessingException): mes = { "prompt_id": prompt_id, "node_id": node_id, "node_type": class_type, "executed": list(executed), } self.server.send_sync("execution_interrupted", mes, self.server.client_id) else: if self.server.client_id is not None: mes = { "prompt_id": prompt_id, "node_id": node_id, "node_type": class_type, "executed": list(executed), "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], "current_outputs": error["current_outputs"], } self.server.send_sync("execution_error", mes, self.server.client_id) # Next, remove the subsequent outputs since they will not be executed to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): to_delete += [o] if o in self.old_prompt: d = self.old_prompt.pop(o) del d for o in to_delete: d = self.outputs.pop(o) del d def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): comfy.model_management.interrupt_current_processing(False) if "client_id" in extra_data: self.server.client_id = extra_data["client_id"] else: self.server.client_id = None if self.server.client_id is not None: self.server.send_sync("execution_start", {"prompt_id": prompt_id}, self.server.client_id) with torch.inference_mode(): # delete cached outputs if nodes don't exist for them to_delete = [] for o in self.outputs: if o not in prompt: to_delete += [o] for o in to_delete: d = self.outputs.pop(o) del d to_delete = [] for o in self.object_storage: if o[0] not in prompt: to_delete += [o] else: p = prompt[o[0]] if o[1] != p['class_type']: to_delete += [o] for o in to_delete: d = self.object_storage.pop(o) del d for x in prompt: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) for x in list(self.outputs_ui.keys()): if x not in current_outputs: d = self.outputs_ui.pop(x) del d if self.server.client_id is not None: self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, self.server.client_id) executed = set() output_node_id = None to_execute = [] for node_id in list(execute_outputs): to_execute += [(0, node_id)] while len(to_execute) > 0: # always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list( map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) output_node_id = to_execute.pop(0)[-1] # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) if success is not True: self.handle_execution_error( prompt_id, prompt, current_outputs, executed, error, ex) break for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]: # todo: this should check if LoadImage / LoadImageMask paths exist # todo: or, nodes should provide a way to validate their values unique_id = item if unique_id in validated: return validated[unique_id] inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] errors = [] valid = True for x in required_inputs: if x not in inputs: error = { "type": "required_input_missing", "message": "Required input is missing", "details": f"{x}", "extra_info": { "input_name": x } } errors.append(error) continue val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: error = { "type": "bad_linked_input", "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val } } errors.append(error) continue o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: received_type = r[val[1]] details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", "details": details, "extra_info": { "input_name": x, "input_config": info, "received_type": received_type, "linked_node": val } } errors.append(error) continue try: r = validate_inputs(prompt, o_id, validated) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False continue except Exception as ex: typ, _, tb = sys.exc_info() valid = False exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_inner_validation", "message": "Exception when validating inner node", "details": str(ex), "extra_info": { "input_name": x, "input_config": info, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "linked_node": val } }] validated[o_id] = (False, reasons, o_id) continue else: try: if type_input == "INT": val = int(val) inputs[x] = val if type_input == "FLOAT": val = float(val) inputs[x] = val if type_input == "STRING": val = str(val) inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", "message": f"Failed to convert an input value to a {type_input} value", "details": f"{x}, {val}, {ex}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, "exception_message": str(ex) } } errors.append(error) continue if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: error = { "type": "value_smaller_than_min", "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, } } errors.append(error) continue if "max" in info[1] and val > info[1]["max"]: error = { "type": "value_bigger_than_max", "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, } } errors.append(error) continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) # ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for i, r in enumerate(ret): if r is not True: details = f"{x}" if r is not False: details += f" - {str(r)}" error = { "type": "custom_validation_failed", "message": "Custom validation failed for node", "details": details, "extra_info": { "input_name": x, "input_config": info, "received_value": val, } } errors.append(error) continue else: if isinstance(type_input, list): if val not in type_input: input_config = info list_info = "" # Don't send back gigantic lists like if they're lots of # scanned model filepaths if len(type_input) > 20: list_info = f"(list of length {len(type_input)})" input_config = None else: list_info = str(type_input) error = { "type": "value_not_in_list", "message": "Value not in list", "details": f"{x}: '{val}' not in {list_info}", "extra_info": { "input_name": x, "input_config": input_config, "received_value": val, } } errors.append(error) continue if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) else: ret = (True, [], unique_id) validated[unique_id] = ret return ret def full_type_name(klass): module = klass.__module__ if module == 'builtins': return klass.__qualname__ return module + '.' + klass.__qualname__ def validate_prompt(prompt: dict) -> typing.Tuple[bool, dict | typing.List[dict] | None, typing.List[str], dict | list]: outputs = set() for x in prompt: class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True: outputs.add(x) if len(outputs) == 0: error = { "type": "prompt_no_outputs", "message": "Prompt has no outputs", "details": "", "extra_info": {} } return (False, error, [], []) good_outputs = set() errors = [] node_errors = {} validated = {} for o in outputs: valid = False reasons = [] try: m = validate_inputs(prompt, o, validated) valid = m[0] reasons = m[1] except Exception as ex: typ, _, tb = sys.exc_info() valid = False exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", "details": str(ex), "extra_info": { "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] validated[o] = (False, reasons, o) if valid is True: good_outputs.add(o) else: print(f"Failed to validate prompt for output {o}:") if len(reasons) > 0: print("* (prompt):") for reason in reasons: print(f" - {reason['message']}: {reason['details']}") errors += [(o, reasons)] for node_id, result in validated.items(): valid = result[0] reasons = result[1] # If a node upstream has errors, the nodes downstream will also # be reported as invalid, but there will be no errors attached. # So don't return those nodes as having errors in the response. if valid is not True and len(reasons) > 0: if node_id not in node_errors: class_type = prompt[node_id]['class_type'] node_errors[node_id] = { "errors": reasons, "dependent_outputs": [], "class_type": class_type } print(f"* {class_type} {node_id}:") for reason in reasons: print(f" - {reason['message']}: {reason['details']}") node_errors[node_id]["dependent_outputs"].append(o) print("Output will be ignored") if len(good_outputs) == 0: errors_list = [] for o, errors in errors: for error in errors: errors_list.append(f"{error['message']}: {error['details']}") errors_list = "\n".join(errors_list) error = { "type": "prompt_outputs_failed_validation", "message": "Prompt outputs failed validation", "details": errors_list, "extra_info": {} } return (False, error, list(good_outputs), node_errors) return (True, None, list(good_outputs), node_errors) class PromptQueue: queue: typing.List[QueueItem] currently_running: typing.Dict[int, QueueItem] # history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs history: typing.Dict[int, HistoryEntry] def __init__(self, server): self.server = server self.mutex = threading.RLock() self.not_empty = threading.Condition(self.mutex) self.next_task_id = 0 self.queue = [] self.currently_running = {} self.history = {} server.prompt_queue = self def size(self) -> int: return len(self.queue) def put(self, item: QueueItem): with self.mutex: heapq.heappush(self.queue, item) self.server.queue_updated() self.not_empty.notify() def get(self) -> typing.Tuple[QueueTuple, int]: with self.not_empty: while len(self.queue) == 0: self.not_empty.wait() item_with_future: QueueItem = heapq.heappop(self.queue) task_id = self.next_task_id self.currently_running[task_id] = item_with_future self.next_task_id += 1 self.server.queue_updated() return copy.deepcopy(item_with_future.queue_tuple), task_id def task_done(self, item_id, outputs: dict): with self.mutex: queue_item = self.currently_running.pop(item_id) prompt = queue_item.queue_tuple self.history[prompt[1]] = {"prompt": prompt, "outputs": {}, "timestamp": datetime.datetime.timestamp()} for o in outputs: self.history[prompt[1]]["outputs"][o] = outputs[o] self.server.queue_updated() if queue_item.completed: queue_item.completed.set_result(outputs) def get_current_queue(self) -> Tuple[typing.List[QueueTuple], typing.List[QueueTuple]]: """ Gets the current state of the queue :return: A tuple containing (the currently running items, the items awaiting execution) """ with self.mutex: out: typing.List[QueueTuple] = [] for x in self.currently_running.values(): out += [x.queue_tuple] return out, copy.deepcopy([item.queue_tuple for item in self.queue]) def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) def wipe_queue(self): with self.mutex: for item in self.queue: if item.completed: item.completed.set_exception(Exception("queue cancelled")) self.queue = [] self.server.queue_updated() def delete_queue_item(self, function): with self.mutex: for x in range(len(self.queue)): if function(self.queue[x]): if len(self.queue) == 1: self.wipe_queue() else: item = self.queue.pop(x) if item.completed: item.completed.set_exception(Exception("queue item deleted")) heapq.heapify(self.queue) self.server.queue_updated() return True return False def get_history(self, prompt_id=None): with self.mutex: if prompt_id is None: return copy.deepcopy(self.history) elif prompt_id in self.history: return {prompt_id: copy.deepcopy(self.history[prompt_id])} else: return {} def wipe_history(self): with self.mutex: self.history = {} def delete_history_item(self, id_to_delete: int): with self.mutex: self.history.pop(id_to_delete, None)