fix(server): serialize PromptMetadataStore access with a lock

Addresses comfyanonymous's review nit on PR #13905. The store is touched
from three threads — the aiohttp event loop (``register`` via
``post_prompt``), the worker thread (``unregister`` via the
``prompt_worker`` try/finally and ``execution_error`` paths), and any
thread that fires ``send_sync`` (``inject``). Individual ``dict``
operations are GIL-atomic but ``register``'s ``len -> pop -> setitem``
and ``inject``'s ``get -> {**a, **b}`` are multi-step compounds whose
interleaving without a lock is racy. A single ``threading.Lock`` keeps
the FIFO cap honest and snapshots the envelope under the lock before
the spread runs.

Adds a stress-test that runs concurrent register/unregister/inject for
100 ms across five threads and asserts no exception escapes and the
capacity bound is held.
This commit is contained in:
Deep Mehta 2026-05-14 21:26:06 -07:00
parent fc9820ebb9
commit 63784baed5
2 changed files with 91 additions and 7 deletions

View File

@ -15,6 +15,7 @@ instance and the helpers can be unit-tested without the rest of the app.
from __future__ import annotations
import logging
import threading
from typing import Any, Callable, Optional
@ -177,28 +178,49 @@ class PromptMetadataStore:
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
backstop: if any cleanup hook is ever skipped, the store sheds the
oldest entry instead of growing without bound.
Access is serialized through a ``threading.Lock``. ``register`` runs
on the aiohttp event-loop thread, ``unregister`` runs on the
``prompt_worker`` thread, and ``inject`` runs on whichever thread
fires ``send_sync`` (event loop, worker, asset seeder). Individual
``dict`` ops are GIL-atomic, but ``register``'s
``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}``
are multi-step compounds whose interleaving without a lock is
racy. The lock is uncontended in steady state (sub-microsecond
critical sections) so the cost is negligible.
"""
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
self._envelopes: dict[str, dict] = {}
self._capacity = capacity
self._lock = threading.Lock()
def register(self, prompt_id: str, extra_data: Any) -> None:
envelope = extract_envelope_from_extra_data(extra_data)
if envelope is None:
return
if len(self._envelopes) >= self._capacity:
self._envelopes.pop(next(iter(self._envelopes)))
self._envelopes[prompt_id] = envelope
with self._lock:
if len(self._envelopes) >= self._capacity:
self._envelopes.pop(next(iter(self._envelopes)))
self._envelopes[prompt_id] = envelope
def unregister(self, prompt_id: str) -> None:
self._envelopes.pop(prompt_id, None)
with self._lock:
self._envelopes.pop(prompt_id, None)
def inject(self, data: Any) -> Any:
return inject_envelope(data, self._envelopes.get)
# Snapshot the envelope under the lock so the spread in
# ``inject_envelope`` runs against a consistent view even if a
# concurrent ``register``/``unregister`` is mutating the map.
def locked_lookup(prompt_id: str) -> Optional[dict]:
with self._lock:
return self._envelopes.get(prompt_id)
return inject_envelope(data, locked_lookup)
def __len__(self) -> int:
return len(self._envelopes)
with self._lock:
return len(self._envelopes)
def __contains__(self, prompt_id: str) -> bool:
return prompt_id in self._envelopes
with self._lock:
return prompt_id in self._envelopes

View File

@ -348,3 +348,65 @@ class TestPromptMetadataStore:
"prompt_id": "p1",
"workflow_id": "wf-1",
})
def test_concurrent_access_does_not_corrupt_or_raise(self):
"""Smoke test for the store's lock. ``register`` is called from
the aiohttp event-loop thread, ``unregister`` from the worker
thread, and ``inject`` fires on every ``send_sync`` from
whichever thread emits the event. Run all three concurrently
and assert no exception escapes and the store stays internally
consistent (the FIFO cap is never exceeded)."""
import threading
store = PromptMetadataStore(capacity=64)
stop = threading.Event()
errors: list[BaseException] = []
def registrar():
i = 0
try:
while not stop.is_set():
store.register(
f"p{i % 100}",
{"metadata": {"workflow_id": f"wf-{i}"}},
)
i += 1
except BaseException as e:
errors.append(e)
def canceller():
i = 0
try:
while not stop.is_set():
store.unregister(f"p{i % 100}")
i += 1
except BaseException as e:
errors.append(e)
def injector():
i = 0
try:
while not stop.is_set():
store.inject({"prompt_id": f"p{i % 100}", "node": "5"})
i += 1
except BaseException as e:
errors.append(e)
threads = [
threading.Thread(target=registrar),
threading.Thread(target=registrar),
threading.Thread(target=canceller),
threading.Thread(target=injector),
threading.Thread(target=injector),
]
for t in threads:
t.start()
# Brief burst — long enough to interleave many ops, short enough
# not to slow CI.
threading.Event().wait(0.1)
stop.set()
for t in threads:
t.join(timeout=2.0)
assert errors == [], f"concurrent access raised: {errors[:3]}"
assert len(store) <= 64, "FIFO cap was breached under contention"