mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 16:02:32 +08:00
added async tasks prompt api
This commit is contained in:
parent
2995a24725
commit
00a09956b9
17
execution.py
17
execution.py
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
@ -696,10 +697,18 @@ class PromptQueue:
|
||||
|
||||
def put(self, item):
|
||||
with self.mutex:
|
||||
heapq.heappush(self.queue, item)
|
||||
heapq.heappush(self.queue, (item, None))
|
||||
self.server.queue_updated()
|
||||
self.not_empty.notify()
|
||||
|
||||
def put_async(self, item):
|
||||
with self.mutex:
|
||||
future = asyncio.Future()
|
||||
heapq.heappush(self.queue, (item, future))
|
||||
self.server.queue_updated()
|
||||
self.not_empty.notify()
|
||||
return future
|
||||
|
||||
def get(self, timeout=None):
|
||||
with self.not_empty:
|
||||
while len(self.queue) == 0:
|
||||
@ -708,12 +717,12 @@ class PromptQueue:
|
||||
return None
|
||||
item = heapq.heappop(self.queue)
|
||||
i = self.task_counter
|
||||
self.currently_running[i] = copy.deepcopy(item)
|
||||
self.currently_running[i] = copy.deepcopy(item[0])
|
||||
self.task_counter += 1
|
||||
self.server.queue_updated()
|
||||
return (item, i)
|
||||
|
||||
def task_done(self, item_id, outputs):
|
||||
def task_done(self, item_id, outputs, future = None):
|
||||
with self.mutex:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||
@ -722,6 +731,8 @@ class PromptQueue:
|
||||
for o in outputs:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
self.server.queue_updated()
|
||||
if future is not None:
|
||||
future.set_result(outputs)
|
||||
|
||||
def get_current_queue(self):
|
||||
with self.mutex:
|
||||
|
||||
4
main.py
4
main.py
@ -99,12 +99,12 @@ def prompt_worker(q, server):
|
||||
|
||||
queue_item = q.get(timeout=timeout)
|
||||
if queue_item is not None:
|
||||
item, item_id = queue_item
|
||||
(item, future), item_id = queue_item
|
||||
execution_start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
q.task_done(item_id, e.outputs_ui, future)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
|
||||
42
server.py
42
server.py
@ -488,6 +488,48 @@ class PromptServer():
|
||||
else:
|
||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||
|
||||
@routes.post("/prompt_async")
|
||||
async def post_prompt_async(request):
|
||||
print("yeeehaw")
|
||||
resp_code = 200
|
||||
out_string = ""
|
||||
json_data = await request.json()
|
||||
json_data = self.trigger_on_prompt(json_data)
|
||||
|
||||
if "number" in json_data:
|
||||
number = float(json_data['number'])
|
||||
else:
|
||||
number = self.number
|
||||
if "front" in json_data:
|
||||
if json_data['front']:
|
||||
number = -number
|
||||
|
||||
self.number += 1
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
valid = execution.validate_prompt(prompt)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
extra_data = json_data["extra_data"]
|
||||
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
future = self.prompt_queue.put_async((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
await future
|
||||
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3], "results": future.result()}
|
||||
return web.json_response(response)
|
||||
else:
|
||||
print("invalid prompt:", valid[1])
|
||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||
else:
|
||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||
|
||||
|
||||
@routes.post("/queue")
|
||||
async def post_queue(request):
|
||||
json_data = await request.json()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user