From 173b1ce0ae003bac3f28ac60da68f007b63321bd Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 26 Aug 2025 16:31:39 -0700 Subject: [PATCH] Add support for intuitive progress notifications when using comfyui as a library --- comfy/client/async_progress_iterable.py | 83 +++++++++++++++++++++++++ comfy/client/client_types.py | 10 ++- comfy/client/embedded_comfy_client.py | 30 ++++++++- comfy/component_model/executor_types.py | 16 ++++- tests/library/test_embedded_client.py | 18 +++++- 5 files changed, 151 insertions(+), 6 deletions(-) create mode 100644 comfy/client/async_progress_iterable.py diff --git a/comfy/client/async_progress_iterable.py b/comfy/client/async_progress_iterable.py new file mode 100644 index 000000000..81f1ef331 --- /dev/null +++ b/comfy/client/async_progress_iterable.py @@ -0,0 +1,83 @@ +import asyncio +import uuid +from asyncio import Task, Future +from typing import override, NamedTuple, Optional, AsyncIterable + +from .client_types import V1QueuePromptResponse, ProgressNotification +from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData + + +class _ProgressNotification(NamedTuple): + event: SendSyncEvent + data: SendSyncData + sid: Optional[str] = None + complete: bool = False + + +class QueuePromptWithProgress: + def __init__(self): + self._progress_handler = _ProgressHandler() + + def progress(self) -> AsyncIterable[ProgressNotification]: + return self._progress_handler + + async def get(self) -> V1QueuePromptResponse: + return await self._progress_handler.fut + + def future(self) -> Future[V1QueuePromptResponse]: + return self._progress_handler.fut + + @property + def progress_handler(self) -> ExecutorToClientProgress: + return self._progress_handler + + def complete(self, task: Task[V1QueuePromptResponse]): + self._progress_handler.complete(task) + + +class _ProgressHandler(ExecutorToClientProgress, AsyncIterable[ProgressNotification]): + def __init__(self, user_id: str = None): + if user_id is None: + self.client_id = str(uuid.uuid4()) + + self._loop = asyncio.get_running_loop() + self._queue: asyncio.Queue[_ProgressNotification] = asyncio.Queue() + self.fut: Future[V1QueuePromptResponse] = asyncio.Future() + + @override + @property + def receive_all_progress_notifications(self) -> bool: + return True + + @override + @receive_all_progress_notifications.setter + def receive_all_progress_notifications(self, value: bool): + return + + def send_sync(self, + event: SendSyncEvent, + data: SendSyncData, + sid: Optional[str] = None): + self._loop.call_soon_threadsafe(self._queue.put_nowait, _ProgressNotification(event, data, sid)) + + def complete(self, task: Task[V1QueuePromptResponse]): + if task.exception() is not None: + self.fut.set_exception(task.exception()) + else: + self.fut.set_result(task.result()) + self._queue.put_nowait(_ProgressNotification(None, None, None, complete=True)) + + def __aiter__(self): + return self + + async def __anext__(self): + result: _ProgressNotification = await self._queue.get() + self._queue.task_done() + + if result.complete: + if self.fut.exception() is not None: + raise self.fut.exception() + else: + raise StopAsyncIteration() + else: + return ProgressNotification(result.event, result.data, result.sid) diff --git a/comfy/client/client_types.py b/comfy/client/client_types.py index 22bd35ad4..d5fd6b1c4 100644 --- a/comfy/client/client_types.py +++ b/comfy/client/client_types.py @@ -1,8 +1,10 @@ import dataclasses -from typing import List +from typing import List, NamedTuple, Optional from typing_extensions import TypedDict, Literal, NotRequired +from comfy.component_model.executor_types import SendSyncEvent, SendSyncData + class FileOutput(TypedDict, total=False): filename: str @@ -22,3 +24,9 @@ class Output(TypedDict, total=False): class V1QueuePromptResponse: urls: List[str] outputs: dict[str, Output] + + +class ProgressNotification(NamedTuple): + event: SendSyncEvent + data: SendSyncData + sid: Optional[str] = None diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 3d7bfd5bc..d3af0acd6 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -10,12 +10,13 @@ import uuid from asyncio import get_event_loop from dataclasses import dataclass from multiprocessing import RLock -from typing import Optional +from typing import Optional, Generator from opentelemetry import context, propagate from opentelemetry.context import Context, attach, detach from opentelemetry.trace import Status, StatusCode +from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress from ..cmd.main_pre import tracer from .client_types import V1QueuePromptResponse from ..api.components.schema.prompt import PromptDict @@ -201,7 +202,8 @@ class Comfy: self._is_running = False async def queue_prompt_api(self, - prompt: PromptDict | str | dict) -> V1QueuePromptResponse: + prompt: PromptDict | str | dict, + progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse: """ Queues a prompt for execution, returning the output when it is complete. :param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt @@ -212,9 +214,31 @@ class Comfy: if isinstance(prompt, dict): from ..api.components.schema.prompt import Prompt prompt = Prompt.validate(prompt) - outputs = await self.queue_prompt(prompt) + outputs = await self.queue_prompt(prompt, progress_handler=progress_handler) return V1QueuePromptResponse(urls=[], outputs=outputs) + def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress: + """ + Queues a prompt with progress notifications. + + >>> from comfy.client.embedded_comfy_client import Comfy + >>> from comfy.client.client_types import ProgressNotification + >>> async with Comfy() as comfy: + >>> task = comfy.queue_with_progress({ ... }) + >>> # Raises an exception while iterating + >>> notification: ProgressNotification + >>> async for notification in task.progress(): + >>> print(notification.data) + >>> # If you get this far, no errors occurred. + >>> result = await task.get() + :param prompt: + :return: + """ + handler = QueuePromptWithProgress() + task = asyncio.create_task(self.queue_prompt_api(prompt, progress_handler=handler.progress_handler)) + task.add_done_callback(handler.complete) + return handler + @tracer.start_as_current_span("Queue Prompt") async def queue_prompt(self, prompt: PromptDict | dict, diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 31ec1c65d..c8108fb16 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -143,6 +143,13 @@ class ExecutorToClientProgress(Protocol): @property def sockets_metadata(self) -> SocketsMetadataType: + """ + Metadata about what the socket supports + + Currently used only by the frontend + + :return: in the abstract base class, a static object that is used by the web server to ignore this; in the real classes, sometimes information about connected users + """ return {"__unimplemented": True} def send_sync(self, @@ -160,6 +167,13 @@ class ExecutorToClientProgress(Protocol): pass def send_progress_text(self, text: Union[bytes, bytearray, str], node_id: str, sid=None): + """ + Send text to the client + :param text: the text to send + :param node_id: the node this belongs to + :param sid: websocket ID / the client ID to be responding to + :return: + """ message = encode_text_for_progress(node_id, text) self.send_sync(BinaryEventTypes.TEXT, message, sid) @@ -167,7 +181,7 @@ class ExecutorToClientProgress(Protocol): def queue_updated(self, queue_remaining: Optional[int] = None): """ Indicates that the local client's queue has been updated - :return: + :return: nothing """ pass diff --git a/tests/library/test_embedded_client.py b/tests/library/test_embedded_client.py index 3c20de91c..6b51ce0fa 100644 --- a/tests/library/test_embedded_client.py +++ b/tests/library/test_embedded_client.py @@ -47,4 +47,20 @@ async def test_multithreaded_comfy(): async with Comfy(max_workers=2) as client: prompt = sdxl_workflow_with_refiner("test") outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)]) - assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter) \ No newline at end of file + assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter) + + +@pytest.mark.asyncio +async def test_progress_notifications(): + async with Comfy() as client: + prompt = sdxl_workflow_with_refiner("test") + task = client.queue_with_progress(prompt) + + notifications_received = [] + async for notification in task.progress(): + notifications_received.append(notification) + + assert len(notifications_received) > 0, "Should have received progress notifications" + + result = await task.get() + assert result.outputs["13"]["images"][0]["abs_path"] is not None \ No newline at end of file