mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 19:07:25 +08:00
fix(server): serialize PromptMetadataStore access with a lock
Addresses comfyanonymous's review nit on PR #13905. The store is touched from three threads — the aiohttp event loop (``register`` via ``post_prompt``), the worker thread (``unregister`` via the ``prompt_worker`` try/finally and ``execution_error`` paths), and any thread that fires ``send_sync`` (``inject``). Individual ``dict`` operations are GIL-atomic but ``register``'s ``len -> pop -> setitem`` and ``inject``'s ``get -> {**a, **b}`` are multi-step compounds whose interleaving without a lock is racy. A single ``threading.Lock`` keeps the FIFO cap honest and snapshots the envelope under the lock before the spread runs. Adds a stress-test that runs concurrent register/unregister/inject for 100 ms across five threads and asserts no exception escapes and the capacity bound is held.
This commit is contained in:
parent
fc9820ebb9
commit
63784baed5
@ -15,6 +15,7 @@ instance and the helpers can be unit-tested without the rest of the app.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
@ -177,28 +178,49 @@ class PromptMetadataStore:
|
||||
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
|
||||
backstop: if any cleanup hook is ever skipped, the store sheds the
|
||||
oldest entry instead of growing without bound.
|
||||
|
||||
Access is serialized through a ``threading.Lock``. ``register`` runs
|
||||
on the aiohttp event-loop thread, ``unregister`` runs on the
|
||||
``prompt_worker`` thread, and ``inject`` runs on whichever thread
|
||||
fires ``send_sync`` (event loop, worker, asset seeder). Individual
|
||||
``dict`` ops are GIL-atomic, but ``register``'s
|
||||
``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}``
|
||||
are multi-step compounds whose interleaving without a lock is
|
||||
racy. The lock is uncontended in steady state (sub-microsecond
|
||||
critical sections) so the cost is negligible.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
|
||||
self._envelopes: dict[str, dict] = {}
|
||||
self._capacity = capacity
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, prompt_id: str, extra_data: Any) -> None:
|
||||
envelope = extract_envelope_from_extra_data(extra_data)
|
||||
if envelope is None:
|
||||
return
|
||||
if len(self._envelopes) >= self._capacity:
|
||||
self._envelopes.pop(next(iter(self._envelopes)))
|
||||
self._envelopes[prompt_id] = envelope
|
||||
with self._lock:
|
||||
if len(self._envelopes) >= self._capacity:
|
||||
self._envelopes.pop(next(iter(self._envelopes)))
|
||||
self._envelopes[prompt_id] = envelope
|
||||
|
||||
def unregister(self, prompt_id: str) -> None:
|
||||
self._envelopes.pop(prompt_id, None)
|
||||
with self._lock:
|
||||
self._envelopes.pop(prompt_id, None)
|
||||
|
||||
def inject(self, data: Any) -> Any:
|
||||
return inject_envelope(data, self._envelopes.get)
|
||||
# Snapshot the envelope under the lock so the spread in
|
||||
# ``inject_envelope`` runs against a consistent view even if a
|
||||
# concurrent ``register``/``unregister`` is mutating the map.
|
||||
def locked_lookup(prompt_id: str) -> Optional[dict]:
|
||||
with self._lock:
|
||||
return self._envelopes.get(prompt_id)
|
||||
return inject_envelope(data, locked_lookup)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._envelopes)
|
||||
with self._lock:
|
||||
return len(self._envelopes)
|
||||
|
||||
def __contains__(self, prompt_id: str) -> bool:
|
||||
return prompt_id in self._envelopes
|
||||
with self._lock:
|
||||
return prompt_id in self._envelopes
|
||||
|
||||
@ -348,3 +348,65 @@ class TestPromptMetadataStore:
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
})
|
||||
|
||||
def test_concurrent_access_does_not_corrupt_or_raise(self):
|
||||
"""Smoke test for the store's lock. ``register`` is called from
|
||||
the aiohttp event-loop thread, ``unregister`` from the worker
|
||||
thread, and ``inject`` fires on every ``send_sync`` from
|
||||
whichever thread emits the event. Run all three concurrently
|
||||
and assert no exception escapes and the store stays internally
|
||||
consistent (the FIFO cap is never exceeded)."""
|
||||
import threading
|
||||
|
||||
store = PromptMetadataStore(capacity=64)
|
||||
stop = threading.Event()
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def registrar():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.register(
|
||||
f"p{i % 100}",
|
||||
{"metadata": {"workflow_id": f"wf-{i}"}},
|
||||
)
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
def canceller():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.unregister(f"p{i % 100}")
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
def injector():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.inject({"prompt_id": f"p{i % 100}", "node": "5"})
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=registrar),
|
||||
threading.Thread(target=registrar),
|
||||
threading.Thread(target=canceller),
|
||||
threading.Thread(target=injector),
|
||||
threading.Thread(target=injector),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
# Brief burst — long enough to interleave many ops, short enough
|
||||
# not to slow CI.
|
||||
threading.Event().wait(0.1)
|
||||
stop.set()
|
||||
for t in threads:
|
||||
t.join(timeout=2.0)
|
||||
|
||||
assert errors == [], f"concurrent access raised: {errors[:3]}"
|
||||
assert len(store) <= 64, "FIFO cap was breached under contention"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user