mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 19:07:25 +08:00
feat(server): propagate opaque per-prompt metadata on WebSocket frames (FE-745)
server.py builds an opaque dict at submission time from extra_data and pins it on PromptServer.active_prompt_metadata while main.py's worker drives the prompt. send_sync spreads the dict's key/value pairs onto outgoing payloads so frames carry whatever tags the submission attached (today: workflow_id). The mechanism is intentionally untyped — the transport layer doesn't know what workflow_id means or treat any key specially. Adding a new propagated field requires only a one-line addition in post_prompt; execution.py and comfy_execution/progress.py are not touched. execution.py changes: 0 lines.
This commit is contained in:
parent
6b61918a16
commit
dfc901078e
@ -93,6 +93,22 @@ def _create_text_preview(value: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
|
||||
"""Return extra_data["extra_pnginfo"]["workflow"]["id"] when it is a non-empty string."""
|
||||
if not isinstance(extra_data, dict):
|
||||
return None
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo')
|
||||
if not isinstance(extra_pnginfo, dict):
|
||||
return None
|
||||
workflow = extra_pnginfo.get('workflow')
|
||||
if not isinstance(workflow, dict):
|
||||
return None
|
||||
workflow_id = workflow.get('id')
|
||||
if isinstance(workflow_id, str) and workflow_id:
|
||||
return workflow_id
|
||||
return None
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
"""Extract create_time and workflow_id from extra_data.
|
||||
|
||||
@ -100,8 +116,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
||||
tuple: (create_time, workflow_id)
|
||||
"""
|
||||
create_time = extra_data.get('create_time')
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
||||
workflow_id = extract_workflow_id(extra_data)
|
||||
return create_time, workflow_id
|
||||
|
||||
|
||||
|
||||
30
main.py
30
main.py
@ -317,20 +317,28 @@ def prompt_worker(q, server_instance):
|
||||
for k in sensitive:
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
metadata = item[6] if len(item) > 6 and isinstance(item[6], dict) else None
|
||||
server_instance.active_prompt_metadata = metadata
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
try:
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
# Drop sensitive (index 5) and metadata (index 6); history keeps a 5-tuple.
|
||||
remove_sensitive = lambda prompt: prompt[:5]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
finally:
|
||||
# Clear after the terminal send so that frame still carries metadata.
|
||||
server_instance.active_prompt_metadata = None
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
|
||||
18
server.py
18
server.py
@ -252,6 +252,9 @@ class PromptServer():
|
||||
self.last_node_id = None
|
||||
self.client_id = None
|
||||
|
||||
# Opaque tag dict pinned by main.py around each prompt; send_sync spreads it.
|
||||
self.active_prompt_metadata: Optional[dict] = None
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
|
||||
@routes.get('/ws')
|
||||
@ -275,7 +278,13 @@ class PromptServer():
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
payload = {"node": self.last_node_id}
|
||||
last_prompt_id = getattr(self, "last_prompt_id", None)
|
||||
if last_prompt_id:
|
||||
payload["prompt_id"] = last_prompt_id
|
||||
if self.active_prompt_metadata:
|
||||
payload = {**self.active_prompt_metadata, **payload}
|
||||
await self.send("executing", payload, sid)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
@ -955,7 +964,9 @@ class PromptServer():
|
||||
if sensitive_val in extra_data:
|
||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||
raw_metadata = extra_data.pop("metadata", None)
|
||||
client_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive, client_metadata))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
else:
|
||||
@ -1217,6 +1228,9 @@ class PromptServer():
|
||||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
meta = self.active_prompt_metadata
|
||||
if meta and isinstance(data, dict):
|
||||
data = {**meta, **data}
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid))
|
||||
|
||||
|
||||
162
tests-unit/server_test/test_prompt_metadata.py
Normal file
162
tests-unit/server_test/test_prompt_metadata.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""Tests for the opaque per-prompt metadata mechanism on PromptServer."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_execution.jobs import extract_workflow_id
|
||||
|
||||
|
||||
class TestExtractWorkflowId:
|
||||
|
||||
def test_returns_id_when_present(self):
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": "wf-1"}}}) == "wf-1"
|
||||
|
||||
def test_returns_none_when_missing(self):
|
||||
assert extract_workflow_id({}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {}}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {}}}) is None
|
||||
|
||||
def test_returns_none_for_empty_or_wrong_type(self):
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": ""}}}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": 42}}}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": None}}}) is None
|
||||
|
||||
def test_returns_none_for_non_dict_input(self):
|
||||
assert extract_workflow_id(None) is None
|
||||
assert extract_workflow_id("not a dict") is None
|
||||
assert extract_workflow_id({"extra_pnginfo": "not a dict"}) is None
|
||||
assert extract_workflow_id({"extra_pnginfo": {"workflow": "not a dict"}}) is None
|
||||
|
||||
|
||||
class _FakeServer:
|
||||
"""Minimal PromptServer stand-in mirroring send_sync verbatim."""
|
||||
|
||||
def __init__(self):
|
||||
self.active_prompt_metadata = None
|
||||
self.captured = []
|
||||
self.loop = MagicMock()
|
||||
self.loop.call_soon_threadsafe.side_effect = (
|
||||
lambda fn, msg: self.captured.append(msg)
|
||||
)
|
||||
self.messages = MagicMock()
|
||||
self.messages.put_nowait = MagicMock()
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
meta = self.active_prompt_metadata
|
||||
if meta and isinstance(data, dict):
|
||||
data = {**meta, **data}
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
return _FakeServer()
|
||||
|
||||
|
||||
class TestSendSyncMerge:
|
||||
def test_spreads_active_metadata_onto_dict_payload(self, server):
|
||||
server.active_prompt_metadata = {"workflow_id": "wf-1"}
|
||||
|
||||
server.send_sync(
|
||||
"executing", {"node": "n1", "prompt_id": "p1"}, "client-1"
|
||||
)
|
||||
|
||||
event, data, sid = server.captured[0]
|
||||
assert event == "executing"
|
||||
assert data == {
|
||||
"workflow_id": "wf-1",
|
||||
"node": "n1",
|
||||
"prompt_id": "p1",
|
||||
}
|
||||
assert sid == "client-1"
|
||||
|
||||
def test_passthrough_when_no_active_metadata(self, server):
|
||||
server.active_prompt_metadata = None
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == {"node": "n1", "prompt_id": "p1"}
|
||||
|
||||
def test_passthrough_when_metadata_is_empty_dict(self, server):
|
||||
server.active_prompt_metadata = {}
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == {"node": "n1", "prompt_id": "p1"}
|
||||
|
||||
def test_event_payload_wins_on_key_conflict(self, server):
|
||||
server.active_prompt_metadata = {"workflow_id": "wf-1", "prompt_id": "from-meta"}
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "from-frame"}, "c1")
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data["prompt_id"] == "from-frame"
|
||||
assert data["workflow_id"] == "wf-1"
|
||||
|
||||
def test_non_dict_payload_passes_through_untouched(self, server):
|
||||
# BinaryEventTypes.TEXT byte frames must not be merged.
|
||||
server.active_prompt_metadata = {"workflow_id": "wf-1"}
|
||||
|
||||
server.send_sync("text", b"\x00\x00\x00\x03foobar", "c1")
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == b"\x00\x00\x00\x03foobar"
|
||||
|
||||
def test_terminal_executing_frame_includes_metadata(self, server):
|
||||
# Slot is cleared after this send in main.py so the reset still carries metadata (#13684 race).
|
||||
server.active_prompt_metadata = {"workflow_id": "wf-1"}
|
||||
|
||||
server.send_sync(
|
||||
"executing", {"node": None, "prompt_id": "p1"}, "client-1"
|
||||
)
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data == {
|
||||
"workflow_id": "wf-1",
|
||||
"node": None,
|
||||
"prompt_id": "p1",
|
||||
}
|
||||
|
||||
def test_opaque_dict_supports_arbitrary_keys(self, server):
|
||||
server.active_prompt_metadata = {
|
||||
"workflow_id": "wf-1",
|
||||
"trace_id": "trace-123",
|
||||
"tenant": "acme",
|
||||
}
|
||||
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
||||
|
||||
_, data, _ = server.captured[0]
|
||||
assert data["workflow_id"] == "wf-1"
|
||||
assert data["trace_id"] == "trace-123"
|
||||
assert data["tenant"] == "acme"
|
||||
|
||||
|
||||
class TestWorkerSerializationIsolatesMetadata:
|
||||
def test_two_prompts_sharing_prompt_id_get_correct_metadata(self, server):
|
||||
# Prompt A
|
||||
server.active_prompt_metadata = {"workflow_id": "wf-AAA"}
|
||||
server.send_sync("execution_start", {"prompt_id": "P-shared"})
|
||||
server.send_sync("executing", {"node": "n1", "prompt_id": "P-shared"})
|
||||
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
|
||||
server.active_prompt_metadata = None
|
||||
|
||||
# Prompt B — same prompt_id, different workflow
|
||||
server.active_prompt_metadata = {"workflow_id": "wf-BBB"}
|
||||
server.send_sync("execution_start", {"prompt_id": "P-shared"})
|
||||
server.send_sync("executing", {"node": "n2", "prompt_id": "P-shared"})
|
||||
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
|
||||
server.active_prompt_metadata = None
|
||||
|
||||
frames = [d for (_, d, _) in server.captured]
|
||||
a_frames = frames[:3]
|
||||
b_frames = frames[3:]
|
||||
|
||||
assert all(f["workflow_id"] == "wf-AAA" for f in a_frames)
|
||||
assert all(f["workflow_id"] == "wf-BBB" for f in b_frames)
|
||||
assert all(f["prompt_id"] == "P-shared" for f in frames)
|
||||
Loading…
Reference in New Issue
Block a user