mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
- move workflows to distinct json files - add the comfy-org workflows for testing - fix issues where workflows from windows users would not be compatible with backends running on linux or macos in light of separator differences. Because this codebase uses get_or_download wherever checkpoints, models, etc. are used, this is the only place where the comparison is gracefully handled for downloading. Validation code will correctly convert backslashes to forward slashes, assuming that 100% of the places they are used and when comparing with a list, they are intended to be paths and not strict symbols
1094 lines
45 KiB
Python
1094 lines
45 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import heapq
|
|
import inspect
|
|
import logging
|
|
import sys
|
|
import threading
|
|
import time
|
|
import traceback
|
|
import typing
|
|
from os import PathLike
|
|
from pathlib import PurePath
|
|
from typing import List, Optional, Tuple
|
|
|
|
import lazy_object_proxy
|
|
import torch
|
|
from opentelemetry.trace import get_current_span, StatusCode, Status
|
|
|
|
from .main_pre import tracer
|
|
from .. import interruption
|
|
from .. import model_management
|
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
|
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
|
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
|
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
|
|
HistoryResultDict
|
|
from ..component_model.files import canonicalize_path
|
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
|
from ..execution_context import new_execution_context, ExecutionContext
|
|
from ..nodes.package import import_all_nodes_in_workspace
|
|
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
|
|
|
|
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
|
# various functions that are declared here. It should have been a context in the first place.
|
|
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
|
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
|
from comfy.graph_utils import is_link, GraphBuilder
|
|
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
|
|
|
|
|
class IsChangedCache:
|
|
def __init__(self, dynprompt, outputs_cache):
|
|
self.dynprompt = dynprompt
|
|
self.outputs_cache = outputs_cache
|
|
self.is_changed = {}
|
|
|
|
def get(self, node_id):
|
|
if node_id in self.is_changed:
|
|
return self.is_changed[node_id]
|
|
|
|
node = self.dynprompt.get_node(node_id)
|
|
class_type = node["class_type"]
|
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
if not hasattr(class_def, "IS_CHANGED"):
|
|
self.is_changed[node_id] = False
|
|
return self.is_changed[node_id]
|
|
|
|
if "is_changed" in node:
|
|
self.is_changed[node_id] = node["is_changed"]
|
|
return self.is_changed[node_id]
|
|
|
|
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
|
try:
|
|
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
|
except:
|
|
node["is_changed"] = float("NaN")
|
|
finally:
|
|
self.is_changed[node_id] = node["is_changed"]
|
|
return self.is_changed[node_id]
|
|
|
|
|
|
class CacheSet:
|
|
def __init__(self, lru_size=None):
|
|
if lru_size is None or lru_size == 0:
|
|
self.init_classic_cache()
|
|
else:
|
|
self.init_lru_cache(lru_size)
|
|
self.all = [self.outputs, self.ui, self.objects]
|
|
|
|
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
|
# blowing away the cache every time
|
|
def init_lru_cache(self, cache_size):
|
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
|
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
|
self.objects = HierarchicalCache(CacheKeySetID)
|
|
|
|
# Performs like the old cache -- dump data ASAP
|
|
def init_classic_cache(self):
|
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
|
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
|
self.objects = HierarchicalCache(CacheKeySetID)
|
|
|
|
def recursive_debug_dump(self):
|
|
result = {
|
|
"outputs": self.outputs.recursive_debug_dump(),
|
|
"ui": self.ui.recursive_debug_dump(),
|
|
}
|
|
return result
|
|
|
|
|
|
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None):
|
|
if extra_data is None:
|
|
extra_data = {}
|
|
if outputs is None:
|
|
outputs = {}
|
|
valid_inputs = class_def.INPUT_TYPES()
|
|
input_data_all = {}
|
|
missing_keys = {}
|
|
for x in inputs:
|
|
input_data = inputs[x]
|
|
input_type, input_category, input_info = get_input_info(class_def, x)
|
|
|
|
def mark_missing():
|
|
missing_keys[x] = True
|
|
input_data_all[x] = (None,)
|
|
|
|
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
|
input_unique_id = input_data[0]
|
|
output_index = input_data[1]
|
|
if outputs is None:
|
|
mark_missing()
|
|
continue # This might be a lazily-evaluated input
|
|
cached_output = outputs.get(input_unique_id)
|
|
if cached_output is None:
|
|
mark_missing()
|
|
continue
|
|
if output_index >= len(cached_output):
|
|
mark_missing()
|
|
continue
|
|
obj = cached_output[output_index]
|
|
input_data_all[x] = obj
|
|
elif input_category is not None:
|
|
input_data_all[x] = [input_data]
|
|
|
|
# todo: this should be retrieved from the execution context
|
|
if "hidden" in valid_inputs:
|
|
h = valid_inputs["hidden"]
|
|
for x in h:
|
|
if h[x] == "PROMPT":
|
|
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
|
if h[x] == "DYNPROMPT":
|
|
input_data_all[x] = [dynprompt]
|
|
if h[x] == "EXTRA_PNGINFO":
|
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
|
if h[x] == "UNIQUE_ID":
|
|
input_data_all[x] = [unique_id]
|
|
return input_data_all, missing_keys
|
|
|
|
|
|
@tracer.start_as_current_span("Execute Node")
|
|
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
|
span = get_current_span()
|
|
class_type = obj.__class__.__name__
|
|
span.set_attribute("class_type", class_type)
|
|
if input_data_all is not None:
|
|
for kwarg_name, kwarg_value in input_data_all.items():
|
|
if isinstance(kwarg_value, str) or isinstance(kwarg_value, bool) or isinstance(kwarg_value, int) or isinstance(kwarg_value, float):
|
|
span.set_attribute(f"input_data_all.{kwarg_name}", kwarg_value)
|
|
else:
|
|
try:
|
|
items_to_display = []
|
|
if hasattr(kwarg_value, "shape"):
|
|
# if the object has a shape attribute (likely a NumPy array or similar), get up to the first ten elements
|
|
flat_values = kwarg_value.flatten() if hasattr(kwarg_value, "flatten") else kwarg_value
|
|
items_to_display = [flat_values[i] for i in range(min(10, flat_values.size))]
|
|
elif hasattr(kwarg_value, "__getitem__") and hasattr(kwarg_value, "__len__"):
|
|
# If the object is indexable and has a length, get the first ten items
|
|
items_to_display = [kwarg_value[i] for i in range(min(10, len(kwarg_value)))]
|
|
|
|
filtered_items = [
|
|
item for item in items_to_display if isinstance(item, (str, bool, int, float))
|
|
]
|
|
|
|
if filtered_items:
|
|
span.set_attribute(f"input_data_all.{kwarg_name}", filtered_items)
|
|
except TypeError:
|
|
pass
|
|
# check if node wants the lists
|
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
|
|
|
if len(input_data_all) == 0:
|
|
max_len_input = 0
|
|
else:
|
|
max_len_input = max(len(x) for x in input_data_all.values())
|
|
|
|
# get a slice of inputs, repeat last input when list isn't long enough
|
|
def slice_dict(d, i):
|
|
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
|
|
|
results = []
|
|
|
|
def process_inputs(inputs, index=None):
|
|
if allow_interrupt:
|
|
interruption.throw_exception_if_processing_interrupted()
|
|
execution_block = None
|
|
for k, v in inputs.items():
|
|
if isinstance(v, ExecutionBlocker):
|
|
execution_block = execution_block_cb(v) if execution_block_cb else v
|
|
break
|
|
if execution_block is None:
|
|
if pre_execute_cb is not None and index is not None:
|
|
pre_execute_cb(index)
|
|
results.append(getattr(obj, func)(**inputs))
|
|
else:
|
|
results.append(execution_block)
|
|
|
|
if input_is_list:
|
|
process_inputs(input_data_all, 0)
|
|
elif max_len_input == 0:
|
|
process_inputs({})
|
|
else:
|
|
for i in range(max_len_input):
|
|
input_dict = slice_dict(input_data_all, i)
|
|
process_inputs(input_dict, i)
|
|
return results
|
|
|
|
|
|
def merge_result_data(results, obj):
|
|
# check which outputs need concatenating
|
|
output = []
|
|
output_is_list = [False] * len(results[0])
|
|
if hasattr(obj, "OUTPUT_IS_LIST"):
|
|
output_is_list = obj.OUTPUT_IS_LIST
|
|
|
|
# merge node execution results
|
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
|
if is_list:
|
|
output.append([x for o in results for x in o[i]])
|
|
else:
|
|
output.append([o[i] for o in results])
|
|
return output
|
|
|
|
|
|
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
|
results = []
|
|
uis = []
|
|
subgraph_results = []
|
|
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
|
has_subgraph = False
|
|
for i in range(len(return_values)):
|
|
r = return_values[i]
|
|
if isinstance(r, dict):
|
|
if 'ui' in r:
|
|
uis.append(r['ui'])
|
|
if 'expand' in r:
|
|
# Perform an expansion, but do not append results
|
|
has_subgraph = True
|
|
new_graph = r['expand']
|
|
result = r.get("result", None)
|
|
if isinstance(result, ExecutionBlocker):
|
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
|
subgraph_results.append((new_graph, result))
|
|
elif 'result' in r:
|
|
result = r.get("result", None)
|
|
if isinstance(result, ExecutionBlocker):
|
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
|
results.append(result)
|
|
subgraph_results.append((None, result))
|
|
else:
|
|
if isinstance(r, ExecutionBlocker):
|
|
r = tuple([r] * len(obj.RETURN_TYPES))
|
|
results.append(r)
|
|
subgraph_results.append((None, r))
|
|
|
|
if has_subgraph:
|
|
output = subgraph_results
|
|
elif len(results) > 0:
|
|
output = merge_result_data(results, obj)
|
|
else:
|
|
output = []
|
|
ui = dict()
|
|
if len(uis) > 0:
|
|
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
|
return output, ui, has_subgraph
|
|
|
|
|
|
def format_value(x) -> FormattedValue:
|
|
if x is None:
|
|
return None
|
|
elif isinstance(x, (int, float, bool, str)):
|
|
return x
|
|
else:
|
|
return str(x)
|
|
|
|
|
|
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
|
unique_id = current_item
|
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
|
parent_node_id = dynprompt.get_parent_node_id(unique_id)
|
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
if caches.outputs.get(unique_id) is not None:
|
|
if server.client_id is not None:
|
|
cached_output = caches.ui.get(unique_id) or {}
|
|
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_output.get("output", None), "prompt_id": prompt_id}, server.client_id)
|
|
return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
|
|
|
|
input_data_all = None
|
|
try:
|
|
if unique_id in pending_subgraph_results:
|
|
cached_results = pending_subgraph_results[unique_id]
|
|
resolved_outputs = []
|
|
for is_subgraph, result in cached_results:
|
|
if not is_subgraph:
|
|
resolved_outputs.append(result)
|
|
else:
|
|
resolved_output = []
|
|
for r in result:
|
|
if is_link(r):
|
|
source_node, source_output = r[0], r[1]
|
|
node_output = caches.outputs.get(source_node)[source_output]
|
|
for o in node_output:
|
|
resolved_output.append(o)
|
|
|
|
else:
|
|
resolved_output.append(r)
|
|
resolved_outputs.append(tuple(resolved_output))
|
|
output_data = merge_result_data(resolved_outputs, class_def)
|
|
output_ui = []
|
|
has_subgraph = False
|
|
else:
|
|
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
|
if server.client_id is not None:
|
|
server.last_node_id = display_node_id
|
|
server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id)
|
|
|
|
obj = caches.objects.get(unique_id)
|
|
if obj is None:
|
|
obj = class_def()
|
|
caches.objects.set(unique_id, obj)
|
|
|
|
if hasattr(obj, "check_lazy_status"):
|
|
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
|
|
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
|
|
x not in input_data_all or x in missing_keys
|
|
)]
|
|
if len(required_inputs) > 0:
|
|
for i in required_inputs:
|
|
execution_list.make_input_strong_link(unique_id, i)
|
|
return (ExecutionResult.PENDING, None, None)
|
|
|
|
def execution_block_cb(block):
|
|
if block.message is not None:
|
|
mes = {
|
|
"prompt_id": prompt_id,
|
|
"node_id": unique_id,
|
|
"node_type": class_type,
|
|
"executed": list(executed),
|
|
|
|
"exception_message": f"Execution Blocked: {block.message}",
|
|
"exception_type": "ExecutionBlocked",
|
|
"traceback": [],
|
|
"current_inputs": [],
|
|
"current_outputs": [],
|
|
}
|
|
server.send_sync("execution_error", mes, server.client_id)
|
|
return ExecutionBlocker(None)
|
|
else:
|
|
return block
|
|
|
|
def pre_execute_cb(call_index):
|
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
|
|
|
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
|
if len(output_ui) > 0:
|
|
caches.ui.set(unique_id, {
|
|
"meta": {
|
|
"node_id": unique_id,
|
|
"display_node": display_node_id,
|
|
"parent_node": parent_node_id,
|
|
"real_node_id": real_node_id,
|
|
},
|
|
"output": output_ui
|
|
})
|
|
if server.client_id is not None:
|
|
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id},
|
|
server.client_id)
|
|
if has_subgraph:
|
|
cached_outputs = []
|
|
new_node_ids = []
|
|
new_output_ids = []
|
|
new_output_links = []
|
|
for i in range(len(output_data)):
|
|
new_graph, node_outputs = output_data[i]
|
|
if new_graph is None:
|
|
cached_outputs.append((False, node_outputs))
|
|
else:
|
|
# Check for conflicts
|
|
for node_id in new_graph.keys():
|
|
if dynprompt.has_node(node_id):
|
|
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
|
|
for node_id, node_info in new_graph.items():
|
|
new_node_ids.append(node_id)
|
|
display_id = node_info.get("override_display_id", unique_id)
|
|
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
|
|
# Figure out if the newly created node is an output node
|
|
class_type = node_info["class_type"]
|
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
|
new_output_ids.append(node_id)
|
|
for i in range(len(node_outputs)):
|
|
if is_link(node_outputs[i]):
|
|
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
|
new_output_links.append((from_node_id, from_socket))
|
|
cached_outputs.append((True, node_outputs))
|
|
new_node_ids = set(new_node_ids)
|
|
for cache in caches.all:
|
|
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
|
for node_id in new_output_ids:
|
|
execution_list.add_node(node_id)
|
|
for link in new_output_links:
|
|
execution_list.add_strong_link(link[0], link[1], unique_id)
|
|
pending_subgraph_results[unique_id] = cached_outputs
|
|
return (ExecutionResult.PENDING, None, None)
|
|
caches.outputs.set(unique_id, output_data)
|
|
except interruption.InterruptProcessingException as iex:
|
|
logging.info("Processing interrupted")
|
|
|
|
# skip formatting inputs/outputs
|
|
error_details: RecursiveExecutionErrorDetailsInterrupted = {
|
|
"node_id": real_node_id,
|
|
}
|
|
|
|
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, iex)
|
|
except Exception as ex:
|
|
typ, _, tb = sys.exc_info()
|
|
exception_type = full_type_name(typ)
|
|
input_data_formatted = {}
|
|
if input_data_all is not None:
|
|
input_data_formatted = {}
|
|
for name, inputs in input_data_all.items():
|
|
input_data_formatted[name] = [format_value(x) for x in inputs]
|
|
|
|
logging.error("An error occurred while executing a workflow", exc_info=ex)
|
|
logging.error(traceback.format_exc())
|
|
|
|
error_details: RecursiveExecutionErrorDetails = {
|
|
"node_id": real_node_id,
|
|
"exception_message": str(ex),
|
|
"exception_type": exception_type,
|
|
"traceback": traceback.format_tb(tb),
|
|
"current_inputs": input_data_formatted
|
|
}
|
|
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
|
|
|
|
executed.add(unique_id)
|
|
|
|
return ExecutionResult.SUCCESS, None, None
|
|
|
|
|
|
class PromptExecutor:
|
|
def __init__(self, server: ExecutorToClientProgress, lru_size=None):
|
|
self.success = None
|
|
self.lru_size = lru_size
|
|
self.server = server
|
|
self.raise_exceptions = False
|
|
self.reset()
|
|
self.history_result: HistoryResultDict | None = None
|
|
|
|
def reset(self):
|
|
self.success = True
|
|
self.caches = CacheSet(self.lru_size)
|
|
self.status_messages = []
|
|
|
|
def add_message(self, event, data, broadcast: bool):
|
|
data = {
|
|
**data,
|
|
# todo: use a real time library
|
|
"timestamp": int(time.time() * 1000),
|
|
}
|
|
self.status_messages.append((event, data))
|
|
if self.server.client_id is not None or broadcast:
|
|
self.server.send_sync(event, data, self.server.client_id)
|
|
|
|
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
|
current_span = get_current_span()
|
|
current_span.set_status(Status(StatusCode.ERROR))
|
|
current_span.record_exception(ex)
|
|
|
|
node_id = error["node_id"]
|
|
class_type = prompt[node_id]["class_type"]
|
|
|
|
# First, send back the status to the frontend depending
|
|
# on the exception type
|
|
if isinstance(ex, interruption.InterruptProcessingException):
|
|
mes = {
|
|
"prompt_id": prompt_id,
|
|
"node_id": node_id,
|
|
"node_type": class_type,
|
|
"executed": list(executed),
|
|
}
|
|
self.add_message("execution_interrupted", mes, broadcast=True)
|
|
else:
|
|
mes = {
|
|
"prompt_id": prompt_id,
|
|
"node_id": node_id,
|
|
"node_type": class_type,
|
|
"executed": list(executed),
|
|
"exception_message": error["exception_message"],
|
|
"exception_type": error["exception_type"],
|
|
"traceback": error["traceback"],
|
|
"current_inputs": error["current_inputs"],
|
|
"current_outputs": list(current_outputs),
|
|
}
|
|
self.add_message("execution_error", mes, broadcast=False)
|
|
|
|
if ex is not None and self.raise_exceptions:
|
|
raise ex
|
|
|
|
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
|
with new_execution_context(ExecutionContext(self.server)):
|
|
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
|
|
|
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
|
|
if execute_outputs is None:
|
|
execute_outputs = []
|
|
if extra_data is None:
|
|
extra_data = {}
|
|
interruption.interrupt_current_processing(False)
|
|
|
|
if "client_id" in extra_data:
|
|
self.server.client_id = extra_data["client_id"]
|
|
else:
|
|
self.server.client_id = None
|
|
|
|
self.status_messages = []
|
|
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
|
|
|
with torch.inference_mode():
|
|
dynamic_prompt = DynamicPrompt(prompt)
|
|
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
|
for cache in self.caches.all:
|
|
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
|
cache.clean_unused()
|
|
|
|
current_outputs = self.caches.outputs.all_node_ids()
|
|
|
|
model_management.cleanup_models(keep_clone_weights_loaded=True)
|
|
self.add_message("execution_cached",
|
|
{"nodes": list(current_outputs), "prompt_id": prompt_id},
|
|
broadcast=False)
|
|
pending_subgraph_results = {}
|
|
executed = set()
|
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
|
for node_id in list(execute_outputs):
|
|
execution_list.add_node(node_id)
|
|
|
|
while not execution_list.is_empty():
|
|
node_id, error, ex = execution_list.stage_node_execution()
|
|
if error is not None:
|
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
|
break
|
|
|
|
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
|
if result == ExecutionResult.FAILURE:
|
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
|
break
|
|
elif result == ExecutionResult.PENDING:
|
|
execution_list.unstage_node_execution()
|
|
else: # result == ExecutionResult.SUCCESS:
|
|
execution_list.complete_node_execution()
|
|
|
|
ui_outputs = {}
|
|
meta_outputs = {}
|
|
all_node_ids = self.caches.ui.all_node_ids()
|
|
for node_id in all_node_ids:
|
|
ui_info = self.caches.ui.get(node_id)
|
|
if ui_info is not None:
|
|
ui_outputs[node_id] = ui_info["output"]
|
|
meta_outputs[node_id] = ui_info["meta"]
|
|
self.history_result = {
|
|
"outputs": ui_outputs,
|
|
"meta": meta_outputs,
|
|
}
|
|
self.server.last_node_id = None
|
|
if model_management.DISABLE_SMART_MEMORY:
|
|
model_management.unload_all_models()
|
|
|
|
@property
|
|
def outputs_ui(self) -> dict | None:
|
|
return self.history_result["outputs"] if self.history_result is not None else None
|
|
|
|
|
|
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
|
|
# todo: this should check if LoadImage / LoadImageMask paths exist
|
|
# todo: or, nodes should provide a way to validate their values
|
|
unique_id = item
|
|
if unique_id in validated:
|
|
return validated[unique_id]
|
|
|
|
inputs = prompt[unique_id]['inputs']
|
|
class_type = prompt[unique_id]['class_type']
|
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
|
|
class_inputs = obj_class.INPUT_TYPES()
|
|
valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {})))
|
|
|
|
error: ValidationErrorDict
|
|
errors = []
|
|
valid = True
|
|
|
|
# todo: investigate if these are at the right indent level
|
|
info: Optional[InputTypeSpec] = None
|
|
val = None
|
|
|
|
validate_function_inputs = []
|
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
|
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
|
received_types = {}
|
|
|
|
for x in valid_inputs:
|
|
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
|
assert extra_info is not None
|
|
if x not in inputs:
|
|
if input_category == "required":
|
|
error = {
|
|
"type": "required_input_missing",
|
|
"message": "Required input is missing",
|
|
"details": f"{x}",
|
|
"extra_info": {
|
|
"input_name": x
|
|
}
|
|
}
|
|
errors.append(error)
|
|
continue
|
|
|
|
val = inputs[x]
|
|
info: InputTypeSpec = (type_input, extra_info)
|
|
if isinstance(val, list):
|
|
if len(val) != 2:
|
|
error = {
|
|
"type": "bad_linked_input",
|
|
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
|
"details": f"{x}",
|
|
"extra_info": {
|
|
"input_name": x,
|
|
"input_config": info,
|
|
"received_value": val
|
|
}
|
|
}
|
|
errors.append(error)
|
|
continue
|
|
|
|
o_id = val[0]
|
|
o_class_type = prompt[o_id]['class_type']
|
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
|
received_type = r[val[1]]
|
|
received_types[x] = received_type
|
|
any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple))
|
|
if 'input_types' not in validate_function_inputs and received_type != type_input and not any_enum:
|
|
details = f"{x}, {received_type} != {type_input}"
|
|
error = {
|
|
"type": "return_type_mismatch",
|
|
"message": "Return type mismatch between linked nodes",
|
|
"details": details,
|
|
"extra_info": {
|
|
"input_name": x,
|
|
"input_config": info,
|
|
"received_type": received_type,
|
|
"linked_node": val
|
|
}
|
|
}
|
|
errors.append(error)
|
|
continue
|
|
try:
|
|
r2 = validate_inputs(prompt, o_id, validated)
|
|
if r2[0] is False:
|
|
# `r` will be set in `validated[o_id]` already
|
|
valid = False
|
|
continue
|
|
except Exception as ex:
|
|
typ, _, tb = sys.exc_info()
|
|
valid = False
|
|
exception_type = full_type_name(typ)
|
|
reasons = [{
|
|
"type": "exception_during_inner_validation",
|
|
"message": "Exception when validating inner node",
|
|
"details": str(ex),
|
|
"extra_info": {
|
|
"input_name": x,
|
|
"input_config": info,
|
|
"exception_message": str(ex),
|
|
"exception_type": exception_type,
|
|
"traceback": traceback.format_tb(tb),
|
|
"linked_node": val
|
|
}
|
|
}]
|
|
validated[o_id] = ValidateInputsTuple(False, reasons, o_id)
|
|
continue
|
|
else:
|
|
try:
|
|
if type_input == "INT":
|
|
val = int(val)
|
|
inputs[x] = val
|
|
if type_input == "FLOAT":
|
|
val = float(val)
|
|
inputs[x] = val
|
|
if type_input == "STRING":
|
|
val = str(val)
|
|
inputs[x] = val
|
|
if type_input == "BOOLEAN":
|
|
val = bool(val)
|
|
inputs[x] = val
|
|
except Exception as ex:
|
|
error = {
|
|
"type": "invalid_input_type",
|
|
"message": f"Failed to convert an input value to a {type_input} value",
|
|
"details": f"{x}, {val}, {ex}",
|
|
"extra_info": {
|
|
"input_name": x,
|
|
"input_config": info,
|
|
"received_value": val,
|
|
"exception_message": str(ex)
|
|
}
|
|
}
|
|
errors.append(error)
|
|
continue
|
|
|
|
if x not in validate_function_inputs:
|
|
has_min_max: IntSpecOptions | FloatSpecOptions = info[1]
|
|
if "min" in has_min_max and val < has_min_max["min"]:
|
|
error = {
|
|
"type": "value_smaller_than_min",
|
|
"message": "Value {} smaller than min of {}".format(val, has_min_max["min"]),
|
|
"details": f"{x}",
|
|
"extra_info": {
|
|
"input_name": x,
|
|
"input_config": info,
|
|
"received_value": val,
|
|
}
|
|
}
|
|
errors.append(error)
|
|
continue
|
|
if "max" in has_min_max and val > has_min_max["max"]:
|
|
error = {
|
|
"type": "value_bigger_than_max",
|
|
"message": "Value {} bigger than max of {}".format(val, has_min_max["max"]),
|
|
"details": f"{x}",
|
|
"extra_info": {
|
|
"input_name": x,
|
|
"input_config": info,
|
|
"received_value": val,
|
|
}
|
|
}
|
|
errors.append(error)
|
|
continue
|
|
|
|
if isinstance(type_input, list):
|
|
if "\\" in val:
|
|
# try to normalize paths for comparison purposes
|
|
val = canonicalize_path(val)
|
|
if val not in type_input:
|
|
input_config = info
|
|
list_info = ""
|
|
|
|
# Don't send back gigantic lists like if they're lots of
|
|
# scanned model filepaths
|
|
if len(type_input) > 20:
|
|
list_info = f"(list of length {len(type_input)})"
|
|
input_config = None
|
|
else:
|
|
list_info = str(type_input)
|
|
|
|
error = {
|
|
"type": "value_not_in_list",
|
|
"message": "Value not in list",
|
|
"details": f"{x}: '{val}' not in {list_info}",
|
|
"extra_info": {
|
|
"input_name": x,
|
|
"input_config": input_config,
|
|
"received_value": val,
|
|
}
|
|
}
|
|
errors.append(error)
|
|
continue
|
|
|
|
if len(validate_function_inputs) > 0:
|
|
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
|
input_filtered = {}
|
|
for x in input_data_all:
|
|
if x in validate_function_inputs:
|
|
input_filtered[x] = input_data_all[x]
|
|
if 'input_types' in validate_function_inputs:
|
|
input_filtered['input_types'] = [received_types]
|
|
|
|
# ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
|
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
|
for x in input_filtered:
|
|
for i, r in enumerate(ret):
|
|
if r is not True and not isinstance(r, ExecutionBlocker):
|
|
details = f"{x}"
|
|
if r is not False:
|
|
details += f" - {str(r)}"
|
|
|
|
error = {
|
|
"type": "custom_validation_failed",
|
|
"message": "Custom validation failed for node",
|
|
"details": details,
|
|
"extra_info": {
|
|
"input_name": x,
|
|
}
|
|
}
|
|
errors.append(error)
|
|
continue
|
|
|
|
if len(errors) > 0 or valid is not True:
|
|
ret = ValidateInputsTuple(False, errors, unique_id)
|
|
else:
|
|
ret = ValidateInputsTuple(True, [], unique_id)
|
|
|
|
validated[unique_id] = ret
|
|
return ret
|
|
|
|
|
|
def full_type_name(klass):
|
|
module = klass.__module__
|
|
if module == 'builtins':
|
|
return klass.__qualname__
|
|
return module + '.' + klass.__qualname__
|
|
|
|
|
|
@tracer.start_as_current_span("Validate Prompt")
|
|
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
|
res = _validate_prompt(prompt)
|
|
if not res.valid:
|
|
span = get_current_span()
|
|
span.set_status(Status(StatusCode.ERROR))
|
|
if res.error is not None and len(res.error) > 0:
|
|
span.set_attributes({
|
|
f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list))
|
|
})
|
|
if "extra_info" in res.error and isinstance(res.error["extra_info"], dict):
|
|
extra_info: ValidationErrorExtraInfoDict = res.error["extra_info"]
|
|
span.set_attributes({
|
|
f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list))
|
|
})
|
|
if len(res.node_errors) > 0:
|
|
for node_id, node_error in res.node_errors.items():
|
|
for node_error_field, node_error_value in node_error.items():
|
|
if isinstance(node_error_value, (str, bool, int, float)):
|
|
span.set_attribute(f"node_errors.{node_id}.{node_error_field}", node_error_value)
|
|
return res
|
|
|
|
|
|
def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
|
outputs = set()
|
|
for x in prompt:
|
|
if 'class_type' not in prompt[x]:
|
|
error = {
|
|
"type": "invalid_prompt",
|
|
"message": f"Cannot execute because a node is missing the class_type property.",
|
|
"details": f"Node ID '#{x}'",
|
|
"extra_info": {}
|
|
}
|
|
return ValidationTuple(False, error, [], [])
|
|
|
|
class_type = prompt[x]['class_type']
|
|
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
|
if class_ is None:
|
|
error = {
|
|
"type": "invalid_prompt",
|
|
"message": f"Cannot execute because node {class_type} does not exist.",
|
|
"details": f"Node ID '#{x}'",
|
|
"extra_info": {}
|
|
}
|
|
return ValidationTuple(False, error, [], [])
|
|
|
|
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
|
outputs.add(x)
|
|
|
|
if len(outputs) == 0:
|
|
error = {
|
|
"type": "prompt_no_outputs",
|
|
"message": "Prompt has no outputs",
|
|
"details": "",
|
|
"extra_info": {}
|
|
}
|
|
return ValidationTuple(False, error, [], [])
|
|
|
|
good_outputs = set()
|
|
errors = []
|
|
node_errors: typing.Dict[str, NodeErrorsDictValue] = {}
|
|
validated: typing.Dict[str, ValidateInputsTuple] = {}
|
|
for o in outputs:
|
|
valid = False
|
|
reasons: List[ValidationErrorDict] = []
|
|
try:
|
|
m = validate_inputs(prompt, o, validated)
|
|
valid = m[0]
|
|
reasons = m[1]
|
|
except Exception as ex:
|
|
typ, _, tb = sys.exc_info()
|
|
valid = False
|
|
exception_type = full_type_name(typ)
|
|
reasons = [{
|
|
"type": "exception_during_validation",
|
|
"message": "Exception when validating node",
|
|
"details": str(ex),
|
|
"extra_info": {
|
|
"exception_type": exception_type,
|
|
"traceback": traceback.format_tb(tb)
|
|
}
|
|
}]
|
|
validated[o] = ValidateInputsTuple(False, reasons, o)
|
|
|
|
if valid is True:
|
|
good_outputs.add(o)
|
|
else:
|
|
logging.error(f"Failed to validate prompt for output {o}:")
|
|
if len(reasons) > 0:
|
|
logging.error("* (prompt):")
|
|
for reason in reasons:
|
|
logging.error(f" - {reason['message']}: {reason['details']}")
|
|
errors += [(o, reasons)]
|
|
for node_id, result in validated.items():
|
|
valid = result[0]
|
|
reasons = result[1]
|
|
# If a node upstream has errors, the nodes downstream will also
|
|
# be reported as invalid, but there will be no errors attached.
|
|
# So don't return those nodes as having errors in the response.
|
|
if valid is not True and len(reasons) > 0:
|
|
if node_id not in node_errors:
|
|
class_type = prompt[node_id]['class_type']
|
|
node_errors[node_id] = {
|
|
"errors": reasons,
|
|
"dependent_outputs": [],
|
|
"class_type": class_type
|
|
}
|
|
logging.error(f"* {class_type} {node_id}:")
|
|
for reason in reasons:
|
|
logging.error(f" - {reason['message']}: {reason['details']}")
|
|
node_errors[node_id]["dependent_outputs"].append(o)
|
|
logging.error("Output will be ignored")
|
|
|
|
if len(good_outputs) == 0:
|
|
errors_list = []
|
|
for o, _errors in errors:
|
|
for error in _errors:
|
|
errors_list.append(f"{error['message']}: {error['details']}")
|
|
errors_list = "\n".join(errors_list)
|
|
|
|
error = {
|
|
"type": "prompt_outputs_failed_validation",
|
|
"message": "Prompt outputs failed validation",
|
|
"details": errors_list,
|
|
"extra_info": {}
|
|
}
|
|
|
|
return ValidationTuple(False, error, list(good_outputs), node_errors)
|
|
|
|
return ValidationTuple(True, None, list(good_outputs), node_errors)
|
|
|
|
|
|
class PromptQueue(AbstractPromptQueue):
|
|
def __init__(self, server: ExecutorToClientProgress):
|
|
self.server = server
|
|
self.mutex = threading.RLock()
|
|
self.not_empty = threading.Condition(self.mutex)
|
|
self.queue: typing.List[QueueItem] = []
|
|
self.currently_running: typing.Dict[str, QueueItem] = {}
|
|
# history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs
|
|
# todo: use the new History class for the sake of simplicity
|
|
self.history: typing.Dict[str, HistoryEntry] = {}
|
|
self.flags = {}
|
|
|
|
def size(self) -> int:
|
|
return len(self.queue)
|
|
|
|
def put(self, item: QueueItem):
|
|
with self.mutex:
|
|
heapq.heappush(self.queue, item)
|
|
self.server.queue_updated()
|
|
self.not_empty.notify()
|
|
|
|
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
|
|
with self.not_empty:
|
|
while len(self.queue) == 0:
|
|
self.not_empty.wait(timeout=timeout)
|
|
if timeout is not None and len(self.queue) == 0:
|
|
return None
|
|
item_with_future: QueueItem = heapq.heappop(self.queue)
|
|
assert item_with_future.prompt_id is not None
|
|
assert item_with_future.prompt_id != ""
|
|
assert item_with_future.prompt_id not in self.currently_running
|
|
assert isinstance(item_with_future.prompt_id, str)
|
|
task_id = item_with_future.prompt_id
|
|
self.currently_running[task_id] = item_with_future
|
|
self.server.queue_updated()
|
|
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
|
|
|
def task_done(self, item_id: str, outputs: dict,
|
|
status: Optional[ExecutionStatus]):
|
|
history_result = outputs
|
|
with self.mutex:
|
|
queue_item = self.currently_running.pop(item_id)
|
|
prompt = queue_item.queue_tuple
|
|
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
|
self.history.pop(next(iter(self.history)))
|
|
|
|
status_dict: Optional[dict] = None
|
|
if status is not None:
|
|
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
|
|
|
|
outputs_ = history_result["outputs"]
|
|
self.history[prompt[1]] = {
|
|
"prompt": prompt,
|
|
"outputs": copy.deepcopy(outputs_),
|
|
'status': status_dict,
|
|
}
|
|
self.history[prompt[1]].update(history_result)
|
|
self.server.queue_updated()
|
|
if queue_item.completed:
|
|
queue_item.completed.set_result(outputs_)
|
|
|
|
def get_current_queue(self) -> Tuple[typing.List[QueueTuple], typing.List[QueueTuple]]:
|
|
with self.mutex:
|
|
out: typing.List[QueueTuple] = []
|
|
for x in self.currently_running.values():
|
|
out += [x.queue_tuple]
|
|
return out, copy.deepcopy([item.queue_tuple for item in self.queue])
|
|
|
|
def get_tasks_remaining(self):
|
|
with self.mutex:
|
|
return len(self.queue) + len(self.currently_running)
|
|
|
|
def wipe_queue(self):
|
|
with self.mutex:
|
|
for item in self.queue:
|
|
if item.completed:
|
|
item.completed.set_exception(Exception("queue cancelled"))
|
|
self.queue = []
|
|
self.server.queue_updated()
|
|
|
|
def delete_queue_item(self, function):
|
|
with self.mutex:
|
|
for x in range(len(self.queue)):
|
|
if function(self.queue[x].queue_tuple):
|
|
if len(self.queue) == 1:
|
|
self.wipe_queue()
|
|
else:
|
|
item = self.queue.pop(x)
|
|
if item.completed:
|
|
item.completed.set_exception(Exception("queue item deleted"))
|
|
heapq.heapify(self.queue)
|
|
self.server.queue_updated()
|
|
return True
|
|
return False
|
|
|
|
def get_history(self, prompt_id=None, max_items=None, offset=-1):
|
|
with self.mutex:
|
|
if prompt_id is None:
|
|
out = {}
|
|
i = 0
|
|
if offset < 0 and max_items is not None:
|
|
offset = len(self.history) - max_items
|
|
for k in self.history:
|
|
if i >= offset:
|
|
out[k] = self.history[k]
|
|
if max_items is not None and len(out) >= max_items:
|
|
break
|
|
i += 1
|
|
return out
|
|
elif prompt_id in self.history:
|
|
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
|
|
else:
|
|
return {}
|
|
|
|
def wipe_history(self):
|
|
with self.mutex:
|
|
self.history.clear()
|
|
|
|
def delete_history_item(self, id_to_delete: str):
|
|
with self.mutex:
|
|
self.history.pop(id_to_delete, None)
|
|
|
|
def set_flag(self, name, data):
|
|
with self.mutex:
|
|
self.flags[name] = data
|
|
self.not_empty.notify()
|
|
|
|
def get_flags(self, reset=True):
|
|
with self.mutex:
|
|
if reset:
|
|
ret = self.flags
|
|
self.flags = {}
|
|
return ret
|
|
else:
|
|
return self.flags.copy()
|