mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
- _FEEDBACK_DEFAULTS: add module-level docstring comment - _find_consumers: add source_socket parameter (default 0), add docstring - _is_sampler_target: add docstring
2052 lines
78 KiB
Python
2052 lines
78 KiB
Python
import asyncio
|
||
import copy
|
||
import heapq
|
||
import inspect
|
||
import logging
|
||
import sys
|
||
import threading
|
||
import time
|
||
import traceback
|
||
from enum import Enum
|
||
from typing import List, Literal, NamedTuple, Optional, Union
|
||
|
||
import comfy_aimdo.model_vbar
|
||
import psutil
|
||
import torch
|
||
|
||
import comfy.memory_management
|
||
import comfy.model_management
|
||
import comfy.model_prefetch
|
||
import nodes
|
||
from comfy.cli_args import args
|
||
from comfy_api.internal import (
|
||
_ComfyNodeInternal,
|
||
_NodeOutputInternal,
|
||
first_real_override,
|
||
is_class,
|
||
make_locked_method_func,
|
||
)
|
||
from comfy_api.latest import _io, io
|
||
from comfy_execution.asset_enrichment import enrich_output_with_assets
|
||
from comfy_execution.cache_provider import _get_cache_providers, _has_cache_providers
|
||
from comfy_execution.cache_provider import _logger as _cache_logger
|
||
from comfy_execution.caching import (
|
||
BasicCache,
|
||
CacheKeySetID,
|
||
CacheKeySetInputSignature,
|
||
HierarchicalCache,
|
||
LRUCache,
|
||
NullCache,
|
||
RAMPressureCache,
|
||
)
|
||
from comfy_execution.graph import (
|
||
DynamicPrompt,
|
||
ExecutionBlocker,
|
||
ExecutionList,
|
||
get_input_info,
|
||
)
|
||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||
from comfy_execution.progress import (
|
||
WebUIProgressHandler,
|
||
add_progress_handler,
|
||
get_progress_state,
|
||
reset_progress_state,
|
||
)
|
||
from comfy_execution.utils import CurrentNodeContext
|
||
from comfy_execution.validation import validate_node_input
|
||
from latent_preview import set_preview_method
|
||
|
||
|
||
class ExecutionResult(Enum):
|
||
SUCCESS = 0
|
||
FAILURE = 1
|
||
PENDING = 2
|
||
|
||
|
||
class DuplicateNodeError(Exception):
|
||
pass
|
||
|
||
|
||
class IsChangedCache:
|
||
def __init__(
|
||
self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache
|
||
):
|
||
self.prompt_id = prompt_id
|
||
self.dynprompt = dynprompt
|
||
self.outputs_cache = outputs_cache
|
||
self.is_changed = {}
|
||
|
||
async def get(self, node_id):
|
||
if node_id in self.is_changed:
|
||
return self.is_changed[node_id]
|
||
|
||
node = self.dynprompt.get_node(node_id)
|
||
class_type = node["class_type"]
|
||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||
has_is_changed = False
|
||
is_changed_name = None
|
||
if (
|
||
issubclass(class_def, _ComfyNodeInternal)
|
||
and first_real_override(class_def, "fingerprint_inputs") is not None
|
||
):
|
||
has_is_changed = True
|
||
is_changed_name = "fingerprint_inputs"
|
||
elif hasattr(class_def, "IS_CHANGED"):
|
||
has_is_changed = True
|
||
is_changed_name = "IS_CHANGED"
|
||
if not has_is_changed:
|
||
self.is_changed[node_id] = False
|
||
return self.is_changed[node_id]
|
||
|
||
if "is_changed" in node:
|
||
self.is_changed[node_id] = node["is_changed"]
|
||
return self.is_changed[node_id]
|
||
|
||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||
input_data_all, _, v3_data = get_input_data(
|
||
node["inputs"], class_def, node_id, None
|
||
)
|
||
try:
|
||
is_changed = await _async_map_node_over_list(
|
||
self.prompt_id,
|
||
node_id,
|
||
class_def,
|
||
input_data_all,
|
||
is_changed_name,
|
||
v3_data=v3_data,
|
||
)
|
||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||
node["is_changed"] = [
|
||
None if isinstance(x, ExecutionBlocker) else x for x in is_changed
|
||
]
|
||
except Exception as e:
|
||
logging.warning("WARNING: {}".format(e))
|
||
node["is_changed"] = float("NaN")
|
||
finally:
|
||
self.is_changed[node_id] = node["is_changed"]
|
||
return self.is_changed[node_id]
|
||
|
||
|
||
class CacheEntry(NamedTuple):
|
||
ui: dict
|
||
outputs: list
|
||
|
||
|
||
class CacheType(Enum):
|
||
CLASSIC = 0
|
||
LRU = 1
|
||
NONE = 2
|
||
RAM_PRESSURE = 3
|
||
|
||
|
||
# Initial values for bounded-feedback iteration outputs keyed by ComfyUI type
|
||
# string. When the DAG contains a feedback loop (e.g. step_index → … → cfg
|
||
# → guider → sampler) the execution engine seeds the iteration output with
|
||
# the default listed here so the downstream chain can evaluate before the
|
||
# iteration-producing node runs.
|
||
_FEEDBACK_DEFAULTS = {
|
||
"INT": 0,
|
||
"FLOAT": 0.0,
|
||
"BOOLEAN": False,
|
||
"STRING": "",
|
||
"NUMBER": 0,
|
||
"PRIMITIVE": 0,
|
||
}
|
||
|
||
|
||
class CacheSet:
|
||
def __init__(self, cache_type=None, cache_args={}):
|
||
if cache_type == CacheType.NONE:
|
||
self.init_null_cache()
|
||
logging.info("Disabling intermediate node cache.")
|
||
elif cache_type == CacheType.RAM_PRESSURE:
|
||
cache_ram = cache_args.get("ram", 16.0)
|
||
self.init_ram_cache(cache_ram)
|
||
logging.info("Using RAM pressure cache.")
|
||
elif cache_type == CacheType.LRU:
|
||
cache_size = cache_args.get("lru", 0)
|
||
self.init_lru_cache(cache_size)
|
||
logging.info("Using LRU cache")
|
||
else:
|
||
self.init_classic_cache()
|
||
|
||
self.all = [self.outputs, self.objects]
|
||
|
||
# Performs like the old cache -- dump data ASAP
|
||
def init_classic_cache(self):
|
||
self.outputs = HierarchicalCache(
|
||
CacheKeySetInputSignature, enable_providers=True
|
||
)
|
||
self.objects = HierarchicalCache(CacheKeySetID)
|
||
|
||
def init_lru_cache(self, cache_size):
|
||
self.outputs = LRUCache(
|
||
CacheKeySetInputSignature, max_size=cache_size, enable_providers=True
|
||
)
|
||
self.objects = HierarchicalCache(CacheKeySetID)
|
||
|
||
def init_ram_cache(self, min_headroom):
|
||
self.outputs = RAMPressureCache(
|
||
CacheKeySetInputSignature, enable_providers=True
|
||
)
|
||
self.objects = HierarchicalCache(CacheKeySetID)
|
||
|
||
def init_null_cache(self):
|
||
self.outputs = NullCache()
|
||
self.objects = NullCache()
|
||
|
||
def recursive_debug_dump(self):
|
||
result = {
|
||
"outputs": self.outputs.recursive_debug_dump(),
|
||
}
|
||
return result
|
||
|
||
|
||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||
|
||
|
||
def get_input_data(
|
||
inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}
|
||
):
|
||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||
v3_data: io.V3Data = {}
|
||
hidden_inputs_v3 = {}
|
||
valid_inputs = class_def.INPUT_TYPES()
|
||
if is_v3:
|
||
valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(
|
||
valid_inputs, inputs
|
||
)
|
||
input_data_all = {}
|
||
missing_keys = {}
|
||
for x in inputs:
|
||
input_data = inputs[x]
|
||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||
|
||
def mark_missing():
|
||
missing_keys[x] = True
|
||
input_data_all[x] = (None,)
|
||
|
||
if is_link(input_data) and (
|
||
not input_info or not input_info.get("rawLink", False)
|
||
):
|
||
input_unique_id = input_data[0]
|
||
output_index = input_data[1]
|
||
if execution_list is None:
|
||
mark_missing()
|
||
continue # This might be a lazily-evaluated input
|
||
cached = execution_list.get_cache(input_unique_id, unique_id)
|
||
if cached is None or cached.outputs is None:
|
||
# If this is a bounded-feedback link whose source hasn't
|
||
# executed yet, supply the type-appropriate initial value
|
||
# (e.g. step_index=0) so the feedback chain can evaluate
|
||
# before the iteration-producing node runs.
|
||
if _is_feedback_link(
|
||
execution_list, unique_id, input_unique_id, output_index
|
||
):
|
||
default_val = _get_feedback_default(
|
||
dynprompt, input_unique_id, output_index
|
||
)
|
||
obj = default_val
|
||
if isinstance(obj, (int, float, bool, str)):
|
||
obj = (obj,)
|
||
input_data_all[x] = obj
|
||
else:
|
||
mark_missing()
|
||
continue
|
||
if output_index >= len(cached.outputs):
|
||
mark_missing()
|
||
continue
|
||
obj = cached.outputs[output_index]
|
||
# Wrap atomic types (int, float, bool, str) in a tuple so
|
||
# _async_map_node_over_list can call len() on every input.
|
||
# The slice_dict helper then unwraps: (val,)[0] == val.
|
||
if isinstance(obj, (int, float, bool, str)):
|
||
obj = (obj,)
|
||
input_data_all[x] = obj
|
||
elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS):
|
||
input_data_all[x] = [input_data]
|
||
|
||
if is_v3:
|
||
if hidden is not None:
|
||
if io.Hidden.prompt.name in hidden:
|
||
hidden_inputs_v3[io.Hidden.prompt] = (
|
||
dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||
)
|
||
if io.Hidden.dynprompt.name in hidden:
|
||
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||
if io.Hidden.extra_pnginfo.name in hidden:
|
||
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get(
|
||
"extra_pnginfo", None
|
||
)
|
||
if io.Hidden.unique_id.name in hidden:
|
||
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||
if io.Hidden.auth_token_comfy_org.name in hidden:
|
||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get(
|
||
"auth_token_comfy_org", None
|
||
)
|
||
if io.Hidden.api_key_comfy_org.name in hidden:
|
||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get(
|
||
"api_key_comfy_org", None
|
||
)
|
||
if io.Hidden.comfy_usage_source.name in hidden:
|
||
hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get(
|
||
"comfy_usage_source", None
|
||
)
|
||
else:
|
||
if "hidden" in valid_inputs:
|
||
h = valid_inputs["hidden"]
|
||
for x in h:
|
||
if h[x] == "PROMPT":
|
||
input_data_all[x] = [
|
||
dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||
]
|
||
if h[x] == "DYNPROMPT":
|
||
input_data_all[x] = [dynprompt]
|
||
if h[x] == "EXTRA_PNGINFO":
|
||
input_data_all[x] = [extra_data.get("extra_pnginfo", None)]
|
||
if h[x] == "UNIQUE_ID":
|
||
input_data_all[x] = [unique_id]
|
||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||
if h[x] == "API_KEY_COMFY_ORG":
|
||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||
if h[x] == "COMFY_USAGE_SOURCE":
|
||
input_data_all[x] = [extra_data.get("comfy_usage_source", None)]
|
||
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||
return input_data_all, missing_keys, v3_data
|
||
|
||
|
||
map_node_over_list = None # Don't hook this please
|
||
|
||
|
||
async def resolve_map_node_over_list_results(results):
|
||
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
|
||
if len(remaining) == 0:
|
||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||
else:
|
||
done, pending = await asyncio.wait(remaining)
|
||
for task in done:
|
||
exc = task.exception()
|
||
if exc is not None:
|
||
raise exc
|
||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||
|
||
|
||
async def _async_map_node_over_list(
|
||
prompt_id,
|
||
unique_id,
|
||
obj,
|
||
input_data_all,
|
||
func,
|
||
allow_interrupt=False,
|
||
execution_block_cb=None,
|
||
pre_execute_cb=None,
|
||
v3_data=None,
|
||
):
|
||
# check if node wants the lists
|
||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||
|
||
if len(input_data_all) == 0:
|
||
max_len_input = 0
|
||
else:
|
||
max_len_input = max(len(x) for x in input_data_all.values())
|
||
|
||
# get a slice of inputs, repeat last input when list isn't long enough
|
||
def slice_dict(d, i):
|
||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||
|
||
results = []
|
||
|
||
async def process_inputs(inputs, index=None, input_is_list=False):
|
||
if allow_interrupt:
|
||
nodes.before_node_execution()
|
||
execution_block = None
|
||
for k, v in inputs.items():
|
||
if input_is_list:
|
||
for e in v:
|
||
if isinstance(e, ExecutionBlocker):
|
||
v = e
|
||
break
|
||
if isinstance(v, ExecutionBlocker):
|
||
execution_block = execution_block_cb(v) if execution_block_cb else v
|
||
break
|
||
if execution_block is None:
|
||
if pre_execute_cb is not None and index is not None:
|
||
pre_execute_cb(index)
|
||
# V3
|
||
if isinstance(obj, _ComfyNodeInternal) or (
|
||
is_class(obj) and issubclass(obj, _ComfyNodeInternal)
|
||
):
|
||
# if is just a class, then assign no state, just create clone
|
||
if is_class(obj):
|
||
type_obj = obj
|
||
obj.VALIDATE_CLASS()
|
||
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||
# otherwise, use class instance to populate/reuse some fields
|
||
else:
|
||
type_obj = type(obj)
|
||
type_obj.VALIDATE_CLASS()
|
||
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||
f = make_locked_method_func(type_obj, func, class_clone)
|
||
# in case of dynamic inputs, restructure inputs to expected nested dict
|
||
if v3_data is not None:
|
||
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||
# V1
|
||
else:
|
||
f = getattr(obj, func)
|
||
if inspect.iscoroutinefunction(f):
|
||
|
||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||
return await f(**args)
|
||
|
||
task = asyncio.create_task(
|
||
async_wrapper(f, prompt_id, unique_id, index, args=inputs)
|
||
)
|
||
# Give the task a chance to execute without yielding
|
||
await asyncio.sleep(0)
|
||
if task.done():
|
||
result = task.result()
|
||
results.append(result)
|
||
else:
|
||
results.append(task)
|
||
else:
|
||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||
result = f(**inputs)
|
||
results.append(result)
|
||
else:
|
||
results.append(execution_block)
|
||
|
||
if input_is_list:
|
||
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||
elif max_len_input == 0:
|
||
await process_inputs({})
|
||
else:
|
||
for i in range(max_len_input):
|
||
input_dict = slice_dict(input_data_all, i)
|
||
await process_inputs(input_dict, i)
|
||
return results
|
||
|
||
|
||
def merge_result_data(results, obj):
|
||
# check which outputs need concatenating
|
||
output = []
|
||
output_is_list = [False] * len(results[0])
|
||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||
output_is_list = obj.OUTPUT_IS_LIST
|
||
|
||
# merge node execution results
|
||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||
if is_list:
|
||
value = []
|
||
for o in results:
|
||
if isinstance(o[i], ExecutionBlocker):
|
||
value.append(o[i])
|
||
else:
|
||
value.extend(o[i])
|
||
output.append(value)
|
||
else:
|
||
output.append([o[i] for o in results])
|
||
return output
|
||
|
||
|
||
async def get_output_data(
|
||
prompt_id,
|
||
unique_id,
|
||
obj,
|
||
input_data_all,
|
||
execution_block_cb=None,
|
||
pre_execute_cb=None,
|
||
v3_data=None,
|
||
):
|
||
return_values = await _async_map_node_over_list(
|
||
prompt_id,
|
||
unique_id,
|
||
obj,
|
||
input_data_all,
|
||
obj.FUNCTION,
|
||
allow_interrupt=True,
|
||
execution_block_cb=execution_block_cb,
|
||
pre_execute_cb=pre_execute_cb,
|
||
v3_data=v3_data,
|
||
)
|
||
has_pending_task = any(
|
||
isinstance(r, asyncio.Task) and not r.done() for r in return_values
|
||
)
|
||
if has_pending_task:
|
||
return return_values, {}, False, has_pending_task
|
||
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
|
||
return output, ui, has_subgraph, False
|
||
|
||
|
||
def get_output_from_returns(return_values, obj):
|
||
results = []
|
||
uis = []
|
||
subgraph_results = []
|
||
has_subgraph = False
|
||
for i in range(len(return_values)):
|
||
r = return_values[i]
|
||
if isinstance(r, dict):
|
||
if "ui" in r:
|
||
uis.append(r["ui"])
|
||
if "expand" in r:
|
||
# Perform an expansion, but do not append results
|
||
has_subgraph = True
|
||
new_graph = r["expand"]
|
||
result = r.get("result", None)
|
||
if isinstance(result, ExecutionBlocker):
|
||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||
subgraph_results.append((new_graph, result))
|
||
elif "result" in r:
|
||
result = r.get("result", None)
|
||
if isinstance(result, ExecutionBlocker):
|
||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||
results.append(result)
|
||
subgraph_results.append((None, result))
|
||
elif isinstance(r, _NodeOutputInternal):
|
||
# V3
|
||
if r.ui is not None:
|
||
if isinstance(r.ui, dict):
|
||
uis.append(r.ui)
|
||
else:
|
||
uis.append(r.ui.as_dict())
|
||
if r.expand is not None:
|
||
has_subgraph = True
|
||
new_graph = r.expand
|
||
result = r.result
|
||
if r.block_execution is not None:
|
||
result = tuple(
|
||
[ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)
|
||
)
|
||
subgraph_results.append((new_graph, result))
|
||
elif r.result is not None:
|
||
result = r.result
|
||
if r.block_execution is not None:
|
||
result = tuple(
|
||
[ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)
|
||
)
|
||
results.append(result)
|
||
subgraph_results.append((None, result))
|
||
else:
|
||
if isinstance(r, ExecutionBlocker):
|
||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||
results.append(r)
|
||
subgraph_results.append((None, r))
|
||
|
||
if has_subgraph:
|
||
output = subgraph_results
|
||
elif len(results) > 0:
|
||
output = merge_result_data(results, obj)
|
||
else:
|
||
output = []
|
||
ui = dict()
|
||
# TODO: Think there's an existing bug here
|
||
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
|
||
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
|
||
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
|
||
if len(uis) > 0:
|
||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||
return output, ui, has_subgraph
|
||
|
||
|
||
def format_value(x):
|
||
if x is None:
|
||
return None
|
||
elif isinstance(x, (int, float, bool, str)):
|
||
return x
|
||
else:
|
||
return str(x)
|
||
|
||
|
||
def _is_intermediate_output(dynprompt, node_id):
|
||
class_type = dynprompt.get_node(node_id)["class_type"]
|
||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||
return getattr(class_def, "HAS_INTERMEDIATE_OUTPUT", False)
|
||
|
||
|
||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||
if server.client_id is None:
|
||
return
|
||
cached_ui = cached.ui or {}
|
||
server.send_sync(
|
||
"executed",
|
||
{
|
||
"node": node_id,
|
||
"display_node": display_node_id,
|
||
"output": cached_ui.get("output", None),
|
||
"prompt_id": prompt_id,
|
||
},
|
||
server.client_id,
|
||
)
|
||
if cached.ui is not None:
|
||
ui_outputs[node_id] = cached.ui
|
||
|
||
|
||
async def execute(
|
||
server,
|
||
dynprompt,
|
||
caches,
|
||
current_item,
|
||
extra_data,
|
||
executed,
|
||
prompt_id,
|
||
execution_list,
|
||
pending_subgraph_results,
|
||
pending_async_nodes,
|
||
ui_outputs,
|
||
):
|
||
unique_id = current_item
|
||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||
parent_node_id = dynprompt.get_parent_node_id(unique_id)
|
||
inputs = dynprompt.get_node(unique_id)["inputs"]
|
||
class_type = dynprompt.get_node(unique_id)["class_type"]
|
||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||
cached = await caches.outputs.get(unique_id)
|
||
if cached is not None:
|
||
_send_cached_ui(
|
||
server, unique_id, display_node_id, cached, prompt_id, ui_outputs
|
||
)
|
||
get_progress_state().finish_progress(unique_id)
|
||
execution_list.cache_update(unique_id, cached)
|
||
return (ExecutionResult.SUCCESS, None, None)
|
||
|
||
input_data_all = None
|
||
try:
|
||
if unique_id in pending_async_nodes:
|
||
results = []
|
||
for r in pending_async_nodes[unique_id]:
|
||
if isinstance(r, asyncio.Task):
|
||
try:
|
||
results.append(r.result())
|
||
except Exception as ex:
|
||
# An async task failed - propagate the exception up
|
||
del pending_async_nodes[unique_id]
|
||
raise ex
|
||
else:
|
||
results.append(r)
|
||
del pending_async_nodes[unique_id]
|
||
output_data, output_ui, has_subgraph = get_output_from_returns(
|
||
results, class_def
|
||
)
|
||
elif unique_id in pending_subgraph_results:
|
||
cached_results = pending_subgraph_results[unique_id]
|
||
resolved_outputs = []
|
||
for is_subgraph, result in cached_results:
|
||
if not is_subgraph:
|
||
resolved_outputs.append(result)
|
||
else:
|
||
resolved_output = []
|
||
for r in result:
|
||
if is_link(r):
|
||
source_node, source_output = r[0], r[1]
|
||
node_cached = execution_list.get_cache(
|
||
source_node, unique_id
|
||
)
|
||
for o in node_cached.outputs[source_output]:
|
||
resolved_output.append(o)
|
||
|
||
else:
|
||
resolved_output.append(r)
|
||
resolved_outputs.append(tuple(resolved_output))
|
||
output_data = merge_result_data(resolved_outputs, class_def)
|
||
output_ui = []
|
||
del pending_subgraph_results[unique_id]
|
||
has_subgraph = False
|
||
else:
|
||
get_progress_state().start_progress(unique_id)
|
||
input_data_all, missing_keys, v3_data = get_input_data(
|
||
inputs, class_def, unique_id, execution_list, dynprompt, extra_data
|
||
)
|
||
if server.client_id is not None:
|
||
server.last_node_id = display_node_id
|
||
server.send_sync(
|
||
"executing",
|
||
{
|
||
"node": unique_id,
|
||
"display_node": display_node_id,
|
||
"prompt_id": prompt_id,
|
||
},
|
||
server.client_id,
|
||
)
|
||
|
||
obj = await caches.objects.get(unique_id)
|
||
if obj is None:
|
||
obj = class_def()
|
||
await caches.objects.set(unique_id, obj)
|
||
|
||
if issubclass(class_def, _ComfyNodeInternal):
|
||
lazy_status_present = (
|
||
first_real_override(class_def, "check_lazy_status") is not None
|
||
)
|
||
else:
|
||
lazy_status_present = (
|
||
getattr(obj, "check_lazy_status", None) is not None
|
||
)
|
||
if lazy_status_present:
|
||
# for check_lazy_status, the returned data should include the original key of the input
|
||
v3_data_lazy = v3_data.copy()
|
||
v3_data_lazy["create_dynamic_tuple"] = True
|
||
required_inputs = await _async_map_node_over_list(
|
||
prompt_id,
|
||
unique_id,
|
||
obj,
|
||
input_data_all,
|
||
"check_lazy_status",
|
||
allow_interrupt=True,
|
||
v3_data=v3_data_lazy,
|
||
)
|
||
required_inputs = await resolve_map_node_over_list_results(
|
||
required_inputs
|
||
)
|
||
required_inputs = set(
|
||
sum([r for r in required_inputs if isinstance(r, list)], [])
|
||
)
|
||
required_inputs = [
|
||
x
|
||
for x in required_inputs
|
||
if isinstance(x, str)
|
||
and (x not in input_data_all or x in missing_keys)
|
||
]
|
||
if len(required_inputs) > 0:
|
||
for i in required_inputs:
|
||
execution_list.make_input_strong_link(unique_id, i)
|
||
return (ExecutionResult.PENDING, None, None)
|
||
|
||
def execution_block_cb(block):
|
||
if block.message is not None:
|
||
mes = {
|
||
"prompt_id": prompt_id,
|
||
"node_id": unique_id,
|
||
"node_type": class_type,
|
||
"executed": list(executed),
|
||
"exception_message": f"Execution Blocked: {block.message}",
|
||
"exception_type": "ExecutionBlocked",
|
||
"traceback": [],
|
||
"current_inputs": [],
|
||
"current_outputs": [],
|
||
}
|
||
server.send_sync("execution_error", mes, server.client_id)
|
||
return ExecutionBlocker(None)
|
||
else:
|
||
return block
|
||
|
||
def pre_execute_cb(call_index):
|
||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||
|
||
try:
|
||
(
|
||
output_data,
|
||
output_ui,
|
||
has_subgraph,
|
||
has_pending_tasks,
|
||
) = await get_output_data(
|
||
prompt_id,
|
||
unique_id,
|
||
obj,
|
||
input_data_all,
|
||
execution_block_cb=execution_block_cb,
|
||
pre_execute_cb=pre_execute_cb,
|
||
v3_data=v3_data,
|
||
)
|
||
finally:
|
||
if comfy.memory_management.aimdo_enabled:
|
||
if args.verbose == "DEBUG":
|
||
comfy_aimdo.control.analyze()
|
||
comfy.model_management.reset_cast_buffers()
|
||
comfy.model_prefetch.cleanup_prefetch_queues()
|
||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||
|
||
if has_pending_tasks:
|
||
pending_async_nodes[unique_id] = output_data
|
||
unblock = execution_list.add_external_block(unique_id)
|
||
|
||
async def await_completion():
|
||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
unblock()
|
||
|
||
asyncio.create_task(await_completion())
|
||
return (ExecutionResult.PENDING, None, None)
|
||
if len(output_ui) > 0:
|
||
# Enrich at output-processing time (not in the send path) so assets
|
||
# are registered even when no client is connected, and the asset id
|
||
# flows into ui_outputs and the cache alongside the raw entries.
|
||
output_ui = enrich_output_with_assets(output_ui)
|
||
ui_outputs[unique_id] = {
|
||
"meta": {
|
||
"node_id": unique_id,
|
||
"display_node": display_node_id,
|
||
"parent_node": parent_node_id,
|
||
"real_node_id": real_node_id,
|
||
},
|
||
"output": output_ui,
|
||
}
|
||
if server.client_id is not None:
|
||
server.send_sync(
|
||
"executed",
|
||
{
|
||
"node": unique_id,
|
||
"display_node": display_node_id,
|
||
"output": output_ui,
|
||
"prompt_id": prompt_id,
|
||
},
|
||
server.client_id,
|
||
)
|
||
if has_subgraph:
|
||
cached_outputs = []
|
||
new_node_ids = []
|
||
new_output_ids = []
|
||
new_output_links = []
|
||
for i in range(len(output_data)):
|
||
new_graph, node_outputs = output_data[i]
|
||
if new_graph is None:
|
||
cached_outputs.append((False, node_outputs))
|
||
else:
|
||
for node_id, node_info in new_graph.items():
|
||
new_node_ids.append(node_id)
|
||
display_id = node_info.get("override_display_id", unique_id)
|
||
dynprompt.add_ephemeral_node(
|
||
node_id, node_info, unique_id, display_id
|
||
)
|
||
# Figure out if the newly created node is an output node
|
||
class_type = node_info["class_type"]
|
||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||
if (
|
||
hasattr(class_def, "OUTPUT_NODE")
|
||
and class_def.OUTPUT_NODE == True
|
||
):
|
||
new_output_ids.append(node_id)
|
||
for i in range(len(node_outputs)):
|
||
if is_link(node_outputs[i]):
|
||
from_node_id, from_socket = (
|
||
node_outputs[i][0],
|
||
node_outputs[i][1],
|
||
)
|
||
new_output_links.append((from_node_id, from_socket))
|
||
cached_outputs.append((True, node_outputs))
|
||
new_node_ids = set(new_node_ids)
|
||
for cache in caches.all:
|
||
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||
subcache.clean_unused()
|
||
for node_id in new_output_ids:
|
||
execution_list.add_node(node_id)
|
||
execution_list.cache_link(node_id, unique_id)
|
||
for link in new_output_links:
|
||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||
pending_subgraph_results[unique_id] = cached_outputs
|
||
return (ExecutionResult.PENDING, None, None)
|
||
|
||
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||
execution_list.cache_update(unique_id, cache_entry)
|
||
await caches.outputs.set(unique_id, cache_entry)
|
||
|
||
except comfy.model_management.InterruptProcessingException as iex:
|
||
logging.info("Processing interrupted")
|
||
|
||
# skip formatting inputs/outputs
|
||
error_details = {
|
||
"node_id": real_node_id,
|
||
}
|
||
|
||
return (ExecutionResult.FAILURE, error_details, iex)
|
||
except Exception as ex:
|
||
typ, _, tb = sys.exc_info()
|
||
exception_type = full_type_name(typ)
|
||
input_data_formatted = {}
|
||
if input_data_all is not None:
|
||
input_data_formatted = {}
|
||
for name, inputs in input_data_all.items():
|
||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||
|
||
logging.error(f"!!! Exception during processing !!! {ex}")
|
||
logging.error(traceback.format_exc())
|
||
tips = ""
|
||
|
||
if comfy.model_management.is_oom(ex):
|
||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||
logging.info(
|
||
"Memory summary:\n{}".format(
|
||
comfy.model_management.debug_memory_summary()
|
||
)
|
||
)
|
||
logging.error("Got an OOM, unloading all loaded models.")
|
||
comfy.model_management.unload_all_models()
|
||
elif (
|
||
isinstance(ex, RuntimeError)
|
||
and ("mat1 and mat2 shapes" in str(ex))
|
||
and "Sampler" in class_type
|
||
):
|
||
tips = '\n\nTIPS: If you have any "Load CLIP" or "*CLIP Loader" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected.'
|
||
|
||
error_details = {
|
||
"node_id": real_node_id,
|
||
"exception_message": "{}\n{}".format(ex, tips),
|
||
"exception_type": exception_type,
|
||
"traceback": traceback.format_tb(tb),
|
||
"current_inputs": input_data_formatted,
|
||
}
|
||
|
||
return (ExecutionResult.FAILURE, error_details, ex)
|
||
|
||
get_progress_state().finish_progress(unique_id)
|
||
executed.add(unique_id)
|
||
|
||
return (ExecutionResult.SUCCESS, None, None)
|
||
|
||
|
||
def _is_feedback_link(execution_list, to_node_id, from_node_id, from_socket):
|
||
"""Return True when *to_node_id* receives *from_node_id*:*from_socket*
|
||
through a bounded-feedback edge (recorded during graph construction)."""
|
||
edges = execution_list.feedback_links.get(to_node_id, [])
|
||
return (from_node_id, from_socket) in edges
|
||
|
||
|
||
def _get_feedback_default(dynprompt, from_node_id, from_socket):
|
||
"""Return the type-appropriate initial value for a feedback iteration
|
||
output (e.g. 0 for INT, 0.0 for FLOAT)."""
|
||
try:
|
||
class_type = dynprompt.get_node(from_node_id)["class_type"]
|
||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||
return_types = class_def.RETURN_TYPES
|
||
except Exception:
|
||
return 0
|
||
if from_socket < len(return_types):
|
||
return _FEEDBACK_DEFAULTS.get(return_types[from_socket], 0)
|
||
return 0
|
||
|
||
|
||
def _build_feedback_fns(
|
||
dynamic_prompt,
|
||
from_node_id,
|
||
from_socket,
|
||
to_node_id,
|
||
cfg_injections,
|
||
sampler_injections,
|
||
):
|
||
"""Try to build per-step update functions from a feedback edge.
|
||
|
||
Walks forward from the feedback-receiving node through intermediate
|
||
ComfyMathExpression nodes to find targets that need per-step callables.
|
||
Handles two target types:
|
||
|
||
* **CFGGuider** — populates *cfg_injections* keyed by guider node id
|
||
with a ``cfg_fn(step, total_steps)`` callable.
|
||
* **Sampler-producing nodes** (any node whose class_type starts with
|
||
"Sampler" except the iteration node itself) — populates
|
||
*sampler_injections* keyed by (sampler_node_id, param_name) with a
|
||
``param_fn(step, total_steps)`` callable.
|
||
|
||
Supports multi-hop chains like::
|
||
|
||
iteration_node ──(step_index)──→ MathExpr_A ──→ MathExpr_B ──→ CFGGuider
|
||
├─→ SamplerXXX
|
||
└─→ ...
|
||
"""
|
||
try:
|
||
prompt = dynamic_prompt.original_prompt
|
||
except Exception:
|
||
return
|
||
|
||
from simpleeval import simple_eval
|
||
|
||
from comfy_extras.nodes_math import MATH_FUNCTIONS
|
||
|
||
# ---- helpers ----
|
||
def _find_consumers(source_id, source_socket=0):
|
||
"""Return all nodes consuming output *source_socket* of *source_id*.
|
||
|
||
Each result is ``(consumer_id, class_type, input_name)``.
|
||
"""
|
||
consumers = []
|
||
for nid, n in prompt.items():
|
||
for iname, ival in n.get("inputs", {}).items():
|
||
if (
|
||
isinstance(ival, list)
|
||
and len(ival) == 2
|
||
and ival[0] == source_id
|
||
and ival[1] == source_socket
|
||
):
|
||
consumers.append((nid, n.get("class_type"), iname))
|
||
return consumers
|
||
|
||
def _is_sampler_target(class_type):
|
||
"""Return True if *class_type* is a sampler that accepts per-step
|
||
parameter updates via :class:`KSAMPLER.extra_options`."""
|
||
return (
|
||
class_type is not None
|
||
and "Sampler" in class_type
|
||
and class_type != "SamplerCustomAdvanced"
|
||
)
|
||
|
||
def _resolve_input_value(source_node_id, source_socket):
|
||
"""Try to resolve a non-feedback linked input to a static value.
|
||
|
||
First checks the source node's ``inputs`` dict (API format) for a
|
||
direct scalar value at the socket. Falls back to ``widgets_values``
|
||
positional mapping (workflow-file format). Returns the resolved
|
||
value, or None if unresolvable.
|
||
"""
|
||
try:
|
||
snode = prompt.get(str(source_node_id))
|
||
if snode is None:
|
||
return None
|
||
class_type = snode.get("class_type", "")
|
||
inputs = snode.get("inputs", {})
|
||
|
||
# API format: inputs are named — find the name that maps to
|
||
# *source_socket* via the class's INPUT_TYPES ordering.
|
||
cls = nodes.NODE_CLASS_MAPPINGS.get(class_type)
|
||
if cls is not None:
|
||
try:
|
||
input_types = cls.INPUT_TYPES()
|
||
except Exception:
|
||
input_types = {}
|
||
required = input_types.get("required", {})
|
||
req_names = list(required.keys())
|
||
if source_socket < len(req_names):
|
||
name = req_names[source_socket]
|
||
val = inputs.get(name)
|
||
if val is not None and not isinstance(val, list):
|
||
return val
|
||
|
||
# Fallback: widgets_values positional mapping (workflow-file format)
|
||
wv = snode.get("widgets_values", [])
|
||
if wv:
|
||
if class_type in ("PrimitiveInt", "PrimitiveFloat", "PrimitiveBool"):
|
||
if source_socket == 0 and len(wv) > 0:
|
||
return wv[0]
|
||
if (
|
||
cls is not None
|
||
and source_socket < len(req_names)
|
||
and source_socket < len(wv)
|
||
):
|
||
return wv[source_socket]
|
||
return None
|
||
except Exception:
|
||
return None
|
||
|
||
def _collect_extra_names(
|
||
node_id, feedback_from_node, feedback_from_socket, feedback_var_name
|
||
):
|
||
"""Collect non-feedback linked inputs from a MathExpression node
|
||
and resolve them to values. Returns dict of name→value."""
|
||
extra = {}
|
||
try:
|
||
snode = prompt.get(str(node_id))
|
||
if snode is None:
|
||
return extra
|
||
for inp_name, inp_val in snode.get("inputs", {}).items():
|
||
if not isinstance(inp_val, list) or len(inp_val) != 2:
|
||
continue
|
||
src_id, src_socket = inp_val[0], inp_val[1]
|
||
# Skip the feedback-linked input — that's the iteration variable
|
||
if src_id == str(feedback_from_node) and int(src_socket) == int(
|
||
feedback_from_socket
|
||
):
|
||
continue
|
||
# This is an additional linked input — try to resolve it
|
||
val = _resolve_input_value(src_id, src_socket)
|
||
if val is not None:
|
||
var_name = inp_name.rsplit(".", 1)[-1]
|
||
extra[var_name] = val
|
||
except Exception:
|
||
pass
|
||
return extra
|
||
|
||
# Each chain element is now (expression, feedback_var, extra_names_dict)
|
||
# ---- depth-first search ----
|
||
def _dfs(start_id, from_node, from_socket, chain):
|
||
"""Walk the MathExpr chain looking for any target node that needs
|
||
per-step updates. Returns a list of (target_type, target_id,
|
||
input_name, full_chain) tuples, where target_type is 'guider'
|
||
or 'sampler'."""
|
||
try:
|
||
node = dynamic_prompt.get_node(start_id)
|
||
except Exception:
|
||
return []
|
||
if node.get("class_type") != "ComfyMathExpression":
|
||
return []
|
||
|
||
expression = node.get("inputs", {}).get("expression", "")
|
||
if not expression or not expression.strip():
|
||
return []
|
||
|
||
var_name = None
|
||
for input_name, input_val in node.get("inputs", {}).items():
|
||
if (
|
||
isinstance(input_val, list)
|
||
and len(input_val) == 2
|
||
and input_val[0] == from_node
|
||
and input_val[1] == from_socket
|
||
):
|
||
var_name = input_name.rsplit(".", 1)[-1]
|
||
break
|
||
if var_name is None:
|
||
return []
|
||
|
||
# Collect additional (non-feedback) input values for this node
|
||
extra_names = _collect_extra_names(start_id, from_node, from_socket, var_name)
|
||
|
||
new_chain = chain + [(expression, var_name, extra_names)]
|
||
results = []
|
||
|
||
for cid, ctype, ciname in _find_consumers(start_id):
|
||
if ctype == "CFGGuider":
|
||
results.append(("guider", cid, None, new_chain))
|
||
elif _is_sampler_target(ctype):
|
||
results.append(("sampler", cid, ciname, new_chain))
|
||
elif ctype == "ComfyMathExpression":
|
||
results.extend(_dfs(cid, start_id, 0, new_chain))
|
||
return results
|
||
|
||
# ---- compose functions from discovered chains ----
|
||
for target_type, target_id, param_name, chain in _dfs(
|
||
to_node_id, from_node_id, from_socket, []
|
||
):
|
||
if not chain:
|
||
continue
|
||
|
||
def _make_fn(_chain):
|
||
def _fn(step, total_steps):
|
||
val = step
|
||
for expr_str, var, extra_names in _chain:
|
||
ctx = dict(extra_names) if extra_names else {}
|
||
ctx[var] = val
|
||
val = float(
|
||
simple_eval(expr_str, names=ctx, functions=MATH_FUNCTIONS)
|
||
)
|
||
return val
|
||
|
||
return _fn
|
||
|
||
if target_type == "guider":
|
||
cfg_injections[target_id] = _make_fn(chain)
|
||
elif target_type == "sampler" and param_name:
|
||
sampler_injections[target_id] = sampler_injections.get(target_id, {})
|
||
sampler_injections[target_id][param_name] = _make_fn(chain)
|
||
|
||
|
||
class PromptExecutor:
|
||
def __init__(self, server, cache_type=False, cache_args=None):
|
||
self.cache_args = cache_args
|
||
self.cache_type = cache_type
|
||
self.server = server
|
||
self.reset()
|
||
|
||
def reset(self):
|
||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||
self.status_messages = []
|
||
self.success = True
|
||
|
||
def add_message(self, event, data: dict, broadcast: bool):
|
||
data = {
|
||
**data,
|
||
"timestamp": int(time.time() * 1000),
|
||
}
|
||
self.status_messages.append((event, data))
|
||
if self.server.client_id is not None or broadcast:
|
||
self.server.send_sync(event, data, self.server.client_id)
|
||
|
||
def handle_execution_error(
|
||
self, prompt_id, prompt, current_outputs, executed, error, ex
|
||
):
|
||
node_id = error["node_id"]
|
||
class_type = prompt[node_id]["class_type"]
|
||
|
||
# First, send back the status to the frontend depending
|
||
# on the exception type
|
||
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||
mes = {
|
||
"prompt_id": prompt_id,
|
||
"node_id": node_id,
|
||
"node_type": class_type,
|
||
"executed": list(executed),
|
||
}
|
||
self.add_message("execution_interrupted", mes, broadcast=True)
|
||
else:
|
||
mes = {
|
||
"prompt_id": prompt_id,
|
||
"node_id": node_id,
|
||
"node_type": class_type,
|
||
"executed": list(executed),
|
||
"exception_message": error["exception_message"],
|
||
"exception_type": error["exception_type"],
|
||
"traceback": error["traceback"],
|
||
"current_inputs": error["current_inputs"],
|
||
"current_outputs": list(current_outputs),
|
||
}
|
||
self.add_message("execution_error", mes, broadcast=False)
|
||
|
||
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||
if not _has_cache_providers():
|
||
return
|
||
|
||
for provider in _get_cache_providers():
|
||
try:
|
||
if event == "start":
|
||
provider.on_prompt_start(prompt_id)
|
||
elif event == "end":
|
||
provider.on_prompt_end(prompt_id)
|
||
except Exception as e:
|
||
_cache_logger.warning(
|
||
f"Cache provider {provider.__class__.__name__} error on {event}: {e}"
|
||
)
|
||
|
||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||
|
||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||
set_preview_method(extra_data.get("preview_method"))
|
||
|
||
nodes.interrupt_processing(False)
|
||
|
||
if "client_id" in extra_data:
|
||
self.server.client_id = extra_data["client_id"]
|
||
else:
|
||
self.server.client_id = None
|
||
|
||
self.status_messages = []
|
||
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
||
|
||
self._notify_prompt_lifecycle("start", prompt_id)
|
||
ram_headroom = int(self.cache_args["ram"] * (1024**3))
|
||
ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024**3))
|
||
ram_release_callback = (
|
||
self.caches.outputs.ram_release
|
||
if self.cache_type == CacheType.RAM_PRESSURE
|
||
else None
|
||
)
|
||
comfy.memory_management.set_ram_cache_release_state(
|
||
ram_release_callback, ram_headroom
|
||
)
|
||
|
||
try:
|
||
with torch.inference_mode():
|
||
dynamic_prompt = DynamicPrompt(prompt)
|
||
reset_progress_state(prompt_id, dynamic_prompt)
|
||
add_progress_handler(WebUIProgressHandler(self.server))
|
||
is_changed_cache = IsChangedCache(
|
||
prompt_id, dynamic_prompt, self.caches.outputs
|
||
)
|
||
for cache in self.caches.all:
|
||
await cache.set_prompt(
|
||
dynamic_prompt, prompt.keys(), is_changed_cache
|
||
)
|
||
cache.clean_unused()
|
||
|
||
node_ids = list(prompt.keys())
|
||
cache_results = await asyncio.gather(
|
||
*(self.caches.outputs.get(node_id) for node_id in node_ids)
|
||
)
|
||
cached_nodes = [
|
||
node_id
|
||
for node_id, result in zip(node_ids, cache_results)
|
||
if result is not None
|
||
]
|
||
|
||
comfy.model_management.cleanup_models_gc()
|
||
self.add_message(
|
||
"execution_cached",
|
||
{"nodes": cached_nodes, "prompt_id": prompt_id},
|
||
broadcast=False,
|
||
)
|
||
pending_subgraph_results = {}
|
||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||
ui_node_outputs = {}
|
||
executed = set()
|
||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||
current_outputs = self.caches.outputs.all_node_ids()
|
||
for node_id in list(execute_outputs):
|
||
execution_list.add_node(node_id)
|
||
|
||
# ---- bounded-feedback bootstrap ---------------------------------
|
||
# Build per-step update functions for feedback chains that
|
||
# pass through ComfyMathExpression → CFGGuider / SamplerXXX.
|
||
# These are injected into the guider / sampler after the
|
||
# target node executes so the sampler can vary parameters
|
||
# (cfg, s_noise, ...) with step_index.
|
||
_feedback_cfg_injections = {} # guider_node_id → cfg_fn
|
||
_feedback_sampler_injections = {} # sampler_node_id → {param: fn}
|
||
for to_node_id, edges in execution_list.feedback_links.items():
|
||
for from_node_id, from_socket in edges:
|
||
try:
|
||
_build_feedback_fns(
|
||
dynamic_prompt,
|
||
from_node_id,
|
||
from_socket,
|
||
to_node_id,
|
||
_feedback_cfg_injections,
|
||
_feedback_sampler_injections,
|
||
)
|
||
except Exception:
|
||
pass # non-critical – feedback just wonʼt vary per step
|
||
# -----------------------------------------------------------------
|
||
|
||
while not execution_list.is_empty():
|
||
node_id, error, ex = await execution_list.stage_node_execution()
|
||
if error is not None:
|
||
self.handle_execution_error(
|
||
prompt_id,
|
||
dynamic_prompt.original_prompt,
|
||
current_outputs,
|
||
executed,
|
||
error,
|
||
ex,
|
||
)
|
||
break
|
||
|
||
assert node_id is not None, (
|
||
"Node ID should not be None at this point"
|
||
)
|
||
result, error, ex = await execute(
|
||
self.server,
|
||
dynamic_prompt,
|
||
self.caches,
|
||
node_id,
|
||
extra_data,
|
||
executed,
|
||
prompt_id,
|
||
execution_list,
|
||
pending_subgraph_results,
|
||
pending_async_nodes,
|
||
ui_node_outputs,
|
||
)
|
||
self.success = result != ExecutionResult.FAILURE
|
||
if result == ExecutionResult.FAILURE:
|
||
self.handle_execution_error(
|
||
prompt_id,
|
||
dynamic_prompt.original_prompt,
|
||
current_outputs,
|
||
executed,
|
||
error,
|
||
ex,
|
||
)
|
||
break
|
||
elif result == ExecutionResult.PENDING:
|
||
execution_list.unstage_node_execution()
|
||
else: # result == ExecutionResult.SUCCESS:
|
||
# ---- bounded-feedback injection ----
|
||
# If this node just produced a guider or sampler
|
||
# that is part of a feedback cycle, inject per-step
|
||
# update function(s).
|
||
if node_id in _feedback_cfg_injections:
|
||
try:
|
||
output = self.caches.outputs.get_local(node_id)
|
||
if (
|
||
output is not None
|
||
and output.outputs is not None
|
||
and len(output.outputs) > 0
|
||
and len(output.outputs[0]) > 0
|
||
):
|
||
guider = output.outputs[0][0]
|
||
guider._feedback_cfg_fn = _feedback_cfg_injections[
|
||
node_id
|
||
]
|
||
except Exception:
|
||
pass
|
||
if node_id in _feedback_sampler_injections:
|
||
try:
|
||
output = self.caches.outputs.get_local(node_id)
|
||
if (
|
||
output is not None
|
||
and output.outputs is not None
|
||
and len(output.outputs) > 0
|
||
and len(output.outputs[0]) > 0
|
||
):
|
||
sampler_obj = output.outputs[0][0]
|
||
sampler_obj._feedback_param_fns = (
|
||
_feedback_sampler_injections[node_id]
|
||
)
|
||
except Exception:
|
||
pass
|
||
# ---------------------------------------
|
||
execution_list.complete_node_execution()
|
||
|
||
if self.cache_type == CacheType.RAM_PRESSURE:
|
||
ram_release_callback(ram_inactive_headroom)
|
||
ram_shortfall = ram_headroom - psutil.virtual_memory().available
|
||
freed = comfy.model_management.free_pins(
|
||
ram_shortfall + 512 * (1024**2)
|
||
)
|
||
if freed < ram_shortfall:
|
||
if freed > 64 * (1024**2):
|
||
# AIMDO MEM_DECOMMIT can outrun psutil.available catching up.
|
||
time.sleep(0.05)
|
||
ram_release_callback(ram_headroom, free_active=True)
|
||
else:
|
||
# Only execute when the while-loop ends without break
|
||
# Send cached UI for intermediate output nodes that weren't executed
|
||
for node_id in dynamic_prompt.all_node_ids():
|
||
if node_id in executed:
|
||
continue
|
||
if not _is_intermediate_output(dynamic_prompt, node_id):
|
||
continue
|
||
cached = await self.caches.outputs.get(node_id)
|
||
if cached is not None:
|
||
display_node_id = dynamic_prompt.get_display_node_id(
|
||
node_id
|
||
)
|
||
_send_cached_ui(
|
||
self.server,
|
||
node_id,
|
||
display_node_id,
|
||
cached,
|
||
prompt_id,
|
||
ui_node_outputs,
|
||
)
|
||
self.add_message(
|
||
"execution_success", {"prompt_id": prompt_id}, broadcast=False
|
||
)
|
||
|
||
ui_outputs = {}
|
||
meta_outputs = {}
|
||
for node_id, ui_info in ui_node_outputs.items():
|
||
ui_outputs[node_id] = ui_info["output"]
|
||
meta_outputs[node_id] = ui_info["meta"]
|
||
self.history_result = {
|
||
"outputs": ui_outputs,
|
||
"meta": meta_outputs,
|
||
}
|
||
self.server.last_node_id = None
|
||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||
comfy.model_management.unload_all_models()
|
||
finally:
|
||
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
||
self._notify_prompt_lifecycle("end", prompt_id)
|
||
|
||
|
||
def _is_bounded_feedback_cycle(prompt, visiting, unique_id):
|
||
"""Check whether a detected dependency cycle is a *bounded* feedback loop.
|
||
|
||
A cycle is bounded when at least one node in it declares ``BOUNDED_FEEDBACK``,
|
||
i.e. the node has a finite internal iteration whose step / index variable
|
||
feeds back upstream to control its own parameters (e.g. a sampler's
|
||
``step_index`` flowing through a math expression to set ``cfg``).
|
||
|
||
Because the iteration is bounded (N steps, then terminates) this isn't an
|
||
infinite cycle — the DAG can safely allow it and the execution engine will
|
||
break the feedback edge by seeding the iteration output with an initial value.
|
||
"""
|
||
cycle_nodes = visiting[visiting.index(unique_id) :] + [unique_id]
|
||
for node_id in cycle_nodes:
|
||
if node_id not in prompt:
|
||
continue
|
||
class_type = prompt[node_id].get("class_type")
|
||
if class_type is None:
|
||
continue
|
||
obj_class = nodes.NODE_CLASS_MAPPINGS.get(class_type)
|
||
if obj_class is None:
|
||
continue
|
||
bounded = getattr(obj_class, "BOUNDED_FEEDBACK", None)
|
||
if bounded:
|
||
return True
|
||
return False
|
||
|
||
|
||
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||
if visiting is None:
|
||
visiting = []
|
||
|
||
unique_id = item
|
||
if unique_id in validated:
|
||
return validated[unique_id]
|
||
|
||
if unique_id in visiting:
|
||
cycle_path_nodes = visiting[visiting.index(unique_id) :] + [unique_id]
|
||
cycle_nodes = list(dict.fromkeys(cycle_path_nodes))
|
||
|
||
# A bounded feedback cycle is one where at least one node in the cycle
|
||
# declares BOUNDED_FEEDBACK — meaning its internal iteration is finite
|
||
# and its iteration output(s) can safely flow back upstream without
|
||
# causing an infinite loop (e.g. a sampler's step_index controlling cfg).
|
||
if _is_bounded_feedback_cycle(prompt, visiting, unique_id):
|
||
# Mark the repeated node as valid and continue the traversal on
|
||
# other branches. The execution layer handles the feedback edge
|
||
# by breaking it and seeding the iteration output with an initial
|
||
# value (e.g. step_index = 0).
|
||
validated[unique_id] = (True, [], unique_id)
|
||
return validated[unique_id]
|
||
|
||
cycle_path = " -> ".join(
|
||
f"{node_id} ({prompt[node_id]['class_type']})"
|
||
for node_id in cycle_path_nodes
|
||
)
|
||
for node_id in cycle_nodes:
|
||
validated[node_id] = (
|
||
False,
|
||
[
|
||
{
|
||
"type": "dependency_cycle",
|
||
"message": "Dependency cycle detected",
|
||
"details": cycle_path,
|
||
"extra_info": {
|
||
"node_id": node_id,
|
||
"cycle_nodes": cycle_nodes,
|
||
},
|
||
}
|
||
],
|
||
node_id,
|
||
)
|
||
return validated[unique_id]
|
||
|
||
inputs = prompt[unique_id]["inputs"]
|
||
class_type = prompt[unique_id]["class_type"]
|
||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||
|
||
errors = []
|
||
valid = True
|
||
|
||
v3_data = None
|
||
validate_function_inputs = []
|
||
validate_has_kwargs = False
|
||
if issubclass(obj_class, _ComfyNodeInternal):
|
||
obj_class: _io._ComfyNodeBaseInternal
|
||
class_inputs = obj_class.INPUT_TYPES()
|
||
class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs)
|
||
validate_function_name = "validate_inputs"
|
||
validate_function = first_real_override(obj_class, validate_function_name)
|
||
else:
|
||
class_inputs = obj_class.INPUT_TYPES()
|
||
validate_function_name = "VALIDATE_INPUTS"
|
||
validate_function = getattr(obj_class, validate_function_name, None)
|
||
if validate_function is not None:
|
||
argspec = inspect.getfullargspec(validate_function)
|
||
validate_function_inputs = argspec.args
|
||
validate_has_kwargs = argspec.varkw is not None
|
||
received_types = {}
|
||
|
||
valid_inputs = set(class_inputs.get("required", {})).union(
|
||
set(class_inputs.get("optional", {}))
|
||
)
|
||
|
||
for x in valid_inputs:
|
||
input_type, input_category, extra_info = get_input_info(
|
||
obj_class, x, class_inputs
|
||
)
|
||
assert extra_info is not None
|
||
if x not in inputs:
|
||
if input_category == "required":
|
||
details = f"{x}" if not v3_data else x.split(".")[-1]
|
||
error = {
|
||
"type": "required_input_missing",
|
||
"message": "Required input is missing",
|
||
"details": details,
|
||
"extra_info": {"input_name": x},
|
||
}
|
||
errors.append(error)
|
||
continue
|
||
|
||
val = inputs[x]
|
||
info = (input_type, extra_info)
|
||
if isinstance(val, list):
|
||
if len(val) != 2:
|
||
error = {
|
||
"type": "bad_linked_input",
|
||
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
||
"details": f"{x}",
|
||
"extra_info": {
|
||
"input_name": x,
|
||
"input_config": info,
|
||
"received_value": val,
|
||
},
|
||
}
|
||
errors.append(error)
|
||
continue
|
||
|
||
o_id = val[0]
|
||
o_class_type = prompt[o_id]["class_type"]
|
||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||
received_type = r[val[1]]
|
||
received_types[x] = received_type
|
||
if (
|
||
"input_types" not in validate_function_inputs
|
||
and not validate_node_input(received_type, input_type)
|
||
):
|
||
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
|
||
error = {
|
||
"type": "return_type_mismatch",
|
||
"message": "Return type mismatch between linked nodes",
|
||
"details": details,
|
||
"extra_info": {
|
||
"input_name": x,
|
||
"input_config": info,
|
||
"received_type": received_type,
|
||
"linked_node": val,
|
||
},
|
||
}
|
||
errors.append(error)
|
||
continue
|
||
try:
|
||
visiting.append(unique_id)
|
||
try:
|
||
r = await validate_inputs(
|
||
prompt_id, prompt, o_id, validated, visiting
|
||
)
|
||
finally:
|
||
visiting.pop()
|
||
if r[0] is False:
|
||
# `r` will be set in `validated[o_id]` already
|
||
valid = False
|
||
continue
|
||
except Exception as ex:
|
||
typ, _, tb = sys.exc_info()
|
||
valid = False
|
||
exception_type = full_type_name(typ)
|
||
reasons = [
|
||
{
|
||
"type": "exception_during_inner_validation",
|
||
"message": "Exception when validating inner node",
|
||
"details": str(ex),
|
||
"extra_info": {
|
||
"input_name": x,
|
||
"input_config": info,
|
||
"exception_message": str(ex),
|
||
"exception_type": exception_type,
|
||
"traceback": traceback.format_tb(tb),
|
||
"linked_node": val,
|
||
},
|
||
}
|
||
]
|
||
validated[o_id] = (False, reasons, o_id)
|
||
continue
|
||
else:
|
||
try:
|
||
# Unwraps values wrapped in __value__ key or typed wrapper.
|
||
# This is used to pass list widget values to execution,
|
||
# as by default list value is reserved to represent the
|
||
# connection between nodes.
|
||
if isinstance(val, dict):
|
||
if "__value__" in val:
|
||
val = val["__value__"]
|
||
inputs[x] = val
|
||
|
||
if input_type == "INT":
|
||
val = int(val)
|
||
inputs[x] = val
|
||
if input_type == "FLOAT":
|
||
val = float(val)
|
||
inputs[x] = val
|
||
if input_type == "STRING":
|
||
val = str(val)
|
||
inputs[x] = val
|
||
if input_type == "BOOLEAN":
|
||
val = bool(val)
|
||
inputs[x] = val
|
||
except Exception as ex:
|
||
error = {
|
||
"type": "invalid_input_type",
|
||
"message": f"Failed to convert an input value to a {input_type} value",
|
||
"details": f"{x}, {val}, {ex}",
|
||
"extra_info": {
|
||
"input_name": x,
|
||
"input_config": info,
|
||
"received_value": val,
|
||
"exception_message": str(ex),
|
||
},
|
||
}
|
||
errors.append(error)
|
||
continue
|
||
|
||
if x not in validate_function_inputs and not validate_has_kwargs:
|
||
if "min" in extra_info and val < extra_info["min"]:
|
||
error = {
|
||
"type": "value_smaller_than_min",
|
||
"message": "Value {} smaller than min of {}".format(
|
||
val, extra_info["min"]
|
||
),
|
||
"details": f"{x}",
|
||
"extra_info": {
|
||
"input_name": x,
|
||
"input_config": info,
|
||
"received_value": val,
|
||
},
|
||
}
|
||
errors.append(error)
|
||
continue
|
||
if "max" in extra_info and val > extra_info["max"]:
|
||
error = {
|
||
"type": "value_bigger_than_max",
|
||
"message": "Value {} bigger than max of {}".format(
|
||
val, extra_info["max"]
|
||
),
|
||
"details": f"{x}",
|
||
"extra_info": {
|
||
"input_name": x,
|
||
"input_config": info,
|
||
"received_value": val,
|
||
},
|
||
}
|
||
errors.append(error)
|
||
continue
|
||
|
||
if isinstance(input_type, list) or input_type == io.Combo.io_type:
|
||
if input_type == io.Combo.io_type:
|
||
combo_options = extra_info.get("options", [])
|
||
else:
|
||
combo_options = input_type
|
||
is_multiselect = extra_info.get("multiselect", False)
|
||
if is_multiselect and isinstance(val, list):
|
||
invalid_vals = [v for v in val if v not in combo_options]
|
||
else:
|
||
invalid_vals = [val] if val not in combo_options else []
|
||
if invalid_vals:
|
||
input_config = info
|
||
list_info = ""
|
||
|
||
# Don't send back gigantic lists like if they're lots of
|
||
# scanned model filepaths
|
||
if len(combo_options) > 20:
|
||
list_info = f"(list of length {len(combo_options)})"
|
||
input_config = None
|
||
else:
|
||
list_info = str(combo_options)
|
||
|
||
error = {
|
||
"type": "value_not_in_list",
|
||
"message": "Value not in list",
|
||
"details": f"{x}: {', '.join(repr(v) for v in invalid_vals)} not in {list_info}",
|
||
"extra_info": {
|
||
"input_name": x,
|
||
"input_config": input_config,
|
||
"received_value": val,
|
||
},
|
||
}
|
||
errors.append(error)
|
||
continue
|
||
|
||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
|
||
input_filtered = {}
|
||
for x in input_data_all:
|
||
if x in validate_function_inputs or validate_has_kwargs:
|
||
input_filtered[x] = input_data_all[x]
|
||
if "input_types" in validate_function_inputs:
|
||
input_filtered["input_types"] = [received_types]
|
||
|
||
ret = await _async_map_node_over_list(
|
||
prompt_id,
|
||
unique_id,
|
||
obj_class,
|
||
input_filtered,
|
||
validate_function_name,
|
||
v3_data=v3_data,
|
||
)
|
||
ret = await resolve_map_node_over_list_results(ret)
|
||
for x in input_filtered:
|
||
for i, r in enumerate(ret):
|
||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||
details = f"{x}"
|
||
if r is not False:
|
||
details += f" - {str(r)}"
|
||
|
||
error = {
|
||
"type": "custom_validation_failed",
|
||
"message": "Custom validation failed for node",
|
||
"details": details,
|
||
"extra_info": {
|
||
"input_name": x,
|
||
},
|
||
}
|
||
errors.append(error)
|
||
continue
|
||
|
||
ret = validated.get(unique_id, (True, [], unique_id))
|
||
# Recursive cycle detection may have already populated an error on us. Join it.
|
||
ret = (
|
||
ret[0] and valid is True and not errors,
|
||
ret[1] + [error for error in errors if error not in ret[1]],
|
||
unique_id,
|
||
)
|
||
|
||
validated[unique_id] = ret
|
||
return ret
|
||
|
||
|
||
def full_type_name(klass):
|
||
module = klass.__module__
|
||
if module == "builtins":
|
||
return klass.__qualname__
|
||
return module + "." + klass.__qualname__
|
||
|
||
|
||
async def validate_prompt(
|
||
prompt_id, prompt, partial_execution_list: Union[list[str], None]
|
||
):
|
||
outputs = set()
|
||
for x in prompt:
|
||
if "class_type" not in prompt[x]:
|
||
node_data = prompt[x]
|
||
node_title = node_data.get("_meta", {}).get("title")
|
||
error = {
|
||
"type": "missing_node_type",
|
||
"message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.",
|
||
"details": f"Node ID '#{x}'",
|
||
"extra_info": {
|
||
"node_id": x,
|
||
"class_type": None,
|
||
"node_title": node_title,
|
||
},
|
||
}
|
||
return (False, error, [], {})
|
||
|
||
class_type = prompt[x]["class_type"]
|
||
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
||
if class_ is None:
|
||
node_data = prompt[x]
|
||
node_title = node_data.get("_meta", {}).get("title", class_type)
|
||
error = {
|
||
"type": "missing_node_type",
|
||
"message": f"Node '{node_title}' not found. The custom node may not be installed.",
|
||
"details": f"Node ID '#{x}'",
|
||
"extra_info": {
|
||
"node_id": x,
|
||
"class_type": class_type,
|
||
"node_title": node_title,
|
||
},
|
||
}
|
||
return (False, error, [], {})
|
||
|
||
if hasattr(class_, "OUTPUT_NODE") and class_.OUTPUT_NODE is True:
|
||
if partial_execution_list is None or x in partial_execution_list:
|
||
outputs.add(x)
|
||
|
||
if len(outputs) == 0:
|
||
error = {
|
||
"type": "prompt_no_outputs",
|
||
"message": "Prompt has no outputs",
|
||
"details": "",
|
||
"extra_info": {},
|
||
}
|
||
return (False, error, [], {})
|
||
|
||
good_outputs = set()
|
||
errors = []
|
||
node_errors = {}
|
||
validated = {}
|
||
for o in outputs:
|
||
valid = False
|
||
reasons = []
|
||
try:
|
||
m = await validate_inputs(prompt_id, prompt, o, validated)
|
||
valid = m[0]
|
||
reasons = m[1]
|
||
except Exception as ex:
|
||
typ, _, tb = sys.exc_info()
|
||
valid = False
|
||
exception_type = full_type_name(typ)
|
||
reasons = [
|
||
{
|
||
"type": "exception_during_validation",
|
||
"message": "Exception when validating node",
|
||
"details": str(ex),
|
||
"extra_info": {
|
||
"exception_type": exception_type,
|
||
"traceback": traceback.format_tb(tb),
|
||
},
|
||
}
|
||
]
|
||
validated[o] = (False, reasons, o)
|
||
|
||
if valid is True:
|
||
good_outputs.add(o)
|
||
else:
|
||
logging.error(f"Failed to validate prompt for output {o}:")
|
||
if len(reasons) > 0:
|
||
logging.error("* (prompt):")
|
||
for reason in reasons:
|
||
logging.error(f" - {reason['message']}: {reason['details']}")
|
||
errors += [(o, reasons)]
|
||
for node_id, result in validated.items():
|
||
valid = result[0]
|
||
reasons = result[1]
|
||
# If a node upstream has errors, the nodes downstream will also
|
||
# be reported as invalid, but there will be no errors attached.
|
||
# So don't return those nodes as having errors in the response.
|
||
if valid is not True and len(reasons) > 0:
|
||
if node_id not in node_errors:
|
||
class_type = prompt[node_id]["class_type"]
|
||
node_errors[node_id] = {
|
||
"errors": reasons,
|
||
"dependent_outputs": [],
|
||
"class_type": class_type,
|
||
}
|
||
logging.error(f"* {class_type} {node_id}:")
|
||
for reason in reasons:
|
||
logging.error(
|
||
f" - {reason['message']}: {reason['details']}"
|
||
)
|
||
node_errors[node_id]["dependent_outputs"].append(o)
|
||
logging.error("Output will be ignored")
|
||
|
||
if len(good_outputs) == 0:
|
||
errors_list = []
|
||
for o, errors in errors:
|
||
for error in errors:
|
||
errors_list.append(f"{error['message']}: {error['details']}")
|
||
errors_list = "\n".join(errors_list)
|
||
|
||
error = {
|
||
"type": "prompt_outputs_failed_validation",
|
||
"message": "Prompt outputs failed validation",
|
||
"details": errors_list,
|
||
"extra_info": {},
|
||
}
|
||
|
||
return (False, error, list(good_outputs), node_errors)
|
||
|
||
return (True, None, list(good_outputs), node_errors)
|
||
|
||
|
||
MAXIMUM_HISTORY_SIZE = 10000
|
||
|
||
|
||
class PromptQueue:
|
||
def __init__(self, server):
|
||
self.server = server
|
||
self.mutex = threading.RLock()
|
||
self.not_empty = threading.Condition(self.mutex)
|
||
self.task_counter = 0
|
||
self.queue = []
|
||
self.currently_running = {}
|
||
self.history = {}
|
||
self.flags = {}
|
||
|
||
def put(self, item):
|
||
with self.mutex:
|
||
heapq.heappush(self.queue, item)
|
||
self.server.queue_updated()
|
||
self.not_empty.notify()
|
||
|
||
def get(self, timeout=None):
|
||
with self.not_empty:
|
||
while len(self.queue) == 0:
|
||
self.not_empty.wait(timeout=timeout)
|
||
if timeout is not None and len(self.queue) == 0:
|
||
return None
|
||
item = heapq.heappop(self.queue)
|
||
i = self.task_counter
|
||
self.currently_running[i] = copy.deepcopy(item)
|
||
self.task_counter += 1
|
||
self.server.queue_updated()
|
||
return (item, i)
|
||
|
||
class ExecutionStatus(NamedTuple):
|
||
status_str: Literal["success", "error"]
|
||
completed: bool
|
||
messages: List[str]
|
||
|
||
def task_done(
|
||
self,
|
||
item_id,
|
||
history_result,
|
||
status: Optional["PromptQueue.ExecutionStatus"],
|
||
process_item=None,
|
||
):
|
||
with self.mutex:
|
||
prompt = self.currently_running.pop(item_id)
|
||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||
self.history.pop(next(iter(self.history)))
|
||
|
||
status_dict: Optional[dict] = None
|
||
if status is not None:
|
||
status_dict = copy.deepcopy(status._asdict())
|
||
|
||
if process_item is not None:
|
||
prompt = process_item(prompt)
|
||
|
||
self.history[prompt[1]] = {
|
||
"prompt": prompt,
|
||
"outputs": {},
|
||
"status": status_dict,
|
||
}
|
||
self.history[prompt[1]].update(history_result)
|
||
self.server.queue_updated()
|
||
|
||
# Note: slow
|
||
def get_current_queue(self):
|
||
with self.mutex:
|
||
out = []
|
||
for x in self.currently_running.values():
|
||
out += [x]
|
||
return (out, copy.deepcopy(self.queue))
|
||
|
||
# read-safe as long as queue items are immutable
|
||
def get_current_queue_volatile(self):
|
||
with self.mutex:
|
||
running = [x for x in self.currently_running.values()]
|
||
queued = copy.copy(self.queue)
|
||
return (running, queued)
|
||
|
||
def get_tasks_remaining(self):
|
||
with self.mutex:
|
||
return len(self.queue) + len(self.currently_running)
|
||
|
||
def wipe_queue(self):
|
||
with self.mutex:
|
||
self.queue = []
|
||
self.server.queue_updated()
|
||
|
||
def delete_queue_item(self, function):
|
||
with self.mutex:
|
||
for x in range(len(self.queue)):
|
||
if function(self.queue[x]):
|
||
if len(self.queue) == 1:
|
||
self.wipe_queue()
|
||
else:
|
||
self.queue.pop(x)
|
||
heapq.heapify(self.queue)
|
||
self.server.queue_updated()
|
||
return True
|
||
return False
|
||
|
||
def get_history(self, prompt_id=None, max_items=None, offset=-1, map_function=None):
|
||
with self.mutex:
|
||
if prompt_id is None:
|
||
out = {}
|
||
i = 0
|
||
if offset < 0 and max_items is not None:
|
||
offset = len(self.history) - max_items
|
||
for k in self.history:
|
||
if i >= offset:
|
||
p = self.history[k]
|
||
if map_function is not None:
|
||
p = map_function(p)
|
||
out[k] = p
|
||
if max_items is not None and len(out) >= max_items:
|
||
break
|
||
i += 1
|
||
return out
|
||
elif prompt_id in self.history:
|
||
p = self.history[prompt_id]
|
||
if map_function is None:
|
||
p = copy.deepcopy(p)
|
||
else:
|
||
p = map_function(p)
|
||
return {prompt_id: p}
|
||
else:
|
||
return {}
|
||
|
||
def wipe_history(self):
|
||
with self.mutex:
|
||
self.history = {}
|
||
|
||
def delete_history_item(self, id_to_delete):
|
||
with self.mutex:
|
||
self.history.pop(id_to_delete, None)
|
||
|
||
def set_flag(self, name, data):
|
||
with self.mutex:
|
||
self.flags[name] = data
|
||
self.not_empty.notify()
|
||
|
||
def get_flags(self, reset=True):
|
||
with self.mutex:
|
||
if reset:
|
||
ret = self.flags
|
||
self.flags = {}
|
||
return ret
|
||
else:
|
||
return self.flags.copy()
|