fix(server): bound metadata envelope and clean up on cancel paths

Addresses review feedback on the per-prompt metadata envelope:

- Sanitize at the boundary: reject envelopes larger than 16 keys, keys
  over 64 chars, values over 256 chars, or anything that isn't a flat
  ``dict[str, str]``. Logs a warning so abuse is observable. Stops a
  malicious client from inflating broadcast volume by stamping a 10 MB
  metadata blob onto every WS event.
- Cap the in-memory store at 4096 concurrent envelopes with FIFO
  eviction. Acts as a backstop if any cleanup hook is skipped.
- Drop envelopes when prompts are cancelled before reaching the worker:
  ``PromptQueue.wipe_queue`` and ``delete_queue_item`` now call
  ``server.unregister_prompt_metadata`` for every removed item.
- Drop envelopes on hard execution failures: the worker now wraps
  ``e.execute()`` in ``try/finally``, so an uncaught exception in
  execution no longer leaks the envelope.
- Guard the WS reconnect handler: only include ``prompt_id`` in the
  ``executing`` payload when ``last_prompt_id`` is set, so clients
  with strict schemas (zod ``prompt_id: zJobId``) don't reject the
  message with a null id.
- Extract a ``PromptMetadataStore`` class that owns the dict and the
  bounds, so ``PromptServer`` becomes a thin delegating layer and the
  full register/inject/unregister cycle (plus FIFO eviction and
  sanitization) is unit-tested without torch.

44 tests passing; ruff clean on all touched files.
This commit is contained in:
Deep Mehta 2026-05-14 21:03:38 -07:00
parent 74cfcaa318
commit fd89498eac
5 changed files with 362 additions and 83 deletions

View File

@ -1,48 +1,121 @@
"""Per-prompt metadata envelope shared between submission and outbound events.
The metadata envelope is an opaque dict (e.g. ``{"workflow_id": ...}``)
attached to a prompt at submission and injected by the server into every
outbound execution event that carries a ``prompt_id``. It lets consumers
scope state by tags they care about (workflow, trace, tenant) without the
execution layer ever needing to know those tags exist.
The metadata envelope is a small flat ``dict[str, str]`` (e.g.
``{"workflow_id": ...}``) attached to a prompt at submission and injected
by the server into every outbound execution event that carries a
``prompt_id``. It lets consumers scope state by tags they care about
(workflow, trace, tenant) without the execution layer ever needing to
know those tags exist.
Two pure functions live here; ``PromptServer`` owns the per-prompt map and
wires them into the submission and send paths.
This module is intentionally pure no imports from ``server`` or
``execution`` so ``PromptServer`` can own a ``PromptMetadataStore``
instance and the helpers can be unit-tested without the rest of the app.
"""
from __future__ import annotations
import logging
from typing import Any, Callable, Optional
# Bounds. The envelope is forwarded to every WebSocket client connected to
# the server on every execution event for the prompt — bounding key count,
# key length, value length, and refusing nested structures keeps a
# malicious or buggy client from inflating the broadcast volume.
MAX_ENVELOPE_KEYS = 16
MAX_ENVELOPE_KEY_LEN = 64
MAX_ENVELOPE_VALUE_LEN = 256
# Cap on concurrently registered prompt envelopes. Acts as a backstop if
# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale
# entry goes first.
DEFAULT_STORE_CAPACITY = 4096
def _sanitize_envelope(envelope: Any) -> Optional[dict]:
"""Validate and copy a candidate envelope.
Enforces the ``dict[str, str]`` contract that downstream consumers
(cloud projections, frontend zod schemas, OpenAPI docs) rely on:
- must be a non-empty ``dict``
- at most ``MAX_ENVELOPE_KEYS`` entries
- every key and value must be a ``str``
- keys at most ``MAX_ENVELOPE_KEY_LEN`` chars
- values at most ``MAX_ENVELOPE_VALUE_LEN`` chars
Returns a defensive shallow copy on success, ``None`` on any
violation. Logs a warning on violation so abuse is visible.
"""
if not isinstance(envelope, dict) or not envelope:
return None
if len(envelope) > MAX_ENVELOPE_KEYS:
logging.warning(
"prompt metadata envelope rejected: %d keys exceeds limit %d",
len(envelope), MAX_ENVELOPE_KEYS,
)
return None
sanitized: dict[str, str] = {}
for key, value in envelope.items():
if not isinstance(key, str) or not isinstance(value, str):
logging.warning(
"prompt metadata envelope rejected: non-string key/value (%s=%s)",
type(key).__name__, type(value).__name__,
)
return None
if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN:
logging.warning(
"prompt metadata envelope rejected: key or value exceeds length limit",
)
return None
sanitized[key] = value
return sanitized
def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]:
"""Pull the per-prompt metadata envelope out of a submitted prompt's
``extra_data``.
Two sources, in order:
1. Explicit ``extra_data["metadata"]`` dict preferred path, accepted
as-is (copied so later mutations on the caller's dict don't leak).
1. Explicit ``extra_data["metadata"]`` sanitized via
``_sanitize_envelope``. Oversized or wrong-typed envelopes are
rejected (a warning is logged) rather than truncated, so the
contract stays strict at the boundary.
2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` backward-
compatibility fallback. Frontends that already stamp the workflow
id into ``extra_pnginfo`` keep working without changes; the
synthesized envelope is ``{"workflow_id": <id>}``.
id into ``extra_pnginfo`` keep working; the synthesized envelope
is ``{"workflow_id": <id>}``. A debug log fires so the legacy path
remains observable.
Returns ``None`` when neither source yields a usable envelope.
"""
if not isinstance(extra_data, dict):
return None
metadata = extra_data.get("metadata")
if isinstance(metadata, dict) and metadata:
return dict(metadata)
if "metadata" in extra_data:
sanitized = _sanitize_envelope(extra_data["metadata"])
if sanitized is not None:
return sanitized
# Explicit metadata was supplied but rejected — do not fall
# through to the legacy path; the caller asked for something
# specific and got it wrong.
if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]:
return None
extra_pnginfo = extra_data.get("extra_pnginfo")
if isinstance(extra_pnginfo, dict):
workflow = extra_pnginfo.get("workflow")
if isinstance(workflow, dict):
workflow_id = workflow.get("id")
if isinstance(workflow_id, str) and workflow_id:
if (
isinstance(workflow_id, str)
and workflow_id
and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN
):
logging.debug(
"prompt metadata envelope synthesized from extra_pnginfo.workflow.id"
)
return {"workflow_id": workflow_id}
return None
@ -55,28 +128,25 @@ def inject_envelope(
"""Return ``data`` with a per-prompt ``metadata`` envelope attached.
``envelope_lookup`` is called with the payload's ``prompt_id`` and is
expected to return the registered envelope or ``None``. This indirection
keeps the function pure and avoids depending on any specific storage.
expected to return the registered envelope or ``None``. This keeps
the function pure and avoids depending on any specific storage.
Two payload shapes are handled:
- **dict** carrying ``prompt_id``. A shallow copy is returned with a
``metadata`` key set to the envelope. The caller's dict is not
mutated.
``metadata`` key set to the envelope.
- **(preview_image, metadata_dict) tuple** the format used by
``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented;
the binary preview is passed through by reference.
The function is a no-op for:
- payloads without a ``prompt_id``,
- payloads already declaring their own ``metadata`` field
(callers can opt out by setting it explicitly),
- prompts with no registered envelope,
- any other payload shape (raw bytes, ``None``, etc.).
No-op for payloads without a ``prompt_id``, payloads already
declaring their own ``metadata`` field, prompts with no registered
envelope, or any other payload shape.
"""
def inject(d: dict) -> dict:
if not isinstance(d, dict) or "metadata" in d:
if not isinstance(d, dict):
return d
if "metadata" in d:
return d
prompt_id = d.get("prompt_id")
if not prompt_id:
@ -94,3 +164,37 @@ def inject_envelope(
return data
return (data[0], injected)
return data
class PromptMetadataStore:
"""Bounded ``prompt_id -> envelope`` map.
Owned by ``PromptServer``. Populated at submission, drained when the
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
backstop: if any cleanup hook is ever skipped, the store sheds the
oldest entry instead of growing without bound.
"""
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
self._envelopes: dict[str, dict] = {}
self._capacity = capacity
def register(self, prompt_id: str, extra_data: Any) -> None:
envelope = extract_envelope_from_extra_data(extra_data)
if envelope is None:
return
if len(self._envelopes) >= self._capacity:
self._envelopes.pop(next(iter(self._envelopes)))
self._envelopes[prompt_id] = envelope
def unregister(self, prompt_id: str) -> None:
self._envelopes.pop(prompt_id, None)
def inject(self, data: Any) -> Any:
return inject_envelope(data, self._envelopes.get)
def __len__(self) -> int:
return len(self._envelopes)
def __contains__(self, prompt_id: str) -> bool:
return prompt_id in self._envelopes

View File

@ -1296,7 +1296,10 @@ class PromptQueue:
def wipe_queue(self):
with self.mutex:
dropped_prompt_ids = [item[1] for item in self.queue]
self.queue = []
for prompt_id in dropped_prompt_ids:
self.server.unregister_prompt_metadata(prompt_id)
self.server.queue_updated()
def delete_queue_item(self, function):
@ -1306,8 +1309,9 @@ class PromptQueue:
if len(self.queue) == 1:
self.wipe_queue()
else:
self.queue.pop(x)
deleted = self.queue.pop(x)
heapq.heapify(self.queue)
self.server.unregister_prompt_metadata(deleted[1])
self.server.queue_updated()
return True
return False

31
main.py
View File

@ -318,21 +318,26 @@ def prompt_worker(q, server_instance):
extra_data[k] = sensitive[k]
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
try:
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
server_instance.unregister_prompt_metadata(prompt_id)
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
finally:
# Always drop the metadata envelope. If e.execute() raises
# hard before its own error handling kicks in, the
# registered envelope would otherwise leak for the
# lifetime of the process.
server_instance.unregister_prompt_metadata(prompt_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time

View File

@ -44,10 +44,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from app.prompt_metadata import (
extract_envelope_from_extra_data,
inject_envelope,
)
from app.prompt_metadata import PromptMetadataStore
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@ -257,10 +254,10 @@ class PromptServer():
self.last_prompt_id = None
self.client_id = None
# prompt_id -> metadata envelope captured at submission and injected
# into outbound execution events. Keeps the workflow scope (and any
# other client-supplied tags) out of the execution layer.
self._prompt_metadata: dict[str, dict] = {}
# Bounded prompt_id -> envelope store. Populated at submission,
# drained on completion/cancel. Keeps workflow scope (and other
# client-supplied tags) out of the execution layer.
self._prompt_metadata = PromptMetadataStore()
self.on_prompt_handlers = []
@ -285,10 +282,12 @@ class PromptServer():
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", self._inject_prompt_metadata({
"node": self.last_node_id,
"prompt_id": self.last_prompt_id,
}), sid)
payload: dict = {"node": self.last_node_id}
if self.last_prompt_id is not None:
payload["prompt_id"] = self.last_prompt_id
await self.send(
"executing", self._inject_prompt_metadata(payload), sid
)
# Flag to track if we've received the first message
first_message = True
@ -1233,17 +1232,16 @@ class PromptServer():
def register_prompt_metadata(self, prompt_id: str, extra_data) -> None:
"""Record per-prompt metadata for injection into outbound execution
events. Called at submission, before the prompt is queued."""
envelope = extract_envelope_from_extra_data(extra_data)
if envelope is not None:
self._prompt_metadata[prompt_id] = envelope
self._prompt_metadata.register(prompt_id, extra_data)
def unregister_prompt_metadata(self, prompt_id: str) -> None:
"""Drop the per-prompt metadata envelope. Called after the prompt
has finished executing and its terminal events have been queued."""
self._prompt_metadata.pop(prompt_id, None)
has finished executing and its terminal events have been queued,
or when the prompt is cancelled before reaching the worker."""
self._prompt_metadata.unregister(prompt_id)
def _inject_prompt_metadata(self, data):
return inject_envelope(data, self._prompt_metadata.get)
return self._prompt_metadata.inject(data)
def send_sync(self, event, data, sid=None):
data = self._inject_prompt_metadata(data)

View File

@ -1,12 +1,19 @@
"""Unit tests for the pure metadata-envelope helpers in
``app.prompt_metadata``. These cover the two functions that PromptServer
wires into submission (``extract_envelope_from_extra_data``) and into the
send chokepoint (``inject_envelope``).
"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``.
Covers the two pure helpers (``extract_envelope_from_extra_data`` and
``inject_envelope``) and the ``PromptMetadataStore`` integration class
that ``PromptServer`` owns.
"""
from __future__ import annotations
import pytest
from app.prompt_metadata import (
MAX_ENVELOPE_KEYS,
MAX_ENVELOPE_KEY_LEN,
MAX_ENVELOPE_VALUE_LEN,
PromptMetadataStore,
extract_envelope_from_extra_data,
inject_envelope,
)
@ -42,10 +49,10 @@ class TestExtractEnvelopeFromExtraData:
is None
)
def test_rejects_non_string_or_empty_workflow_id(self):
for bad in ["", 123, None, [], {}]:
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
assert extract_envelope_from_extra_data(extra_data) is None
@pytest.mark.parametrize("bad", ["", 123, None, [], {}])
def test_rejects_non_string_or_empty_workflow_id(self, bad):
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
assert extract_envelope_from_extra_data(extra_data) is None
def test_rejects_non_dict_inputs_at_each_level(self):
assert extract_envelope_from_extra_data(None) is None
@ -73,6 +80,7 @@ class TestExtractEnvelopeFromExtraData:
def test_returned_envelope_is_copy_not_reference(self):
original = {"workflow_id": "wf-1"}
result = extract_envelope_from_extra_data({"metadata": original})
assert result is not None
result["new_key"] = "x"
assert "new_key" not in original
@ -86,10 +94,76 @@ class TestExtractEnvelopeFromExtraData:
}
class TestEnvelopeSanitization:
"""The wire contract is ``dict[str, str]`` with bounded size. A bad
envelope is dropped (and a warning is logged) rather than truncated,
so the boundary stays strict."""
def test_rejects_too_many_keys(self, caplog):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}
with caplog.at_level("WARNING"):
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
assert any("exceeds limit" in r.message for r in caplog.records)
def test_accepts_max_keys_exactly(self):
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_rejects_non_string_keys(self, caplog):
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data({"metadata": {42: "v"}})
is None
)
assert any("non-string" in r.message for r in caplog.records)
def test_rejects_non_string_values(self, caplog):
for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]:
with caplog.at_level("WARNING"):
assert (
extract_envelope_from_extra_data(
{"metadata": {"k": bad_value}}
)
is None
)
def test_rejects_oversized_key(self):
envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_rejects_oversized_value(self):
envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
def test_accepts_max_lengths_exactly(self):
envelope = {
"x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN
}
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
def test_oversized_workflow_id_in_pnginfo_rejected(self):
"""The legacy synthesized path also respects the value bound."""
extra_data = {
"extra_pnginfo": {
"workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
}
}
assert extract_envelope_from_extra_data(extra_data) is None
def test_invalid_explicit_metadata_does_not_fall_through(self):
"""An explicit but invalid metadata dict means the caller asked
for something specific and got it wrong; the synthesized
fallback must not silently substitute."""
extra_data = {
"metadata": {"k": 42}, # non-string value
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
}
assert extract_envelope_from_extra_data(extra_data) is None
class TestInjectEnvelope:
@staticmethod
def _lookup(table):
"""Build an envelope_lookup callable backed by a dict."""
return table.get
def test_injects_envelope_on_dict_with_known_prompt_id(self):
@ -108,16 +182,15 @@ class TestInjectEnvelope:
def test_passthrough_when_payload_lacks_prompt_id(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
data = {"status": "ok"}
assert inject_envelope(data, lookup) is data
assert inject_envelope(data, lookup) == data
def test_passthrough_when_payload_already_has_metadata(self):
"""If a caller has already set a ``metadata`` field (e.g. for
opt-out or pre-augmented payloads), the function must not
overwrite it."""
"""If a caller has already set a ``metadata`` field, the
function must not overwrite it."""
lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}})
data = {"prompt_id": "p1", "metadata": {"workflow_id": "wf-caller"}}
result = inject_envelope(data, lookup)
assert result is data
assert result == data
assert result["metadata"] == {"workflow_id": "wf-caller"}
def test_does_not_mutate_input_dict(self):
@ -153,29 +226,124 @@ class TestInjectEnvelope:
def test_preview_tuple_passthrough_when_inner_already_has_metadata(self):
lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}})
preview_image = ("PNG", object(), 256)
inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": 1}}
inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": "1"}}
result = inject_envelope((preview_image, inner), lookup)
assert result == (preview_image, inner)
def test_non_dict_non_tuple_payloads_passthrough(self):
@pytest.mark.parametrize("payload", [b"raw-bytes", None, 42])
def test_non_dict_non_tuple_payloads_passthrough(self, payload):
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
assert inject_envelope(b"raw-bytes", lookup) == b"raw-bytes"
assert inject_envelope(None, lookup) is None
assert inject_envelope(42, lookup) == 42
assert inject_envelope(payload, lookup) == payload
def test_tuple_of_wrong_arity_passthrough(self):
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is special-
cased. Other tuples must not be touched."""
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is
special-cased. Other tuples must not be touched."""
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
triple = (1, {"prompt_id": "p1"}, 3)
assert inject_envelope(triple, lookup) is triple
def test_envelope_lookup_called_at_send_time(self):
"""The lookup runs each time the function is called, so a producer
and consumer that share a backing dict observe the current value."""
def test_envelope_lookup_called_per_invocation(self):
"""The lookup runs each time the function is called, so changes
to the backing store are immediately visible."""
store = {"p1": {"workflow_id": "wf-1"}}
first = inject_envelope({"prompt_id": "p1"}, store.get)
store["p1"] = {"workflow_id": "wf-2"}
second = inject_envelope({"prompt_id": "p1"}, store.get)
del store["p1"]
third = inject_envelope({"prompt_id": "p1"}, store.get)
assert first["metadata"] == {"workflow_id": "wf-1"}
assert second["metadata"] == {"workflow_id": "wf-2"}
assert "metadata" not in third
class TestPromptMetadataStore:
"""End-to-end wiring tests that exercise the full register/inject/
unregister cycle the way ``PromptServer`` does."""
def test_register_inject_unregister_cycle(self):
store = PromptMetadataStore()
store.register(
"p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
)
injected = store.inject({"node": "5", "prompt_id": "p1"})
assert injected == {
"node": "5",
"prompt_id": "p1",
"metadata": {"workflow_id": "wf-1"},
}
store.unregister("p1")
passthrough = store.inject({"node": "5", "prompt_id": "p1"})
assert "metadata" not in passthrough
def test_register_with_no_derivable_envelope_is_noop(self):
store = PromptMetadataStore()
store.register("p1", {})
assert "p1" not in store
assert store.inject({"prompt_id": "p1"}) == {"prompt_id": "p1"}
def test_register_with_oversized_envelope_is_noop(self):
"""Sanitization rejection means nothing is registered — the
store stays empty and inject is a passthrough."""
store = PromptMetadataStore()
store.register(
"p1",
{"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}},
)
assert "p1" not in store
def test_unregister_unknown_prompt_is_silent(self):
store = PromptMetadataStore()
store.unregister("does-not-exist")
def test_fifo_eviction_when_capacity_exceeded(self):
"""If cleanup hooks are ever bypassed, the store must shed the
oldest entry rather than grow without bound."""
store = PromptMetadataStore(capacity=3)
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p2", {"metadata": {"workflow_id": "wf-2"}})
store.register("p3", {"metadata": {"workflow_id": "wf-3"}})
assert len(store) == 3
store.register("p4", {"metadata": {"workflow_id": "wf-4"}})
assert len(store) == 3
assert "p1" not in store
assert "p4" in store
# The newer entries are still injectable.
assert store.inject({"prompt_id": "p4"})["metadata"] == {
"workflow_id": "wf-4"
}
# The evicted one is gone.
assert "metadata" not in store.inject({"prompt_id": "p1"})
def test_register_after_unregister_does_not_count_against_capacity(self):
"""Normal lifecycle: register, unregister, register many — the
store should not silently evict valid entries because of stale
accounting."""
store = PromptMetadataStore(capacity=2)
for i in range(10):
store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}})
store.unregister(f"p{i}")
assert len(store) == 0
def test_re_register_overwrites(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
store.register("p1", {"metadata": {"workflow_id": "wf-2"}})
assert store.inject({"prompt_id": "p1"})["metadata"] == {
"workflow_id": "wf-2"
}
def test_inject_with_no_registrations_is_passthrough(self):
store = PromptMetadataStore()
data = {"prompt_id": "p1", "node": "5"}
assert store.inject(data) == data
def test_inject_into_preview_tuple(self):
store = PromptMetadataStore()
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
result = store.inject((b"image-bytes", {"prompt_id": "p1"}))
assert result == (b"image-bytes", {
"prompt_id": "p1",
"metadata": {"workflow_id": "wf-1"},
})