mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-20 06:49:37 +08:00
Compare commits
2 Commits
8d9ea888ae
...
db9c8cc2fd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db9c8cc2fd | ||
|
|
dfc901078e |
44
app/prompt_metadata.py
Normal file
44
app/prompt_metadata.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Validation for client-supplied per-prompt metadata (extra_data.metadata)."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
MAX_METADATA_KEYS = 16
|
||||
MAX_METADATA_KEY_LEN = 64
|
||||
MAX_METADATA_VALUE_LEN = 256
|
||||
|
||||
# Server-emitted top-level fields on prompt-scoped WebSocket events.
|
||||
# Client-supplied metadata may not shadow these — payload-wins-on-conflict
|
||||
# only protects keys present in each individual frame, so reserve them
|
||||
# at the submission boundary as defense in depth.
|
||||
RESERVED_METADATA_KEYS = frozenset({
|
||||
"prompt_id", "node", "display_node", "output", "nodes", "node_id",
|
||||
"node_type", "executed", "exception_message", "exception_type",
|
||||
"traceback", "current_inputs", "current_outputs", "timestamp",
|
||||
"sid", "status", "prompt", "value", "max",
|
||||
})
|
||||
|
||||
|
||||
def validate_client_metadata(raw) -> tuple[Optional[dict], Optional[str]]:
|
||||
"""Return ``(cleaned_metadata, error_message)``.
|
||||
|
||||
A missing field (``None``) is treated as empty metadata. Anything else
|
||||
must be a flat ``dict[str, str]`` within the size caps and free of
|
||||
reserved keys.
|
||||
"""
|
||||
if raw is None:
|
||||
return {}, None
|
||||
if not isinstance(raw, dict):
|
||||
return None, "extra_data.metadata must be an object"
|
||||
if len(raw) > MAX_METADATA_KEYS:
|
||||
return None, f"extra_data.metadata exceeds {MAX_METADATA_KEYS} keys"
|
||||
cleaned: dict = {}
|
||||
for key, value in raw.items():
|
||||
if not isinstance(key, str) or not key or len(key) > MAX_METADATA_KEY_LEN:
|
||||
return None, f"metadata key must be a non-empty string up to {MAX_METADATA_KEY_LEN} chars"
|
||||
if key in RESERVED_METADATA_KEYS:
|
||||
return None, f"metadata key '{key}' is reserved"
|
||||
if not isinstance(value, str) or len(value) > MAX_METADATA_VALUE_LEN:
|
||||
return None, f"metadata value for '{key}' must be a string up to {MAX_METADATA_VALUE_LEN} chars"
|
||||
cleaned[key] = value
|
||||
return cleaned, None
|
||||
@ -93,6 +93,22 @@ def _create_text_preview(value: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
|
||||
"""Return extra_data["extra_pnginfo"]["workflow"]["id"] when it is a non-empty string."""
|
||||
if not isinstance(extra_data, dict):
|
||||
return None
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo')
|
||||
if not isinstance(extra_pnginfo, dict):
|
||||
return None
|
||||
workflow = extra_pnginfo.get('workflow')
|
||||
if not isinstance(workflow, dict):
|
||||
return None
|
||||
workflow_id = workflow.get('id')
|
||||
if isinstance(workflow_id, str) and workflow_id:
|
||||
return workflow_id
|
||||
return None
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
"""Extract create_time and workflow_id from extra_data.
|
||||
|
||||
@ -100,8 +116,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
||||
tuple: (create_time, workflow_id)
|
||||
"""
|
||||
create_time = extra_data.get('create_time')
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
||||
workflow_id = extract_workflow_id(extra_data)
|
||||
return create_time, workflow_id
|
||||
|
||||
|
||||
|
||||
30
main.py
30
main.py
@ -317,20 +317,28 @@ def prompt_worker(q, server_instance):
|
||||
for k in sensitive:
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
metadata = item[6] if len(item) > 6 and isinstance(item[6], dict) else None
|
||||
server_instance.active_prompt_metadata = (prompt_id, metadata) if metadata else None
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
try:
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
# Drop sensitive (index 5) and metadata (index 6); history keeps a 5-tuple.
|
||||
remove_sensitive = lambda prompt: prompt[:5]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
finally:
|
||||
# Clear after the terminal send so that frame still carries metadata.
|
||||
server_instance.active_prompt_metadata = None
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
|
||||
31
server.py
31
server.py
@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from app.prompt_metadata import validate_client_metadata
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@ -252,6 +253,11 @@ class PromptServer():
|
||||
self.last_node_id = None
|
||||
self.client_id = None
|
||||
|
||||
# (prompt_id, opaque tag dict) pinned by main.py around each prompt.
|
||||
# send_sync only spreads the dict onto payloads whose prompt_id matches,
|
||||
# so concurrent queue/status broadcasts are not contaminated.
|
||||
self.active_prompt_metadata: Optional[tuple[str, dict]] = None
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
|
||||
@routes.get('/ws')
|
||||
@ -275,7 +281,16 @@ class PromptServer():
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
payload = {"node": self.last_node_id}
|
||||
last_prompt_id = getattr(self, "last_prompt_id", None)
|
||||
if last_prompt_id:
|
||||
payload["prompt_id"] = last_prompt_id
|
||||
slot = self.active_prompt_metadata
|
||||
if slot is not None:
|
||||
active_prompt_id, meta = slot
|
||||
if meta and payload.get("prompt_id") == active_prompt_id:
|
||||
payload = {**meta, **payload}
|
||||
await self.send("executing", payload, sid)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
@ -955,7 +970,14 @@ class PromptServer():
|
||||
if sensitive_val in extra_data:
|
||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||
raw_metadata = extra_data.pop("metadata", None)
|
||||
client_metadata, meta_error = validate_client_metadata(raw_metadata)
|
||||
if meta_error is not None:
|
||||
return web.json_response(
|
||||
{"error": {"type": "invalid_metadata", "message": meta_error}},
|
||||
status=400,
|
||||
)
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive, client_metadata))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
else:
|
||||
@ -1217,6 +1239,11 @@ class PromptServer():
|
||||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
slot = self.active_prompt_metadata
|
||||
if slot is not None and isinstance(data, dict):
|
||||
active_prompt_id, meta = slot
|
||||
if meta and data.get("prompt_id") == active_prompt_id:
|
||||
data = {**meta, **data}
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid))
|
||||
|
||||
|
||||
288
tests-unit/server_test/test_prompt_metadata.py
Normal file
288
tests-unit/server_test/test_prompt_metadata.py
Normal file
@ -0,0 +1,288 @@
|
||||
"""Tests for the opaque per-prompt metadata mechanism on PromptServer."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.prompt_metadata import (
|
||||
MAX_METADATA_KEY_LEN,
|
||||
MAX_METADATA_KEYS,
|
||||
MAX_METADATA_VALUE_LEN,
|
||||
RESERVED_METADATA_KEYS,
|
||||
validate_client_metadata,
|
||||
)
|
||||
from comfy_execution.jobs import extract_workflow_id
|
||||
|
||||
|
||||
class TestExtractWorkflowId:
|
||||
|
||||
def test_returns_id_when_present(self):
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": "wf-1"}}}) == "wf-1"
|
||||
|
||||
def test_returns_none_when_missing(self):
|
||||
assert extract_workflow_id({}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {}}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {}}}) is None
|
||||
|
||||
def test_returns_none_for_empty_or_wrong_type(self):
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": ""}}}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": 42}}}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": None}}}) is None
|
||||
|
||||
def test_returns_none_for_non_dict_input(self):
|
||||
assert extract_workflow_id(None) is None
|
||||
assert extract_workflow_id("not a dict") is None
|
||||
assert extract_workflow_id({"extra_pnginfo": "not a dict"}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": "not a dict"}}) is None
|
||||
|
||||
|
||||
class _FakeServer:
|
||||
"""Minimal PromptServer stand-in mirroring send_sync verbatim.
|
||||
|
||||
``active_prompt_metadata`` is ``Optional[tuple[str, dict]]`` — the
|
||||
``prompt_id`` it belongs to plus the opaque dict. send_sync only merges
|
||||
when the outgoing payload's ``prompt_id`` matches the active one, so
|
||||
unrelated queue/status broadcasts are not contaminated.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_prompt_metadata = None
|
||||
self.captured = []
|
||||
self.loop = MagicMock()
|
||||
self.loop.call_soon_threadsafe.side_effect = (
|
||||
lambda fn, msg: self.captured.append(msg)
|
||||
)
|
||||
self.messages = MagicMock()
|
||||
self.messages.put_nowait = MagicMock()
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
slot = self.active_prompt_metadata
|
||||
if slot is not None and isinstance(data, dict):
|
||||
active_prompt_id, meta = slot
|
||||
if meta and data.get("prompt_id") == active_prompt_id:
|
||||
data = {**meta, **data}
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
return _FakeServer()
|
||||
|
||||
|
||||
class TestSendSyncMerge:
|
||||
def test_spreads_active_metadata_onto_dict_payload(self, server):
|
||||
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
|
||||
|
||||
server.send_sync(
|
||||
"executing", {"node": "n1", "prompt_id": "p1"}, "client-1"
|
||||
)
|
||||
|
||||
event, data, sid = server.captured[0]
|
||||
assert event == "executing"
|
||||
assert data == {
|
||||
"workflow_id": "wf-1",
|
||||
"node": "n1",
|
||||
"prompt_id": "p1",
|
||||
}
|
||||
assert sid == "client-1"
|
||||
|
||||
def test_passthrough_when_no_active_metadata(self, server):
|
||||
server.active_prompt_metadata = None
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == {"node": "n1", "prompt_id": "p1"}
|
||||
|
||||
def test_passthrough_when_metadata_is_empty_dict(self, server):
|
||||
server.active_prompt_metadata = ("p1", {})
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == {"node": "n1", "prompt_id": "p1"}
|
||||
|
||||
def test_event_payload_wins_on_key_conflict(self, server):
|
||||
server.active_prompt_metadata = (
|
||||
"p1",
|
||||
{"workflow_id": "wf-1", "prompt_id": "from-meta"},
|
||||
)
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"}, "c1")
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data["prompt_id"] == "p1"
|
||||
assert data["workflow_id"] == "wf-1"
|
||||
|
||||
def test_non_dict_payload_passes_through_untouched(self, server):
|
||||
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
|
||||
|
||||
server.send_sync("text", b"\x00\x00\x00\x03foobar", "c1")
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == b"\x00\x00\x00\x03foobar"
|
||||
|
||||
def test_terminal_executing_frame_includes_metadata(self, server):
|
||||
server.active_prompt_metadata = ("p1", {"workflow_id": "wf-1"})
|
||||
|
||||
server.send_sync(
|
||||
"executing", {"node": None, "prompt_id": "p1"}, "client-1"
|
||||
)
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == {
|
||||
"workflow_id": "wf-1",
|
||||
"node": None,
|
||||
"prompt_id": "p1",
|
||||
}
|
||||
|
||||
def test_opaque_dict_supports_arbitrary_keys(self, server):
|
||||
server.active_prompt_metadata = (
|
||||
"p1",
|
||||
{"workflow_id": "wf-1", "trace_id": "trace-123", "tenant": "acme"},
|
||||
)
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data["workflow_id"] == "wf-1"
|
||||
assert data["trace_id"] == "trace-123"
|
||||
assert data["tenant"] == "acme"
|
||||
|
||||
|
||||
class TestStatusBroadcastsAreNotContaminated:
|
||||
"""Regression tests for the contamination bug:
|
||||
|
||||
``send_sync`` previously spread metadata onto any dict payload, so a
|
||||
status broadcast fired while a prompt was running picked up that
|
||||
prompt's metadata even though it had nothing to do with that prompt.
|
||||
"""
|
||||
|
||||
def test_status_payload_without_prompt_id_is_untouched(self, server):
|
||||
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
|
||||
|
||||
server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 1}}})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == {"status": {"exec_info": {"queue_remaining": 1}}}
|
||||
assert "workflow_id" not in data
|
||||
|
||||
def test_payload_for_different_prompt_is_untouched(self, server):
|
||||
# Active prompt is p-running; we send a frame for p-other (e.g. another
|
||||
# client's queued item). The merge must not leak across prompts.
|
||||
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "p-other"})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == {"node": "n1", "prompt_id": "p-other"}
|
||||
assert "workflow_id" not in data
|
||||
|
||||
def test_queue_updated_frame_during_active_prompt_is_clean(self, server):
|
||||
server.active_prompt_metadata = ("p-running", {"workflow_id": "wf-1"})
|
||||
|
||||
server.send_sync("status", {"status": {"exec_info": {"queue_remaining": 0}}})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert "workflow_id" not in data
|
||||
|
||||
|
||||
class TestWorkerSerializationIsolatesMetadata:
|
||||
def test_two_prompts_sharing_prompt_id_get_correct_metadata(self, server):
|
||||
# Prompt A
|
||||
server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-AAA"})
|
||||
server.send_sync("execution_start", {"prompt_id": "P-shared"})
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "P-shared"})
|
||||
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
|
||||
server.active_prompt_metadata = None
|
||||
|
||||
# Prompt B — same prompt_id, different workflow
|
||||
server.active_prompt_metadata = ("P-shared", {"workflow_id": "wf-BBB"})
|
||||
server.send_sync("execution_start", {"prompt_id": "P-shared"})
|
||||
server.send_sync("executing", {"node": "n2", "prompt_id": "P-shared"})
|
||||
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
|
||||
server.active_prompt_metadata = None
|
||||
|
||||
frames = [d for (_, d, _) in server.captured]
|
||||
a_frames = frames[:3]
|
||||
b_frames = frames[3:]
|
||||
|
||||
assert all(f["workflow_id"] == "wf-AAA" for f in a_frames)
|
||||
assert all(f["workflow_id"] == "wf-BBB" for f in b_frames)
|
||||
assert all(f["prompt_id"] == "P-shared" for f in frames)
|
||||
|
||||
|
||||
class TestValidateClientMetadata:
|
||||
def test_none_returns_empty_dict(self):
|
||||
cleaned, error = validate_client_metadata(None)
|
||||
assert cleaned == {}
|
||||
assert error is None
|
||||
|
||||
def test_flat_string_dict_is_accepted(self):
|
||||
cleaned, error = validate_client_metadata(
|
||||
{"workflow_id": "wf-1", "trace_id": "trace-abc"}
|
||||
)
|
||||
assert cleaned == {"workflow_id": "wf-1", "trace_id": "trace-abc"}
|
||||
assert error is None
|
||||
|
||||
def test_non_dict_is_rejected(self):
|
||||
_, error = validate_client_metadata("not a dict")
|
||||
assert error is not None
|
||||
assert "object" in error
|
||||
|
||||
def test_list_is_rejected(self):
|
||||
_, error = validate_client_metadata([("workflow_id", "wf-1")])
|
||||
assert error is not None
|
||||
|
||||
def test_nested_dict_value_is_rejected(self):
|
||||
_, error = validate_client_metadata({"workflow": {"id": "wf-1"}})
|
||||
assert error is not None
|
||||
assert "string" in error
|
||||
|
||||
def test_non_string_value_is_rejected(self):
|
||||
_, error = validate_client_metadata({"workflow_id": 42})
|
||||
assert error is not None
|
||||
|
||||
def test_non_string_key_is_rejected(self):
|
||||
_, error = validate_client_metadata({123: "wf-1"})
|
||||
assert error is not None
|
||||
|
||||
def test_empty_key_is_rejected(self):
|
||||
_, error = validate_client_metadata({"": "wf-1"})
|
||||
assert error is not None
|
||||
|
||||
def test_key_exceeding_limit_is_rejected(self):
|
||||
_, error = validate_client_metadata({"k" * (MAX_METADATA_KEY_LEN + 1): "v"})
|
||||
assert error is not None
|
||||
assert str(MAX_METADATA_KEY_LEN) in error
|
||||
|
||||
def test_value_exceeding_limit_is_rejected(self):
|
||||
_, error = validate_client_metadata({"workflow_id": "v" * (MAX_METADATA_VALUE_LEN + 1)})
|
||||
assert error is not None
|
||||
assert str(MAX_METADATA_VALUE_LEN) in error
|
||||
|
||||
def test_too_many_keys_is_rejected(self):
|
||||
raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS + 1)}
|
||||
_, error = validate_client_metadata(raw)
|
||||
assert error is not None
|
||||
assert str(MAX_METADATA_KEYS) in error
|
||||
|
||||
def test_max_size_dict_is_accepted(self):
|
||||
raw = {f"k{i}": "v" for i in range(MAX_METADATA_KEYS)}
|
||||
cleaned, error = validate_client_metadata(raw)
|
||||
assert error is None
|
||||
assert len(cleaned) == MAX_METADATA_KEYS
|
||||
|
||||
def test_max_length_strings_are_accepted(self):
|
||||
raw = {"k" * MAX_METADATA_KEY_LEN: "v" * MAX_METADATA_VALUE_LEN}
|
||||
cleaned, error = validate_client_metadata(raw)
|
||||
assert error is None
|
||||
assert cleaned == raw
|
||||
|
||||
@pytest.mark.parametrize("reserved_key", sorted(RESERVED_METADATA_KEYS))
|
||||
def test_reserved_keys_are_rejected(self, reserved_key):
|
||||
_, error = validate_client_metadata({reserved_key: "anything"})
|
||||
assert error is not None
|
||||
assert reserved_key in error
|
||||
Loading…
Reference in New Issue
Block a user