ComfyUI/tests-unit/app_test/test_prompt_metadata.py
Deep Mehta 63784baed5 fix(server): serialize PromptMetadataStore access with a lock
Addresses comfyanonymous's review nit on PR #13905. The store is touched
from three threads — the aiohttp event loop (``register`` via
``post_prompt``), the worker thread (``unregister`` via the
``prompt_worker`` try/finally and ``execution_error`` paths), and any
thread that fires ``send_sync`` (``inject``). Individual ``dict``
operations are GIL-atomic but ``register``'s ``len -> pop -> setitem``
and ``inject``'s ``get -> {**a, **b}`` are multi-step compounds whose
interleaving without a lock is racy. A single ``threading.Lock`` keeps
the FIFO cap honest and snapshots the envelope under the lock before
the spread runs.

Adds a stress-test that runs concurrent register/unregister/inject for
100 ms across five threads and asserts no exception escapes and the
capacity bound is held.
2026-05-14 21:26:06 -07:00

413 lines
16 KiB
Python

"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``.
Covers the two pure helpers (``extract_envelope_from_extra_data`` and
``inject_envelope``) and the ``PromptMetadataStore`` integration class
that ``PromptServer`` owns.
"""
from __future__ import annotations
import pytest
from app.prompt_metadata import (
MAX_ENVELOPE_KEYS,
MAX_ENVELOPE_KEY_LEN,
MAX_ENVELOPE_VALUE_LEN,
PromptMetadataStore,
extract_envelope_from_extra_data,
inject_envelope,
)
class TestExtractEnvelopeFromExtraData:
def test_explicit_metadata_dict_is_used_as_is(self):
extra_data = {"metadata": {"workflow_id": "wf-1", "trace_id": "t-9"}}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-1",
"trace_id": "t-9",
}
def test_explicit_metadata_takes_precedence_over_extra_pnginfo(self):
extra_data = {
"metadata": {"workflow_id": "explicit"},
"extra_pnginfo": {"workflow": {"id": "fallback"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "explicit"
}
def test_falls_back_to_extra_pnginfo_workflow_id(self):
extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-legacy"}}}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
def test_returns_none_when_no_metadata_and_no_workflow_id(self):
assert extract_envelope_from_extra_data({}) is None
assert (
extract_envelope_from_extra_data({"extra_pnginfo": {"workflow": {}}})
is None
)
@pytest.mark.parametrize("bad", ["", 123, None, [], {}])
def test_rejects_non_string_or_empty_workflow_id(self, bad):
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
assert extract_envelope_from_extra_data(extra_data) is None
def test_rejects_non_dict_inputs_at_each_level(self):
assert extract_envelope_from_extra_data(None) is None
assert extract_envelope_from_extra_data("not-a-dict") is None
assert (
extract_envelope_from_extra_data({"extra_pnginfo": "not-a-dict"})
is None
)
assert (
extract_envelope_from_extra_data(
{"extra_pnginfo": {"workflow": "not-a-dict"}}
)
is None
)
def test_empty_explicit_metadata_falls_through_to_workflow_id(self):
extra_data = {
"metadata": {},
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
def test_returned_envelope_is_copy_not_reference(self):
original = {"workflow_id": "wf-1"}
result = extract_envelope_from_extra_data({"metadata": original})
assert result is not None
result["new_key"] = "x"
assert "new_key" not in original
def test_non_dict_explicit_metadata_falls_through_to_workflow_id(self):
extra_data = {
"metadata": "not-a-dict",
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) == {
"workflow_id": "wf-legacy"
}
class TestEnvelopeSanitization:
"""The wire contract is ``dict[str, str]`` with bounded size. A bad
envelope is dropped (and a warning is logged) rather than truncated,
so the boundary stays strict."""
def test_rejects_too_many_keys(self, caplog):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}
with caplog.at_level("WARNING"):
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
assert any("exceeds limit" in r.message for r in caplog.records)
def test_accepts_max_keys_exactly(self):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_rejects_non_string_keys(self, caplog):
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data({"metadata": {42: "v"}})
is None
)
assert any("non-string" in r.message for r in caplog.records)
def test_rejects_non_string_values(self, caplog):
for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]:
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data(
{"metadata": {"k": bad_value}}
)
is None
)
def test_rejects_oversized_key(self):
envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_rejects_oversized_value(self):
envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_accepts_max_lengths_exactly(self):
envelope = {
"x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN
}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_oversized_workflow_id_in_pnginfo_rejected(self):
"""The legacy synthesized path also respects the value bound."""
extra_data = {
"extra_pnginfo": {
"workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
}
}
assert extract_envelope_from_extra_data(extra_data) is None
def test_invalid_explicit_metadata_does_not_fall_through(self):
"""An explicit but invalid metadata dict means the caller asked
for something specific and got it wrong; the synthesized
fallback must not silently substitute."""
extra_data = {
"metadata": {"k": 42}, # non-string value
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) is None
class TestInjectEnvelope:
@staticmethod
def _lookup(table):
return table.get
def test_spreads_envelope_keys_onto_payload(self):
"""Envelope keys are merged at the top level so consumers can
read them directly (e.g. ``event.workflow_id``)."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1", "trace_id": "t-9"}})
assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == {
"node": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
"trace_id": "t-9",
}
def test_passthrough_when_prompt_id_not_registered(self):
lookup = self._lookup({})
data = {"node": "5", "prompt_id": "unknown"}
assert inject_envelope(data, lookup) == data
def test_passthrough_when_payload_lacks_prompt_id(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
data = {"status": "ok"}
assert inject_envelope(data, lookup) == data
def test_server_keys_win_on_collision_with_envelope(self):
"""A misbehaving client cannot shadow server-emitted fields by
stamping the same key in their submission envelope."""
lookup = self._lookup({
"p1": {"prompt_id": "client-claimed", "node": "spoofed", "workflow_id": "wf-1"}
})
result = inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
assert result["prompt_id"] == "p1"
assert result["node"] == "5"
assert result["workflow_id"] == "wf-1"
def test_does_not_mutate_input_dict(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
original = {"node": "5", "prompt_id": "p1"}
inject_envelope(original, lookup)
assert "workflow_id" not in original
def test_does_not_mutate_envelope_dict(self):
envelope = {"workflow_id": "wf-1"}
lookup = self._lookup({"p1": envelope})
inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
assert envelope == {"workflow_id": "wf-1"}
def test_injects_into_inner_dict_of_preview_metadata_tuple(self):
"""``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as
``(preview_image, metadata_dict)``; the inner dict is the only
place the envelope can attach."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
preview_image = ("PNG", object(), 256)
inner = {"node_id": "5", "prompt_id": "p1"}
result = inject_envelope((preview_image, inner), lookup)
assert isinstance(result, tuple)
assert result[0] is preview_image
assert result[1] == {
"node_id": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
}
assert "workflow_id" not in inner
def test_preview_tuple_passthrough_when_no_envelope_registered(self):
lookup = self._lookup({})
preview_image = ("PNG", object(), 256)
inner = {"node_id": "5", "prompt_id": "unknown"}
result = inject_envelope((preview_image, inner), lookup)
assert result == (preview_image, inner)
@pytest.mark.parametrize("payload", [b"raw-bytes", None, 42])
def test_non_dict_non_tuple_payloads_passthrough(self, payload):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
assert inject_envelope(payload, lookup) == payload
def test_tuple_of_wrong_arity_passthrough(self):
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is
special-cased. Other tuples must not be touched."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
triple = (1, {"prompt_id": "p1"}, 3)
assert inject_envelope(triple, lookup) is triple
def test_envelope_lookup_called_per_invocation(self):
"""The lookup runs each time the function is called, so changes
to the backing store are immediately visible."""
store = {"p1": {"workflow_id": "wf-1"}}
first = inject_envelope({"prompt_id": "p1"}, store.get)
store["p1"] = {"workflow_id": "wf-2"}
second = inject_envelope({"prompt_id": "p1"}, store.get)
del store["p1"]
third = inject_envelope({"prompt_id": "p1"}, store.get)
assert first["workflow_id"] == "wf-1"
assert second["workflow_id"] == "wf-2"
assert "workflow_id" not in third
class TestPromptMetadataStore:
"""End-to-end wiring tests that exercise the full register/inject/
unregister cycle the way ``PromptServer`` does."""
def test_register_inject_unregister_cycle(self):
store = PromptMetadataStore()
store.register(
"p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
)
injected = store.inject({"node": "5", "prompt_id": "p1"})
assert injected == {
"node": "5",
"prompt_id": "p1",
"workflow_id": "wf-1",
}
store.unregister("p1")
passthrough = store.inject({"node": "5", "prompt_id": "p1"})
assert "workflow_id" not in passthrough
def test_register_with_no_derivable_envelope_is_noop(self):
store = PromptMetadataStore()
store.register("p1", {})
assert "p1" not in store
data = {"prompt_id": "p1"}
assert store.inject(data) == data
def test_register_with_oversized_envelope_is_noop(self):
"""Sanitization rejection means nothing is registered — the
store stays empty and inject is a passthrough."""
store = PromptMetadataStore()
store.register(
"p1",
{"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}},
)
assert "p1" not in store
def test_unregister_unknown_prompt_is_silent(self):
store = PromptMetadataStore()
store.unregister("does-not-exist")
def test_fifo_eviction_when_capacity_exceeded(self):
"""If cleanup hooks are ever bypassed, the store must shed the
oldest entry rather than grow without bound."""
store = PromptMetadataStore(capacity=3)
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p2", {"metadata": {"workflow_id": "wf-2"}})
store.register("p3", {"metadata": {"workflow_id": "wf-3"}})
assert len(store) == 3
store.register("p4", {"metadata": {"workflow_id": "wf-4"}})
assert len(store) == 3
assert "p1" not in store
assert "p4" in store
# The newer entries are still injectable.
assert store.inject({"prompt_id": "p4"})["workflow_id"] == "wf-4"
# The evicted one is gone.
assert "workflow_id" not in store.inject({"prompt_id": "p1"})
def test_register_after_unregister_does_not_count_against_capacity(self):
"""Normal lifecycle: register, unregister, register many — the
store should not silently evict valid entries because of stale
accounting."""
store = PromptMetadataStore(capacity=2)
for i in range(10):
store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}})
store.unregister(f"p{i}")
assert len(store) == 0
def test_re_register_overwrites(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p1", {"metadata": {"workflow_id": "wf-2"}})
assert store.inject({"prompt_id": "p1"})["workflow_id"] == "wf-2"
def test_inject_with_no_registrations_is_passthrough(self):
store = PromptMetadataStore()
data = {"prompt_id": "p1", "node": "5"}
assert store.inject(data) == data
def test_inject_into_preview_tuple(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
result = store.inject((b"image-bytes", {"prompt_id": "p1"}))
assert result == (b"image-bytes", {
"prompt_id": "p1",
"workflow_id": "wf-1",
})
def test_concurrent_access_does_not_corrupt_or_raise(self):
"""Smoke test for the store's lock. ``register`` is called from
the aiohttp event-loop thread, ``unregister`` from the worker
thread, and ``inject`` fires on every ``send_sync`` from
whichever thread emits the event. Run all three concurrently
and assert no exception escapes and the store stays internally
consistent (the FIFO cap is never exceeded)."""
import threading
store = PromptMetadataStore(capacity=64)
stop = threading.Event()
errors: list[BaseException] = []
def registrar():
i = 0
try:
while not stop.is_set():
store.register(
f"p{i % 100}",
{"metadata": {"workflow_id": f"wf-{i}"}},
)
i += 1
except BaseException as e:
errors.append(e)
def canceller():
i = 0
try:
while not stop.is_set():
store.unregister(f"p{i % 100}")
i += 1
except BaseException as e:
errors.append(e)
def injector():
i = 0
try:
while not stop.is_set():
store.inject({"prompt_id": f"p{i % 100}", "node": "5"})
i += 1
except BaseException as e:
errors.append(e)
threads = [
threading.Thread(target=registrar),
threading.Thread(target=registrar),
threading.Thread(target=canceller),
threading.Thread(target=injector),
threading.Thread(target=injector),
]
for t in threads:
t.start()
# Brief burst — long enough to interleave many ops, short enough
# not to slow CI.
threading.Event().wait(0.1)
stop.set()
for t in threads:
t.join(timeout=2.0)
assert errors == [], f"concurrent access raised: {errors[:3]}"
assert len(store) <= 64, "FIFO cap was breached under contention"