Improve API support

- Removed /api/v1/images because you should use your own CDN style
   image host and /view for maximum compatibility
 - The /api/v1/prompts POST application/json response will now return
   the outputs dictionary
 - Caching has been removed
 - More tests
 - Subdirectory prefixes are now supported
 - Fixed an issue where a Linux frontend and Windows backend would have
   paths that could not interact with each other correctly
This commit is contained in:
doctorpangloss 2024-03-21 16:24:22 -07:00
parent d73b116446
commit 0db040cc47
11 changed files with 151 additions and 98 deletions

View File

@ -15,6 +15,8 @@ paths:
description: the index.html of the website description: the index.html of the website
content: content:
text/html: text/html:
schema:
type: string
example: "<!DOCTYPE html>..." example: "<!DOCTYPE html>..."
/embeddings: /embeddings:
get: get:
@ -454,9 +456,25 @@ paths:
type: string type: string
outputs: outputs:
$ref: "#/components/schemas/Outputs" $ref: "#/components/schemas/Outputs"
example: example:
outputs: {} outputs: { }
urls: [ "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output", "https://comfyui.example.com/view?filename=ComfyUI_00001_.png&type=output" ] urls: [ "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output", "https://comfyui.example.com/view?filename=ComfyUI_00001_.png&type=output" ]
multipart/mixed:
encoding:
"^\\d+$":
contentType: image/png, image/jpeg, image/webp
schema:
type: object
description: |
Each of the output nodes' binary values for images, and the "outputs" json object.
properties:
outputs:
$ref: "#/components/schemas/Outputs"
additionalProperties:
patternProperties:
"^\\d+$":
type: string
format: binary
204: 204:
description: | description: |
The prompt was run but did not contain any SaveImage outputs, so nothing will be returned. The prompt was run but did not contain any SaveImage outputs, so nothing will be returned.
@ -489,9 +507,12 @@ paths:
enum: enum:
- "application/json" - "application/json"
- "image/png" - "image/png"
- "multipart/mixed"
required: false required: false
description: | description: |
Specifies the media type the client is willing to receive. Specifies the media type the client is willing to receive.
multipart/mixed will soon be supported to return all the images from the workflow.
requestBody: requestBody:
content: content:
application/json: application/json:

View File

@ -178,12 +178,14 @@ def create_parser() -> argparse.ArgumentParser:
help="Specifies a base URL for external addresses reported by the API, such as for image paths.") help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.") parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.")
parser.add_argument("--max-queue-size", type=int, default=65536, help="The API will reject prompt requests if the queue's size exceeds this value.")
# now give plugins a chance to add configuration # now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'): for entry_point in entry_points().select(group='comfyui.custom_config'):
try: try:
plugin_callable: ConfigurationExtender | ModuleType = entry_point.load() plugin_callable: ConfigurationExtender | ModuleType = entry_point.load()
if isinstance(plugin_callable, ModuleType): if isinstance(plugin_callable, ModuleType):
# todo: find the configuration extender in the module
plugin_callable = ... plugin_callable = ...
else: else:
parser_result = plugin_callable(parser) parser_result = plugin_callable(parser)

View File

@ -80,6 +80,7 @@ class Configuration(dict):
external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths. external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths.
verbose (bool): Shows extra output for debugging purposes such as import errors of custom nodes. verbose (bool): Shows extra output for debugging purposes such as import errors of custom nodes.
disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI. disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI.
max_queue_size (int): The API will reject prompt requests if the queue's size exceeds this value.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -144,6 +145,7 @@ class Configuration(dict):
self.distributed_queue_name: str = "comfyui" self.distributed_queue_name: str = "comfyui"
self.external_address: Optional[str] = None self.external_address: Optional[str] = None
self.disable_known_models: bool = False self.disable_known_models: bool = False
self.max_queue_size: int = 65536
for key, value in kwargs.items(): for key, value in kwargs.items():
self[key] = value self[key] = value

View File

@ -3,12 +3,14 @@ import uuid
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict from typing import Optional, List
from urllib.parse import urlparse, urljoin from urllib.parse import urlparse, urljoin
import aiohttp import aiohttp
from aiohttp import WSMessage, ClientResponse from aiohttp import WSMessage, ClientResponse
from typing_extensions import Dict
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..api.api_client import JSONEncoder from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt_request import PromptRequest from ..api.components.schema.prompt_request import PromptRequest
@ -17,6 +19,9 @@ from ..component_model.file_output_path import file_output_path
class AsyncRemoteComfyClient: class AsyncRemoteComfyClient:
"""
An asynchronous client for remote servers
"""
__json_encoder = JSONEncoder() __json_encoder = JSONEncoder()
def __init__(self, server_address: str = "http://localhost:8188", client_id: str = str(uuid.uuid4()), def __init__(self, server_address: str = "http://localhost:8188", client_id: str = str(uuid.uuid4()),
@ -28,11 +33,11 @@ class AsyncRemoteComfyClient:
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}") f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]: async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
""" """
Calls the API to queue a prompt. Calls the API to queue a prompt.
:param prompt: :param prompt:
:return: a list of URLs corresponding to the SaveImage nodes in the prompt. :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
""" """
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -41,10 +46,18 @@ class AsyncRemoteComfyClient:
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
if response.status == 200: if response.status == 200:
return (await response.json())["urls"] return V1QueuePromptResponse(**(await response.json()))
else: else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]:
"""
Calls the API to queue a prompt.
:param prompt:
:return: a list of URLs corresponding to the SaveImage nodes in the prompt.
"""
return (await self.queue_prompt_api(prompt)).urls
async def queue_prompt(self, prompt: PromptDict) -> bytes: async def queue_prompt(self, prompt: PromptDict) -> bytes:
""" """
Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node. Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node.

View File

@ -0,0 +1,22 @@
import dataclasses
from typing import List
from typing_extensions import TypedDict, Literal, NotRequired, Dict
class FileOutput(TypedDict, total=False):
filename: str
subfolder: str
type: Literal["output", "input", "temp"]
abs_path: str
class Output(TypedDict, total=False):
latents: NotRequired[List[FileOutput]]
images: NotRequired[List[FileOutput]]
@dataclasses.dataclass
class V1QueuePromptResponse:
urls: List[str]
outputs: Dict[str, Output]

View File

@ -96,7 +96,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
return results return results
def get_output_data(obj, input_data_all): def get_output_data(obj, input_data_all) -> Tuple[List[typing.Any], typing.Dict[str, List[typing.Any]]]:
results = [] results = []
uis = [] uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)

View File

@ -1,43 +1,48 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import traceback
import glob import glob
import json
import logging
import mimetypes
import os
import struct import struct
import sys import sys
import shutil import traceback
from urllib.parse import quote, urljoin
from pkg_resources import resource_filename
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
import logging
import json
import os
import uuid import uuid
from asyncio import Future, AbstractEventLoop from asyncio import Future, AbstractEventLoop
from typing import List, Optional from io import BytesIO
from typing import List, Optional, Dict
from urllib.parse import quote, urlencode
from posixpath import join as urljoin
from can_ada import URL, parse as urlparse
import aiofiles import aiofiles
import aiohttp import aiohttp
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from aiohttp import web from aiohttp import web
from pkg_resources import resource_filename
from typing_extensions import NamedTuple
import comfy.interruption import comfy.interruption
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation from .. import model_management
from .. import utils
from ..app.user_manager import UserManager
from ..cli_args import args
from ..client.client_types import Output, FileOutput
from ..cmd import execution from ..cmd import execution
from ..cmd import folder_paths from ..cmd import folder_paths
import mimetypes
from ..digest import digest
from ..cli_args import args
from .. import utils
from .. import model_management
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
from ..digest import digest
from ..nodes.package_typing import ExportedNodes from ..nodes.package_typing import ExportedNodes
from ..vendor.appdirs import user_data_dir
from ..app.user_manager import UserManager
class HeuristicPath(NamedTuple):
filename_heuristic: str
abs_path: str
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
@ -543,13 +548,6 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=200) return web.Response(status=200)
@routes.get("/api/v1/images/{content_digest}")
async def get_image(request: web.Request) -> web.FileResponse:
digest_ = request.match_info['content_digest']
path = str(os.path.join(user_data_dir("comfyui", "comfyanonymous", roaming=False), digest_))
return web.FileResponse(path,
headers={"Content-Disposition": f"filename=\"{digest_}.png\""})
@routes.post("/api/v1/prompts") @routes.post("/api/v1/prompts")
async def post_prompt(request: web.Request) -> web.Response | web.FileResponse: async def post_prompt(request: web.Request) -> web.Response | web.FileResponse:
# check if the queue is too long # check if the queue is too long
@ -595,30 +593,6 @@ class PromptServer(ExecutorToClientProgress):
if not valid[0]: if not valid[0]:
return web.Response(status=400, body=valid[1]) return web.Response(status=400, body=valid[1])
cache_path = os.path.join(user_data_dir("comfyui", "comfyanonymous", roaming=False), content_digest)
cache_url = f"/api/v1/images/{content_digest}"
if os.path.exists(cache_path):
filename__ = os.path.basename(cache_path)
digest_headers_ = {
"Digest": f"SHA-256={content_digest}",
"Location": f"/api/v1/images/{content_digest}"
}
if accept == "application/json":
return web.Response(status=200,
headers=digest_headers_,
content_type="application/json",
body=json.dumps({'urls': [cache_url]}))
elif accept == "image/png":
return web.FileResponse(cache_path,
headers={"Content-Disposition": f"filename=\"{filename__}\"",
**digest_headers_})
else:
return web.json_response(status=400, reason=f"invalid accept header {accept}")
# todo: check that the files specified in the InputFile nodes exist
# convert a valid prompt to the queue tuple this expects # convert a valid prompt to the queue tuple this expects
completed: Future[TaskInvocation | dict] = self.loop.create_future() completed: Future[TaskInvocation | dict] = self.loop.create_future()
number = self.number number = self.number
@ -633,55 +607,57 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(body=str(ex), status=503) return web.Response(body=str(ex), status=503)
# expect a single image # expect a single image
result: TaskInvocation | dict = completed.result() result: TaskInvocation | dict = completed.result()
outputs_dict: dict = result.outputs if isinstance(result, TaskInvocation) else result outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result
# find images and read them # find images and read them
output_images: List[str] = [] output_images: List[FileOutput] = []
for node_id, node in outputs_dict.items(): for node_id, node in outputs_dict.items():
images: List[dict] = [] images: List[FileOutput] = []
if 'images' in node: if 'images' in node:
images = node['images'] images = node['images']
# todo: does this ever occur?
elif (isinstance(node, dict) elif (isinstance(node, dict)
and 'ui' in node and isinstance(node['ui'], dict) and 'ui' in node and isinstance(node['ui'], dict)
and 'images' in node['ui']): and 'images' in node['ui']):
images = node['ui']['images'] images = node['ui']['images']
for image_tuple in images: for image_tuple in images:
filename_ = image_tuple['abs_path'] output_images.append(image_tuple)
output_images.append(filename_)
if len(output_images) > 0: if len(output_images) > 0:
image_ = output_images[-1] main_image = output_images[-1]
if not os.path.exists(os.path.dirname(cache_path)): filename = main_image["filename"]
try:
os.makedirs(os.path.dirname(cache_path))
except:
pass
# shutil.copy(image_, cache_path)
filename = os.path.basename(image_)
digest_headers_ = { digest_headers_ = {
"Digest": f"SHA-256={content_digest}", "Digest": f"SHA-256={content_digest}",
} }
urls_ = [cache_url] urls_ = []
if len(output_images) == 1: if len(output_images) == 1:
digest_headers_.update({ digest_headers_.update({
"Location": f"/api/v1/images/{content_digest}",
"Content-Disposition": f"filename=\"{filename}\"" "Content-Disposition": f"filename=\"{filename}\""
}) })
for image_indv_ in output_images: for image_indv_ in output_images:
image_indv_filename_ = os.path.basename(image_indv_) local_address = f"http://{self.address}:{self.port}"
urls_ += [ external_address = self.external_address
f"http://{self.address}:{self.port}/view?filename={image_indv_filename_}&type=output",
urljoin(self.external_address, f"/view?filename={image_indv_filename_}&type=output") for base in (local_address, external_address):
] url: URL = urlparse(urljoin(base, "view"))
url_search_dict: FileOutput = dict(image_indv_)
del url_search_dict["abs_path"]
if url_search_dict["subfolder"] == "":
del url_search_dict["subfolder"]
url.search = f"?{urlencode(url_search_dict)}"
urls_.append(str(url))
if accept == "application/json": if accept == "application/json":
return web.Response(status=200, return web.Response(status=200,
content_type="application/json", content_type="application/json",
headers=digest_headers_, headers=digest_headers_,
body=json.dumps({'urls': urls_})) body=json.dumps({
'urls': urls_,
'outputs': outputs_dict
}))
elif accept == "image/png": elif accept == "image/png":
return web.FileResponse(image_, return web.FileResponse(main_image["abs_path"],
headers=digest_headers_) headers=digest_headers_)
else: else:
return web.Response(status=204) return web.Response(status=204)
@ -750,7 +726,7 @@ class PromptServer(ExecutorToClientProgress):
if hasattr(Image, 'Resampling'): if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR resampling = Image.Resampling.BILINEAR
else: else:
resampling = Image.ANTIALIAS resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling) image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1 type_num = 1
@ -838,5 +814,4 @@ class PromptServer(ExecutorToClientProgress):
@classmethod @classmethod
def get_too_busy_queue_size(cls): def get_too_busy_queue_size(cls):
# todo: what is too busy of a queue for API clients? return args.max_queue_size
return 100

View File

@ -33,3 +33,4 @@ kornia>=0.7.1
mpmath>=1.0,!=1.4.0a0 mpmath>=1.0,!=1.4.0a0
huggingface_hub huggingface_hub
lazy-object-proxy lazy-object-proxy
can_ada

View File

@ -21,7 +21,6 @@ def run_server(args_pytest):
args.output_directory = args_pytest["output_dir"] args.output_directory = args_pytest["output_dir"]
args.listen = args_pytest["listen"] args.listen = args_pytest["listen"]
args.port = args_pytest["port"] args.port = args_pytest["port"]
print("running server anyway!")
asyncio.run(main()) asyncio.run(main())

View File

@ -1,6 +1,10 @@
import random import random
from urllib.parse import parse_qsl
import aiohttp
import pytest import pytest
from can_ada import URL, parse
from comfy.client.aio_client import AsyncRemoteComfyClient from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@ -8,7 +12,7 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completes_prompt(comfy_background_server): async def test_completes_prompt(comfy_background_server):
client = AsyncRemoteComfyClient() client = AsyncRemoteComfyClient()
random_seed = random.randint(1,4294967295) random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt) png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000 assert len(png_image_bytes) > 1000
@ -17,7 +21,7 @@ async def test_completes_prompt(comfy_background_server):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completes_prompt_with_ui(comfy_background_server): async def test_completes_prompt_with_ui(comfy_background_server):
client = AsyncRemoteComfyClient() client = AsyncRemoteComfyClient()
random_seed = random.randint(1,4294967295) random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_dict = await client.queue_prompt_ui(prompt) result_dict = await client.queue_prompt_ui(prompt)
# should contain one output # should contain one output
@ -27,10 +31,24 @@ async def test_completes_prompt_with_ui(comfy_background_server):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completes_prompt_with_image_urls(comfy_background_server): async def test_completes_prompt_with_image_urls(comfy_background_server):
client = AsyncRemoteComfyClient() client = AsyncRemoteComfyClient()
random_seed = random.randint(1,4294967295) random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
result_list = await client.queue_prompt_uris(prompt) result = await client.queue_prompt_api(prompt)
assert len(result_list) == 3 assert len(result.urls) == 2
result_list = await client.queue_prompt_uris(prompt) for url_str in result.urls:
# cached url: URL = parse(url_str)
assert len(result_list) == 1 assert url.hostname == "localhost" or url.hostname == "127.0.0.1" or url.hostname == "::1"
assert url.pathname == "/view"
search = {k: v for (k, v) in parse_qsl(url.search[1:])}
assert str(search["filename"]).startswith("sdxl")
assert search["subfolder"] == "subdirtest"
assert search["type"] == "output"
# get the actual image file and assert it works
async with aiohttp.ClientSession() as session:
async with session.get(url_str) as response:
assert response.status == 200
assert response.headers['Content-Type'] == 'image/png'
content = await response.read()
assert len(content) > 1000
assert len(result.outputs) == 1
assert len(result.outputs["13"]["images"]) == 1