fix(server): scope prompt metadata to active prompt_id and validate at submission
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

Address adversarial-review findings on FE-745 metadata propagation:

- send_sync previously spread active_prompt_metadata onto every dict
  payload, contaminating unrelated status/queue broadcasts with the
  running prompt's workflow_id. Change the slot to (prompt_id, metadata)
  and only inject when payload.prompt_id matches the active prompt_id.
  Same condition applied to the WS reconnect catch-up frame.

- post_prompt now validates extra_data.metadata at the submission
  boundary: flat dict[str,str], max 16 keys, 64-char keys, 256-char
  values, and reserved server-side keys (prompt_id, node, output, etc.)
  are rejected with 400. Removes the broadcast-amplification vector
  where a client could submit arbitrarily large metadata and force it
  onto every WS frame.

- Extract validate_client_metadata + caps into app/prompt_metadata.py so
  tests can import without pulling server.py's import-time side effects.

- Expand tests-unit/server_test/test_prompt_metadata.py from 12 to 47:
  add TestStatusBroadcastsAreNotContaminated for prompt_id-scoping and
  TestValidateClientMetadata for the new submission-boundary checks
  (including parametrized reserved-key rejection).
This commit is contained in:
dante01yoon 2026-05-20 19:01:38 +09:00
parent dfc901078e
commit db9c8cc2fd
4 changed files with 212 additions and 29 deletions

44
app/prompt_metadata.py Normal file
View File

@ -0,0 +1,44 @@
"""Validation for client-supplied per-prompt metadata (extra_data.metadata)."""
from typing import Optional
MAX_METADATA_KEYS = 16
MAX_METADATA_KEY_LEN = 64
MAX_METADATA_VALUE_LEN = 256
# Server-emitted top-level fields on prompt-scoped WebSocket events.
# Client-supplied metadata may not shadow these — payload-wins-on-conflict
# only protects keys present in each individual frame, so reserve them
# at the submission boundary as defense in depth.
RESERVED_METADATA_KEYS = frozenset({
"prompt_id", "node", "display_node", "output", "nodes", "node_id",
"node_type", "executed", "exception_message", "exception_type",
"traceback", "current_inputs", "current_outputs", "timestamp",
"sid", "status", "prompt", "value", "max",
})
def validate_client_metadata(raw) -> tuple[Optional[dict], Optional[str]]:
"""Return ``(cleaned_metadata, error_message)``.
A missing field (``None``) is treated as empty metadata. Anything else
must be a flat ``dict[str, str]`` within the size caps and free of
reserved keys.
"""
if raw is None:
return {}, None
if not isinstance(raw, dict):
return None, "extra_data.metadata must be an object"
if len(raw) > MAX_METADATA_KEYS:
return None, f"extra_data.metadata exceeds {MAX_METADATA_KEYS} keys"
cleaned: dict = {}
for key, value in raw.items():
if not isinstance(key, str) or not key or len(key) > MAX_METADATA_KEY_LEN:
return None, f"metadata key must be a non-empty string up to {MAX_METADATA_KEY_LEN} chars"
if key in RESERVED_METADATA_KEYS:
return None, f"metadata key '{key}' is reserved"
if not isinstance(value, str) or len(value) > MAX_METADATA_VALUE_LEN:
return None, f"metadata value for '{key}' must be a string up to {MAX_METADATA_VALUE_LEN} chars"
cleaned[key] = value
return cleaned, None

View File

@ -318,7 +318,7 @@ def prompt_worker(q, server_instance):
extra_data[k] = sensitive[k]
metadata = item[6] if len(item) > 6 and isinstance(item[6], dict) else None
server_instance.active_prompt_metadata = metadata
server_instance.active_prompt_metadata = (prompt_id, metadata) if metadata else None
asset_seeder.pause()
try:

View File

@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from app.prompt_metadata import validate_client_metadata
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@ -252,8 +253,10 @@ class PromptServer():
self.last_node_id = None
self.client_id = None
# Opaque tag dict pinned by main.py around each prompt; send_sync spreads it.
self.active_prompt_metadata: Optional[dict] = None
# (prompt_id, opaque tag dict) pinned by main.py around each prompt.
# send_sync only spreads the dict onto payloads whose prompt_id matches,
# so concurrent queue/status broadcasts are not contaminated.
self.active_prompt_metadata: Optional[tuple[str, dict]] = None
self.on_prompt_handlers = []
@ -282,8 +285,11 @@ class PromptServer():
last_prompt_id = getattr(self, "last_prompt_id", None)
if last_prompt_id:
payload["prompt_id"] = last_prompt_id
if self.active_prompt_metadata:
payload = {**self.active_prompt_metadata, **payload}
slot = self.active_prompt_metadata
if slot is not None:
active_prompt_id, meta = slot
if meta and payload.get("prompt_id") == active_prompt_id:
payload = {**meta, **payload}
await self.send("executing", payload, sid)
# Flag to track if we've received the first message
@ -965,7 +971,12 @@ class PromptServer():
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
raw_metadata = extra_data.pop("metadata", None)
client_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
client_metadata, meta_error = validate_client_metadata(raw_metadata)
if meta_error is not None:
return web.json_response(
{"error": {"type": "invalid_metadata", "message": meta_error}},
status=400,
)
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive, client_metadata))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
@ -1228,9 +1239,11 @@ class PromptServer():
await send_socket_catch_exception(self.sockets[sid].send_json, message)
def send_sync(self, event, data, sid=None):
meta = self.active_prompt_metadata
if meta and isinstance(data, dict):
data = {**meta, **data}
slot = self.active_prompt_metadata
if slot is not None and isinstance(data, dict):
active_prompt_id, meta = slot
if meta and data.get("prompt_id") == active_prompt_id:
data = {**meta, **data}
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))

View File

@ -4,6 +4,13 @@ from unittest.mock import MagicMock
import pytest
from app.prompt_metadata import (
MAX_METADATA_KEY_LEN,
MAX_METADATA_KEYS,
MAX_METADATA_VALUE_LEN,
RESERVED_METADATA_KEYS,
validate_client_metadata,
)
from comfy_execution.jobs import extract_workflow_id
@ -30,7 +37,13 @@ class TestExtractWorkflowId:
class _FakeServer:
"""Minimal PromptServer stand-in mirroring send_sync verbatim."""
"""Minimal PromptServer stand-in mirroring send_sync verbatim.
``active_prompt_metadata`` is ``Optional[tuple[str, dict]]`` the
``prompt_id`` it belongs to plus the opaque dict. send_sync only merges
when the outgoing payload's ``prompt_id`` matches the active one, so
unrelated queue/status broadcasts are not contaminated.
"""
def __init__(self):
self.active_prompt_metadata = None
@ -43,9 +56,11 @@ class _FakeServer:
self.messages.put_nowait = MagicMock()
def send_sync(self, event, data, sid=None):
meta = self.active_prompt_metadata
if meta and isinstance(data, dict):
data = {**meta, **data}
slot = self.active_prompt_metadata
if slot is not None and isinstance(data, dict):
active_prompt_id, meta = slot
if meta and data.get("prompt_id") == active_prompt_id:
data = {**meta, **data}
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid)
)
@ -58,7 +73,7 @@ def server():
class TestSendSyncMerge:
def test_spreads_active_metadata_onto_dict_payload(self, server):
server.active_prompt_metadata = {"workflow_id": "wf-1"}
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
server.send_sync(
"executing", {"node": "n1", "prompt_id": "p1"}, "client-1"
@ -82,7 +97,7 @@ class TestSendSyncMerge:
assert data == {"node": "n1", "prompt_id": "p1"}
def test_passthrough_when_metadata_is_empty_dict(self, server):
server.active_prompt_metadata = {}
server.active_prompt_metadata = ("p1", {})
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
@ -90,17 +105,19 @@ class TestSendSyncMerge:
assert data == {"node": "n1", "prompt_id": "p1"}
def test_event_payload_wins_on_key_conflict(self, server):
server.active_prompt_metadata = {"workflow_id": "wf-1", "prompt_id": "from-meta"}
server.active_prompt_metadata = (
"p1",
{"workflow_id": "wf-1", "prompt_id": "from-meta"},
)
server.send_sync("executing", {"node": "n1", "prompt_id": "from-frame"}, "c1")
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}, "c1")
_, data, _ = server.captured[0]
assert data["prompt_id"] == "from-frame"
assert data["prompt_id"] == "p1"
assert data["workflow_id"] == "wf-1"
def test_non_dict_payload_passes_through_untouched(self, server):
# BinaryEventTypes.TEXT byte frames must not be merged.
server.active_prompt_metadata = {"workflow_id": "wf-1"}
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
server.send_sync("text", b"\x00\x00\x00\x03foobar", "c1")
@ -108,8 +125,7 @@ class TestSendSyncMerge:
assert data == b"\x00\x00\x00\x03foobar"
def test_terminal_executing_frame_includes_metadata(self, server):
# Slot is cleared after this send in main.py so the reset still carries metadata (#13684 race).
server.active_prompt_metadata = {"workflow_id": "wf-1"}
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
server.send_sync(
"executing", {"node": None, "prompt_id": "p1"}, "client-1"
@ -123,11 +139,10 @@ class TestSendSyncMerge:
}
def test_opaque_dict_supports_arbitrary_keys(self, server):
server.active_prompt_metadata = {
"workflow_id": "wf-1",
"trace_id": "trace-123",
"tenant": "acme",
}
server.active_prompt_metadata = (
"p1",
{"workflow_id": "wf-1", "trace_id": "trace-123", "tenant": "acme"},
)
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
@ -137,17 +152,54 @@ class TestSendSyncMerge:
assert data["tenant"] == "acme"
class TestStatusBroadcastsAreNotContaminated:
"""Regression tests for the contamination bug:
``send_sync`` previously spread metadata onto any dict payload, so a
status broadcast fired while a prompt was running picked up that
prompt's metadata even though it had nothing to do with that prompt.
"""
def test_status_payload_without_prompt_id_is_untouched(self, server):
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 1}}})
_, data, _ = server.captured[0]
assert data == {"status": {"exec_info": {"queue_remaining": 1}}}
assert "workflow_id" not in data
def test_payload_for_different_prompt_is_untouched(self, server):
# Active prompt is p-running; we send a frame for p-other (e.g. another
# client's queued item). The merge must not leak across prompts.
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
server.send_sync("executing", {"node": "n1", "prompt_id": "p-other"})
_, data, _ = server.captured[0]
assert data == {"node": "n1", "prompt_id": "p-other"}
assert "workflow_id" not in data
def test_queue_updated_frame_during_active_prompt_is_clean(self, server):
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 0}}})
_, data, _ = server.captured[0]
assert "workflow_id" not in data
class TestWorkerSerializationIsolatesMetadata:
def test_two_prompts_sharing_prompt_id_get_correct_metadata(self, server):
# Prompt A
server.active_prompt_metadata = {"workflow_id": "wf-AAA"}
server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-AAA"})
server.send_sync("execution_start", {"prompt_id": "P-shared"})
server.send_sync("executing", {"node": "n1", "prompt_id": "P-shared"})
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
server.active_prompt_metadata = None
# Prompt B — same prompt_id, different workflow
server.active_prompt_metadata = {"workflow_id": "wf-BBB"}
server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-BBB"})
server.send_sync("execution_start", {"prompt_id": "P-shared"})
server.send_sync("executing", {"node": "n2", "prompt_id": "P-shared"})
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
@ -160,3 +212,77 @@ class TestWorkerSerializationIsolatesMetadata:
assert all(f["workflow_id"] == "wf-AAA" for f in a_frames)
assert all(f["workflow_id"] == "wf-BBB" for f in b_frames)
assert all(f["prompt_id"] == "P-shared" for f in frames)
class TestValidateClientMetadata:
def test_none_returns_empty_dict(self):
cleaned, error = validate_client_metadata(None)
assert cleaned == {}
assert error is None
def test_flat_string_dict_is_accepted(self):
cleaned, error = validate_client_metadata(
{"workflow_id": "wf-1", "trace_id": "trace-abc"}
)
assert cleaned == {"workflow_id": "wf-1", "trace_id": "trace-abc"}
assert error is None
def test_non_dict_is_rejected(self):
_, error = validate_client_metadata("not a dict")
assert error is not None
assert "object" in error
def test_list_is_rejected(self):
_, error = validate_client_metadata([("workflow_id", "wf-1")])
assert error is not None
def test_nested_dict_value_is_rejected(self):
_, error = validate_client_metadata({"workflow": {"id": "wf-1"}})
assert error is not None
assert "string" in error
def test_non_string_value_is_rejected(self):
_, error = validate_client_metadata({"workflow_id": 42})
assert error is not None
def test_non_string_key_is_rejected(self):
_, error = validate_client_metadata({123: "wf-1"})
assert error is not None
def test_empty_key_is_rejected(self):
_, error = validate_client_metadata({"": "wf-1"})
assert error is not None
def test_key_exceeding_limit_is_rejected(self):
_, error = validate_client_metadata({"k" * (MAX_METADATA_KEY_LEN + 1): "v"})
assert error is not None
assert str(MAX_METADATA_KEY_LEN) in error
def test_value_exceeding_limit_is_rejected(self):
_, error = validate_client_metadata({"workflow_id": "v" * (MAX_METADATA_VALUE_LEN + 1)})
assert error is not None
assert str(MAX_METADATA_VALUE_LEN) in error
def test_too_many_keys_is_rejected(self):
raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS + 1)}
_, error = validate_client_metadata(raw)
assert error is not None
assert str(MAX_METADATA_KEYS) in error
def test_max_size_dict_is_accepted(self):
raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS)}
cleaned, error = validate_client_metadata(raw)
assert error is None
assert len(cleaned) == MAX_METADATA_KEYS
def test_max_length_strings_are_accepted(self):
raw = {"k" * MAX_METADATA_KEY_LEN: "v" * MAX_METADATA_VALUE_LEN}
cleaned, error = validate_client_metadata(raw)
assert error is None
assert cleaned == raw
@pytest.mark.parametrize("reserved_key", sorted(RESERVED_METADATA_KEYS))
def test_reserved_keys_are_rejected(self, reserved_key):
_, error = validate_client_metadata({reserved_key: "anything"})
assert error is not None
assert reserved_key in error