This commit is contained in:
Deep Mehta 2026-05-17 21:22:34 -07:00 committed by GitHub
commit 3b51306917
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 690 additions and 13 deletions

226
app/prompt_metadata.py Normal file
View File

@ -0,0 +1,226 @@
"""Per-prompt metadata envelope shared between submission and outbound events.
The metadata envelope is a small flat ``dict[str, str]`` (e.g.
``{"workflow_id": ...}``) attached to a prompt at submission and injected
by the server into every outbound execution event that carries a
``prompt_id``. It lets consumers scope state by tags they care about
(workflow, trace, tenant) without the execution layer ever needing to
know those tags exist.
This module is intentionally pure no imports from ``server`` or
``execution`` so ``PromptServer`` can own a ``PromptMetadataStore``
instance and the helpers can be unit-tested without the rest of the app.
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Callable, Optional
# Bounds. The envelope is forwarded to every WebSocket client connected to
# the server on every execution event for the prompt — bounding key count,
# key length, value length, and refusing nested structures keeps a
# malicious or buggy client from inflating the broadcast volume.
MAX_ENVELOPE_KEYS = 16
MAX_ENVELOPE_KEY_LEN = 64
MAX_ENVELOPE_VALUE_LEN = 256
# Cap on concurrently registered prompt envelopes. Acts as a backstop if
# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale
# entry goes first.
DEFAULT_STORE_CAPACITY = 4096
def _sanitize_envelope(envelope: Any) -> Optional[dict]:
"""Validate and copy a candidate envelope.
Enforces the ``dict[str, str]`` contract that downstream consumers
(cloud projections, frontend zod schemas, OpenAPI docs) rely on:
- must be a non-empty ``dict``
- at most ``MAX_ENVELOPE_KEYS`` entries
- every key and value must be a ``str``
- keys at most ``MAX_ENVELOPE_KEY_LEN`` chars
- values at most ``MAX_ENVELOPE_VALUE_LEN`` chars
Returns a defensive shallow copy on success, ``None`` on any
violation. Logs a warning on violation so abuse is visible.
"""
if not isinstance(envelope, dict) or not envelope:
return None
if len(envelope) > MAX_ENVELOPE_KEYS:
logging.warning(
"prompt metadata envelope rejected: %d keys exceeds limit %d",
len(envelope), MAX_ENVELOPE_KEYS,
)
return None
sanitized: dict[str, str] = {}
for key, value in envelope.items():
if not isinstance(key, str) or not isinstance(value, str):
logging.warning(
"prompt metadata envelope rejected: non-string key/value (%s=%s)",
type(key).__name__, type(value).__name__,
)
return None
if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN:
logging.warning(
"prompt metadata envelope rejected: key or value exceeds length limit",
)
return None
sanitized[key] = value
return sanitized
def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]:
"""Pull the per-prompt metadata envelope out of a submitted prompt's
``extra_data``.
Two sources, in order:
1. Explicit ``extra_data["metadata"]`` sanitized via
``_sanitize_envelope``. Oversized or wrong-typed envelopes are
rejected (a warning is logged) rather than truncated, so the
contract stays strict at the boundary.
2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` backward-
compatibility fallback. Frontends that already stamp the workflow
id into ``extra_pnginfo`` keep working; the synthesized envelope
is ``{"workflow_id": <id>}``. A debug log fires so the legacy path
remains observable.
Returns ``None`` when neither source yields a usable envelope.
"""
if not isinstance(extra_data, dict):
return None
if "metadata" in extra_data:
sanitized = _sanitize_envelope(extra_data["metadata"])
if sanitized is not None:
return sanitized
# Explicit metadata was supplied but rejected — do not fall
# through to the legacy path; the caller asked for something
# specific and got it wrong.
if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]:
return None
extra_pnginfo = extra_data.get("extra_pnginfo")
if isinstance(extra_pnginfo, dict):
workflow = extra_pnginfo.get("workflow")
if isinstance(workflow, dict):
workflow_id = workflow.get("id")
if (
isinstance(workflow_id, str)
and workflow_id
and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN
):
logging.debug(
"prompt metadata envelope synthesized from extra_pnginfo.workflow.id"
)
return {"workflow_id": workflow_id}
return None
def inject_envelope(
data: Any,
envelope_lookup: Callable[[str], Optional[dict]],
) -> Any:
"""Return ``data`` with the per-prompt envelope's keys spread onto it.
``envelope_lookup`` is called with the payload's ``prompt_id`` and is
expected to return the registered envelope or ``None``. This keeps
the function pure and avoids depending on any specific storage.
The envelope's keys are merged onto the payload at the top level so
consumers can read them directly (e.g. ``event.workflow_id``)
matching the wire shape of the prior workflow-id-on-events work and
avoiding an extra nesting hop for clients. Server-emitted fields on
the payload always win on collision (``{**envelope, **d}``); a
misbehaving client cannot shadow ``prompt_id``, ``node``, etc.
Two payload shapes are handled:
- **dict** carrying ``prompt_id``. A shallow copy is returned with
the envelope's keys merged onto it.
- **(preview_image, metadata_dict) tuple** the format used by
``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented;
the binary preview is passed through by reference.
No-op for payloads without a ``prompt_id``, prompts with no
registered envelope, or any other payload shape.
"""
def inject(d: dict) -> dict:
if not isinstance(d, dict):
return d
prompt_id = d.get("prompt_id")
if not prompt_id:
return d
envelope = envelope_lookup(prompt_id)
if envelope is None:
return d
return {**envelope, **d}
if isinstance(data, dict):
return inject(data)
if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict):
injected = inject(data[1])
if injected is data[1]:
return data
return (data[0], injected)
return data
class PromptMetadataStore:
"""Bounded ``prompt_id -> envelope`` map.
Owned by ``PromptServer``. Populated at submission, drained when the
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
backstop: if any cleanup hook is ever skipped, the store sheds the
oldest entry instead of growing without bound.
Access is serialized through a ``threading.Lock``. ``register`` runs
on the aiohttp event-loop thread, ``unregister`` runs on the
``prompt_worker`` thread, and ``inject`` runs on whichever thread
fires ``send_sync`` (event loop, worker, asset seeder). Individual
``dict`` ops are GIL-atomic, but ``register``'s
``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}``
are multi-step compounds whose interleaving without a lock is
racy. The lock is uncontended in steady state (sub-microsecond
critical sections) so the cost is negligible.
"""
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
self._envelopes: dict[str, dict] = {}
self._capacity = capacity
self._lock = threading.Lock()
def register(self, prompt_id: str, extra_data: Any) -> None:
envelope = extract_envelope_from_extra_data(extra_data)
if envelope is None:
return
with self._lock:
if len(self._envelopes) >= self._capacity:
self._envelopes.pop(next(iter(self._envelopes)))
self._envelopes[prompt_id] = envelope
def unregister(self, prompt_id: str) -> None:
with self._lock:
self._envelopes.pop(prompt_id, None)
def inject(self, data: Any) -> Any:
# Snapshot the envelope under the lock so the spread in
# ``inject_envelope`` runs against a consistent view even if a
# concurrent ``register``/``unregister`` is mutating the map.
def locked_lookup(prompt_id: str) -> Optional[dict]:
with self._lock:
return self._envelopes.get(prompt_id)
return inject_envelope(data, locked_lookup)
def __len__(self) -> int:
with self._lock:
return len(self._envelopes)
def __contains__(self, prompt_id: str) -> bool:
with self._lock:
return prompt_id in self._envelopes

View File

@ -1296,7 +1296,10 @@ class PromptQueue:
def wipe_queue(self):
with self.mutex:
dropped_prompt_ids = [item[1] for item in self.queue]
self.queue = []
for prompt_id in dropped_prompt_ids:
self.server.unregister_prompt_metadata(prompt_id)
self.server.queue_updated()
def delete_queue_item(self, function):
@ -1306,8 +1309,9 @@ class PromptQueue:
if len(self.queue) == 1:
self.wipe_queue()
else:
self.queue.pop(x)
deleted = self.queue.pop(x)
heapq.heapify(self.queue)
self.server.unregister_prompt_metadata(deleted[1])
self.server.queue_updated()
return True
return False

29
main.py
View File

@ -318,19 +318,26 @@ def prompt_worker(q, server_instance):
extra_data[k] = sensitive[k]
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
try:
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
finally:
# Always drop the metadata envelope. If e.execute() raises
# hard before its own error handling kicks in, the
# registered envelope would otherwise leak for the
# lifetime of the process.
server_instance.unregister_prompt_metadata(prompt_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time

View File

@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from app.prompt_metadata import PromptMetadataStore
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@ -250,8 +251,14 @@ class PromptServer():
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
self.last_prompt_id = None
self.client_id = None
# Bounded prompt_id -> envelope store. Populated at submission,
# drained on completion/cancel. Keeps workflow scope (and other
# client-supplied tags) out of the execution layer.
self._prompt_metadata = PromptMetadataStore()
self.on_prompt_handlers = []
@routes.get('/ws')
@ -275,7 +282,12 @@ class PromptServer():
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
payload: dict = {"node": self.last_node_id}
if self.last_prompt_id is not None:
payload["prompt_id"] = self.last_prompt_id
await self.send(
"executing", self._inject_prompt_metadata(payload), sid
)
# Flag to track if we've received the first message
first_message = True
@ -955,6 +967,7 @@ class PromptServer():
if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
self.register_prompt_metadata(prompt_id, extra_data)
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
@ -1216,7 +1229,22 @@ class PromptServer():
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message)
def register_prompt_metadata(self, prompt_id: str, extra_data) -> None:
"""Record per-prompt metadata for injection into outbound execution
events. Called at submission, before the prompt is queued."""
self._prompt_metadata.register(prompt_id, extra_data)
def unregister_prompt_metadata(self, prompt_id: str) -> None:
"""Drop the per-prompt metadata envelope. Called after the prompt
has finished executing and its terminal events have been queued,
or when the prompt is cancelled before reaching the worker."""
self._prompt_metadata.unregister(prompt_id)
def _inject_prompt_metadata(self, data):
return self._prompt_metadata.inject(data)
def send_sync(self, event, data, sid=None):
data = self._inject_prompt_metadata(data)
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))

View File

@ -0,0 +1,412 @@
"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``.
Covers the two pure helpers (``extract_envelope_from_extra_data`` and
``inject_envelope``) and the ``PromptMetadataStore`` integration class
that ``PromptServer`` owns.
"""
from __future__ import annotations
import pytest
from app.prompt_metadata import (
MAX_ENVELOPE_KEYS,
MAX_ENVELOPE_KEY_LEN,
MAX_ENVELOPE_VALUE_LEN,
PromptMetadataStore,
extract_envelope_from_extra_data,
inject_envelope,
)
class TestExtractEnvelopeFromExtraData:
def test_explicit_metadata_dict_is_used_as_is(self):
extra_data = {"metadata": {"workflow_id": "wf-1", "trace_id": "t-9"}}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-1",
"trace_id": "t-9",
}
def test_explicit_metadata_takes_precedence_over_extra_pnginfo(self):
extra_data = {
"metadata": {"workflow_id": "explicit"},
"extra_pnginfo": {"workflow": {"id": "fallback"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "explicit"
}
def test_falls_back_to_extra_pnginfo_workflow_id(self):
extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-legacy"}}}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
def test_returns_none_when_no_metadata_and_no_workflow_id(self):
assert extract_envelope_from_extra_data({}) is None
assert (
extract_envelope_from_extra_data({"extra_pnginfo": {"workflow": {}}})
is None
)
@pytest.mark.parametrize("bad", ["", 123, None, [], {}])
def test_rejects_non_string_or_empty_workflow_id(self, bad):
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
assert extract_envelope_from_extra_data(extra_data) is None
def test_rejects_non_dict_inputs_at_each_level(self):
assert extract_envelope_from_extra_data(None) is None
assert extract_envelope_from_extra_data("not-a-dict") is None
assert (
extract_envelope_from_extra_data({"extra_pnginfo": "not-a-dict"})
is None
)
assert (
extract_envelope_from_extra_data(
{"extra_pnginfo": {"workflow": "not-a-dict"}}
)
is None
)
def test_empty_explicit_metadata_falls_through_to_workflow_id(self):
extra_data = {
"metadata": {},
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
def test_returned_envelope_is_copy_not_reference(self):
original = {"workflow_id": "wf-1"}
result = extract_envelope_from_extra_data({"metadata": original})
assert result is not None
result["new_key"] = "x"
assert "new_key" not in original
def test_non_dict_explicit_metadata_falls_through_to_workflow_id(self):
extra_data = {
"metadata": "not-a-dict",
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
class TestEnvelopeSanitization:
"""The wire contract is ``dict[str, str]`` with bounded size. A bad
envelope is dropped (and a warning is logged) rather than truncated,
so the boundary stays strict."""
def test_rejects_too_many_keys(self, caplog):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}
with caplog.at_level("WARNING"):
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
assert any("exceeds limit" in r.message for r in caplog.records)
def test_accepts_max_keys_exactly(self):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_rejects_non_string_keys(self, caplog):
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data({"metadata": {42: "v"}})
is None
)
assert any("non-string" in r.message for r in caplog.records)
def test_rejects_non_string_values(self, caplog):
for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]:
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data(
{"metadata": {"k": bad_value}}
)
is None
)
def test_rejects_oversized_key(self):
envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_rejects_oversized_value(self):
envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_accepts_max_lengths_exactly(self):
envelope = {
"x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN
}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_oversized_workflow_id_in_pnginfo_rejected(self):
"""The legacy synthesized path also respects the value bound."""
extra_data = {
"extra_pnginfo": {
"workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
}
}
assert extract_envelope_from_extra_data(extra_data) is None
def test_invalid_explicit_metadata_does_not_fall_through(self):
"""An explicit but invalid metadata dict means the caller asked
for something specific and got it wrong; the synthesized
fallback must not silently substitute."""
extra_data = {
"metadata": {"k": 42}, # non-string value
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) is None
class TestInjectEnvelope:
@staticmethod
def _lookup(table):
return table.get
def test_spreads_envelope_keys_onto_payload(self):
"""Envelope keys are merged at the top level so consumers can
read them directly (e.g. ``event.workflow_id``)."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1", "trace_id": "t-9"}})
assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == {
"node": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
"trace_id": "t-9",
}
def test_passthrough_when_prompt_id_not_registered(self):
lookup = self._lookup({})
data = {"node": "5", "prompt_id": "unknown"}
assert inject_envelope(data, lookup) == data
def test_passthrough_when_payload_lacks_prompt_id(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
data = {"status": "ok"}
assert inject_envelope(data, lookup) == data
def test_server_keys_win_on_collision_with_envelope(self):
"""A misbehaving client cannot shadow server-emitted fields by
stamping the same key in their submission envelope."""
lookup = self._lookup({
"p1": {"prompt_id": "client-claimed", "node": "spoofed", "workflow_id": "wf-1"}
})
result = inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
assert result["prompt_id"] == "p1"
assert result["node"] == "5"
assert result["workflow_id"] == "wf-1"
def test_does_not_mutate_input_dict(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
original = {"node": "5", "prompt_id": "p1"}
inject_envelope(original, lookup)
assert "workflow_id" not in original
def test_does_not_mutate_envelope_dict(self):
envelope = {"workflow_id": "wf-1"}
lookup = self._lookup({"p1": envelope})
inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
assert envelope == {"workflow_id": "wf-1"}
def test_injects_into_inner_dict_of_preview_metadata_tuple(self):
"""``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as
``(preview_image, metadata_dict)``; the inner dict is the only
place the envelope can attach."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
preview_image = ("PNG", object(), 256)
inner = {"node_id": "5", "prompt_id": "p1"}
result = inject_envelope((preview_image, inner), lookup)
assert isinstance(result, tuple)
assert result[0] is preview_image
assert result[1] == {
"node_id": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
}
assert "workflow_id" not in inner
def test_preview_tuple_passthrough_when_no_envelope_registered(self):
lookup = self._lookup({})
preview_image = ("PNG", object(), 256)
inner = {"node_id": "5", "prompt_id": "unknown"}
result = inject_envelope((preview_image, inner), lookup)
assert result == (preview_image, inner)
@pytest.mark.parametrize("payload", [b"raw-bytes", None, 42])
def test_non_dict_non_tuple_payloads_passthrough(self, payload):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
assert inject_envelope(payload, lookup) == payload
def test_tuple_of_wrong_arity_passthrough(self):
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is
special-cased. Other tuples must not be touched."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
triple = (1, {"prompt_id": "p1"}, 3)
assert inject_envelope(triple, lookup) is triple
def test_envelope_lookup_called_per_invocation(self):
"""The lookup runs each time the function is called, so changes
to the backing store are immediately visible."""
store = {"p1": {"workflow_id": "wf-1"}}
first = inject_envelope({"prompt_id": "p1"}, store.get)
store["p1"] = {"workflow_id": "wf-2"}
second = inject_envelope({"prompt_id": "p1"}, store.get)
del store["p1"]
third = inject_envelope({"prompt_id": "p1"}, store.get)
assert first["workflow_id"] == "wf-1"
assert second["workflow_id"] == "wf-2"
assert "workflow_id" not in third
class TestPromptMetadataStore:
"""End-to-end wiring tests that exercise the full register/inject/
unregister cycle the way ``PromptServer`` does."""
def test_register_inject_unregister_cycle(self):
store = PromptMetadataStore()
store.register(
"p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
)
injected = store.inject({"node": "5", "prompt_id": "p1"})
assert injected == {
"node": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
}
store.unregister("p1")
passthrough = store.inject({"node": "5", "prompt_id": "p1"})
assert "workflow_id" not in passthrough
def test_register_with_no_derivable_envelope_is_noop(self):
store = PromptMetadataStore()
store.register("p1", {})
assert "p1" not in store
data = {"prompt_id": "p1"}
assert store.inject(data) == data
def test_register_with_oversized_envelope_is_noop(self):
"""Sanitization rejection means nothing is registered — the
store stays empty and inject is a passthrough."""
store = PromptMetadataStore()
store.register(
"p1",
{"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}},
)
assert "p1" not in store
def test_unregister_unknown_prompt_is_silent(self):
store = PromptMetadataStore()
store.unregister("does-not-exist")
def test_fifo_eviction_when_capacity_exceeded(self):
"""If cleanup hooks are ever bypassed, the store must shed the
oldest entry rather than grow without bound."""
store = PromptMetadataStore(capacity=3)
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p2", {"metadata": {"workflow_id": "wf-2"}})
store.register("p3", {"metadata": {"workflow_id": "wf-3"}})
assert len(store) == 3
store.register("p4", {"metadata": {"workflow_id": "wf-4"}})
assert len(store) == 3
assert "p1" not in store
assert "p4" in store
# The newer entries are still injectable.
assert store.inject({"prompt_id": "p4"})["workflow_id"] == "wf-4"
# The evicted one is gone.
assert "workflow_id" not in store.inject({"prompt_id": "p1"})
def test_register_after_unregister_does_not_count_against_capacity(self):
"""Normal lifecycle: register, unregister, register many — the
store should not silently evict valid entries because of stale
accounting."""
store = PromptMetadataStore(capacity=2)
for i in range(10):
store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}})
store.unregister(f"p{i}")
assert len(store) == 0
def test_re_register_overwrites(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p1", {"metadata": {"workflow_id": "wf-2"}})
assert store.inject({"prompt_id": "p1"})["workflow_id"] == "wf-2"
def test_inject_with_no_registrations_is_passthrough(self):
store = PromptMetadataStore()
data = {"prompt_id": "p1", "node": "5"}
assert store.inject(data) == data
def test_inject_into_preview_tuple(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
result = store.inject((b"image-bytes", {"prompt_id": "p1"}))
assert result == (b"image-bytes", {
"prompt_id": "p1",
"workflow_id": "wf-1",
})
def test_concurrent_access_does_not_corrupt_or_raise(self):
"""Smoke test for the store's lock. ``register`` is called from
the aiohttp event-loop thread, ``unregister`` from the worker
thread, and ``inject`` fires on every ``send_sync`` from
whichever thread emits the event. Run all three concurrently
and assert no exception escapes and the store stays internally
consistent (the FIFO cap is never exceeded)."""
import threading
store = PromptMetadataStore(capacity=64)
stop = threading.Event()
errors: list[BaseException] = []
def registrar():
i = 0
try:
while not stop.is_set():
store.register(
f"p{i % 100}",
{"metadata": {"workflow_id": f"wf-{i}"}},
)
i += 1
except BaseException as e:
errors.append(e)
def canceller():
i = 0
try:
while not stop.is_set():
store.unregister(f"p{i % 100}")
i += 1
except BaseException as e:
errors.append(e)
def injector():
i = 0
try:
while not stop.is_set():
store.inject({"prompt_id": f"p{i % 100}", "node": "5"})
i += 1
except BaseException as e:
errors.append(e)
threads = [
threading.Thread(target=registrar),
threading.Thread(target=registrar),
threading.Thread(target=canceller),
threading.Thread(target=injector),
threading.Thread(target=injector),
]
for t in threads:
t.start()
# Brief burst — long enough to interleave many ops, short enough
# not to slow CI.
threading.Event().wait(0.1)
stop.set()
for t in threads:
t.join(timeout=2.0)
assert errors == [], f"concurrent access raised: {errors[:3]}"
assert len(store) <= 64, "FIFO cap was breached under contention"