mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-04 14:22:35 +08:00
multiple workers
This commit is contained in:
parent
8f374716ee
commit
61bbab42bf
26
execution.py
26
execution.py
@ -1202,7 +1202,7 @@ class PromptQueue:
|
|||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
self.not_empty.notify()
|
self.not_empty.notify()
|
||||||
|
|
||||||
def get(self, timeout=None):
|
def get(self, timeout=None, worker_id=0):
|
||||||
with self.not_empty:
|
with self.not_empty:
|
||||||
while len(self.queue) == 0:
|
while len(self.queue) == 0:
|
||||||
self.not_empty.wait(timeout=timeout)
|
self.not_empty.wait(timeout=timeout)
|
||||||
@ -1210,7 +1210,8 @@ class PromptQueue:
|
|||||||
return None
|
return None
|
||||||
item = heapq.heappop(self.queue)
|
item = heapq.heappop(self.queue)
|
||||||
i = self.task_counter
|
i = self.task_counter
|
||||||
self.currently_running[i] = copy.deepcopy(item)
|
# Store with worker_id to support multiple workers
|
||||||
|
self.currently_running[i] = {"worker_id": worker_id, "item": copy.deepcopy(item)}
|
||||||
self.task_counter += 1
|
self.task_counter += 1
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
return (item, i)
|
return (item, i)
|
||||||
@ -1223,7 +1224,12 @@ class PromptQueue:
|
|||||||
def task_done(self, item_id, history_result,
|
def task_done(self, item_id, history_result,
|
||||||
status: Optional['PromptQueue.ExecutionStatus'], process_item=None):
|
status: Optional['PromptQueue.ExecutionStatus'], process_item=None):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
prompt = self.currently_running.pop(item_id)
|
running_entry = self.currently_running.pop(item_id)
|
||||||
|
# Support both old format (direct item) and new format (dict with worker_id and item)
|
||||||
|
if isinstance(running_entry, dict) and "item" in running_entry:
|
||||||
|
prompt = running_entry["item"]
|
||||||
|
else:
|
||||||
|
prompt = running_entry
|
||||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||||
self.history.pop(next(iter(self.history)))
|
self.history.pop(next(iter(self.history)))
|
||||||
|
|
||||||
@ -1247,13 +1253,23 @@ class PromptQueue:
|
|||||||
with self.mutex:
|
with self.mutex:
|
||||||
out = []
|
out = []
|
||||||
for x in self.currently_running.values():
|
for x in self.currently_running.values():
|
||||||
out += [x]
|
# Support both old format (direct item) and new format (dict with worker_id and item)
|
||||||
|
if isinstance(x, dict) and "item" in x:
|
||||||
|
out += [x["item"]]
|
||||||
|
else:
|
||||||
|
out += [x]
|
||||||
return (out, copy.deepcopy(self.queue))
|
return (out, copy.deepcopy(self.queue))
|
||||||
|
|
||||||
# read-safe as long as queue items are immutable
|
# read-safe as long as queue items are immutable
|
||||||
def get_current_queue_volatile(self):
|
def get_current_queue_volatile(self):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
running = [x for x in self.currently_running.values()]
|
running = []
|
||||||
|
for x in self.currently_running.values():
|
||||||
|
# Support both old format (direct item) and new format (dict with worker_id and item)
|
||||||
|
if isinstance(x, dict) and "item" in x:
|
||||||
|
running.append(x["item"])
|
||||||
|
else:
|
||||||
|
running.append(x)
|
||||||
queued = copy.copy(self.queue)
|
queued = copy.copy(self.queue)
|
||||||
return (running, queued)
|
return (running, queued)
|
||||||
|
|
||||||
|
|||||||
12
main.py
12
main.py
@ -273,7 +273,7 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
|||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance, worker_id=0):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
cache_ram = args.cache_ram
|
cache_ram = args.cache_ram
|
||||||
if cache_ram < 0:
|
if cache_ram < 0:
|
||||||
@ -297,7 +297,7 @@ def prompt_worker(q, server_instance):
|
|||||||
if need_gc:
|
if need_gc:
|
||||||
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
||||||
|
|
||||||
queue_item = q.get(timeout=timeout)
|
queue_item = q.get(timeout=timeout, worker_id=worker_id)
|
||||||
if queue_item is not None:
|
if queue_item is not None:
|
||||||
item, item_id = queue_item
|
item, item_id = queue_item
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
@ -478,7 +478,13 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
# Number of parallel workers - can be adjusted based on GPU memory and workload
|
||||||
|
# WARNING: Multiple workers will increase GPU memory usage significantly
|
||||||
|
NUM_WORKERS = int(os.environ.get("COMFYUI_NUM_WORKERS", "1"))
|
||||||
|
if NUM_WORKERS > 1:
|
||||||
|
logging.info(f"Starting {NUM_WORKERS} parallel prompt workers")
|
||||||
|
for worker_id in range(NUM_WORKERS):
|
||||||
|
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server, worker_id)).start()
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user