From dfc901078e8dc8ca4106f0b0058264a029c2313d Mon Sep 17 00:00:00 2001 From: dante01yoon Date: Wed, 20 May 2026 16:32:20 +0900 Subject: [PATCH] feat(server): propagate opaque per-prompt metadata on WebSocket frames (FE-745) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit server.py builds an opaque dict at submission time from extra_data and pins it on PromptServer.active_prompt_metadata while main.py's worker drives the prompt. send_sync spreads the dict's key/value pairs onto outgoing payloads so frames carry whatever tags the submission attached (today: workflow_id). The mechanism is intentionally untyped — the transport layer doesn't know what workflow_id means or treat any key specially. Adding a new propagated field requires only a one-line addition in post_prompt; execution.py and comfy_execution/progress.py are not touched. execution.py changes: 0 lines. --- comfy_execution/jobs.py | 19 +- main.py | 30 ++-- server.py | 18 +- .../server_test/test_prompt_metadata.py | 162 ++++++++++++++++++ 4 files changed, 214 insertions(+), 15 deletions(-) create mode 100644 tests-unit/server_test/test_prompt_metadata.py diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fcd7ef735..c3d21b8e9 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -93,6 +93,22 @@ def _create_text_preview(value: str) -> dict: } +def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]: + """Return extra_data["extra_pnginfo"]["workflow"]["id"] when it is a non-empty string.""" + if not isinstance(extra_data, dict): + return None + extra_pnginfo = extra_data.get('extra_pnginfo') + if not isinstance(extra_pnginfo, dict): + return None + workflow = extra_pnginfo.get('workflow') + if not isinstance(workflow, dict): + return None + workflow_id = workflow.get('id') + if isinstance(workflow_id, str) and workflow_id: + return workflow_id + return None + + def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: """Extract create_time and workflow_id from extra_data. @@ -100,8 +116,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str tuple: (create_time, workflow_id) """ create_time = extra_data.get('create_time') - extra_pnginfo = extra_data.get('extra_pnginfo', {}) - workflow_id = extra_pnginfo.get('workflow', {}).get('id') + workflow_id = extract_workflow_id(extra_data) return create_time, workflow_id diff --git a/main.py b/main.py index a6fdaf43c..f029bf3bd 100644 --- a/main.py +++ b/main.py @@ -317,20 +317,28 @@ def prompt_worker(q, server_instance): for k in sensitive: extra_data[k] = sensitive[k] + metadata = item[6] if len(item) > 6 and isinstance(item[6], dict) else None + server_instance.active_prompt_metadata = metadata + asset_seeder.pause() - e.execute(item[2], prompt_id, extra_data, item[4]) + try: + e.execute(item[2], prompt_id, extra_data, item[4]) - need_gc = True + need_gc = True - remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] - q.task_done(item_id, - e.history_result, - status=execution.PromptQueue.ExecutionStatus( - status_str='success' if e.success else 'error', - completed=e.success, - messages=e.status_messages), process_item=remove_sensitive) - if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + # Drop sensitive (index 5) and metadata (index 6); history keeps a 5-tuple. + remove_sensitive = lambda prompt: prompt[:5] + q.task_done(item_id, + e.history_result, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + messages=e.status_messages), process_item=remove_sensitive) + if server_instance.client_id is not None: + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + finally: + # Clear after the terminal send so that frame still carries metadata. + server_instance.active_prompt_metadata = None current_time = time.perf_counter() execution_time = current_time - execution_start_time diff --git a/server.py b/server.py index 44470b904..bfe97f761 100644 --- a/server.py +++ b/server.py @@ -252,6 +252,9 @@ class PromptServer(): self.last_node_id = None self.client_id = None + # Opaque tag dict pinned by main.py around each prompt; send_sync spreads it. + self.active_prompt_metadata: Optional[dict] = None + self.on_prompt_handlers = [] @routes.get('/ws') @@ -275,7 +278,13 @@ class PromptServer(): await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", { "node": self.last_node_id }, sid) + payload = {"node": self.last_node_id} + last_prompt_id = getattr(self, "last_prompt_id", None) + if last_prompt_id: + payload["prompt_id"] = last_prompt_id + if self.active_prompt_metadata: + payload = {**self.active_prompt_metadata, **payload} + await self.send("executing", payload, sid) # Flag to track if we've received the first message first_message = True @@ -955,7 +964,9 @@ class PromptServer(): if sensitive_val in extra_data: sensitive[sensitive_val] = extra_data.pop(sensitive_val) extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds - self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) + raw_metadata = extra_data.pop("metadata", None) + client_metadata = raw_metadata if isinstance(raw_metadata, dict) else {} + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive, client_metadata)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: @@ -1217,6 +1228,9 @@ class PromptServer(): await send_socket_catch_exception(self.sockets[sid].send_json, message) def send_sync(self, event, data, sid=None): + meta = self.active_prompt_metadata + if meta and isinstance(data, dict): + data = {**meta, **data} self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) diff --git a/tests-unit/server_test/test_prompt_metadata.py b/tests-unit/server_test/test_prompt_metadata.py new file mode 100644 index 000000000..5f3ff997f --- /dev/null +++ b/tests-unit/server_test/test_prompt_metadata.py @@ -0,0 +1,162 @@ +"""Tests for the opaque per-prompt metadata mechanism on PromptServer.""" + +from unittest.mock import MagicMock + +import pytest + +from comfy_execution.jobs import extract_workflow_id + + +class TestExtractWorkflowId: + + def test_returns_id_when_present(self): + assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": "wf-1"}}}) == "wf-1" + + def test_returns_none_when_missing(self): + assert extract_workflow_id({}) is None + assert extract_workflow_id({"extra_pnginfo": {}}) is None + assert extract_workflow_id({"extra_pnginfo": {"workflow": {}}}) is None + + def test_returns_none_for_empty_or_wrong_type(self): + assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": ""}}}) is None + assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": 42}}}) is None + assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": None}}}) is None + + def test_returns_none_for_non_dict_input(self): + assert extract_workflow_id(None) is None + assert extract_workflow_id("not a dict") is None + assert extract_workflow_id({"extra_pnginfo": "not a dict"}) is None + assert extract_workflow_id({"extra_pnginfo": {"workflow": "not a dict"}}) is None + + +class _FakeServer: + """Minimal PromptServer stand-in mirroring send_sync verbatim.""" + + def __init__(self): + self.active_prompt_metadata = None + self.captured = [] + self.loop = MagicMock() + self.loop.call_soon_threadsafe.side_effect = ( + lambda fn, msg: self.captured.append(msg) + ) + self.messages = MagicMock() + self.messages.put_nowait = MagicMock() + + def send_sync(self, event, data, sid=None): + meta = self.active_prompt_metadata + if meta and isinstance(data, dict): + data = {**meta, **data} + self.loop.call_soon_threadsafe( + self.messages.put_nowait, (event, data, sid) + ) + + +@pytest.fixture +def server(): + return _FakeServer() + + +class TestSendSyncMerge: + def test_spreads_active_metadata_onto_dict_payload(self, server): + server.active_prompt_metadata = {"workflow_id": "wf-1"} + + server.send_sync( + "executing", {"node": "n1", "prompt_id": "p1"}, "client-1" + ) + + event, data, sid = server.captured[0] + assert event == "executing" + assert data == { + "workflow_id": "wf-1", + "node": "n1", + "prompt_id": "p1", + } + assert sid == "client-1" + + def test_passthrough_when_no_active_metadata(self, server): + server.active_prompt_metadata = None + + server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}) + + _, data, _ = server.captured[0] + assert data == {"node": "n1", "prompt_id": "p1"} + + def test_passthrough_when_metadata_is_empty_dict(self, server): + server.active_prompt_metadata = {} + + server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}) + + _, data, _ = server.captured[0] + assert data == {"node": "n1", "prompt_id": "p1"} + + def test_event_payload_wins_on_key_conflict(self, server): + server.active_prompt_metadata = {"workflow_id": "wf-1", "prompt_id": "from-meta"} + + server.send_sync("executing", {"node": "n1", "prompt_id": "from-frame"}, "c1") + + _, data, _ = server.captured[0] + assert data["prompt_id"] == "from-frame" + assert data["workflow_id"] == "wf-1" + + def test_non_dict_payload_passes_through_untouched(self, server): + # BinaryEventTypes.TEXT byte frames must not be merged. + server.active_prompt_metadata = {"workflow_id": "wf-1"} + + server.send_sync("text", b"\x00\x00\x00\x03foobar", "c1") + + _, data, _ = server.captured[0] + assert data == b"\x00\x00\x00\x03foobar" + + def test_terminal_executing_frame_includes_metadata(self, server): + # Slot is cleared after this send in main.py so the reset still carries metadata (#13684 race). + server.active_prompt_metadata = {"workflow_id": "wf-1"} + + server.send_sync( + "executing", {"node": None, "prompt_id": "p1"}, "client-1" + ) + + _, data, _ = server.captured[0] + assert data == { + "workflow_id": "wf-1", + "node": None, + "prompt_id": "p1", + } + + def test_opaque_dict_supports_arbitrary_keys(self, server): + server.active_prompt_metadata = { + "workflow_id": "wf-1", + "trace_id": "trace-123", + "tenant": "acme", + } + + server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}) + + _, data, _ = server.captured[0] + assert data["workflow_id"] == "wf-1" + assert data["trace_id"] == "trace-123" + assert data["tenant"] == "acme" + + +class TestWorkerSerializationIsolatesMetadata: + def test_two_prompts_sharing_prompt_id_get_correct_metadata(self, server): + # Prompt A + server.active_prompt_metadata = {"workflow_id": "wf-AAA"} + server.send_sync("execution_start", {"prompt_id": "P-shared"}) + server.send_sync("executing", {"node": "n1", "prompt_id": "P-shared"}) + server.send_sync("executing", {"node": None, "prompt_id": "P-shared"}) + server.active_prompt_metadata = None + + # Prompt B — same prompt_id, different workflow + server.active_prompt_metadata = {"workflow_id": "wf-BBB"} + server.send_sync("execution_start", {"prompt_id": "P-shared"}) + server.send_sync("executing", {"node": "n2", "prompt_id": "P-shared"}) + server.send_sync("executing", {"node": None, "prompt_id": "P-shared"}) + server.active_prompt_metadata = None + + frames = [d for (_, d, _) in server.captured] + a_frames = frames[:3] + b_frames = frames[3:] + + assert all(f["workflow_id"] == "wf-AAA" for f in a_frames) + assert all(f["workflow_id"] == "wf-BBB" for f in b_frames) + assert all(f["prompt_id"] == "P-shared" for f in frames)