diff --git a/execution.py b/execution.py index c265cbb49..69e21cc4b 100644 --- a/execution.py +++ b/execution.py @@ -1,33 +1,42 @@ +import asyncio import copy import heapq import inspect import logging -import psutil import sys import threading import time import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union -import asyncio +import comfy_aimdo.model_vbar +import psutil import torch -from comfy.cli_args import args import comfy.memory_management import comfy.model_management import comfy.model_prefetch -import comfy_aimdo.model_vbar - -from latent_preview import set_preview_method import nodes +from comfy.cli_args import args +from comfy_api.internal import ( + _ComfyNodeInternal, + _NodeOutputInternal, + first_real_override, + is_class, + make_locked_method_func, +) +from comfy_api.latest import _io, io +from comfy_execution.asset_enrichment import enrich_output_with_assets +from comfy_execution.cache_provider import _get_cache_providers, _has_cache_providers +from comfy_execution.cache_provider import _logger as _cache_logger from comfy_execution.caching import ( BasicCache, CacheKeySetID, CacheKeySetInputSignature, - NullCache, HierarchicalCache, LRUCache, + NullCache, RAMPressureCache, ) from comfy_execution.graph import ( @@ -37,13 +46,15 @@ from comfy_execution.graph import ( get_input_info, ) from comfy_execution.graph_utils import GraphBuilder, is_link -from comfy_execution.validation import validate_node_input -from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler +from comfy_execution.progress import ( + WebUIProgressHandler, + add_progress_handler, + get_progress_state, + reset_progress_state, +) from comfy_execution.utils import CurrentNodeContext -from comfy_execution.asset_enrichment import enrich_output_with_assets -from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func -from comfy_api.latest import io, _io -from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger +from comfy_execution.validation import validate_node_input +from latent_preview import set_preview_method class ExecutionResult(Enum): @@ -51,11 +62,15 @@ class ExecutionResult(Enum): FAILURE = 1 PENDING = 2 + class DuplicateNodeError(Exception): pass + class IsChangedCache: - def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): + def __init__( + self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache + ): self.prompt_id = prompt_id self.dynprompt = dynprompt self.outputs_cache = outputs_cache @@ -70,7 +85,10 @@ class IsChangedCache: class_def = nodes.NODE_CLASS_MAPPINGS[class_type] has_is_changed = False is_changed_name = None - if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: + if ( + issubclass(class_def, _ComfyNodeInternal) + and first_real_override(class_def, "fingerprint_inputs") is not None + ): has_is_changed = True is_changed_name = "fingerprint_inputs" elif hasattr(class_def, "IS_CHANGED"): @@ -85,11 +103,22 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = get_input_data( + node["inputs"], class_def, node_id, None + ) try: - is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) + is_changed = await _async_map_node_over_list( + self.prompt_id, + node_id, + class_def, + input_data_all, + is_changed_name, + v3_data=v3_data, + ) is_changed = await resolve_map_node_over_list_results(is_changed) - node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + node["is_changed"] = [ + None if isinstance(x, ExecutionBlocker) else x for x in is_changed + ] except Exception as e: logging.warning("WARNING: {}".format(e)) node["is_changed"] = float("NaN") @@ -145,15 +174,21 @@ class CacheSet: # Performs like the old cache -- dump data ASAP def init_classic_cache(self): - self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True) + self.outputs = HierarchicalCache( + CacheKeySetInputSignature, enable_providers=True + ) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True) + self.outputs = LRUCache( + CacheKeySetInputSignature, max_size=cache_size, enable_providers=True + ) self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True) + self.outputs = RAMPressureCache( + CacheKeySetInputSignature, enable_providers=True + ) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -166,37 +201,51 @@ class CacheSet: } return result + SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): + +def get_input_data( + inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={} +): is_v3 = issubclass(class_def, _ComfyNodeInternal) v3_data: io.V3Data = {} hidden_inputs_v3 = {} valid_inputs = class_def.INPUT_TYPES() if is_v3: - valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs) + valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs( + valid_inputs, inputs + ) input_data_all = {} missing_keys = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) + def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) - if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): + + if is_link(input_data) and ( + not input_info or not input_info.get("rawLink", False) + ): input_unique_id = input_data[0] output_index = input_data[1] if execution_list is None: mark_missing() - continue # This might be a lazily-evaluated input + continue # This might be a lazily-evaluated input cached = execution_list.get_cache(input_unique_id, unique_id) if cached is None or cached.outputs is None: # If this is a bounded-feedback link whose source hasn't # executed yet, supply the type-appropriate initial value # (e.g. step_index=0) so the feedback chain can evaluate # before the iteration-producing node runs. - if _is_feedback_link(execution_list, unique_id, input_unique_id, output_index): - default_val = _get_feedback_default(dynprompt, input_unique_id, output_index) + if _is_feedback_link( + execution_list, unique_id, input_unique_id, output_index + ): + default_val = _get_feedback_default( + dynprompt, input_unique_id, output_index + ) obj = default_val if isinstance(obj, (int, float, bool, str)): obj = (obj,) @@ -220,29 +269,41 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= if is_v3: if hidden is not None: if io.Hidden.prompt.name in hidden: - hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} + hidden_inputs_v3[io.Hidden.prompt] = ( + dynprompt.get_original_prompt() if dynprompt is not None else {} + ) if io.Hidden.dynprompt.name in hidden: hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt if io.Hidden.extra_pnginfo.name in hidden: - hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) + hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get( + "extra_pnginfo", None + ) if io.Hidden.unique_id.name in hidden: hidden_inputs_v3[io.Hidden.unique_id] = unique_id if io.Hidden.auth_token_comfy_org.name in hidden: - hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) + hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get( + "auth_token_comfy_org", None + ) if io.Hidden.api_key_comfy_org.name in hidden: - hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get( + "api_key_comfy_org", None + ) if io.Hidden.comfy_usage_source.name in hidden: - hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None) + hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get( + "comfy_usage_source", None + ) else: if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] + input_data_all[x] = [ + dynprompt.get_original_prompt() if dynprompt is not None else {} + ] if h[x] == "DYNPROMPT": input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": - input_data_all[x] = [extra_data.get('extra_pnginfo', None)] + input_data_all[x] = [extra_data.get("extra_pnginfo", None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] if h[x] == "AUTH_TOKEN_COMFY_ORG": @@ -254,7 +315,9 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= v3_data["hidden_inputs"] = hidden_inputs_v3 return input_data_all, missing_keys, v3_data -map_node_over_list = None #Don't hook this please + +map_node_over_list = None # Don't hook this please + async def resolve_map_node_over_list_results(results): remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()] @@ -268,7 +331,18 @@ async def resolve_map_node_over_list_results(results): raise exc return [x.result() if isinstance(x, asyncio.Task) else x for x in results] -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): + +async def _async_map_node_over_list( + prompt_id, + unique_id, + obj, + input_data_all, + func, + allow_interrupt=False, + execution_block_cb=None, + pre_execute_cb=None, + v3_data=None, +): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -282,6 +356,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] + async def process_inputs(inputs, index=None, input_is_list=False): if allow_interrupt: nodes.before_node_execution() @@ -299,7 +374,9 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f if pre_execute_cb is not None and index is not None: pre_execute_cb(index) # V3 - if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): + if isinstance(obj, _ComfyNodeInternal) or ( + is_class(obj) and issubclass(obj, _ComfyNodeInternal) + ): # if is just a class, then assign no state, just create clone if is_class(obj): type_obj = obj @@ -318,10 +395,14 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f else: f = getattr(obj, func) if inspect.iscoroutinefunction(f): + async def async_wrapper(f, prompt_id, unique_id, list_index, args): with CurrentNodeContext(prompt_id, unique_id, list_index): return await f(**args) - task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs)) + + task = asyncio.create_task( + async_wrapper(f, prompt_id, unique_id, index, args=inputs) + ) # Give the task a chance to execute without yielding await asyncio.sleep(0) if task.done(): @@ -368,14 +449,36 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) - has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) + +async def get_output_data( + prompt_id, + unique_id, + obj, + input_data_all, + execution_block_cb=None, + pre_execute_cb=None, + v3_data=None, +): + return_values = await _async_map_node_over_list( + prompt_id, + unique_id, + obj, + input_data_all, + obj.FUNCTION, + allow_interrupt=True, + execution_block_cb=execution_block_cb, + pre_execute_cb=pre_execute_cb, + v3_data=v3_data, + ) + has_pending_task = any( + isinstance(r, asyncio.Task) and not r.done() for r in return_values + ) if has_pending_task: return return_values, {}, False, has_pending_task output, ui, has_subgraph = get_output_from_returns(return_values, obj) return output, ui, has_subgraph, False + def get_output_from_returns(return_values, obj): results = [] uis = [] @@ -384,17 +487,17 @@ def get_output_from_returns(return_values, obj): for i in range(len(return_values)): r = return_values[i] if isinstance(r, dict): - if 'ui' in r: - uis.append(r['ui']) - if 'expand' in r: + if "ui" in r: + uis.append(r["ui"]) + if "expand" in r: # Perform an expansion, but do not append results has_subgraph = True - new_graph = r['expand'] + new_graph = r["expand"] result = r.get("result", None) if isinstance(result, ExecutionBlocker): result = tuple([result] * len(obj.RETURN_TYPES)) subgraph_results.append((new_graph, result)) - elif 'result' in r: + elif "result" in r: result = r.get("result", None) if isinstance(result, ExecutionBlocker): result = tuple([result] * len(obj.RETURN_TYPES)) @@ -412,12 +515,16 @@ def get_output_from_returns(return_values, obj): new_graph = r.expand result = r.result if r.block_execution is not None: - result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + result = tuple( + [ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES) + ) subgraph_results.append((new_graph, result)) elif r.result is not None: result = r.result if r.block_execution is not None: - result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + result = tuple( + [ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES) + ) results.append(result) subgraph_results.append((None, result)) else: @@ -441,6 +548,7 @@ def get_output_from_returns(return_values, obj): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui, has_subgraph + def format_value(x): if x is None: return None @@ -449,31 +557,56 @@ def format_value(x): else: return str(x) + def _is_intermediate_output(dynprompt, node_id): class_type = dynprompt.get_node(node_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False) + return getattr(class_def, "HAS_INTERMEDIATE_OUTPUT", False) def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs): if server.client_id is None: return cached_ui = cached.ui or {} - server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id) + server.send_sync( + "executed", + { + "node": node_id, + "display_node": display_node_id, + "output": cached_ui.get("output", None), + "prompt_id": prompt_id, + }, + server.client_id, + ) if cached.ui is not None: ui_outputs[node_id] = cached.ui -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): + +async def execute( + server, + dynprompt, + caches, + current_item, + extra_data, + executed, + prompt_id, + execution_list, + pending_subgraph_results, + pending_async_nodes, + ui_outputs, +): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) parent_node_id = dynprompt.get_parent_node_id(unique_id) - inputs = dynprompt.get_node(unique_id)['inputs'] - class_type = dynprompt.get_node(unique_id)['class_type'] + inputs = dynprompt.get_node(unique_id)["inputs"] + class_type = dynprompt.get_node(unique_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] cached = await caches.outputs.get(unique_id) if cached is not None: - _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs) + _send_cached_ui( + server, unique_id, display_node_id, cached, prompt_id, ui_outputs + ) get_progress_state().finish_progress(unique_id) execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) @@ -493,7 +626,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: results.append(r) del pending_async_nodes[unique_id] - output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def) + output_data, output_ui, has_subgraph = get_output_from_returns( + results, class_def + ) elif unique_id in pending_subgraph_results: cached_results = pending_subgraph_results[unique_id] resolved_outputs = [] @@ -505,7 +640,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_cached = execution_list.get_cache(source_node, unique_id) + node_cached = execution_list.get_cache( + source_node, unique_id + ) for o in node_cached.outputs[source_output]: resolved_output.append(o) @@ -518,10 +655,20 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data = get_input_data( + inputs, class_def, unique_id, execution_list, dynprompt, extra_data + ) if server.client_id is not None: server.last_node_id = display_node_id - server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync( + "executing", + { + "node": unique_id, + "display_node": display_node_id, + "prompt_id": prompt_id, + }, + server.client_id, + ) obj = await caches.objects.get(unique_id) if obj is None: @@ -529,19 +676,38 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, await caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): - lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None + lazy_status_present = ( + first_real_override(class_def, "check_lazy_status") is not None + ) else: - lazy_status_present = getattr(obj, "check_lazy_status", None) is not None + lazy_status_present = ( + getattr(obj, "check_lazy_status", None) is not None + ) if lazy_status_present: # for check_lazy_status, the returned data should include the original key of the input v3_data_lazy = v3_data.copy() v3_data_lazy["create_dynamic_tuple"] = True - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy) - required_inputs = await resolve_map_node_over_list_results(required_inputs) - required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) - required_inputs = [x for x in required_inputs if isinstance(x,str) and ( - x not in input_data_all or x in missing_keys - )] + required_inputs = await _async_map_node_over_list( + prompt_id, + unique_id, + obj, + input_data_all, + "check_lazy_status", + allow_interrupt=True, + v3_data=v3_data_lazy, + ) + required_inputs = await resolve_map_node_over_list_results( + required_inputs + ) + required_inputs = set( + sum([r for r in required_inputs if isinstance(r, list)], []) + ) + required_inputs = [ + x + for x in required_inputs + if isinstance(x, str) + and (x not in input_data_all or x in missing_keys) + ] if len(required_inputs) > 0: for i in required_inputs: execution_list.make_input_strong_link(unique_id, i) @@ -554,7 +720,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, "node_id": unique_id, "node_type": class_type, "executed": list(executed), - "exception_message": f"Execution Blocked: {block.message}", "exception_type": "ExecutionBlocked", "traceback": [], @@ -565,12 +730,26 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, return ExecutionBlocker(None) else: return block + def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) try: - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + ( + output_data, + output_ui, + has_subgraph, + has_pending_tasks, + ) = await get_output_data( + prompt_id, + unique_id, + obj, + input_data_all, + execution_block_cb=execution_block_cb, + pre_execute_cb=pre_execute_cb, + v3_data=v3_data, + ) finally: if comfy.memory_management.aimdo_enabled: if args.verbose == "DEBUG": @@ -582,10 +761,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) + async def await_completion(): tasks = [x for x in output_data if isinstance(x, asyncio.Task)] await asyncio.gather(*tasks, return_exceptions=True) unblock() + asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: @@ -600,10 +781,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, "parent_node": parent_node_id, "real_node_id": real_node_id, }, - "output": output_ui + "output": output_ui, } if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync( + "executed", + { + "node": unique_id, + "display_node": display_node_id, + "output": output_ui, + "prompt_id": prompt_id, + }, + server.client_id, + ) if has_subgraph: cached_outputs = [] new_node_ids = [] @@ -617,15 +807,23 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) - dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) + dynprompt.add_ephemeral_node( + node_id, node_info, unique_id, display_id + ) # Figure out if the newly created node is an output node class_type = node_info["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + if ( + hasattr(class_def, "OUTPUT_NODE") + and class_def.OUTPUT_NODE == True + ): new_output_ids.append(node_id) for i in range(len(node_outputs)): if is_link(node_outputs[i]): - from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] + from_node_id, from_socket = ( + node_outputs[i][0], + node_outputs[i][1], + ) new_output_links.append((from_node_id, from_socket)) cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) @@ -668,18 +866,26 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." - logging.info("Memory summary:\n{}".format(comfy.model_management.debug_memory_summary())) + logging.info( + "Memory summary:\n{}".format( + comfy.model_management.debug_memory_summary() + ) + ) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() - elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type: - tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected." + elif ( + isinstance(ex, RuntimeError) + and ("mat1 and mat2 shapes" in str(ex)) + and "Sampler" in class_type + ): + tips = '\n\nTIPS: If you have any "Load CLIP" or "*CLIP Loader" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected.' error_details = { "node_id": real_node_id, "exception_message": "{}\n{}".format(ex, tips), "exception_type": exception_type, "traceback": traceback.format_tb(tb), - "current_inputs": input_data_formatted + "current_inputs": input_data_formatted, } return (ExecutionResult.FAILURE, error_details, ex) @@ -711,8 +917,14 @@ def _get_feedback_default(dynprompt, from_node_id, from_socket): return 0 -def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, - cfg_injections, sampler_injections): +def _build_feedback_fns( + dynamic_prompt, + from_node_id, + from_socket, + to_node_id, + cfg_injections, + sampler_injections, +): """Try to build per-step update functions from a feedback edge. Walks forward from the feedback-receiving node through intermediate @@ -738,24 +950,35 @@ def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, return from simpleeval import simple_eval + from comfy_extras.nodes_math import MATH_FUNCTIONS # ---- helpers ---- - def _find_consumers(source_id): + def _find_consumers(source_id, source_socket=0): + """Return all nodes consuming output *source_socket* of *source_id*. + + Each result is ``(consumer_id, class_type, input_name)``. + """ consumers = [] for nid, n in prompt.items(): for iname, ival in n.get("inputs", {}).items(): - if isinstance(ival, list) and len(ival) == 2 \ - and ival[0] == source_id and ival[1] == 0: + if ( + isinstance(ival, list) + and len(ival) == 2 + and ival[0] == source_id + and ival[1] == source_socket + ): consumers.append((nid, n.get("class_type"), iname)) return consumers def _is_sampler_target(class_type): - # Sampler-producing nodes whose parameters can be updated per-step - # via KSAMPLER.extra_options. - return (class_type is not None - and "Sampler" in class_type - and class_type != "SamplerCustomAdvanced") + """Return True if *class_type* is a sampler that accepts per-step + parameter updates via :class:`KSAMPLER.extra_options`.""" + return ( + class_type is not None + and "Sampler" in class_type + and class_type != "SamplerCustomAdvanced" + ) def _resolve_input_value(source_node_id, source_socket): """Try to resolve a non-feedback linked input to a static value. @@ -794,14 +1017,19 @@ def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, if class_type in ("PrimitiveInt", "PrimitiveFloat", "PrimitiveBool"): if source_socket == 0 and len(wv) > 0: return wv[0] - if cls is not None and source_socket < len(req_names) and source_socket < len(wv): + if ( + cls is not None + and source_socket < len(req_names) + and source_socket < len(wv) + ): return wv[source_socket] return None except Exception: return None - def _collect_extra_names(node_id, feedback_from_node, feedback_from_socket, - feedback_var_name): + def _collect_extra_names( + node_id, feedback_from_node, feedback_from_socket, feedback_var_name + ): """Collect non-feedback linked inputs from a MathExpression node and resolve them to values. Returns dict of name→value.""" extra = {} @@ -814,8 +1042,9 @@ def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, continue src_id, src_socket = inp_val[0], inp_val[1] # Skip the feedback-linked input — that's the iteration variable - if (src_id == str(feedback_from_node) - and int(src_socket) == int(feedback_from_socket)): + if src_id == str(feedback_from_node) and int(src_socket) == int( + feedback_from_socket + ): continue # This is an additional linked input — try to resolve it val = _resolve_input_value(src_id, src_socket) @@ -846,16 +1075,19 @@ def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, var_name = None for input_name, input_val in node.get("inputs", {}).items(): - if isinstance(input_val, list) and len(input_val) == 2 \ - and input_val[0] == from_node and input_val[1] == from_socket: + if ( + isinstance(input_val, list) + and len(input_val) == 2 + and input_val[0] == from_node + and input_val[1] == from_socket + ): var_name = input_name.rsplit(".", 1)[-1] break if var_name is None: return [] # Collect additional (non-feedback) input values for this node - extra_names = _collect_extra_names(start_id, from_node, from_socket, - var_name) + extra_names = _collect_extra_names(start_id, from_node, from_socket, var_name) new_chain = chain + [(expression, var_name, extra_names)] results = [] @@ -870,8 +1102,9 @@ def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, return results # ---- compose functions from discovered chains ---- - for target_type, target_id, param_name, chain in \ - _dfs(to_node_id, from_node_id, from_socket, []): + for target_type, target_id, param_name, chain in _dfs( + to_node_id, from_node_id, from_socket, [] + ): if not chain: continue @@ -881,8 +1114,11 @@ def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, for expr_str, var, extra_names in _chain: ctx = dict(extra_names) if extra_names else {} ctx[var] = val - val = float(simple_eval(expr_str, names=ctx, functions=MATH_FUNCTIONS)) + val = float( + simple_eval(expr_str, names=ctx, functions=MATH_FUNCTIONS) + ) return val + return _fn if target_type == "guider": @@ -913,7 +1149,9 @@ class PromptExecutor: if self.server.client_id is not None or broadcast: self.server.send_sync(event, data, self.server.client_id) - def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + def handle_execution_error( + self, prompt_id, prompt, current_outputs, executed, error, ex + ): node_id = error["node_id"] class_type = prompt[node_id]["class_type"] @@ -952,7 +1190,9 @@ class PromptExecutor: elif event == "end": provider.on_prompt_end(prompt_id) except Exception as e: - _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") + _cache_logger.warning( + f"Cache provider {provider.__class__.__name__} error on {event}: {e}" + ) def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -968,22 +1208,32 @@ class PromptExecutor: self.server.client_id = None self.status_messages = [] - self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) + self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) self._notify_prompt_lifecycle("start", prompt_id) - ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) - ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3)) - ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None - comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) + ram_headroom = int(self.cache_args["ram"] * (1024**3)) + ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024**3)) + ram_release_callback = ( + self.caches.outputs.ram_release + if self.cache_type == CacheType.RAM_PRESSURE + else None + ) + comfy.memory_management.set_ram_cache_release_state( + ram_release_callback, ram_headroom + ) try: with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) reset_progress_state(prompt_id, dynamic_prompt) add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + is_changed_cache = IsChangedCache( + prompt_id, dynamic_prompt, self.caches.outputs + ) for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + await cache.set_prompt( + dynamic_prompt, prompt.keys(), is_changed_cache + ) cache.clean_unused() node_ids = list(prompt.keys()) @@ -991,16 +1241,19 @@ class PromptExecutor: *(self.caches.outputs.get(node_id) for node_id in node_ids) ) cached_nodes = [ - node_id for node_id, result in zip(node_ids, cache_results) + node_id + for node_id, result in zip(node_ids, cache_results) if result is not None ] comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) + self.add_message( + "execution_cached", + {"nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False, + ) pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results ui_node_outputs = {} executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) @@ -1014,35 +1267,66 @@ class PromptExecutor: # These are injected into the guider / sampler after the # target node executes so the sampler can vary parameters # (cfg, s_noise, ...) with step_index. - _feedback_cfg_injections = {} # guider_node_id → cfg_fn - _feedback_sampler_injections = {} # sampler_node_id → {param: fn} + _feedback_cfg_injections = {} # guider_node_id → cfg_fn + _feedback_sampler_injections = {} # sampler_node_id → {param: fn} for to_node_id, edges in execution_list.feedback_links.items(): for from_node_id, from_socket in edges: try: _build_feedback_fns( - dynamic_prompt, from_node_id, from_socket, - to_node_id, _feedback_cfg_injections, + dynamic_prompt, + from_node_id, + from_socket, + to_node_id, + _feedback_cfg_injections, _feedback_sampler_injections, ) except Exception: - pass # non-critical – feedback just wonʼt vary per step + pass # non-critical – feedback just wonʼt vary per step # ----------------------------------------------------------------- while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + self.handle_execution_error( + prompt_id, + dynamic_prompt.original_prompt, + current_outputs, + executed, + error, + ex, + ) break - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + assert node_id is not None, ( + "Node ID should not be None at this point" + ) + result, error, ex = await execute( + self.server, + dynamic_prompt, + self.caches, + node_id, + extra_data, + executed, + prompt_id, + execution_list, + pending_subgraph_results, + pending_async_nodes, + ui_node_outputs, + ) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + self.handle_execution_error( + prompt_id, + dynamic_prompt.original_prompt, + current_outputs, + executed, + error, + ex, + ) break elif result == ExecutionResult.PENDING: execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: + else: # result == ExecutionResult.SUCCESS: # ---- bounded-feedback injection ---- # If this node just produced a guider or sampler # that is part of a feedback cycle, inject per-step @@ -1050,19 +1334,31 @@ class PromptExecutor: if node_id in _feedback_cfg_injections: try: output = self.caches.outputs.get_local(node_id) - if output is not None and output.outputs is not None \ - and len(output.outputs) > 0 and len(output.outputs[0]) > 0: + if ( + output is not None + and output.outputs is not None + and len(output.outputs) > 0 + and len(output.outputs[0]) > 0 + ): guider = output.outputs[0][0] - guider._feedback_cfg_fn = _feedback_cfg_injections[node_id] + guider._feedback_cfg_fn = _feedback_cfg_injections[ + node_id + ] except Exception: pass if node_id in _feedback_sampler_injections: try: output = self.caches.outputs.get_local(node_id) - if output is not None and output.outputs is not None \ - and len(output.outputs) > 0 and len(output.outputs[0]) > 0: + if ( + output is not None + and output.outputs is not None + and len(output.outputs) > 0 + and len(output.outputs[0]) > 0 + ): sampler_obj = output.outputs[0][0] - sampler_obj._feedback_param_fns = _feedback_sampler_injections[node_id] + sampler_obj._feedback_param_fns = ( + _feedback_sampler_injections[node_id] + ) except Exception: pass # --------------------------------------- @@ -1071,9 +1367,11 @@ class PromptExecutor: if self.cache_type == CacheType.RAM_PRESSURE: ram_release_callback(ram_inactive_headroom) ram_shortfall = ram_headroom - psutil.virtual_memory().available - freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2)) + freed = comfy.model_management.free_pins( + ram_shortfall + 512 * (1024**2) + ) if freed < ram_shortfall: - if freed > 64 * (1024 ** 2): + if freed > 64 * (1024**2): # AIMDO MEM_DECOMMIT can outrun psutil.available catching up. time.sleep(0.05) ram_release_callback(ram_headroom, free_active=True) @@ -1087,9 +1385,20 @@ class PromptExecutor: continue cached = await self.caches.outputs.get(node_id) if cached is not None: - display_node_id = dynamic_prompt.get_display_node_id(node_id) - _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs) - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + display_node_id = dynamic_prompt.get_display_node_id( + node_id + ) + _send_cached_ui( + self.server, + node_id, + display_node_id, + cached, + prompt_id, + ui_node_outputs, + ) + self.add_message( + "execution_success", {"prompt_id": prompt_id}, broadcast=False + ) ui_outputs = {} meta_outputs = {} @@ -1120,17 +1429,17 @@ def _is_bounded_feedback_cycle(prompt, visiting, unique_id): infinite cycle — the DAG can safely allow it and the execution engine will break the feedback edge by seeding the iteration output with an initial value. """ - cycle_nodes = visiting[visiting.index(unique_id):] + [unique_id] + cycle_nodes = visiting[visiting.index(unique_id) :] + [unique_id] for node_id in cycle_nodes: if node_id not in prompt: continue - class_type = prompt[node_id].get('class_type') + class_type = prompt[node_id].get("class_type") if class_type is None: continue obj_class = nodes.NODE_CLASS_MAPPINGS.get(class_type) if obj_class is None: continue - bounded = getattr(obj_class, 'BOUNDED_FEEDBACK', None) + bounded = getattr(obj_class, "BOUNDED_FEEDBACK", None) if bounded: return True return False @@ -1145,7 +1454,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): return validated[unique_id] if unique_id in visiting: - cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] + cycle_path_nodes = visiting[visiting.index(unique_id) :] + [unique_id] cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) # A bounded feedback cycle is one where at least one node in the cycle @@ -1160,21 +1469,30 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): validated[unique_id] = (True, [], unique_id) return validated[unique_id] - cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) + cycle_path = " -> ".join( + f"{node_id} ({prompt[node_id]['class_type']})" + for node_id in cycle_path_nodes + ) for node_id in cycle_nodes: - validated[node_id] = (False, [{ - "type": "dependency_cycle", - "message": "Dependency cycle detected", - "details": cycle_path, - "extra_info": { - "node_id": node_id, - "cycle_nodes": cycle_nodes, - } - }], node_id) + validated[node_id] = ( + False, + [ + { + "type": "dependency_cycle", + "message": "Dependency cycle detected", + "details": cycle_path, + "extra_info": { + "node_id": node_id, + "cycle_nodes": cycle_nodes, + }, + } + ], + node_id, + ) return validated[unique_id] - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] + inputs = prompt[unique_id]["inputs"] + class_type = prompt[unique_id]["class_type"] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] errors = [] @@ -1199,10 +1517,14 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): validate_has_kwargs = argspec.varkw is not None received_types = {} - valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) + valid_inputs = set(class_inputs.get("required", {})).union( + set(class_inputs.get("optional", {})) + ) for x in valid_inputs: - input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) + input_type, input_category, extra_info = get_input_info( + obj_class, x, class_inputs + ) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -1211,9 +1533,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "type": "required_input_missing", "message": "Required input is missing", "details": details, - "extra_info": { - "input_name": x - } + "extra_info": {"input_name": x}, } errors.append(error) continue @@ -1229,18 +1549,21 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "extra_info": { "input_name": x, "input_config": info, - "received_value": val - } + "received_value": val, + }, } errors.append(error) continue o_id = val[0] - o_class_type = prompt[o_id]['class_type'] + o_class_type = prompt[o_id]["class_type"] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type): + if ( + "input_types" not in validate_function_inputs + and not validate_node_input(received_type, input_type) + ): details = f"{x}, received_type({received_type}) mismatch input_type({input_type})" error = { "type": "return_type_mismatch", @@ -1250,15 +1573,17 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "input_name": x, "input_config": info, "received_type": received_type, - "linked_node": val - } + "linked_node": val, + }, } errors.append(error) continue try: visiting.append(unique_id) try: - r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting) + r = await validate_inputs( + prompt_id, prompt, o_id, validated, visiting + ) finally: visiting.pop() if r[0] is False: @@ -1269,19 +1594,21 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): typ, _, tb = sys.exc_info() valid = False exception_type = full_type_name(typ) - reasons = [{ - "type": "exception_during_inner_validation", - "message": "Exception when validating inner node", - "details": str(ex), - "extra_info": { - "input_name": x, - "input_config": info, - "exception_message": str(ex), - "exception_type": exception_type, - "traceback": traceback.format_tb(tb), - "linked_node": val + reasons = [ + { + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", + "details": str(ex), + "extra_info": { + "input_name": x, + "input_config": info, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "linked_node": val, + }, } - }] + ] validated[o_id] = (False, reasons, o_id) continue else: @@ -1316,8 +1643,8 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "input_name": x, "input_config": info, "received_value": val, - "exception_message": str(ex) - } + "exception_message": str(ex), + }, } errors.append(error) continue @@ -1326,26 +1653,30 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): if "min" in extra_info and val < extra_info["min"]: error = { "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "message": "Value {} smaller than min of {}".format( + val, extra_info["min"] + ), "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, - } + }, } errors.append(error) continue if "max" in extra_info and val > extra_info["max"]: error = { "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "message": "Value {} bigger than max of {}".format( + val, extra_info["max"] + ), "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, - } + }, } errors.append(error) continue @@ -1380,7 +1711,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "input_name": x, "input_config": input_config, "received_value": val, - } + }, } errors.append(error) continue @@ -1391,10 +1722,17 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: input_filtered[x] = input_data_all[x] - if 'input_types' in validate_function_inputs: - input_filtered['input_types'] = [received_types] + if "input_types" in validate_function_inputs: + input_filtered["input_types"] = [received_types] - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data) + ret = await _async_map_node_over_list( + prompt_id, + unique_id, + obj_class, + input_filtered, + validate_function_name, + v3_data=v3_data, + ) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): @@ -1409,7 +1747,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "details": details, "extra_info": { "input_name": x, - } + }, } errors.append(error) continue @@ -1425,18 +1763,22 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): validated[unique_id] = ret return ret + def full_type_name(klass): module = klass.__module__ - if module == 'builtins': + if module == "builtins": return klass.__qualname__ - return module + '.' + klass.__qualname__ + return module + "." + klass.__qualname__ -async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]): + +async def validate_prompt( + prompt_id, prompt, partial_execution_list: Union[list[str], None] +): outputs = set() for x in prompt: - if 'class_type' not in prompt[x]: + if "class_type" not in prompt[x]: node_data = prompt[x] - node_title = node_data.get('_meta', {}).get('title') + node_title = node_data.get("_meta", {}).get("title") error = { "type": "missing_node_type", "message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.", @@ -1444,16 +1786,16 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ "extra_info": { "node_id": x, "class_type": None, - "node_title": node_title - } + "node_title": node_title, + }, } return (False, error, [], {}) - class_type = prompt[x]['class_type'] + class_type = prompt[x]["class_type"] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) if class_ is None: node_data = prompt[x] - node_title = node_data.get('_meta', {}).get('title', class_type) + node_title = node_data.get("_meta", {}).get("title", class_type) error = { "type": "missing_node_type", "message": f"Node '{node_title}' not found. The custom node may not be installed.", @@ -1461,12 +1803,12 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ "extra_info": { "node_id": x, "class_type": class_type, - "node_title": node_title - } + "node_title": node_title, + }, } return (False, error, [], {}) - if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: + if hasattr(class_, "OUTPUT_NODE") and class_.OUTPUT_NODE is True: if partial_execution_list is None or x in partial_execution_list: outputs.add(x) @@ -1475,7 +1817,7 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ "type": "prompt_no_outputs", "message": "Prompt has no outputs", "details": "", - "extra_info": {} + "extra_info": {}, } return (False, error, [], {}) @@ -1494,15 +1836,17 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ typ, _, tb = sys.exc_info() valid = False exception_type = full_type_name(typ) - reasons = [{ - "type": "exception_during_validation", - "message": "Exception when validating node", - "details": str(ex), - "extra_info": { - "exception_type": exception_type, - "traceback": traceback.format_tb(tb) + reasons = [ + { + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + }, } - }] + ] validated[o] = (False, reasons, o) if valid is True: @@ -1522,15 +1866,17 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ # So don't return those nodes as having errors in the response. if valid is not True and len(reasons) > 0: if node_id not in node_errors: - class_type = prompt[node_id]['class_type'] + class_type = prompt[node_id]["class_type"] node_errors[node_id] = { "errors": reasons, "dependent_outputs": [], - "class_type": class_type + "class_type": class_type, } logging.error(f"* {class_type} {node_id}:") for reason in reasons: - logging.error(f" - {reason['message']}: {reason['details']}") + logging.error( + f" - {reason['message']}: {reason['details']}" + ) node_errors[node_id]["dependent_outputs"].append(o) logging.error("Output will be ignored") @@ -1545,15 +1891,17 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ "type": "prompt_outputs_failed_validation", "message": "Prompt outputs failed validation", "details": errors_list, - "extra_info": {} + "extra_info": {}, } return (False, error, list(good_outputs), node_errors) return (True, None, list(good_outputs), node_errors) + MAXIMUM_HISTORY_SIZE = 10000 + class PromptQueue: def __init__(self, server): self.server = server @@ -1585,12 +1933,17 @@ class PromptQueue: return (item, i) class ExecutionStatus(NamedTuple): - status_str: Literal['success', 'error'] + status_str: Literal["success", "error"] completed: bool messages: List[str] - def task_done(self, item_id, history_result, - status: Optional['PromptQueue.ExecutionStatus'], process_item=None): + def task_done( + self, + item_id, + history_result, + status: Optional["PromptQueue.ExecutionStatus"], + process_item=None, + ): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: @@ -1606,7 +1959,7 @@ class PromptQueue: self.history[prompt[1]] = { "prompt": prompt, "outputs": {}, - 'status': status_dict, + "status": status_dict, } self.history[prompt[1]].update(history_result) self.server.queue_updated()