merging upstream with hiddenswitch

- deprecate our preview image fork of the frontend because upstream now has the needed functionality
 - merge the context executor from upstream into ours
This commit is contained in:
doctorpangloss 2025-07-14 16:55:47 -07:00
parent bd6f28e3bd
commit 499f9be5fa
15 changed files with 242 additions and 89 deletions

View File

@ -146,6 +146,7 @@ class Configuration(dict):
front_end_root (Optional[str]): The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.
comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org)
database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
"""
def __init__(self, **kwargs):
@ -212,6 +213,7 @@ class Configuration(dict):
self.windows_standalone_build: bool = False
self.disable_metadata: bool = False
self.disable_all_custom_nodes: bool = False
self.whitelist_custom_nodes: list[str] = []
self.multi_user: bool = False
self.plausible_analytics_base_url: Optional[str] = None
self.plausible_analytics_domain: Optional[str] = None

View File

@ -19,12 +19,16 @@ from typing import List, Optional, Tuple, Literal
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
DependencyAwareCache, \
BasicCache
# order matters
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext
from .main_pre import tracer
from .. import interruption
from .. import model_management
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, DependencyAwareCache, \
BasicCache
from ..cli_args import args
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
@ -36,12 +40,10 @@ from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception
# order matters
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
from ..nodes_context import get_nodes
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
ProgressRegistry
from ..validation import validate_node_input
_module_properties = create_module_properties()
@ -135,6 +137,7 @@ class CacheSet:
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None):
if extra_data is None:
extra_data = {}
@ -187,9 +190,11 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
raise ValueError("")
async def resolve_map_node_over_list_results(results):
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
if len(remaining) == 0:
@ -245,6 +250,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = []
async def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt:
interruption.throw_exception_if_processing_interrupted()
@ -264,8 +270,10 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
# todo: this is redundant with other parts of the hiddenswitch fork, but we've shimmed it for compatibility
with CurrentNodeContext(prompt_id, unique_id, list_index):
return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
@ -322,6 +330,7 @@ async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_b
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
return output, ui, has_subgraph, False
def get_output_from_returns(return_values, obj):
results = []
uis = []
@ -464,9 +473,9 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
if hasattr(obj, "check_lazy_status"):
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
x not in input_data_all or x in missing_keys
)]
if len(required_inputs) > 0:
for i in required_inputs:
@ -500,10 +509,12 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
async def await_completion():
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True)
unblock()
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
@ -683,11 +694,12 @@ class PromptExecutor:
# torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
with context_execute_prompt(self.server, prompt_id, inference_mode=inference_mode):
await self._execute_async(prompt, prompt_id, extra_data, execute_outputs)
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
with context_execute_prompt(self.server, prompt_id, progress_registry=ProgressRegistry(prompt_id, dynamic_prompt), inference_mode=inference_mode):
await self._execute_async(dynamic_prompt, prompt_id, extra_data, execute_outputs)
async def _execute_async(self, prompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True):
async def _execute_async(self, prompt: DynamicPrompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True):
if execute_outputs is None:
execute_outputs = []
if extra_data is None:
@ -704,8 +716,8 @@ class PromptExecutor:
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode() if inference_mode else nullcontext():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
dynamic_prompt = prompt
prompt: dict = prompt.original_prompt
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
@ -722,7 +734,7 @@ class PromptExecutor:
{"nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
@ -988,7 +1000,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
# ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:

View File

@ -1,7 +1,4 @@
from ..component_model import queue_types
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
PREVIEW_IMAGE_WITH_METADATA = 4
# todo: should protocol really be all of queue_types?
BinaryEventTypes = queue_types.BinaryEventTypes

View File

@ -12,6 +12,7 @@ import socket
import struct
import sys
import traceback
import typing
import urllib
import uuid
from asyncio import Future, AbstractEventLoop, Task
@ -45,7 +46,8 @@ from ..cmd import execution
from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.encode_text_for_progress import encode_text_for_progress
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo, \
UnencodedPreviewImageMessage
from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
ExecutionStatus
@ -53,6 +55,7 @@ from ..digest import digest
from ..images import open_image
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
from ..nodes.package_typing import ExportedNodes
from ..progress_types import PreviewImageMetadata
logger = logging.getLogger(__name__)
@ -1049,7 +1052,7 @@ class PromptServer(ExecutorToClientProgress):
prompt_info['exec_info'] = exec_info
return prompt_info
async def send(self, event, data, sid=None):
async def send(self, event, data: UnencodedPreviewImageMessage | tuple[UnencodedPreviewImageMessage, PreviewImageMetadata] | bytes | bytearray | dict, sid=None):
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
await self.send_image(data, sid=sid)
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
@ -1061,7 +1064,7 @@ class PromptServer(ExecutorToClientProgress):
else:
await self.send_json(event, data, sid)
def encode_bytes(self, event: int | Enum | str, data):
def encode_bytes(self, event: int | Enum | str, data: bytes | bytearray | typing.Sequence[int]):
# todo: investigate what is propagating these spurious, string-repr'd previews
if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE):
event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value
@ -1077,14 +1080,14 @@ class PromptServer(ExecutorToClientProgress):
message.extend(data)
return message
async def send_image(self, image_data, sid=None):
async def send_image(self, image_data: UnencodedPreviewImageMessage, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
preview_bytes = encode_preview_image(image, image_type, max_size)
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
async def send_image_with_metadata(self, image_data: UnencodedPreviewImageMessage, metadata: Optional[PreviewImageMetadata] = None, sid=None):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
@ -1104,7 +1107,6 @@ class PromptServer(ExecutorToClientProgress):
metadata["image_type"] = mimetype
# Serialize metadata as JSON
import json
metadata_json = json.dumps(metadata).encode('utf-8')
metadata_length = len(metadata_json)
@ -1131,7 +1133,7 @@ class PromptServer(ExecutorToClientProgress):
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
async def send_json(self, event, data, sid=None):
async def send_json(self, event, data: dict, sid=None):
message = {"type": event, "data": data}
if sid is None:

View File

@ -12,6 +12,7 @@ from .outputs_types import OutputsDict
from .queue_types import BinaryEventTypes
from ..cli_args_types import Configuration
from ..nodes.package_typing import InputTypeSpec
from ..progress_types import PreviewImageMetadata
class ExecInfo(TypedDict):
@ -82,7 +83,7 @@ ExecutedMessage = ExecutingMessage
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[UnencodedPreviewImageMessage, PreviewImageMetadata], bytes, bytearray, str, None]
class ExecutorToClientProgress(Protocol):

View File

@ -4,6 +4,7 @@ import asyncio
from enum import Enum
from typing import NamedTuple, Optional, List, Literal, Sequence
from typing import Tuple
from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict
@ -142,6 +143,7 @@ class BinaryEventTypes(Enum):
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
TEXT = 3
PREVIEW_IMAGE_WITH_METADATA = 4
class ExecutorToClientMessage(TypedDict, total=False):

View File

@ -10,6 +10,7 @@ from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames
from .distributed.server_stub import ServerStub
from .nodes.package_typing import ExportedNodes, exported_nodes_view
from .progress_types import AbstractProgressRegistry, ProgressRegistryStub
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
# enables context var propagation across process boundaries for process pool executors
@ -23,10 +24,21 @@ class ExecutionContext:
custom_nodes: ExportedNodes
node_id: Optional[str] = None
task_id: Optional[str] = None
list_index: Optional[int] = None
inference_mode: bool = True
progress_registry: Optional[AbstractProgressRegistry] = None
def __iter__(self):
"""
Provides tuple-like unpacking behavior, similar to a NamedTuple.
Yields task_id, node_id, and list_index.
"""
yield self.task_id
yield self.node_id
yield self.list_index
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes()))
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
def current_execution_context() -> ExecutionContext:
@ -51,9 +63,9 @@ def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
@contextmanager
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, inference_mode: bool = True):
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, progress_registry: AbstractProgressRegistry, inference_mode: bool = True):
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode)
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode, progress_registry=progress_registry)
with _new_execution_context(new_ctx):
yield new_ctx
@ -85,3 +97,17 @@ def context_add_custom_nodes(exported_nodes: ExportedNodes):
new_ctx = replace(current_ctx, custom_nodes=merged_custom_nodes)
with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager
def context_set_node_and_prompt(prompt_id: str, node_id: str, list_index: Optional[int] = None):
"""
A context manager to set the prompt_id (task_id), node_id, and optional list_index for the current execution.
This is useful for fine-grained context setting within a node's execution, especially for batch processing.
Replaces the @guill code upstream
"""
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, task_id=prompt_id, node_id=node_id, list_index=list_index)
with _new_execution_context(new_ctx):
yield new_ctx

View File

@ -194,17 +194,15 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
# the way community custom nodes is pretty radioactive
from ..cmd import cuda_malloc, folder_paths, latent_preview
from .. import graph, graph_utils, caching
from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol
from .. import node_helpers
from .. import __version__
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers):
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module
sys.modules['nodes'] = base_nodes
sys.modules['comfy_execution.graph'] = graph
sys.modules['comfy_execution.graph_utils'] = graph_utils
sys.modules['comfy_execution.caching'] = caching
# apparently this is also something that happens
sys.modules['comfy.nodes'] = base_nodes
comfyui_version = types.ModuleType('comfyui_version', '')
setattr(comfyui_version, "__version__", __version__)
sys.modules['comfyui_version'] = comfyui_version

View File

@ -7,11 +7,17 @@ from PIL import Image
from tqdm import tqdm
from typing_extensions import override
from .component_model.module_property import create_module_properties
from .execution_context import current_execution_context
from .progress_types import AbstractProgressRegistry
if TYPE_CHECKING:
from .graph import DynamicPrompt
from protocol import BinaryEventTypes
from comfy_execution.graph import DynamicPrompt
from .cmd.protocol import BinaryEventTypes
from comfy_api import feature_flags
_module_properties = create_module_properties()
class NodeState(Enum):
Pending = "pending"
@ -234,7 +240,7 @@ class WebUIProgressHandler(ProgressHandler):
self._send_progress_state(prompt_id, self.registry.nodes)
class ProgressRegistry:
class ProgressRegistry(AbstractProgressRegistry):
"""
Registry that maintains node progress state and notifies registered handlers.
"""
@ -320,18 +326,25 @@ class ProgressRegistry:
# Global registry instance
global_progress_registry: ProgressRegistry = None
@_module_properties.getter
def _global_progress_registry() -> ProgressRegistry:
return current_execution_context().progress_registry
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
global global_progress_registry
"""
the caller must create a new progress registry
:param prompt_id:
:param dynprompt:
:return: None
"""
global_progress_registry = _global_progress_registry()
# Reset existing handlers if registry exists
if global_progress_registry is not None:
global_progress_registry.reset_handlers()
# Create new registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
# XXX caller now creates new progress registry
def add_progress_handler(handler: ProgressHandler) -> None:
@ -341,11 +354,4 @@ def add_progress_handler(handler: ProgressHandler) -> None:
def get_progress_state() -> ProgressRegistry:
global global_progress_registry
if global_progress_registry is None:
from .graph import DynamicPrompt
global_progress_registry = ProgressRegistry(
prompt_id="", dynprompt=DynamicPrompt({})
)
return global_progress_registry
return _global_progress_registry()

103
comfy/progress_types.py Normal file
View File

@ -0,0 +1,103 @@
from abc import ABCMeta, abstractmethod
from typing_extensions import TypedDict, NotRequired
class PreviewImageMetadata(TypedDict, total=True):
"""
Metadata associated with a preview image sent to the UI.
"""
node_id: str
prompt_id: str
display_node_id: str
parent_node_id: str
real_node_id: str
image_type: NotRequired[str]
class AbstractProgressRegistry(metaclass=ABCMeta):
@abstractmethod
def register_handler(self, handler):
"""Register a progress handler"""
pass
@abstractmethod
def unregister_handler(self, handler_name):
"""Unregister a progress handler"""
pass
@abstractmethod
def enable_handler(self, handler_name):
"""Enable a progress handler"""
pass
@abstractmethod
def disable_handler(self, handler_name):
"""Disable a progress handler"""
pass
@abstractmethod
def ensure_entry(self, node_id):
"""Ensure a node entry exists"""
pass
@abstractmethod
def start_progress(self, node_id):
"""Start progress tracking for a node"""
pass
@abstractmethod
def update_progress(self, node_id, value, max_value, image):
"""Update progress for a node"""
pass
@abstractmethod
def finish_progress(self, node_id):
"""Finish progress tracking for a node"""
pass
@abstractmethod
def reset_handlers(self):
"""Reset all handlers"""
pass
class ProgressRegistryStub(AbstractProgressRegistry):
"""A stub implementation of AbstractProgressRegistry that performs no operations."""
def register_handler(self, handler):
"""Register a progress handler"""
pass
def unregister_handler(self, handler_name):
"""Unregister a progress handler"""
pass
def enable_handler(self, handler_name):
"""Enable a progress handler"""
pass
def disable_handler(self, handler_name):
"""Disable a progress handler"""
pass
def ensure_entry(self, node_id):
"""Ensure a node entry exists"""
pass
def start_progress(self, node_id):
"""Start progress tracking for a node"""
pass
def update_progress(self, node_id, value, max_value, image):
"""Update progress for a node"""
pass
def finish_progress(self, node_id):
"""Finish progress tracking for a node"""
pass
def reset_handlers(self):
"""Reset all handlers"""
pass

View File

View File

@ -3,7 +3,7 @@ from typing import Sequence, Mapping, Dict
from .graph import DynamicPrompt
from .graph_utils import is_link
from .nodes_context import get_nodes
from comfy.nodes_context import get_nodes
from abc import ABC, abstractmethod

View File

@ -4,11 +4,11 @@ import asyncio
import inspect
from typing import Optional, Type, Literal
from .comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
DependencyExecutionErrorMessage
from comfy.nodes_context import get_nodes
from .graph_utils import is_link
from .nodes_context import get_nodes
class DynamicPrompt:

View File

@ -1,46 +1,53 @@
import contextvars
from typing import Optional, NamedTuple
from __future__ import annotations
class ExecutionContext(NamedTuple):
"""
Context information about the currently executing node.
from typing import Optional
Attributes:
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
"""
prompt_id: str
node_id: str
list_index: Optional[int]
from comfy import execution_context as core_execution_context
ExecutionContext = core_execution_context.ExecutionContext
"""
Context information about the currently executing node.
This is a compatibility wrapper around the core execution context.
Attributes:
prompt_id: The ID of the currently executing prompt (task_id in core context)
node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists)
"""
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]:
return current_executing_context.get(None)
"""
Gets the current execution context from the core context provider.
Returns a compatibility ExecutionContext object or None if not in an execution context.
"""
ctx = core_execution_context.current_execution_context()
if ctx.task_id is None or ctx.node_id is None:
return None
return ctx
class CurrentNodeContext:
"""
Context manager for setting the current executing node context.
Sets the current_executing_context on enter and resets it on exit.
This is a wrapper around the core `context_set_node_and_prompt` context manager.
Example:
with CurrentNodeContext(node_id="123", list_index=0):
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
# Code that should run with the current node context set
process_image()
"""
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
self.context = ExecutionContext(
prompt_id= prompt_id,
node_id= node_id,
list_index= list_index
self._cm = core_execution_context.context_set_node_and_prompt(
prompt_id=prompt_id,
node_id=node_id,
list_index=list_index
)
self.token = None
def __enter__(self):
self.token = current_executing_context.set(self.context)
self._cm.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.token is not None:
current_executing_context.reset(self.token)
self._cm.__exit__(exc_type, exc_val, exc_tb)

View File

@ -220,9 +220,6 @@ torchvision = [
torchaudio = [
{ index = "pytorch-cpu", extra = "cpu" },
]
comfyui-frontend-package = [
{ git = "https://github.com/appmana/appmana-comfyui-frontend", subdirectory = "comfyui_frontend_package" },
]
"sageattention" = [
{ git = "https://github.com/thu-ml/SageAttention.git", extra = "attention", marker = "sys_platform == 'Linux' or sys_platform == 'win32'" },
]
@ -244,4 +241,4 @@ exclude = ["*.ipynb"]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/"]
packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/", "comfy_execution/"]