From 06e74226dfe7a5f5486c38b0f3403d17545e3fb6 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 15 Feb 2024 17:39:15 -0800 Subject: [PATCH] Add external address parameter --- comfy/api/openapi.yaml | 32 ++++++++---- comfy/cli_args.py | 3 +- comfy/cli_args_types.py | 2 + comfy/client/aio_client.py | 17 +++++++ comfy/cmd/main.py | 2 + comfy/cmd/server.py | 51 ++++++++++++++----- comfy/nodes/package.py | 5 +- tests/conftest.py | 3 +- .../distributed/test_asyncio_remote_client.py | 15 ++++++ 9 files changed, 101 insertions(+), 29 deletions(-) diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml index 25af0c332..6f25c81cf 100644 --- a/comfy/api/openapi.yaml +++ b/comfy/api/openapi.yaml @@ -378,7 +378,7 @@ paths: 200: headers: Location: - description: The URL to the file based on a hash of the request body. + description: The URL to the file based on a hash of the request body when exactly one SaveImage node is specified. example: /api/v1/images/e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3 schema: type: string @@ -388,7 +388,7 @@ paths: schema: type: string Content-Disposition: - description: The filename when a SaveImage node is specified. + description: The filename when exactly one SaveImage node is specified. example: filename=ComfyUI_00001.png schema: type: string @@ -407,8 +407,11 @@ paths: description: | A list of URLs to retrieve the binary content of the image. - This will return two URLs. The first is the ordinary ComfyUI view image URL that exactly corresponds - to the UI call. The second is the URL that corresponds to sha256 hash of the request body. + The first URL is named by the digest of the prompt and references the image returned by the first + SaveImage URL, allowing you to exactly retrieve the image without re-running the prompt. + + Then, for each SaveImage node, there will be two URLs: the internal URL returned by the worker, and + the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration. Hashing function for web browsers: @@ -466,7 +469,7 @@ paths: items: type: string example: - urls: [ "/api/v1/images/e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3", "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output" ] + urls: [ "/api/v1/images/e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3", "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output", "https://comfyui.example.com/view?filename=ComfyUI_00001_.png&type=output" ] 204: description: | The prompt was run but did not contain any SaveImage outputs, so nothing will be returned. @@ -781,6 +784,14 @@ components: - class_type - inputs properties: + _meta: + type: object + properties: + title: + type: string + description: | + The title of the node when authored in the workflow. Set only when the end user changed it using the + panel properties in the UI. class_type: type: string description: The node's class type, which maps to a class in NODE_CLASS_MAPPINGS. @@ -794,12 +805,11 @@ components: - type: array description: | When this is specified, it is a node connection, followed by an output. - items: - minItems: 2 - maxItems: 2 - prefixItems: - - type: string - - type: integer + minItems: 2 + maxItems: 2 + prefixItems: + - type: string + - type: integer description: The inputs for the node, which can be scalar values or references to other nodes' outputs. is_changed: oneOf: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 4a829b78f..66d1b9e8a 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -135,7 +135,8 @@ parser.add_argument( ) parser.add_argument("--distributed-queue-name", type=str, default="comfyui", help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID") - +parser.add_argument("--external-address", required=False, + help="Specifies a base URL for external addresses reported by the API, such as for image paths.") if options.args_parsing: args, _ = parser.parse_known_args() diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index b67f04873..9e10261f1 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -74,6 +74,7 @@ class Configuration(dict): distributed_queue_frontend (bool): Frontends will start the web UI and connect to the provided AMQP URL to submit prompts. distributed_queue_worker (bool): Workers will pull requests off the AMQP URL. distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID. + external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths. """ def __init__(self, **kwargs): super().__init__() @@ -135,6 +136,7 @@ class Configuration(dict): self.distributed_queue_worker: bool = False self.distributed_queue_frontend: bool = False self.distributed_queue_name: str = "comfyui" + self.external_address: Optional[str] = None for key, value in kwargs.items(): self[key] = value diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index f816baba1..02c8f939d 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -28,6 +28,23 @@ class AsyncRemoteComfyClient: f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}") self.loop = loop or asyncio.get_event_loop() + async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]: + """ + Calls the API to queue a prompt. + :param prompt: + :return: a list of URLs corresponding to the SaveImage nodes in the prompt. + """ + prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) + async with aiohttp.ClientSession() as session: + response: ClientResponse + async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: + + if response.status == 200: + return (await response.json())["urls"] + else: + raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + async def queue_prompt(self, prompt: PromptDict) -> bytes: """ Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node. diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 192843e2f..447878125 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -224,6 +224,8 @@ async def main(): loop = asyncio.get_event_loop() server = server_module.PromptServer(loop) + if args.external_address is not None: + server.external_address = args.external_address if args.distributed_queue_connection_uri is not None: distributed = True q = DistributedPromptQueue( diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 452672917..374c192a4 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -5,7 +5,7 @@ import glob import struct import sys import shutil -from urllib.parse import quote +from urllib.parse import quote, urljoin from pkg_resources import resource_filename from PIL import Image, ImageOps @@ -16,7 +16,7 @@ import json import os import uuid from asyncio import Future, AbstractEventLoop -from typing import List +from typing import List, Optional import aiofiles import aiohttp @@ -92,14 +92,16 @@ class PromptServer(ExecutorToClientProgress): self.messages: asyncio.Queue = asyncio.Queue() self.number: int = 0 self.port: int = 8188 + self._external_address: Optional[str] = None middlewares = [cache_control] if args.enable_cors_header: middlewares.append(create_cors_middleware(args.enable_cors_header)) max_upload_size = round(args.max_upload_size * 1024 * 1024) - self.app: web.Application = web.Application(client_max_size=max_upload_size, handler_args={'max_field_size': 16380}, - middlewares=middlewares) + self.app: web.Application = web.Application(client_max_size=max_upload_size, + handler_args={'max_field_size': 16380}, + middlewares=middlewares) self.sockets = dict() web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web") if not os.path.exists(web_root_path): @@ -254,7 +256,7 @@ class PromptServer(ExecutorToClientProgress): if os.path.isfile(file): with Image.open(file) as original_pil: metadata = PngInfo() - if hasattr(original_pil,'text'): + if hasattr(original_pil, 'text'): for key in original_pil.text: metadata.add_text(key, original_pil.text[key]) original_pil = original_pil.convert('RGBA') @@ -407,7 +409,7 @@ class PromptServer(ExecutorToClientProgress): info['name'] = node_class info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[ node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class - info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' + info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else '' info['category'] = 'sd' if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: info['output_node'] = True @@ -425,7 +427,8 @@ class PromptServer(ExecutorToClientProgress): try: out[x] = node_info(x) except Exception as e: - print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr) + print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", + file=sys.stderr) traceback.print_exc() return web.json_response(out) @@ -489,7 +492,7 @@ class PromptServer(ExecutorToClientProgress): outputs_to_execute = valid[2] self.prompt_queue.put( QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute), - completed=None)) + completed=None)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: @@ -606,6 +609,7 @@ class PromptServer(ExecutorToClientProgress): return web.Response(status=200, headers=digest_headers_, + content_type="application/json", body=json.dumps({'urls': [cache_url]})) elif accept == "image/png": return web.FileResponse(cache_path, @@ -622,7 +626,7 @@ class PromptServer(ExecutorToClientProgress): self.number += 1 self.prompt_queue.put( QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), - completed=completed)) + completed=completed)) try: await completed @@ -654,17 +658,28 @@ class PromptServer(ExecutorToClientProgress): pass shutil.copy(image_, cache_path) filename = os.path.basename(image_) - comfyui_url = f"http://{self.address}:{self.port}/view?filename={filename}&type=output" digest_headers_ = { "Digest": f"SHA-256={content_digest}", - "Location": f"/api/v1/images/{content_digest}", - "Content-Disposition": f"filename=\"{filename}\"" } - if accept == "application/json": + urls_ = [cache_url] + if len(output_images) == 1: + digest_headers_.update({ + "Location": f"/api/v1/images/{content_digest}", + "Content-Disposition": f"filename=\"{filename}\"" + }) + for image_indv_ in output_images: + image_indv_filename_ = os.path.basename(image_indv_) + urls_ += [ + f"http://{self.address}:{self.port}/view?filename={image_indv_filename_}&type=output", + urljoin(self.external_address, f"/view?filename={image_indv_filename_}&type=output") + ] + + if accept == "application/json": return web.Response(status=200, + content_type="application/json", headers=digest_headers_, - body=json.dumps({'urls': [cache_url, comfyui_url]})) + body=json.dumps({'urls': urls_})) elif accept == "image/png": return web.FileResponse(image_, headers=digest_headers_) @@ -682,6 +697,14 @@ class PromptServer(ExecutorToClientProgress): prompt = last_history_item['prompt'][2] return web.json_response(prompt, status=200) + @property + def external_address(self): + return self._external_address if self._external_address is not None else f"http://{'localhost' if self.address == '0.0.0.0' else self.address}:{self.port}" + + @external_address.setter + def external_address(self, value): + self._external_address = value + def add_routes(self): self.user_manager.add_routes(self.routes) self.app.add_routes(self.routes) diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index ee8a62db8..3dd4d370b 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -16,7 +16,8 @@ except: custom_nodes: typing.Optional[types.ModuleType] = None from .package_typing import ExportedNodes from functools import reduce -from pkg_resources import resource_filename, iter_entry_points +from pkg_resources import resource_filename +from importlib.metadata import entry_points _comfy_nodes = ExportedNodes() @@ -85,7 +86,7 @@ def import_all_nodes_in_workspace() -> ExportedNodes: custom_nodes_mappings.update(_import_and_enumerate_nodes_in_module(custom_nodes, print_import_times=True)) # load from entrypoints - for entry_point in iter_entry_points(group='comfyui.custom_nodes'): + for entry_point in entry_points().select(group='comfyui.custom_nodes'): # Load the module associated with the current entry point module = entry_point.load() diff --git a/tests/conftest.py b/tests/conftest.py index eadb6c16a..a514d6d48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,11 +17,12 @@ def pytest_addoption(parser): def run_server(args_pytest): from comfy.cmd.main import main from comfy.cli_args import args + import asyncio args.output_directory = args_pytest["output_dir"] args.listen = args_pytest["listen"] args.port = args_pytest["port"] print("running server anyway!") - main() + asyncio.run(main()) # This initializes args at the beginning of the test session diff --git a/tests/distributed/test_asyncio_remote_client.py b/tests/distributed/test_asyncio_remote_client.py index 0ea3dbe66..d0aab0650 100644 --- a/tests/distributed/test_asyncio_remote_client.py +++ b/tests/distributed/test_asyncio_remote_client.py @@ -1,3 +1,5 @@ +import random + import pytest from comfy.client.aio_client import AsyncRemoteComfyClient from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner @@ -10,6 +12,7 @@ async def test_completes_prompt(comfy_background_server): png_image_bytes = await client.queue_prompt(prompt) assert len(png_image_bytes) > 1000 + @pytest.mark.asyncio async def test_completes_prompt_with_ui(comfy_background_server): client = AsyncRemoteComfyClient() @@ -17,3 +20,15 @@ async def test_completes_prompt_with_ui(comfy_background_server): result_dict = await client.queue_prompt_ui(prompt) # should contain one output assert len(result_dict) == 1 + + +@pytest.mark.asyncio +async def test_completes_prompt_with_image_urls(comfy_background_server): + client = AsyncRemoteComfyClient() + random_seed = random.randint(1,4294967295) + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) + result_list = await client.queue_prompt_uris(prompt) + assert len(result_list) == 3 + result_list = await client.queue_prompt_uris(prompt) + # cached + assert len(result_list) == 1