diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml index 87a15769a..9507bd248 100644 --- a/comfy/api/openapi.yaml +++ b/comfy/api/openapi.yaml @@ -15,6 +15,8 @@ paths: description: the index.html of the website content: text/html: + schema: + type: string example: "..." /embeddings: get: @@ -390,7 +392,7 @@ paths: The complete outputs dictionary from the workflow. Additionally, a list of URLs to binary outputs, whenever save nodes are used. - + For each SaveImage node, there will be two URLs: the internal URL returned by the worker, and the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration. @@ -454,9 +456,25 @@ paths: type: string outputs: $ref: "#/components/schemas/Outputs" - example: - outputs: {} - urls: [ "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output", "https://comfyui.example.com/view?filename=ComfyUI_00001_.png&type=output" ] + example: + outputs: { } + urls: [ "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output", "https://comfyui.example.com/view?filename=ComfyUI_00001_.png&type=output" ] + multipart/mixed: + encoding: + "^\\d+$": + contentType: image/png, image/jpeg, image/webp + schema: + type: object + description: | + Each of the output nodes' binary values for images, and the "outputs" json object. + properties: + outputs: + $ref: "#/components/schemas/Outputs" + additionalProperties: + patternProperties: + "^\\d+$": + type: string + format: binary 204: description: | The prompt was run but did not contain any SaveImage outputs, so nothing will be returned. @@ -489,9 +507,12 @@ paths: enum: - "application/json" - "image/png" + - "multipart/mixed" required: false description: | Specifies the media type the client is willing to receive. + + multipart/mixed will soon be supported to return all the images from the workflow. requestBody: content: application/json: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0497478af..9cbf38422 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -178,12 +178,14 @@ def create_parser() -> argparse.ArgumentParser: help="Specifies a base URL for external addresses reported by the API, such as for image paths.") parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.") + parser.add_argument("--max-queue-size", type=int, default=65536, help="The API will reject prompt requests if the queue's size exceeds this value.") # now give plugins a chance to add configuration for entry_point in entry_points().select(group='comfyui.custom_config'): try: plugin_callable: ConfigurationExtender | ModuleType = entry_point.load() if isinstance(plugin_callable, ModuleType): + # todo: find the configuration extender in the module plugin_callable = ... else: parser_result = plugin_callable(parser) diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 19ddf3b9f..69b889c80 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -80,6 +80,7 @@ class Configuration(dict): external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths. verbose (bool): Shows extra output for debugging purposes such as import errors of custom nodes. disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI. + max_queue_size (int): The API will reject prompt requests if the queue's size exceeds this value. """ def __init__(self, **kwargs): @@ -144,6 +145,7 @@ class Configuration(dict): self.distributed_queue_name: str = "comfyui" self.external_address: Optional[str] = None self.disable_known_models: bool = False + self.max_queue_size: int = 65536 for key, value in kwargs.items(): self[key] = value diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index 02c8f939d..89ad8d6be 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -3,12 +3,14 @@ import uuid from asyncio import AbstractEventLoop from collections import defaultdict from pathlib import Path -from typing import Optional, List, Dict +from typing import Optional, List from urllib.parse import urlparse, urljoin import aiohttp from aiohttp import WSMessage, ClientResponse +from typing_extensions import Dict +from .client_types import V1QueuePromptResponse from ..api.components.schema.prompt import PromptDict from ..api.api_client import JSONEncoder from ..api.components.schema.prompt_request import PromptRequest @@ -17,6 +19,9 @@ from ..component_model.file_output_path import file_output_path class AsyncRemoteComfyClient: + """ + An asynchronous client for remote servers + """ __json_encoder = JSONEncoder() def __init__(self, server_address: str = "http://localhost:8188", client_id: str = str(uuid.uuid4()), @@ -28,11 +33,11 @@ class AsyncRemoteComfyClient: f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}") self.loop = loop or asyncio.get_event_loop() - async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]: + async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse: """ Calls the API to queue a prompt. :param prompt: - :return: a list of URLs corresponding to the SaveImage nodes in the prompt. + :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) """ prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) async with aiohttp.ClientSession() as session: @@ -41,10 +46,18 @@ class AsyncRemoteComfyClient: headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: if response.status == 200: - return (await response.json())["urls"] + return V1QueuePromptResponse(**(await response.json())) else: raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]: + """ + Calls the API to queue a prompt. + :param prompt: + :return: a list of URLs corresponding to the SaveImage nodes in the prompt. + """ + return (await self.queue_prompt_api(prompt)).urls + async def queue_prompt(self, prompt: PromptDict) -> bytes: """ Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node. diff --git a/comfy/client/client_types.py b/comfy/client/client_types.py new file mode 100644 index 000000000..f22727736 --- /dev/null +++ b/comfy/client/client_types.py @@ -0,0 +1,22 @@ +import dataclasses +from typing import List + +from typing_extensions import TypedDict, Literal, NotRequired, Dict + + +class FileOutput(TypedDict, total=False): + filename: str + subfolder: str + type: Literal["output", "input", "temp"] + abs_path: str + + +class Output(TypedDict, total=False): + latents: NotRequired[List[FileOutput]] + images: NotRequired[List[FileOutput]] + + +@dataclasses.dataclass +class V1QueuePromptResponse: + urls: List[str] + outputs: Dict[str, Output] diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 95b58bc7c..31246f214 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -96,7 +96,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): return results -def get_output_data(obj, input_data_all): +def get_output_data(obj, input_data_all) -> Tuple[List[typing.Any], typing.Dict[str, List[typing.Any]]]: results = [] uis = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index fa7c1cce1..8e37704a8 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -1,43 +1,48 @@ from __future__ import annotations + import asyncio -import traceback import glob +import json +import logging +import mimetypes +import os import struct import sys -import shutil -from urllib.parse import quote, urljoin -from pkg_resources import resource_filename - -from PIL import Image, ImageOps -from PIL.PngImagePlugin import PngInfo -from io import BytesIO - -import logging -import json -import os +import traceback import uuid from asyncio import Future, AbstractEventLoop -from typing import List, Optional +from io import BytesIO +from typing import List, Optional, Dict +from urllib.parse import quote, urlencode +from posixpath import join as urljoin +from can_ada import URL, parse as urlparse import aiofiles import aiohttp +from PIL import Image, ImageOps +from PIL.PngImagePlugin import PngInfo from aiohttp import web +from pkg_resources import resource_filename +from typing_extensions import NamedTuple import comfy.interruption -from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation +from .. import model_management +from .. import utils +from ..app.user_manager import UserManager +from ..cli_args import args +from ..client.client_types import Output, FileOutput from ..cmd import execution from ..cmd import folder_paths -import mimetypes - -from ..digest import digest -from ..cli_args import args -from .. import utils -from .. import model_management from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.file_output_path import file_output_path +from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation +from ..digest import digest from ..nodes.package_typing import ExportedNodes -from ..vendor.appdirs import user_data_dir -from ..app.user_manager import UserManager + + +class HeuristicPath(NamedTuple): + filename_heuristic: str + abs_path: str async def send_socket_catch_exception(function, message): @@ -543,13 +548,6 @@ class PromptServer(ExecutorToClientProgress): return web.Response(status=200) - @routes.get("/api/v1/images/{content_digest}") - async def get_image(request: web.Request) -> web.FileResponse: - digest_ = request.match_info['content_digest'] - path = str(os.path.join(user_data_dir("comfyui", "comfyanonymous", roaming=False), digest_)) - return web.FileResponse(path, - headers={"Content-Disposition": f"filename=\"{digest_}.png\""}) - @routes.post("/api/v1/prompts") async def post_prompt(request: web.Request) -> web.Response | web.FileResponse: # check if the queue is too long @@ -595,30 +593,6 @@ class PromptServer(ExecutorToClientProgress): if not valid[0]: return web.Response(status=400, body=valid[1]) - cache_path = os.path.join(user_data_dir("comfyui", "comfyanonymous", roaming=False), content_digest) - cache_url = f"/api/v1/images/{content_digest}" - - if os.path.exists(cache_path): - filename__ = os.path.basename(cache_path) - digest_headers_ = { - "Digest": f"SHA-256={content_digest}", - "Location": f"/api/v1/images/{content_digest}" - } - if accept == "application/json": - - return web.Response(status=200, - headers=digest_headers_, - content_type="application/json", - body=json.dumps({'urls': [cache_url]})) - elif accept == "image/png": - return web.FileResponse(cache_path, - headers={"Content-Disposition": f"filename=\"{filename__}\"", - **digest_headers_}) - else: - return web.json_response(status=400, reason=f"invalid accept header {accept}") - - # todo: check that the files specified in the InputFile nodes exist - # convert a valid prompt to the queue tuple this expects completed: Future[TaskInvocation | dict] = self.loop.create_future() number = self.number @@ -633,55 +607,57 @@ class PromptServer(ExecutorToClientProgress): return web.Response(body=str(ex), status=503) # expect a single image result: TaskInvocation | dict = completed.result() - outputs_dict: dict = result.outputs if isinstance(result, TaskInvocation) else result + outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result # find images and read them - output_images: List[str] = [] + output_images: List[FileOutput] = [] for node_id, node in outputs_dict.items(): - images: List[dict] = [] + images: List[FileOutput] = [] if 'images' in node: images = node['images'] + # todo: does this ever occur? elif (isinstance(node, dict) and 'ui' in node and isinstance(node['ui'], dict) and 'images' in node['ui']): images = node['ui']['images'] for image_tuple in images: - filename_ = image_tuple['abs_path'] - output_images.append(filename_) + output_images.append(image_tuple) if len(output_images) > 0: - image_ = output_images[-1] - if not os.path.exists(os.path.dirname(cache_path)): - try: - os.makedirs(os.path.dirname(cache_path)) - except: - pass - # shutil.copy(image_, cache_path) - filename = os.path.basename(image_) + main_image = output_images[-1] + filename = main_image["filename"] digest_headers_ = { "Digest": f"SHA-256={content_digest}", } - urls_ = [cache_url] + urls_ = [] if len(output_images) == 1: digest_headers_.update({ - "Location": f"/api/v1/images/{content_digest}", "Content-Disposition": f"filename=\"{filename}\"" }) for image_indv_ in output_images: - image_indv_filename_ = os.path.basename(image_indv_) - urls_ += [ - f"http://{self.address}:{self.port}/view?filename={image_indv_filename_}&type=output", - urljoin(self.external_address, f"/view?filename={image_indv_filename_}&type=output") - ] + local_address = f"http://{self.address}:{self.port}" + external_address = self.external_address + + for base in (local_address, external_address): + url: URL = urlparse(urljoin(base, "view")) + url_search_dict: FileOutput = dict(image_indv_) + del url_search_dict["abs_path"] + if url_search_dict["subfolder"] == "": + del url_search_dict["subfolder"] + url.search = f"?{urlencode(url_search_dict)}" + urls_.append(str(url)) if accept == "application/json": return web.Response(status=200, content_type="application/json", headers=digest_headers_, - body=json.dumps({'urls': urls_})) + body=json.dumps({ + 'urls': urls_, + 'outputs': outputs_dict + })) elif accept == "image/png": - return web.FileResponse(image_, + return web.FileResponse(main_image["abs_path"], headers=digest_headers_) else: return web.Response(status=204) @@ -750,7 +726,7 @@ class PromptServer(ExecutorToClientProgress): if hasattr(Image, 'Resampling'): resampling = Image.Resampling.BILINEAR else: - resampling = Image.ANTIALIAS + resampling = Image.Resampling.LANCZOS image = ImageOps.contain(image, (max_size, max_size), resampling) type_num = 1 @@ -838,5 +814,4 @@ class PromptServer(ExecutorToClientProgress): @classmethod def get_too_busy_queue_size(cls): - # todo: what is too busy of a queue for API clients? - return 100 + return args.max_queue_size diff --git a/custom_nodes/__init__.py b/custom_nodes/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/requirements.txt b/requirements.txt index 615d4beb2..fa71afe49 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,4 +32,5 @@ pyjwt[crypto] kornia>=0.7.1 mpmath>=1.0,!=1.4.0a0 huggingface_hub -lazy-object-proxy \ No newline at end of file +lazy-object-proxy +can_ada \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index a514d6d48..d407dc474 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,6 @@ def run_server(args_pytest): args.output_directory = args_pytest["output_dir"] args.listen = args_pytest["listen"] args.port = args_pytest["port"] - print("running server anyway!") asyncio.run(main()) diff --git a/tests/distributed/test_asyncio_remote_client.py b/tests/distributed/test_asyncio_remote_client.py index 6382857fe..8c54feaa1 100644 --- a/tests/distributed/test_asyncio_remote_client.py +++ b/tests/distributed/test_asyncio_remote_client.py @@ -1,6 +1,10 @@ import random +from urllib.parse import parse_qsl +import aiohttp import pytest +from can_ada import URL, parse + from comfy.client.aio_client import AsyncRemoteComfyClient from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner @@ -8,7 +12,7 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner @pytest.mark.asyncio async def test_completes_prompt(comfy_background_server): client = AsyncRemoteComfyClient() - random_seed = random.randint(1,4294967295) + random_seed = random.randint(1, 4294967295) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) png_image_bytes = await client.queue_prompt(prompt) assert len(png_image_bytes) > 1000 @@ -17,7 +21,7 @@ async def test_completes_prompt(comfy_background_server): @pytest.mark.asyncio async def test_completes_prompt_with_ui(comfy_background_server): client = AsyncRemoteComfyClient() - random_seed = random.randint(1,4294967295) + random_seed = random.randint(1, 4294967295) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) result_dict = await client.queue_prompt_ui(prompt) # should contain one output @@ -27,10 +31,24 @@ async def test_completes_prompt_with_ui(comfy_background_server): @pytest.mark.asyncio async def test_completes_prompt_with_image_urls(comfy_background_server): client = AsyncRemoteComfyClient() - random_seed = random.randint(1,4294967295) - prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) - result_list = await client.queue_prompt_uris(prompt) - assert len(result_list) == 3 - result_list = await client.queue_prompt_uris(prompt) - # cached - assert len(result_list) == 1 + random_seed = random.randint(1, 4294967295) + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl") + result = await client.queue_prompt_api(prompt) + assert len(result.urls) == 2 + for url_str in result.urls: + url: URL = parse(url_str) + assert url.hostname == "localhost" or url.hostname == "127.0.0.1" or url.hostname == "::1" + assert url.pathname == "/view" + search = {k: v for (k, v) in parse_qsl(url.search[1:])} + assert str(search["filename"]).startswith("sdxl") + assert search["subfolder"] == "subdirtest" + assert search["type"] == "output" + # get the actual image file and assert it works + async with aiohttp.ClientSession() as session: + async with session.get(url_str) as response: + assert response.status == 200 + assert response.headers['Content-Type'] == 'image/png' + content = await response.read() + assert len(content) > 1000 + assert len(result.outputs) == 1 + assert len(result.outputs["13"]["images"]) == 1