ComfyUI/comfy/client/async_progress_iterable.py

85 lines
2.7 KiB
Python

import asyncio
import uuid
from asyncio import Task, Future
from typing import NamedTuple, Optional, AsyncIterable
from typing_extensions import override
from .client_types import V1QueuePromptResponse, ProgressNotification
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
class _ProgressNotification(NamedTuple):
event: SendSyncEvent
data: SendSyncData
sid: Optional[str] = None
complete: bool = False
class QueuePromptWithProgress:
def __init__(self):
self._progress_handler = _ProgressHandler()
def progress(self) -> AsyncIterable[ProgressNotification]:
return self._progress_handler
async def get(self) -> V1QueuePromptResponse:
return await self._progress_handler.fut
def future(self) -> Future[V1QueuePromptResponse]:
return self._progress_handler.fut
@property
def progress_handler(self) -> ExecutorToClientProgress:
return self._progress_handler
def complete(self, task: Task[V1QueuePromptResponse]):
self._progress_handler.complete(task)
class _ProgressHandler(ExecutorToClientProgress, AsyncIterable[ProgressNotification]):
def __init__(self, user_id: str = None):
if user_id is None:
self.client_id = str(uuid.uuid4())
self._loop = asyncio.get_running_loop()
self._queue: asyncio.Queue[_ProgressNotification] = asyncio.Queue()
self.fut: Future[V1QueuePromptResponse] = asyncio.Future()
@override
@property
def receive_all_progress_notifications(self) -> bool:
return True
@override
@receive_all_progress_notifications.setter
def receive_all_progress_notifications(self, value: bool):
return
def send_sync(self,
event: SendSyncEvent,
data: SendSyncData,
sid: Optional[str] = None):
self._loop.call_soon_threadsafe(self._queue.put_nowait, _ProgressNotification(event, data, sid))
def complete(self, task: Task[V1QueuePromptResponse]):
if task.exception() is not None:
self.fut.set_exception(task.exception())
else:
self.fut.set_result(task.result())
self._queue.put_nowait(_ProgressNotification(None, None, None, complete=True))
def __aiter__(self):
return self
async def __anext__(self):
result: _ProgressNotification = await self._queue.get()
self._queue.task_done()
if result.complete:
if self.fut.exception() is not None:
raise self.fut.exception()
else:
raise StopAsyncIteration()
else:
return ProgressNotification(result.event, result.data, result.sid)