mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
added async tasks prompt api
This commit is contained in:
parent
2995a24725
commit
00a09956b9
17
execution.py
17
execution.py
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
@ -696,10 +697,18 @@ class PromptQueue:
|
|||||||
|
|
||||||
def put(self, item):
|
def put(self, item):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
heapq.heappush(self.queue, item)
|
heapq.heappush(self.queue, (item, None))
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
self.not_empty.notify()
|
self.not_empty.notify()
|
||||||
|
|
||||||
|
def put_async(self, item):
|
||||||
|
with self.mutex:
|
||||||
|
future = asyncio.Future()
|
||||||
|
heapq.heappush(self.queue, (item, future))
|
||||||
|
self.server.queue_updated()
|
||||||
|
self.not_empty.notify()
|
||||||
|
return future
|
||||||
|
|
||||||
def get(self, timeout=None):
|
def get(self, timeout=None):
|
||||||
with self.not_empty:
|
with self.not_empty:
|
||||||
while len(self.queue) == 0:
|
while len(self.queue) == 0:
|
||||||
@ -708,12 +717,12 @@ class PromptQueue:
|
|||||||
return None
|
return None
|
||||||
item = heapq.heappop(self.queue)
|
item = heapq.heappop(self.queue)
|
||||||
i = self.task_counter
|
i = self.task_counter
|
||||||
self.currently_running[i] = copy.deepcopy(item)
|
self.currently_running[i] = copy.deepcopy(item[0])
|
||||||
self.task_counter += 1
|
self.task_counter += 1
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
return (item, i)
|
return (item, i)
|
||||||
|
|
||||||
def task_done(self, item_id, outputs):
|
def task_done(self, item_id, outputs, future = None):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
prompt = self.currently_running.pop(item_id)
|
prompt = self.currently_running.pop(item_id)
|
||||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||||
@ -722,6 +731,8 @@ class PromptQueue:
|
|||||||
for o in outputs:
|
for o in outputs:
|
||||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
if future is not None:
|
||||||
|
future.set_result(outputs)
|
||||||
|
|
||||||
def get_current_queue(self):
|
def get_current_queue(self):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
|
|||||||
4
main.py
4
main.py
@ -99,12 +99,12 @@ def prompt_worker(q, server):
|
|||||||
|
|
||||||
queue_item = q.get(timeout=timeout)
|
queue_item = q.get(timeout=timeout)
|
||||||
if queue_item is not None:
|
if queue_item is not None:
|
||||||
item, item_id = queue_item
|
(item, future), item_id = queue_item
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
q.task_done(item_id, e.outputs_ui)
|
q.task_done(item_id, e.outputs_ui, future)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
|
|||||||
42
server.py
42
server.py
@ -488,6 +488,48 @@ class PromptServer():
|
|||||||
else:
|
else:
|
||||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||||
|
|
||||||
|
@routes.post("/prompt_async")
|
||||||
|
async def post_prompt_async(request):
|
||||||
|
print("yeeehaw")
|
||||||
|
resp_code = 200
|
||||||
|
out_string = ""
|
||||||
|
json_data = await request.json()
|
||||||
|
json_data = self.trigger_on_prompt(json_data)
|
||||||
|
|
||||||
|
if "number" in json_data:
|
||||||
|
number = float(json_data['number'])
|
||||||
|
else:
|
||||||
|
number = self.number
|
||||||
|
if "front" in json_data:
|
||||||
|
if json_data['front']:
|
||||||
|
number = -number
|
||||||
|
|
||||||
|
self.number += 1
|
||||||
|
|
||||||
|
if "prompt" in json_data:
|
||||||
|
prompt = json_data["prompt"]
|
||||||
|
valid = execution.validate_prompt(prompt)
|
||||||
|
extra_data = {}
|
||||||
|
if "extra_data" in json_data:
|
||||||
|
extra_data = json_data["extra_data"]
|
||||||
|
|
||||||
|
if "client_id" in json_data:
|
||||||
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
|
if valid[0]:
|
||||||
|
prompt_id = str(uuid.uuid4())
|
||||||
|
outputs_to_execute = valid[2]
|
||||||
|
future = self.prompt_queue.put_async((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||||
|
await future
|
||||||
|
|
||||||
|
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3], "results": future.result()}
|
||||||
|
return web.json_response(response)
|
||||||
|
else:
|
||||||
|
print("invalid prompt:", valid[1])
|
||||||
|
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||||
|
else:
|
||||||
|
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/queue")
|
@routes.post("/queue")
|
||||||
async def post_queue(request):
|
async def post_queue(request):
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user