mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
Add support for intuitive progress notifications when using comfyui as a library
This commit is contained in:
parent
35d890eafe
commit
173b1ce0ae
83
comfy/client/async_progress_iterable.py
Normal file
83
comfy/client/async_progress_iterable.py
Normal file
@ -0,0 +1,83 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from asyncio import Task, Future
|
||||
from typing import override, NamedTuple, Optional, AsyncIterable
|
||||
|
||||
from .client_types import V1QueuePromptResponse, ProgressNotification
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||
|
||||
|
||||
class _ProgressNotification(NamedTuple):
|
||||
event: SendSyncEvent
|
||||
data: SendSyncData
|
||||
sid: Optional[str] = None
|
||||
complete: bool = False
|
||||
|
||||
|
||||
class QueuePromptWithProgress:
|
||||
def __init__(self):
|
||||
self._progress_handler = _ProgressHandler()
|
||||
|
||||
def progress(self) -> AsyncIterable[ProgressNotification]:
|
||||
return self._progress_handler
|
||||
|
||||
async def get(self) -> V1QueuePromptResponse:
|
||||
return await self._progress_handler.fut
|
||||
|
||||
def future(self) -> Future[V1QueuePromptResponse]:
|
||||
return self._progress_handler.fut
|
||||
|
||||
@property
|
||||
def progress_handler(self) -> ExecutorToClientProgress:
|
||||
return self._progress_handler
|
||||
|
||||
def complete(self, task: Task[V1QueuePromptResponse]):
|
||||
self._progress_handler.complete(task)
|
||||
|
||||
|
||||
class _ProgressHandler(ExecutorToClientProgress, AsyncIterable[ProgressNotification]):
|
||||
def __init__(self, user_id: str = None):
|
||||
if user_id is None:
|
||||
self.client_id = str(uuid.uuid4())
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._queue: asyncio.Queue[_ProgressNotification] = asyncio.Queue()
|
||||
self.fut: Future[V1QueuePromptResponse] = asyncio.Future()
|
||||
|
||||
@override
|
||||
@property
|
||||
def receive_all_progress_notifications(self) -> bool:
|
||||
return True
|
||||
|
||||
@override
|
||||
@receive_all_progress_notifications.setter
|
||||
def receive_all_progress_notifications(self, value: bool):
|
||||
return
|
||||
|
||||
def send_sync(self,
|
||||
event: SendSyncEvent,
|
||||
data: SendSyncData,
|
||||
sid: Optional[str] = None):
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, _ProgressNotification(event, data, sid))
|
||||
|
||||
def complete(self, task: Task[V1QueuePromptResponse]):
|
||||
if task.exception() is not None:
|
||||
self.fut.set_exception(task.exception())
|
||||
else:
|
||||
self.fut.set_result(task.result())
|
||||
self._queue.put_nowait(_ProgressNotification(None, None, None, complete=True))
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
result: _ProgressNotification = await self._queue.get()
|
||||
self._queue.task_done()
|
||||
|
||||
if result.complete:
|
||||
if self.fut.exception() is not None:
|
||||
raise self.fut.exception()
|
||||
else:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return ProgressNotification(result.event, result.data, result.sid)
|
||||
@ -1,8 +1,10 @@
|
||||
import dataclasses
|
||||
from typing import List
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
from typing_extensions import TypedDict, Literal, NotRequired
|
||||
|
||||
from comfy.component_model.executor_types import SendSyncEvent, SendSyncData
|
||||
|
||||
|
||||
class FileOutput(TypedDict, total=False):
|
||||
filename: str
|
||||
@ -22,3 +24,9 @@ class Output(TypedDict, total=False):
|
||||
class V1QueuePromptResponse:
|
||||
urls: List[str]
|
||||
outputs: dict[str, Output]
|
||||
|
||||
|
||||
class ProgressNotification(NamedTuple):
|
||||
event: SendSyncEvent
|
||||
data: SendSyncData
|
||||
sid: Optional[str] = None
|
||||
|
||||
@ -10,12 +10,13 @@ import uuid
|
||||
from asyncio import get_event_loop
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import RLock
|
||||
from typing import Optional
|
||||
from typing import Optional, Generator
|
||||
|
||||
from opentelemetry import context, propagate
|
||||
from opentelemetry.context import Context, attach, detach
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress
|
||||
from ..cmd.main_pre import tracer
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
@ -201,7 +202,8 @@ class Comfy:
|
||||
self._is_running = False
|
||||
|
||||
async def queue_prompt_api(self,
|
||||
prompt: PromptDict | str | dict) -> V1QueuePromptResponse:
|
||||
prompt: PromptDict | str | dict,
|
||||
progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse:
|
||||
"""
|
||||
Queues a prompt for execution, returning the output when it is complete.
|
||||
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
|
||||
@ -212,9 +214,31 @@ class Comfy:
|
||||
if isinstance(prompt, dict):
|
||||
from ..api.components.schema.prompt import Prompt
|
||||
prompt = Prompt.validate(prompt)
|
||||
outputs = await self.queue_prompt(prompt)
|
||||
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler)
|
||||
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
||||
|
||||
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
|
||||
"""
|
||||
Queues a prompt with progress notifications.
|
||||
|
||||
>>> from comfy.client.embedded_comfy_client import Comfy
|
||||
>>> from comfy.client.client_types import ProgressNotification
|
||||
>>> async with Comfy() as comfy:
|
||||
>>> task = comfy.queue_with_progress({ ... })
|
||||
>>> # Raises an exception while iterating
|
||||
>>> notification: ProgressNotification
|
||||
>>> async for notification in task.progress():
|
||||
>>> print(notification.data)
|
||||
>>> # If you get this far, no errors occurred.
|
||||
>>> result = await task.get()
|
||||
:param prompt:
|
||||
:return:
|
||||
"""
|
||||
handler = QueuePromptWithProgress()
|
||||
task = asyncio.create_task(self.queue_prompt_api(prompt, progress_handler=handler.progress_handler))
|
||||
task.add_done_callback(handler.complete)
|
||||
return handler
|
||||
|
||||
@tracer.start_as_current_span("Queue Prompt")
|
||||
async def queue_prompt(self,
|
||||
prompt: PromptDict | dict,
|
||||
|
||||
@ -143,6 +143,13 @@ class ExecutorToClientProgress(Protocol):
|
||||
|
||||
@property
|
||||
def sockets_metadata(self) -> SocketsMetadataType:
|
||||
"""
|
||||
Metadata about what the socket supports
|
||||
|
||||
Currently used only by the frontend
|
||||
|
||||
:return: in the abstract base class, a static object that is used by the web server to ignore this; in the real classes, sometimes information about connected users
|
||||
"""
|
||||
return {"__unimplemented": True}
|
||||
|
||||
def send_sync(self,
|
||||
@ -160,6 +167,13 @@ class ExecutorToClientProgress(Protocol):
|
||||
pass
|
||||
|
||||
def send_progress_text(self, text: Union[bytes, bytearray, str], node_id: str, sid=None):
|
||||
"""
|
||||
Send text to the client
|
||||
:param text: the text to send
|
||||
:param node_id: the node this belongs to
|
||||
:param sid: websocket ID / the client ID to be responding to
|
||||
:return:
|
||||
"""
|
||||
message = encode_text_for_progress(node_id, text)
|
||||
|
||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
||||
@ -167,7 +181,7 @@ class ExecutorToClientProgress(Protocol):
|
||||
def queue_updated(self, queue_remaining: Optional[int] = None):
|
||||
"""
|
||||
Indicates that the local client's queue has been updated
|
||||
:return:
|
||||
:return: nothing
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -47,4 +47,20 @@ async def test_multithreaded_comfy():
|
||||
async with Comfy(max_workers=2) as client:
|
||||
prompt = sdxl_workflow_with_refiner("test")
|
||||
outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)])
|
||||
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)
|
||||
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_notifications():
|
||||
async with Comfy() as client:
|
||||
prompt = sdxl_workflow_with_refiner("test")
|
||||
task = client.queue_with_progress(prompt)
|
||||
|
||||
notifications_received = []
|
||||
async for notification in task.progress():
|
||||
notifications_received.append(notification)
|
||||
|
||||
assert len(notifications_received) > 0, "Should have received progress notifications"
|
||||
|
||||
result = await task.get()
|
||||
assert result.outputs["13"]["images"][0]["abs_path"] is not None
|
||||
Loading…
Reference in New Issue
Block a user