multiple workers

This commit is contained in:
Yao Chi 2026-04-15 17:32:44 +08:00
parent 8f374716ee
commit 61bbab42bf
2 changed files with 30 additions and 8 deletions

View File

@ -1202,7 +1202,7 @@ class PromptQueue:
self.server.queue_updated()
self.not_empty.notify()
def get(self, timeout=None):
def get(self, timeout=None, worker_id=0):
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait(timeout=timeout)
@ -1210,7 +1210,8 @@ class PromptQueue:
return None
item = heapq.heappop(self.queue)
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
# Store with worker_id to support multiple workers
self.currently_running[i] = {"worker_id": worker_id, "item": copy.deepcopy(item)}
self.task_counter += 1
self.server.queue_updated()
return (item, i)
@ -1223,7 +1224,12 @@ class PromptQueue:
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus'], process_item=None):
with self.mutex:
prompt = self.currently_running.pop(item_id)
running_entry = self.currently_running.pop(item_id)
# Support both old format (direct item) and new format (dict with worker_id and item)
if isinstance(running_entry, dict) and "item" in running_entry:
prompt = running_entry["item"]
else:
prompt = running_entry
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
@ -1247,13 +1253,23 @@ class PromptQueue:
with self.mutex:
out = []
for x in self.currently_running.values():
out += [x]
# Support both old format (direct item) and new format (dict with worker_id and item)
if isinstance(x, dict) and "item" in x:
out += [x["item"]]
else:
out += [x]
return (out, copy.deepcopy(self.queue))
# read-safe as long as queue items are immutable
def get_current_queue_volatile(self):
with self.mutex:
running = [x for x in self.currently_running.values()]
running = []
for x in self.currently_running.values():
# Support both old format (direct item) and new format (dict with worker_id and item)
if isinstance(x, dict) and "item" in x:
running.append(x["item"])
else:
running.append(x)
queued = copy.copy(self.queue)
return (running, queued)

12
main.py
View File

@ -273,7 +273,7 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
return paths
def prompt_worker(q, server_instance):
def prompt_worker(q, server_instance, worker_id=0):
current_time: float = 0.0
cache_ram = args.cache_ram
if cache_ram < 0:
@ -297,7 +297,7 @@ def prompt_worker(q, server_instance):
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
queue_item = q.get(timeout=timeout)
queue_item = q.get(timeout=timeout, worker_id=worker_id)
if queue_item is not None:
item, item_id = queue_item
execution_start_time = time.perf_counter()
@ -478,7 +478,13 @@ def start_comfyui(asyncio_loop=None):
prompt_server.add_routes()
hijack_progress(prompt_server)
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
# Number of parallel workers - can be adjusted based on GPU memory and workload
# WARNING: Multiple workers will increase GPU memory usage significantly
NUM_WORKERS = int(os.environ.get("COMFYUI_NUM_WORKERS", "1"))
if NUM_WORKERS > 1:
logging.info(f"Starting {NUM_WORKERS} parallel prompt workers")
for worker_id in range(NUM_WORKERS):
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server, worker_id)).start()
if args.quick_test_for_ci:
exit(0)