Improve API support

- Removed /api/v1/images because you should use your own CDN style
   image host and /view for maximum compatibility
 - The /api/v1/prompts POST application/json response will now return
   the outputs dictionary
 - Caching has been removed
 - More tests
 - Subdirectory prefixes are now supported
 - Fixed an issue where a Linux frontend and Windows backend would have
   paths that could not interact with each other correctly
This commit is contained in:
doctorpangloss 2024-03-21 16:24:22 -07:00
parent d73b116446
commit 0db040cc47
11 changed files with 151 additions and 98 deletions

View File

@ -15,6 +15,8 @@ paths:
description: the index.html of the website
content:
text/html:
schema:
type: string
example: "<!DOCTYPE html>..."
/embeddings:
get:
@ -390,7 +392,7 @@ paths:
The complete outputs dictionary from the workflow.
Additionally, a list of URLs to binary outputs, whenever save nodes are used.
For each SaveImage node, there will be two URLs: the internal URL returned by the worker, and
the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration.
@ -454,9 +456,25 @@ paths:
type: string
outputs:
$ref: "#/components/schemas/Outputs"
example:
outputs: {}
urls: [ "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output", "https://comfyui.example.com/view?filename=ComfyUI_00001_.png&type=output" ]
example:
outputs: { }
urls: [ "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output", "https://comfyui.example.com/view?filename=ComfyUI_00001_.png&type=output" ]
multipart/mixed:
encoding:
"^\\d+$":
contentType: image/png, image/jpeg, image/webp
schema:
type: object
description: |
Each of the output nodes' binary values for images, and the "outputs" json object.
properties:
outputs:
$ref: "#/components/schemas/Outputs"
additionalProperties:
patternProperties:
"^\\d+$":
type: string
format: binary
204:
description: |
The prompt was run but did not contain any SaveImage outputs, so nothing will be returned.
@ -489,9 +507,12 @@ paths:
enum:
- "application/json"
- "image/png"
- "multipart/mixed"
required: false
description: |
Specifies the media type the client is willing to receive.
multipart/mixed will soon be supported to return all the images from the workflow.
requestBody:
content:
application/json:

View File

@ -178,12 +178,14 @@ def create_parser() -> argparse.ArgumentParser:
help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.")
parser.add_argument("--max-queue-size", type=int, default=65536, help="The API will reject prompt requests if the queue's size exceeds this value.")
# now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'):
try:
plugin_callable: ConfigurationExtender | ModuleType = entry_point.load()
if isinstance(plugin_callable, ModuleType):
# todo: find the configuration extender in the module
plugin_callable = ...
else:
parser_result = plugin_callable(parser)

View File

@ -80,6 +80,7 @@ class Configuration(dict):
external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths.
verbose (bool): Shows extra output for debugging purposes such as import errors of custom nodes.
disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI.
max_queue_size (int): The API will reject prompt requests if the queue's size exceeds this value.
"""
def __init__(self, **kwargs):
@ -144,6 +145,7 @@ class Configuration(dict):
self.distributed_queue_name: str = "comfyui"
self.external_address: Optional[str] = None
self.disable_known_models: bool = False
self.max_queue_size: int = 65536
for key, value in kwargs.items():
self[key] = value

View File

@ -3,12 +3,14 @@ import uuid
from asyncio import AbstractEventLoop
from collections import defaultdict
from pathlib import Path
from typing import Optional, List, Dict
from typing import Optional, List
from urllib.parse import urlparse, urljoin
import aiohttp
from aiohttp import WSMessage, ClientResponse
from typing_extensions import Dict
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt_request import PromptRequest
@ -17,6 +19,9 @@ from ..component_model.file_output_path import file_output_path
class AsyncRemoteComfyClient:
"""
An asynchronous client for remote servers
"""
__json_encoder = JSONEncoder()
def __init__(self, server_address: str = "http://localhost:8188", client_id: str = str(uuid.uuid4()),
@ -28,11 +33,11 @@ class AsyncRemoteComfyClient:
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
self.loop = loop or asyncio.get_event_loop()
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]:
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
"""
Calls the API to queue a prompt.
:param prompt:
:return: a list of URLs corresponding to the SaveImage nodes in the prompt.
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
async with aiohttp.ClientSession() as session:
@ -41,10 +46,18 @@ class AsyncRemoteComfyClient:
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
if response.status == 200:
return (await response.json())["urls"]
return V1QueuePromptResponse(**(await response.json()))
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]:
"""
Calls the API to queue a prompt.
:param prompt:
:return: a list of URLs corresponding to the SaveImage nodes in the prompt.
"""
return (await self.queue_prompt_api(prompt)).urls
async def queue_prompt(self, prompt: PromptDict) -> bytes:
"""
Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node.

View File

@ -0,0 +1,22 @@
import dataclasses
from typing import List
from typing_extensions import TypedDict, Literal, NotRequired, Dict
class FileOutput(TypedDict, total=False):
filename: str
subfolder: str
type: Literal["output", "input", "temp"]
abs_path: str
class Output(TypedDict, total=False):
latents: NotRequired[List[FileOutput]]
images: NotRequired[List[FileOutput]]
@dataclasses.dataclass
class V1QueuePromptResponse:
urls: List[str]
outputs: Dict[str, Output]

View File

@ -96,7 +96,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
return results
def get_output_data(obj, input_data_all):
def get_output_data(obj, input_data_all) -> Tuple[List[typing.Any], typing.Dict[str, List[typing.Any]]]:
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)

View File

@ -1,43 +1,48 @@
from __future__ import annotations
import asyncio
import traceback
import glob
import json
import logging
import mimetypes
import os
import struct
import sys
import shutil
from urllib.parse import quote, urljoin
from pkg_resources import resource_filename
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
import logging
import json
import os
import traceback
import uuid
from asyncio import Future, AbstractEventLoop
from typing import List, Optional
from io import BytesIO
from typing import List, Optional, Dict
from urllib.parse import quote, urlencode
from posixpath import join as urljoin
from can_ada import URL, parse as urlparse
import aiofiles
import aiohttp
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from aiohttp import web
from pkg_resources import resource_filename
from typing_extensions import NamedTuple
import comfy.interruption
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
from .. import model_management
from .. import utils
from ..app.user_manager import UserManager
from ..cli_args import args
from ..client.client_types import Output, FileOutput
from ..cmd import execution
from ..cmd import folder_paths
import mimetypes
from ..digest import digest
from ..cli_args import args
from .. import utils
from .. import model_management
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
from ..digest import digest
from ..nodes.package_typing import ExportedNodes
from ..vendor.appdirs import user_data_dir
from ..app.user_manager import UserManager
class HeuristicPath(NamedTuple):
filename_heuristic: str
abs_path: str
async def send_socket_catch_exception(function, message):
@ -543,13 +548,6 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=200)
@routes.get("/api/v1/images/{content_digest}")
async def get_image(request: web.Request) -> web.FileResponse:
digest_ = request.match_info['content_digest']
path = str(os.path.join(user_data_dir("comfyui", "comfyanonymous", roaming=False), digest_))
return web.FileResponse(path,
headers={"Content-Disposition": f"filename=\"{digest_}.png\""})
@routes.post("/api/v1/prompts")
async def post_prompt(request: web.Request) -> web.Response | web.FileResponse:
# check if the queue is too long
@ -595,30 +593,6 @@ class PromptServer(ExecutorToClientProgress):
if not valid[0]:
return web.Response(status=400, body=valid[1])
cache_path = os.path.join(user_data_dir("comfyui", "comfyanonymous", roaming=False), content_digest)
cache_url = f"/api/v1/images/{content_digest}"
if os.path.exists(cache_path):
filename__ = os.path.basename(cache_path)
digest_headers_ = {
"Digest": f"SHA-256={content_digest}",
"Location": f"/api/v1/images/{content_digest}"
}
if accept == "application/json":
return web.Response(status=200,
headers=digest_headers_,
content_type="application/json",
body=json.dumps({'urls': [cache_url]}))
elif accept == "image/png":
return web.FileResponse(cache_path,
headers={"Content-Disposition": f"filename=\"{filename__}\"",
**digest_headers_})
else:
return web.json_response(status=400, reason=f"invalid accept header {accept}")
# todo: check that the files specified in the InputFile nodes exist
# convert a valid prompt to the queue tuple this expects
completed: Future[TaskInvocation | dict] = self.loop.create_future()
number = self.number
@ -633,55 +607,57 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(body=str(ex), status=503)
# expect a single image
result: TaskInvocation | dict = completed.result()
outputs_dict: dict = result.outputs if isinstance(result, TaskInvocation) else result
outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result
# find images and read them
output_images: List[str] = []
output_images: List[FileOutput] = []
for node_id, node in outputs_dict.items():
images: List[dict] = []
images: List[FileOutput] = []
if 'images' in node:
images = node['images']
# todo: does this ever occur?
elif (isinstance(node, dict)
and 'ui' in node and isinstance(node['ui'], dict)
and 'images' in node['ui']):
images = node['ui']['images']
for image_tuple in images:
filename_ = image_tuple['abs_path']
output_images.append(filename_)
output_images.append(image_tuple)
if len(output_images) > 0:
image_ = output_images[-1]
if not os.path.exists(os.path.dirname(cache_path)):
try:
os.makedirs(os.path.dirname(cache_path))
except:
pass
# shutil.copy(image_, cache_path)
filename = os.path.basename(image_)
main_image = output_images[-1]
filename = main_image["filename"]
digest_headers_ = {
"Digest": f"SHA-256={content_digest}",
}
urls_ = [cache_url]
urls_ = []
if len(output_images) == 1:
digest_headers_.update({
"Location": f"/api/v1/images/{content_digest}",
"Content-Disposition": f"filename=\"{filename}\""
})
for image_indv_ in output_images:
image_indv_filename_ = os.path.basename(image_indv_)
urls_ += [
f"http://{self.address}:{self.port}/view?filename={image_indv_filename_}&type=output",
urljoin(self.external_address, f"/view?filename={image_indv_filename_}&type=output")
]
local_address = f"http://{self.address}:{self.port}"
external_address = self.external_address
for base in (local_address, external_address):
url: URL = urlparse(urljoin(base, "view"))
url_search_dict: FileOutput = dict(image_indv_)
del url_search_dict["abs_path"]
if url_search_dict["subfolder"] == "":
del url_search_dict["subfolder"]
url.search = f"?{urlencode(url_search_dict)}"
urls_.append(str(url))
if accept == "application/json":
return web.Response(status=200,
content_type="application/json",
headers=digest_headers_,
body=json.dumps({'urls': urls_}))
body=json.dumps({
'urls': urls_,
'outputs': outputs_dict
}))
elif accept == "image/png":
return web.FileResponse(image_,
return web.FileResponse(main_image["abs_path"],
headers=digest_headers_)
else:
return web.Response(status=204)
@ -750,7 +726,7 @@ class PromptServer(ExecutorToClientProgress):
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
@ -838,5 +814,4 @@ class PromptServer(ExecutorToClientProgress):
@classmethod
def get_too_busy_queue_size(cls):
# todo: what is too busy of a queue for API clients?
return 100
return args.max_queue_size

View File

@ -32,4 +32,5 @@ pyjwt[crypto]
kornia>=0.7.1
mpmath>=1.0,!=1.4.0a0
huggingface_hub
lazy-object-proxy
lazy-object-proxy
can_ada

View File

@ -21,7 +21,6 @@ def run_server(args_pytest):
args.output_directory = args_pytest["output_dir"]
args.listen = args_pytest["listen"]
args.port = args_pytest["port"]
print("running server anyway!")
asyncio.run(main())

View File

@ -1,6 +1,10 @@
import random
from urllib.parse import parse_qsl
import aiohttp
import pytest
from can_ada import URL, parse
from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@ -8,7 +12,7 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@pytest.mark.asyncio
async def test_completes_prompt(comfy_background_server):
client = AsyncRemoteComfyClient()
random_seed = random.randint(1,4294967295)
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000
@ -17,7 +21,7 @@ async def test_completes_prompt(comfy_background_server):
@pytest.mark.asyncio
async def test_completes_prompt_with_ui(comfy_background_server):
client = AsyncRemoteComfyClient()
random_seed = random.randint(1,4294967295)
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_dict = await client.queue_prompt_ui(prompt)
# should contain one output
@ -27,10 +31,24 @@ async def test_completes_prompt_with_ui(comfy_background_server):
@pytest.mark.asyncio
async def test_completes_prompt_with_image_urls(comfy_background_server):
client = AsyncRemoteComfyClient()
random_seed = random.randint(1,4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_list = await client.queue_prompt_uris(prompt)
assert len(result_list) == 3
result_list = await client.queue_prompt_uris(prompt)
# cached
assert len(result_list) == 1
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
result = await client.queue_prompt_api(prompt)
assert len(result.urls) == 2
for url_str in result.urls:
url: URL = parse(url_str)
assert url.hostname == "localhost" or url.hostname == "127.0.0.1" or url.hostname == "::1"
assert url.pathname == "/view"
search = {k: v for (k, v) in parse_qsl(url.search[1:])}
assert str(search["filename"]).startswith("sdxl")
assert search["subfolder"] == "subdirtest"
assert search["type"] == "output"
# get the actual image file and assert it works
async with aiohttp.ClientSession() as session:
async with session.get(url_str) as response:
assert response.status == 200
assert response.headers['Content-Type'] == 'image/png'
content = await response.read()
assert len(content) > 1000
assert len(result.outputs) == 1
assert len(result.outputs["13"]["images"]) == 1