mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-18 22:09:38 +08:00
Address adversarial review: stack registry + tighten TEXT routing
Two issues raised against the per-prompt metadata registry: 1. Client-supplied prompt_id can collide (post_prompt accepts the id verbatim). With a flat dict-keyed registry, the second submission clobbered the first and a single unregister could erase metadata still needed by the other prompt. Now stored as a LIFO stack per prompt_id — most recent registration wins on merge, unregister pops one entry, the key is dropped only when the stack drains. 2. BinaryEventTypes.TEXT (send_progress_text) bypasses the metadata merge because the payload is bytes, and the wire format has no prompt_id / workflow_id field. The merge can't fix this without a wire-format change + frontend feature flag, which is out of scope for FE-745. Inside scope: default the sid to PromptServer.client_id so other clients no longer silently receive untagged text frames. Cross-tab isolation inside a single client still depends on the wire-format follow-up. - comfy_execution/metadata.py: registry is dict[str, list[PromptMetadata]]; merge_prompt_metadata reads stack[-1]; new resolve_progress_text_sid helper extracted so the routing default is unit-testable without the full server import chain. - server.py: register_prompt_metadata appends to the stack; unregister_prompt_metadata pops; get_prompt_metadata returns a copy of the top entry; send_progress_text routes through resolve_progress_text_sid. - tests: collision LIFO behavior, sid resolution default, and the existing merge tests updated to the stack shape. 16 new assertions in this file, 104/104 pass overall.
This commit is contained in:
parent
5396b4fe67
commit
b0c05af67f
@ -30,6 +30,21 @@ def build_prompt_metadata(extra_data: Optional[dict]) -> PromptMetadata:
|
|||||||
return meta
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_progress_text_sid(sid, default_sid):
|
||||||
|
"""Pick the recipient for a ``send_progress_text`` binary frame.
|
||||||
|
|
||||||
|
Returns ``default_sid`` (typically ``PromptServer.client_id`` — the client
|
||||||
|
that submitted the active prompt) when the caller didn't pin a specific
|
||||||
|
socket. This narrows the audience for text status updates from "every
|
||||||
|
connected client" to "the client running this prompt", matching the
|
||||||
|
cross-client isolation other execution events already have.
|
||||||
|
|
||||||
|
Splitting this out keeps the unit test independent of the full ``server``
|
||||||
|
import chain.
|
||||||
|
"""
|
||||||
|
return default_sid if sid is None else sid
|
||||||
|
|
||||||
|
|
||||||
def merge_prompt_metadata(
|
def merge_prompt_metadata(
|
||||||
registry: dict,
|
registry: dict,
|
||||||
lock: threading.Lock,
|
lock: threading.Lock,
|
||||||
@ -38,6 +53,10 @@ def merge_prompt_metadata(
|
|||||||
"""Return ``data`` with the registered metadata for its ``prompt_id`` merged
|
"""Return ``data`` with the registered metadata for its ``prompt_id`` merged
|
||||||
top-level. The event payload wins on conflict, and non-dict payloads (e.g.
|
top-level. The event payload wins on conflict, and non-dict payloads (e.g.
|
||||||
the binary preview tuple) pass through untouched.
|
the binary preview tuple) pass through untouched.
|
||||||
|
|
||||||
|
The registry is a stack per ``prompt_id`` (``dict[str, list[PromptMetadata]]``)
|
||||||
|
so duplicate submissions of the same ``prompt_id`` don't clobber each
|
||||||
|
other's metadata; the most recently registered entry wins.
|
||||||
"""
|
"""
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
return data
|
return data
|
||||||
@ -45,7 +64,8 @@ def merge_prompt_metadata(
|
|||||||
if not prompt_id:
|
if not prompt_id:
|
||||||
return data
|
return data
|
||||||
with lock:
|
with lock:
|
||||||
meta = registry.get(prompt_id)
|
stack = registry.get(prompt_id)
|
||||||
|
meta = stack[-1] if stack else None
|
||||||
if not meta:
|
if not meta:
|
||||||
return data
|
return data
|
||||||
return {**meta, **data}
|
return {**meta, **data}
|
||||||
|
|||||||
37
server.py
37
server.py
@ -10,7 +10,12 @@ import folder_paths
|
|||||||
import execution
|
import execution
|
||||||
import threading
|
import threading
|
||||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
||||||
from comfy_execution.metadata import PromptMetadata, build_prompt_metadata, merge_prompt_metadata
|
from comfy_execution.metadata import (
|
||||||
|
PromptMetadata,
|
||||||
|
build_prompt_metadata,
|
||||||
|
merge_prompt_metadata,
|
||||||
|
resolve_progress_text_sid,
|
||||||
|
)
|
||||||
import uuid
|
import uuid
|
||||||
import urllib
|
import urllib
|
||||||
import json
|
import json
|
||||||
@ -254,7 +259,7 @@ class PromptServer():
|
|||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.client_id = None
|
self.client_id = None
|
||||||
|
|
||||||
self.prompt_metadata: dict[str, PromptMetadata] = {}
|
self.prompt_metadata: dict[str, list[PromptMetadata]] = {}
|
||||||
self._prompt_metadata_lock = threading.Lock()
|
self._prompt_metadata_lock = threading.Lock()
|
||||||
|
|
||||||
self.on_prompt_handlers = []
|
self.on_prompt_handlers = []
|
||||||
@ -1234,20 +1239,30 @@ class PromptServer():
|
|||||||
the execution thread and can be merged onto outbound WebSocket payloads
|
the execution thread and can be merged onto outbound WebSocket payloads
|
||||||
in ``send_sync`` without coupling the execution layer to workflow-level
|
in ``send_sync`` without coupling the execution layer to workflow-level
|
||||||
concepts.
|
concepts.
|
||||||
|
|
||||||
|
Stacked per ``prompt_id`` so a client retrying or colliding with the
|
||||||
|
same id doesn't have its metadata clobbered or, worse, removed by the
|
||||||
|
other prompt's unregister.
|
||||||
"""
|
"""
|
||||||
meta = build_prompt_metadata(extra_data)
|
meta = build_prompt_metadata(extra_data)
|
||||||
if not meta:
|
if not meta:
|
||||||
return
|
return
|
||||||
with self._prompt_metadata_lock:
|
with self._prompt_metadata_lock:
|
||||||
self.prompt_metadata[prompt_id] = meta
|
self.prompt_metadata.setdefault(prompt_id, []).append(meta)
|
||||||
|
|
||||||
def unregister_prompt_metadata(self, prompt_id: str) -> None:
|
def unregister_prompt_metadata(self, prompt_id: str) -> None:
|
||||||
with self._prompt_metadata_lock:
|
with self._prompt_metadata_lock:
|
||||||
self.prompt_metadata.pop(prompt_id, None)
|
stack = self.prompt_metadata.get(prompt_id)
|
||||||
|
if not stack:
|
||||||
|
return
|
||||||
|
stack.pop()
|
||||||
|
if not stack:
|
||||||
|
self.prompt_metadata.pop(prompt_id, None)
|
||||||
|
|
||||||
def get_prompt_metadata(self, prompt_id: str) -> PromptMetadata:
|
def get_prompt_metadata(self, prompt_id: str) -> PromptMetadata:
|
||||||
with self._prompt_metadata_lock:
|
with self._prompt_metadata_lock:
|
||||||
return dict(self.prompt_metadata.get(prompt_id, {}))
|
stack = self.prompt_metadata.get(prompt_id)
|
||||||
|
return dict(stack[-1]) if stack else {}
|
||||||
|
|
||||||
def send_sync(self, event, data, sid=None):
|
def send_sync(self, event, data, sid=None):
|
||||||
data = merge_prompt_metadata(self.prompt_metadata, self._prompt_metadata_lock, data)
|
data = merge_prompt_metadata(self.prompt_metadata, self._prompt_metadata_lock, data)
|
||||||
@ -1319,7 +1334,10 @@ class PromptServer():
|
|||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
def send_progress_text(
|
def send_progress_text(
|
||||||
self, text: Union[bytes, bytearray, str], node_id: str, sid=None
|
self,
|
||||||
|
text: Union[bytes, bytearray, str],
|
||||||
|
node_id: str,
|
||||||
|
sid=None,
|
||||||
):
|
):
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = text.encode("utf-8")
|
text = text.encode("utf-8")
|
||||||
@ -1328,4 +1346,11 @@ class PromptServer():
|
|||||||
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
|
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
|
||||||
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
||||||
|
|
||||||
|
# Default routing to the active prompt's client so other clients don't
|
||||||
|
# silently receive untagged text frames. The binary wire format does
|
||||||
|
# not yet carry prompt_id/workflow_id, so cross-tab filtering inside a
|
||||||
|
# single client still depends on a follow-up wire-format change with a
|
||||||
|
# feature flag.
|
||||||
|
sid = resolve_progress_text_sid(sid, self.client_id)
|
||||||
|
|
||||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import pytest
|
|||||||
from comfy_execution.metadata import (
|
from comfy_execution.metadata import (
|
||||||
build_prompt_metadata,
|
build_prompt_metadata,
|
||||||
merge_prompt_metadata,
|
merge_prompt_metadata,
|
||||||
|
resolve_progress_text_sid,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ class TestMergeMetadata:
|
|||||||
(the preview tuple) pass through untouched."""
|
(the preview tuple) pass through untouched."""
|
||||||
|
|
||||||
def test_merges_workflow_id_when_prompt_id_known(self, registry, lock):
|
def test_merges_workflow_id_when_prompt_id_known(self, registry, lock):
|
||||||
registry["p1"] = {"workflow_id": "wf-1"}
|
registry["p1"] = [{"workflow_id": "wf-1"}]
|
||||||
merged = merge_prompt_metadata(registry, lock, {"node": "n1", "prompt_id": "p1"})
|
merged = merge_prompt_metadata(registry, lock, {"node": "n1", "prompt_id": "p1"})
|
||||||
assert merged == {"node": "n1", "prompt_id": "p1", "workflow_id": "wf-1"}
|
assert merged == {"node": "n1", "prompt_id": "p1", "workflow_id": "wf-1"}
|
||||||
|
|
||||||
@ -65,23 +66,83 @@ class TestMergeMetadata:
|
|||||||
assert merged == {"node": "n1", "prompt_id": "missing"}
|
assert merged == {"node": "n1", "prompt_id": "missing"}
|
||||||
|
|
||||||
def test_passthrough_when_no_prompt_id(self, registry, lock):
|
def test_passthrough_when_no_prompt_id(self, registry, lock):
|
||||||
registry["p1"] = {"workflow_id": "wf-1"}
|
registry["p1"] = [{"workflow_id": "wf-1"}]
|
||||||
merged = merge_prompt_metadata(registry, lock, {"status": {"queue_remaining": 0}})
|
merged = merge_prompt_metadata(registry, lock, {"status": {"queue_remaining": 0}})
|
||||||
assert merged == {"status": {"queue_remaining": 0}}
|
assert merged == {"status": {"queue_remaining": 0}}
|
||||||
|
|
||||||
def test_passthrough_for_non_dict_payload(self, registry, lock):
|
def test_passthrough_for_non_dict_payload(self, registry, lock):
|
||||||
registry["p1"] = {"workflow_id": "wf-1"}
|
registry["p1"] = [{"workflow_id": "wf-1"}]
|
||||||
binary = (b"image-bytes", {"prompt_id": "p1"})
|
binary = (b"image-bytes", {"prompt_id": "p1"})
|
||||||
assert merge_prompt_metadata(registry, lock, binary) is binary
|
assert merge_prompt_metadata(registry, lock, binary) is binary
|
||||||
|
|
||||||
def test_event_payload_wins_over_registered_metadata(self, registry, lock):
|
def test_event_payload_wins_over_registered_metadata(self, registry, lock):
|
||||||
registry["p1"] = {"workflow_id": "wf-registered"}
|
registry["p1"] = [{"workflow_id": "wf-registered"}]
|
||||||
merged = merge_prompt_metadata(
|
merged = merge_prompt_metadata(
|
||||||
registry, lock, {"prompt_id": "p1", "workflow_id": "wf-caller"}
|
registry, lock, {"prompt_id": "p1", "workflow_id": "wf-caller"}
|
||||||
)
|
)
|
||||||
assert merged["workflow_id"] == "wf-caller"
|
assert merged["workflow_id"] == "wf-caller"
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressTextSidResolution:
|
||||||
|
"""``BinaryEventTypes.TEXT`` frames don't yet carry ``prompt_id`` /
|
||||||
|
``workflow_id`` in their wire shape, so cross-client routing has to happen
|
||||||
|
at the ``sid`` level instead of via the metadata merge. The default sid
|
||||||
|
pins the broadcast to the active prompt's client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_explicit_sid_passes_through(self):
|
||||||
|
assert resolve_progress_text_sid("client-explicit", "client-active") == "client-explicit"
|
||||||
|
|
||||||
|
def test_none_sid_defaults_to_active_client(self):
|
||||||
|
assert resolve_progress_text_sid(None, "client-active") == "client-active"
|
||||||
|
|
||||||
|
def test_none_sid_with_no_active_client_stays_none(self):
|
||||||
|
# No active prompt means there is no sensible recipient — fall through
|
||||||
|
# to broadcast rather than fabricate a target.
|
||||||
|
assert resolve_progress_text_sid(None, None) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptIdCollision:
|
||||||
|
"""Two prompts may be submitted with the same ``prompt_id`` (client retry,
|
||||||
|
forced custom id, partner-integration deduplication, etc.). With a flat
|
||||||
|
dict-keyed registry the second registration would clobber the first and a
|
||||||
|
single ``unregister`` call would erase metadata still needed by the other
|
||||||
|
prompt. The stack-based registry resolves both cases."""
|
||||||
|
|
||||||
|
def test_duplicate_register_does_not_clobber_prior_entry(self, registry, lock):
|
||||||
|
# Caller B clobbers A in the merge view (last-wins), but A's metadata
|
||||||
|
# is still in the stack and reappears after B unregisters.
|
||||||
|
registry.setdefault("p1", []).append({"workflow_id": "wf-A"})
|
||||||
|
registry["p1"].append({"workflow_id": "wf-B"})
|
||||||
|
|
||||||
|
merged = merge_prompt_metadata(registry, lock, {"prompt_id": "p1"})
|
||||||
|
assert merged["workflow_id"] == "wf-B"
|
||||||
|
|
||||||
|
registry["p1"].pop()
|
||||||
|
|
||||||
|
merged = merge_prompt_metadata(registry, lock, {"prompt_id": "p1"})
|
||||||
|
assert merged["workflow_id"] == "wf-A"
|
||||||
|
|
||||||
|
def test_single_unregister_does_not_drop_concurrent_submission(self, registry, lock):
|
||||||
|
registry.setdefault("p1", []).append({"workflow_id": "wf-A"})
|
||||||
|
registry["p1"].append({"workflow_id": "wf-B"})
|
||||||
|
|
||||||
|
# Only one of the two prompts finished — pop once.
|
||||||
|
registry["p1"].pop()
|
||||||
|
|
||||||
|
merged = merge_prompt_metadata(registry, lock, {"prompt_id": "p1"})
|
||||||
|
assert "workflow_id" in merged
|
||||||
|
|
||||||
|
def test_full_drain_clears_registry(self, registry, lock):
|
||||||
|
registry.setdefault("p1", []).append({"workflow_id": "wf-A"})
|
||||||
|
registry["p1"].append({"workflow_id": "wf-B"})
|
||||||
|
registry["p1"].pop()
|
||||||
|
registry["p1"].pop()
|
||||||
|
|
||||||
|
merged = merge_prompt_metadata(registry, lock, {"prompt_id": "p1"})
|
||||||
|
assert "workflow_id" not in merged
|
||||||
|
|
||||||
|
|
||||||
class TestRaceRegressionForTerminalExecutingFrame:
|
class TestRaceRegressionForTerminalExecutingFrame:
|
||||||
"""Regression for the PR #13684 finally-clear race.
|
"""Regression for the PR #13684 finally-clear race.
|
||||||
|
|
||||||
@ -95,11 +156,13 @@ class TestRaceRegressionForTerminalExecutingFrame:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_terminal_executing_frame_includes_workflow_id(self, registry, lock):
|
def test_terminal_executing_frame_includes_workflow_id(self, registry, lock):
|
||||||
registry["p1"] = {"workflow_id": "wf-1"}
|
registry["p1"] = [{"workflow_id": "wf-1"}]
|
||||||
|
|
||||||
# main.py emits the terminal frame BEFORE unregistering.
|
# main.py emits the terminal frame BEFORE unregistering.
|
||||||
terminal = merge_prompt_metadata(registry, lock, {"node": None, "prompt_id": "p1"})
|
terminal = merge_prompt_metadata(registry, lock, {"node": None, "prompt_id": "p1"})
|
||||||
registry.pop("p1", None) # main.py's finally: unregister_prompt_metadata
|
registry["p1"].pop()
|
||||||
|
if not registry["p1"]:
|
||||||
|
registry.pop("p1", None)
|
||||||
|
|
||||||
assert terminal == {"node": None, "prompt_id": "p1", "workflow_id": "wf-1"}
|
assert terminal == {"node": None, "prompt_id": "p1", "workflow_id": "wf-1"}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user