From db9c8cc2fd0b99b7bea16bca79f5d1162e916876 Mon Sep 17 00:00:00 2001 From: dante01yoon Date: Wed, 20 May 2026 19:01:38 +0900 Subject: [PATCH] fix(server): scope prompt metadata to active prompt_id and validate at submission Address adversarial-review findings on FE-745 metadata propagation: - send_sync previously spread active_prompt_metadata onto every dict payload, contaminating unrelated status/queue broadcasts with the running prompt's workflow_id. Change the slot to (prompt_id, metadata) and only inject when payload.prompt_id matches the active prompt_id. Same condition applied to the WS reconnect catch-up frame. - post_prompt now validates extra_data.metadata at the submission boundary: flat dict[str,str], max 16 keys, 64-char keys, 256-char values, and reserved server-side keys (prompt_id, node, output, etc.) are rejected with 400. Removes the broadcast-amplification vector where a client could submit arbitrarily large metadata and force it onto every WS frame. - Extract validate_client_metadata + caps into app/prompt_metadata.py so tests can import without pulling server.py's import-time side effects. - Expand tests-unit/server_test/test_prompt_metadata.py from 12 to 47: add TestStatusBroadcastsAreNotContaminated for prompt_id-scoping and TestValidateClientMetadata for the new submission-boundary checks (including parametrized reserved-key rejection). --- app/prompt_metadata.py | 44 +++++ main.py | 2 +- server.py | 29 ++- .../server_test/test_prompt_metadata.py | 166 +++++++++++++++--- 4 files changed, 212 insertions(+), 29 deletions(-) create mode 100644 app/prompt_metadata.py diff --git a/app/prompt_metadata.py b/app/prompt_metadata.py new file mode 100644 index 000000000..5d0ef36b2 --- /dev/null +++ b/app/prompt_metadata.py @@ -0,0 +1,44 @@ +"""Validation for client-supplied per-prompt metadata (extra_data.metadata).""" + +from typing import Optional + + +MAX_METADATA_KEYS = 16 +MAX_METADATA_KEY_LEN = 64 +MAX_METADATA_VALUE_LEN = 256 + +# Server-emitted top-level fields on prompt-scoped WebSocket events. +# Client-supplied metadata may not shadow these — payload-wins-on-conflict +# only protects keys present in each individual frame, so reserve them +# at the submission boundary as defense in depth. +RESERVED_METADATA_KEYS = frozenset({ + "prompt_id", "node", "display_node", "output", "nodes", "node_id", + "node_type", "executed", "exception_message", "exception_type", + "traceback", "current_inputs", "current_outputs", "timestamp", + "sid", "status", "prompt", "value", "max", +}) + + +def validate_client_metadata(raw) -> tuple[Optional[dict], Optional[str]]: + """Return ``(cleaned_metadata, error_message)``. + + A missing field (``None``) is treated as empty metadata. Anything else + must be a flat ``dict[str, str]`` within the size caps and free of + reserved keys. + """ + if raw is None: + return {}, None + if not isinstance(raw, dict): + return None, "extra_data.metadata must be an object" + if len(raw) > MAX_METADATA_KEYS: + return None, f"extra_data.metadata exceeds {MAX_METADATA_KEYS} keys" + cleaned: dict = {} + for key, value in raw.items(): + if not isinstance(key, str) or not key or len(key) > MAX_METADATA_KEY_LEN: + return None, f"metadata key must be a non-empty string up to {MAX_METADATA_KEY_LEN} chars" + if key in RESERVED_METADATA_KEYS: + return None, f"metadata key '{key}' is reserved" + if not isinstance(value, str) or len(value) > MAX_METADATA_VALUE_LEN: + return None, f"metadata value for '{key}' must be a string up to {MAX_METADATA_VALUE_LEN} chars" + cleaned[key] = value + return cleaned, None diff --git a/main.py b/main.py index f029bf3bd..161142d8d 100644 --- a/main.py +++ b/main.py @@ -318,7 +318,7 @@ def prompt_worker(q, server_instance): extra_data[k] = sensitive[k] metadata = item[6] if len(item) > 6 and isinstance(item[6], dict) else None - server_instance.active_prompt_metadata = metadata + server_instance.active_prompt_metadata = (prompt_id, metadata) if metadata else None asset_seeder.pause() try: diff --git a/server.py b/server.py index bfe97f761..9e31dd4b6 100644 --- a/server.py +++ b/server.py @@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager from app.node_replace_manager import NodeReplaceManager +from app.prompt_metadata import validate_client_metadata from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -252,8 +253,10 @@ class PromptServer(): self.last_node_id = None self.client_id = None - # Opaque tag dict pinned by main.py around each prompt; send_sync spreads it. - self.active_prompt_metadata: Optional[dict] = None + # (prompt_id, opaque tag dict) pinned by main.py around each prompt. + # send_sync only spreads the dict onto payloads whose prompt_id matches, + # so concurrent queue/status broadcasts are not contaminated. + self.active_prompt_metadata: Optional[tuple[str, dict]] = None self.on_prompt_handlers = [] @@ -282,8 +285,11 @@ class PromptServer(): last_prompt_id = getattr(self, "last_prompt_id", None) if last_prompt_id: payload["prompt_id"] = last_prompt_id - if self.active_prompt_metadata: - payload = {**self.active_prompt_metadata, **payload} + slot = self.active_prompt_metadata + if slot is not None: + active_prompt_id, meta = slot + if meta and payload.get("prompt_id") == active_prompt_id: + payload = {**meta, **payload} await self.send("executing", payload, sid) # Flag to track if we've received the first message @@ -965,7 +971,12 @@ class PromptServer(): sensitive[sensitive_val] = extra_data.pop(sensitive_val) extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds raw_metadata = extra_data.pop("metadata", None) - client_metadata = raw_metadata if isinstance(raw_metadata, dict) else {} + client_metadata, meta_error = validate_client_metadata(raw_metadata) + if meta_error is not None: + return web.json_response( + {"error": {"type": "invalid_metadata", "message": meta_error}}, + status=400, + ) self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive, client_metadata)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) @@ -1228,9 +1239,11 @@ class PromptServer(): await send_socket_catch_exception(self.sockets[sid].send_json, message) def send_sync(self, event, data, sid=None): - meta = self.active_prompt_metadata - if meta and isinstance(data, dict): - data = {**meta, **data} + slot = self.active_prompt_metadata + if slot is not None and isinstance(data, dict): + active_prompt_id, meta = slot + if meta and data.get("prompt_id") == active_prompt_id: + data = {**meta, **data} self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) diff --git a/tests-unit/server_test/test_prompt_metadata.py b/tests-unit/server_test/test_prompt_metadata.py index 5f3ff997f..04b9f5499 100644 --- a/tests-unit/server_test/test_prompt_metadata.py +++ b/tests-unit/server_test/test_prompt_metadata.py @@ -4,6 +4,13 @@ from unittest.mock import MagicMock import pytest +from app.prompt_metadata import ( + MAX_METADATA_KEY_LEN, + MAX_METADATA_KEYS, + MAX_METADATA_VALUE_LEN, + RESERVED_METADATA_KEYS, + validate_client_metadata, +) from comfy_execution.jobs import extract_workflow_id @@ -30,7 +37,13 @@ class TestExtractWorkflowId: class _FakeServer: - """Minimal PromptServer stand-in mirroring send_sync verbatim.""" + """Minimal PromptServer stand-in mirroring send_sync verbatim. + + ``active_prompt_metadata`` is ``Optional[tuple[str, dict]]`` — the + ``prompt_id`` it belongs to plus the opaque dict. send_sync only merges + when the outgoing payload's ``prompt_id`` matches the active one, so + unrelated queue/status broadcasts are not contaminated. + """ def __init__(self): self.active_prompt_metadata = None @@ -43,9 +56,11 @@ class _FakeServer: self.messages.put_nowait = MagicMock() def send_sync(self, event, data, sid=None): - meta = self.active_prompt_metadata - if meta and isinstance(data, dict): - data = {**meta, **data} + slot = self.active_prompt_metadata + if slot is not None and isinstance(data, dict): + active_prompt_id, meta = slot + if meta and data.get("prompt_id") == active_prompt_id: + data = {**meta, **data} self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid) ) @@ -58,7 +73,7 @@ def server(): class TestSendSyncMerge: def test_spreads_active_metadata_onto_dict_payload(self, server): - server.active_prompt_metadata = {"workflow_id": "wf-1"} + server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"}) server.send_sync( "executing", {"node": "n1", "prompt_id": "p1"}, "client-1" @@ -82,7 +97,7 @@ class TestSendSyncMerge: assert data == {"node": "n1", "prompt_id": "p1"} def test_passthrough_when_metadata_is_empty_dict(self, server): - server.active_prompt_metadata = {} + server.active_prompt_metadata = ("p1", {}) server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}) @@ -90,17 +105,19 @@ class TestSendSyncMerge: assert data == {"node": "n1", "prompt_id": "p1"} def test_event_payload_wins_on_key_conflict(self, server): - server.active_prompt_metadata = {"workflow_id": "wf-1", "prompt_id": "from-meta"} + server.active_prompt_metadata = ( + "p1", + {"workflow_id": "wf-1", "prompt_id": "from-meta"}, + ) - server.send_sync("executing", {"node": "n1", "prompt_id": "from-frame"}, "c1") + server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}, "c1") _, data, _ = server.captured[0] - assert data["prompt_id"] == "from-frame" + assert data["prompt_id"] == "p1" assert data["workflow_id"] == "wf-1" def test_non_dict_payload_passes_through_untouched(self, server): - # BinaryEventTypes.TEXT byte frames must not be merged. - server.active_prompt_metadata = {"workflow_id": "wf-1"} + server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"}) server.send_sync("text", b"\x00\x00\x00\x03foobar", "c1") @@ -108,8 +125,7 @@ class TestSendSyncMerge: assert data == b"\x00\x00\x00\x03foobar" def test_terminal_executing_frame_includes_metadata(self, server): - # Slot is cleared after this send in main.py so the reset still carries metadata (#13684 race). - server.active_prompt_metadata = {"workflow_id": "wf-1"} + server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"}) server.send_sync( "executing", {"node": None, "prompt_id": "p1"}, "client-1" @@ -123,11 +139,10 @@ class TestSendSyncMerge: } def test_opaque_dict_supports_arbitrary_keys(self, server): - server.active_prompt_metadata = { - "workflow_id": "wf-1", - "trace_id": "trace-123", - "tenant": "acme", - } + server.active_prompt_metadata = ( + "p1", + {"workflow_id": "wf-1", "trace_id": "trace-123", "tenant": "acme"}, + ) server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}) @@ -137,17 +152,54 @@ class TestSendSyncMerge: assert data["tenant"] == "acme" +class TestStatusBroadcastsAreNotContaminated: + """Regression tests for the contamination bug: + + ``send_sync`` previously spread metadata onto any dict payload, so a + status broadcast fired while a prompt was running picked up that + prompt's metadata even though it had nothing to do with that prompt. + """ + + def test_status_payload_without_prompt_id_is_untouched(self, server): + server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"}) + + server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 1}}}) + + _, data, _ = server.captured[0] + assert data == {"status": {"exec_info": {"queue_remaining": 1}}} + assert "workflow_id" not in data + + def test_payload_for_different_prompt_is_untouched(self, server): + # Active prompt is p-running; we send a frame for p-other (e.g. another + # client's queued item). The merge must not leak across prompts. + server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"}) + + server.send_sync("executing", {"node": "n1", "prompt_id": "p-other"}) + + _, data, _ = server.captured[0] + assert data == {"node": "n1", "prompt_id": "p-other"} + assert "workflow_id" not in data + + def test_queue_updated_frame_during_active_prompt_is_clean(self, server): + server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"}) + + server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 0}}}) + + _, data, _ = server.captured[0] + assert "workflow_id" not in data + + class TestWorkerSerializationIsolatesMetadata: def test_two_prompts_sharing_prompt_id_get_correct_metadata(self, server): # Prompt A - server.active_prompt_metadata = {"workflow_id": "wf-AAA"} + server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-AAA"}) server.send_sync("execution_start", {"prompt_id": "P-shared"}) server.send_sync("executing", {"node": "n1", "prompt_id": "P-shared"}) server.send_sync("executing", {"node": None, "prompt_id": "P-shared"}) server.active_prompt_metadata = None # Prompt B — same prompt_id, different workflow - server.active_prompt_metadata = {"workflow_id": "wf-BBB"} + server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-BBB"}) server.send_sync("execution_start", {"prompt_id": "P-shared"}) server.send_sync("executing", {"node": "n2", "prompt_id": "P-shared"}) server.send_sync("executing", {"node": None, "prompt_id": "P-shared"}) @@ -160,3 +212,77 @@ class TestWorkerSerializationIsolatesMetadata: assert all(f["workflow_id"] == "wf-AAA" for f in a_frames) assert all(f["workflow_id"] == "wf-BBB" for f in b_frames) assert all(f["prompt_id"] == "P-shared" for f in frames) + + +class TestValidateClientMetadata: + def test_none_returns_empty_dict(self): + cleaned, error = validate_client_metadata(None) + assert cleaned == {} + assert error is None + + def test_flat_string_dict_is_accepted(self): + cleaned, error = validate_client_metadata( + {"workflow_id": "wf-1", "trace_id": "trace-abc"} + ) + assert cleaned == {"workflow_id": "wf-1", "trace_id": "trace-abc"} + assert error is None + + def test_non_dict_is_rejected(self): + _, error = validate_client_metadata("not a dict") + assert error is not None + assert "object" in error + + def test_list_is_rejected(self): + _, error = validate_client_metadata([("workflow_id", "wf-1")]) + assert error is not None + + def test_nested_dict_value_is_rejected(self): + _, error = validate_client_metadata({"workflow": {"id": "wf-1"}}) + assert error is not None + assert "string" in error + + def test_non_string_value_is_rejected(self): + _, error = validate_client_metadata({"workflow_id": 42}) + assert error is not None + + def test_non_string_key_is_rejected(self): + _, error = validate_client_metadata({123: "wf-1"}) + assert error is not None + + def test_empty_key_is_rejected(self): + _, error = validate_client_metadata({"": "wf-1"}) + assert error is not None + + def test_key_exceeding_limit_is_rejected(self): + _, error = validate_client_metadata({"k" * (MAX_METADATA_KEY_LEN + 1): "v"}) + assert error is not None + assert str(MAX_METADATA_KEY_LEN) in error + + def test_value_exceeding_limit_is_rejected(self): + _, error = validate_client_metadata({"workflow_id": "v" * (MAX_METADATA_VALUE_LEN + 1)}) + assert error is not None + assert str(MAX_METADATA_VALUE_LEN) in error + + def test_too_many_keys_is_rejected(self): + raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS + 1)} + _, error = validate_client_metadata(raw) + assert error is not None + assert str(MAX_METADATA_KEYS) in error + + def test_max_size_dict_is_accepted(self): + raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS)} + cleaned, error = validate_client_metadata(raw) + assert error is None + assert len(cleaned) == MAX_METADATA_KEYS + + def test_max_length_strings_are_accepted(self): + raw = {"k" * MAX_METADATA_KEY_LEN: "v" * MAX_METADATA_VALUE_LEN} + cleaned, error = validate_client_metadata(raw) + assert error is None + assert cleaned == raw + + @pytest.mark.parametrize("reserved_key", sorted(RESERVED_METADATA_KEYS)) + def test_reserved_keys_are_rejected(self, reserved_key): + _, error = validate_client_metadata({reserved_key: "anything"}) + assert error is not None + assert reserved_key in error