mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-01 20:07:37 +08:00
fix(server): serialize PromptMetadataStore access with a lock
Addresses comfyanonymous's review nit on PR #13905. The store is touched from three threads — the aiohttp event loop (``register`` via ``post_prompt``), the worker thread (``unregister`` via the ``prompt_worker`` try/finally and ``execution_error`` paths), and any thread that fires ``send_sync`` (``inject``). Individual ``dict`` operations are GIL-atomic but ``register``'s ``len -> pop -> setitem`` and ``inject``'s ``get -> {**a, **b}`` are multi-step compounds whose interleaving without a lock is racy. A single ``threading.Lock`` keeps the FIFO cap honest and snapshots the envelope under the lock before the spread runs. Adds a stress-test that runs concurrent register/unregister/inject for 100 ms across five threads and asserts no exception escapes and the capacity bound is held.
This commit is contained in:
parent
fc9820ebb9
commit
63784baed5
@ -15,6 +15,7 @@ instance and the helpers can be unit-tested without the rest of the app.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
@ -177,28 +178,49 @@ class PromptMetadataStore:
|
|||||||
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
|
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
|
||||||
backstop: if any cleanup hook is ever skipped, the store sheds the
|
backstop: if any cleanup hook is ever skipped, the store sheds the
|
||||||
oldest entry instead of growing without bound.
|
oldest entry instead of growing without bound.
|
||||||
|
|
||||||
|
Access is serialized through a ``threading.Lock``. ``register`` runs
|
||||||
|
on the aiohttp event-loop thread, ``unregister`` runs on the
|
||||||
|
``prompt_worker`` thread, and ``inject`` runs on whichever thread
|
||||||
|
fires ``send_sync`` (event loop, worker, asset seeder). Individual
|
||||||
|
``dict`` ops are GIL-atomic, but ``register``'s
|
||||||
|
``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}``
|
||||||
|
are multi-step compounds whose interleaving without a lock is
|
||||||
|
racy. The lock is uncontended in steady state (sub-microsecond
|
||||||
|
critical sections) so the cost is negligible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
|
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
|
||||||
self._envelopes: dict[str, dict] = {}
|
self._envelopes: dict[str, dict] = {}
|
||||||
self._capacity = capacity
|
self._capacity = capacity
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def register(self, prompt_id: str, extra_data: Any) -> None:
|
def register(self, prompt_id: str, extra_data: Any) -> None:
|
||||||
envelope = extract_envelope_from_extra_data(extra_data)
|
envelope = extract_envelope_from_extra_data(extra_data)
|
||||||
if envelope is None:
|
if envelope is None:
|
||||||
return
|
return
|
||||||
if len(self._envelopes) >= self._capacity:
|
with self._lock:
|
||||||
self._envelopes.pop(next(iter(self._envelopes)))
|
if len(self._envelopes) >= self._capacity:
|
||||||
self._envelopes[prompt_id] = envelope
|
self._envelopes.pop(next(iter(self._envelopes)))
|
||||||
|
self._envelopes[prompt_id] = envelope
|
||||||
|
|
||||||
def unregister(self, prompt_id: str) -> None:
|
def unregister(self, prompt_id: str) -> None:
|
||||||
self._envelopes.pop(prompt_id, None)
|
with self._lock:
|
||||||
|
self._envelopes.pop(prompt_id, None)
|
||||||
|
|
||||||
def inject(self, data: Any) -> Any:
|
def inject(self, data: Any) -> Any:
|
||||||
return inject_envelope(data, self._envelopes.get)
|
# Snapshot the envelope under the lock so the spread in
|
||||||
|
# ``inject_envelope`` runs against a consistent view even if a
|
||||||
|
# concurrent ``register``/``unregister`` is mutating the map.
|
||||||
|
def locked_lookup(prompt_id: str) -> Optional[dict]:
|
||||||
|
with self._lock:
|
||||||
|
return self._envelopes.get(prompt_id)
|
||||||
|
return inject_envelope(data, locked_lookup)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._envelopes)
|
with self._lock:
|
||||||
|
return len(self._envelopes)
|
||||||
|
|
||||||
def __contains__(self, prompt_id: str) -> bool:
|
def __contains__(self, prompt_id: str) -> bool:
|
||||||
return prompt_id in self._envelopes
|
with self._lock:
|
||||||
|
return prompt_id in self._envelopes
|
||||||
|
|||||||
@ -348,3 +348,65 @@ class TestPromptMetadataStore:
|
|||||||
"prompt_id": "p1",
|
"prompt_id": "p1",
|
||||||
"workflow_id": "wf-1",
|
"workflow_id": "wf-1",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def test_concurrent_access_does_not_corrupt_or_raise(self):
|
||||||
|
"""Smoke test for the store's lock. ``register`` is called from
|
||||||
|
the aiohttp event-loop thread, ``unregister`` from the worker
|
||||||
|
thread, and ``inject`` fires on every ``send_sync`` from
|
||||||
|
whichever thread emits the event. Run all three concurrently
|
||||||
|
and assert no exception escapes and the store stays internally
|
||||||
|
consistent (the FIFO cap is never exceeded)."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
store = PromptMetadataStore(capacity=64)
|
||||||
|
stop = threading.Event()
|
||||||
|
errors: list[BaseException] = []
|
||||||
|
|
||||||
|
def registrar():
|
||||||
|
i = 0
|
||||||
|
try:
|
||||||
|
while not stop.is_set():
|
||||||
|
store.register(
|
||||||
|
f"p{i % 100}",
|
||||||
|
{"metadata": {"workflow_id": f"wf-{i}"}},
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
except BaseException as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
def canceller():
|
||||||
|
i = 0
|
||||||
|
try:
|
||||||
|
while not stop.is_set():
|
||||||
|
store.unregister(f"p{i % 100}")
|
||||||
|
i += 1
|
||||||
|
except BaseException as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
def injector():
|
||||||
|
i = 0
|
||||||
|
try:
|
||||||
|
while not stop.is_set():
|
||||||
|
store.inject({"prompt_id": f"p{i % 100}", "node": "5"})
|
||||||
|
i += 1
|
||||||
|
except BaseException as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=registrar),
|
||||||
|
threading.Thread(target=registrar),
|
||||||
|
threading.Thread(target=canceller),
|
||||||
|
threading.Thread(target=injector),
|
||||||
|
threading.Thread(target=injector),
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
# Brief burst — long enough to interleave many ops, short enough
|
||||||
|
# not to slow CI.
|
||||||
|
threading.Event().wait(0.1)
|
||||||
|
stop.set()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=2.0)
|
||||||
|
|
||||||
|
assert errors == [], f"concurrent access raised: {errors[:3]}"
|
||||||
|
assert len(store) <= 64, "FIFO cap was breached under contention"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user