Add support for intuitive progress notifications when using comfyui as a library

This commit is contained in:
doctorpangloss 2025-08-26 16:31:39 -07:00
parent 35d890eafe
commit 173b1ce0ae
5 changed files with 151 additions and 6 deletions

View File

@ -0,0 +1,83 @@
import asyncio
import uuid
from asyncio import Task, Future
from typing import override, NamedTuple, Optional, AsyncIterable
from .client_types import V1QueuePromptResponse, ProgressNotification
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
class _ProgressNotification(NamedTuple):
event: SendSyncEvent
data: SendSyncData
sid: Optional[str] = None
complete: bool = False
class QueuePromptWithProgress:
def __init__(self):
self._progress_handler = _ProgressHandler()
def progress(self) -> AsyncIterable[ProgressNotification]:
return self._progress_handler
async def get(self) -> V1QueuePromptResponse:
return await self._progress_handler.fut
def future(self) -> Future[V1QueuePromptResponse]:
return self._progress_handler.fut
@property
def progress_handler(self) -> ExecutorToClientProgress:
return self._progress_handler
def complete(self, task: Task[V1QueuePromptResponse]):
self._progress_handler.complete(task)
class _ProgressHandler(ExecutorToClientProgress, AsyncIterable[ProgressNotification]):
def __init__(self, user_id: str = None):
if user_id is None:
self.client_id = str(uuid.uuid4())
self._loop = asyncio.get_running_loop()
self._queue: asyncio.Queue[_ProgressNotification] = asyncio.Queue()
self.fut: Future[V1QueuePromptResponse] = asyncio.Future()
@override
@property
def receive_all_progress_notifications(self) -> bool:
return True
@override
@receive_all_progress_notifications.setter
def receive_all_progress_notifications(self, value: bool):
return
def send_sync(self,
event: SendSyncEvent,
data: SendSyncData,
sid: Optional[str] = None):
self._loop.call_soon_threadsafe(self._queue.put_nowait, _ProgressNotification(event, data, sid))
def complete(self, task: Task[V1QueuePromptResponse]):
if task.exception() is not None:
self.fut.set_exception(task.exception())
else:
self.fut.set_result(task.result())
self._queue.put_nowait(_ProgressNotification(None, None, None, complete=True))
def __aiter__(self):
return self
async def __anext__(self):
result: _ProgressNotification = await self._queue.get()
self._queue.task_done()
if result.complete:
if self.fut.exception() is not None:
raise self.fut.exception()
else:
raise StopAsyncIteration()
else:
return ProgressNotification(result.event, result.data, result.sid)

View File

@ -1,8 +1,10 @@
import dataclasses
from typing import List
from typing import List, NamedTuple, Optional
from typing_extensions import TypedDict, Literal, NotRequired
from comfy.component_model.executor_types import SendSyncEvent, SendSyncData
class FileOutput(TypedDict, total=False):
filename: str
@ -22,3 +24,9 @@ class Output(TypedDict, total=False):
class V1QueuePromptResponse:
urls: List[str]
outputs: dict[str, Output]
class ProgressNotification(NamedTuple):
event: SendSyncEvent
data: SendSyncData
sid: Optional[str] = None

View File

@ -10,12 +10,13 @@ import uuid
from asyncio import get_event_loop
from dataclasses import dataclass
from multiprocessing import RLock
from typing import Optional
from typing import Optional, Generator
from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode
from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress
from ..cmd.main_pre import tracer
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
@ -201,7 +202,8 @@ class Comfy:
self._is_running = False
async def queue_prompt_api(self,
prompt: PromptDict | str | dict) -> V1QueuePromptResponse:
prompt: PromptDict | str | dict,
progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse:
"""
Queues a prompt for execution, returning the output when it is complete.
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
@ -212,9 +214,31 @@ class Comfy:
if isinstance(prompt, dict):
from ..api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt)
outputs = await self.queue_prompt(prompt)
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler)
return V1QueuePromptResponse(urls=[], outputs=outputs)
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
"""
Queues a prompt with progress notifications.
>>> from comfy.client.embedded_comfy_client import Comfy
>>> from comfy.client.client_types import ProgressNotification
>>> async with Comfy() as comfy:
>>> task = comfy.queue_with_progress({ ... })
>>> # Raises an exception while iterating
>>> notification: ProgressNotification
>>> async for notification in task.progress():
>>> print(notification.data)
>>> # If you get this far, no errors occurred.
>>> result = await task.get()
:param prompt:
:return:
"""
handler = QueuePromptWithProgress()
task = asyncio.create_task(self.queue_prompt_api(prompt, progress_handler=handler.progress_handler))
task.add_done_callback(handler.complete)
return handler
@tracer.start_as_current_span("Queue Prompt")
async def queue_prompt(self,
prompt: PromptDict | dict,

View File

@ -143,6 +143,13 @@ class ExecutorToClientProgress(Protocol):
@property
def sockets_metadata(self) -> SocketsMetadataType:
"""
Metadata about what the socket supports
Currently used only by the frontend
:return: in the abstract base class, a static object that is used by the web server to ignore this; in the real classes, sometimes information about connected users
"""
return {"__unimplemented": True}
def send_sync(self,
@ -160,6 +167,13 @@ class ExecutorToClientProgress(Protocol):
pass
def send_progress_text(self, text: Union[bytes, bytearray, str], node_id: str, sid=None):
"""
Send text to the client
:param text: the text to send
:param node_id: the node this belongs to
:param sid: websocket ID / the client ID to be responding to
:return:
"""
message = encode_text_for_progress(node_id, text)
self.send_sync(BinaryEventTypes.TEXT, message, sid)
@ -167,7 +181,7 @@ class ExecutorToClientProgress(Protocol):
def queue_updated(self, queue_remaining: Optional[int] = None):
"""
Indicates that the local client's queue has been updated
:return:
:return: nothing
"""
pass

View File

@ -47,4 +47,20 @@ async def test_multithreaded_comfy():
async with Comfy(max_workers=2) as client:
prompt = sdxl_workflow_with_refiner("test")
outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)])
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)
@pytest.mark.asyncio
async def test_progress_notifications():
async with Comfy() as client:
prompt = sdxl_workflow_with_refiner("test")
task = client.queue_with_progress(prompt)
notifications_received = []
async for notification in task.progress():
notifications_received.append(notification)
assert len(notifications_received) > 0, "Should have received progress notifications"
result = await task.get()
assert result.outputs["13"]["images"][0]["abs_path"] is not None