From 1a2c4b4cc662f08a5f4f44b34b79c41e55559e5c Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 6 Dec 2024 11:22:09 -0800 Subject: [PATCH] Improve response semantics --- comfy/cmd/server.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 22ef781ee..bec98a372 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -53,6 +53,7 @@ from ..nodes.package_typing import ExportedNodes logger = logging.getLogger(__name__) + class HeuristicPath(NamedTuple): filename_heuristic: str abs_path: str @@ -740,6 +741,8 @@ class PromptServer(ExecutorToClientProgress): @routes.post("/api/v1/prompts") async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: accept = request.headers.get("accept", "application/json") + if accept == '*/*': + accept = "application/json" content_type = request.headers.get("content-type", "application/json") preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type if "+" in content_type: @@ -747,6 +750,9 @@ class PromptServer(ExecutorToClientProgress): wait = not "respond-async" in preferences + if accept not in ("application/json", "image/png"): + return web.json_response(status=400, reason=f"invalid accept content type, expected application/json or image/png, got {accept}") + # check if the queue is too long queue_size = self.prompt_queue.size() queue_too_busy_size = PromptServer.get_too_busy_queue_size() @@ -859,7 +865,10 @@ class PromptServer(ExecutorToClientProgress): external_address = self.external_address for base in (local_address, external_address): - url: URL = urlparse(urljoin(base, "view")) + try: + url: URL = urlparse(urljoin(base, "view")) + except ValueError: + continue url_search_dict: FileOutput = dict(image_indv_) del url_search_dict["abs_path"] if "name" in url_search_dict: @@ -877,9 +886,12 @@ class PromptServer(ExecutorToClientProgress): 'urls': urls_, 'outputs': result.outputs })) - elif accept == "image/png": + elif accept == "image/png" or accept == "image/jpeg": return web.FileResponse(main_image["abs_path"], headers=digest_headers_) + else: + return web.Response(status=500, + reason="unreachable") else: return web.Response(status=204) @@ -1028,21 +1040,28 @@ class PromptServer(ExecutorToClientProgress): async def start_multi_address(self, addresses, call_on_start=None, verbose=True): address_print = "localhost" - port = 8188 - address: str = None + address: str = "127.0.0.1" + port: int = 8188 runner = web.AppRunner(self.app, access_log=None, keepalive_timeout=900) await runner.setup() - for addr in addresses: - address = addr[0] - port = addr[1] + + def is_ipv4(address: str, *args): + try: + parsed = ipaddress.ip_address(address) + return isinstance(parsed, ipaddress.IPv4Address) + except: + return False + + addresses = sorted(addresses, key=lambda tuple: is_ipv4(*tuple)) + for (address, port) in addresses: site = web.TCPSite(runner, address, port, backlog=PromptServer.get_too_busy_queue_size()) await site.start() - if not hasattr(self, 'address'): - self.address = address # TODO: remove this - self.port = port + # preference for the ipv4 address achieved by sorting + self.address = "localhost" if address == "0.0.0.0" else address + self.port = port - if address == '::': + if address == '::' or address == "127.0.0.1" or address == "0.0.0.0": address_print = "localhost" elif ':' in address: address_print = "[{}]".format(address)