From 74cfcaa3181cbe82d5f160fe6f8d62e354810657 Mon Sep 17 00:00:00 2001 From: Deep Mehta Date: Thu, 14 May 2026 20:47:00 -0700 Subject: [PATCH] feat(server): per-prompt metadata envelope on websocket events Replaces the workflow_id-on-every-event approach (#13684, reverted in #13901) with a generic metadata envelope captured at submission and injected at the server-side send chokepoint. - POST /prompt accepts an opaque ``extra_data.metadata`` dict (falls back to synthesizing ``{"workflow_id": }`` from ``extra_pnginfo.workflow.id`` so existing frontends keep working). - ``PromptServer`` owns a ``prompt_id -> metadata`` map populated at submission, drained when the prompt finishes. ``send_sync`` injects the envelope into any outbound payload that carries a ``prompt_id``, including the ``(preview_image, metadata_dict)`` tuple used by ``PREVIEW_IMAGE_WITH_METADATA``. WS reconnect path carries it too. - Pure helpers live in ``app/prompt_metadata.py`` so the execution layer never depends on workflow concepts and the helpers can be unit-tested without torch. Execution layer (``execution.py``, ``comfy_execution/*``) and the jobs API are unchanged. Backward compatible: existing fields and shapes are preserved, only an additional ``metadata`` field is attached when present. --- app/prompt_metadata.py | 96 +++++++++++ main.py | 2 + server.py | 32 +++- tests-unit/app_test/test_prompt_metadata.py | 181 ++++++++++++++++++++ 4 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 app/prompt_metadata.py create mode 100644 tests-unit/app_test/test_prompt_metadata.py diff --git a/app/prompt_metadata.py b/app/prompt_metadata.py new file mode 100644 index 000000000..43e387769 --- /dev/null +++ b/app/prompt_metadata.py @@ -0,0 +1,96 @@ +"""Per-prompt metadata envelope shared between submission and outbound events. + +The metadata envelope is an opaque dict (e.g. ``{"workflow_id": ...}``) +attached to a prompt at submission and injected by the server into every +outbound execution event that carries a ``prompt_id``. It lets consumers +scope state by tags they care about (workflow, trace, tenant) without the +execution layer ever needing to know those tags exist. + +Two pure functions live here; ``PromptServer`` owns the per-prompt map and +wires them into the submission and send paths. +""" + +from __future__ import annotations + +from typing import Any, Callable, Optional + + +def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]: + """Pull the per-prompt metadata envelope out of a submitted prompt's + ``extra_data``. + + Two sources, in order: + + 1. Explicit ``extra_data["metadata"]`` dict — preferred path, accepted + as-is (copied so later mutations on the caller's dict don't leak). + 2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward- + compatibility fallback. Frontends that already stamp the workflow + id into ``extra_pnginfo`` keep working without changes; the + synthesized envelope is ``{"workflow_id": }``. + + Returns ``None`` when neither source yields a usable envelope. + """ + if not isinstance(extra_data, dict): + return None + + metadata = extra_data.get("metadata") + if isinstance(metadata, dict) and metadata: + return dict(metadata) + + extra_pnginfo = extra_data.get("extra_pnginfo") + if isinstance(extra_pnginfo, dict): + workflow = extra_pnginfo.get("workflow") + if isinstance(workflow, dict): + workflow_id = workflow.get("id") + if isinstance(workflow_id, str) and workflow_id: + return {"workflow_id": workflow_id} + + return None + + +def inject_envelope( + data: Any, + envelope_lookup: Callable[[str], Optional[dict]], +) -> Any: + """Return ``data`` with a per-prompt ``metadata`` envelope attached. + + ``envelope_lookup`` is called with the payload's ``prompt_id`` and is + expected to return the registered envelope or ``None``. This indirection + keeps the function pure and avoids depending on any specific storage. + + Two payload shapes are handled: + + - **dict** carrying ``prompt_id``. A shallow copy is returned with a + ``metadata`` key set to the envelope. The caller's dict is not + mutated. + - **(preview_image, metadata_dict) tuple** — the format used by + ``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented; + the binary preview is passed through by reference. + + The function is a no-op for: + + - payloads without a ``prompt_id``, + - payloads already declaring their own ``metadata`` field + (callers can opt out by setting it explicitly), + - prompts with no registered envelope, + - any other payload shape (raw bytes, ``None``, etc.). + """ + def inject(d: dict) -> dict: + if not isinstance(d, dict) or "metadata" in d: + return d + prompt_id = d.get("prompt_id") + if not prompt_id: + return d + envelope = envelope_lookup(prompt_id) + if envelope is None: + return d + return {**d, "metadata": envelope} + + if isinstance(data, dict): + return inject(data) + if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict): + injected = inject(data[1]) + if injected is data[1]: + return data + return (data[0], injected) + return data diff --git a/main.py b/main.py index a6fdaf43c..f0293ae0a 100644 --- a/main.py +++ b/main.py @@ -332,6 +332,8 @@ def prompt_worker(q, server_instance): if server_instance.client_id is not None: server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + server_instance.unregister_prompt_metadata(prompt_id) + current_time = time.perf_counter() execution_time = current_time - execution_start_time diff --git a/server.py b/server.py index 44470b904..b4876c8fa 100644 --- a/server.py +++ b/server.py @@ -44,6 +44,10 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager from app.node_replace_manager import NodeReplaceManager +from app.prompt_metadata import ( + extract_envelope_from_extra_data, + inject_envelope, +) from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -250,8 +254,14 @@ class PromptServer(): routes = web.RouteTableDef() self.routes = routes self.last_node_id = None + self.last_prompt_id = None self.client_id = None + # prompt_id -> metadata envelope captured at submission and injected + # into outbound execution events. Keeps the workflow scope (and any + # other client-supplied tags) out of the execution layer. + self._prompt_metadata: dict[str, dict] = {} + self.on_prompt_handlers = [] @routes.get('/ws') @@ -275,7 +285,10 @@ class PromptServer(): await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", { "node": self.last_node_id }, sid) + await self.send("executing", self._inject_prompt_metadata({ + "node": self.last_node_id, + "prompt_id": self.last_prompt_id, + }), sid) # Flag to track if we've received the first message first_message = True @@ -955,6 +968,7 @@ class PromptServer(): if sensitive_val in extra_data: sensitive[sensitive_val] = extra_data.pop(sensitive_val) extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds + self.register_prompt_metadata(prompt_id, extra_data) self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) @@ -1216,7 +1230,23 @@ class PromptServer(): elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_json, message) + def register_prompt_metadata(self, prompt_id: str, extra_data) -> None: + """Record per-prompt metadata for injection into outbound execution + events. Called at submission, before the prompt is queued.""" + envelope = extract_envelope_from_extra_data(extra_data) + if envelope is not None: + self._prompt_metadata[prompt_id] = envelope + + def unregister_prompt_metadata(self, prompt_id: str) -> None: + """Drop the per-prompt metadata envelope. Called after the prompt + has finished executing and its terminal events have been queued.""" + self._prompt_metadata.pop(prompt_id, None) + + def _inject_prompt_metadata(self, data): + return inject_envelope(data, self._prompt_metadata.get) + def send_sync(self, event, data, sid=None): + data = self._inject_prompt_metadata(data) self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) diff --git a/tests-unit/app_test/test_prompt_metadata.py b/tests-unit/app_test/test_prompt_metadata.py new file mode 100644 index 000000000..184ce22e5 --- /dev/null +++ b/tests-unit/app_test/test_prompt_metadata.py @@ -0,0 +1,181 @@ +"""Unit tests for the pure metadata-envelope helpers in +``app.prompt_metadata``. These cover the two functions that PromptServer +wires into submission (``extract_envelope_from_extra_data``) and into the +send chokepoint (``inject_envelope``). +""" + +from __future__ import annotations + +from app.prompt_metadata import ( + extract_envelope_from_extra_data, + inject_envelope, +) + + +class TestExtractEnvelopeFromExtraData: + def test_explicit_metadata_dict_is_used_as_is(self): + extra_data = {"metadata": {"workflow_id": "wf-1", "trace_id": "t-9"}} + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-1", + "trace_id": "t-9", + } + + def test_explicit_metadata_takes_precedence_over_extra_pnginfo(self): + extra_data = { + "metadata": {"workflow_id": "explicit"}, + "extra_pnginfo": {"workflow": {"id": "fallback"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "explicit" + } + + def test_falls_back_to_extra_pnginfo_workflow_id(self): + extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-legacy"}}} + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + def test_returns_none_when_no_metadata_and_no_workflow_id(self): + assert extract_envelope_from_extra_data({}) is None + assert ( + extract_envelope_from_extra_data({"extra_pnginfo": {"workflow": {}}}) + is None + ) + + def test_rejects_non_string_or_empty_workflow_id(self): + for bad in ["", 123, None, [], {}]: + extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}} + assert extract_envelope_from_extra_data(extra_data) is None + + def test_rejects_non_dict_inputs_at_each_level(self): + assert extract_envelope_from_extra_data(None) is None + assert extract_envelope_from_extra_data("not-a-dict") is None + assert ( + extract_envelope_from_extra_data({"extra_pnginfo": "not-a-dict"}) + is None + ) + assert ( + extract_envelope_from_extra_data( + {"extra_pnginfo": {"workflow": "not-a-dict"}} + ) + is None + ) + + def test_empty_explicit_metadata_falls_through_to_workflow_id(self): + extra_data = { + "metadata": {}, + "extra_pnginfo": {"workflow": {"id": "wf-legacy"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + def test_returned_envelope_is_copy_not_reference(self): + original = {"workflow_id": "wf-1"} + result = extract_envelope_from_extra_data({"metadata": original}) + result["new_key"] = "x" + assert "new_key" not in original + + def test_non_dict_explicit_metadata_falls_through_to_workflow_id(self): + extra_data = { + "metadata": "not-a-dict", + "extra_pnginfo": {"workflow": {"id": "wf-legacy"}}, + } + assert extract_envelope_from_extra_data(extra_data) == { + "workflow_id": "wf-legacy" + } + + +class TestInjectEnvelope: + @staticmethod + def _lookup(table): + """Build an envelope_lookup callable backed by a dict.""" + return table.get + + def test_injects_envelope_on_dict_with_known_prompt_id(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == { + "node": "5", + "prompt_id": "p1", + "metadata": {"workflow_id": "wf-1"}, + } + + def test_passthrough_when_prompt_id_not_registered(self): + lookup = self._lookup({}) + data = {"node": "5", "prompt_id": "unknown"} + assert inject_envelope(data, lookup) == data + + def test_passthrough_when_payload_lacks_prompt_id(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + data = {"status": "ok"} + assert inject_envelope(data, lookup) is data + + def test_passthrough_when_payload_already_has_metadata(self): + """If a caller has already set a ``metadata`` field (e.g. for + opt-out or pre-augmented payloads), the function must not + overwrite it.""" + lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) + data = {"prompt_id": "p1", "metadata": {"workflow_id": "wf-caller"}} + result = inject_envelope(data, lookup) + assert result is data + assert result["metadata"] == {"workflow_id": "wf-caller"} + + def test_does_not_mutate_input_dict(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + original = {"node": "5", "prompt_id": "p1"} + inject_envelope(original, lookup) + assert "metadata" not in original + + def test_injects_into_inner_dict_of_preview_metadata_tuple(self): + """``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as + ``(preview_image, metadata_dict)``; the inner dict is the only + place the envelope can attach.""" + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + preview_image = ("PNG", object(), 256) + inner = {"node_id": "5", "prompt_id": "p1"} + result = inject_envelope((preview_image, inner), lookup) + assert isinstance(result, tuple) + assert result[0] is preview_image + assert result[1] == { + "node_id": "5", + "prompt_id": "p1", + "metadata": {"workflow_id": "wf-1"}, + } + assert "metadata" not in inner + + def test_preview_tuple_passthrough_when_no_envelope_registered(self): + lookup = self._lookup({}) + preview_image = ("PNG", object(), 256) + inner = {"node_id": "5", "prompt_id": "unknown"} + result = inject_envelope((preview_image, inner), lookup) + assert result == (preview_image, inner) + + def test_preview_tuple_passthrough_when_inner_already_has_metadata(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) + preview_image = ("PNG", object(), 256) + inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": 1}} + result = inject_envelope((preview_image, inner), lookup) + assert result == (preview_image, inner) + + def test_non_dict_non_tuple_payloads_passthrough(self): + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + assert inject_envelope(b"raw-bytes", lookup) == b"raw-bytes" + assert inject_envelope(None, lookup) is None + assert inject_envelope(42, lookup) == 42 + + def test_tuple_of_wrong_arity_passthrough(self): + """Only the 2-tuple ``(preview, metadata_dict)`` shape is special- + cased. Other tuples must not be touched.""" + lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) + triple = (1, {"prompt_id": "p1"}, 3) + assert inject_envelope(triple, lookup) is triple + + def test_envelope_lookup_called_at_send_time(self): + """The lookup runs each time the function is called, so a producer + and consumer that share a backing dict observe the current value.""" + store = {"p1": {"workflow_id": "wf-1"}} + first = inject_envelope({"prompt_id": "p1"}, store.get) + store["p1"] = {"workflow_id": "wf-2"} + second = inject_envelope({"prompt_id": "p1"}, store.get) + assert first["metadata"] == {"workflow_id": "wf-1"} + assert second["metadata"] == {"workflow_id": "wf-2"}