mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Add support for intuitive progress notifications when using comfyui as a library
This commit is contained in:
parent
35d890eafe
commit
173b1ce0ae
83
comfy/client/async_progress_iterable.py
Normal file
83
comfy/client/async_progress_iterable.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from asyncio import Task, Future
|
||||||
|
from typing import override, NamedTuple, Optional, AsyncIterable
|
||||||
|
|
||||||
|
from .client_types import V1QueuePromptResponse, ProgressNotification
|
||||||
|
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||||
|
|
||||||
|
|
||||||
|
class _ProgressNotification(NamedTuple):
|
||||||
|
event: SendSyncEvent
|
||||||
|
data: SendSyncData
|
||||||
|
sid: Optional[str] = None
|
||||||
|
complete: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class QueuePromptWithProgress:
|
||||||
|
def __init__(self):
|
||||||
|
self._progress_handler = _ProgressHandler()
|
||||||
|
|
||||||
|
def progress(self) -> AsyncIterable[ProgressNotification]:
|
||||||
|
return self._progress_handler
|
||||||
|
|
||||||
|
async def get(self) -> V1QueuePromptResponse:
|
||||||
|
return await self._progress_handler.fut
|
||||||
|
|
||||||
|
def future(self) -> Future[V1QueuePromptResponse]:
|
||||||
|
return self._progress_handler.fut
|
||||||
|
|
||||||
|
@property
|
||||||
|
def progress_handler(self) -> ExecutorToClientProgress:
|
||||||
|
return self._progress_handler
|
||||||
|
|
||||||
|
def complete(self, task: Task[V1QueuePromptResponse]):
|
||||||
|
self._progress_handler.complete(task)
|
||||||
|
|
||||||
|
|
||||||
|
class _ProgressHandler(ExecutorToClientProgress, AsyncIterable[ProgressNotification]):
|
||||||
|
def __init__(self, user_id: str = None):
|
||||||
|
if user_id is None:
|
||||||
|
self.client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
self._loop = asyncio.get_running_loop()
|
||||||
|
self._queue: asyncio.Queue[_ProgressNotification] = asyncio.Queue()
|
||||||
|
self.fut: Future[V1QueuePromptResponse] = asyncio.Future()
|
||||||
|
|
||||||
|
@override
|
||||||
|
@property
|
||||||
|
def receive_all_progress_notifications(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@override
|
||||||
|
@receive_all_progress_notifications.setter
|
||||||
|
def receive_all_progress_notifications(self, value: bool):
|
||||||
|
return
|
||||||
|
|
||||||
|
def send_sync(self,
|
||||||
|
event: SendSyncEvent,
|
||||||
|
data: SendSyncData,
|
||||||
|
sid: Optional[str] = None):
|
||||||
|
self._loop.call_soon_threadsafe(self._queue.put_nowait, _ProgressNotification(event, data, sid))
|
||||||
|
|
||||||
|
def complete(self, task: Task[V1QueuePromptResponse]):
|
||||||
|
if task.exception() is not None:
|
||||||
|
self.fut.set_exception(task.exception())
|
||||||
|
else:
|
||||||
|
self.fut.set_result(task.result())
|
||||||
|
self._queue.put_nowait(_ProgressNotification(None, None, None, complete=True))
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
result: _ProgressNotification = await self._queue.get()
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
if result.complete:
|
||||||
|
if self.fut.exception() is not None:
|
||||||
|
raise self.fut.exception()
|
||||||
|
else:
|
||||||
|
raise StopAsyncIteration()
|
||||||
|
else:
|
||||||
|
return ProgressNotification(result.event, result.data, result.sid)
|
||||||
@ -1,8 +1,10 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import List
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
from typing_extensions import TypedDict, Literal, NotRequired
|
from typing_extensions import TypedDict, Literal, NotRequired
|
||||||
|
|
||||||
|
from comfy.component_model.executor_types import SendSyncEvent, SendSyncData
|
||||||
|
|
||||||
|
|
||||||
class FileOutput(TypedDict, total=False):
|
class FileOutput(TypedDict, total=False):
|
||||||
filename: str
|
filename: str
|
||||||
@ -22,3 +24,9 @@ class Output(TypedDict, total=False):
|
|||||||
class V1QueuePromptResponse:
|
class V1QueuePromptResponse:
|
||||||
urls: List[str]
|
urls: List[str]
|
||||||
outputs: dict[str, Output]
|
outputs: dict[str, Output]
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressNotification(NamedTuple):
|
||||||
|
event: SendSyncEvent
|
||||||
|
data: SendSyncData
|
||||||
|
sid: Optional[str] = None
|
||||||
|
|||||||
@ -10,12 +10,13 @@ import uuid
|
|||||||
from asyncio import get_event_loop
|
from asyncio import get_event_loop
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import RLock
|
from multiprocessing import RLock
|
||||||
from typing import Optional
|
from typing import Optional, Generator
|
||||||
|
|
||||||
from opentelemetry import context, propagate
|
from opentelemetry import context, propagate
|
||||||
from opentelemetry.context import Context, attach, detach
|
from opentelemetry.context import Context, attach, detach
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress
|
||||||
from ..cmd.main_pre import tracer
|
from ..cmd.main_pre import tracer
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
@ -201,7 +202,8 @@ class Comfy:
|
|||||||
self._is_running = False
|
self._is_running = False
|
||||||
|
|
||||||
async def queue_prompt_api(self,
|
async def queue_prompt_api(self,
|
||||||
prompt: PromptDict | str | dict) -> V1QueuePromptResponse:
|
prompt: PromptDict | str | dict,
|
||||||
|
progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse:
|
||||||
"""
|
"""
|
||||||
Queues a prompt for execution, returning the output when it is complete.
|
Queues a prompt for execution, returning the output when it is complete.
|
||||||
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
|
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
|
||||||
@ -212,9 +214,31 @@ class Comfy:
|
|||||||
if isinstance(prompt, dict):
|
if isinstance(prompt, dict):
|
||||||
from ..api.components.schema.prompt import Prompt
|
from ..api.components.schema.prompt import Prompt
|
||||||
prompt = Prompt.validate(prompt)
|
prompt = Prompt.validate(prompt)
|
||||||
outputs = await self.queue_prompt(prompt)
|
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler)
|
||||||
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
||||||
|
|
||||||
|
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
|
||||||
|
"""
|
||||||
|
Queues a prompt with progress notifications.
|
||||||
|
|
||||||
|
>>> from comfy.client.embedded_comfy_client import Comfy
|
||||||
|
>>> from comfy.client.client_types import ProgressNotification
|
||||||
|
>>> async with Comfy() as comfy:
|
||||||
|
>>> task = comfy.queue_with_progress({ ... })
|
||||||
|
>>> # Raises an exception while iterating
|
||||||
|
>>> notification: ProgressNotification
|
||||||
|
>>> async for notification in task.progress():
|
||||||
|
>>> print(notification.data)
|
||||||
|
>>> # If you get this far, no errors occurred.
|
||||||
|
>>> result = await task.get()
|
||||||
|
:param prompt:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
handler = QueuePromptWithProgress()
|
||||||
|
task = asyncio.create_task(self.queue_prompt_api(prompt, progress_handler=handler.progress_handler))
|
||||||
|
task.add_done_callback(handler.complete)
|
||||||
|
return handler
|
||||||
|
|
||||||
@tracer.start_as_current_span("Queue Prompt")
|
@tracer.start_as_current_span("Queue Prompt")
|
||||||
async def queue_prompt(self,
|
async def queue_prompt(self,
|
||||||
prompt: PromptDict | dict,
|
prompt: PromptDict | dict,
|
||||||
|
|||||||
@ -143,6 +143,13 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def sockets_metadata(self) -> SocketsMetadataType:
|
def sockets_metadata(self) -> SocketsMetadataType:
|
||||||
|
"""
|
||||||
|
Metadata about what the socket supports
|
||||||
|
|
||||||
|
Currently used only by the frontend
|
||||||
|
|
||||||
|
:return: in the abstract base class, a static object that is used by the web server to ignore this; in the real classes, sometimes information about connected users
|
||||||
|
"""
|
||||||
return {"__unimplemented": True}
|
return {"__unimplemented": True}
|
||||||
|
|
||||||
def send_sync(self,
|
def send_sync(self,
|
||||||
@ -160,6 +167,13 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def send_progress_text(self, text: Union[bytes, bytearray, str], node_id: str, sid=None):
|
def send_progress_text(self, text: Union[bytes, bytearray, str], node_id: str, sid=None):
|
||||||
|
"""
|
||||||
|
Send text to the client
|
||||||
|
:param text: the text to send
|
||||||
|
:param node_id: the node this belongs to
|
||||||
|
:param sid: websocket ID / the client ID to be responding to
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
message = encode_text_for_progress(node_id, text)
|
message = encode_text_for_progress(node_id, text)
|
||||||
|
|
||||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
||||||
@ -167,7 +181,7 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
def queue_updated(self, queue_remaining: Optional[int] = None):
|
def queue_updated(self, queue_remaining: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Indicates that the local client's queue has been updated
|
Indicates that the local client's queue has been updated
|
||||||
:return:
|
:return: nothing
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -48,3 +48,19 @@ async def test_multithreaded_comfy():
|
|||||||
prompt = sdxl_workflow_with_refiner("test")
|
prompt = sdxl_workflow_with_refiner("test")
|
||||||
outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)])
|
outputs_iter = await asyncio.gather(*[client.queue_prompt(prompt) for _ in range(4)])
|
||||||
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)
|
assert all(outputs["13"]["images"][0]["abs_path"] is not None for outputs in outputs_iter)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_progress_notifications():
|
||||||
|
async with Comfy() as client:
|
||||||
|
prompt = sdxl_workflow_with_refiner("test")
|
||||||
|
task = client.queue_with_progress(prompt)
|
||||||
|
|
||||||
|
notifications_received = []
|
||||||
|
async for notification in task.progress():
|
||||||
|
notifications_received.append(notification)
|
||||||
|
|
||||||
|
assert len(notifications_received) > 0, "Should have received progress notifications"
|
||||||
|
|
||||||
|
result = await task.get()
|
||||||
|
assert result.outputs["13"]["images"][0]["abs_path"] is not None
|
||||||
Loading…
Reference in New Issue
Block a user