Add external address parameter

This commit is contained in:
doctorpangloss 2024-02-15 17:39:15 -08:00
parent 7c6b8ecb02
commit 06e74226df
9 changed files with 101 additions and 29 deletions

View File

@ -378,7 +378,7 @@ paths:
200:
headers:
Location:
description: The URL to the file based on a hash of the request body.
description: The URL to the file based on a hash of the request body when exactly one SaveImage node is specified.
example: /api/v1/images/e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3
schema:
type: string
@ -388,7 +388,7 @@ paths:
schema:
type: string
Content-Disposition:
description: The filename when a SaveImage node is specified.
description: The filename when exactly one SaveImage node is specified.
example: filename=ComfyUI_00001.png
schema:
type: string
@ -407,8 +407,11 @@ paths:
description: |
A list of URLs to retrieve the binary content of the image.
This will return two URLs. The first is the ordinary ComfyUI view image URL that exactly corresponds
to the UI call. The second is the URL that corresponds to sha256 hash of the request body.
The first URL is named by the digest of the prompt and references the image returned by the first
SaveImage URL, allowing you to exactly retrieve the image without re-running the prompt.
Then, for each SaveImage node, there will be two URLs: the internal URL returned by the worker, and
the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration.
Hashing function for web browsers:
@ -466,7 +469,7 @@ paths:
items:
type: string
example:
urls: [ "/api/v1/images/e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3", "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output" ]
urls: [ "/api/v1/images/e5187160a7b2c496773c1c5a45bfd3ffbf25eaa5969328e6469d36f31cf240a3", "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&type=output", "https://comfyui.example.com/view?filename=ComfyUI_00001_.png&type=output" ]
204:
description: |
The prompt was run but did not contain any SaveImage outputs, so nothing will be returned.
@ -781,6 +784,14 @@ components:
- class_type
- inputs
properties:
_meta:
type: object
properties:
title:
type: string
description: |
The title of the node when authored in the workflow. Set only when the end user changed it using the
panel properties in the UI.
class_type:
type: string
description: The node's class type, which maps to a class in NODE_CLASS_MAPPINGS.
@ -794,12 +805,11 @@ components:
- type: array
description: |
When this is specified, it is a node connection, followed by an output.
items:
minItems: 2
maxItems: 2
prefixItems:
- type: string
- type: integer
minItems: 2
maxItems: 2
prefixItems:
- type: string
- type: integer
description: The inputs for the node, which can be scalar values or references to other nodes' outputs.
is_changed:
oneOf:

View File

@ -135,7 +135,8 @@ parser.add_argument(
)
parser.add_argument("--distributed-queue-name", type=str, default="comfyui",
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
parser.add_argument("--external-address", required=False,
help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
if options.args_parsing:
args, _ = parser.parse_known_args()

View File

@ -74,6 +74,7 @@ class Configuration(dict):
distributed_queue_frontend (bool): Frontends will start the web UI and connect to the provided AMQP URL to submit prompts.
distributed_queue_worker (bool): Workers will pull requests off the AMQP URL.
distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths.
"""
def __init__(self, **kwargs):
super().__init__()
@ -135,6 +136,7 @@ class Configuration(dict):
self.distributed_queue_worker: bool = False
self.distributed_queue_frontend: bool = False
self.distributed_queue_name: str = "comfyui"
self.external_address: Optional[str] = None
for key, value in kwargs.items():
self[key] = value

View File

@ -28,6 +28,23 @@ class AsyncRemoteComfyClient:
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
self.loop = loop or asyncio.get_event_loop()
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]:
"""
Calls the API to queue a prompt.
:param prompt:
:return: a list of URLs corresponding to the SaveImage nodes in the prompt.
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
async with aiohttp.ClientSession() as session:
response: ClientResponse
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
if response.status == 200:
return (await response.json())["urls"]
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt(self, prompt: PromptDict) -> bytes:
"""
Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node.

View File

@ -224,6 +224,8 @@ async def main():
loop = asyncio.get_event_loop()
server = server_module.PromptServer(loop)
if args.external_address is not None:
server.external_address = args.external_address
if args.distributed_queue_connection_uri is not None:
distributed = True
q = DistributedPromptQueue(

View File

@ -5,7 +5,7 @@ import glob
import struct
import sys
import shutil
from urllib.parse import quote
from urllib.parse import quote, urljoin
from pkg_resources import resource_filename
from PIL import Image, ImageOps
@ -16,7 +16,7 @@ import json
import os
import uuid
from asyncio import Future, AbstractEventLoop
from typing import List
from typing import List, Optional
import aiofiles
import aiohttp
@ -92,14 +92,16 @@ class PromptServer(ExecutorToClientProgress):
self.messages: asyncio.Queue = asyncio.Queue()
self.number: int = 0
self.port: int = 8188
self._external_address: Optional[str] = None
middlewares = [cache_control]
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app: web.Application = web.Application(client_max_size=max_upload_size, handler_args={'max_field_size': 16380},
middlewares=middlewares)
self.app: web.Application = web.Application(client_max_size=max_upload_size,
handler_args={'max_field_size': 16380},
middlewares=middlewares)
self.sockets = dict()
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web")
if not os.path.exists(web_root_path):
@ -254,7 +256,7 @@ class PromptServer(ExecutorToClientProgress):
if os.path.isfile(file):
with Image.open(file) as original_pil:
metadata = PngInfo()
if hasattr(original_pil,'text'):
if hasattr(original_pil, 'text'):
for key in original_pil.text:
metadata.add_text(key, original_pil.text[key])
original_pil = original_pil.convert('RGBA')
@ -407,7 +409,7 @@ class PromptServer(ExecutorToClientProgress):
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[
node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else ''
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
@ -425,7 +427,8 @@ class PromptServer(ExecutorToClientProgress):
try:
out[x] = node_info(x)
except Exception as e:
print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr)
print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.",
file=sys.stderr)
traceback.print_exc()
return web.json_response(out)
@ -489,7 +492,7 @@ class PromptServer(ExecutorToClientProgress):
outputs_to_execute = valid[2]
self.prompt_queue.put(
QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute),
completed=None))
completed=None))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
else:
@ -606,6 +609,7 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=200,
headers=digest_headers_,
content_type="application/json",
body=json.dumps({'urls': [cache_url]}))
elif accept == "image/png":
return web.FileResponse(cache_path,
@ -622,7 +626,7 @@ class PromptServer(ExecutorToClientProgress):
self.number += 1
self.prompt_queue.put(
QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]),
completed=completed))
completed=completed))
try:
await completed
@ -654,17 +658,28 @@ class PromptServer(ExecutorToClientProgress):
pass
shutil.copy(image_, cache_path)
filename = os.path.basename(image_)
comfyui_url = f"http://{self.address}:{self.port}/view?filename={filename}&type=output"
digest_headers_ = {
"Digest": f"SHA-256={content_digest}",
"Location": f"/api/v1/images/{content_digest}",
"Content-Disposition": f"filename=\"{filename}\""
}
if accept == "application/json":
urls_ = [cache_url]
if len(output_images) == 1:
digest_headers_.update({
"Location": f"/api/v1/images/{content_digest}",
"Content-Disposition": f"filename=\"{filename}\""
})
for image_indv_ in output_images:
image_indv_filename_ = os.path.basename(image_indv_)
urls_ += [
f"http://{self.address}:{self.port}/view?filename={image_indv_filename_}&type=output",
urljoin(self.external_address, f"/view?filename={image_indv_filename_}&type=output")
]
if accept == "application/json":
return web.Response(status=200,
content_type="application/json",
headers=digest_headers_,
body=json.dumps({'urls': [cache_url, comfyui_url]}))
body=json.dumps({'urls': urls_}))
elif accept == "image/png":
return web.FileResponse(image_,
headers=digest_headers_)
@ -682,6 +697,14 @@ class PromptServer(ExecutorToClientProgress):
prompt = last_history_item['prompt'][2]
return web.json_response(prompt, status=200)
@property
def external_address(self):
return self._external_address if self._external_address is not None else f"http://{'localhost' if self.address == '0.0.0.0' else self.address}:{self.port}"
@external_address.setter
def external_address(self, value):
self._external_address = value
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.app.add_routes(self.routes)

View File

@ -16,7 +16,8 @@ except:
custom_nodes: typing.Optional[types.ModuleType] = None
from .package_typing import ExportedNodes
from functools import reduce
from pkg_resources import resource_filename, iter_entry_points
from pkg_resources import resource_filename
from importlib.metadata import entry_points
_comfy_nodes = ExportedNodes()
@ -85,7 +86,7 @@ def import_all_nodes_in_workspace() -> ExportedNodes:
custom_nodes_mappings.update(_import_and_enumerate_nodes_in_module(custom_nodes, print_import_times=True))
# load from entrypoints
for entry_point in iter_entry_points(group='comfyui.custom_nodes'):
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
# Load the module associated with the current entry point
module = entry_point.load()

View File

@ -17,11 +17,12 @@ def pytest_addoption(parser):
def run_server(args_pytest):
from comfy.cmd.main import main
from comfy.cli_args import args
import asyncio
args.output_directory = args_pytest["output_dir"]
args.listen = args_pytest["listen"]
args.port = args_pytest["port"]
print("running server anyway!")
main()
asyncio.run(main())
# This initializes args at the beginning of the test session

View File

@ -1,3 +1,5 @@
import random
import pytest
from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@ -10,6 +12,7 @@ async def test_completes_prompt(comfy_background_server):
png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000
@pytest.mark.asyncio
async def test_completes_prompt_with_ui(comfy_background_server):
client = AsyncRemoteComfyClient()
@ -17,3 +20,15 @@ async def test_completes_prompt_with_ui(comfy_background_server):
result_dict = await client.queue_prompt_ui(prompt)
# should contain one output
assert len(result_dict) == 1
@pytest.mark.asyncio
async def test_completes_prompt_with_image_urls(comfy_background_server):
client = AsyncRemoteComfyClient()
random_seed = random.randint(1,4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_list = await client.queue_prompt_uris(prompt)
assert len(result_list) == 3
result_list = await client.queue_prompt_uris(prompt)
# cached
assert len(result_list) == 1