Add method to make this congruent with aio client

This commit is contained in:
doctorpangloss 2024-09-26 18:08:15 -07:00
parent ab1a1de7a4
commit dbc8ee92a5

View File

@ -14,6 +14,7 @@ from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration
from ..cmd.main_pre import tracer
@ -132,6 +133,7 @@ class EmbeddedComfyClient:
In order to use this in blocking methods, learn more about asyncio online.
"""
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
self._progress_handler = progress_handler or ServerStub()
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
@ -162,6 +164,11 @@ class EmbeddedComfyClient:
self._executor.shutdown(wait=True)
self._is_running = False
async def queue_prompt_api(self,
prompt: PromptDict) -> V1QueuePromptResponse:
outputs = await self.queue_prompt(prompt)
return V1QueuePromptResponse(**outputs)
@tracer.start_as_current_span("Queue Prompt")
async def queue_prompt(self,
prompt: PromptDict | dict,