mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
merging upstream with hiddenswitch
- deprecate our preview image fork of the frontend because upstream now has the needed functionality - merge the context executor from upstream into ours
This commit is contained in:
parent
bd6f28e3bd
commit
499f9be5fa
@ -146,6 +146,7 @@ class Configuration(dict):
|
||||
front_end_root (Optional[str]): The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.
|
||||
comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org)
|
||||
database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.
|
||||
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -212,6 +213,7 @@ class Configuration(dict):
|
||||
self.windows_standalone_build: bool = False
|
||||
self.disable_metadata: bool = False
|
||||
self.disable_all_custom_nodes: bool = False
|
||||
self.whitelist_custom_nodes: list[str] = []
|
||||
self.multi_user: bool = False
|
||||
self.plausible_analytics_base_url: Optional[str] = None
|
||||
self.plausible_analytics_domain: Optional[str] = None
|
||||
|
||||
@ -19,12 +19,16 @@ from typing import List, Optional, Tuple, Literal
|
||||
import torch
|
||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||
DependencyAwareCache, \
|
||||
BasicCache
|
||||
# order matters
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from .main_pre import tracer
|
||||
from .. import interruption
|
||||
from .. import model_management
|
||||
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, DependencyAwareCache, \
|
||||
BasicCache
|
||||
from ..cli_args import args
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||
@ -36,12 +40,10 @@ from ..component_model.module_property import create_module_properties
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..execution_context import context_execute_node, context_execute_prompt
|
||||
from ..execution_ext import should_panic_on_exception
|
||||
# order matters
|
||||
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from ..graph_utils import is_link, GraphBuilder
|
||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||
from ..nodes_context import get_nodes
|
||||
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
||||
ProgressRegistry
|
||||
from ..validation import validate_node_input
|
||||
|
||||
_module_properties = create_module_properties()
|
||||
@ -135,6 +137,7 @@ class CacheSet:
|
||||
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None):
|
||||
if extra_data is None:
|
||||
extra_data = {}
|
||||
@ -187,9 +190,11 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
return input_data_all, missing_keys
|
||||
|
||||
|
||||
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
raise ValueError("")
|
||||
|
||||
|
||||
async def resolve_map_node_over_list_results(results):
|
||||
remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()]
|
||||
if len(remaining) == 0:
|
||||
@ -245,6 +250,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||
|
||||
results = []
|
||||
|
||||
async def process_inputs(inputs, index=None, input_is_list=False):
|
||||
if allow_interrupt:
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
@ -264,8 +270,10 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
# todo: this is redundant with other parts of the hiddenswitch fork, but we've shimmed it for compatibility
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
return await f(**args)
|
||||
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
@ -322,6 +330,7 @@ async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_b
|
||||
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
|
||||
return output, ui, has_subgraph, False
|
||||
|
||||
|
||||
def get_output_from_returns(return_values, obj):
|
||||
results = []
|
||||
uis = []
|
||||
@ -464,9 +473,9 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
)]
|
||||
if len(required_inputs) > 0:
|
||||
for i in required_inputs:
|
||||
@ -500,10 +509,12 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
|
||||
async def await_completion():
|
||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
unblock()
|
||||
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
@ -683,11 +694,12 @@ class PromptExecutor:
|
||||
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
||||
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
||||
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
|
||||
with context_execute_prompt(self.server, prompt_id, inference_mode=inference_mode):
|
||||
await self._execute_async(prompt, prompt_id, extra_data, execute_outputs)
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
with context_execute_prompt(self.server, prompt_id, progress_registry=ProgressRegistry(prompt_id, dynamic_prompt), inference_mode=inference_mode):
|
||||
await self._execute_async(dynamic_prompt, prompt_id, extra_data, execute_outputs)
|
||||
|
||||
|
||||
async def _execute_async(self, prompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True):
|
||||
async def _execute_async(self, prompt: DynamicPrompt, prompt_id, extra_data=None, execute_outputs: list[str] = None, inference_mode: bool = True):
|
||||
if execute_outputs is None:
|
||||
execute_outputs = []
|
||||
if extra_data is None:
|
||||
@ -704,8 +716,8 @@ class PromptExecutor:
|
||||
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode() if inference_mode else nullcontext():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
dynamic_prompt = prompt
|
||||
prompt: dict = prompt.original_prompt
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
@ -722,7 +734,7 @@ class PromptExecutor:
|
||||
{"nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
@ -988,7 +1000,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
# ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
ret = await resolve_map_node_over_list_results(ret)
|
||||
for x in input_filtered:
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
from ..component_model import queue_types
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
PREVIEW_IMAGE_WITH_METADATA = 4
|
||||
|
||||
# todo: should protocol really be all of queue_types?
|
||||
BinaryEventTypes = queue_types.BinaryEventTypes
|
||||
@ -12,6 +12,7 @@ import socket
|
||||
import struct
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
import urllib
|
||||
import uuid
|
||||
from asyncio import Future, AbstractEventLoop, Task
|
||||
@ -45,7 +46,8 @@ from ..cmd import execution
|
||||
from ..cmd import folder_paths
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
|
||||
from ..component_model.encode_text_for_progress import encode_text_for_progress
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo, \
|
||||
UnencodedPreviewImageMessage
|
||||
from ..component_model.file_output_path import file_output_path
|
||||
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
|
||||
ExecutionStatus
|
||||
@ -53,6 +55,7 @@ from ..digest import digest
|
||||
from ..images import open_image
|
||||
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
from ..progress_types import PreviewImageMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1049,7 +1052,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
prompt_info['exec_info'] = exec_info
|
||||
return prompt_info
|
||||
|
||||
async def send(self, event, data, sid=None):
|
||||
async def send(self, event, data: UnencodedPreviewImageMessage | tuple[UnencodedPreviewImageMessage, PreviewImageMetadata] | bytes | bytearray | dict, sid=None):
|
||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
await self.send_image(data, sid=sid)
|
||||
elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
||||
@ -1061,7 +1064,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
else:
|
||||
await self.send_json(event, data, sid)
|
||||
|
||||
def encode_bytes(self, event: int | Enum | str, data):
|
||||
def encode_bytes(self, event: int | Enum | str, data: bytes | bytearray | typing.Sequence[int]):
|
||||
# todo: investigate what is propagating these spurious, string-repr'd previews
|
||||
if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE):
|
||||
event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value
|
||||
@ -1077,14 +1080,14 @@ class PromptServer(ExecutorToClientProgress):
|
||||
message.extend(data)
|
||||
return message
|
||||
|
||||
async def send_image(self, image_data, sid=None):
|
||||
async def send_image(self, image_data: UnencodedPreviewImageMessage, sid=None):
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
preview_bytes = encode_preview_image(image, image_type, max_size)
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
async def send_image_with_metadata(self, image_data, metadata=None, sid=None):
|
||||
async def send_image_with_metadata(self, image_data: UnencodedPreviewImageMessage, metadata: Optional[PreviewImageMetadata] = None, sid=None):
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
@ -1104,7 +1107,6 @@ class PromptServer(ExecutorToClientProgress):
|
||||
metadata["image_type"] = mimetype
|
||||
|
||||
# Serialize metadata as JSON
|
||||
import json
|
||||
metadata_json = json.dumps(metadata).encode('utf-8')
|
||||
metadata_length = len(metadata_json)
|
||||
|
||||
@ -1131,7 +1133,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
elif sid in self.sockets:
|
||||
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
|
||||
|
||||
async def send_json(self, event, data, sid=None):
|
||||
async def send_json(self, event, data: dict, sid=None):
|
||||
message = {"type": event, "data": data}
|
||||
|
||||
if sid is None:
|
||||
|
||||
@ -12,6 +12,7 @@ from .outputs_types import OutputsDict
|
||||
from .queue_types import BinaryEventTypes
|
||||
from ..cli_args_types import Configuration
|
||||
from ..nodes.package_typing import InputTypeSpec
|
||||
from ..progress_types import PreviewImageMetadata
|
||||
|
||||
|
||||
class ExecInfo(TypedDict):
|
||||
@ -82,7 +83,7 @@ ExecutedMessage = ExecutingMessage
|
||||
|
||||
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
|
||||
|
||||
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
||||
SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, tuple[UnencodedPreviewImageMessage, PreviewImageMetadata], bytes, bytearray, str, None]
|
||||
|
||||
|
||||
class ExecutorToClientProgress(Protocol):
|
||||
|
||||
@ -4,6 +4,7 @@ import asyncio
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Optional, List, Literal, Sequence
|
||||
from typing import Tuple
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .outputs_types import OutputsDict
|
||||
@ -142,6 +143,7 @@ class BinaryEventTypes(Enum):
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
TEXT = 3
|
||||
PREVIEW_IMAGE_WITH_METADATA = 4
|
||||
|
||||
|
||||
class ExecutorToClientMessage(TypedDict, total=False):
|
||||
|
||||
@ -10,6 +10,7 @@ from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.folder_path_types import FolderNames
|
||||
from .distributed.server_stub import ServerStub
|
||||
from .nodes.package_typing import ExportedNodes, exported_nodes_view
|
||||
from .progress_types import AbstractProgressRegistry, ProgressRegistryStub
|
||||
|
||||
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
|
||||
# enables context var propagation across process boundaries for process pool executors
|
||||
@ -23,10 +24,21 @@ class ExecutionContext:
|
||||
custom_nodes: ExportedNodes
|
||||
node_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
list_index: Optional[int] = None
|
||||
inference_mode: bool = True
|
||||
progress_registry: Optional[AbstractProgressRegistry] = None
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Provides tuple-like unpacking behavior, similar to a NamedTuple.
|
||||
Yields task_id, node_id, and list_index.
|
||||
"""
|
||||
yield self.task_id
|
||||
yield self.node_id
|
||||
yield self.list_index
|
||||
|
||||
|
||||
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes()))
|
||||
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
|
||||
|
||||
|
||||
def current_execution_context() -> ExecutionContext:
|
||||
@ -51,9 +63,9 @@ def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, inference_mode: bool = True):
|
||||
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, progress_registry: AbstractProgressRegistry, inference_mode: bool = True):
|
||||
current_ctx = current_execution_context()
|
||||
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode)
|
||||
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode, progress_registry=progress_registry)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
@ -85,3 +97,17 @@ def context_add_custom_nodes(exported_nodes: ExportedNodes):
|
||||
new_ctx = replace(current_ctx, custom_nodes=merged_custom_nodes)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_set_node_and_prompt(prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
"""
|
||||
A context manager to set the prompt_id (task_id), node_id, and optional list_index for the current execution.
|
||||
This is useful for fine-grained context setting within a node's execution, especially for batch processing.
|
||||
|
||||
Replaces the @guill code upstream
|
||||
"""
|
||||
current_ctx = current_execution_context()
|
||||
new_ctx = replace(current_ctx, task_id=prompt_id, node_id=node_id, list_index=list_index)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
@ -194,17 +194,15 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
||||
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
|
||||
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
|
||||
# the way community custom nodes is pretty radioactive
|
||||
from ..cmd import cuda_malloc, folder_paths, latent_preview
|
||||
from .. import graph, graph_utils, caching
|
||||
from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol
|
||||
from .. import node_helpers
|
||||
from .. import __version__
|
||||
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers):
|
||||
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
|
||||
module_short_name = module.__name__.split(".")[-1]
|
||||
sys.modules[module_short_name] = module
|
||||
sys.modules['nodes'] = base_nodes
|
||||
sys.modules['comfy_execution.graph'] = graph
|
||||
sys.modules['comfy_execution.graph_utils'] = graph_utils
|
||||
sys.modules['comfy_execution.caching'] = caching
|
||||
# apparently this is also something that happens
|
||||
sys.modules['comfy.nodes'] = base_nodes
|
||||
comfyui_version = types.ModuleType('comfyui_version', '')
|
||||
setattr(comfyui_version, "__version__", __version__)
|
||||
sys.modules['comfyui_version'] = comfyui_version
|
||||
|
||||
@ -7,11 +7,17 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import override
|
||||
|
||||
from .component_model.module_property import create_module_properties
|
||||
from .execution_context import current_execution_context
|
||||
from .progress_types import AbstractProgressRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import DynamicPrompt
|
||||
from protocol import BinaryEventTypes
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from .cmd.protocol import BinaryEventTypes
|
||||
from comfy_api import feature_flags
|
||||
|
||||
_module_properties = create_module_properties()
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
Pending = "pending"
|
||||
@ -234,7 +240,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
self._send_progress_state(prompt_id, self.registry.nodes)
|
||||
|
||||
|
||||
class ProgressRegistry:
|
||||
class ProgressRegistry(AbstractProgressRegistry):
|
||||
"""
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
"""
|
||||
@ -320,18 +326,25 @@ class ProgressRegistry:
|
||||
|
||||
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry = None
|
||||
@_module_properties.getter
|
||||
def _global_progress_registry() -> ProgressRegistry:
|
||||
return current_execution_context().progress_registry
|
||||
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
global global_progress_registry
|
||||
"""
|
||||
the caller must create a new progress registry
|
||||
:param prompt_id:
|
||||
:param dynprompt:
|
||||
:return: None
|
||||
"""
|
||||
global_progress_registry = _global_progress_registry()
|
||||
|
||||
# Reset existing handlers if registry exists
|
||||
if global_progress_registry is not None:
|
||||
global_progress_registry.reset_handlers()
|
||||
|
||||
# Create new registry
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||
# XXX caller now creates new progress registry
|
||||
|
||||
|
||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
@ -341,11 +354,4 @@ def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
|
||||
|
||||
def get_progress_state() -> ProgressRegistry:
|
||||
global global_progress_registry
|
||||
if global_progress_registry is None:
|
||||
from .graph import DynamicPrompt
|
||||
|
||||
global_progress_registry = ProgressRegistry(
|
||||
prompt_id="", dynprompt=DynamicPrompt({})
|
||||
)
|
||||
return global_progress_registry
|
||||
return _global_progress_registry()
|
||||
|
||||
103
comfy/progress_types.py
Normal file
103
comfy/progress_types.py
Normal file
@ -0,0 +1,103 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
|
||||
class PreviewImageMetadata(TypedDict, total=True):
|
||||
"""
|
||||
Metadata associated with a preview image sent to the UI.
|
||||
"""
|
||||
node_id: str
|
||||
prompt_id: str
|
||||
display_node_id: str
|
||||
parent_node_id: str
|
||||
real_node_id: str
|
||||
image_type: NotRequired[str]
|
||||
|
||||
|
||||
class AbstractProgressRegistry(metaclass=ABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
def register_handler(self, handler):
|
||||
"""Register a progress handler"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unregister_handler(self, handler_name):
|
||||
"""Unregister a progress handler"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enable_handler(self, handler_name):
|
||||
"""Enable a progress handler"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disable_handler(self, handler_name):
|
||||
"""Disable a progress handler"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def ensure_entry(self, node_id):
|
||||
"""Ensure a node entry exists"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_progress(self, node_id):
|
||||
"""Start progress tracking for a node"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_progress(self, node_id, value, max_value, image):
|
||||
"""Update progress for a node"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def finish_progress(self, node_id):
|
||||
"""Finish progress tracking for a node"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_handlers(self):
|
||||
"""Reset all handlers"""
|
||||
pass
|
||||
|
||||
|
||||
class ProgressRegistryStub(AbstractProgressRegistry):
|
||||
"""A stub implementation of AbstractProgressRegistry that performs no operations."""
|
||||
|
||||
def register_handler(self, handler):
|
||||
"""Register a progress handler"""
|
||||
pass
|
||||
|
||||
def unregister_handler(self, handler_name):
|
||||
"""Unregister a progress handler"""
|
||||
pass
|
||||
|
||||
def enable_handler(self, handler_name):
|
||||
"""Enable a progress handler"""
|
||||
pass
|
||||
|
||||
def disable_handler(self, handler_name):
|
||||
"""Disable a progress handler"""
|
||||
pass
|
||||
|
||||
def ensure_entry(self, node_id):
|
||||
"""Ensure a node entry exists"""
|
||||
pass
|
||||
|
||||
def start_progress(self, node_id):
|
||||
"""Start progress tracking for a node"""
|
||||
pass
|
||||
|
||||
def update_progress(self, node_id, value, max_value, image):
|
||||
"""Update progress for a node"""
|
||||
pass
|
||||
|
||||
def finish_progress(self, node_id):
|
||||
"""Finish progress tracking for a node"""
|
||||
pass
|
||||
|
||||
def reset_handlers(self):
|
||||
"""Reset all handlers"""
|
||||
pass
|
||||
0
comfy_execution/__init__.py
Normal file
0
comfy_execution/__init__.py
Normal file
@ -3,7 +3,7 @@ from typing import Sequence, Mapping, Dict
|
||||
|
||||
from .graph import DynamicPrompt
|
||||
from .graph_utils import is_link
|
||||
from .nodes_context import get_nodes
|
||||
from comfy.nodes_context import get_nodes
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
@ -4,11 +4,11 @@ import asyncio
|
||||
import inspect
|
||||
from typing import Optional, Type, Literal
|
||||
|
||||
from .comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
|
||||
DependencyExecutionErrorMessage
|
||||
from comfy.nodes_context import get_nodes
|
||||
from .graph_utils import is_link
|
||||
from .nodes_context import get_nodes
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
|
||||
@ -1,46 +1,53 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
from __future__ import annotations
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
from typing import Optional
|
||||
|
||||
Attributes:
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
from comfy import execution_context as core_execution_context
|
||||
|
||||
ExecutionContext = core_execution_context.ExecutionContext
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
This is a compatibility wrapper around the core execution context.
|
||||
|
||||
Attributes:
|
||||
prompt_id: The ID of the currently executing prompt (task_id in core context)
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
"""
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
return current_executing_context.get(None)
|
||||
"""
|
||||
Gets the current execution context from the core context provider.
|
||||
Returns a compatibility ExecutionContext object or None if not in an execution context.
|
||||
"""
|
||||
ctx = core_execution_context.current_execution_context()
|
||||
if ctx.task_id is None or ctx.node_id is None:
|
||||
return None
|
||||
return ctx
|
||||
|
||||
|
||||
class CurrentNodeContext:
|
||||
"""
|
||||
Context manager for setting the current executing node context.
|
||||
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
This is a wrapper around the core `context_set_node_and_prompt` context manager.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
self._cm = core_execution_context.context_set_node_and_prompt(
|
||||
prompt_id=prompt_id,
|
||||
node_id=node_id,
|
||||
list_index=list_index
|
||||
)
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
self.token = current_executing_context.set(self.context)
|
||||
self._cm.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.token is not None:
|
||||
current_executing_context.reset(self.token)
|
||||
self._cm.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
@ -220,9 +220,6 @@ torchvision = [
|
||||
torchaudio = [
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
]
|
||||
comfyui-frontend-package = [
|
||||
{ git = "https://github.com/appmana/appmana-comfyui-frontend", subdirectory = "comfyui_frontend_package" },
|
||||
]
|
||||
"sageattention" = [
|
||||
{ git = "https://github.com/thu-ml/SageAttention.git", extra = "attention", marker = "sys_platform == 'Linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
@ -244,4 +241,4 @@ exclude = ["*.ipynb"]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/"]
|
||||
packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/", "comfy_execution/"]
|
||||
Loading…
Reference in New Issue
Block a user