mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-31 03:17:23 +08:00
server.py builds an opaque dict at submission time from extra_data and pins it on PromptServer.active_prompt_metadata while main.py's worker drives the prompt. send_sync spreads the dict's key/value pairs onto outgoing payloads so frames carry whatever tags the submission attached (today: workflow_id). The mechanism is intentionally untyped — the transport layer doesn't know what workflow_id means or treat any key specially. Adding a new propagated field requires only a one-line addition in post_prompt; execution.py and comfy_execution/progress.py are not touched. execution.py changes: 0 lines.
163 lines
5.8 KiB
Python
163 lines
5.8 KiB
Python
"""Tests for the opaque per-prompt metadata mechanism on PromptServer."""
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from comfy_execution.jobs import extract_workflow_id
|
|
|
|
|
|
class TestExtractWorkflowId:
|
|
|
|
def test_returns_id_when_present(self):
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": "wf-1"}}}) == "wf-1"
|
|
|
|
def test_returns_none_when_missing(self):
|
|
assert extract_workflow_id({}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {}}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {}}}) is None
|
|
|
|
def test_returns_none_for_empty_or_wrong_type(self):
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": ""}}}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": 42}}}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": {"id": None}}}) is None
|
|
|
|
def test_returns_none_for_non_dict_input(self):
|
|
assert extract_workflow_id(None) is None
|
|
assert extract_workflow_id("not a dict") is None
|
|
assert extract_workflow_id({"extra_pnginfo": "not a dict"}) is None
|
|
assert extract_workflow_id({"extra_pnginfo": {"workflow": "not a dict"}}) is None
|
|
|
|
|
|
class _FakeServer:
|
|
"""Minimal PromptServer stand-in mirroring send_sync verbatim."""
|
|
|
|
def __init__(self):
|
|
self.active_prompt_metadata = None
|
|
self.captured = []
|
|
self.loop = MagicMock()
|
|
self.loop.call_soon_threadsafe.side_effect = (
|
|
lambda fn, msg: self.captured.append(msg)
|
|
)
|
|
self.messages = MagicMock()
|
|
self.messages.put_nowait = MagicMock()
|
|
|
|
def send_sync(self, event, data, sid=None):
|
|
meta = self.active_prompt_metadata
|
|
if meta and isinstance(data, dict):
|
|
data = {**meta, **data}
|
|
self.loop.call_soon_threadsafe(
|
|
self.messages.put_nowait, (event, data, sid)
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def server():
|
|
return _FakeServer()
|
|
|
|
|
|
class TestSendSyncMerge:
|
|
def test_spreads_active_metadata_onto_dict_payload(self, server):
|
|
server.active_prompt_metadata = {"workflow_id": "wf-1"}
|
|
|
|
server.send_sync(
|
|
"executing", {"node": "n1", "prompt_id": "p1"}, "client-1"
|
|
)
|
|
|
|
event, data, sid = server.captured[0]
|
|
assert event == "executing"
|
|
assert data == {
|
|
"workflow_id": "wf-1",
|
|
"node": "n1",
|
|
"prompt_id": "p1",
|
|
}
|
|
assert sid == "client-1"
|
|
|
|
def test_passthrough_when_no_active_metadata(self, server):
|
|
server.active_prompt_metadata = None
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == {"node": "n1", "prompt_id": "p1"}
|
|
|
|
def test_passthrough_when_metadata_is_empty_dict(self, server):
|
|
server.active_prompt_metadata = {}
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == {"node": "n1", "prompt_id": "p1"}
|
|
|
|
def test_event_payload_wins_on_key_conflict(self, server):
|
|
server.active_prompt_metadata = {"workflow_id": "wf-1", "prompt_id": "from-meta"}
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "from-frame"}, "c1")
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data["prompt_id"] == "from-frame"
|
|
assert data["workflow_id"] == "wf-1"
|
|
|
|
def test_non_dict_payload_passes_through_untouched(self, server):
|
|
# BinaryEventTypes.TEXT byte frames must not be merged.
|
|
server.active_prompt_metadata = {"workflow_id": "wf-1"}
|
|
|
|
server.send_sync("text", b"\x00\x00\x00\x03foobar", "c1")
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == b"\x00\x00\x00\x03foobar"
|
|
|
|
def test_terminal_executing_frame_includes_metadata(self, server):
|
|
# Slot is cleared after this send in main.py so the reset still carries metadata (#13684 race).
|
|
server.active_prompt_metadata = {"workflow_id": "wf-1"}
|
|
|
|
server.send_sync(
|
|
"executing", {"node": None, "prompt_id": "p1"}, "client-1"
|
|
)
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data == {
|
|
"workflow_id": "wf-1",
|
|
"node": None,
|
|
"prompt_id": "p1",
|
|
}
|
|
|
|
def test_opaque_dict_supports_arbitrary_keys(self, server):
|
|
server.active_prompt_metadata = {
|
|
"workflow_id": "wf-1",
|
|
"trace_id": "trace-123",
|
|
"tenant": "acme",
|
|
}
|
|
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "p1"})
|
|
|
|
_, data, _ = server.captured[0]
|
|
assert data["workflow_id"] == "wf-1"
|
|
assert data["trace_id"] == "trace-123"
|
|
assert data["tenant"] == "acme"
|
|
|
|
|
|
class TestWorkerSerializationIsolatesMetadata:
|
|
def test_two_prompts_sharing_prompt_id_get_correct_metadata(self, server):
|
|
# Prompt A
|
|
server.active_prompt_metadata = {"workflow_id": "wf-AAA"}
|
|
server.send_sync("execution_start", {"prompt_id": "P-shared"})
|
|
server.send_sync("executing", {"node": "n1", "prompt_id": "P-shared"})
|
|
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
|
|
server.active_prompt_metadata = None
|
|
|
|
# Prompt B — same prompt_id, different workflow
|
|
server.active_prompt_metadata = {"workflow_id": "wf-BBB"}
|
|
server.send_sync("execution_start", {"prompt_id": "P-shared"})
|
|
server.send_sync("executing", {"node": "n2", "prompt_id": "P-shared"})
|
|
server.send_sync("executing", {"node": None, "prompt_id": "P-shared"})
|
|
server.active_prompt_metadata = None
|
|
|
|
frames = [d for (_, d, _) in server.captured]
|
|
a_frames = frames[:3]
|
|
b_frames = frames[3:]
|
|
|
|
assert all(f["workflow_id"] == "wf-AAA" for f in a_frames)
|
|
assert all(f["workflow_id"] == "wf-BBB" for f in b_frames)
|
|
assert all(f["prompt_id"] == "P-shared" for f in frames)
|