diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 723c4ec57..1e9ed1b10 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -146,6 +146,7 @@ class Configuration(dict): front_end_root (Optional[str]): The local filesystem path to the directory where the frontend is located. Overrides --front-end-version. comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org) database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'. + whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled. """ def __init__(self, **kwargs): @@ -212,6 +213,7 @@ class Configuration(dict): self.windows_standalone_build: bool = False self.disable_metadata: bool = False self.disable_all_custom_nodes: bool = False + self.whitelist_custom_nodes: list[str] = [] self.multi_user: bool = False self.plausible_analytics_base_url: Optional[str] = None self.plausible_analytics_domain: Optional[str] = None diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index df3e094a2..cf6d1c0c2 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -19,12 +19,16 @@ from typing import List, Optional, Tuple, Literal import torch from opentelemetry.trace import get_current_span, StatusCode, Status +from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \ + DependencyAwareCache, \ + BasicCache +# order matters +from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.utils import CurrentNodeContext from .main_pre import tracer from .. import interruption from .. import model_management -from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, DependencyAwareCache, \ - BasicCache from ..cli_args import args from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ @@ -36,12 +40,10 @@ from ..component_model.module_property import create_module_properties from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import context_execute_node, context_execute_prompt from ..execution_ext import should_panic_on_exception -# order matters -from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker -from ..graph_utils import is_link, GraphBuilder from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode from ..nodes_context import get_nodes -from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler +from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \ + ProgressRegistry from ..validation import validate_node_input _module_properties = create_module_properties() @@ -135,6 +137,7 @@ class CacheSet: SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") + def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None): if extra_data is None: extra_data = {} @@ -187,9 +190,11 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] return input_data_all, missing_keys + def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): raise ValueError("") + async def resolve_map_node_over_list_results(results): remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()] if len(remaining) == 0: @@ -245,6 +250,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] + async def process_inputs(inputs, index=None, input_is_list=False): if allow_interrupt: interruption.throw_exception_if_processing_interrupted() @@ -264,8 +270,10 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f f = getattr(obj, func) if inspect.iscoroutinefunction(f): async def async_wrapper(f, prompt_id, unique_id, list_index, args): + # todo: this is redundant with other parts of the hiddenswitch fork, but we've shimmed it for compatibility with CurrentNodeContext(prompt_id, unique_id, list_index): return await f(**args) + task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs)) # Give the task a chance to execute without yielding await asyncio.sleep(0) @@ -322,6 +330,7 @@ async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_b output, ui, has_subgraph = get_output_from_returns(return_values, obj) return output, ui, has_subgraph, False + def get_output_from_returns(return_values, obj): results = [] uis = [] @@ -464,9 +473,9 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra if hasattr(obj, "check_lazy_status"): required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True) required_inputs = await resolve_map_node_over_list_results(required_inputs) - required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) - required_inputs = [x for x in required_inputs if isinstance(x,str) and ( - x not in input_data_all or x in missing_keys + required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], [])) + required_inputs = [x for x in required_inputs if isinstance(x, str) and ( + x not in input_data_all or x in missing_keys )] if len(required_inputs) > 0: for i in required_inputs: @@ -500,10 +509,12 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) + async def await_completion(): tasks = [x for x in output_data if isinstance(x, asyncio.Task)] await asyncio.gather(*tasks, return_exceptions=True) unblock() + asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: @@ -683,11 +694,12 @@ class PromptExecutor: # torchao and potentially other optimization approaches break when the models are created in inference mode # todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) - with context_execute_prompt(self.server, prompt_id, inference_mode=inference_mode): - await self._execute_async(prompt, prompt_id, extra_data, execute_outputs) + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + with context_execute_prompt(self.server, prompt_id, progress_registry=ProgressRegistry(prompt_id, dynamic_prompt), inference_mode=inference_mode): + await self._execute_async(dynamic_prompt, prompt_id, extra_data, execute_outputs) - - async def _execute_async(self, prompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True): + async def _execute_async(self, prompt: DynamicPrompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True): if execute_outputs is None: execute_outputs = [] if extra_data is None: @@ -704,8 +716,8 @@ class PromptExecutor: self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) with torch.inference_mode() if inference_mode else nullcontext(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) + dynamic_prompt = prompt + prompt: dict = prompt.original_prompt add_progress_handler(WebUIProgressHandler(self.server)) is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) for cache in self.caches.all: @@ -722,7 +734,7 @@ class PromptExecutor: {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() @@ -988,7 +1000,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - #ret = obj_class.VALIDATE_INPUTS(**input_filtered) + # ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS") ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: diff --git a/comfy/cmd/protocol.py b/comfy/cmd/protocol.py index 038a0a840..ab6393fb8 100644 --- a/comfy/cmd/protocol.py +++ b/comfy/cmd/protocol.py @@ -1,7 +1,4 @@ +from ..component_model import queue_types -class BinaryEventTypes: - PREVIEW_IMAGE = 1 - UNENCODED_PREVIEW_IMAGE = 2 - TEXT = 3 - PREVIEW_IMAGE_WITH_METADATA = 4 - +# todo: should protocol really be all of queue_types? +BinaryEventTypes = queue_types.BinaryEventTypes \ No newline at end of file diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 9068d6ea4..0758ceccd 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -12,6 +12,7 @@ import socket import struct import sys import traceback +import typing import urllib import uuid from asyncio import Future, AbstractEventLoop, Task @@ -45,7 +46,8 @@ from ..cmd import execution from ..cmd import folder_paths from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue from ..component_model.encode_text_for_progress import encode_text_for_progress -from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo +from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo, \ + UnencodedPreviewImageMessage from ..component_model.file_output_path import file_output_path from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \ ExecutionStatus @@ -53,6 +55,7 @@ from ..digest import digest from ..images import open_image from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version from ..nodes.package_typing import ExportedNodes +from ..progress_types import PreviewImageMetadata logger = logging.getLogger(__name__) @@ -1049,7 +1052,7 @@ class PromptServer(ExecutorToClientProgress): prompt_info['exec_info'] = exec_info return prompt_info - async def send(self, event, data, sid=None): + async def send(self, event, data: UnencodedPreviewImageMessage | tuple[UnencodedPreviewImageMessage, PreviewImageMetadata] | bytes | bytearray | dict, sid=None): if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: await self.send_image(data, sid=sid) elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA: @@ -1061,7 +1064,7 @@ class PromptServer(ExecutorToClientProgress): else: await self.send_json(event, data, sid) - def encode_bytes(self, event: int | Enum | str, data): + def encode_bytes(self, event: int | Enum | str, data: bytes | bytearray | typing.Sequence[int]): # todo: investigate what is propagating these spurious, string-repr'd previews if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE): event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value @@ -1077,14 +1080,14 @@ class PromptServer(ExecutorToClientProgress): message.extend(data) return message - async def send_image(self, image_data, sid=None): + async def send_image(self, image_data: UnencodedPreviewImageMessage, sid=None): image_type = image_data[0] image = image_data[1] max_size = image_data[2] preview_bytes = encode_preview_image(image, image_type, max_size) await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) - async def send_image_with_metadata(self, image_data, metadata=None, sid=None): + async def send_image_with_metadata(self, image_data: UnencodedPreviewImageMessage, metadata: Optional[PreviewImageMetadata] = None, sid=None): image_type = image_data[0] image = image_data[1] max_size = image_data[2] @@ -1104,7 +1107,6 @@ class PromptServer(ExecutorToClientProgress): metadata["image_type"] = mimetype # Serialize metadata as JSON - import json metadata_json = json.dumps(metadata).encode('utf-8') metadata_length = len(metadata_json) @@ -1131,7 +1133,7 @@ class PromptServer(ExecutorToClientProgress): elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_bytes, message) - async def send_json(self, event, data, sid=None): + async def send_json(self, event, data: dict, sid=None): message = {"type": event, "data": data} if sid is None: diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index d966f9c0b..44b033ed6 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -12,6 +12,7 @@ from .outputs_types import OutputsDict from .queue_types import BinaryEventTypes from ..cli_args_types import Configuration from ..nodes.package_typing import InputTypeSpec +from ..progress_types import PreviewImageMetadata class ExecInfo(TypedDict): @@ -82,7 +83,7 @@ ExecutedMessage = ExecutingMessage SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None] -SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None] +SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[UnencodedPreviewImageMessage, PreviewImageMetadata], bytes, bytearray, str, None] class ExecutorToClientProgress(Protocol): diff --git a/comfy/component_model/queue_types.py b/comfy/component_model/queue_types.py index 4fcf84e3d..9b1471b84 100644 --- a/comfy/component_model/queue_types.py +++ b/comfy/component_model/queue_types.py @@ -4,6 +4,7 @@ import asyncio from enum import Enum from typing import NamedTuple, Optional, List, Literal, Sequence from typing import Tuple + from typing_extensions import NotRequired, TypedDict from .outputs_types import OutputsDict @@ -142,6 +143,7 @@ class BinaryEventTypes(Enum): PREVIEW_IMAGE = 1 UNENCODED_PREVIEW_IMAGE = 2 TEXT = 3 + PREVIEW_IMAGE_WITH_METADATA = 4 class ExecutorToClientMessage(TypedDict, total=False): diff --git a/comfy/execution_context.py b/comfy/execution_context.py index 4c1b2a656..30ebd4894 100644 --- a/comfy/execution_context.py +++ b/comfy/execution_context.py @@ -10,6 +10,7 @@ from .component_model.executor_types import ExecutorToClientProgress from .component_model.folder_path_types import FolderNames from .distributed.server_stub import ServerStub from .nodes.package_typing import ExportedNodes, exported_nodes_view +from .progress_types import AbstractProgressRegistry, ProgressRegistryStub comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context") # enables context var propagation across process boundaries for process pool executors @@ -23,10 +24,21 @@ class ExecutionContext: custom_nodes: ExportedNodes node_id: Optional[str] = None task_id: Optional[str] = None + list_index: Optional[int] = None inference_mode: bool = True + progress_registry: Optional[AbstractProgressRegistry] = None + + def __iter__(self): + """ + Provides tuple-like unpacking behavior, similar to a NamedTuple. + Yields task_id, node_id, and list_index. + """ + yield self.task_id + yield self.node_id + yield self.list_index -comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes())) +comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub())) def current_execution_context() -> ExecutionContext: @@ -51,9 +63,9 @@ def context_folder_names_and_paths(folder_names_and_paths: FolderNames): @contextmanager -def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, inference_mode: bool = True): +def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, progress_registry: AbstractProgressRegistry, inference_mode: bool = True): current_ctx = current_execution_context() - new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode) + new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode, progress_registry=progress_registry) with _new_execution_context(new_ctx): yield new_ctx @@ -84,4 +96,18 @@ def context_add_custom_nodes(exported_nodes: ExportedNodes): new_ctx = replace(current_ctx, custom_nodes=merged_custom_nodes) with _new_execution_context(new_ctx): - yield new_ctx \ No newline at end of file + yield new_ctx + + +@contextmanager +def context_set_node_and_prompt(prompt_id: str, node_id: str, list_index: Optional[int] = None): + """ + A context manager to set the prompt_id (task_id), node_id, and optional list_index for the current execution. + This is useful for fine-grained context setting within a node's execution, especially for batch processing. + + Replaces the @guill code upstream + """ + current_ctx = current_execution_context() + new_ctx = replace(current_ctx, task_id=prompt_id, node_id=node_id, list_index=list_index) + with _new_execution_context(new_ctx): + yield new_ctx diff --git a/comfy/nodes/vanilla_node_importing.py b/comfy/nodes/vanilla_node_importing.py index 712bd20f2..164536077 100644 --- a/comfy/nodes/vanilla_node_importing.py +++ b/comfy/nodes/vanilla_node_importing.py @@ -194,17 +194,15 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes: # this mitigation puts files that custom nodes expects are at the root of the repository back where they should be # found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things, # the way community custom nodes is pretty radioactive - from ..cmd import cuda_malloc, folder_paths, latent_preview - from .. import graph, graph_utils, caching + from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol from .. import node_helpers from .. import __version__ - for module in (cuda_malloc, folder_paths, latent_preview, node_helpers): + for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol): module_short_name = module.__name__.split(".")[-1] sys.modules[module_short_name] = module sys.modules['nodes'] = base_nodes - sys.modules['comfy_execution.graph'] = graph - sys.modules['comfy_execution.graph_utils'] = graph_utils - sys.modules['comfy_execution.caching'] = caching + # apparently this is also something that happens + sys.modules['comfy.nodes'] = base_nodes comfyui_version = types.ModuleType('comfyui_version', '') setattr(comfyui_version, "__version__", __version__) sys.modules['comfyui_version'] = comfyui_version diff --git a/comfy/progress.py b/comfy/progress.py index d63dfba55..1711711f6 100644 --- a/comfy/progress.py +++ b/comfy/progress.py @@ -7,11 +7,17 @@ from PIL import Image from tqdm import tqdm from typing_extensions import override +from .component_model.module_property import create_module_properties +from .execution_context import current_execution_context +from .progress_types import AbstractProgressRegistry + if TYPE_CHECKING: - from .graph import DynamicPrompt -from protocol import BinaryEventTypes + from comfy_execution.graph import DynamicPrompt +from .cmd.protocol import BinaryEventTypes from comfy_api import feature_flags +_module_properties = create_module_properties() + class NodeState(Enum): Pending = "pending" @@ -234,7 +240,7 @@ class WebUIProgressHandler(ProgressHandler): self._send_progress_state(prompt_id, self.registry.nodes) -class ProgressRegistry: +class ProgressRegistry(AbstractProgressRegistry): """ Registry that maintains node progress state and notifies registered handlers. """ @@ -320,18 +326,25 @@ class ProgressRegistry: # Global registry instance -global_progress_registry: ProgressRegistry = None +@_module_properties.getter +def _global_progress_registry() -> ProgressRegistry: + return current_execution_context().progress_registry def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: - global global_progress_registry + """ + the caller must create a new progress registry + :param prompt_id: + :param dynprompt: + :return: None + """ + global_progress_registry = _global_progress_registry() # Reset existing handlers if registry exists if global_progress_registry is not None: global_progress_registry.reset_handlers() - # Create new registry - global_progress_registry = ProgressRegistry(prompt_id, dynprompt) + # XXX caller now creates new progress registry def add_progress_handler(handler: ProgressHandler) -> None: @@ -341,11 +354,4 @@ def add_progress_handler(handler: ProgressHandler) -> None: def get_progress_state() -> ProgressRegistry: - global global_progress_registry - if global_progress_registry is None: - from .graph import DynamicPrompt - - global_progress_registry = ProgressRegistry( - prompt_id="", dynprompt=DynamicPrompt({}) - ) - return global_progress_registry + return _global_progress_registry() diff --git a/comfy/progress_types.py b/comfy/progress_types.py new file mode 100644 index 000000000..f7df4ed50 --- /dev/null +++ b/comfy/progress_types.py @@ -0,0 +1,103 @@ +from abc import ABCMeta, abstractmethod + +from typing_extensions import TypedDict, NotRequired + + +class PreviewImageMetadata(TypedDict, total=True): + """ + Metadata associated with a preview image sent to the UI. + """ + node_id: str + prompt_id: str + display_node_id: str + parent_node_id: str + real_node_id: str + image_type: NotRequired[str] + + +class AbstractProgressRegistry(metaclass=ABCMeta): + + @abstractmethod + def register_handler(self, handler): + """Register a progress handler""" + pass + + @abstractmethod + def unregister_handler(self, handler_name): + """Unregister a progress handler""" + pass + + @abstractmethod + def enable_handler(self, handler_name): + """Enable a progress handler""" + pass + + @abstractmethod + def disable_handler(self, handler_name): + """Disable a progress handler""" + pass + + @abstractmethod + def ensure_entry(self, node_id): + """Ensure a node entry exists""" + pass + + @abstractmethod + def start_progress(self, node_id): + """Start progress tracking for a node""" + pass + + @abstractmethod + def update_progress(self, node_id, value, max_value, image): + """Update progress for a node""" + pass + + @abstractmethod + def finish_progress(self, node_id): + """Finish progress tracking for a node""" + pass + + @abstractmethod + def reset_handlers(self): + """Reset all handlers""" + pass + + +class ProgressRegistryStub(AbstractProgressRegistry): + """A stub implementation of AbstractProgressRegistry that performs no operations.""" + + def register_handler(self, handler): + """Register a progress handler""" + pass + + def unregister_handler(self, handler_name): + """Unregister a progress handler""" + pass + + def enable_handler(self, handler_name): + """Enable a progress handler""" + pass + + def disable_handler(self, handler_name): + """Disable a progress handler""" + pass + + def ensure_entry(self, node_id): + """Ensure a node entry exists""" + pass + + def start_progress(self, node_id): + """Start progress tracking for a node""" + pass + + def update_progress(self, node_id, value, max_value, image): + """Update progress for a node""" + pass + + def finish_progress(self, node_id): + """Finish progress tracking for a node""" + pass + + def reset_handlers(self): + """Reset all handlers""" + pass diff --git a/comfy_execution/__init__.py b/comfy_execution/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index bdff5327a..4c9bf49cb 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -3,7 +3,7 @@ from typing import Sequence, Mapping, Dict from .graph import DynamicPrompt from .graph_utils import is_link -from .nodes_context import get_nodes +from comfy.nodes_context import get_nodes from abc import ABC, abstractmethod diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index e561485e8..2d291c247 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -4,11 +4,11 @@ import asyncio import inspect from typing import Optional, Type, Literal -from .comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions -from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \ +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions +from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \ DependencyExecutionErrorMessage +from comfy.nodes_context import get_nodes from .graph_utils import is_link -from .nodes_context import get_nodes class DynamicPrompt: diff --git a/comfy_execution/utils.py b/comfy_execution/utils.py index 62d32f101..6f94101f9 100644 --- a/comfy_execution/utils.py +++ b/comfy_execution/utils.py @@ -1,46 +1,53 @@ -import contextvars -from typing import Optional, NamedTuple +from __future__ import annotations -class ExecutionContext(NamedTuple): - """ - Context information about the currently executing node. +from typing import Optional - Attributes: - node_id: The ID of the currently executing node - list_index: The index in a list being processed (for operations on batches/lists) - """ - prompt_id: str - node_id: str - list_index: Optional[int] +from comfy import execution_context as core_execution_context + +ExecutionContext = core_execution_context.ExecutionContext +""" +Context information about the currently executing node. +This is a compatibility wrapper around the core execution context. + +Attributes: + prompt_id: The ID of the currently executing prompt (task_id in core context) + node_id: The ID of the currently executing node + list_index: The index in a list being processed (for operations on batches/lists) +""" -current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None) def get_executing_context() -> Optional[ExecutionContext]: - return current_executing_context.get(None) + """ + Gets the current execution context from the core context provider. + Returns a compatibility ExecutionContext object or None if not in an execution context. + """ + ctx = core_execution_context.current_execution_context() + if ctx.task_id is None or ctx.node_id is None: + return None + return ctx + class CurrentNodeContext: """ Context manager for setting the current executing node context. - - Sets the current_executing_context on enter and resets it on exit. + This is a wrapper around the core `context_set_node_and_prompt` context manager. Example: - with CurrentNodeContext(node_id="123", list_index=0): + with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0): # Code that should run with the current node context set process_image() """ + def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None): - self.context = ExecutionContext( - prompt_id= prompt_id, - node_id= node_id, - list_index= list_index + self._cm = core_execution_context.context_set_node_and_prompt( + prompt_id=prompt_id, + node_id=node_id, + list_index=list_index ) - self.token = None def __enter__(self): - self.token = current_executing_context.set(self.context) + self._cm.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): - if self.token is not None: - current_executing_context.reset(self.token) + self._cm.__exit__(exc_type, exc_val, exc_tb) diff --git a/pyproject.toml b/pyproject.toml index e38c8a932..c6d17e4a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,9 +220,6 @@ torchvision = [ torchaudio = [ { index = "pytorch-cpu", extra = "cpu" }, ] -comfyui-frontend-package = [ - { git = "https://github.com/appmana/appmana-comfyui-frontend", subdirectory = "comfyui_frontend_package" }, -] "sageattention" = [ { git = "https://github.com/thu-ml/SageAttention.git", extra = "attention", marker = "sys_platform == 'Linux' or sys_platform == 'win32'" }, ] @@ -244,4 +241,4 @@ exclude = ["*.ipynb"] allow-direct-references = true [tool.hatch.build.targets.wheel] -packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/"] \ No newline at end of file +packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/", "comfy_execution/"] \ No newline at end of file