mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 01:37:45 +08:00
- Default sid to self.client_id when not explicitly provided, matching every other WS message dispatch (executing, executed, progress_state, etc.) - Previously sid=None caused broadcast to all connected clients - Format signature per ruff, remove redundant comments - Add unit tests for routing, legacy format, and new prompt_id format Amp-Thread-ID: https://ampcode.com/threads/T-019ca3ce-c530-75dd-8d68-349e745a022e
208 lines
7.3 KiB
Python
208 lines
7.3 KiB
Python
"""Tests for send_progress_text routing and binary format logic.
|
||
|
||
These tests verify:
|
||
1. sid defaults to client_id (unicast) instead of None (broadcast)
|
||
2. Legacy binary format when prompt_id absent or client unsupported
|
||
3. New binary format with prompt_id when client supports the feature flag
|
||
"""
|
||
|
||
import struct
|
||
|
||
from comfy_api import feature_flags
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers – replicate the packing logic so we can assert on the wire format
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _unpack_legacy(message: bytes):
|
||
"""Unpack a legacy progress_text binary message -> (node_id, text)."""
|
||
offset = 0
|
||
node_id_len = struct.unpack_from(">I", message, offset)[0]
|
||
offset += 4
|
||
node_id = message[offset : offset + node_id_len].decode("utf-8")
|
||
offset += node_id_len
|
||
text = message[offset:].decode("utf-8")
|
||
return node_id, text
|
||
|
||
|
||
def _unpack_with_prompt_id(message: bytes):
|
||
"""Unpack new format -> (prompt_id, node_id, text)."""
|
||
offset = 0
|
||
prompt_id_len = struct.unpack_from(">I", message, offset)[0]
|
||
offset += 4
|
||
prompt_id = message[offset : offset + prompt_id_len].decode("utf-8")
|
||
offset += prompt_id_len
|
||
node_id_len = struct.unpack_from(">I", message, offset)[0]
|
||
offset += 4
|
||
node_id = message[offset : offset + node_id_len].decode("utf-8")
|
||
offset += node_id_len
|
||
text = message[offset:].decode("utf-8")
|
||
return prompt_id, node_id, text
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Minimal stub that mirrors send_progress_text logic from server.py
|
||
# We can't import server.py directly (it pulls in torch via nodes.py),
|
||
# so we replicate the method body here. If the implementation changes,
|
||
# these tests should be updated in tandem.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class _StubServer:
|
||
"""Stub that captures send_sync calls and runs the real packing logic."""
|
||
|
||
def __init__(self, client_id=None, sockets_metadata=None):
|
||
self.client_id = client_id
|
||
self.sockets_metadata = sockets_metadata or {}
|
||
self.sent = [] # list of (event, data, sid)
|
||
|
||
def send_sync(self, event, data, sid=None):
|
||
self.sent.append((event, data, sid))
|
||
|
||
def send_progress_text(self, text, node_id, prompt_id=None, sid=None):
|
||
if isinstance(text, str):
|
||
text = text.encode("utf-8")
|
||
node_id_bytes = str(node_id).encode("utf-8")
|
||
|
||
target_sid = sid if sid is not None else self.client_id
|
||
|
||
if prompt_id and feature_flags.supports_feature(
|
||
self.sockets_metadata, target_sid, "supports_progress_text_metadata"
|
||
):
|
||
prompt_id_bytes = prompt_id.encode("utf-8")
|
||
message = (
|
||
struct.pack(">I", len(prompt_id_bytes))
|
||
+ prompt_id_bytes
|
||
+ struct.pack(">I", len(node_id_bytes))
|
||
+ node_id_bytes
|
||
+ text
|
||
)
|
||
else:
|
||
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
||
|
||
self.send_sync(3, message, target_sid) # 3 == BinaryEventTypes.TEXT
|
||
|
||
|
||
# ===========================================================================
|
||
# Routing tests
|
||
# ===========================================================================
|
||
|
||
|
||
class TestSendProgressTextRouting:
|
||
"""Verify sid resolution: defaults to client_id, overridable via sid param."""
|
||
|
||
def test_defaults_to_client_id_when_sid_not_provided(self):
|
||
server = _StubServer(client_id="active-client-123")
|
||
server.send_progress_text("hello", "node1")
|
||
|
||
_, _, sid = server.sent[0]
|
||
assert sid == "active-client-123"
|
||
|
||
def test_explicit_sid_overrides_client_id(self):
|
||
server = _StubServer(client_id="active-client-123")
|
||
server.send_progress_text("hello", "node1", sid="explicit-sid")
|
||
|
||
_, _, sid = server.sent[0]
|
||
assert sid == "explicit-sid"
|
||
|
||
def test_broadcasts_when_no_client_id_and_no_sid(self):
|
||
server = _StubServer(client_id=None)
|
||
server.send_progress_text("hello", "node1")
|
||
|
||
_, _, sid = server.sent[0]
|
||
assert sid is None
|
||
|
||
|
||
# ===========================================================================
|
||
# Legacy format tests
|
||
# ===========================================================================
|
||
|
||
|
||
class TestSendProgressTextLegacyFormat:
|
||
"""Verify legacy binary format: [4B node_id_len][node_id][text]."""
|
||
|
||
def test_legacy_format_no_prompt_id(self):
|
||
server = _StubServer(client_id="c1")
|
||
server.send_progress_text("some text", "node-42")
|
||
|
||
_, data, _ = server.sent[0]
|
||
node_id, text = _unpack_legacy(data)
|
||
assert node_id == "node-42"
|
||
assert text == "some text"
|
||
|
||
def test_legacy_format_when_client_unsupported(self):
|
||
server = _StubServer(
|
||
client_id="c1",
|
||
sockets_metadata={"c1": {"feature_flags": {}}},
|
||
)
|
||
server.send_progress_text("text", "node1", prompt_id="prompt-abc")
|
||
|
||
_, data, _ = server.sent[0]
|
||
node_id, text = _unpack_legacy(data)
|
||
assert node_id == "node1"
|
||
assert text == "text"
|
||
|
||
def test_bytes_input_preserved(self):
|
||
server = _StubServer(client_id="c1")
|
||
server.send_progress_text(b"raw bytes", "node1")
|
||
|
||
_, data, _ = server.sent[0]
|
||
node_id, text = _unpack_legacy(data)
|
||
assert text == "raw bytes"
|
||
|
||
|
||
# ===========================================================================
|
||
# New format tests
|
||
# ===========================================================================
|
||
|
||
|
||
class TestSendProgressTextNewFormat:
|
||
"""Verify new format: [4B prompt_id_len][prompt_id][4B node_id_len][node_id][text]."""
|
||
|
||
def test_includes_prompt_id_when_supported(self):
|
||
server = _StubServer(
|
||
client_id="c1",
|
||
sockets_metadata={
|
||
"c1": {"feature_flags": {"supports_progress_text_metadata": True}}
|
||
},
|
||
)
|
||
server.send_progress_text("progress!", "node-7", prompt_id="prompt-xyz")
|
||
|
||
_, data, _ = server.sent[0]
|
||
prompt_id, node_id, text = _unpack_with_prompt_id(data)
|
||
assert prompt_id == "prompt-xyz"
|
||
assert node_id == "node-7"
|
||
assert text == "progress!"
|
||
|
||
def test_new_format_with_explicit_sid(self):
|
||
server = _StubServer(
|
||
client_id=None,
|
||
sockets_metadata={
|
||
"my-sid": {"feature_flags": {"supports_progress_text_metadata": True}}
|
||
},
|
||
)
|
||
server.send_progress_text("txt", "n1", prompt_id="p1", sid="my-sid")
|
||
|
||
_, data, sid = server.sent[0]
|
||
assert sid == "my-sid"
|
||
prompt_id, node_id, text = _unpack_with_prompt_id(data)
|
||
assert prompt_id == "p1"
|
||
assert node_id == "n1"
|
||
assert text == "txt"
|
||
|
||
|
||
# ===========================================================================
|
||
# Feature flag tests
|
||
# ===========================================================================
|
||
|
||
|
||
class TestProgressTextFeatureFlag:
|
||
"""Verify the supports_progress_text_metadata flag exists in server features."""
|
||
|
||
def test_flag_in_server_features(self):
|
||
features = feature_flags.get_server_features()
|
||
assert "supports_progress_text_metadata" in features
|
||
assert features["supports_progress_text_metadata"] is True
|