diff --git a/app/prompt_metadata.py b/app/prompt_metadata.py new file mode 100644 index 000000000..7d6c29db1 --- /dev/null +++ b/app/prompt_metadata.py @@ -0,0 +1,226 @@ +"""Per-prompt metadata envelope shared between submission and outbound events. + +The metadata envelope is a small flat ``dict[str, str]`` (e.g. +``{"workflow_id": ...}``) attached to a prompt at submission and injected +by the server into every outbound execution event that carries a +``prompt_id``. It lets consumers scope state by tags they care about +(workflow, trace, tenant) without the execution layer ever needing to +know those tags exist. + +This module is intentionally pure — no imports from ``server`` or +``execution`` — so ``PromptServer`` can own a ``PromptMetadataStore`` +instance and the helpers can be unit-tested without the rest of the app. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any, Callable, Optional + + +# Bounds. The envelope is forwarded to every WebSocket client connected to +# the server on every execution event for the prompt — bounding key count, +# key length, value length, and refusing nested structures keeps a +# malicious or buggy client from inflating the broadcast volume. +MAX_ENVELOPE_KEYS = 16 +MAX_ENVELOPE_KEY_LEN = 64 +MAX_ENVELOPE_VALUE_LEN = 256 + +# Cap on concurrently registered prompt envelopes. Acts as a backstop if +# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale +# entry goes first. +DEFAULT_STORE_CAPACITY = 4096 + + +def _sanitize_envelope(envelope: Any) -> Optional[dict]: + """Validate and copy a candidate envelope. + + Enforces the ``dict[str, str]`` contract that downstream consumers + (cloud projections, frontend zod schemas, OpenAPI docs) rely on: + + - must be a non-empty ``dict`` + - at most ``MAX_ENVELOPE_KEYS`` entries + - every key and value must be a ``str`` + - keys at most ``MAX_ENVELOPE_KEY_LEN`` chars + - values at most ``MAX_ENVELOPE_VALUE_LEN`` chars + + Returns a defensive shallow copy on success, ``None`` on any + violation. Logs a warning on violation so abuse is visible. + """ + if not isinstance(envelope, dict) or not envelope: + return None + if len(envelope) > MAX_ENVELOPE_KEYS: + logging.warning( + "prompt metadata envelope rejected: %d keys exceeds limit %d", + len(envelope), MAX_ENVELOPE_KEYS, + ) + return None + sanitized: dict[str, str] = {} + for key, value in envelope.items(): + if not isinstance(key, str) or not isinstance(value, str): + logging.warning( + "prompt metadata envelope rejected: non-string key/value (%s=%s)", + type(key).__name__, type(value).__name__, + ) + return None + if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN: + logging.warning( + "prompt metadata envelope rejected: key or value exceeds length limit", + ) + return None + sanitized[key] = value + return sanitized + + +def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]: + """Pull the per-prompt metadata envelope out of a submitted prompt's + ``extra_data``. + + Two sources, in order: + + 1. Explicit ``extra_data["metadata"]`` — sanitized via + ``_sanitize_envelope``. Oversized or wrong-typed envelopes are + rejected (a warning is logged) rather than truncated, so the + contract stays strict at the boundary. + 2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward- + compatibility fallback. Frontends that already stamp the workflow + id into ``extra_pnginfo`` keep working; the synthesized envelope + is ``{"workflow_id": }``. A debug log fires so the legacy path + remains observable. + + Returns ``None`` when neither source yields a usable envelope. + """ + if not isinstance(extra_data, dict): + return None + + if "metadata" in extra_data: + sanitized = _sanitize_envelope(extra_data["metadata"]) + if sanitized is not None: + return sanitized + # Explicit metadata was supplied but rejected — do not fall + # through to the legacy path; the caller asked for something + # specific and got it wrong. + if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]: + return None + + extra_pnginfo = extra_data.get("extra_pnginfo") + if isinstance(extra_pnginfo, dict): + workflow = extra_pnginfo.get("workflow") + if isinstance(workflow, dict): + workflow_id = workflow.get("id") + if ( + isinstance(workflow_id, str) + and workflow_id + and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN + ): + logging.debug( + "prompt metadata envelope synthesized from extra_pnginfo.workflow.id" + ) + return {"workflow_id": workflow_id} + + return None + + +def inject_envelope( + data: Any, + envelope_lookup: Callable[[str], Optional[dict]], +) -> Any: + """Return ``data`` with the per-prompt envelope's keys spread onto it. + + ``envelope_lookup`` is called with the payload's ``prompt_id`` and is + expected to return the registered envelope or ``None``. This keeps + the function pure and avoids depending on any specific storage. + + The envelope's keys are merged onto the payload at the top level so + consumers can read them directly (e.g. ``event.workflow_id``) — + matching the wire shape of the prior workflow-id-on-events work and + avoiding an extra nesting hop for clients. Server-emitted fields on + the payload always win on collision (``{**envelope, **d}``); a + misbehaving client cannot shadow ``prompt_id``, ``node``, etc. + + Two payload shapes are handled: + + - **dict** carrying ``prompt_id``. A shallow copy is returned with + the envelope's keys merged onto it. + - **(preview_image, metadata_dict) tuple** — the format used by + ``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented; + the binary preview is passed through by reference. + + No-op for payloads without a ``prompt_id``, prompts with no + registered envelope, or any other payload shape. + """ + def inject(d: dict) -> dict: + if not isinstance(d, dict): + return d + prompt_id = d.get("prompt_id") + if not prompt_id: + return d + envelope = envelope_lookup(prompt_id) + if envelope is None: + return d + return {**envelope, **d} + + if isinstance(data, dict): + return inject(data) + if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict): + injected = inject(data[1]) + if injected is data[1]: + return data + return (data[0], injected) + return data + + +class PromptMetadataStore: + """Bounded ``prompt_id -> envelope`` map. + + Owned by ``PromptServer``. Populated at submission, drained when the + prompt finishes, wiped on queue cancel/delete. The FIFO cap is a + backstop: if any cleanup hook is ever skipped, the store sheds the + oldest entry instead of growing without bound. + + Access is serialized through a ``threading.Lock``. ``register`` runs + on the aiohttp event-loop thread, ``unregister`` runs on the + ``prompt_worker`` thread, and ``inject`` runs on whichever thread + fires ``send_sync`` (event loop, worker, asset seeder). Individual + ``dict`` ops are GIL-atomic, but ``register``'s + ``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}`` + are multi-step compounds whose interleaving without a lock is + racy. The lock is uncontended in steady state (sub-microsecond + critical sections) so the cost is negligible. + """ + + def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY): + self._envelopes: dict[str, dict] = {} + self._capacity = capacity + self._lock = threading.Lock() + + def register(self, prompt_id: str, extra_data: Any) -> None: + envelope = extract_envelope_from_extra_data(extra_data) + if envelope is None: + return + with self._lock: + if len(self._envelopes) >= self._capacity: + self._envelopes.pop(next(iter(self._envelopes))) + self._envelopes[prompt_id] = envelope + + def unregister(self, prompt_id: str) -> None: + with self._lock: + self._envelopes.pop(prompt_id, None) + + def inject(self, data: Any) -> Any: + # Snapshot the envelope under the lock so the spread in + # ``inject_envelope`` runs against a consistent view even if a + # concurrent ``register``/``unregister`` is mutating the map. + def locked_lookup(prompt_id: str) -> Optional[dict]: + with self._lock: + return self._envelopes.get(prompt_id) + return inject_envelope(data, locked_lookup) + + def __len__(self) -> int: + with self._lock: + return len(self._envelopes) + + def __contains__(self, prompt_id: str) -> bool: + with self._lock: + return prompt_id in self._envelopes diff --git a/execution.py b/execution.py index 4c7de2e84..b3ec4b170 100644 --- a/execution.py +++ b/execution.py @@ -1296,7 +1296,10 @@ class PromptQueue: def wipe_queue(self): with self.mutex: + dropped_prompt_ids = [item[1] for item in self.queue] self.queue = [] + for prompt_id in dropped_prompt_ids: + self.server.unregister_prompt_metadata(prompt_id) self.server.queue_updated() def delete_queue_item(self, function): @@ -1306,8 +1309,9 @@ class PromptQueue: if len(self.queue) == 1: self.wipe_queue() else: - self.queue.pop(x) + deleted = self.queue.pop(x) heapq.heapify(self.queue) + self.server.unregister_prompt_metadata(deleted[1]) self.server.queue_updated() return True return False diff --git a/main.py b/main.py index a6fdaf43c..3c4b1844f 100644 --- a/main.py +++ b/main.py @@ -318,19 +318,26 @@ def prompt_worker(q, server_instance): extra_data[k] = sensitive[k] asset_seeder.pause() - e.execute(item[2], prompt_id, extra_data, item[4]) + try: + e.execute(item[2], prompt_id, extra_data, item[4]) - need_gc = True + need_gc = True - remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] - q.task_done(item_id, - e.history_result, - status=execution.PromptQueue.ExecutionStatus( - status_str='success' if e.success else 'error', - completed=e.success, - messages=e.status_messages), process_item=remove_sensitive) - if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] + q.task_done(item_id, + e.history_result, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + messages=e.status_messages), process_item=remove_sensitive) + if server_instance.client_id is not None: + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + finally: + # Always drop the metadata envelope. If e.execute() raises + # hard before its own error handling kicks in, the + # registered envelope would otherwise leak for the + # lifetime of the process. + server_instance.unregister_prompt_metadata(prompt_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time diff --git a/server.py b/server.py index 44470b904..28a44426e 100644 --- a/server.py +++ b/server.py @@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager from app.node_replace_manager import NodeReplaceManager +from app.prompt_metadata import PromptMetadataStore from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -250,8 +251,14 @@ class PromptServer(): routes = web.RouteTableDef() self.routes = routes self.last_node_id = None + self.last_prompt_id = None self.client_id = None + # Bounded prompt_id -> envelope store. Populated at submission, + # drained on completion/cancel. Keeps workflow scope (and other + # client-supplied tags) out of the execution layer. + self._prompt_metadata = PromptMetadataStore() + self.on_prompt_handlers = [] @routes.get('/ws') @@ -275,7 +282,12 @@ class PromptServer(): await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", { "node": self.last_node_id }, sid) + payload: dict = {"node": self.last_node_id} + if self.last_prompt_id is not None: + payload["prompt_id"] = self.last_prompt_id + await self.send( + "executing", self._inject_prompt_metadata(payload), sid + ) # Flag to track if we've received the first message first_message = True @@ -955,6 +967,7 @@ class PromptServer(): if sensitive_val in extra_data: sensitive[sensitive_val] = extra_data.pop(sensitive_val) extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds + self.register_prompt_metadata(prompt_id, extra_data) self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) @@ -1216,7 +1229,22 @@ class PromptServer(): elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_json, message) + def register_prompt_metadata(self, prompt_id: str, extra_data) -> None: + """Record per-prompt metadata for injection into outbound execution + events. Called at submission, before the prompt is queued.""" + self._prompt_metadata.register(prompt_id, extra_data) + + def unregister_prompt_metadata(self, prompt_id: str) -> None: + """Drop the per-prompt metadata envelope. Called after the prompt + has finished executing and its terminal events have been queued, + or when the prompt is cancelled before reaching the worker.""" + self._prompt_metadata.unregister(prompt_id) + + def _inject_prompt_metadata(self, data): + return self._prompt_metadata.inject(data) + def send_sync(self, event, data, sid=None): + data = self._inject_prompt_metadata(data) self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) diff --git a/tests-unit/app_test/test_prompt_metadata.py b/tests-unit/app_test/test_prompt_metadata.py new file mode 100644 index 000000000..b5241dd5d --- /dev/null +++ b/tests-unit/app_test/test_prompt_metadata.py @@ -0,0 +1,412 @@ +"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``. + +Covers the two pure helpers (``extract_envelope_from_extra_data`` and +``inject_envelope``) and the ``PromptMetadataStore`` integration class +that ``PromptServer`` owns. +""" + +from __future__ import annotations + +import pytest + +from app.prompt_metadata import ( + MAX_ENVELOPE_KEYS, + MAX_ENVELOPE_KEY_LEN, + MAX_ENVELOPE_VALUE_LEN, + PromptMetadataStore, + extract_envelope_from_extra_data, + inject_envelope, +) + + +class TestExtractEnvelopeFromExtraData: + def test_explicit_metadata_dict_is_used_as_is(self): + extra_data = {"metadata": {"workflow_id": "wf-1", "trace_id": "t-9"}} + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-1", + "trace_id": "t-9", + } + + def test_explicit_metadata_takes_precedence_over_extra_pnginfo(self): + extra_data = { + "metadata": {"workflow_id": "explicit"}, + "extra_pnginfo": {"workflow": {"id": "fallback"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "explicit" + } + + def test_falls_back_to_extra_pnginfo_workflow_id(self): + extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-legacy"}}} + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + def test_returns_none_when_no_metadata_and_no_workflow_id(self): + assert extract_envelope_from_extra_data({}) is None + assert ( + extract_envelope_from_extra_data({"extra_pnginfo": {"workflow": {}}}) + is None + ) + + @pytest.mark.parametrize("bad", ["", 123, None, [], {}]) + def test_rejects_non_string_or_empty_workflow_id(self, bad): + extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}} + assert extract_envelope_from_extra_data(extra_data) is None + + def test_rejects_non_dict_inputs_at_each_level(self): + assert extract_envelope_from_extra_data(None) is None + assert extract_envelope_from_extra_data("not-a-dict") is None + assert ( + extract_envelope_from_extra_data({"extra_pnginfo": "not-a-dict"}) + is None + ) + assert ( + extract_envelope_from_extra_data( + {"extra_pnginfo": {"workflow": "not-a-dict"}} + ) + is None + ) + + def test_empty_explicit_metadata_falls_through_to_workflow_id(self): + extra_data = { + "metadata": {}, + "extra_pnginfo": {"workflow": {"id": "wf-legacy"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + def test_returned_envelope_is_copy_not_reference(self): + original = {"workflow_id": "wf-1"} + result = extract_envelope_from_extra_data({"metadata": original}) + assert result is not None + result["new_key"] = "x" + assert "new_key" not in original + + def test_non_dict_explicit_metadata_falls_through_to_workflow_id(self): + extra_data = { + "metadata": "not-a-dict", + "extra_pnginfo": {"workflow": {"id": "wf-legacy"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + +class TestEnvelopeSanitization: + """The wire contract is ``dict[str, str]`` with bounded size. A bad + envelope is dropped (and a warning is logged) rather than truncated, + so the boundary stays strict.""" + + def test_rejects_too_many_keys(self, caplog): + envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)} + with caplog.at_level("WARNING"): + assert extract_envelope_from_extra_data({"metadata": envelope}) is None + assert any("exceeds limit" in r.message for r in caplog.records) + + def test_accepts_max_keys_exactly(self): + envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)} + assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope + + def test_rejects_non_string_keys(self, caplog): + with caplog.at_level("WARNING"): + assert ( + extract_envelope_from_extra_data({"metadata": {42: "v"}}) + is None + ) + assert any("non-string" in r.message for r in caplog.records) + + def test_rejects_non_string_values(self, caplog): + for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]: + with caplog.at_level("WARNING"): + assert ( + extract_envelope_from_extra_data( + {"metadata": {"k": bad_value}} + ) + is None + ) + + def test_rejects_oversized_key(self): + envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"} + assert extract_envelope_from_extra_data({"metadata": envelope}) is None + + def test_rejects_oversized_value(self): + envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)} + assert extract_envelope_from_extra_data({"metadata": envelope}) is None + + def test_accepts_max_lengths_exactly(self): + envelope = { + "x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN + } + assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope + + def test_oversized_workflow_id_in_pnginfo_rejected(self): + """The legacy synthesized path also respects the value bound.""" + extra_data = { + "extra_pnginfo": { + "workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)} + } + } + assert extract_envelope_from_extra_data(extra_data) is None + + def test_invalid_explicit_metadata_does_not_fall_through(self): + """An explicit but invalid metadata dict means the caller asked + for something specific and got it wrong; the synthesized + fallback must not silently substitute.""" + extra_data = { + "metadata": {"k": 42}, # non-string value + "extra_pnginfo": {"workflow": {"id": "wf-legacy"}}, + } + assert extract_envelope_from_extra_data(extra_data) is None + + +class TestInjectEnvelope: + @staticmethod + def _lookup(table): + return table.get + + def test_spreads_envelope_keys_onto_payload(self): + """Envelope keys are merged at the top level so consumers can + read them directly (e.g. ``event.workflow_id``).""" + lookup = self._lookup({"p1": {"workflow_id": "wf-1", "trace_id": "t-9"}}) + assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == { + "node": "5", + "prompt_id": "p1", + "workflow_id": "wf-1", + "trace_id": "t-9", + } + + def test_passthrough_when_prompt_id_not_registered(self): + lookup = self._lookup({}) + data = {"node": "5", "prompt_id": "unknown"} + assert inject_envelope(data, lookup) == data + + def test_passthrough_when_payload_lacks_prompt_id(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + data = {"status": "ok"} + assert inject_envelope(data, lookup) == data + + def test_server_keys_win_on_collision_with_envelope(self): + """A misbehaving client cannot shadow server-emitted fields by + stamping the same key in their submission envelope.""" + lookup = self._lookup({ + "p1": {"prompt_id": "client-claimed", "node": "spoofed", "workflow_id": "wf-1"} + }) + result = inject_envelope({"prompt_id": "p1", "node": "5"}, lookup) + assert result["prompt_id"] == "p1" + assert result["node"] == "5" + assert result["workflow_id"] == "wf-1" + + def test_does_not_mutate_input_dict(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + original = {"node": "5", "prompt_id": "p1"} + inject_envelope(original, lookup) + assert "workflow_id" not in original + + def test_does_not_mutate_envelope_dict(self): + envelope = {"workflow_id": "wf-1"} + lookup = self._lookup({"p1": envelope}) + inject_envelope({"prompt_id": "p1", "node": "5"}, lookup) + assert envelope == {"workflow_id": "wf-1"} + + def test_injects_into_inner_dict_of_preview_metadata_tuple(self): + """``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as + ``(preview_image, metadata_dict)``; the inner dict is the only + place the envelope can attach.""" + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + preview_image = ("PNG", object(), 256) + inner = {"node_id": "5", "prompt_id": "p1"} + result = inject_envelope((preview_image, inner), lookup) + assert isinstance(result, tuple) + assert result[0] is preview_image + assert result[1] == { + "node_id": "5", + "prompt_id": "p1", + "workflow_id": "wf-1", + } + assert "workflow_id" not in inner + + def test_preview_tuple_passthrough_when_no_envelope_registered(self): + lookup = self._lookup({}) + preview_image = ("PNG", object(), 256) + inner = {"node_id": "5", "prompt_id": "unknown"} + result = inject_envelope((preview_image, inner), lookup) + assert result == (preview_image, inner) + + @pytest.mark.parametrize("payload", [b"raw-bytes", None, 42]) + def test_non_dict_non_tuple_payloads_passthrough(self, payload): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + assert inject_envelope(payload, lookup) == payload + + def test_tuple_of_wrong_arity_passthrough(self): + """Only the 2-tuple ``(preview, metadata_dict)`` shape is + special-cased. Other tuples must not be touched.""" + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + triple = (1, {"prompt_id": "p1"}, 3) + assert inject_envelope(triple, lookup) is triple + + def test_envelope_lookup_called_per_invocation(self): + """The lookup runs each time the function is called, so changes + to the backing store are immediately visible.""" + store = {"p1": {"workflow_id": "wf-1"}} + first = inject_envelope({"prompt_id": "p1"}, store.get) + store["p1"] = {"workflow_id": "wf-2"} + second = inject_envelope({"prompt_id": "p1"}, store.get) + del store["p1"] + third = inject_envelope({"prompt_id": "p1"}, store.get) + assert first["workflow_id"] == "wf-1" + assert second["workflow_id"] == "wf-2" + assert "workflow_id" not in third + + +class TestPromptMetadataStore: + """End-to-end wiring tests that exercise the full register/inject/ + unregister cycle the way ``PromptServer`` does.""" + + def test_register_inject_unregister_cycle(self): + store = PromptMetadataStore() + store.register( + "p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}} + ) + injected = store.inject({"node": "5", "prompt_id": "p1"}) + assert injected == { + "node": "5", + "prompt_id": "p1", + "workflow_id": "wf-1", + } + store.unregister("p1") + passthrough = store.inject({"node": "5", "prompt_id": "p1"}) + assert "workflow_id" not in passthrough + + def test_register_with_no_derivable_envelope_is_noop(self): + store = PromptMetadataStore() + store.register("p1", {}) + assert "p1" not in store + data = {"prompt_id": "p1"} + assert store.inject(data) == data + + def test_register_with_oversized_envelope_is_noop(self): + """Sanitization rejection means nothing is registered — the + store stays empty and inject is a passthrough.""" + store = PromptMetadataStore() + store.register( + "p1", + {"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}}, + ) + assert "p1" not in store + + def test_unregister_unknown_prompt_is_silent(self): + store = PromptMetadataStore() + store.unregister("does-not-exist") + + def test_fifo_eviction_when_capacity_exceeded(self): + """If cleanup hooks are ever bypassed, the store must shed the + oldest entry rather than grow without bound.""" + store = PromptMetadataStore(capacity=3) + store.register("p1", {"metadata": {"workflow_id": "wf-1"}}) + store.register("p2", {"metadata": {"workflow_id": "wf-2"}}) + store.register("p3", {"metadata": {"workflow_id": "wf-3"}}) + assert len(store) == 3 + + store.register("p4", {"metadata": {"workflow_id": "wf-4"}}) + assert len(store) == 3 + assert "p1" not in store + assert "p4" in store + + # The newer entries are still injectable. + assert store.inject({"prompt_id": "p4"})["workflow_id"] == "wf-4" + # The evicted one is gone. + assert "workflow_id" not in store.inject({"prompt_id": "p1"}) + + def test_register_after_unregister_does_not_count_against_capacity(self): + """Normal lifecycle: register, unregister, register many — the + store should not silently evict valid entries because of stale + accounting.""" + store = PromptMetadataStore(capacity=2) + for i in range(10): + store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}}) + store.unregister(f"p{i}") + assert len(store) == 0 + + def test_re_register_overwrites(self): + store = PromptMetadataStore() + store.register("p1", {"metadata": {"workflow_id": "wf-1"}}) + store.register("p1", {"metadata": {"workflow_id": "wf-2"}}) + assert store.inject({"prompt_id": "p1"})["workflow_id"] == "wf-2" + + def test_inject_with_no_registrations_is_passthrough(self): + store = PromptMetadataStore() + data = {"prompt_id": "p1", "node": "5"} + assert store.inject(data) == data + + def test_inject_into_preview_tuple(self): + store = PromptMetadataStore() + store.register("p1", {"metadata": {"workflow_id": "wf-1"}}) + result = store.inject((b"image-bytes", {"prompt_id": "p1"})) + assert result == (b"image-bytes", { + "prompt_id": "p1", + "workflow_id": "wf-1", + }) + + def test_concurrent_access_does_not_corrupt_or_raise(self): + """Smoke test for the store's lock. ``register`` is called from + the aiohttp event-loop thread, ``unregister`` from the worker + thread, and ``inject`` fires on every ``send_sync`` from + whichever thread emits the event. Run all three concurrently + and assert no exception escapes and the store stays internally + consistent (the FIFO cap is never exceeded).""" + import threading + + store = PromptMetadataStore(capacity=64) + stop = threading.Event() + errors: list[BaseException] = [] + + def registrar(): + i = 0 + try: + while not stop.is_set(): + store.register( + f"p{i % 100}", + {"metadata": {"workflow_id": f"wf-{i}"}}, + ) + i += 1 + except BaseException as e: + errors.append(e) + + def canceller(): + i = 0 + try: + while not stop.is_set(): + store.unregister(f"p{i % 100}") + i += 1 + except BaseException as e: + errors.append(e) + + def injector(): + i = 0 + try: + while not stop.is_set(): + store.inject({"prompt_id": f"p{i % 100}", "node": "5"}) + i += 1 + except BaseException as e: + errors.append(e) + + threads = [ + threading.Thread(target=registrar), + threading.Thread(target=registrar), + threading.Thread(target=canceller), + threading.Thread(target=injector), + threading.Thread(target=injector), + ] + for t in threads: + t.start() + # Brief burst — long enough to interleave many ops, short enough + # not to slow CI. + threading.Event().wait(0.1) + stop.set() + for t in threads: + t.join(timeout=2.0) + + assert errors == [], f"concurrent access raised: {errors[:3]}" + assert len(store) <= 64, "FIFO cap was breached under contention"