From 74cfcaa3181cbe82d5f160fe6f8d62e354810657 Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Thu, 14 May 2026 20:47:00 -0700 Subject: [PATCH 1/4] feat(server): per-prompt metadata envelope on websocket events Replaces the workflow_id-on-every-event approach (#13684, reverted in #13901) with a generic metadata envelope captured at submission and injected at the server-side send chokepoint. - POST /prompt accepts an opaque ``extra_data.metadata`` dict (falls back to synthesizing ``{"workflow_id": }`` from ``extra_pnginfo.workflow.id`` so existing frontends keep working). - ``PromptServer`` owns a ``prompt_id -> metadata`` map populated at submission, drained when the prompt finishes. ``send_sync`` injects the envelope into any outbound payload that carries a ``prompt_id``, including the ``(preview_image, metadata_dict)`` tuple used by ``PREVIEW_IMAGE_WITH_METADATA``. WS reconnect path carries it too. - Pure helpers live in ``app/prompt_metadata.py`` so the execution layer never depends on workflow concepts and the helpers can be unit-tested without torch. Execution layer (``execution.py``, ``comfy_execution/*``) and the jobs API are unchanged. Backward compatible: existing fields and shapes are preserved, only an additional ``metadata`` field is attached when present. --- app/prompt_metadata.py | 96 +++++++++++ main.py | 2 + server.py | 32 +++- tests-unit/app_test/test_prompt_metadata.py | 181 ++++++++++++++++++++ 4 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 app/prompt_metadata.py create mode 100644 tests-unit/app_test/test_prompt_metadata.py diff --git a/app/prompt_metadata.py b/app/prompt_metadata.py new file mode 100644 index 000000000..43e387769 --- /dev/null +++ b/app/prompt_metadata.py @@ -0,0 +1,96 @@ +"""Per-prompt metadata envelope shared between submission and outbound events. + +The metadata envelope is an opaque dict (e.g. ``{"workflow_id": ...}``) +attached to a prompt at submission and injected by the server into every +outbound execution event that carries a ``prompt_id``. It lets consumers +scope state by tags they care about (workflow, trace, tenant) without the +execution layer ever needing to know those tags exist. + +Two pure functions live here; ``PromptServer`` owns the per-prompt map and +wires them into the submission and send paths. +""" + +from __future__ import annotations + +from typing import Any, Callable, Optional + + +def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]: + """Pull the per-prompt metadata envelope out of a submitted prompt's + ``extra_data``. + + Two sources, in order: + + 1. Explicit ``extra_data["metadata"]`` dict — preferred path, accepted + as-is (copied so later mutations on the caller's dict don't leak). + 2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward- + compatibility fallback. Frontends that already stamp the workflow + id into ``extra_pnginfo`` keep working without changes; the + synthesized envelope is ``{"workflow_id": }``. + + Returns ``None`` when neither source yields a usable envelope. + """ + if not isinstance(extra_data, dict): + return None + + metadata = extra_data.get("metadata") + if isinstance(metadata, dict) and metadata: + return dict(metadata) + + extra_pnginfo = extra_data.get("extra_pnginfo") + if isinstance(extra_pnginfo, dict): + workflow = extra_pnginfo.get("workflow") + if isinstance(workflow, dict): + workflow_id = workflow.get("id") + if isinstance(workflow_id, str) and workflow_id: + return {"workflow_id": workflow_id} + + return None + + +def inject_envelope( + data: Any, + envelope_lookup: Callable[[str], Optional[dict]], +) -> Any: + """Return ``data`` with a per-prompt ``metadata`` envelope attached. + + ``envelope_lookup`` is called with the payload's ``prompt_id`` and is + expected to return the registered envelope or ``None``. This indirection + keeps the function pure and avoids depending on any specific storage. + + Two payload shapes are handled: + + - **dict** carrying ``prompt_id``. A shallow copy is returned with a + ``metadata`` key set to the envelope. The caller's dict is not + mutated. + - **(preview_image, metadata_dict) tuple** — the format used by + ``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented; + the binary preview is passed through by reference. + + The function is a no-op for: + + - payloads without a ``prompt_id``, + - payloads already declaring their own ``metadata`` field + (callers can opt out by setting it explicitly), + - prompts with no registered envelope, + - any other payload shape (raw bytes, ``None``, etc.). + """ + def inject(d: dict) -> dict: + if not isinstance(d, dict) or "metadata" in d: + return d + prompt_id = d.get("prompt_id") + if not prompt_id: + return d + envelope = envelope_lookup(prompt_id) + if envelope is None: + return d + return {**d, "metadata": envelope} + + if isinstance(data, dict): + return inject(data) + if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict): + injected = inject(data[1]) + if injected is data[1]: + return data + return (data[0], injected) + return data diff --git a/main.py b/main.py index a6fdaf43c..f0293ae0a 100644 --- a/main.py +++ b/main.py @@ -332,6 +332,8 @@ def prompt_worker(q, server_instance): if server_instance.client_id is not None: server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + server_instance.unregister_prompt_metadata(prompt_id) + current_time = time.perf_counter() execution_time = current_time - execution_start_time diff --git a/server.py b/server.py index 44470b904..b4876c8fa 100644 --- a/server.py +++ b/server.py @@ -44,6 +44,10 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager from app.node_replace_manager import NodeReplaceManager +from app.prompt_metadata import ( + extract_envelope_from_extra_data, + inject_envelope, +) from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -250,8 +254,14 @@ class PromptServer(): routes = web.RouteTableDef() self.routes = routes self.last_node_id = None + self.last_prompt_id = None self.client_id = None + # prompt_id -> metadata envelope captured at submission and injected + # into outbound execution events. Keeps the workflow scope (and any + # other client-supplied tags) out of the execution layer. + self._prompt_metadata: dict[str, dict] = {} + self.on_prompt_handlers = [] @routes.get('/ws') @@ -275,7 +285,10 @@ class PromptServer(): await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", { "node": self.last_node_id }, sid) + await self.send("executing", self._inject_prompt_metadata({ + "node": self.last_node_id, + "prompt_id": self.last_prompt_id, + }), sid) # Flag to track if we've received the first message first_message = True @@ -955,6 +968,7 @@ class PromptServer(): if sensitive_val in extra_data: sensitive[sensitive_val] = extra_data.pop(sensitive_val) extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds + self.register_prompt_metadata(prompt_id, extra_data) self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) @@ -1216,7 +1230,23 @@ class PromptServer(): elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_json, message) + def register_prompt_metadata(self, prompt_id: str, extra_data) -> None: + """Record per-prompt metadata for injection into outbound execution + events. Called at submission, before the prompt is queued.""" + envelope = extract_envelope_from_extra_data(extra_data) + if envelope is not None: + self._prompt_metadata[prompt_id] = envelope + + def unregister_prompt_metadata(self, prompt_id: str) -> None: + """Drop the per-prompt metadata envelope. Called after the prompt + has finished executing and its terminal events have been queued.""" + self._prompt_metadata.pop(prompt_id, None) + + def _inject_prompt_metadata(self, data): + return inject_envelope(data, self._prompt_metadata.get) + def send_sync(self, event, data, sid=None): + data = self._inject_prompt_metadata(data) self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) diff --git a/tests-unit/app_test/test_prompt_metadata.py b/tests-unit/app_test/test_prompt_metadata.py new file mode 100644 index 000000000..184ce22e5 --- /dev/null +++ b/tests-unit/app_test/test_prompt_metadata.py @@ -0,0 +1,181 @@ +"""Unit tests for the pure metadata-envelope helpers in +``app.prompt_metadata``. These cover the two functions that PromptServer +wires into submission (``extract_envelope_from_extra_data``) and into the +send chokepoint (``inject_envelope``). +""" + +from __future__ import annotations + +from app.prompt_metadata import ( + extract_envelope_from_extra_data, + inject_envelope, +) + + +class TestExtractEnvelopeFromExtraData: + def test_explicit_metadata_dict_is_used_as_is(self): + extra_data = {"metadata": {"workflow_id": "wf-1", "trace_id": "t-9"}} + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-1", + "trace_id": "t-9", + } + + def test_explicit_metadata_takes_precedence_over_extra_pnginfo(self): + extra_data = { + "metadata": {"workflow_id": "explicit"}, + "extra_pnginfo": {"workflow": {"id": "fallback"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "explicit" + } + + def test_falls_back_to_extra_pnginfo_workflow_id(self): + extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-legacy"}}} + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + def test_returns_none_when_no_metadata_and_no_workflow_id(self): + assert extract_envelope_from_extra_data({}) is None + assert ( + extract_envelope_from_extra_data({"extra_pnginfo": {"workflow": {}}}) + is None + ) + + def test_rejects_non_string_or_empty_workflow_id(self): + for bad in ["", 123, None, [], {}]: + extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}} + assert extract_envelope_from_extra_data(extra_data) is None + + def test_rejects_non_dict_inputs_at_each_level(self): + assert extract_envelope_from_extra_data(None) is None + assert extract_envelope_from_extra_data("not-a-dict") is None + assert ( + extract_envelope_from_extra_data({"extra_pnginfo": "not-a-dict"}) + is None + ) + assert ( + extract_envelope_from_extra_data( + {"extra_pnginfo": {"workflow": "not-a-dict"}} + ) + is None + ) + + def test_empty_explicit_metadata_falls_through_to_workflow_id(self): + extra_data = { + "metadata": {}, + "extra_pnginfo": {"workflow": {"id": "wf-legacy"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + def test_returned_envelope_is_copy_not_reference(self): + original = {"workflow_id": "wf-1"} + result = extract_envelope_from_extra_data({"metadata": original}) + result["new_key"] = "x" + assert "new_key" not in original + + def test_non_dict_explicit_metadata_falls_through_to_workflow_id(self): + extra_data = { + "metadata": "not-a-dict", + "extra_pnginfo": {"workflow": {"id": "wf-legacy"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + +class TestInjectEnvelope: + @staticmethod + def _lookup(table): + """Build an envelope_lookup callable backed by a dict.""" + return table.get + + def test_injects_envelope_on_dict_with_known_prompt_id(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == { + "node": "5", + "prompt_id": "p1", + "metadata": {"workflow_id": "wf-1"}, + } + + def test_passthrough_when_prompt_id_not_registered(self): + lookup = self._lookup({}) + data = {"node": "5", "prompt_id": "unknown"} + assert inject_envelope(data, lookup) == data + + def test_passthrough_when_payload_lacks_prompt_id(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + data = {"status": "ok"} + assert inject_envelope(data, lookup) is data + + def test_passthrough_when_payload_already_has_metadata(self): + """If a caller has already set a ``metadata`` field (e.g. for + opt-out or pre-augmented payloads), the function must not + overwrite it.""" + lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) + data = {"prompt_id": "p1", "metadata": {"workflow_id": "wf-caller"}} + result = inject_envelope(data, lookup) + assert result is data + assert result["metadata"] == {"workflow_id": "wf-caller"} + + def test_does_not_mutate_input_dict(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + original = {"node": "5", "prompt_id": "p1"} + inject_envelope(original, lookup) + assert "metadata" not in original + + def test_injects_into_inner_dict_of_preview_metadata_tuple(self): + """``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as + ``(preview_image, metadata_dict)``; the inner dict is the only + place the envelope can attach.""" + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + preview_image = ("PNG", object(), 256) + inner = {"node_id": "5", "prompt_id": "p1"} + result = inject_envelope((preview_image, inner), lookup) + assert isinstance(result, tuple) + assert result[0] is preview_image + assert result[1] == { + "node_id": "5", + "prompt_id": "p1", + "metadata": {"workflow_id": "wf-1"}, + } + assert "metadata" not in inner + + def test_preview_tuple_passthrough_when_no_envelope_registered(self): + lookup = self._lookup({}) + preview_image = ("PNG", object(), 256) + inner = {"node_id": "5", "prompt_id": "unknown"} + result = inject_envelope((preview_image, inner), lookup) + assert result == (preview_image, inner) + + def test_preview_tuple_passthrough_when_inner_already_has_metadata(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) + preview_image = ("PNG", object(), 256) + inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": 1}} + result = inject_envelope((preview_image, inner), lookup) + assert result == (preview_image, inner) + + def test_non_dict_non_tuple_payloads_passthrough(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + assert inject_envelope(b"raw-bytes", lookup) == b"raw-bytes" + assert inject_envelope(None, lookup) is None + assert inject_envelope(42, lookup) == 42 + + def test_tuple_of_wrong_arity_passthrough(self): + """Only the 2-tuple ``(preview, metadata_dict)`` shape is special- + cased. Other tuples must not be touched.""" + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + triple = (1, {"prompt_id": "p1"}, 3) + assert inject_envelope(triple, lookup) is triple + + def test_envelope_lookup_called_at_send_time(self): + """The lookup runs each time the function is called, so a producer + and consumer that share a backing dict observe the current value.""" + store = {"p1": {"workflow_id": "wf-1"}} + first = inject_envelope({"prompt_id": "p1"}, store.get) + store["p1"] = {"workflow_id": "wf-2"} + second = inject_envelope({"prompt_id": "p1"}, store.get) + assert first["metadata"] == {"workflow_id": "wf-1"} + assert second["metadata"] == {"workflow_id": "wf-2"} From fd89498eaceee735bd10b11b03997632815e44a4 Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Thu, 14 May 2026 21:03:38 -0700 Subject: [PATCH 2/4] fix(server): bound metadata envelope and clean up on cancel paths Addresses review feedback on the per-prompt metadata envelope: - Sanitize at the boundary: reject envelopes larger than 16 keys, keys over 64 chars, values over 256 chars, or anything that isn't a flat ``dict[str, str]``. Logs a warning so abuse is observable. Stops a malicious client from inflating broadcast volume by stamping a 10 MB metadata blob onto every WS event. - Cap the in-memory store at 4096 concurrent envelopes with FIFO eviction. Acts as a backstop if any cleanup hook is skipped. - Drop envelopes when prompts are cancelled before reaching the worker: ``PromptQueue.wipe_queue`` and ``delete_queue_item`` now call ``server.unregister_prompt_metadata`` for every removed item. - Drop envelopes on hard execution failures: the worker now wraps ``e.execute()`` in ``try/finally``, so an uncaught exception in execution no longer leaks the envelope. - Guard the WS reconnect handler: only include ``prompt_id`` in the ``executing`` payload when ``last_prompt_id`` is set, so clients with strict schemas (zod ``prompt_id: zJobId``) don't reject the message with a null id. - Extract a ``PromptMetadataStore`` class that owns the dict and the bounds, so ``PromptServer`` becomes a thin delegating layer and the full register/inject/unregister cycle (plus FIFO eviction and sanitization) is unit-tested without torch. 44 tests passing; ruff clean on all touched files. --- app/prompt_metadata.py | 158 +++++++++++--- execution.py | 6 +- main.py | 31 +-- server.py | 34 ++- tests-unit/app_test/test_prompt_metadata.py | 216 +++++++++++++++++--- 5 files changed, 362 insertions(+), 83 deletions(-) diff --git a/app/prompt_metadata.py b/app/prompt_metadata.py index 43e387769..361566249 100644 --- a/app/prompt_metadata.py +++ b/app/prompt_metadata.py @@ -1,48 +1,121 @@ """Per-prompt metadata envelope shared between submission and outbound events. -The metadata envelope is an opaque dict (e.g. ``{"workflow_id": ...}``) -attached to a prompt at submission and injected by the server into every -outbound execution event that carries a ``prompt_id``. It lets consumers -scope state by tags they care about (workflow, trace, tenant) without the -execution layer ever needing to know those tags exist. +The metadata envelope is a small flat ``dict[str, str]`` (e.g. +``{"workflow_id": ...}``) attached to a prompt at submission and injected +by the server into every outbound execution event that carries a +``prompt_id``. It lets consumers scope state by tags they care about +(workflow, trace, tenant) without the execution layer ever needing to +know those tags exist. -Two pure functions live here; ``PromptServer`` owns the per-prompt map and -wires them into the submission and send paths. +This module is intentionally pure — no imports from ``server`` or +``execution`` — so ``PromptServer`` can own a ``PromptMetadataStore`` +instance and the helpers can be unit-tested without the rest of the app. """ from __future__ import annotations +import logging from typing import Any, Callable, Optional +# Bounds. The envelope is forwarded to every WebSocket client connected to +# the server on every execution event for the prompt — bounding key count, +# key length, value length, and refusing nested structures keeps a +# malicious or buggy client from inflating the broadcast volume. +MAX_ENVELOPE_KEYS = 16 +MAX_ENVELOPE_KEY_LEN = 64 +MAX_ENVELOPE_VALUE_LEN = 256 + +# Cap on concurrently registered prompt envelopes. Acts as a backstop if +# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale +# entry goes first. +DEFAULT_STORE_CAPACITY = 4096 + + +def _sanitize_envelope(envelope: Any) -> Optional[dict]: + """Validate and copy a candidate envelope. + + Enforces the ``dict[str, str]`` contract that downstream consumers + (cloud projections, frontend zod schemas, OpenAPI docs) rely on: + + - must be a non-empty ``dict`` + - at most ``MAX_ENVELOPE_KEYS`` entries + - every key and value must be a ``str`` + - keys at most ``MAX_ENVELOPE_KEY_LEN`` chars + - values at most ``MAX_ENVELOPE_VALUE_LEN`` chars + + Returns a defensive shallow copy on success, ``None`` on any + violation. Logs a warning on violation so abuse is visible. + """ + if not isinstance(envelope, dict) or not envelope: + return None + if len(envelope) > MAX_ENVELOPE_KEYS: + logging.warning( + "prompt metadata envelope rejected: %d keys exceeds limit %d", + len(envelope), MAX_ENVELOPE_KEYS, + ) + return None + sanitized: dict[str, str] = {} + for key, value in envelope.items(): + if not isinstance(key, str) or not isinstance(value, str): + logging.warning( + "prompt metadata envelope rejected: non-string key/value (%s=%s)", + type(key).__name__, type(value).__name__, + ) + return None + if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN: + logging.warning( + "prompt metadata envelope rejected: key or value exceeds length limit", + ) + return None + sanitized[key] = value + return sanitized + + def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]: """Pull the per-prompt metadata envelope out of a submitted prompt's ``extra_data``. Two sources, in order: - 1. Explicit ``extra_data["metadata"]`` dict — preferred path, accepted - as-is (copied so later mutations on the caller's dict don't leak). + 1. Explicit ``extra_data["metadata"]`` — sanitized via + ``_sanitize_envelope``. Oversized or wrong-typed envelopes are + rejected (a warning is logged) rather than truncated, so the + contract stays strict at the boundary. 2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward- compatibility fallback. Frontends that already stamp the workflow - id into ``extra_pnginfo`` keep working without changes; the - synthesized envelope is ``{"workflow_id": }``. + id into ``extra_pnginfo`` keep working; the synthesized envelope + is ``{"workflow_id": }``. A debug log fires so the legacy path + remains observable. Returns ``None`` when neither source yields a usable envelope. """ if not isinstance(extra_data, dict): return None - metadata = extra_data.get("metadata") - if isinstance(metadata, dict) and metadata: - return dict(metadata) + if "metadata" in extra_data: + sanitized = _sanitize_envelope(extra_data["metadata"]) + if sanitized is not None: + return sanitized + # Explicit metadata was supplied but rejected — do not fall + # through to the legacy path; the caller asked for something + # specific and got it wrong. + if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]: + return None extra_pnginfo = extra_data.get("extra_pnginfo") if isinstance(extra_pnginfo, dict): workflow = extra_pnginfo.get("workflow") if isinstance(workflow, dict): workflow_id = workflow.get("id") - if isinstance(workflow_id, str) and workflow_id: + if ( + isinstance(workflow_id, str) + and workflow_id + and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN + ): + logging.debug( + "prompt metadata envelope synthesized from extra_pnginfo.workflow.id" + ) return {"workflow_id": workflow_id} return None @@ -55,28 +128,25 @@ def inject_envelope( """Return ``data`` with a per-prompt ``metadata`` envelope attached. ``envelope_lookup`` is called with the payload's ``prompt_id`` and is - expected to return the registered envelope or ``None``. This indirection - keeps the function pure and avoids depending on any specific storage. + expected to return the registered envelope or ``None``. This keeps + the function pure and avoids depending on any specific storage. Two payload shapes are handled: - **dict** carrying ``prompt_id``. A shallow copy is returned with a - ``metadata`` key set to the envelope. The caller's dict is not - mutated. + ``metadata`` key set to the envelope. - **(preview_image, metadata_dict) tuple** — the format used by ``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented; the binary preview is passed through by reference. - The function is a no-op for: - - - payloads without a ``prompt_id``, - - payloads already declaring their own ``metadata`` field - (callers can opt out by setting it explicitly), - - prompts with no registered envelope, - - any other payload shape (raw bytes, ``None``, etc.). + No-op for payloads without a ``prompt_id``, payloads already + declaring their own ``metadata`` field, prompts with no registered + envelope, or any other payload shape. """ def inject(d: dict) -> dict: - if not isinstance(d, dict) or "metadata" in d: + if not isinstance(d, dict): + return d + if "metadata" in d: return d prompt_id = d.get("prompt_id") if not prompt_id: @@ -94,3 +164,37 @@ def inject_envelope( return data return (data[0], injected) return data + + +class PromptMetadataStore: + """Bounded ``prompt_id -> envelope`` map. + + Owned by ``PromptServer``. Populated at submission, drained when the + prompt finishes, wiped on queue cancel/delete. The FIFO cap is a + backstop: if any cleanup hook is ever skipped, the store sheds the + oldest entry instead of growing without bound. + """ + + def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY): + self._envelopes: dict[str, dict] = {} + self._capacity = capacity + + def register(self, prompt_id: str, extra_data: Any) -> None: + envelope = extract_envelope_from_extra_data(extra_data) + if envelope is None: + return + if len(self._envelopes) >= self._capacity: + self._envelopes.pop(next(iter(self._envelopes))) + self._envelopes[prompt_id] = envelope + + def unregister(self, prompt_id: str) -> None: + self._envelopes.pop(prompt_id, None) + + def inject(self, data: Any) -> Any: + return inject_envelope(data, self._envelopes.get) + + def __len__(self) -> int: + return len(self._envelopes) + + def __contains__(self, prompt_id: str) -> bool: + return prompt_id in self._envelopes diff --git a/execution.py b/execution.py index f37d0360d..970a40010 100644 --- a/execution.py +++ b/execution.py @@ -1296,7 +1296,10 @@ class PromptQueue: def wipe_queue(self): with self.mutex: + dropped_prompt_ids = [item[1] for item in self.queue] self.queue = [] + for prompt_id in dropped_prompt_ids: + self.server.unregister_prompt_metadata(prompt_id) self.server.queue_updated() def delete_queue_item(self, function): @@ -1306,8 +1309,9 @@ class PromptQueue: if len(self.queue) == 1: self.wipe_queue() else: - self.queue.pop(x) + deleted = self.queue.pop(x) heapq.heapify(self.queue) + self.server.unregister_prompt_metadata(deleted[1]) self.server.queue_updated() return True return False diff --git a/main.py b/main.py index f0293ae0a..3c4b1844f 100644 --- a/main.py +++ b/main.py @@ -318,21 +318,26 @@ def prompt_worker(q, server_instance): extra_data[k] = sensitive[k] asset_seeder.pause() - e.execute(item[2], prompt_id, extra_data, item[4]) + try: + e.execute(item[2], prompt_id, extra_data, item[4]) - need_gc = True + need_gc = True - remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] - q.task_done(item_id, - e.history_result, - status=execution.PromptQueue.ExecutionStatus( - status_str='success' if e.success else 'error', - completed=e.success, - messages=e.status_messages), process_item=remove_sensitive) - if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) - - server_instance.unregister_prompt_metadata(prompt_id) + remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] + q.task_done(item_id, + e.history_result, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + messages=e.status_messages), process_item=remove_sensitive) + if server_instance.client_id is not None: + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + finally: + # Always drop the metadata envelope. If e.execute() raises + # hard before its own error handling kicks in, the + # registered envelope would otherwise leak for the + # lifetime of the process. + server_instance.unregister_prompt_metadata(prompt_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time diff --git a/server.py b/server.py index b4876c8fa..28a44426e 100644 --- a/server.py +++ b/server.py @@ -44,10 +44,7 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager from app.node_replace_manager import NodeReplaceManager -from app.prompt_metadata import ( - extract_envelope_from_extra_data, - inject_envelope, -) +from app.prompt_metadata import PromptMetadataStore from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -257,10 +254,10 @@ class PromptServer(): self.last_prompt_id = None self.client_id = None - # prompt_id -> metadata envelope captured at submission and injected - # into outbound execution events. Keeps the workflow scope (and any - # other client-supplied tags) out of the execution layer. - self._prompt_metadata: dict[str, dict] = {} + # Bounded prompt_id -> envelope store. Populated at submission, + # drained on completion/cancel. Keeps workflow scope (and other + # client-supplied tags) out of the execution layer. + self._prompt_metadata = PromptMetadataStore() self.on_prompt_handlers = [] @@ -285,10 +282,12 @@ class PromptServer(): await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", self._inject_prompt_metadata({ - "node": self.last_node_id, - "prompt_id": self.last_prompt_id, - }), sid) + payload: dict = {"node": self.last_node_id} + if self.last_prompt_id is not None: + payload["prompt_id"] = self.last_prompt_id + await self.send( + "executing", self._inject_prompt_metadata(payload), sid + ) # Flag to track if we've received the first message first_message = True @@ -1233,17 +1232,16 @@ class PromptServer(): def register_prompt_metadata(self, prompt_id: str, extra_data) -> None: """Record per-prompt metadata for injection into outbound execution events. Called at submission, before the prompt is queued.""" - envelope = extract_envelope_from_extra_data(extra_data) - if envelope is not None: - self._prompt_metadata[prompt_id] = envelope + self._prompt_metadata.register(prompt_id, extra_data) def unregister_prompt_metadata(self, prompt_id: str) -> None: """Drop the per-prompt metadata envelope. Called after the prompt - has finished executing and its terminal events have been queued.""" - self._prompt_metadata.pop(prompt_id, None) + has finished executing and its terminal events have been queued, + or when the prompt is cancelled before reaching the worker.""" + self._prompt_metadata.unregister(prompt_id) def _inject_prompt_metadata(self, data): - return inject_envelope(data, self._prompt_metadata.get) + return self._prompt_metadata.inject(data) def send_sync(self, event, data, sid=None): data = self._inject_prompt_metadata(data) diff --git a/tests-unit/app_test/test_prompt_metadata.py b/tests-unit/app_test/test_prompt_metadata.py index 184ce22e5..e0d2f9749 100644 --- a/tests-unit/app_test/test_prompt_metadata.py +++ b/tests-unit/app_test/test_prompt_metadata.py @@ -1,12 +1,19 @@ -"""Unit tests for the pure metadata-envelope helpers in -``app.prompt_metadata``. These cover the two functions that PromptServer -wires into submission (``extract_envelope_from_extra_data``) and into the -send chokepoint (``inject_envelope``). +"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``. + +Covers the two pure helpers (``extract_envelope_from_extra_data`` and +``inject_envelope``) and the ``PromptMetadataStore`` integration class +that ``PromptServer`` owns. """ from __future__ import annotations +import pytest + from app.prompt_metadata import ( + MAX_ENVELOPE_KEYS, + MAX_ENVELOPE_KEY_LEN, + MAX_ENVELOPE_VALUE_LEN, + PromptMetadataStore, extract_envelope_from_extra_data, inject_envelope, ) @@ -42,10 +49,10 @@ class TestExtractEnvelopeFromExtraData: is None ) - def test_rejects_non_string_or_empty_workflow_id(self): - for bad in ["", 123, None, [], {}]: - extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}} - assert extract_envelope_from_extra_data(extra_data) is None + @pytest.mark.parametrize("bad", ["", 123, None, [], {}]) + def test_rejects_non_string_or_empty_workflow_id(self, bad): + extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}} + assert extract_envelope_from_extra_data(extra_data) is None def test_rejects_non_dict_inputs_at_each_level(self): assert extract_envelope_from_extra_data(None) is None @@ -73,6 +80,7 @@ class TestExtractEnvelopeFromExtraData: def test_returned_envelope_is_copy_not_reference(self): original = {"workflow_id": "wf-1"} result = extract_envelope_from_extra_data({"metadata": original}) + assert result is not None result["new_key"] = "x" assert "new_key" not in original @@ -86,10 +94,76 @@ class TestExtractEnvelopeFromExtraData: } +class TestEnvelopeSanitization: + """The wire contract is ``dict[str, str]`` with bounded size. A bad + envelope is dropped (and a warning is logged) rather than truncated, + so the boundary stays strict.""" + + def test_rejects_too_many_keys(self, caplog): + envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)} + with caplog.at_level("WARNING"): + assert extract_envelope_from_extra_data({"metadata": envelope}) is None + assert any("exceeds limit" in r.message for r in caplog.records) + + def test_accepts_max_keys_exactly(self): + envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)} + assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope + + def test_rejects_non_string_keys(self, caplog): + with caplog.at_level("WARNING"): + assert ( + extract_envelope_from_extra_data({"metadata": {42: "v"}}) + is None + ) + assert any("non-string" in r.message for r in caplog.records) + + def test_rejects_non_string_values(self, caplog): + for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]: + with caplog.at_level("WARNING"): + assert ( + extract_envelope_from_extra_data( + {"metadata": {"k": bad_value}} + ) + is None + ) + + def test_rejects_oversized_key(self): + envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"} + assert extract_envelope_from_extra_data({"metadata": envelope}) is None + + def test_rejects_oversized_value(self): + envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)} + assert extract_envelope_from_extra_data({"metadata": envelope}) is None + + def test_accepts_max_lengths_exactly(self): + envelope = { + "x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN + } + assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope + + def test_oversized_workflow_id_in_pnginfo_rejected(self): + """The legacy synthesized path also respects the value bound.""" + extra_data = { + "extra_pnginfo": { + "workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)} + } + } + assert extract_envelope_from_extra_data(extra_data) is None + + def test_invalid_explicit_metadata_does_not_fall_through(self): + """An explicit but invalid metadata dict means the caller asked + for something specific and got it wrong; the synthesized + fallback must not silently substitute.""" + extra_data = { + "metadata": {"k": 42}, # non-string value + "extra_pnginfo": {"workflow": {"id": "wf-legacy"}}, + } + assert extract_envelope_from_extra_data(extra_data) is None + + class TestInjectEnvelope: @staticmethod def _lookup(table): - """Build an envelope_lookup callable backed by a dict.""" return table.get def test_injects_envelope_on_dict_with_known_prompt_id(self): @@ -108,16 +182,15 @@ class TestInjectEnvelope: def test_passthrough_when_payload_lacks_prompt_id(self): lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) data = {"status": "ok"} - assert inject_envelope(data, lookup) is data + assert inject_envelope(data, lookup) == data def test_passthrough_when_payload_already_has_metadata(self): - """If a caller has already set a ``metadata`` field (e.g. for - opt-out or pre-augmented payloads), the function must not - overwrite it.""" + """If a caller has already set a ``metadata`` field, the + function must not overwrite it.""" lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) data = {"prompt_id": "p1", "metadata": {"workflow_id": "wf-caller"}} result = inject_envelope(data, lookup) - assert result is data + assert result == data assert result["metadata"] == {"workflow_id": "wf-caller"} def test_does_not_mutate_input_dict(self): @@ -153,29 +226,124 @@ class TestInjectEnvelope: def test_preview_tuple_passthrough_when_inner_already_has_metadata(self): lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) preview_image = ("PNG", object(), 256) - inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": 1}} + inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": "1"}} result = inject_envelope((preview_image, inner), lookup) assert result == (preview_image, inner) - def test_non_dict_non_tuple_payloads_passthrough(self): + @pytest.mark.parametrize("payload", [b"raw-bytes", None, 42]) + def test_non_dict_non_tuple_payloads_passthrough(self, payload): lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) - assert inject_envelope(b"raw-bytes", lookup) == b"raw-bytes" - assert inject_envelope(None, lookup) is None - assert inject_envelope(42, lookup) == 42 + assert inject_envelope(payload, lookup) == payload def test_tuple_of_wrong_arity_passthrough(self): - """Only the 2-tuple ``(preview, metadata_dict)`` shape is special- - cased. Other tuples must not be touched.""" + """Only the 2-tuple ``(preview, metadata_dict)`` shape is + special-cased. Other tuples must not be touched.""" lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) triple = (1, {"prompt_id": "p1"}, 3) assert inject_envelope(triple, lookup) is triple - def test_envelope_lookup_called_at_send_time(self): - """The lookup runs each time the function is called, so a producer - and consumer that share a backing dict observe the current value.""" + def test_envelope_lookup_called_per_invocation(self): + """The lookup runs each time the function is called, so changes + to the backing store are immediately visible.""" store = {"p1": {"workflow_id": "wf-1"}} first = inject_envelope({"prompt_id": "p1"}, store.get) store["p1"] = {"workflow_id": "wf-2"} second = inject_envelope({"prompt_id": "p1"}, store.get) + del store["p1"] + third = inject_envelope({"prompt_id": "p1"}, store.get) assert first["metadata"] == {"workflow_id": "wf-1"} assert second["metadata"] == {"workflow_id": "wf-2"} + assert "metadata" not in third + + +class TestPromptMetadataStore: + """End-to-end wiring tests that exercise the full register/inject/ + unregister cycle the way ``PromptServer`` does.""" + + def test_register_inject_unregister_cycle(self): + store = PromptMetadataStore() + store.register( + "p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}} + ) + injected = store.inject({"node": "5", "prompt_id": "p1"}) + assert injected == { + "node": "5", + "prompt_id": "p1", + "metadata": {"workflow_id": "wf-1"}, + } + store.unregister("p1") + passthrough = store.inject({"node": "5", "prompt_id": "p1"}) + assert "metadata" not in passthrough + + def test_register_with_no_derivable_envelope_is_noop(self): + store = PromptMetadataStore() + store.register("p1", {}) + assert "p1" not in store + assert store.inject({"prompt_id": "p1"}) == {"prompt_id": "p1"} + + def test_register_with_oversized_envelope_is_noop(self): + """Sanitization rejection means nothing is registered — the + store stays empty and inject is a passthrough.""" + store = PromptMetadataStore() + store.register( + "p1", + {"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}}, + ) + assert "p1" not in store + + def test_unregister_unknown_prompt_is_silent(self): + store = PromptMetadataStore() + store.unregister("does-not-exist") + + def test_fifo_eviction_when_capacity_exceeded(self): + """If cleanup hooks are ever bypassed, the store must shed the + oldest entry rather than grow without bound.""" + store = PromptMetadataStore(capacity=3) + store.register("p1", {"metadata": {"workflow_id": "wf-1"}}) + store.register("p2", {"metadata": {"workflow_id": "wf-2"}}) + store.register("p3", {"metadata": {"workflow_id": "wf-3"}}) + assert len(store) == 3 + + store.register("p4", {"metadata": {"workflow_id": "wf-4"}}) + assert len(store) == 3 + assert "p1" not in store + assert "p4" in store + + # The newer entries are still injectable. + assert store.inject({"prompt_id": "p4"})["metadata"] == { + "workflow_id": "wf-4" + } + # The evicted one is gone. + assert "metadata" not in store.inject({"prompt_id": "p1"}) + + def test_register_after_unregister_does_not_count_against_capacity(self): + """Normal lifecycle: register, unregister, register many — the + store should not silently evict valid entries because of stale + accounting.""" + store = PromptMetadataStore(capacity=2) + for i in range(10): + store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}}) + store.unregister(f"p{i}") + assert len(store) == 0 + + def test_re_register_overwrites(self): + store = PromptMetadataStore() + store.register("p1", {"metadata": {"workflow_id": "wf-1"}}) + store.register("p1", {"metadata": {"workflow_id": "wf-2"}}) + assert store.inject({"prompt_id": "p1"})["metadata"] == { + "workflow_id": "wf-2" + } + + def test_inject_with_no_registrations_is_passthrough(self): + store = PromptMetadataStore() + data = {"prompt_id": "p1", "node": "5"} + assert store.inject(data) == data + + def test_inject_into_preview_tuple(self): + store = PromptMetadataStore() + store.register("p1", {"metadata": {"workflow_id": "wf-1"}}) + result = store.inject((b"image-bytes", {"prompt_id": "p1"})) + assert result == (b"image-bytes", { + "prompt_id": "p1", + "metadata": {"workflow_id": "wf-1"}, + }) From fc9820ebb914ec6b1ca686f64d805204e5b1644f Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Thu, 14 May 2026 21:18:43 -0700 Subject: [PATCH 3/4] refactor(server): spread envelope keys onto payload at top level MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch the wire shape from nested ``metadata: {workflow_id: ...}`` to spreading the envelope's keys directly onto each event payload. The contract on the websocket is now identical to the prior workflow-id-on- events work — consumers read ``event.workflow_id`` directly — but the core executor still has no concept of workflow scope; the envelope is captured at submission and decorated at the server transport layer. Server-emitted fields always win on collision (``{**envelope, **d}``): a misbehaving client cannot shadow ``prompt_id``, ``node``, etc. by stamping the same key in their submission envelope. --- app/prompt_metadata.py | 22 ++++--- tests-unit/app_test/test_prompt_metadata.py | 71 +++++++++++---------- 2 files changed, 49 insertions(+), 44 deletions(-) diff --git a/app/prompt_metadata.py b/app/prompt_metadata.py index 361566249..821a516f2 100644 --- a/app/prompt_metadata.py +++ b/app/prompt_metadata.py @@ -125,36 +125,40 @@ def inject_envelope( data: Any, envelope_lookup: Callable[[str], Optional[dict]], ) -> Any: - """Return ``data`` with a per-prompt ``metadata`` envelope attached. + """Return ``data`` with the per-prompt envelope's keys spread onto it. ``envelope_lookup`` is called with the payload's ``prompt_id`` and is expected to return the registered envelope or ``None``. This keeps the function pure and avoids depending on any specific storage. + The envelope's keys are merged onto the payload at the top level so + consumers can read them directly (e.g. ``event.workflow_id``) — + matching the wire shape of the prior workflow-id-on-events work and + avoiding an extra nesting hop for clients. Server-emitted fields on + the payload always win on collision (``{**envelope, **d}``); a + misbehaving client cannot shadow ``prompt_id``, ``node``, etc. + Two payload shapes are handled: - - **dict** carrying ``prompt_id``. A shallow copy is returned with a - ``metadata`` key set to the envelope. + - **dict** carrying ``prompt_id``. A shallow copy is returned with + the envelope's keys merged onto it. - **(preview_image, metadata_dict) tuple** — the format used by ``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented; the binary preview is passed through by reference. - No-op for payloads without a ``prompt_id``, payloads already - declaring their own ``metadata`` field, prompts with no registered - envelope, or any other payload shape. + No-op for payloads without a ``prompt_id``, prompts with no + registered envelope, or any other payload shape. """ def inject(d: dict) -> dict: if not isinstance(d, dict): return d - if "metadata" in d: - return d prompt_id = d.get("prompt_id") if not prompt_id: return d envelope = envelope_lookup(prompt_id) if envelope is None: return d - return {**d, "metadata": envelope} + return {**envelope, **d} if isinstance(data, dict): return inject(data) diff --git a/tests-unit/app_test/test_prompt_metadata.py b/tests-unit/app_test/test_prompt_metadata.py index e0d2f9749..8d88d788e 100644 --- a/tests-unit/app_test/test_prompt_metadata.py +++ b/tests-unit/app_test/test_prompt_metadata.py @@ -166,12 +166,15 @@ class TestInjectEnvelope: def _lookup(table): return table.get - def test_injects_envelope_on_dict_with_known_prompt_id(self): - lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + def test_spreads_envelope_keys_onto_payload(self): + """Envelope keys are merged at the top level so consumers can + read them directly (e.g. ``event.workflow_id``).""" + lookup = self._lookup({"p1": {"workflow_id": "wf-1", "trace_id": "t-9"}}) assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == { "node": "5", "prompt_id": "p1", - "metadata": {"workflow_id": "wf-1"}, + "workflow_id": "wf-1", + "trace_id": "t-9", } def test_passthrough_when_prompt_id_not_registered(self): @@ -184,20 +187,28 @@ class TestInjectEnvelope: data = {"status": "ok"} assert inject_envelope(data, lookup) == data - def test_passthrough_when_payload_already_has_metadata(self): - """If a caller has already set a ``metadata`` field, the - function must not overwrite it.""" - lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) - data = {"prompt_id": "p1", "metadata": {"workflow_id": "wf-caller"}} - result = inject_envelope(data, lookup) - assert result == data - assert result["metadata"] == {"workflow_id": "wf-caller"} + def test_server_keys_win_on_collision_with_envelope(self): + """A misbehaving client cannot shadow server-emitted fields by + stamping the same key in their submission envelope.""" + lookup = self._lookup({ + "p1": {"prompt_id": "client-claimed", "node": "spoofed", "workflow_id": "wf-1"} + }) + result = inject_envelope({"prompt_id": "p1", "node": "5"}, lookup) + assert result["prompt_id"] == "p1" + assert result["node"] == "5" + assert result["workflow_id"] == "wf-1" def test_does_not_mutate_input_dict(self): lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) original = {"node": "5", "prompt_id": "p1"} inject_envelope(original, lookup) - assert "metadata" not in original + assert "workflow_id" not in original + + def test_does_not_mutate_envelope_dict(self): + envelope = {"workflow_id": "wf-1"} + lookup = self._lookup({"p1": envelope}) + inject_envelope({"prompt_id": "p1", "node": "5"}, lookup) + assert envelope == {"workflow_id": "wf-1"} def test_injects_into_inner_dict_of_preview_metadata_tuple(self): """``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as @@ -212,9 +223,9 @@ class TestInjectEnvelope: assert result[1] == { "node_id": "5", "prompt_id": "p1", - "metadata": {"workflow_id": "wf-1"}, + "workflow_id": "wf-1", } - assert "metadata" not in inner + assert "workflow_id" not in inner def test_preview_tuple_passthrough_when_no_envelope_registered(self): lookup = self._lookup({}) @@ -223,13 +234,6 @@ class TestInjectEnvelope: result = inject_envelope((preview_image, inner), lookup) assert result == (preview_image, inner) - def test_preview_tuple_passthrough_when_inner_already_has_metadata(self): - lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) - preview_image = ("PNG", object(), 256) - inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": "1"}} - result = inject_envelope((preview_image, inner), lookup) - assert result == (preview_image, inner) - @pytest.mark.parametrize("payload", [b"raw-bytes", None, 42]) def test_non_dict_non_tuple_payloads_passthrough(self, payload): lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) @@ -251,9 +255,9 @@ class TestInjectEnvelope: second = inject_envelope({"prompt_id": "p1"}, store.get) del store["p1"] third = inject_envelope({"prompt_id": "p1"}, store.get) - assert first["metadata"] == {"workflow_id": "wf-1"} - assert second["metadata"] == {"workflow_id": "wf-2"} - assert "metadata" not in third + assert first["workflow_id"] == "wf-1" + assert second["workflow_id"] == "wf-2" + assert "workflow_id" not in third class TestPromptMetadataStore: @@ -269,17 +273,18 @@ class TestPromptMetadataStore: assert injected == { "node": "5", "prompt_id": "p1", - "metadata": {"workflow_id": "wf-1"}, + "workflow_id": "wf-1", } store.unregister("p1") passthrough = store.inject({"node": "5", "prompt_id": "p1"}) - assert "metadata" not in passthrough + assert "workflow_id" not in passthrough def test_register_with_no_derivable_envelope_is_noop(self): store = PromptMetadataStore() store.register("p1", {}) assert "p1" not in store - assert store.inject({"prompt_id": "p1"}) == {"prompt_id": "p1"} + data = {"prompt_id": "p1"} + assert store.inject(data) == data def test_register_with_oversized_envelope_is_noop(self): """Sanitization rejection means nothing is registered — the @@ -310,11 +315,9 @@ class TestPromptMetadataStore: assert "p4" in store # The newer entries are still injectable. - assert store.inject({"prompt_id": "p4"})["metadata"] == { - "workflow_id": "wf-4" - } + assert store.inject({"prompt_id": "p4"})["workflow_id"] == "wf-4" # The evicted one is gone. - assert "metadata" not in store.inject({"prompt_id": "p1"}) + assert "workflow_id" not in store.inject({"prompt_id": "p1"}) def test_register_after_unregister_does_not_count_against_capacity(self): """Normal lifecycle: register, unregister, register many — the @@ -330,9 +333,7 @@ class TestPromptMetadataStore: store = PromptMetadataStore() store.register("p1", {"metadata": {"workflow_id": "wf-1"}}) store.register("p1", {"metadata": {"workflow_id": "wf-2"}}) - assert store.inject({"prompt_id": "p1"})["metadata"] == { - "workflow_id": "wf-2" - } + assert store.inject({"prompt_id": "p1"})["workflow_id"] == "wf-2" def test_inject_with_no_registrations_is_passthrough(self): store = PromptMetadataStore() @@ -345,5 +346,5 @@ class TestPromptMetadataStore: result = store.inject((b"image-bytes", {"prompt_id": "p1"})) assert result == (b"image-bytes", { "prompt_id": "p1", - "metadata": {"workflow_id": "wf-1"}, + "workflow_id": "wf-1", }) From 63784baed56a2d17bbc6e05042b9d7cdf258745c Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Thu, 14 May 2026 21:26:06 -0700 Subject: [PATCH 4/4] fix(server): serialize PromptMetadataStore access with a lock MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses comfyanonymous's review nit on PR #13905. The store is touched from three threads — the aiohttp event loop (``register`` via ``post_prompt``), the worker thread (``unregister`` via the ``prompt_worker`` try/finally and ``execution_error`` paths), and any thread that fires ``send_sync`` (``inject``). Individual ``dict`` operations are GIL-atomic but ``register``'s ``len -> pop -> setitem`` and ``inject``'s ``get -> {**a, **b}`` are multi-step compounds whose interleaving without a lock is racy. A single ``threading.Lock`` keeps the FIFO cap honest and snapshots the envelope under the lock before the spread runs. Adds a stress-test that runs concurrent register/unregister/inject for 100 ms across five threads and asserts no exception escapes and the capacity bound is held. --- app/prompt_metadata.py | 36 +++++++++--- tests-unit/app_test/test_prompt_metadata.py | 62 +++++++++++++++++++++ 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/app/prompt_metadata.py b/app/prompt_metadata.py index 821a516f2..7d6c29db1 100644 --- a/app/prompt_metadata.py +++ b/app/prompt_metadata.py @@ -15,6 +15,7 @@ instance and the helpers can be unit-tested without the rest of the app. from __future__ import annotations import logging +import threading from typing import Any, Callable, Optional @@ -177,28 +178,49 @@ class PromptMetadataStore: prompt finishes, wiped on queue cancel/delete. The FIFO cap is a backstop: if any cleanup hook is ever skipped, the store sheds the oldest entry instead of growing without bound. + + Access is serialized through a ``threading.Lock``. ``register`` runs + on the aiohttp event-loop thread, ``unregister`` runs on the + ``prompt_worker`` thread, and ``inject`` runs on whichever thread + fires ``send_sync`` (event loop, worker, asset seeder). Individual + ``dict`` ops are GIL-atomic, but ``register``'s + ``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}`` + are multi-step compounds whose interleaving without a lock is + racy. The lock is uncontended in steady state (sub-microsecond + critical sections) so the cost is negligible. """ def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY): self._envelopes: dict[str, dict] = {} self._capacity = capacity + self._lock = threading.Lock() def register(self, prompt_id: str, extra_data: Any) -> None: envelope = extract_envelope_from_extra_data(extra_data) if envelope is None: return - if len(self._envelopes) >= self._capacity: - self._envelopes.pop(next(iter(self._envelopes))) - self._envelopes[prompt_id] = envelope + with self._lock: + if len(self._envelopes) >= self._capacity: + self._envelopes.pop(next(iter(self._envelopes))) + self._envelopes[prompt_id] = envelope def unregister(self, prompt_id: str) -> None: - self._envelopes.pop(prompt_id, None) + with self._lock: + self._envelopes.pop(prompt_id, None) def inject(self, data: Any) -> Any: - return inject_envelope(data, self._envelopes.get) + # Snapshot the envelope under the lock so the spread in + # ``inject_envelope`` runs against a consistent view even if a + # concurrent ``register``/``unregister`` is mutating the map. + def locked_lookup(prompt_id: str) -> Optional[dict]: + with self._lock: + return self._envelopes.get(prompt_id) + return inject_envelope(data, locked_lookup) def __len__(self) -> int: - return len(self._envelopes) + with self._lock: + return len(self._envelopes) def __contains__(self, prompt_id: str) -> bool: - return prompt_id in self._envelopes + with self._lock: + return prompt_id in self._envelopes diff --git a/tests-unit/app_test/test_prompt_metadata.py b/tests-unit/app_test/test_prompt_metadata.py index 8d88d788e..b5241dd5d 100644 --- a/tests-unit/app_test/test_prompt_metadata.py +++ b/tests-unit/app_test/test_prompt_metadata.py @@ -348,3 +348,65 @@ class TestPromptMetadataStore: "prompt_id": "p1", "workflow_id": "wf-1", }) + + def test_concurrent_access_does_not_corrupt_or_raise(self): + """Smoke test for the store's lock. ``register`` is called from + the aiohttp event-loop thread, ``unregister`` from the worker + thread, and ``inject`` fires on every ``send_sync`` from + whichever thread emits the event. Run all three concurrently + and assert no exception escapes and the store stays internally + consistent (the FIFO cap is never exceeded).""" + import threading + + store = PromptMetadataStore(capacity=64) + stop = threading.Event() + errors: list[BaseException] = [] + + def registrar(): + i = 0 + try: + while not stop.is_set(): + store.register( + f"p{i % 100}", + {"metadata": {"workflow_id": f"wf-{i}"}}, + ) + i += 1 + except BaseException as e: + errors.append(e) + + def canceller(): + i = 0 + try: + while not stop.is_set(): + store.unregister(f"p{i % 100}") + i += 1 + except BaseException as e: + errors.append(e) + + def injector(): + i = 0 + try: + while not stop.is_set(): + store.inject({"prompt_id": f"p{i % 100}", "node": "5"}) + i += 1 + except BaseException as e: + errors.append(e) + + threads = [ + threading.Thread(target=registrar), + threading.Thread(target=registrar), + threading.Thread(target=canceller), + threading.Thread(target=injector), + threading.Thread(target=injector), + ] + for t in threads: + t.start() + # Brief burst — long enough to interleave many ops, short enough + # not to slow CI. + threading.Event().wait(0.1) + stop.set() + for t in threads: + t.join(timeout=2.0) + + assert errors == [], f"concurrent access raised: {errors[:3]}" + assert len(store) <= 64, "FIFO cap was breached under contention"