merging upstream with hiddenswitch

- deprecate our preview image fork of the frontend because upstream now has the needed functionality
 - merge the context executor from upstream into ours
This commit is contained in:
doctorpangloss 2025-07-14 16:55:47 -07:00
parent bd6f28e3bd
commit 499f9be5fa
15 changed files with 242 additions and 89 deletions

View File

@ -146,6 +146,7 @@ class Configuration(dict):
front_end_root (Optional[str]): The local filesystem path to the directory where the frontend is located. Overrides --front-end-version. front_end_root (Optional[str]): The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.
comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org) comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org)
database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'. database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -212,6 +213,7 @@ class Configuration(dict):
self.windows_standalone_build: bool = False self.windows_standalone_build: bool = False
self.disable_metadata: bool = False self.disable_metadata: bool = False
self.disable_all_custom_nodes: bool = False self.disable_all_custom_nodes: bool = False
self.whitelist_custom_nodes: list[str] = []
self.multi_user: bool = False self.multi_user: bool = False
self.plausible_analytics_base_url: Optional[str] = None self.plausible_analytics_base_url: Optional[str] = None
self.plausible_analytics_domain: Optional[str] = None self.plausible_analytics_domain: Optional[str] = None

View File

@ -19,12 +19,16 @@ from typing import List, Optional, Tuple, Literal
import torch import torch
from opentelemetry.trace import get_current_span, StatusCode, Status from opentelemetry.trace import get_current_span, StatusCode, Status
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
DependencyAwareCache, \
BasicCache
# order matters
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from .main_pre import tracer from .main_pre import tracer
from .. import interruption from .. import interruption
from .. import model_management from .. import model_management
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, DependencyAwareCache, \
BasicCache
from ..cli_args import args from ..cli_args import args
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
@ -36,12 +40,10 @@ from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import context_execute_node, context_execute_prompt from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception from ..execution_ext import should_panic_on_exception
# order matters
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
from ..nodes_context import get_nodes from ..nodes_context import get_nodes
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
ProgressRegistry
from ..validation import validate_node_input from ..validation import validate_node_input
_module_properties = create_module_properties() _module_properties = create_module_properties()
@ -135,6 +137,7 @@ class CacheSet:
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None): def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None):
if extra_data is None: if extra_data is None:
extra_data = {} extra_data = {}
@ -187,9 +190,11 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys return input_data_all, missing_keys
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
raise ValueError("") raise ValueError("")
async def resolve_map_node_over_list_results(results): async def resolve_map_node_over_list_results(results):
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()] remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
if len(remaining) == 0: if len(remaining) == 0:
@ -245,6 +250,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
return {k: v[i if len(v) > i else -1] for k, v in d.items()} return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = [] results = []
async def process_inputs(inputs, index=None, input_is_list=False): async def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt: if allow_interrupt:
interruption.throw_exception_if_processing_interrupted() interruption.throw_exception_if_processing_interrupted()
@ -264,8 +270,10 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
f = getattr(obj, func) f = getattr(obj, func)
if inspect.iscoroutinefunction(f): if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args): async def async_wrapper(f, prompt_id, unique_id, list_index, args):
# todo: this is redundant with other parts of the hiddenswitch fork, but we've shimmed it for compatibility
with CurrentNodeContext(prompt_id, unique_id, list_index): with CurrentNodeContext(prompt_id, unique_id, list_index):
return await f(**args) return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs)) task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
# Give the task a chance to execute without yielding # Give the task a chance to execute without yielding
await asyncio.sleep(0) await asyncio.sleep(0)
@ -322,6 +330,7 @@ async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_b
output, ui, has_subgraph = get_output_from_returns(return_values, obj) output, ui, has_subgraph = get_output_from_returns(return_values, obj)
return output, ui, has_subgraph, False return output, ui, has_subgraph, False
def get_output_from_returns(return_values, obj): def get_output_from_returns(return_values, obj):
results = [] results = []
uis = [] uis = []
@ -464,9 +473,9 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
if hasattr(obj, "check_lazy_status"): if hasattr(obj, "check_lazy_status"):
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True) required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x, str) and (
x not in input_data_all or x in missing_keys x not in input_data_all or x in missing_keys
)] )]
if len(required_inputs) > 0: if len(required_inputs) > 0:
for i in required_inputs: for i in required_inputs:
@ -500,10 +509,12 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id) unblock = execution_list.add_external_block(unique_id)
async def await_completion(): async def await_completion():
tasks = [x for x in output_data if isinstance(x, asyncio.Task)] tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
unblock() unblock()
asyncio.create_task(await_completion()) asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
@ -683,11 +694,12 @@ class PromptExecutor:
# torchao and potentially other optimization approaches break when the models are created in inference mode # torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here # todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
with context_execute_prompt(self.server, prompt_id, inference_mode=inference_mode): dynamic_prompt = DynamicPrompt(prompt)
await self._execute_async(prompt, prompt_id, extra_data, execute_outputs) reset_progress_state(prompt_id, dynamic_prompt)
with context_execute_prompt(self.server, prompt_id, progress_registry=ProgressRegistry(prompt_id, dynamic_prompt), inference_mode=inference_mode):
await self._execute_async(dynamic_prompt, prompt_id, extra_data, execute_outputs)
async def _execute_async(self, prompt: DynamicPrompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True):
async def _execute_async(self, prompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True):
if execute_outputs is None: if execute_outputs is None:
execute_outputs = [] execute_outputs = []
if extra_data is None: if extra_data is None:
@ -704,8 +716,8 @@ class PromptExecutor:
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode() if inference_mode else nullcontext(): with torch.inference_mode() if inference_mode else nullcontext():
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = prompt
reset_progress_state(prompt_id, dynamic_prompt) prompt: dict = prompt.original_prompt
add_progress_handler(WebUIProgressHandler(self.server)) add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all: for cache in self.caches.all:
@ -722,7 +734,7 @@ class PromptExecutor:
{"nodes": cached_nodes, "prompt_id": prompt_id}, {"nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False) broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
@ -988,7 +1000,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
if 'input_types' in validate_function_inputs: if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered) # ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS") ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await resolve_map_node_over_list_results(ret) ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered: for x in input_filtered:

View File

@ -1,7 +1,4 @@
from ..component_model import queue_types
class BinaryEventTypes: # todo: should protocol really be all of queue_types?
PREVIEW_IMAGE = 1 BinaryEventTypes = queue_types.BinaryEventTypes
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
PREVIEW_IMAGE_WITH_METADATA = 4

View File

@ -12,6 +12,7 @@ import socket
import struct import struct
import sys import sys
import traceback import traceback
import typing
import urllib import urllib
import uuid import uuid
from asyncio import Future, AbstractEventLoop, Task from asyncio import Future, AbstractEventLoop, Task
@ -45,7 +46,8 @@ from ..cmd import execution
from ..cmd import folder_paths from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.encode_text_for_progress import encode_text_for_progress from ..component_model.encode_text_for_progress import encode_text_for_progress
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo, \
UnencodedPreviewImageMessage
from ..component_model.file_output_path import file_output_path from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \ from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
ExecutionStatus ExecutionStatus
@ -53,6 +55,7 @@ from ..digest import digest
from ..images import open_image from ..images import open_image
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
from ..nodes.package_typing import ExportedNodes from ..nodes.package_typing import ExportedNodes
from ..progress_types import PreviewImageMetadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -1049,7 +1052,7 @@ class PromptServer(ExecutorToClientProgress):
prompt_info['exec_info'] = exec_info prompt_info['exec_info'] = exec_info
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data: UnencodedPreviewImageMessage | tuple[UnencodedPreviewImageMessage, PreviewImageMetadata] | bytes | bytearray | dict, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid) await self.send_image(data, sid=sid)
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA: elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
@ -1061,7 +1064,7 @@ class PromptServer(ExecutorToClientProgress):
else: else:
await self.send_json(event, data, sid) await self.send_json(event, data, sid)
def encode_bytes(self, event: int | Enum | str, data): def encode_bytes(self, event: int | Enum | str, data: bytes | bytearray | typing.Sequence[int]):
# todo: investigate what is propagating these spurious, string-repr'd previews # todo: investigate what is propagating these spurious, string-repr'd previews
if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE): if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE):
event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value
@ -1077,14 +1080,14 @@ class PromptServer(ExecutorToClientProgress):
message.extend(data) message.extend(data)
return message return message
async def send_image(self, image_data, sid=None): async def send_image(self, image_data: UnencodedPreviewImageMessage, sid=None):
image_type = image_data[0] image_type = image_data[0]
image = image_data[1] image = image_data[1]
max_size = image_data[2] max_size = image_data[2]
preview_bytes = encode_preview_image(image, image_type, max_size) preview_bytes = encode_preview_image(image, image_type, max_size)
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_image_with_metadata(self, image_data, metadata=None, sid=None): async def send_image_with_metadata(self, image_data: UnencodedPreviewImageMessage, metadata: Optional[PreviewImageMetadata] = None, sid=None):
image_type = image_data[0] image_type = image_data[0]
image = image_data[1] image = image_data[1]
max_size = image_data[2] max_size = image_data[2]
@ -1104,7 +1107,6 @@ class PromptServer(ExecutorToClientProgress):
metadata["image_type"] = mimetype metadata["image_type"] = mimetype
# Serialize metadata as JSON # Serialize metadata as JSON
import json
metadata_json = json.dumps(metadata).encode('utf-8') metadata_json = json.dumps(metadata).encode('utf-8')
metadata_length = len(metadata_json) metadata_length = len(metadata_json)
@ -1131,7 +1133,7 @@ class PromptServer(ExecutorToClientProgress):
elif sid in self.sockets: elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_bytes, message) await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
async def send_json(self, event, data, sid=None): async def send_json(self, event, data: dict, sid=None):
message = {"type": event, "data": data} message = {"type": event, "data": data}
if sid is None: if sid is None:

View File

@ -12,6 +12,7 @@ from .outputs_types import OutputsDict
from .queue_types import BinaryEventTypes from .queue_types import BinaryEventTypes
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..nodes.package_typing import InputTypeSpec from ..nodes.package_typing import InputTypeSpec
from ..progress_types import PreviewImageMetadata
class ExecInfo(TypedDict): class ExecInfo(TypedDict):
@ -82,7 +83,7 @@ ExecutedMessage = ExecutingMessage
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None] SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None] SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[UnencodedPreviewImageMessage, PreviewImageMetadata], bytes, bytearray, str, None]
class ExecutorToClientProgress(Protocol): class ExecutorToClientProgress(Protocol):

View File

@ -4,6 +4,7 @@ import asyncio
from enum import Enum from enum import Enum
from typing import NamedTuple, Optional, List, Literal, Sequence from typing import NamedTuple, Optional, List, Literal, Sequence
from typing import Tuple from typing import Tuple
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict from .outputs_types import OutputsDict
@ -142,6 +143,7 @@ class BinaryEventTypes(Enum):
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2 UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3 TEXT = 3
PREVIEW_IMAGE_WITH_METADATA = 4
class ExecutorToClientMessage(TypedDict, total=False): class ExecutorToClientMessage(TypedDict, total=False):

View File

@ -10,6 +10,7 @@ from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames from .component_model.folder_path_types import FolderNames
from .distributed.server_stub import ServerStub from .distributed.server_stub import ServerStub
from .nodes.package_typing import ExportedNodes, exported_nodes_view from .nodes.package_typing import ExportedNodes, exported_nodes_view
from .progress_types import AbstractProgressRegistry, ProgressRegistryStub
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context") comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
# enables context var propagation across process boundaries for process pool executors # enables context var propagation across process boundaries for process pool executors
@ -23,10 +24,21 @@ class ExecutionContext:
custom_nodes: ExportedNodes custom_nodes: ExportedNodes
node_id: Optional[str] = None node_id: Optional[str] = None
task_id: Optional[str] = None task_id: Optional[str] = None
list_index: Optional[int] = None
inference_mode: bool = True inference_mode: bool = True
progress_registry: Optional[AbstractProgressRegistry] = None
def __iter__(self):
"""
Provides tuple-like unpacking behavior, similar to a NamedTuple.
Yields task_id, node_id, and list_index.
"""
yield self.task_id
yield self.node_id
yield self.list_index
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes())) comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
def current_execution_context() -> ExecutionContext: def current_execution_context() -> ExecutionContext:
@ -51,9 +63,9 @@ def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
@contextmanager @contextmanager
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, inference_mode: bool = True): def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, progress_registry: AbstractProgressRegistry, inference_mode: bool = True):
current_ctx = current_execution_context() current_ctx = current_execution_context()
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode) new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode, progress_registry=progress_registry)
with _new_execution_context(new_ctx): with _new_execution_context(new_ctx):
yield new_ctx yield new_ctx
@ -84,4 +96,18 @@ def context_add_custom_nodes(exported_nodes: ExportedNodes):
new_ctx = replace(current_ctx, custom_nodes=merged_custom_nodes) new_ctx = replace(current_ctx, custom_nodes=merged_custom_nodes)
with _new_execution_context(new_ctx): with _new_execution_context(new_ctx):
yield new_ctx yield new_ctx
@contextmanager
def context_set_node_and_prompt(prompt_id: str, node_id: str, list_index: Optional[int] = None):
"""
A context manager to set the prompt_id (task_id), node_id, and optional list_index for the current execution.
This is useful for fine-grained context setting within a node's execution, especially for batch processing.
Replaces the @guill code upstream
"""
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, task_id=prompt_id, node_id=node_id, list_index=list_index)
with _new_execution_context(new_ctx):
yield new_ctx

View File

@ -194,17 +194,15 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be # this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things, # found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
# the way community custom nodes is pretty radioactive # the way community custom nodes is pretty radioactive
from ..cmd import cuda_malloc, folder_paths, latent_preview from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol
from .. import graph, graph_utils, caching
from .. import node_helpers from .. import node_helpers
from .. import __version__ from .. import __version__
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers): for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
module_short_name = module.__name__.split(".")[-1] module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module sys.modules[module_short_name] = module
sys.modules['nodes'] = base_nodes sys.modules['nodes'] = base_nodes
sys.modules['comfy_execution.graph'] = graph # apparently this is also something that happens
sys.modules['comfy_execution.graph_utils'] = graph_utils sys.modules['comfy.nodes'] = base_nodes
sys.modules['comfy_execution.caching'] = caching
comfyui_version = types.ModuleType('comfyui_version', '') comfyui_version = types.ModuleType('comfyui_version', '')
setattr(comfyui_version, "__version__", __version__) setattr(comfyui_version, "__version__", __version__)
sys.modules['comfyui_version'] = comfyui_version sys.modules['comfyui_version'] = comfyui_version

View File

@ -7,11 +7,17 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import override from typing_extensions import override
from .component_model.module_property import create_module_properties
from .execution_context import current_execution_context
from .progress_types import AbstractProgressRegistry
if TYPE_CHECKING: if TYPE_CHECKING:
from .graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
from protocol import BinaryEventTypes from .cmd.protocol import BinaryEventTypes
from comfy_api import feature_flags from comfy_api import feature_flags
_module_properties = create_module_properties()
class NodeState(Enum): class NodeState(Enum):
Pending = "pending" Pending = "pending"
@ -234,7 +240,7 @@ class WebUIProgressHandler(ProgressHandler):
self._send_progress_state(prompt_id, self.registry.nodes) self._send_progress_state(prompt_id, self.registry.nodes)
class ProgressRegistry: class ProgressRegistry(AbstractProgressRegistry):
""" """
Registry that maintains node progress state and notifies registered handlers. Registry that maintains node progress state and notifies registered handlers.
""" """
@ -320,18 +326,25 @@ class ProgressRegistry:
# Global registry instance # Global registry instance
global_progress_registry: ProgressRegistry = None @_module_properties.getter
def _global_progress_registry() -> ProgressRegistry:
return current_execution_context().progress_registry
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
global global_progress_registry """
the caller must create a new progress registry
:param prompt_id:
:param dynprompt:
:return: None
"""
global_progress_registry = _global_progress_registry()
# Reset existing handlers if registry exists # Reset existing handlers if registry exists
if global_progress_registry is not None: if global_progress_registry is not None:
global_progress_registry.reset_handlers() global_progress_registry.reset_handlers()
# Create new registry # XXX caller now creates new progress registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
def add_progress_handler(handler: ProgressHandler) -> None: def add_progress_handler(handler: ProgressHandler) -> None:
@ -341,11 +354,4 @@ def add_progress_handler(handler: ProgressHandler) -> None:
def get_progress_state() -> ProgressRegistry: def get_progress_state() -> ProgressRegistry:
global global_progress_registry return _global_progress_registry()
if global_progress_registry is None:
from .graph import DynamicPrompt
global_progress_registry = ProgressRegistry(
prompt_id="", dynprompt=DynamicPrompt({})
)
return global_progress_registry

103
comfy/progress_types.py Normal file
View File

@ -0,0 +1,103 @@
from abc import ABCMeta, abstractmethod
from typing_extensions import TypedDict, NotRequired
class PreviewImageMetadata(TypedDict, total=True):
"""
Metadata associated with a preview image sent to the UI.
"""
node_id: str
prompt_id: str
display_node_id: str
parent_node_id: str
real_node_id: str
image_type: NotRequired[str]
class AbstractProgressRegistry(metaclass=ABCMeta):
@abstractmethod
def register_handler(self, handler):
"""Register a progress handler"""
pass
@abstractmethod
def unregister_handler(self, handler_name):
"""Unregister a progress handler"""
pass
@abstractmethod
def enable_handler(self, handler_name):
"""Enable a progress handler"""
pass
@abstractmethod
def disable_handler(self, handler_name):
"""Disable a progress handler"""
pass
@abstractmethod
def ensure_entry(self, node_id):
"""Ensure a node entry exists"""
pass
@abstractmethod
def start_progress(self, node_id):
"""Start progress tracking for a node"""
pass
@abstractmethod
def update_progress(self, node_id, value, max_value, image):
"""Update progress for a node"""
pass
@abstractmethod
def finish_progress(self, node_id):
"""Finish progress tracking for a node"""
pass
@abstractmethod
def reset_handlers(self):
"""Reset all handlers"""
pass
class ProgressRegistryStub(AbstractProgressRegistry):
"""A stub implementation of AbstractProgressRegistry that performs no operations."""
def register_handler(self, handler):
"""Register a progress handler"""
pass
def unregister_handler(self, handler_name):
"""Unregister a progress handler"""
pass
def enable_handler(self, handler_name):
"""Enable a progress handler"""
pass
def disable_handler(self, handler_name):
"""Disable a progress handler"""
pass
def ensure_entry(self, node_id):
"""Ensure a node entry exists"""
pass
def start_progress(self, node_id):
"""Start progress tracking for a node"""
pass
def update_progress(self, node_id, value, max_value, image):
"""Update progress for a node"""
pass
def finish_progress(self, node_id):
"""Finish progress tracking for a node"""
pass
def reset_handlers(self):
"""Reset all handlers"""
pass

View File

View File

@ -3,7 +3,7 @@ from typing import Sequence, Mapping, Dict
from .graph import DynamicPrompt from .graph import DynamicPrompt
from .graph_utils import is_link from .graph_utils import is_link
from .nodes_context import get_nodes from comfy.nodes_context import get_nodes
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@ -4,11 +4,11 @@ import asyncio
import inspect import inspect
from typing import Optional, Type, Literal from typing import Optional, Type, Literal
from .comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \ from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
DependencyExecutionErrorMessage DependencyExecutionErrorMessage
from comfy.nodes_context import get_nodes
from .graph_utils import is_link from .graph_utils import is_link
from .nodes_context import get_nodes
class DynamicPrompt: class DynamicPrompt:

View File

@ -1,46 +1,53 @@
import contextvars from __future__ import annotations
from typing import Optional, NamedTuple
class ExecutionContext(NamedTuple): from typing import Optional
"""
Context information about the currently executing node.
Attributes: from comfy import execution_context as core_execution_context
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists) ExecutionContext = core_execution_context.ExecutionContext
""" """
prompt_id: str Context information about the currently executing node.
node_id: str This is a compatibility wrapper around the core execution context.
list_index: Optional[int]
Attributes:
prompt_id: The ID of the currently executing prompt (task_id in core context)
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
"""
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]: def get_executing_context() -> Optional[ExecutionContext]:
return current_executing_context.get(None) """
Gets the current execution context from the core context provider.
Returns a compatibility ExecutionContext object or None if not in an execution context.
"""
ctx = core_execution_context.current_execution_context()
if ctx.task_id is None or ctx.node_id is None:
return None
return ctx
class CurrentNodeContext: class CurrentNodeContext:
""" """
Context manager for setting the current executing node context. Context manager for setting the current executing node context.
This is a wrapper around the core `context_set_node_and_prompt` context manager.
Sets the current_executing_context on enter and resets it on exit.
Example: Example:
with CurrentNodeContext(node_id="123", list_index=0): with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
# Code that should run with the current node context set # Code that should run with the current node context set
process_image() process_image()
""" """
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None): def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
self.context = ExecutionContext( self._cm = core_execution_context.context_set_node_and_prompt(
prompt_id= prompt_id, prompt_id=prompt_id,
node_id= node_id, node_id=node_id,
list_index= list_index list_index=list_index
) )
self.token = None
def __enter__(self): def __enter__(self):
self.token = current_executing_context.set(self.context) self._cm.__enter__()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if self.token is not None: self._cm.__exit__(exc_type, exc_val, exc_tb)
current_executing_context.reset(self.token)

View File

@ -220,9 +220,6 @@ torchvision = [
torchaudio = [ torchaudio = [
{ index = "pytorch-cpu", extra = "cpu" }, { index = "pytorch-cpu", extra = "cpu" },
] ]
comfyui-frontend-package = [
{ git = "https://github.com/appmana/appmana-comfyui-frontend", subdirectory = "comfyui_frontend_package" },
]
"sageattention" = [ "sageattention" = [
{ git = "https://github.com/thu-ml/SageAttention.git", extra = "attention", marker = "sys_platform == 'Linux' or sys_platform == 'win32'" }, { git = "https://github.com/thu-ml/SageAttention.git", extra = "attention", marker = "sys_platform == 'Linux' or sys_platform == 'win32'" },
] ]
@ -244,4 +241,4 @@ exclude = ["*.ipynb"]
allow-direct-references = true allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/"] packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/", "comfy_execution/"]