Compare commits

...

2 Commits

Author SHA1 Message Date
dante01yoon
db9c8cc2fd fix(server): scope prompt metadata to active prompt_id and validate at submission
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Address adversarial-review findings on FE-745 metadata propagation:

- send_sync previously spread active_prompt_metadata onto every dict
  payload, contaminating unrelated status/queue broadcasts with the
  running prompt's workflow_id. Change the slot to (prompt_id, metadata)
  and only inject when payload.prompt_id matches the active prompt_id.
  Same condition applied to the WS reconnect catch-up frame.

- post_prompt now validates extra_data.metadata at the submission
  boundary: flat dict[str,str], max 16 keys, 64-char keys, 256-char
  values, and reserved server-side keys (prompt_id, node, output, etc.)
  are rejected with 400. Removes the broadcast-amplification vector
  where a client could submit arbitrarily large metadata and force it
  onto every WS frame.

- Extract validate_client_metadata + caps into app/prompt_metadata.py so
  tests can import without pulling server.py's import-time side effects.

- Expand tests-unit/server_test/test_prompt_metadata.py from 12 to 47:
  add TestStatusBroadcastsAreNotContaminated for prompt_id-scoping and
  TestValidateClientMetadata for the new submission-boundary checks
  (including parametrized reserved-key rejection).
2026-05-20 19:01:38 +09:00
dante01yoon
dfc901078e feat(server): propagate opaque per-prompt metadata on WebSocket frames (FE-745)
server.py builds an opaque dict at submission time from extra_data and pins
it on PromptServer.active_prompt_metadata while main.py's worker drives the
prompt. send_sync spreads the dict's key/value pairs onto outgoing payloads
so frames carry whatever tags the submission attached (today: workflow_id).

The mechanism is intentionally untyped — the transport layer doesn't know
what workflow_id means or treat any key specially. Adding a new propagated
field requires only a one-line addition in post_prompt; execution.py and
comfy_execution/progress.py are not touched.

execution.py changes: 0 lines.
2026-05-20 18:00:48 +09:00
5 changed files with 397 additions and 15 deletions

44
app/prompt_metadata.py Normal file
View File

@ -0,0 +1,44 @@
"""Validation for client-supplied per-prompt metadata (extra_data.metadata)."""
from typing import Optional
MAX_METADATA_KEYS = 16
MAX_METADATA_KEY_LEN = 64
MAX_METADATA_VALUE_LEN = 256
# Server-emitted top-level fields on prompt-scoped WebSocket events.
# Client-supplied metadata may not shadow these — payload-wins-on-conflict
# only protects keys present in each individual frame, so reserve them
# at the submission boundary as defense in depth.
RESERVED_METADATA_KEYS = frozenset({
"prompt_id", "node", "display_node", "output", "nodes", "node_id",
"node_type", "executed", "exception_message", "exception_type",
"traceback", "current_inputs", "current_outputs", "timestamp",
"sid", "status", "prompt", "value", "max",
})
def validate_client_metadata(raw) -> tuple[Optional[dict], Optional[str]]:
"""Return ``(cleaned_metadata, error_message)``.
A missing field (``None``) is treated as empty metadata. Anything else
must be a flat ``dict[str, str]`` within the size caps and free of
reserved keys.
"""
if raw is None:
return {}, None
if not isinstance(raw, dict):
return None, "extra_data.metadata must be an object"
if len(raw) > MAX_METADATA_KEYS:
return None, f"extra_data.metadata exceeds {MAX_METADATA_KEYS} keys"
cleaned: dict = {}
for key, value in raw.items():
if not isinstance(key, str) or not key or len(key) > MAX_METADATA_KEY_LEN:
return None, f"metadata key must be a non-empty string up to {MAX_METADATA_KEY_LEN} chars"
if key in RESERVED_METADATA_KEYS:
return None, f"metadata key '{key}' is reserved"
if not isinstance(value, str) or len(value) > MAX_METADATA_VALUE_LEN:
return None, f"metadata value for '{key}' must be a string up to {MAX_METADATA_VALUE_LEN} chars"
cleaned[key] = value
return cleaned, None

View File

@ -93,6 +93,22 @@ def _create_text_preview(value: str) -> dict:
}
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
"""Return extra_data["extra_pnginfo"]["workflow"]["id"] when it is a non-empty string."""
if not isinstance(extra_data, dict):
return None
extra_pnginfo = extra_data.get('extra_pnginfo')
if not isinstance(extra_pnginfo, dict):
return None
workflow = extra_pnginfo.get('workflow')
if not isinstance(workflow, dict):
return None
workflow_id = workflow.get('id')
if isinstance(workflow_id, str) and workflow_id:
return workflow_id
return None
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
@ -100,8 +116,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
tuple: (create_time, workflow_id)
"""
create_time = extra_data.get('create_time')
extra_pnginfo = extra_data.get('extra_pnginfo', {})
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
workflow_id = extract_workflow_id(extra_data)
return create_time, workflow_id

30
main.py
View File

@ -317,20 +317,28 @@ def prompt_worker(q, server_instance):
for k in sensitive:
extra_data[k] = sensitive[k]
metadata = item[6] if len(item) > 6 and isinstance(item[6], dict) else None
server_instance.active_prompt_metadata = (prompt_id, metadata) if metadata else None
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
try:
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
# Drop sensitive (index 5) and metadata (index 6); history keeps a 5-tuple.
remove_sensitive = lambda prompt: prompt[:5]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
finally:
# Clear after the terminal send so that frame still carries metadata.
server_instance.active_prompt_metadata = None
current_time = time.perf_counter()
execution_time = current_time - execution_start_time

View File

@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from app.prompt_metadata import validate_client_metadata
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@ -252,6 +253,11 @@ class PromptServer():
self.last_node_id = None
self.client_id = None
# (prompt_id, opaque tag dict) pinned by main.py around each prompt.
# send_sync only spreads the dict onto payloads whose prompt_id matches,
# so concurrent queue/status broadcasts are not contaminated.
self.active_prompt_metadata: Optional[tuple[str, dict]] = None
self.on_prompt_handlers = []
@routes.get('/ws')
@ -275,7 +281,16 @@ class PromptServer():
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
payload = {"node": self.last_node_id}
last_prompt_id = getattr(self, "last_prompt_id", None)
if last_prompt_id:
payload["prompt_id"] = last_prompt_id
slot = self.active_prompt_metadata
if slot is not None:
active_prompt_id, meta = slot
if meta and payload.get("prompt_id") == active_prompt_id:
payload = {**meta, **payload}
await self.send("executing", payload, sid)
# Flag to track if we've received the first message
first_message = True
@ -955,7 +970,14 @@ class PromptServer():
if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
raw_metadata = extra_data.pop("metadata", None)
client_metadata, meta_error = validate_client_metadata(raw_metadata)
if meta_error is not None:
return web.json_response(
{"error": {"type": "invalid_metadata", "message": meta_error}},
status=400,
)
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive, client_metadata))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
else:
@ -1217,6 +1239,11 @@ class PromptServer():
await send_socket_catch_exception(self.sockets[sid].send_json, message)
def send_sync(self, event, data, sid=None):
slot = self.active_prompt_metadata
if slot is not None and isinstance(data, dict):
active_prompt_id, meta = slot
if meta and data.get("prompt_id") == active_prompt_id:
data = {**meta, **data}
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))

View File

@ -0,0 +1,288 @@
"""Tests for the opaque per-prompt metadata mechanism on PromptServer."""
from unittest.mock import MagicMock
import pytest
from app.prompt_metadata import (
MAX_METADATA_KEY_LEN,
MAX_METADATA_KEYS,
MAX_METADATA_VALUE_LEN,
RESERVED_METADATA_KEYS,
validate_client_metadata,
)
from comfy_execution.jobs import extract_workflow_id
class TestExtractWorkflowId:
def test_returns_id_when_present(self):
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": "wf-1"}}}) == "wf-1"
def test_returns_none_when_missing(self):
assert extract_workflow_id({}) is None
assert extract_workflow_id({"extra_pnginfo": {}}) is None
assert extract_workflow_id({"extra_pnginfo": {"workflow": {}}}) is None
def test_returns_none_for_empty_or_wrong_type(self):
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": ""}}}) is None
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": 42}}}) is None
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": None}}}) is None
def test_returns_none_for_non_dict_input(self):
assert extract_workflow_id(None) is None
assert extract_workflow_id("not a dict") is None
assert extract_workflow_id({"extra_pnginfo": "not a dict"}) is None
assert extract_workflow_id({"extra_pnginfo": {"workflow": "not a dict"}}) is None
class _FakeServer:
"""Minimal PromptServer stand-in mirroring send_sync verbatim.
``active_prompt_metadata`` is ``Optional[tuple[str, dict]]`` the
``prompt_id`` it belongs to plus the opaque dict. send_sync only merges
when the outgoing payload's ``prompt_id`` matches the active one, so
unrelated queue/status broadcasts are not contaminated.
"""
def __init__(self):
self.active_prompt_metadata = None
self.captured = []
self.loop = MagicMock()
self.loop.call_soon_threadsafe.side_effect = (
lambda fn, msg: self.captured.append(msg)
)
self.messages = MagicMock()
self.messages.put_nowait = MagicMock()
def send_sync(self, event, data, sid=None):
slot = self.active_prompt_metadata
if slot is not None and isinstance(data, dict):
active_prompt_id, meta = slot
if meta and data.get("prompt_id") == active_prompt_id:
data = {**meta, **data}
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid)
)
@pytest.fixture
def server():
return _FakeServer()
class TestSendSyncMerge:
def test_spreads_active_metadata_onto_dict_payload(self, server):
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
server.send_sync(
"executing", {"node": "n1", "prompt_id": "p1"}, "client-1"
)
event, data, sid = server.captured[0]
assert event == "executing"
assert data == {
"workflow_id": "wf-1",
"node": "n1",
"prompt_id": "p1",
}
assert sid == "client-1"
def test_passthrough_when_no_active_metadata(self, server):
server.active_prompt_metadata = None
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
_, data, _ = server.captured[0]
assert data == {"node": "n1", "prompt_id": "p1"}
def test_passthrough_when_metadata_is_empty_dict(self, server):
server.active_prompt_metadata = ("p1", {})
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
_, data, _ = server.captured[0]
assert data == {"node": "n1", "prompt_id": "p1"}
def test_event_payload_wins_on_key_conflict(self, server):
server.active_prompt_metadata = (
"p1",
{"workflow_id": "wf-1", "prompt_id": "from-meta"},
)
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}, "c1")
_, data, _ = server.captured[0]
assert data["prompt_id"] == "p1"
assert data["workflow_id"] == "wf-1"
def test_non_dict_payload_passes_through_untouched(self, server):
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
server.send_sync("text", b"\x00\x00\x00\x03foobar", "c1")
_, data, _ = server.captured[0]
assert data == b"\x00\x00\x00\x03foobar"
def test_terminal_executing_frame_includes_metadata(self, server):
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
server.send_sync(
"executing", {"node": None, "prompt_id": "p1"}, "client-1"
)
_, data, _ = server.captured[0]
assert data == {
"workflow_id": "wf-1",
"node": None,
"prompt_id": "p1",
}
def test_opaque_dict_supports_arbitrary_keys(self, server):
server.active_prompt_metadata = (
"p1",
{"workflow_id": "wf-1", "trace_id": "trace-123", "tenant": "acme"},
)
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
_, data, _ = server.captured[0]
assert data["workflow_id"] == "wf-1"
assert data["trace_id"] == "trace-123"
assert data["tenant"] == "acme"
class TestStatusBroadcastsAreNotContaminated:
"""Regression tests for the contamination bug:
``send_sync`` previously spread metadata onto any dict payload, so a
status broadcast fired while a prompt was running picked up that
prompt's metadata even though it had nothing to do with that prompt.
"""
def test_status_payload_without_prompt_id_is_untouched(self, server):
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 1}}})
_, data, _ = server.captured[0]
assert data == {"status": {"exec_info": {"queue_remaining": 1}}}
assert "workflow_id" not in data
def test_payload_for_different_prompt_is_untouched(self, server):
# Active prompt is p-running; we send a frame for p-other (e.g. another
# client's queued item). The merge must not leak across prompts.
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
server.send_sync("executing", {"node": "n1", "prompt_id": "p-other"})
_, data, _ = server.captured[0]
assert data == {"node": "n1", "prompt_id": "p-other"}
assert "workflow_id" not in data
def test_queue_updated_frame_during_active_prompt_is_clean(self, server):
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 0}}})
_, data, _ = server.captured[0]
assert "workflow_id" not in data
class TestWorkerSerializationIsolatesMetadata:
def test_two_prompts_sharing_prompt_id_get_correct_metadata(self, server):
# Prompt A
server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-AAA"})
server.send_sync("execution_start", {"prompt_id": "P-shared"})
server.send_sync("executing", {"node": "n1", "prompt_id": "P-shared"})
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
server.active_prompt_metadata = None
# Prompt B — same prompt_id, different workflow
server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-BBB"})
server.send_sync("execution_start", {"prompt_id": "P-shared"})
server.send_sync("executing", {"node": "n2", "prompt_id": "P-shared"})
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
server.active_prompt_metadata = None
frames = [d for (_, d, _) in server.captured]
a_frames = frames[:3]
b_frames = frames[3:]
assert all(f["workflow_id"] == "wf-AAA" for f in a_frames)
assert all(f["workflow_id"] == "wf-BBB" for f in b_frames)
assert all(f["prompt_id"] == "P-shared" for f in frames)
class TestValidateClientMetadata:
def test_none_returns_empty_dict(self):
cleaned, error = validate_client_metadata(None)
assert cleaned == {}
assert error is None
def test_flat_string_dict_is_accepted(self):
cleaned, error = validate_client_metadata(
{"workflow_id": "wf-1", "trace_id": "trace-abc"}
)
assert cleaned == {"workflow_id": "wf-1", "trace_id": "trace-abc"}
assert error is None
def test_non_dict_is_rejected(self):
_, error = validate_client_metadata("not a dict")
assert error is not None
assert "object" in error
def test_list_is_rejected(self):
_, error = validate_client_metadata([("workflow_id", "wf-1")])
assert error is not None
def test_nested_dict_value_is_rejected(self):
_, error = validate_client_metadata({"workflow": {"id": "wf-1"}})
assert error is not None
assert "string" in error
def test_non_string_value_is_rejected(self):
_, error = validate_client_metadata({"workflow_id": 42})
assert error is not None
def test_non_string_key_is_rejected(self):
_, error = validate_client_metadata({123: "wf-1"})
assert error is not None
def test_empty_key_is_rejected(self):
_, error = validate_client_metadata({"": "wf-1"})
assert error is not None
def test_key_exceeding_limit_is_rejected(self):
_, error = validate_client_metadata({"k" * (MAX_METADATA_KEY_LEN + 1): "v"})
assert error is not None
assert str(MAX_METADATA_KEY_LEN) in error
def test_value_exceeding_limit_is_rejected(self):
_, error = validate_client_metadata({"workflow_id": "v" * (MAX_METADATA_VALUE_LEN + 1)})
assert error is not None
assert str(MAX_METADATA_VALUE_LEN) in error
def test_too_many_keys_is_rejected(self):
raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS + 1)}
_, error = validate_client_metadata(raw)
assert error is not None
assert str(MAX_METADATA_KEYS) in error
def test_max_size_dict_is_accepted(self):
raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS)}
cleaned, error = validate_client_metadata(raw)
assert error is None
assert len(cleaned) == MAX_METADATA_KEYS
def test_max_length_strings_are_accepted(self):
raw = {"k" * MAX_METADATA_KEY_LEN: "v" * MAX_METADATA_VALUE_LEN}
cleaned, error = validate_client_metadata(raw)
assert error is None
assert cleaned == raw
@pytest.mark.parametrize("reserved_key", sorted(RESERVED_METADATA_KEYS))
def test_reserved_keys_are_rejected(self, reserved_key):
_, error = validate_client_metadata({reserved_key: "anything"})
assert error is not None
assert reserved_key in error