Improve response semantics

This commit is contained in:
doctorpangloss 2024-12-06 11:22:09 -08:00
parent 4f085c4d58
commit 1a2c4b4cc6

View File

@ -53,6 +53,7 @@ from ..nodes.package_typing import ExportedNodes
logger = logging.getLogger(__name__)
class HeuristicPath(NamedTuple):
filename_heuristic: str
abs_path: str
@ -740,6 +741,8 @@ class PromptServer(ExecutorToClientProgress):
@routes.post("/api/v1/prompts")
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
accept = request.headers.get("accept", "application/json")
if accept == '*/*':
accept = "application/json"
content_type = request.headers.get("content-type", "application/json")
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
if "+" in content_type:
@ -747,6 +750,9 @@ class PromptServer(ExecutorToClientProgress):
wait = not "respond-async" in preferences
if accept not in ("application/json", "image/png"):
return web.json_response(status=400, reason=f"invalid accept content type, expected application/json or image/png, got {accept}")
# check if the queue is too long
queue_size = self.prompt_queue.size()
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
@ -859,7 +865,10 @@ class PromptServer(ExecutorToClientProgress):
external_address = self.external_address
for base in (local_address, external_address):
url: URL = urlparse(urljoin(base, "view"))
try:
url: URL = urlparse(urljoin(base, "view"))
except ValueError:
continue
url_search_dict: FileOutput = dict(image_indv_)
del url_search_dict["abs_path"]
if "name" in url_search_dict:
@ -877,9 +886,12 @@ class PromptServer(ExecutorToClientProgress):
'urls': urls_,
'outputs': result.outputs
}))
elif accept == "image/png":
elif accept == "image/png" or accept == "image/jpeg":
return web.FileResponse(main_image["abs_path"],
headers=digest_headers_)
else:
return web.Response(status=500,
reason="unreachable")
else:
return web.Response(status=204)
@ -1028,21 +1040,28 @@ class PromptServer(ExecutorToClientProgress):
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
address_print = "localhost"
port = 8188
address: str = None
address: str = "127.0.0.1"
port: int = 8188
runner = web.AppRunner(self.app, access_log=None, keepalive_timeout=900)
await runner.setup()
for addr in addresses:
address = addr[0]
port = addr[1]
def is_ipv4(address: str, *args):
try:
parsed = ipaddress.ip_address(address)
return isinstance(parsed, ipaddress.IPv4Address)
except:
return False
addresses = sorted(addresses, key=lambda tuple: is_ipv4(*tuple))
for (address, port) in addresses:
site = web.TCPSite(runner, address, port, backlog=PromptServer.get_too_busy_queue_size())
await site.start()
if not hasattr(self, 'address'):
self.address = address # TODO: remove this
self.port = port
# preference for the ipv4 address achieved by sorting
self.address = "localhost" if address == "0.0.0.0" else address
self.port = port
if address == '::':
if address == '::' or address == "127.0.0.1" or address == "0.0.0.0":
address_print = "localhost"
elif ':' in address:
address_print = "[{}]".format(address)