added async tasks prompt api

This commit is contained in:
user 2023-12-03 17:11:20 -05:00
parent 2995a24725
commit 00a09956b9
3 changed files with 58 additions and 5 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import os import os
import sys import sys
import copy import copy
@ -696,10 +697,18 @@ class PromptQueue:
def put(self, item): def put(self, item):
with self.mutex: with self.mutex:
heapq.heappush(self.queue, item) heapq.heappush(self.queue, (item, None))
self.server.queue_updated() self.server.queue_updated()
self.not_empty.notify() self.not_empty.notify()
def put_async(self, item):
with self.mutex:
future = asyncio.Future()
heapq.heappush(self.queue, (item, future))
self.server.queue_updated()
self.not_empty.notify()
return future
def get(self, timeout=None): def get(self, timeout=None):
with self.not_empty: with self.not_empty:
while len(self.queue) == 0: while len(self.queue) == 0:
@ -708,12 +717,12 @@ class PromptQueue:
return None return None
item = heapq.heappop(self.queue) item = heapq.heappop(self.queue)
i = self.task_counter i = self.task_counter
self.currently_running[i] = copy.deepcopy(item) self.currently_running[i] = copy.deepcopy(item[0])
self.task_counter += 1 self.task_counter += 1
self.server.queue_updated() self.server.queue_updated()
return (item, i) return (item, i)
def task_done(self, item_id, outputs): def task_done(self, item_id, outputs, future = None):
with self.mutex: with self.mutex:
prompt = self.currently_running.pop(item_id) prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE: if len(self.history) > MAXIMUM_HISTORY_SIZE:
@ -722,6 +731,8 @@ class PromptQueue:
for o in outputs: for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o] self.history[prompt[1]]["outputs"][o] = outputs[o]
self.server.queue_updated() self.server.queue_updated()
if future is not None:
future.set_result(outputs)
def get_current_queue(self): def get_current_queue(self):
with self.mutex: with self.mutex:

View File

@ -99,12 +99,12 @@ def prompt_worker(q, server):
queue_item = q.get(timeout=timeout) queue_item = q.get(timeout=timeout)
if queue_item is not None: if queue_item is not None:
item, item_id = queue_item (item, future), item_id = queue_item
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
prompt_id = item[1] prompt_id = item[1]
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
q.task_done(item_id, e.outputs_ui) q.task_done(item_id, e.outputs_ui, future)
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)

View File

@ -488,6 +488,48 @@ class PromptServer():
else: else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400) return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
@routes.post("/prompt_async")
async def post_prompt_async(request):
print("yeeehaw")
resp_code = 200
out_string = ""
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
if "number" in json_data:
number = float(json_data['number'])
else:
number = self.number
if "front" in json_data:
if json_data['front']:
number = -number
self.number += 1
if "prompt" in json_data:
prompt = json_data["prompt"]
valid = execution.validate_prompt(prompt)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
future = self.prompt_queue.put_async((number, prompt_id, prompt, extra_data, outputs_to_execute))
await future
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3], "results": future.result()}
return web.json_response(response)
else:
print("invalid prompt:", valid[1])
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
@routes.post("/queue") @routes.post("/queue")
async def post_queue(request): async def post_queue(request):
json_data = await request.json() json_data = await request.json()