fix(server): bound metadata envelope and clean up on cancel paths

Addresses review feedback on the per-prompt metadata envelope:

- Sanitize at the boundary: reject envelopes larger than 16 keys, keys
  over 64 chars, values over 256 chars, or anything that isn't a flat
  ``dict[str, str]``. Logs a warning so abuse is observable. Stops a
  malicious client from inflating broadcast volume by stamping a 10 MB
  metadata blob onto every WS event.
- Cap the in-memory store at 4096 concurrent envelopes with FIFO
  eviction. Acts as a backstop if any cleanup hook is skipped.
- Drop envelopes when prompts are cancelled before reaching the worker:
  ``PromptQueue.wipe_queue`` and ``delete_queue_item`` now call
  ``server.unregister_prompt_metadata`` for every removed item.
- Drop envelopes on hard execution failures: the worker now wraps
  ``e.execute()`` in ``try/finally``, so an uncaught exception in
  execution no longer leaks the envelope.
- Guard the WS reconnect handler: only include ``prompt_id`` in the
  ``executing`` payload when ``last_prompt_id`` is set, so clients
  with strict schemas (zod ``prompt_id: zJobId``) don't reject the
  message with a null id.
- Extract a ``PromptMetadataStore`` class that owns the dict and the
  bounds, so ``PromptServer`` becomes a thin delegating layer and the
  full register/inject/unregister cycle (plus FIFO eviction and
  sanitization) is unit-tested without torch.

44 tests passing; ruff clean on all touched files.
This commit is contained in:
Deep Mehta 2026-05-14 21:03:38 -07:00
parent 74cfcaa318
commit fd89498eac
5 changed files with 362 additions and 83 deletions

View File

@ -1,48 +1,121 @@
"""Per-prompt metadata envelope shared between submission and outbound events. """Per-prompt metadata envelope shared between submission and outbound events.
The metadata envelope is an opaque dict (e.g. ``{"workflow_id": ...}``) The metadata envelope is a small flat ``dict[str, str]`` (e.g.
attached to a prompt at submission and injected by the server into every ``{"workflow_id": ...}``) attached to a prompt at submission and injected
outbound execution event that carries a ``prompt_id``. It lets consumers by the server into every outbound execution event that carries a
scope state by tags they care about (workflow, trace, tenant) without the ``prompt_id``. It lets consumers scope state by tags they care about
execution layer ever needing to know those tags exist. (workflow, trace, tenant) without the execution layer ever needing to
know those tags exist.
Two pure functions live here; ``PromptServer`` owns the per-prompt map and This module is intentionally pure no imports from ``server`` or
wires them into the submission and send paths. ``execution`` so ``PromptServer`` can own a ``PromptMetadataStore``
instance and the helpers can be unit-tested without the rest of the app.
""" """
from __future__ import annotations from __future__ import annotations
import logging
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
# Bounds. The envelope is forwarded to every WebSocket client connected to
# the server on every execution event for the prompt — bounding key count,
# key length, value length, and refusing nested structures keeps a
# malicious or buggy client from inflating the broadcast volume.
MAX_ENVELOPE_KEYS = 16
MAX_ENVELOPE_KEY_LEN = 64
MAX_ENVELOPE_VALUE_LEN = 256
# Cap on concurrently registered prompt envelopes. Acts as a backstop if
# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale
# entry goes first.
DEFAULT_STORE_CAPACITY = 4096
def _sanitize_envelope(envelope: Any) -> Optional[dict]:
"""Validate and copy a candidate envelope.
Enforces the ``dict[str, str]`` contract that downstream consumers
(cloud projections, frontend zod schemas, OpenAPI docs) rely on:
- must be a non-empty ``dict``
- at most ``MAX_ENVELOPE_KEYS`` entries
- every key and value must be a ``str``
- keys at most ``MAX_ENVELOPE_KEY_LEN`` chars
- values at most ``MAX_ENVELOPE_VALUE_LEN`` chars
Returns a defensive shallow copy on success, ``None`` on any
violation. Logs a warning on violation so abuse is visible.
"""
if not isinstance(envelope, dict) or not envelope:
return None
if len(envelope) > MAX_ENVELOPE_KEYS:
logging.warning(
"prompt metadata envelope rejected: %d keys exceeds limit %d",
len(envelope), MAX_ENVELOPE_KEYS,
)
return None
sanitized: dict[str, str] = {}
for key, value in envelope.items():
if not isinstance(key, str) or not isinstance(value, str):
logging.warning(
"prompt metadata envelope rejected: non-string key/value (%s=%s)",
type(key).__name__, type(value).__name__,
)
return None
if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN:
logging.warning(
"prompt metadata envelope rejected: key or value exceeds length limit",
)
return None
sanitized[key] = value
return sanitized
def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]: def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]:
"""Pull the per-prompt metadata envelope out of a submitted prompt's """Pull the per-prompt metadata envelope out of a submitted prompt's
``extra_data``. ``extra_data``.
Two sources, in order: Two sources, in order:
1. Explicit ``extra_data["metadata"]`` dict preferred path, accepted 1. Explicit ``extra_data["metadata"]`` sanitized via
as-is (copied so later mutations on the caller's dict don't leak). ``_sanitize_envelope``. Oversized or wrong-typed envelopes are
rejected (a warning is logged) rather than truncated, so the
contract stays strict at the boundary.
2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` backward- 2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` backward-
compatibility fallback. Frontends that already stamp the workflow compatibility fallback. Frontends that already stamp the workflow
id into ``extra_pnginfo`` keep working without changes; the id into ``extra_pnginfo`` keep working; the synthesized envelope
synthesized envelope is ``{"workflow_id": <id>}``. is ``{"workflow_id": <id>}``. A debug log fires so the legacy path
remains observable.
Returns ``None`` when neither source yields a usable envelope. Returns ``None`` when neither source yields a usable envelope.
""" """
if not isinstance(extra_data, dict): if not isinstance(extra_data, dict):
return None return None
metadata = extra_data.get("metadata") if "metadata" in extra_data:
if isinstance(metadata, dict) and metadata: sanitized = _sanitize_envelope(extra_data["metadata"])
return dict(metadata) if sanitized is not None:
return sanitized
# Explicit metadata was supplied but rejected — do not fall
# through to the legacy path; the caller asked for something
# specific and got it wrong.
if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]:
return None
extra_pnginfo = extra_data.get("extra_pnginfo") extra_pnginfo = extra_data.get("extra_pnginfo")
if isinstance(extra_pnginfo, dict): if isinstance(extra_pnginfo, dict):
workflow = extra_pnginfo.get("workflow") workflow = extra_pnginfo.get("workflow")
if isinstance(workflow, dict): if isinstance(workflow, dict):
workflow_id = workflow.get("id") workflow_id = workflow.get("id")
if isinstance(workflow_id, str) and workflow_id: if (
isinstance(workflow_id, str)
and workflow_id
and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN
):
logging.debug(
"prompt metadata envelope synthesized from extra_pnginfo.workflow.id"
)
return {"workflow_id": workflow_id} return {"workflow_id": workflow_id}
return None return None
@ -55,28 +128,25 @@ def inject_envelope(
"""Return ``data`` with a per-prompt ``metadata`` envelope attached. """Return ``data`` with a per-prompt ``metadata`` envelope attached.
``envelope_lookup`` is called with the payload's ``prompt_id`` and is ``envelope_lookup`` is called with the payload's ``prompt_id`` and is
expected to return the registered envelope or ``None``. This indirection expected to return the registered envelope or ``None``. This keeps
keeps the function pure and avoids depending on any specific storage. the function pure and avoids depending on any specific storage.
Two payload shapes are handled: Two payload shapes are handled:
- **dict** carrying ``prompt_id``. A shallow copy is returned with a - **dict** carrying ``prompt_id``. A shallow copy is returned with a
``metadata`` key set to the envelope. The caller's dict is not ``metadata`` key set to the envelope.
mutated.
- **(preview_image, metadata_dict) tuple** the format used by - **(preview_image, metadata_dict) tuple** the format used by
``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented; ``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented;
the binary preview is passed through by reference. the binary preview is passed through by reference.
The function is a no-op for: No-op for payloads without a ``prompt_id``, payloads already
declaring their own ``metadata`` field, prompts with no registered
- payloads without a ``prompt_id``, envelope, or any other payload shape.
- payloads already declaring their own ``metadata`` field
(callers can opt out by setting it explicitly),
- prompts with no registered envelope,
- any other payload shape (raw bytes, ``None``, etc.).
""" """
def inject(d: dict) -> dict: def inject(d: dict) -> dict:
if not isinstance(d, dict) or "metadata" in d: if not isinstance(d, dict):
return d
if "metadata" in d:
return d return d
prompt_id = d.get("prompt_id") prompt_id = d.get("prompt_id")
if not prompt_id: if not prompt_id:
@ -94,3 +164,37 @@ def inject_envelope(
return data return data
return (data[0], injected) return (data[0], injected)
return data return data
class PromptMetadataStore:
"""Bounded ``prompt_id -> envelope`` map.
Owned by ``PromptServer``. Populated at submission, drained when the
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
backstop: if any cleanup hook is ever skipped, the store sheds the
oldest entry instead of growing without bound.
"""
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
self._envelopes: dict[str, dict] = {}
self._capacity = capacity
def register(self, prompt_id: str, extra_data: Any) -> None:
envelope = extract_envelope_from_extra_data(extra_data)
if envelope is None:
return
if len(self._envelopes) >= self._capacity:
self._envelopes.pop(next(iter(self._envelopes)))
self._envelopes[prompt_id] = envelope
def unregister(self, prompt_id: str) -> None:
self._envelopes.pop(prompt_id, None)
def inject(self, data: Any) -> Any:
return inject_envelope(data, self._envelopes.get)
def __len__(self) -> int:
return len(self._envelopes)
def __contains__(self, prompt_id: str) -> bool:
return prompt_id in self._envelopes

View File

@ -1296,7 +1296,10 @@ class PromptQueue:
def wipe_queue(self): def wipe_queue(self):
with self.mutex: with self.mutex:
dropped_prompt_ids = [item[1] for item in self.queue]
self.queue = [] self.queue = []
for prompt_id in dropped_prompt_ids:
self.server.unregister_prompt_metadata(prompt_id)
self.server.queue_updated() self.server.queue_updated()
def delete_queue_item(self, function): def delete_queue_item(self, function):
@ -1306,8 +1309,9 @@ class PromptQueue:
if len(self.queue) == 1: if len(self.queue) == 1:
self.wipe_queue() self.wipe_queue()
else: else:
self.queue.pop(x) deleted = self.queue.pop(x)
heapq.heapify(self.queue) heapq.heapify(self.queue)
self.server.unregister_prompt_metadata(deleted[1])
self.server.queue_updated() self.server.queue_updated()
return True return True
return False return False

31
main.py
View File

@ -318,21 +318,26 @@ def prompt_worker(q, server_instance):
extra_data[k] = sensitive[k] extra_data[k] = sensitive[k]
asset_seeder.pause() asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4]) try:
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id, q.task_done(item_id,
e.history_result, e.history_result,
status=execution.PromptQueue.ExecutionStatus( status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error', status_str='success' if e.success else 'error',
completed=e.success, completed=e.success,
messages=e.status_messages), process_item=remove_sensitive) messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None: if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
finally:
server_instance.unregister_prompt_metadata(prompt_id) # Always drop the metadata envelope. If e.execute() raises
# hard before its own error handling kicks in, the
# registered envelope would otherwise leak for the
# lifetime of the process.
server_instance.unregister_prompt_metadata(prompt_id)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time

View File

@ -44,10 +44,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager from app.node_replace_manager import NodeReplaceManager
from app.prompt_metadata import ( from app.prompt_metadata import PromptMetadataStore
extract_envelope_from_extra_data,
inject_envelope,
)
from typing import Optional, Union from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes from protocol import BinaryEventTypes
@ -257,10 +254,10 @@ class PromptServer():
self.last_prompt_id = None self.last_prompt_id = None
self.client_id = None self.client_id = None
# prompt_id -> metadata envelope captured at submission and injected # Bounded prompt_id -> envelope store. Populated at submission,
# into outbound execution events. Keeps the workflow scope (and any # drained on completion/cancel. Keeps workflow scope (and other
# other client-supplied tags) out of the execution layer. # client-supplied tags) out of the execution layer.
self._prompt_metadata: dict[str, dict] = {} self._prompt_metadata = PromptMetadataStore()
self.on_prompt_handlers = [] self.on_prompt_handlers = []
@ -285,10 +282,12 @@ class PromptServer():
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node # On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None: if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", self._inject_prompt_metadata({ payload: dict = {"node": self.last_node_id}
"node": self.last_node_id, if self.last_prompt_id is not None:
"prompt_id": self.last_prompt_id, payload["prompt_id"] = self.last_prompt_id
}), sid) await self.send(
"executing", self._inject_prompt_metadata(payload), sid
)
# Flag to track if we've received the first message # Flag to track if we've received the first message
first_message = True first_message = True
@ -1233,17 +1232,16 @@ class PromptServer():
def register_prompt_metadata(self, prompt_id: str, extra_data) -> None: def register_prompt_metadata(self, prompt_id: str, extra_data) -> None:
"""Record per-prompt metadata for injection into outbound execution """Record per-prompt metadata for injection into outbound execution
events. Called at submission, before the prompt is queued.""" events. Called at submission, before the prompt is queued."""
envelope = extract_envelope_from_extra_data(extra_data) self._prompt_metadata.register(prompt_id, extra_data)
if envelope is not None:
self._prompt_metadata[prompt_id] = envelope
def unregister_prompt_metadata(self, prompt_id: str) -> None: def unregister_prompt_metadata(self, prompt_id: str) -> None:
"""Drop the per-prompt metadata envelope. Called after the prompt """Drop the per-prompt metadata envelope. Called after the prompt
has finished executing and its terminal events have been queued.""" has finished executing and its terminal events have been queued,
self._prompt_metadata.pop(prompt_id, None) or when the prompt is cancelled before reaching the worker."""
self._prompt_metadata.unregister(prompt_id)
def _inject_prompt_metadata(self, data): def _inject_prompt_metadata(self, data):
return inject_envelope(data, self._prompt_metadata.get) return self._prompt_metadata.inject(data)
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
data = self._inject_prompt_metadata(data) data = self._inject_prompt_metadata(data)

View File

@ -1,12 +1,19 @@
"""Unit tests for the pure metadata-envelope helpers in """Unit tests for the metadata-envelope module in ``app.prompt_metadata``.
``app.prompt_metadata``. These cover the two functions that PromptServer
wires into submission (``extract_envelope_from_extra_data``) and into the Covers the two pure helpers (``extract_envelope_from_extra_data`` and
send chokepoint (``inject_envelope``). ``inject_envelope``) and the ``PromptMetadataStore`` integration class
that ``PromptServer`` owns.
""" """
from __future__ import annotations from __future__ import annotations
import pytest
from app.prompt_metadata import ( from app.prompt_metadata import (
MAX_ENVELOPE_KEYS,
MAX_ENVELOPE_KEY_LEN,
MAX_ENVELOPE_VALUE_LEN,
PromptMetadataStore,
extract_envelope_from_extra_data, extract_envelope_from_extra_data,
inject_envelope, inject_envelope,
) )
@ -42,10 +49,10 @@ class TestExtractEnvelopeFromExtraData:
is None is None
) )
def test_rejects_non_string_or_empty_workflow_id(self): @pytest.mark.parametrize("bad", ["", 123, None, [], {}])
for bad in ["", 123, None, [], {}]: def test_rejects_non_string_or_empty_workflow_id(self, bad):
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}} extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
assert extract_envelope_from_extra_data(extra_data) is None assert extract_envelope_from_extra_data(extra_data) is None
def test_rejects_non_dict_inputs_at_each_level(self): def test_rejects_non_dict_inputs_at_each_level(self):
assert extract_envelope_from_extra_data(None) is None assert extract_envelope_from_extra_data(None) is None
@ -73,6 +80,7 @@ class TestExtractEnvelopeFromExtraData:
def test_returned_envelope_is_copy_not_reference(self): def test_returned_envelope_is_copy_not_reference(self):
original = {"workflow_id": "wf-1"} original = {"workflow_id": "wf-1"}
result = extract_envelope_from_extra_data({"metadata": original}) result = extract_envelope_from_extra_data({"metadata": original})
assert result is not None
result["new_key"] = "x" result["new_key"] = "x"
assert "new_key" not in original assert "new_key" not in original
@ -86,10 +94,76 @@ class TestExtractEnvelopeFromExtraData:
} }
class TestEnvelopeSanitization:
"""The wire contract is ``dict[str, str]`` with bounded size. A bad
envelope is dropped (and a warning is logged) rather than truncated,
so the boundary stays strict."""
def test_rejects_too_many_keys(self, caplog):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}
with caplog.at_level("WARNING"):
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
assert any("exceeds limit" in r.message for r in caplog.records)
def test_accepts_max_keys_exactly(self):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_rejects_non_string_keys(self, caplog):
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data({"metadata": {42: "v"}})
is None
)
assert any("non-string" in r.message for r in caplog.records)
def test_rejects_non_string_values(self, caplog):
for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]:
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data(
{"metadata": {"k": bad_value}}
)
is None
)
def test_rejects_oversized_key(self):
envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_rejects_oversized_value(self):
envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_accepts_max_lengths_exactly(self):
envelope = {
"x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN
}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_oversized_workflow_id_in_pnginfo_rejected(self):
"""The legacy synthesized path also respects the value bound."""
extra_data = {
"extra_pnginfo": {
"workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
}
}
assert extract_envelope_from_extra_data(extra_data) is None
def test_invalid_explicit_metadata_does_not_fall_through(self):
"""An explicit but invalid metadata dict means the caller asked
for something specific and got it wrong; the synthesized
fallback must not silently substitute."""
extra_data = {
"metadata": {"k": 42}, # non-string value
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) is None
class TestInjectEnvelope: class TestInjectEnvelope:
@staticmethod @staticmethod
def _lookup(table): def _lookup(table):
"""Build an envelope_lookup callable backed by a dict."""
return table.get return table.get
def test_injects_envelope_on_dict_with_known_prompt_id(self): def test_injects_envelope_on_dict_with_known_prompt_id(self):
@ -108,16 +182,15 @@ class TestInjectEnvelope:
def test_passthrough_when_payload_lacks_prompt_id(self): def test_passthrough_when_payload_lacks_prompt_id(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
data = {"status": "ok"} data = {"status": "ok"}
assert inject_envelope(data, lookup) is data assert inject_envelope(data, lookup) == data
def test_passthrough_when_payload_already_has_metadata(self): def test_passthrough_when_payload_already_has_metadata(self):
"""If a caller has already set a ``metadata`` field (e.g. for """If a caller has already set a ``metadata`` field, the
opt-out or pre-augmented payloads), the function must not function must not overwrite it."""
overwrite it."""
lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}})
data = {"prompt_id": "p1", "metadata": {"workflow_id": "wf-caller"}} data = {"prompt_id": "p1", "metadata": {"workflow_id": "wf-caller"}}
result = inject_envelope(data, lookup) result = inject_envelope(data, lookup)
assert result is data assert result == data
assert result["metadata"] == {"workflow_id": "wf-caller"} assert result["metadata"] == {"workflow_id": "wf-caller"}
def test_does_not_mutate_input_dict(self): def test_does_not_mutate_input_dict(self):
@ -153,29 +226,124 @@ class TestInjectEnvelope:
def test_preview_tuple_passthrough_when_inner_already_has_metadata(self): def test_preview_tuple_passthrough_when_inner_already_has_metadata(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}}) lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}})
preview_image = ("PNG", object(), 256) preview_image = ("PNG", object(), 256)
inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": 1}} inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": "1"}}
result = inject_envelope((preview_image, inner), lookup) result = inject_envelope((preview_image, inner), lookup)
assert result == (preview_image, inner) assert result == (preview_image, inner)
def test_non_dict_non_tuple_payloads_passthrough(self): @pytest.mark.parametrize("payload", [b"raw-bytes", None, 42])
def test_non_dict_non_tuple_payloads_passthrough(self, payload):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
assert inject_envelope(b"raw-bytes", lookup) == b"raw-bytes" assert inject_envelope(payload, lookup) == payload
assert inject_envelope(None, lookup) is None
assert inject_envelope(42, lookup) == 42
def test_tuple_of_wrong_arity_passthrough(self): def test_tuple_of_wrong_arity_passthrough(self):
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is special- """Only the 2-tuple ``(preview, metadata_dict)`` shape is
cased. Other tuples must not be touched.""" special-cased. Other tuples must not be touched."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}}) lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
triple = (1, {"prompt_id": "p1"}, 3) triple = (1, {"prompt_id": "p1"}, 3)
assert inject_envelope(triple, lookup) is triple assert inject_envelope(triple, lookup) is triple
def test_envelope_lookup_called_at_send_time(self): def test_envelope_lookup_called_per_invocation(self):
"""The lookup runs each time the function is called, so a producer """The lookup runs each time the function is called, so changes
and consumer that share a backing dict observe the current value.""" to the backing store are immediately visible."""
store = {"p1": {"workflow_id": "wf-1"}} store = {"p1": {"workflow_id": "wf-1"}}
first = inject_envelope({"prompt_id": "p1"}, store.get) first = inject_envelope({"prompt_id": "p1"}, store.get)
store["p1"] = {"workflow_id": "wf-2"} store["p1"] = {"workflow_id": "wf-2"}
second = inject_envelope({"prompt_id": "p1"}, store.get) second = inject_envelope({"prompt_id": "p1"}, store.get)
del store["p1"]
third = inject_envelope({"prompt_id": "p1"}, store.get)
assert first["metadata"] == {"workflow_id": "wf-1"} assert first["metadata"] == {"workflow_id": "wf-1"}
assert second["metadata"] == {"workflow_id": "wf-2"} assert second["metadata"] == {"workflow_id": "wf-2"}
assert "metadata" not in third
class TestPromptMetadataStore:
"""End-to-end wiring tests that exercise the full register/inject/
unregister cycle the way ``PromptServer`` does."""
def test_register_inject_unregister_cycle(self):
store = PromptMetadataStore()
store.register(
"p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
)
injected = store.inject({"node": "5", "prompt_id": "p1"})
assert injected == {
"node": "5",
"prompt_id": "p1",
"metadata": {"workflow_id": "wf-1"},
}
store.unregister("p1")
passthrough = store.inject({"node": "5", "prompt_id": "p1"})
assert "metadata" not in passthrough
def test_register_with_no_derivable_envelope_is_noop(self):
store = PromptMetadataStore()
store.register("p1", {})
assert "p1" not in store
assert store.inject({"prompt_id": "p1"}) == {"prompt_id": "p1"}
def test_register_with_oversized_envelope_is_noop(self):
"""Sanitization rejection means nothing is registered — the
store stays empty and inject is a passthrough."""
store = PromptMetadataStore()
store.register(
"p1",
{"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}},
)
assert "p1" not in store
def test_unregister_unknown_prompt_is_silent(self):
store = PromptMetadataStore()
store.unregister("does-not-exist")
def test_fifo_eviction_when_capacity_exceeded(self):
"""If cleanup hooks are ever bypassed, the store must shed the
oldest entry rather than grow without bound."""
store = PromptMetadataStore(capacity=3)
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p2", {"metadata": {"workflow_id": "wf-2"}})
store.register("p3", {"metadata": {"workflow_id": "wf-3"}})
assert len(store) == 3
store.register("p4", {"metadata": {"workflow_id": "wf-4"}})
assert len(store) == 3
assert "p1" not in store
assert "p4" in store
# The newer entries are still injectable.
assert store.inject({"prompt_id": "p4"})["metadata"] == {
"workflow_id": "wf-4"
}
# The evicted one is gone.
assert "metadata" not in store.inject({"prompt_id": "p1"})
def test_register_after_unregister_does_not_count_against_capacity(self):
"""Normal lifecycle: register, unregister, register many — the
store should not silently evict valid entries because of stale
accounting."""
store = PromptMetadataStore(capacity=2)
for i in range(10):
store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}})
store.unregister(f"p{i}")
assert len(store) == 0
def test_re_register_overwrites(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p1", {"metadata": {"workflow_id": "wf-2"}})
assert store.inject({"prompt_id": "p1"})["metadata"] == {
"workflow_id": "wf-2"
}
def test_inject_with_no_registrations_is_passthrough(self):
store = PromptMetadataStore()
data = {"prompt_id": "p1", "node": "5"}
assert store.inject(data) == data
def test_inject_into_preview_tuple(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
result = store.inject((b"image-bytes", {"prompt_id": "p1"}))
assert result == (b"image-bytes", {
"prompt_id": "p1",
"metadata": {"workflow_id": "wf-1"},
})