mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 19:07:25 +08:00
Address adversarial-review findings on FE-745 metadata propagation: - send_sync previously spread active_prompt_metadata onto every dict payload, contaminating unrelated status/queue broadcasts with the running prompt's workflow_id. Change the slot to (prompt_id, metadata) and only inject when payload.prompt_id matches the active prompt_id. Same condition applied to the WS reconnect catch-up frame. - post_prompt now validates extra_data.metadata at the submission boundary: flat dict[str,str], max 16 keys, 64-char keys, 256-char values, and reserved server-side keys (prompt_id, node, output, etc.) are rejected with 400. Removes the broadcast-amplification vector where a client could submit arbitrarily large metadata and force it onto every WS frame. - Extract validate_client_metadata + caps into app/prompt_metadata.py so tests can import without pulling server.py's import-time side effects. - Expand tests-unit/server_test/test_prompt_metadata.py from 12 to 47: add TestStatusBroadcastsAreNotContaminated for prompt_id-scoping and TestValidateClientMetadata for the new submission-boundary checks (including parametrized reserved-key rejection).
289 lines
11 KiB
Python
289 lines
11 KiB
Python
"""Tests for the opaque per-prompt metadata mechanism on PromptServer."""
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.prompt_metadata import (
|
|
MAX_METADATA_KEY_LEN,
|
|
MAX_METADATA_KEYS,
|
|
MAX_METADATA_VALUE_LEN,
|
|
RESERVED_METADATA_KEYS,
|
|
validate_client_metadata,
|
|
)
|
|
from comfy_execution.jobs import extract_workflow_id
|
|
|
|
|
|
class TestExtractWorkflowId:
|
|
|
|
def test_returns_id_when_present(self):
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": "wf-1"}}}) == "wf-1"
|
|
|
|
def test_returns_none_when_missing(self):
|
|
assert extract_workflow_id({}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {}}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {}}}) is None
|
|
|
|
def test_returns_none_for_empty_or_wrong_type(self):
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": ""}}}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": 42}}}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": None}}}) is None
|
|
|
|
def test_returns_none_for_non_dict_input(self):
|
|
assert extract_workflow_id(None) is None
|
|
assert extract_workflow_id("not a dict") is None
|
|
assert extract_workflow_id({"extra_pnginfo": "not a dict"}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": "not a dict"}}) is None
|
|
|
|
|
|
class _FakeServer:
|
|
"""Minimal PromptServer stand-in mirroring send_sync verbatim.
|
|
|
|
``active_prompt_metadata`` is ``Optional[tuple[str, dict]]`` — the
|
|
``prompt_id`` it belongs to plus the opaque dict. send_sync only merges
|
|
when the outgoing payload's ``prompt_id`` matches the active one, so
|
|
unrelated queue/status broadcasts are not contaminated.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.active_prompt_metadata = None
|
|
self.captured = []
|
|
self.loop = MagicMock()
|
|
self.loop.call_soon_threadsafe.side_effect = (
|
|
lambda fn, msg: self.captured.append(msg)
|
|
)
|
|
self.messages = MagicMock()
|
|
self.messages.put_nowait = MagicMock()
|
|
|
|
def send_sync(self, event, data, sid=None):
|
|
slot = self.active_prompt_metadata
|
|
if slot is not None and isinstance(data, dict):
|
|
active_prompt_id, meta = slot
|
|
if meta and data.get("prompt_id") == active_prompt_id:
|
|
data = {**meta, **data}
|
|
self.loop.call_soon_threadsafe(
|
|
self.messages.put_nowait, (event, data, sid)
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def server():
|
|
return _FakeServer()
|
|
|
|
|
|
class TestSendSyncMerge:
|
|
def test_spreads_active_metadata_onto_dict_payload(self, server):
|
|
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
|
|
|
|
server.send_sync(
|
|
"executing", {"node": "n1", "prompt_id": "p1"}, "client-1"
|
|
)
|
|
|
|
event, data, sid = server.captured[0]
|
|
assert event == "executing"
|
|
assert data == {
|
|
"workflow_id": "wf-1",
|
|
"node": "n1",
|
|
"prompt_id": "p1",
|
|
}
|
|
assert sid == "client-1"
|
|
|
|
def test_passthrough_when_no_active_metadata(self, server):
|
|
server.active_prompt_metadata = None
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == {"node": "n1", "prompt_id": "p1"}
|
|
|
|
def test_passthrough_when_metadata_is_empty_dict(self, server):
|
|
server.active_prompt_metadata = ("p1", {})
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == {"node": "n1", "prompt_id": "p1"}
|
|
|
|
def test_event_payload_wins_on_key_conflict(self, server):
|
|
server.active_prompt_metadata = (
|
|
"p1",
|
|
{"workflow_id": "wf-1", "prompt_id": "from-meta"},
|
|
)
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}, "c1")
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data["prompt_id"] == "p1"
|
|
assert data["workflow_id"] == "wf-1"
|
|
|
|
def test_non_dict_payload_passes_through_untouched(self, server):
|
|
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
|
|
|
|
server.send_sync("text", b"\x00\x00\x00\x03foobar", "c1")
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == b"\x00\x00\x00\x03foobar"
|
|
|
|
def test_terminal_executing_frame_includes_metadata(self, server):
|
|
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
|
|
|
|
server.send_sync(
|
|
"executing", {"node": None, "prompt_id": "p1"}, "client-1"
|
|
)
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == {
|
|
"workflow_id": "wf-1",
|
|
"node": None,
|
|
"prompt_id": "p1",
|
|
}
|
|
|
|
def test_opaque_dict_supports_arbitrary_keys(self, server):
|
|
server.active_prompt_metadata = (
|
|
"p1",
|
|
{"workflow_id": "wf-1", "trace_id": "trace-123", "tenant": "acme"},
|
|
)
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data["workflow_id"] == "wf-1"
|
|
assert data["trace_id"] == "trace-123"
|
|
assert data["tenant"] == "acme"
|
|
|
|
|
|
class TestStatusBroadcastsAreNotContaminated:
|
|
"""Regression tests for the contamination bug:
|
|
|
|
``send_sync`` previously spread metadata onto any dict payload, so a
|
|
status broadcast fired while a prompt was running picked up that
|
|
prompt's metadata even though it had nothing to do with that prompt.
|
|
"""
|
|
|
|
def test_status_payload_without_prompt_id_is_untouched(self, server):
|
|
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
|
|
|
|
server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 1}}})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == {"status": {"exec_info": {"queue_remaining": 1}}}
|
|
assert "workflow_id" not in data
|
|
|
|
def test_payload_for_different_prompt_is_untouched(self, server):
|
|
# Active prompt is p-running; we send a frame for p-other (e.g. another
|
|
# client's queued item). The merge must not leak across prompts.
|
|
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "p-other"})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == {"node": "n1", "prompt_id": "p-other"}
|
|
assert "workflow_id" not in data
|
|
|
|
def test_queue_updated_frame_during_active_prompt_is_clean(self, server):
|
|
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
|
|
|
|
server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 0}}})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert "workflow_id" not in data
|
|
|
|
|
|
class TestWorkerSerializationIsolatesMetadata:
|
|
def test_two_prompts_sharing_prompt_id_get_correct_metadata(self, server):
|
|
# Prompt A
|
|
server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-AAA"})
|
|
server.send_sync("execution_start", {"prompt_id": "P-shared"})
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "P-shared"})
|
|
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
|
|
server.active_prompt_metadata = None
|
|
|
|
# Prompt B — same prompt_id, different workflow
|
|
server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-BBB"})
|
|
server.send_sync("execution_start", {"prompt_id": "P-shared"})
|
|
server.send_sync("executing", {"node": "n2", "prompt_id": "P-shared"})
|
|
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
|
|
server.active_prompt_metadata = None
|
|
|
|
frames = [d for (_, d, _) in server.captured]
|
|
a_frames = frames[:3]
|
|
b_frames = frames[3:]
|
|
|
|
assert all(f["workflow_id"] == "wf-AAA" for f in a_frames)
|
|
assert all(f["workflow_id"] == "wf-BBB" for f in b_frames)
|
|
assert all(f["prompt_id"] == "P-shared" for f in frames)
|
|
|
|
|
|
class TestValidateClientMetadata:
|
|
def test_none_returns_empty_dict(self):
|
|
cleaned, error = validate_client_metadata(None)
|
|
assert cleaned == {}
|
|
assert error is None
|
|
|
|
def test_flat_string_dict_is_accepted(self):
|
|
cleaned, error = validate_client_metadata(
|
|
{"workflow_id": "wf-1", "trace_id": "trace-abc"}
|
|
)
|
|
assert cleaned == {"workflow_id": "wf-1", "trace_id": "trace-abc"}
|
|
assert error is None
|
|
|
|
def test_non_dict_is_rejected(self):
|
|
_, error = validate_client_metadata("not a dict")
|
|
assert error is not None
|
|
assert "object" in error
|
|
|
|
def test_list_is_rejected(self):
|
|
_, error = validate_client_metadata([("workflow_id", "wf-1")])
|
|
assert error is not None
|
|
|
|
def test_nested_dict_value_is_rejected(self):
|
|
_, error = validate_client_metadata({"workflow": {"id": "wf-1"}})
|
|
assert error is not None
|
|
assert "string" in error
|
|
|
|
def test_non_string_value_is_rejected(self):
|
|
_, error = validate_client_metadata({"workflow_id": 42})
|
|
assert error is not None
|
|
|
|
def test_non_string_key_is_rejected(self):
|
|
_, error = validate_client_metadata({123: "wf-1"})
|
|
assert error is not None
|
|
|
|
def test_empty_key_is_rejected(self):
|
|
_, error = validate_client_metadata({"": "wf-1"})
|
|
assert error is not None
|
|
|
|
def test_key_exceeding_limit_is_rejected(self):
|
|
_, error = validate_client_metadata({"k" * (MAX_METADATA_KEY_LEN + 1): "v"})
|
|
assert error is not None
|
|
assert str(MAX_METADATA_KEY_LEN) in error
|
|
|
|
def test_value_exceeding_limit_is_rejected(self):
|
|
_, error = validate_client_metadata({"workflow_id": "v" * (MAX_METADATA_VALUE_LEN + 1)})
|
|
assert error is not None
|
|
assert str(MAX_METADATA_VALUE_LEN) in error
|
|
|
|
def test_too_many_keys_is_rejected(self):
|
|
raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS + 1)}
|
|
_, error = validate_client_metadata(raw)
|
|
assert error is not None
|
|
assert str(MAX_METADATA_KEYS) in error
|
|
|
|
def test_max_size_dict_is_accepted(self):
|
|
raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS)}
|
|
cleaned, error = validate_client_metadata(raw)
|
|
assert error is None
|
|
assert len(cleaned) == MAX_METADATA_KEYS
|
|
|
|
def test_max_length_strings_are_accepted(self):
|
|
raw = {"k" * MAX_METADATA_KEY_LEN: "v" * MAX_METADATA_VALUE_LEN}
|
|
cleaned, error = validate_client_metadata(raw)
|
|
assert error is None
|
|
assert cleaned == raw
|
|
|
|
@pytest.mark.parametrize("reserved_key", sorted(RESERVED_METADATA_KEYS))
|
|
def test_reserved_keys_are_rejected(self, reserved_key):
|
|
_, error = validate_client_metadata({reserved_key: "anything"})
|
|
assert error is not None
|
|
assert reserved_key in error
|