Update with our changes

This commit is contained in:
doctorpangloss 2024-08-16 14:31:26 -07:00
parent f04b582744
commit 24a9eb2600
4 changed files with 288 additions and 153 deletions

View File

@ -1,10 +1,10 @@
import itertools
from typing import Sequence, Mapping
from comfy_execution.graph import DynamicPrompt
import nodes
from .cmd.execution import nodes
from .graph import DynamicPrompt
from .graph_utils import is_link
from comfy_execution.graph_utils import is_link
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
@ -29,10 +29,12 @@ class CacheKeySet:
def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None)
class Unhashable:
def __init__(self):
self.value = float("NaN")
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
@ -46,6 +48,7 @@ def to_hashable(obj):
# TODO - Support other objects like tensors?
return Unhashable()
class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
@ -60,6 +63,7 @@ class CacheKeySetID(CacheKeySet):
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
@ -98,7 +102,7 @@ class CacheKeySetInputSignature(CacheKeySet):
if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
signature.append((key, ("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
return signature
@ -122,6 +126,7 @@ class CacheKeySetInputSignature(CacheKeySet):
order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
def __init__(self, key_class):
self.key_class = key_class
@ -207,6 +212,7 @@ class BasicCache:
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class):
super().__init__(key_class)
@ -245,6 +251,7 @@ class HierarchicalCache(BasicCache):
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
@ -296,4 +303,3 @@ class LRUCache(BasicCache):
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self

View File

@ -1,30 +1,44 @@
import sys
from __future__ import annotations
import copy
import logging
import threading
import heapq
import inspect
import logging
import sys
import threading
import time
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
import typing
from os import PathLike
from typing import List, Optional, Tuple
import lazy_object_proxy
import torch
import nodes
from opentelemetry.trace import get_current_span, StatusCode, Status
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args
from .main_pre import tracer
from .. import interruption
from .. import model_management
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
HistoryResultDict
from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
PENDING = 2
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
# order matters
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
class DuplicateNodeError(Exception):
pass
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
@ -49,19 +63,19 @@ class IsChangedCache:
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e:
logging.warning("WARNING: {}".format(e))
except:
node["is_changed"] = float("NaN")
finally:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
@ -86,22 +100,29 @@ class CacheSet:
}
return result
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None):
if extra_data is None:
extra_data = {}
if outputs is None:
outputs = {}
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if outputs is None:
mark_missing()
continue # This might be a lazily-evaluated input
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
mark_missing()
@ -114,6 +135,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
elif input_category is not None:
input_data_all[x] = [input_data]
# todo: this should be retrieved from the execution context
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
@ -127,9 +149,35 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [unique_id]
return input_data_all, missing_keys
map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
@tracer.start_as_current_span("Execute Node")
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
span = get_current_span()
class_type = obj.__class__.__name__
span.set_attribute("class_type", class_type)
if input_data_all is not None:
for kwarg_name, kwarg_value in input_data_all.items():
if isinstance(kwarg_value, str) or isinstance(kwarg_value, bool) or isinstance(kwarg_value, int) or isinstance(kwarg_value, float):
span.set_attribute(f"input_data_all.{kwarg_name}", kwarg_value)
else:
try:
items_to_display = []
if hasattr(kwarg_value, "shape"):
# if the object has a shape attribute (likely a NumPy array or similar), get up to the first ten elements
flat_values = kwarg_value.flatten() if hasattr(kwarg_value, "flatten") else kwarg_value
items_to_display = [flat_values[i] for i in range(min(10, flat_values.size))]
elif hasattr(kwarg_value, "__getitem__") and hasattr(kwarg_value, "__len__"):
# If the object is indexable and has a length, get the first ten items
items_to_display = [kwarg_value[i] for i in range(min(10, len(kwarg_value)))]
filtered_items = [
item for item in items_to_display if isinstance(item, (str, bool, int, float))
]
if filtered_items:
span.set_attribute(f"input_data_all.{kwarg_name}", filtered_items)
except TypeError:
pass
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -137,15 +185,16 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
max_len_input = 0
else:
max_len_input = max(len(x) for x in input_data_all.values())
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
def process_inputs(inputs, index=None):
if allow_interrupt:
nodes.before_node_execution()
interruption.throw_exception_if_processing_interrupted()
execution_block = None
for k, v in inputs.items():
if isinstance(v, ExecutionBlocker):
@ -162,12 +211,13 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
process_inputs(input_data_all, 0)
elif max_len_input == 0:
process_inputs({})
else:
else:
for i in range(max_len_input):
input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i)
return results
def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
@ -183,12 +233,12 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = []
uis = []
subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
@ -214,19 +264,20 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r)
subgraph_results.append((None, r))
if has_subgraph:
output = subgraph_results
elif len(results) > 0:
output = merge_result_data(results, obj)
else:
output = []
ui = dict()
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui, has_subgraph
def format_value(x):
def format_value(x) -> FormattedValue:
if x is None:
return None
elif isinstance(x, (int, float, bool, str)):
@ -234,6 +285,7 @@ def format_value(x):
else:
return str(x)
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
@ -245,8 +297,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
return (ExecutionResult.SUCCESS, None, None)
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_output.get("output", None), "prompt_id": prompt_id}, server.client_id)
return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
input_data_all = None
try:
@ -275,7 +327,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id)
obj = caches.objects.get(unique_id)
if obj is None:
@ -283,10 +335,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
x not in input_data_all or x in missing_keys
)]
if len(required_inputs) > 0:
for i in required_inputs:
@ -311,8 +363,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return ExecutionBlocker(None)
else:
return block
def pre_execute_cb(call_index):
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
@ -325,7 +379,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
"output": output_ui
})
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id},
server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
@ -364,15 +419,15 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex:
except interruption.InterruptProcessingException as iex:
logging.info("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
error_details: RecursiveExecutionErrorDetailsInterrupted = {
"node_id": real_node_id,
}
return (ExecutionResult.FAILURE, error_details, iex)
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
@ -382,40 +437,46 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
logging.error(f"!!! Exception during processing !!! {ex}")
logging.error("An error occurred while executing a workflow", exc_info=ex)
logging.error(traceback.format_exc())
error_details = {
error_details: RecursiveExecutionErrorDetails = {
"node_id": real_node_id,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted
}
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (ExecutionResult.FAILURE, error_details, ex)
if isinstance(ex, model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
model_management.unload_all_models()
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id)
return (ExecutionResult.SUCCESS, None, None)
return ExecutionResult.SUCCESS, None, None
class PromptExecutor:
def __init__(self, server, lru_size=None):
def __init__(self, server: ExecutorToClientProgress, lru_size=None):
self.success = None
self.lru_size = lru_size
self.server = server
self.raise_exceptions = False
self.reset()
self.history_result: HistoryResultDict | None = None
def reset(self):
self.success = True
self.caches = CacheSet(self.lru_size)
self.status_messages = []
self.success = True
def add_message(self, event, data: dict, broadcast: bool):
data = {
**data,
# todo: use a real time library
"timestamp": int(time.time() * 1000),
}
self.status_messages.append((event, data))
@ -423,12 +484,16 @@ class PromptExecutor:
self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
current_span = get_current_span()
current_span.set_status(Status(StatusCode.ERROR))
current_span.record_exception(ex)
node_id = error["node_id"]
class_type = prompt[node_id]["class_type"]
# First, send back the status to the frontend depending
# on the exception type
if isinstance(ex, comfy.model_management.InterruptProcessingException):
if isinstance(ex, interruption.InterruptProcessingException):
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
@ -449,9 +514,20 @@ class PromptExecutor:
"current_outputs": list(current_outputs),
}
self.add_message("execution_error", mes, broadcast=False)
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
if ex is not None and self.raise_exceptions:
raise ex
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
with new_execution_context(ExecutionContext(self.server)):
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
if execute_outputs is None:
execute_outputs = []
if extra_data is None:
extra_data = {}
interruption.interrupt_current_processing(False)
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
@ -459,7 +535,7 @@ class PromptExecutor:
self.server.client_id = None
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
@ -473,10 +549,10 @@ class PromptExecutor:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
@ -496,7 +572,7 @@ class PromptExecutor:
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
else:
# Only execute when the while-loop ends without break
@ -515,12 +591,17 @@ class PromptExecutor:
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
if model_management.DISABLE_SMART_MEMORY:
model_management.unload_all_models()
@property
def outputs_ui(self) -> dict | None:
return self.history_result["outputs"] if self.history_result is not None else None
def validate_inputs(prompt, item, validated):
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
# todo: this should check if LoadImage / LoadImageMask paths exist
# todo: or, nodes should provide a way to validate their values
unique_id = item
if unique_id in validated:
return validated[unique_id]
@ -530,11 +611,16 @@ def validate_inputs(prompt, item, validated):
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {})))
error: ValidationErrorDict
errors = []
valid = True
# todo: investigate if these are at the right indent level
info: Optional[InputTypeSpec] = None
val = None
validate_function_inputs = []
validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"):
@ -560,7 +646,7 @@ def validate_inputs(prompt, item, validated):
continue
val = inputs[x]
info = (type_input, extra_info)
info: InputTypeSpec = (type_input, extra_info)
if isinstance(val, list):
if len(val) != 2:
error = {
@ -581,7 +667,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple))
if 'input_types' not in validate_function_inputs and received_type != type_input and not any_enum:
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",
@ -597,8 +684,8 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
if r[0] is False:
r2 = validate_inputs(prompt, o_id, validated)
if r2[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
continue
@ -619,7 +706,7 @@ def validate_inputs(prompt, item, validated):
"linked_node": val
}
}]
validated[o_id] = (False, reasons, o_id)
validated[o_id] = ValidateInputsTuple(False, reasons, o_id)
continue
else:
try:
@ -650,11 +737,12 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if x not in validate_function_inputs and not validate_has_kwargs:
if "min" in extra_info and val < extra_info["min"]:
if x not in validate_function_inputs:
has_min_max: IntSpecOptions | FloatSpecOptions = info[1]
if "min" in has_min_max and val < has_min_max["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"message": "Value {} smaller than min of {}".format(val, has_min_max["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
@ -664,10 +752,10 @@ def validate_inputs(prompt, item, validated):
}
errors.append(error)
continue
if "max" in extra_info and val > extra_info["max"]:
if "max" in has_min_max and val > has_min_max["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"message": "Value {} bigger than max of {}".format(val, has_min_max["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
@ -679,6 +767,11 @@ def validate_inputs(prompt, item, validated):
continue
if isinstance(type_input, list):
if "\\" in val:
# try to normalize paths for comparison purposes
val = canonicalize_path(val)
if all(isinstance(item, (str, PathLike)) for item in type_input):
type_input = [canonicalize_path(item) for item in type_input]
if val not in type_input:
input_config = info
list_info = ""
@ -713,8 +806,8 @@ def validate_inputs(prompt, item, validated):
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
# ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True and not isinstance(r, ExecutionBlocker):
@ -734,20 +827,45 @@ def validate_inputs(prompt, item, validated):
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
ret = ValidateInputsTuple(False, errors, unique_id)
else:
ret = (True, [], unique_id)
ret = ValidateInputsTuple(True, [], unique_id)
validated[unique_id] = ret
return ret
def full_type_name(klass):
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__
def validate_prompt(prompt):
@tracer.start_as_current_span("Validate Prompt")
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
res = _validate_prompt(prompt)
if not res.valid:
span = get_current_span()
span.set_status(Status(StatusCode.ERROR))
if res.error is not None and len(res.error) > 0:
span.set_attributes({
f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list))
})
if "extra_info" in res.error and isinstance(res.error["extra_info"], dict):
extra_info: ValidationErrorExtraInfoDict = res.error["extra_info"]
span.set_attributes({
f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list))
})
if len(res.node_errors) > 0:
for node_id, node_error in res.node_errors.items():
for node_error_field, node_error_value in node_error.items():
if isinstance(node_error_value, (str, bool, int, float)):
span.set_attribute(f"node_errors.{node_id}.{node_error_field}", node_error_value)
return res
def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
@ -757,7 +875,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
return ValidationTuple(False, error, [], [])
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
@ -768,7 +886,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
return ValidationTuple(False, error, [], [])
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
@ -780,15 +898,15 @@ def validate_prompt(prompt):
"details": "",
"extra_info": {}
}
return (False, error, [], [])
return ValidationTuple(False, error, [], [])
good_outputs = set()
errors = []
node_errors = {}
validated = {}
node_errors: typing.Dict[str, NodeErrorsDictValue] = {}
validated: typing.Dict[str, ValidateInputsTuple] = {}
for o in outputs:
valid = False
reasons = []
reasons: List[ValidationErrorDict] = []
try:
m = validate_inputs(prompt, o, validated)
valid = m[0]
@ -806,7 +924,7 @@ def validate_prompt(prompt):
"traceback": traceback.format_tb(tb)
}
}]
validated[o] = (False, reasons, o)
validated[o] = ValidateInputsTuple(False, reasons, o)
if valid is True:
good_outputs.add(o)
@ -839,8 +957,8 @@ def validate_prompt(prompt):
if len(good_outputs) == 0:
errors_list = []
for o, errors in errors:
for error in errors:
for o, _errors in errors:
for error in _errors:
errors_list.append(f"{error['message']}: {error['details']}")
errors_list = "\n".join(errors_list)
@ -851,73 +969,78 @@ def validate_prompt(prompt):
"extra_info": {}
}
return (False, error, list(good_outputs), node_errors)
return ValidationTuple(False, error, list(good_outputs), node_errors)
return (True, None, list(good_outputs), node_errors)
return ValidationTuple(True, None, list(good_outputs), node_errors)
MAXIMUM_HISTORY_SIZE = 10000
class PromptQueue:
def __init__(self, server):
class PromptQueue(AbstractPromptQueue):
def __init__(self, server: ExecutorToClientProgress):
self.server = server
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0
self.queue = []
self.currently_running = {}
self.history = {}
self.queue: typing.List[QueueItem] = []
self.currently_running: typing.Dict[str, QueueItem] = {}
# history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs
# todo: use the new History class for the sake of simplicity
self.history: typing.Dict[str, HistoryEntry] = {}
self.flags = {}
server.prompt_queue = self
def put(self, item):
def size(self) -> int:
return len(self.queue)
def put(self, item: QueueItem):
with self.mutex:
heapq.heappush(self.queue, item)
self.server.queue_updated()
self.not_empty.notify()
def get(self, timeout=None):
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait(timeout=timeout)
if timeout is not None and len(self.queue) == 0:
return None
item = heapq.heappop(self.queue)
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
self.task_counter += 1
item_with_future: QueueItem = heapq.heappop(self.queue)
assert item_with_future.prompt_id is not None
assert item_with_future.prompt_id != ""
assert item_with_future.prompt_id not in self.currently_running
assert isinstance(item_with_future.prompt_id, str)
task_id = item_with_future.prompt_id
self.currently_running[task_id] = item_with_future
self.server.queue_updated()
return (item, i)
return copy.deepcopy(item_with_future.queue_tuple), task_id
class ExecutionStatus(NamedTuple):
status_str: Literal['success', 'error']
completed: bool
messages: List[str]
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus']):
def task_done(self, item_id: str, outputs: dict,
status: Optional[ExecutionStatus]):
history_result = outputs
with self.mutex:
prompt = self.currently_running.pop(item_id)
queue_item = self.currently_running.pop(item_id)
prompt = queue_item.queue_tuple
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
status_dict: Optional[dict] = None
if status is not None:
status_dict = copy.deepcopy(status._asdict())
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
outputs_ = history_result["outputs"]
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": {},
"outputs": copy.deepcopy(outputs_),
'status': status_dict,
}
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
if queue_item.completed:
queue_item.completed.set_result(outputs_)
def get_current_queue(self):
def get_current_queue(self) -> Tuple[typing.List[QueueTuple], typing.List[QueueTuple]]:
with self.mutex:
out = []
out: typing.List[QueueTuple] = []
for x in self.currently_running.values():
out += [x]
return (out, copy.deepcopy(self.queue))
out += [x.queue_tuple]
return out, copy.deepcopy([item.queue_tuple for item in self.queue])
def get_tasks_remaining(self):
with self.mutex:
@ -925,17 +1048,22 @@ class PromptQueue:
def wipe_queue(self):
with self.mutex:
for item in self.queue:
if item.completed:
item.completed.set_exception(Exception("queue cancelled"))
self.queue = []
self.server.queue_updated()
def delete_queue_item(self, function):
with self.mutex:
for x in range(len(self.queue)):
if function(self.queue[x]):
if function(self.queue[x].queue_tuple):
if len(self.queue) == 1:
self.wipe_queue()
else:
self.queue.pop(x)
item = self.queue.pop(x)
if item.completed:
item.completed.set_exception(Exception("queue item deleted"))
heapq.heapify(self.queue)
self.server.queue_updated()
return True
@ -962,9 +1090,9 @@ class PromptQueue:
def wipe_history(self):
with self.mutex:
self.history = {}
self.history.clear()
def delete_history_item(self, id_to_delete):
def delete_history_item(self, id_to_delete: str):
with self.mutex:
self.history.pop(id_to_delete, None)

View File

@ -1,15 +1,7 @@
import nodes
from .cmd.execution import nodes
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError
from .graph_utils import is_link
from comfy_execution.graph_utils import is_link
class DependencyCycleError(Exception):
pass
class NodeInputError(Exception):
pass
class NodeNotFoundError(Exception):
pass
class DynamicPrompt:
def __init__(self, original_prompt):
@ -54,6 +46,7 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
@ -76,12 +69,13 @@ def get_input_info(class_def, input_name):
extra_info = {}
return input_type, input_category, extra_info
class TopologicalSort:
def __init__(self, dynprompt):
self.dynprompt = dynprompt
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
@ -136,11 +130,13 @@ class TopologicalSort:
def is_empty(self):
return len(self.pendingNodes) == 0
class ExecutionList(TopologicalSort):
"""
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
it can still be returned to the graph after having further dependencies added.
"""
def __init__(self, dynprompt, output_cache):
super().__init__(dynprompt)
self.output_cache = output_cache
@ -203,7 +199,7 @@ class ExecutionList(TopologicalSort):
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
# the code simple (and because having a cycle in the first place is a catastrophic error)
blocked_by = { node_id: {} for node_id in self.pendingNodes }
blocked_by = {node_id: {} for node_id in self.pendingNodes}
for from_node_id in self.blocking:
for to_node_id in self.blocking[from_node_id]:
if True in self.blocking[from_node_id][to_node_id].values():
@ -218,6 +214,7 @@ class ExecutionList(TopologicalSort):
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
return list(blocked_by.keys())
class ExecutionBlocker:
"""
Return this from a node and any users will be blocked with the given error message.
@ -232,6 +229,6 @@ class ExecutionBlocker:
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
"""
def __init__(self, message):
self.message = message

View File

@ -9,13 +9,16 @@ def is_link(obj):
return False
return True
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder:
"""
The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
"""
_default_prefix_root = ""
_default_prefix_call_index = 0
_default_prefix_graph_index = 0
def __init__(self, prefix = None):
def __init__(self, prefix=None):
if prefix is None:
self.prefix = GraphBuilder.alloc_prefix()
else:
@ -24,7 +27,7 @@ class GraphBuilder:
self.id_gen = 1
@classmethod
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
def set_default_prefix(cls, prefix_root, call_index, graph_index=0):
cls._default_prefix_root = prefix_root
cls._default_prefix_call_index = call_index
cls._default_prefix_graph_index = graph_index
@ -80,6 +83,7 @@ class GraphBuilder:
id = self.prefix + id
del self.nodes[id]
class Node:
def __init__(self, id, class_type, inputs):
self.id = id
@ -112,13 +116,14 @@ class Node:
serialized["override_display_id"] = self.override_display_id
return serialized
def add_graph_prefix(graph, outputs, prefix):
# Change the node IDs and any internal links
new_graph = {}
for node_id, node_info in graph.items():
# Make sure the added nodes have unique IDs
new_node_id = prefix + node_id
new_node = { "class_type": node_info["class_type"], "inputs": {} }
new_node = {"class_type": node_info["class_type"], "inputs": {}}
for input_name, input_value in node_info.get("inputs", {}).items():
if is_link(input_value):
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
@ -136,4 +141,3 @@ def add_graph_prefix(graph, outputs, prefix):
new_outputs.append(output)
return new_graph, tuple(new_outputs)