Improve API return values and tracing reports

This commit is contained in:
doctorpangloss 2024-05-08 15:52:17 -07:00
parent aa0cfb54ce
commit 0d8924442a
7 changed files with 136 additions and 46 deletions

View File

@ -483,14 +483,23 @@ paths:
400:
description: |
The prompt is invalid.
content:
application/json:
description: A validation error dictionary from the ComfyUI frontend.
schema:
$ref: "#/components/schemas/ValidationErrorDict"
429:
description: |
The queue is currently too long to process your request.
500:
description: |
An unexpected exception occurred and it is being passed to you.
This can occur if file was referenced in a LoadImage / LoadImageMask that doesn't exist.
An execution error occurred while processing your prompt.
content:
application/json:
description:
An execution status directly from the workers
schema:
$ref: "#/components/schemas/ExecutionStatusAsDict"
507:
description: |
The server had an IOError like running out of disk space.
@ -803,4 +812,48 @@ components:
type:
type: string
abs_path:
type: string
type: string
ValidationErrorDict:
type: object
properties:
type:
type: string
message:
type: string
details:
type: string
extra_info:
type: object
required:
- exception_type
- traceback
properties:
exception_type:
type: string
traceback:
type: array
items:
type: string
required:
- type
- details
- extra_info
- message
ExecutionStatusAsDict:
type: object
required:
- status_str
- completed
- messages
properties:
status_str:
type: string
enum:
- "success"
- "error"
completed:
type: bool
messages:
type: array
items:
type: string

View File

@ -14,7 +14,7 @@ from asyncio import Future, AbstractEventLoop
from enum import Enum
from io import BytesIO
from posixpath import join as urljoin
from typing import List, Optional, Dict
from typing import List, Optional
from urllib.parse import quote, urlencode
import aiofiles
@ -32,13 +32,14 @@ from .. import model_management
from .. import utils
from ..app.user_manager import UserManager
from ..cli_args import args
from ..client.client_types import Output, FileOutput
from ..client.client_types import FileOutput
from ..cmd import execution
from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
ExecutionStatus
from ..digest import digest
from ..images import open_image
from ..nodes.package_typing import ExportedNodes
@ -602,26 +603,34 @@ class PromptServer(ExecutorToClientProgress):
number = self.number
self.number += 1
result: TaskInvocation
completed: Future[TaskInvocation | dict] = self.loop.create_future()
item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed)
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
# this enables span propagation seamlessly
result = await self.prompt_queue.put_async(item)
if result is None:
return web.Response(body="the queue is shutting down", status=503)
else:
try:
try:
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
# this enables span propagation seamlessly
result = await self.prompt_queue.put_async(item)
if result is None:
return web.Response(body="the queue is shutting down", status=503)
else:
self.prompt_queue.put(item)
await completed
except Exception as ex:
return web.Response(body=str(ex), status=503)
# expect a single image
result: TaskInvocation | dict = completed.result()
outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result
task_invocation_or_dict: TaskInvocation | dict = completed.result()
if isinstance(task_invocation_or_dict, dict):
result = TaskInvocation(item_id=item.prompt_id, outputs=task_invocation_or_dict, status=ExecutionStatus("success", True, []))
else:
result = task_invocation_or_dict
except ExecutionError as exec_exc:
result = exec_exc.as_task_invocation()
except Exception as ex:
return web.Response(body=str(ex), status=500)
if result.status is not None and result.status.status_str == "error":
return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json")
# find images and read them
output_images: List[FileOutput] = []
for node_id, node in outputs_dict.items():
for node_id, node in result.outputs.items():
images: List[FileOutput] = []
if 'images' in node:
images = node['images']
@ -666,7 +675,7 @@ class PromptServer(ExecutorToClientProgress):
headers=digest_headers_,
body=json.dumps({
'urls': urls_,
'outputs': outputs_dict
'outputs': result.outputs
}))
elif accept == "image/png":
return web.FileResponse(main_image["abs_path"],

View File

@ -26,7 +26,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
@abstractmethod
def put(self, item: QueueItem):
"""
Puts an item on the queue.
Puts an item on the queue. Does not block or wait
:param item: a queue item
:return:
"""
@ -120,11 +120,17 @@ class AbstractPromptQueue(metaclass=ABCMeta):
pass
class AsyncAbstractPromptQueue(AbstractPromptQueue):
class AsyncAbstractPromptQueue(metaclass=ABCMeta):
@abstractmethod
async def put_async(self, queue_item) -> TaskInvocation | None:
"""
Puts the item on the queue, and waits until it is complete
:param queue_item:
:return:
:raises: ExecutionException when the worker returns an error, which can be cast to a task invocation
"""
pass
@abstractmethod
async def get_async(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
pass
pass

View File

@ -1,12 +1,13 @@
from __future__ import annotations
import asyncio
from enum import Enum
from typing import NamedTuple, Optional, List, Literal
from typing_extensions import NotRequired, TypedDict
from dataclasses import dataclass
from enum import Enum
from typing import NamedTuple, Optional, List, Literal, Sequence
from typing import Tuple
from typing_extensions import NotRequired, TypedDict
QueueTuple = Tuple[float, str, dict, dict, list]
MAXIMUM_HISTORY_SIZE = 10000
@ -23,6 +24,28 @@ class ExecutionStatus(NamedTuple):
messages: List[str]
class ExecutionError(RuntimeError):
def __init__(self, task_id: int | str, status: Optional[ExecutionStatus] = None, exceptions: Optional[Sequence[Exception]] = None, *args):
super().__init__(*args)
self._task_id = task_id
if status is not None:
self._status = status
elif exceptions is not None:
self._status = ExecutionStatus('error', False, [str(ex) for ex in exceptions])
else:
self._status = ExecutionStatus('error', False, [])
@property
def status(self) -> ExecutionStatus:
return self._status
def as_task_invocation(self) -> TaskInvocation:
return TaskInvocation(self._task_id, {}, self.status)
def __str__(self):
return ",".join(self._status.messages)
class ExecutionStatusAsDict(TypedDict):
status_str: Literal['success', 'error']
completed: bool

View File

@ -20,12 +20,13 @@ from .server_stub import ServerStub
from ..auth.permissions import jwt_decode
from ..cmd.main_pre import tracer
from ..cmd.server import PromptServer
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation, \
ExecutionError
class DistributedPromptQueue(AsyncAbstractPromptQueue):
class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
"""
A distributed prompt queue for the ComfyUI web client and single-threaded worker.
"""
@ -44,7 +45,7 @@ class DistributedPromptQueue(AsyncAbstractPromptQueue):
async def put_async(self, queue_item: QueueItem) -> TaskInvocation | None:
assert self._is_caller
assert self._rpc is not None
reply: TaskInvocation
if self._closing:
return None
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
@ -71,33 +72,31 @@ class DistributedPromptQueue(AsyncAbstractPromptQueue):
assert self._caller_progress_handlers is not None
await self._caller_progress_handlers.register_progress(user_id)
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
res: TaskInvocation = RpcReply(
**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
self._caller_history.put(queue_item, res.outputs, res.status)
reply = RpcReply(**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
self._caller_history.put(queue_item, reply.outputs, reply.status)
if self._caller_server is not None:
self._caller_server.queue_updated()
# if this has a completion future, complete it
if queue_item.completed is not None:
queue_item.completed.set_result(res)
return res
except Exception as e:
queue_item.completed.set_result(reply)
except Exception as exc:
# if a caller-side error occurred, use the passed error for the messages
# we didn't receive any outputs here
self._caller_history.put(queue_item, outputs={},
status=ExecutionStatus(status_str="error", completed=False, messages=[str(e)]))
as_exec_exc = ExecutionError(queue_item.prompt_id, exceptions=[exc])
self._caller_history.put(queue_item, outputs={}, status=as_exec_exc.status)
# if we have a completer, propoagate the exception to it
if queue_item.completed is not None:
queue_item.completed.set_exception(e)
raise e
queue_item.completed.set_exception(as_exec_exc)
raise as_exec_exc
finally:
self._caller_local_in_progress.pop(queue_item.prompt_id)
if self._caller_server is not None:
# todo: this ensures that the web ui is notified about the completed task, but it should really be done by worker
self._caller_server.send_sync("executing", {"node": None, "prompt_id": queue_item.prompt_id},
self._caller_server.client_id)
self._caller_server.send_sync("executing", {"node": None, "prompt_id": queue_item.prompt_id}, self._caller_server.client_id)
self._caller_server.queue_updated()
return reply
def put(self, item: QueueItem):
# caller: execute on main thread

View File

@ -127,7 +127,7 @@ async def test_frontend_backend_workers():
]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
server_address = f"http://{get_lan_ip()}:8188"
server_address = f"http://{get_lan_ip()}:9001"
start_time = time.time()
while time.time() - start_time < 60:
try:

View File

@ -210,12 +210,12 @@ def test_image_exif_merge():
assert res[1].exif["a"] == "1"
@freeze_time("2012-01-14 03:21:34", tz_offset=-4)
@freeze_time("2024-01-14 03:21:34", tz_offset=-4)
def test_image_exif_creation_date_and_batch_number():
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
n = ImageExifCreationDateAndBatchNumber()
res, = n.execute(images=[_image_1x1, _image_1x1])
mock_now = datetime(2012, 1, 13, 23, 21, 34)
mock_now = datetime(2024, 1, 13, 23, 21, 34)
now_formatted = mock_now.strftime("%Y:%m:%d %H:%M:%S%z")
assert res[0].exif["ImageNumber"] == "0"