From 61bbab42bf41485ae329203948c9ff81734f3469 Mon Sep 17 00:00:00 2001 From: Yao Chi Date: Wed, 15 Apr 2026 17:32:44 +0800 Subject: [PATCH] multiple workers --- execution.py | 26 +++++++++++++++++++++----- main.py | 12 +++++++++--- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/execution.py b/execution.py index 5e02dffb2..2223d845d 100644 --- a/execution.py +++ b/execution.py @@ -1202,7 +1202,7 @@ class PromptQueue: self.server.queue_updated() self.not_empty.notify() - def get(self, timeout=None): + def get(self, timeout=None, worker_id=0): with self.not_empty: while len(self.queue) == 0: self.not_empty.wait(timeout=timeout) @@ -1210,7 +1210,8 @@ class PromptQueue: return None item = heapq.heappop(self.queue) i = self.task_counter - self.currently_running[i] = copy.deepcopy(item) + # Store with worker_id to support multiple workers + self.currently_running[i] = {"worker_id": worker_id, "item": copy.deepcopy(item)} self.task_counter += 1 self.server.queue_updated() return (item, i) @@ -1223,7 +1224,12 @@ class PromptQueue: def task_done(self, item_id, history_result, status: Optional['PromptQueue.ExecutionStatus'], process_item=None): with self.mutex: - prompt = self.currently_running.pop(item_id) + running_entry = self.currently_running.pop(item_id) + # Support both old format (direct item) and new format (dict with worker_id and item) + if isinstance(running_entry, dict) and "item" in running_entry: + prompt = running_entry["item"] + else: + prompt = running_entry if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) @@ -1247,13 +1253,23 @@ class PromptQueue: with self.mutex: out = [] for x in self.currently_running.values(): - out += [x] + # Support both old format (direct item) and new format (dict with worker_id and item) + if isinstance(x, dict) and "item" in x: + out += [x["item"]] + else: + out += [x] return (out, copy.deepcopy(self.queue)) # read-safe as long as queue items are immutable def get_current_queue_volatile(self): with self.mutex: - running = [x for x in self.currently_running.values()] + running = [] + for x in self.currently_running.values(): + # Support both old format (direct item) and new format (dict with worker_id and item) + if isinstance(x, dict) and "item" in x: + running.append(x["item"]) + else: + running.append(x) queued = copy.copy(self.queue) return (running, queued) diff --git a/main.py b/main.py index 12b04719d..37930bbe2 100644 --- a/main.py +++ b/main.py @@ -273,7 +273,7 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]: return paths -def prompt_worker(q, server_instance): +def prompt_worker(q, server_instance, worker_id=0): current_time: float = 0.0 cache_ram = args.cache_ram if cache_ram < 0: @@ -297,7 +297,7 @@ def prompt_worker(q, server_instance): if need_gc: timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) - queue_item = q.get(timeout=timeout) + queue_item = q.get(timeout=timeout, worker_id=worker_id) if queue_item is not None: item, item_id = queue_item execution_start_time = time.perf_counter() @@ -478,7 +478,13 @@ def start_comfyui(asyncio_loop=None): prompt_server.add_routes() hijack_progress(prompt_server) - threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() + # Number of parallel workers - can be adjusted based on GPU memory and workload + # WARNING: Multiple workers will increase GPU memory usage significantly + NUM_WORKERS = int(os.environ.get("COMFYUI_NUM_WORKERS", "1")) + if NUM_WORKERS > 1: + logging.info(f"Starting {NUM_WORKERS} parallel prompt workers") + for worker_id in range(NUM_WORKERS): + threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server, worker_id)).start() if args.quick_test_for_ci: exit(0)