mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-31 03:17:23 +08:00
fix(server): bound metadata envelope and clean up on cancel paths
Addresses review feedback on the per-prompt metadata envelope: - Sanitize at the boundary: reject envelopes larger than 16 keys, keys over 64 chars, values over 256 chars, or anything that isn't a flat ``dict[str, str]``. Logs a warning so abuse is observable. Stops a malicious client from inflating broadcast volume by stamping a 10 MB metadata blob onto every WS event. - Cap the in-memory store at 4096 concurrent envelopes with FIFO eviction. Acts as a backstop if any cleanup hook is skipped. - Drop envelopes when prompts are cancelled before reaching the worker: ``PromptQueue.wipe_queue`` and ``delete_queue_item`` now call ``server.unregister_prompt_metadata`` for every removed item. - Drop envelopes on hard execution failures: the worker now wraps ``e.execute()`` in ``try/finally``, so an uncaught exception in execution no longer leaks the envelope. - Guard the WS reconnect handler: only include ``prompt_id`` in the ``executing`` payload when ``last_prompt_id`` is set, so clients with strict schemas (zod ``prompt_id: zJobId``) don't reject the message with a null id. - Extract a ``PromptMetadataStore`` class that owns the dict and the bounds, so ``PromptServer`` becomes a thin delegating layer and the full register/inject/unregister cycle (plus FIFO eviction and sanitization) is unit-tested without torch. 44 tests passing; ruff clean on all touched files.
This commit is contained in:
parent
74cfcaa318
commit
fd89498eac
@ -1,48 +1,121 @@
|
||||
"""Per-prompt metadata envelope shared between submission and outbound events.
|
||||
|
||||
The metadata envelope is an opaque dict (e.g. ``{"workflow_id": ...}``)
|
||||
attached to a prompt at submission and injected by the server into every
|
||||
outbound execution event that carries a ``prompt_id``. It lets consumers
|
||||
scope state by tags they care about (workflow, trace, tenant) without the
|
||||
execution layer ever needing to know those tags exist.
|
||||
The metadata envelope is a small flat ``dict[str, str]`` (e.g.
|
||||
``{"workflow_id": ...}``) attached to a prompt at submission and injected
|
||||
by the server into every outbound execution event that carries a
|
||||
``prompt_id``. It lets consumers scope state by tags they care about
|
||||
(workflow, trace, tenant) without the execution layer ever needing to
|
||||
know those tags exist.
|
||||
|
||||
Two pure functions live here; ``PromptServer`` owns the per-prompt map and
|
||||
wires them into the submission and send paths.
|
||||
This module is intentionally pure — no imports from ``server`` or
|
||||
``execution`` — so ``PromptServer`` can own a ``PromptMetadataStore``
|
||||
instance and the helpers can be unit-tested without the rest of the app.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
# Bounds. The envelope is forwarded to every WebSocket client connected to
|
||||
# the server on every execution event for the prompt — bounding key count,
|
||||
# key length, value length, and refusing nested structures keeps a
|
||||
# malicious or buggy client from inflating the broadcast volume.
|
||||
MAX_ENVELOPE_KEYS = 16
|
||||
MAX_ENVELOPE_KEY_LEN = 64
|
||||
MAX_ENVELOPE_VALUE_LEN = 256
|
||||
|
||||
# Cap on concurrently registered prompt envelopes. Acts as a backstop if
|
||||
# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale
|
||||
# entry goes first.
|
||||
DEFAULT_STORE_CAPACITY = 4096
|
||||
|
||||
|
||||
def _sanitize_envelope(envelope: Any) -> Optional[dict]:
|
||||
"""Validate and copy a candidate envelope.
|
||||
|
||||
Enforces the ``dict[str, str]`` contract that downstream consumers
|
||||
(cloud projections, frontend zod schemas, OpenAPI docs) rely on:
|
||||
|
||||
- must be a non-empty ``dict``
|
||||
- at most ``MAX_ENVELOPE_KEYS`` entries
|
||||
- every key and value must be a ``str``
|
||||
- keys at most ``MAX_ENVELOPE_KEY_LEN`` chars
|
||||
- values at most ``MAX_ENVELOPE_VALUE_LEN`` chars
|
||||
|
||||
Returns a defensive shallow copy on success, ``None`` on any
|
||||
violation. Logs a warning on violation so abuse is visible.
|
||||
"""
|
||||
if not isinstance(envelope, dict) or not envelope:
|
||||
return None
|
||||
if len(envelope) > MAX_ENVELOPE_KEYS:
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: %d keys exceeds limit %d",
|
||||
len(envelope), MAX_ENVELOPE_KEYS,
|
||||
)
|
||||
return None
|
||||
sanitized: dict[str, str] = {}
|
||||
for key, value in envelope.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: non-string key/value (%s=%s)",
|
||||
type(key).__name__, type(value).__name__,
|
||||
)
|
||||
return None
|
||||
if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN:
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: key or value exceeds length limit",
|
||||
)
|
||||
return None
|
||||
sanitized[key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]:
|
||||
"""Pull the per-prompt metadata envelope out of a submitted prompt's
|
||||
``extra_data``.
|
||||
|
||||
Two sources, in order:
|
||||
|
||||
1. Explicit ``extra_data["metadata"]`` dict — preferred path, accepted
|
||||
as-is (copied so later mutations on the caller's dict don't leak).
|
||||
1. Explicit ``extra_data["metadata"]`` — sanitized via
|
||||
``_sanitize_envelope``. Oversized or wrong-typed envelopes are
|
||||
rejected (a warning is logged) rather than truncated, so the
|
||||
contract stays strict at the boundary.
|
||||
2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward-
|
||||
compatibility fallback. Frontends that already stamp the workflow
|
||||
id into ``extra_pnginfo`` keep working without changes; the
|
||||
synthesized envelope is ``{"workflow_id": <id>}``.
|
||||
id into ``extra_pnginfo`` keep working; the synthesized envelope
|
||||
is ``{"workflow_id": <id>}``. A debug log fires so the legacy path
|
||||
remains observable.
|
||||
|
||||
Returns ``None`` when neither source yields a usable envelope.
|
||||
"""
|
||||
if not isinstance(extra_data, dict):
|
||||
return None
|
||||
|
||||
metadata = extra_data.get("metadata")
|
||||
if isinstance(metadata, dict) and metadata:
|
||||
return dict(metadata)
|
||||
if "metadata" in extra_data:
|
||||
sanitized = _sanitize_envelope(extra_data["metadata"])
|
||||
if sanitized is not None:
|
||||
return sanitized
|
||||
# Explicit metadata was supplied but rejected — do not fall
|
||||
# through to the legacy path; the caller asked for something
|
||||
# specific and got it wrong.
|
||||
if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]:
|
||||
return None
|
||||
|
||||
extra_pnginfo = extra_data.get("extra_pnginfo")
|
||||
if isinstance(extra_pnginfo, dict):
|
||||
workflow = extra_pnginfo.get("workflow")
|
||||
if isinstance(workflow, dict):
|
||||
workflow_id = workflow.get("id")
|
||||
if isinstance(workflow_id, str) and workflow_id:
|
||||
if (
|
||||
isinstance(workflow_id, str)
|
||||
and workflow_id
|
||||
and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN
|
||||
):
|
||||
logging.debug(
|
||||
"prompt metadata envelope synthesized from extra_pnginfo.workflow.id"
|
||||
)
|
||||
return {"workflow_id": workflow_id}
|
||||
|
||||
return None
|
||||
@ -55,28 +128,25 @@ def inject_envelope(
|
||||
"""Return ``data`` with a per-prompt ``metadata`` envelope attached.
|
||||
|
||||
``envelope_lookup`` is called with the payload's ``prompt_id`` and is
|
||||
expected to return the registered envelope or ``None``. This indirection
|
||||
keeps the function pure and avoids depending on any specific storage.
|
||||
expected to return the registered envelope or ``None``. This keeps
|
||||
the function pure and avoids depending on any specific storage.
|
||||
|
||||
Two payload shapes are handled:
|
||||
|
||||
- **dict** carrying ``prompt_id``. A shallow copy is returned with a
|
||||
``metadata`` key set to the envelope. The caller's dict is not
|
||||
mutated.
|
||||
``metadata`` key set to the envelope.
|
||||
- **(preview_image, metadata_dict) tuple** — the format used by
|
||||
``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented;
|
||||
the binary preview is passed through by reference.
|
||||
|
||||
The function is a no-op for:
|
||||
|
||||
- payloads without a ``prompt_id``,
|
||||
- payloads already declaring their own ``metadata`` field
|
||||
(callers can opt out by setting it explicitly),
|
||||
- prompts with no registered envelope,
|
||||
- any other payload shape (raw bytes, ``None``, etc.).
|
||||
No-op for payloads without a ``prompt_id``, payloads already
|
||||
declaring their own ``metadata`` field, prompts with no registered
|
||||
envelope, or any other payload shape.
|
||||
"""
|
||||
def inject(d: dict) -> dict:
|
||||
if not isinstance(d, dict) or "metadata" in d:
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
if "metadata" in d:
|
||||
return d
|
||||
prompt_id = d.get("prompt_id")
|
||||
if not prompt_id:
|
||||
@ -94,3 +164,37 @@ def inject_envelope(
|
||||
return data
|
||||
return (data[0], injected)
|
||||
return data
|
||||
|
||||
|
||||
class PromptMetadataStore:
|
||||
"""Bounded ``prompt_id -> envelope`` map.
|
||||
|
||||
Owned by ``PromptServer``. Populated at submission, drained when the
|
||||
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
|
||||
backstop: if any cleanup hook is ever skipped, the store sheds the
|
||||
oldest entry instead of growing without bound.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
|
||||
self._envelopes: dict[str, dict] = {}
|
||||
self._capacity = capacity
|
||||
|
||||
def register(self, prompt_id: str, extra_data: Any) -> None:
|
||||
envelope = extract_envelope_from_extra_data(extra_data)
|
||||
if envelope is None:
|
||||
return
|
||||
if len(self._envelopes) >= self._capacity:
|
||||
self._envelopes.pop(next(iter(self._envelopes)))
|
||||
self._envelopes[prompt_id] = envelope
|
||||
|
||||
def unregister(self, prompt_id: str) -> None:
|
||||
self._envelopes.pop(prompt_id, None)
|
||||
|
||||
def inject(self, data: Any) -> Any:
|
||||
return inject_envelope(data, self._envelopes.get)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._envelopes)
|
||||
|
||||
def __contains__(self, prompt_id: str) -> bool:
|
||||
return prompt_id in self._envelopes
|
||||
|
||||
@ -1296,7 +1296,10 @@ class PromptQueue:
|
||||
|
||||
def wipe_queue(self):
|
||||
with self.mutex:
|
||||
dropped_prompt_ids = [item[1] for item in self.queue]
|
||||
self.queue = []
|
||||
for prompt_id in dropped_prompt_ids:
|
||||
self.server.unregister_prompt_metadata(prompt_id)
|
||||
self.server.queue_updated()
|
||||
|
||||
def delete_queue_item(self, function):
|
||||
@ -1306,8 +1309,9 @@ class PromptQueue:
|
||||
if len(self.queue) == 1:
|
||||
self.wipe_queue()
|
||||
else:
|
||||
self.queue.pop(x)
|
||||
deleted = self.queue.pop(x)
|
||||
heapq.heapify(self.queue)
|
||||
self.server.unregister_prompt_metadata(deleted[1])
|
||||
self.server.queue_updated()
|
||||
return True
|
||||
return False
|
||||
|
||||
31
main.py
31
main.py
@ -318,21 +318,26 @@ def prompt_worker(q, server_instance):
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
try:
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
|
||||
server_instance.unregister_prompt_metadata(prompt_id)
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
finally:
|
||||
# Always drop the metadata envelope. If e.execute() raises
|
||||
# hard before its own error handling kicks in, the
|
||||
# registered envelope would otherwise leak for the
|
||||
# lifetime of the process.
|
||||
server_instance.unregister_prompt_metadata(prompt_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
|
||||
34
server.py
34
server.py
@ -44,10 +44,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from app.prompt_metadata import (
|
||||
extract_envelope_from_extra_data,
|
||||
inject_envelope,
|
||||
)
|
||||
from app.prompt_metadata import PromptMetadataStore
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@ -257,10 +254,10 @@ class PromptServer():
|
||||
self.last_prompt_id = None
|
||||
self.client_id = None
|
||||
|
||||
# prompt_id -> metadata envelope captured at submission and injected
|
||||
# into outbound execution events. Keeps the workflow scope (and any
|
||||
# other client-supplied tags) out of the execution layer.
|
||||
self._prompt_metadata: dict[str, dict] = {}
|
||||
# Bounded prompt_id -> envelope store. Populated at submission,
|
||||
# drained on completion/cancel. Keeps workflow scope (and other
|
||||
# client-supplied tags) out of the execution layer.
|
||||
self._prompt_metadata = PromptMetadataStore()
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
|
||||
@ -285,10 +282,12 @@ class PromptServer():
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", self._inject_prompt_metadata({
|
||||
"node": self.last_node_id,
|
||||
"prompt_id": self.last_prompt_id,
|
||||
}), sid)
|
||||
payload: dict = {"node": self.last_node_id}
|
||||
if self.last_prompt_id is not None:
|
||||
payload["prompt_id"] = self.last_prompt_id
|
||||
await self.send(
|
||||
"executing", self._inject_prompt_metadata(payload), sid
|
||||
)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
@ -1233,17 +1232,16 @@ class PromptServer():
|
||||
def register_prompt_metadata(self, prompt_id: str, extra_data) -> None:
|
||||
"""Record per-prompt metadata for injection into outbound execution
|
||||
events. Called at submission, before the prompt is queued."""
|
||||
envelope = extract_envelope_from_extra_data(extra_data)
|
||||
if envelope is not None:
|
||||
self._prompt_metadata[prompt_id] = envelope
|
||||
self._prompt_metadata.register(prompt_id, extra_data)
|
||||
|
||||
def unregister_prompt_metadata(self, prompt_id: str) -> None:
|
||||
"""Drop the per-prompt metadata envelope. Called after the prompt
|
||||
has finished executing and its terminal events have been queued."""
|
||||
self._prompt_metadata.pop(prompt_id, None)
|
||||
has finished executing and its terminal events have been queued,
|
||||
or when the prompt is cancelled before reaching the worker."""
|
||||
self._prompt_metadata.unregister(prompt_id)
|
||||
|
||||
def _inject_prompt_metadata(self, data):
|
||||
return inject_envelope(data, self._prompt_metadata.get)
|
||||
return self._prompt_metadata.inject(data)
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
data = self._inject_prompt_metadata(data)
|
||||
|
||||
@ -1,12 +1,19 @@
|
||||
"""Unit tests for the pure metadata-envelope helpers in
|
||||
``app.prompt_metadata``. These cover the two functions that PromptServer
|
||||
wires into submission (``extract_envelope_from_extra_data``) and into the
|
||||
send chokepoint (``inject_envelope``).
|
||||
"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``.
|
||||
|
||||
Covers the two pure helpers (``extract_envelope_from_extra_data`` and
|
||||
``inject_envelope``) and the ``PromptMetadataStore`` integration class
|
||||
that ``PromptServer`` owns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.prompt_metadata import (
|
||||
MAX_ENVELOPE_KEYS,
|
||||
MAX_ENVELOPE_KEY_LEN,
|
||||
MAX_ENVELOPE_VALUE_LEN,
|
||||
PromptMetadataStore,
|
||||
extract_envelope_from_extra_data,
|
||||
inject_envelope,
|
||||
)
|
||||
@ -42,10 +49,10 @@ class TestExtractEnvelopeFromExtraData:
|
||||
is None
|
||||
)
|
||||
|
||||
def test_rejects_non_string_or_empty_workflow_id(self):
|
||||
for bad in ["", 123, None, [], {}]:
|
||||
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
@pytest.mark.parametrize("bad", ["", 123, None, [], {}])
|
||||
def test_rejects_non_string_or_empty_workflow_id(self, bad):
|
||||
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
def test_rejects_non_dict_inputs_at_each_level(self):
|
||||
assert extract_envelope_from_extra_data(None) is None
|
||||
@ -73,6 +80,7 @@ class TestExtractEnvelopeFromExtraData:
|
||||
def test_returned_envelope_is_copy_not_reference(self):
|
||||
original = {"workflow_id": "wf-1"}
|
||||
result = extract_envelope_from_extra_data({"metadata": original})
|
||||
assert result is not None
|
||||
result["new_key"] = "x"
|
||||
assert "new_key" not in original
|
||||
|
||||
@ -86,10 +94,76 @@ class TestExtractEnvelopeFromExtraData:
|
||||
}
|
||||
|
||||
|
||||
class TestEnvelopeSanitization:
|
||||
"""The wire contract is ``dict[str, str]`` with bounded size. A bad
|
||||
envelope is dropped (and a warning is logged) rather than truncated,
|
||||
so the boundary stays strict."""
|
||||
|
||||
def test_rejects_too_many_keys(self, caplog):
|
||||
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}
|
||||
with caplog.at_level("WARNING"):
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
assert any("exceeds limit" in r.message for r in caplog.records)
|
||||
|
||||
def test_accepts_max_keys_exactly(self):
|
||||
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
|
||||
|
||||
def test_rejects_non_string_keys(self, caplog):
|
||||
with caplog.at_level("WARNING"):
|
||||
assert (
|
||||
extract_envelope_from_extra_data({"metadata": {42: "v"}})
|
||||
is None
|
||||
)
|
||||
assert any("non-string" in r.message for r in caplog.records)
|
||||
|
||||
def test_rejects_non_string_values(self, caplog):
|
||||
for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]:
|
||||
with caplog.at_level("WARNING"):
|
||||
assert (
|
||||
extract_envelope_from_extra_data(
|
||||
{"metadata": {"k": bad_value}}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_rejects_oversized_key(self):
|
||||
envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
|
||||
def test_rejects_oversized_value(self):
|
||||
envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
|
||||
def test_accepts_max_lengths_exactly(self):
|
||||
envelope = {
|
||||
"x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN
|
||||
}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
|
||||
|
||||
def test_oversized_workflow_id_in_pnginfo_rejected(self):
|
||||
"""The legacy synthesized path also respects the value bound."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {
|
||||
"workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
|
||||
}
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
def test_invalid_explicit_metadata_does_not_fall_through(self):
|
||||
"""An explicit but invalid metadata dict means the caller asked
|
||||
for something specific and got it wrong; the synthesized
|
||||
fallback must not silently substitute."""
|
||||
extra_data = {
|
||||
"metadata": {"k": 42}, # non-string value
|
||||
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
|
||||
class TestInjectEnvelope:
|
||||
@staticmethod
|
||||
def _lookup(table):
|
||||
"""Build an envelope_lookup callable backed by a dict."""
|
||||
return table.get
|
||||
|
||||
def test_injects_envelope_on_dict_with_known_prompt_id(self):
|
||||
@ -108,16 +182,15 @@ class TestInjectEnvelope:
|
||||
def test_passthrough_when_payload_lacks_prompt_id(self):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
data = {"status": "ok"}
|
||||
assert inject_envelope(data, lookup) is data
|
||||
assert inject_envelope(data, lookup) == data
|
||||
|
||||
def test_passthrough_when_payload_already_has_metadata(self):
|
||||
"""If a caller has already set a ``metadata`` field (e.g. for
|
||||
opt-out or pre-augmented payloads), the function must not
|
||||
overwrite it."""
|
||||
"""If a caller has already set a ``metadata`` field, the
|
||||
function must not overwrite it."""
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}})
|
||||
data = {"prompt_id": "p1", "metadata": {"workflow_id": "wf-caller"}}
|
||||
result = inject_envelope(data, lookup)
|
||||
assert result is data
|
||||
assert result == data
|
||||
assert result["metadata"] == {"workflow_id": "wf-caller"}
|
||||
|
||||
def test_does_not_mutate_input_dict(self):
|
||||
@ -153,29 +226,124 @@ class TestInjectEnvelope:
|
||||
def test_preview_tuple_passthrough_when_inner_already_has_metadata(self):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-injected"}})
|
||||
preview_image = ("PNG", object(), 256)
|
||||
inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": 1}}
|
||||
inner = {"node_id": "5", "prompt_id": "p1", "metadata": {"x": "1"}}
|
||||
result = inject_envelope((preview_image, inner), lookup)
|
||||
assert result == (preview_image, inner)
|
||||
|
||||
def test_non_dict_non_tuple_payloads_passthrough(self):
|
||||
@pytest.mark.parametrize("payload", [b"raw-bytes", None, 42])
|
||||
def test_non_dict_non_tuple_payloads_passthrough(self, payload):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
assert inject_envelope(b"raw-bytes", lookup) == b"raw-bytes"
|
||||
assert inject_envelope(None, lookup) is None
|
||||
assert inject_envelope(42, lookup) == 42
|
||||
assert inject_envelope(payload, lookup) == payload
|
||||
|
||||
def test_tuple_of_wrong_arity_passthrough(self):
|
||||
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is special-
|
||||
cased. Other tuples must not be touched."""
|
||||
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is
|
||||
special-cased. Other tuples must not be touched."""
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
triple = (1, {"prompt_id": "p1"}, 3)
|
||||
assert inject_envelope(triple, lookup) is triple
|
||||
|
||||
def test_envelope_lookup_called_at_send_time(self):
|
||||
"""The lookup runs each time the function is called, so a producer
|
||||
and consumer that share a backing dict observe the current value."""
|
||||
def test_envelope_lookup_called_per_invocation(self):
|
||||
"""The lookup runs each time the function is called, so changes
|
||||
to the backing store are immediately visible."""
|
||||
store = {"p1": {"workflow_id": "wf-1"}}
|
||||
first = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
store["p1"] = {"workflow_id": "wf-2"}
|
||||
second = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
del store["p1"]
|
||||
third = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
assert first["metadata"] == {"workflow_id": "wf-1"}
|
||||
assert second["metadata"] == {"workflow_id": "wf-2"}
|
||||
assert "metadata" not in third
|
||||
|
||||
|
||||
class TestPromptMetadataStore:
|
||||
"""End-to-end wiring tests that exercise the full register/inject/
|
||||
unregister cycle the way ``PromptServer`` does."""
|
||||
|
||||
def test_register_inject_unregister_cycle(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register(
|
||||
"p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
|
||||
)
|
||||
injected = store.inject({"node": "5", "prompt_id": "p1"})
|
||||
assert injected == {
|
||||
"node": "5",
|
||||
"prompt_id": "p1",
|
||||
"metadata": {"workflow_id": "wf-1"},
|
||||
}
|
||||
store.unregister("p1")
|
||||
passthrough = store.inject({"node": "5", "prompt_id": "p1"})
|
||||
assert "metadata" not in passthrough
|
||||
|
||||
def test_register_with_no_derivable_envelope_is_noop(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {})
|
||||
assert "p1" not in store
|
||||
assert store.inject({"prompt_id": "p1"}) == {"prompt_id": "p1"}
|
||||
|
||||
def test_register_with_oversized_envelope_is_noop(self):
|
||||
"""Sanitization rejection means nothing is registered — the
|
||||
store stays empty and inject is a passthrough."""
|
||||
store = PromptMetadataStore()
|
||||
store.register(
|
||||
"p1",
|
||||
{"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}},
|
||||
)
|
||||
assert "p1" not in store
|
||||
|
||||
def test_unregister_unknown_prompt_is_silent(self):
|
||||
store = PromptMetadataStore()
|
||||
store.unregister("does-not-exist")
|
||||
|
||||
def test_fifo_eviction_when_capacity_exceeded(self):
|
||||
"""If cleanup hooks are ever bypassed, the store must shed the
|
||||
oldest entry rather than grow without bound."""
|
||||
store = PromptMetadataStore(capacity=3)
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
store.register("p2", {"metadata": {"workflow_id": "wf-2"}})
|
||||
store.register("p3", {"metadata": {"workflow_id": "wf-3"}})
|
||||
assert len(store) == 3
|
||||
|
||||
store.register("p4", {"metadata": {"workflow_id": "wf-4"}})
|
||||
assert len(store) == 3
|
||||
assert "p1" not in store
|
||||
assert "p4" in store
|
||||
|
||||
# The newer entries are still injectable.
|
||||
assert store.inject({"prompt_id": "p4"})["metadata"] == {
|
||||
"workflow_id": "wf-4"
|
||||
}
|
||||
# The evicted one is gone.
|
||||
assert "metadata" not in store.inject({"prompt_id": "p1"})
|
||||
|
||||
def test_register_after_unregister_does_not_count_against_capacity(self):
|
||||
"""Normal lifecycle: register, unregister, register many — the
|
||||
store should not silently evict valid entries because of stale
|
||||
accounting."""
|
||||
store = PromptMetadataStore(capacity=2)
|
||||
for i in range(10):
|
||||
store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}})
|
||||
store.unregister(f"p{i}")
|
||||
assert len(store) == 0
|
||||
|
||||
def test_re_register_overwrites(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-2"}})
|
||||
assert store.inject({"prompt_id": "p1"})["metadata"] == {
|
||||
"workflow_id": "wf-2"
|
||||
}
|
||||
|
||||
def test_inject_with_no_registrations_is_passthrough(self):
|
||||
store = PromptMetadataStore()
|
||||
data = {"prompt_id": "p1", "node": "5"}
|
||||
assert store.inject(data) == data
|
||||
|
||||
def test_inject_into_preview_tuple(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
result = store.inject((b"image-bytes", {"prompt_id": "p1"}))
|
||||
assert result == (b"image-bytes", {
|
||||
"prompt_id": "p1",
|
||||
"metadata": {"workflow_id": "wf-1"},
|
||||
})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user