diff --git a/app/prompt_metadata.py b/app/prompt_metadata.py index 821a516f2..7d6c29db1 100644 --- a/app/prompt_metadata.py +++ b/app/prompt_metadata.py @@ -15,6 +15,7 @@ instance and the helpers can be unit-tested without the rest of the app. from __future__ import annotations import logging +import threading from typing import Any, Callable, Optional @@ -177,28 +178,49 @@ class PromptMetadataStore: prompt finishes, wiped on queue cancel/delete. The FIFO cap is a backstop: if any cleanup hook is ever skipped, the store sheds the oldest entry instead of growing without bound. + + Access is serialized through a ``threading.Lock``. ``register`` runs + on the aiohttp event-loop thread, ``unregister`` runs on the + ``prompt_worker`` thread, and ``inject`` runs on whichever thread + fires ``send_sync`` (event loop, worker, asset seeder). Individual + ``dict`` ops are GIL-atomic, but ``register``'s + ``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}`` + are multi-step compounds whose interleaving without a lock is + racy. The lock is uncontended in steady state (sub-microsecond + critical sections) so the cost is negligible. """ def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY): self._envelopes: dict[str, dict] = {} self._capacity = capacity + self._lock = threading.Lock() def register(self, prompt_id: str, extra_data: Any) -> None: envelope = extract_envelope_from_extra_data(extra_data) if envelope is None: return - if len(self._envelopes) >= self._capacity: - self._envelopes.pop(next(iter(self._envelopes))) - self._envelopes[prompt_id] = envelope + with self._lock: + if len(self._envelopes) >= self._capacity: + self._envelopes.pop(next(iter(self._envelopes))) + self._envelopes[prompt_id] = envelope def unregister(self, prompt_id: str) -> None: - self._envelopes.pop(prompt_id, None) + with self._lock: + self._envelopes.pop(prompt_id, None) def inject(self, data: Any) -> Any: - return inject_envelope(data, self._envelopes.get) + # Snapshot the envelope under the lock so the spread in + # ``inject_envelope`` runs against a consistent view even if a + # concurrent ``register``/``unregister`` is mutating the map. + def locked_lookup(prompt_id: str) -> Optional[dict]: + with self._lock: + return self._envelopes.get(prompt_id) + return inject_envelope(data, locked_lookup) def __len__(self) -> int: - return len(self._envelopes) + with self._lock: + return len(self._envelopes) def __contains__(self, prompt_id: str) -> bool: - return prompt_id in self._envelopes + with self._lock: + return prompt_id in self._envelopes diff --git a/tests-unit/app_test/test_prompt_metadata.py b/tests-unit/app_test/test_prompt_metadata.py index 8d88d788e..b5241dd5d 100644 --- a/tests-unit/app_test/test_prompt_metadata.py +++ b/tests-unit/app_test/test_prompt_metadata.py @@ -348,3 +348,65 @@ class TestPromptMetadataStore: "prompt_id": "p1", "workflow_id": "wf-1", }) + + def test_concurrent_access_does_not_corrupt_or_raise(self): + """Smoke test for the store's lock. ``register`` is called from + the aiohttp event-loop thread, ``unregister`` from the worker + thread, and ``inject`` fires on every ``send_sync`` from + whichever thread emits the event. Run all three concurrently + and assert no exception escapes and the store stays internally + consistent (the FIFO cap is never exceeded).""" + import threading + + store = PromptMetadataStore(capacity=64) + stop = threading.Event() + errors: list[BaseException] = [] + + def registrar(): + i = 0 + try: + while not stop.is_set(): + store.register( + f"p{i % 100}", + {"metadata": {"workflow_id": f"wf-{i}"}}, + ) + i += 1 + except BaseException as e: + errors.append(e) + + def canceller(): + i = 0 + try: + while not stop.is_set(): + store.unregister(f"p{i % 100}") + i += 1 + except BaseException as e: + errors.append(e) + + def injector(): + i = 0 + try: + while not stop.is_set(): + store.inject({"prompt_id": f"p{i % 100}", "node": "5"}) + i += 1 + except BaseException as e: + errors.append(e) + + threads = [ + threading.Thread(target=registrar), + threading.Thread(target=registrar), + threading.Thread(target=canceller), + threading.Thread(target=injector), + threading.Thread(target=injector), + ] + for t in threads: + t.start() + # Brief burst — long enough to interleave many ops, short enough + # not to slow CI. + threading.Event().wait(0.1) + stop.set() + for t in threads: + t.join(timeout=2.0) + + assert errors == [], f"concurrent access raised: {errors[:3]}" + assert len(store) <= 64, "FIFO cap was breached under contention"