mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Improve response semantics
This commit is contained in:
parent
4f085c4d58
commit
1a2c4b4cc6
@ -53,6 +53,7 @@ from ..nodes.package_typing import ExportedNodes
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HeuristicPath(NamedTuple):
|
class HeuristicPath(NamedTuple):
|
||||||
filename_heuristic: str
|
filename_heuristic: str
|
||||||
abs_path: str
|
abs_path: str
|
||||||
@ -740,6 +741,8 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
@routes.post("/api/v1/prompts")
|
@routes.post("/api/v1/prompts")
|
||||||
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||||
accept = request.headers.get("accept", "application/json")
|
accept = request.headers.get("accept", "application/json")
|
||||||
|
if accept == '*/*':
|
||||||
|
accept = "application/json"
|
||||||
content_type = request.headers.get("content-type", "application/json")
|
content_type = request.headers.get("content-type", "application/json")
|
||||||
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
|
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
|
||||||
if "+" in content_type:
|
if "+" in content_type:
|
||||||
@ -747,6 +750,9 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
wait = not "respond-async" in preferences
|
wait = not "respond-async" in preferences
|
||||||
|
|
||||||
|
if accept not in ("application/json", "image/png"):
|
||||||
|
return web.json_response(status=400, reason=f"invalid accept content type, expected application/json or image/png, got {accept}")
|
||||||
|
|
||||||
# check if the queue is too long
|
# check if the queue is too long
|
||||||
queue_size = self.prompt_queue.size()
|
queue_size = self.prompt_queue.size()
|
||||||
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
|
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
|
||||||
@ -859,7 +865,10 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
external_address = self.external_address
|
external_address = self.external_address
|
||||||
|
|
||||||
for base in (local_address, external_address):
|
for base in (local_address, external_address):
|
||||||
url: URL = urlparse(urljoin(base, "view"))
|
try:
|
||||||
|
url: URL = urlparse(urljoin(base, "view"))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
url_search_dict: FileOutput = dict(image_indv_)
|
url_search_dict: FileOutput = dict(image_indv_)
|
||||||
del url_search_dict["abs_path"]
|
del url_search_dict["abs_path"]
|
||||||
if "name" in url_search_dict:
|
if "name" in url_search_dict:
|
||||||
@ -877,9 +886,12 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
'urls': urls_,
|
'urls': urls_,
|
||||||
'outputs': result.outputs
|
'outputs': result.outputs
|
||||||
}))
|
}))
|
||||||
elif accept == "image/png":
|
elif accept == "image/png" or accept == "image/jpeg":
|
||||||
return web.FileResponse(main_image["abs_path"],
|
return web.FileResponse(main_image["abs_path"],
|
||||||
headers=digest_headers_)
|
headers=digest_headers_)
|
||||||
|
else:
|
||||||
|
return web.Response(status=500,
|
||||||
|
reason="unreachable")
|
||||||
else:
|
else:
|
||||||
return web.Response(status=204)
|
return web.Response(status=204)
|
||||||
|
|
||||||
@ -1028,21 +1040,28 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
|
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
|
||||||
address_print = "localhost"
|
address_print = "localhost"
|
||||||
port = 8188
|
address: str = "127.0.0.1"
|
||||||
address: str = None
|
port: int = 8188
|
||||||
runner = web.AppRunner(self.app, access_log=None, keepalive_timeout=900)
|
runner = web.AppRunner(self.app, access_log=None, keepalive_timeout=900)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
for addr in addresses:
|
|
||||||
address = addr[0]
|
def is_ipv4(address: str, *args):
|
||||||
port = addr[1]
|
try:
|
||||||
|
parsed = ipaddress.ip_address(address)
|
||||||
|
return isinstance(parsed, ipaddress.IPv4Address)
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
addresses = sorted(addresses, key=lambda tuple: is_ipv4(*tuple))
|
||||||
|
for (address, port) in addresses:
|
||||||
site = web.TCPSite(runner, address, port, backlog=PromptServer.get_too_busy_queue_size())
|
site = web.TCPSite(runner, address, port, backlog=PromptServer.get_too_busy_queue_size())
|
||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
if not hasattr(self, 'address'):
|
# preference for the ipv4 address achieved by sorting
|
||||||
self.address = address # TODO: remove this
|
self.address = "localhost" if address == "0.0.0.0" else address
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
||||||
if address == '::':
|
if address == '::' or address == "127.0.0.1" or address == "0.0.0.0":
|
||||||
address_print = "localhost"
|
address_print = "localhost"
|
||||||
elif ':' in address:
|
elif ':' in address:
|
||||||
address_print = "[{}]".format(address)
|
address_print = "[{}]".format(address)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user