mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 16:57:29 +08:00
Merge 63784baed5 into aeadb7acaa
This commit is contained in:
commit
3b51306917
226
app/prompt_metadata.py
Normal file
226
app/prompt_metadata.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""Per-prompt metadata envelope shared between submission and outbound events.
|
||||
|
||||
The metadata envelope is a small flat ``dict[str, str]`` (e.g.
|
||||
``{"workflow_id": ...}``) attached to a prompt at submission and injected
|
||||
by the server into every outbound execution event that carries a
|
||||
``prompt_id``. It lets consumers scope state by tags they care about
|
||||
(workflow, trace, tenant) without the execution layer ever needing to
|
||||
know those tags exist.
|
||||
|
||||
This module is intentionally pure — no imports from ``server`` or
|
||||
``execution`` — so ``PromptServer`` can own a ``PromptMetadataStore``
|
||||
instance and the helpers can be unit-tested without the rest of the app.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
# Bounds. The envelope is forwarded to every WebSocket client connected to
|
||||
# the server on every execution event for the prompt — bounding key count,
|
||||
# key length, value length, and refusing nested structures keeps a
|
||||
# malicious or buggy client from inflating the broadcast volume.
|
||||
MAX_ENVELOPE_KEYS = 16
|
||||
MAX_ENVELOPE_KEY_LEN = 64
|
||||
MAX_ENVELOPE_VALUE_LEN = 256
|
||||
|
||||
# Cap on concurrently registered prompt envelopes. Acts as a backstop if
|
||||
# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale
|
||||
# entry goes first.
|
||||
DEFAULT_STORE_CAPACITY = 4096
|
||||
|
||||
|
||||
def _sanitize_envelope(envelope: Any) -> Optional[dict]:
|
||||
"""Validate and copy a candidate envelope.
|
||||
|
||||
Enforces the ``dict[str, str]`` contract that downstream consumers
|
||||
(cloud projections, frontend zod schemas, OpenAPI docs) rely on:
|
||||
|
||||
- must be a non-empty ``dict``
|
||||
- at most ``MAX_ENVELOPE_KEYS`` entries
|
||||
- every key and value must be a ``str``
|
||||
- keys at most ``MAX_ENVELOPE_KEY_LEN`` chars
|
||||
- values at most ``MAX_ENVELOPE_VALUE_LEN`` chars
|
||||
|
||||
Returns a defensive shallow copy on success, ``None`` on any
|
||||
violation. Logs a warning on violation so abuse is visible.
|
||||
"""
|
||||
if not isinstance(envelope, dict) or not envelope:
|
||||
return None
|
||||
if len(envelope) > MAX_ENVELOPE_KEYS:
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: %d keys exceeds limit %d",
|
||||
len(envelope), MAX_ENVELOPE_KEYS,
|
||||
)
|
||||
return None
|
||||
sanitized: dict[str, str] = {}
|
||||
for key, value in envelope.items():
|
||||
if not isinstance(key, str) or not isinstance(value, str):
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: non-string key/value (%s=%s)",
|
||||
type(key).__name__, type(value).__name__,
|
||||
)
|
||||
return None
|
||||
if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN:
|
||||
logging.warning(
|
||||
"prompt metadata envelope rejected: key or value exceeds length limit",
|
||||
)
|
||||
return None
|
||||
sanitized[key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]:
|
||||
"""Pull the per-prompt metadata envelope out of a submitted prompt's
|
||||
``extra_data``.
|
||||
|
||||
Two sources, in order:
|
||||
|
||||
1. Explicit ``extra_data["metadata"]`` — sanitized via
|
||||
``_sanitize_envelope``. Oversized or wrong-typed envelopes are
|
||||
rejected (a warning is logged) rather than truncated, so the
|
||||
contract stays strict at the boundary.
|
||||
2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward-
|
||||
compatibility fallback. Frontends that already stamp the workflow
|
||||
id into ``extra_pnginfo`` keep working; the synthesized envelope
|
||||
is ``{"workflow_id": <id>}``. A debug log fires so the legacy path
|
||||
remains observable.
|
||||
|
||||
Returns ``None`` when neither source yields a usable envelope.
|
||||
"""
|
||||
if not isinstance(extra_data, dict):
|
||||
return None
|
||||
|
||||
if "metadata" in extra_data:
|
||||
sanitized = _sanitize_envelope(extra_data["metadata"])
|
||||
if sanitized is not None:
|
||||
return sanitized
|
||||
# Explicit metadata was supplied but rejected — do not fall
|
||||
# through to the legacy path; the caller asked for something
|
||||
# specific and got it wrong.
|
||||
if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]:
|
||||
return None
|
||||
|
||||
extra_pnginfo = extra_data.get("extra_pnginfo")
|
||||
if isinstance(extra_pnginfo, dict):
|
||||
workflow = extra_pnginfo.get("workflow")
|
||||
if isinstance(workflow, dict):
|
||||
workflow_id = workflow.get("id")
|
||||
if (
|
||||
isinstance(workflow_id, str)
|
||||
and workflow_id
|
||||
and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN
|
||||
):
|
||||
logging.debug(
|
||||
"prompt metadata envelope synthesized from extra_pnginfo.workflow.id"
|
||||
)
|
||||
return {"workflow_id": workflow_id}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def inject_envelope(
|
||||
data: Any,
|
||||
envelope_lookup: Callable[[str], Optional[dict]],
|
||||
) -> Any:
|
||||
"""Return ``data`` with the per-prompt envelope's keys spread onto it.
|
||||
|
||||
``envelope_lookup`` is called with the payload's ``prompt_id`` and is
|
||||
expected to return the registered envelope or ``None``. This keeps
|
||||
the function pure and avoids depending on any specific storage.
|
||||
|
||||
The envelope's keys are merged onto the payload at the top level so
|
||||
consumers can read them directly (e.g. ``event.workflow_id``) —
|
||||
matching the wire shape of the prior workflow-id-on-events work and
|
||||
avoiding an extra nesting hop for clients. Server-emitted fields on
|
||||
the payload always win on collision (``{**envelope, **d}``); a
|
||||
misbehaving client cannot shadow ``prompt_id``, ``node``, etc.
|
||||
|
||||
Two payload shapes are handled:
|
||||
|
||||
- **dict** carrying ``prompt_id``. A shallow copy is returned with
|
||||
the envelope's keys merged onto it.
|
||||
- **(preview_image, metadata_dict) tuple** — the format used by
|
||||
``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented;
|
||||
the binary preview is passed through by reference.
|
||||
|
||||
No-op for payloads without a ``prompt_id``, prompts with no
|
||||
registered envelope, or any other payload shape.
|
||||
"""
|
||||
def inject(d: dict) -> dict:
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
prompt_id = d.get("prompt_id")
|
||||
if not prompt_id:
|
||||
return d
|
||||
envelope = envelope_lookup(prompt_id)
|
||||
if envelope is None:
|
||||
return d
|
||||
return {**envelope, **d}
|
||||
|
||||
if isinstance(data, dict):
|
||||
return inject(data)
|
||||
if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict):
|
||||
injected = inject(data[1])
|
||||
if injected is data[1]:
|
||||
return data
|
||||
return (data[0], injected)
|
||||
return data
|
||||
|
||||
|
||||
class PromptMetadataStore:
|
||||
"""Bounded ``prompt_id -> envelope`` map.
|
||||
|
||||
Owned by ``PromptServer``. Populated at submission, drained when the
|
||||
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
|
||||
backstop: if any cleanup hook is ever skipped, the store sheds the
|
||||
oldest entry instead of growing without bound.
|
||||
|
||||
Access is serialized through a ``threading.Lock``. ``register`` runs
|
||||
on the aiohttp event-loop thread, ``unregister`` runs on the
|
||||
``prompt_worker`` thread, and ``inject`` runs on whichever thread
|
||||
fires ``send_sync`` (event loop, worker, asset seeder). Individual
|
||||
``dict`` ops are GIL-atomic, but ``register``'s
|
||||
``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}``
|
||||
are multi-step compounds whose interleaving without a lock is
|
||||
racy. The lock is uncontended in steady state (sub-microsecond
|
||||
critical sections) so the cost is negligible.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
|
||||
self._envelopes: dict[str, dict] = {}
|
||||
self._capacity = capacity
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, prompt_id: str, extra_data: Any) -> None:
|
||||
envelope = extract_envelope_from_extra_data(extra_data)
|
||||
if envelope is None:
|
||||
return
|
||||
with self._lock:
|
||||
if len(self._envelopes) >= self._capacity:
|
||||
self._envelopes.pop(next(iter(self._envelopes)))
|
||||
self._envelopes[prompt_id] = envelope
|
||||
|
||||
def unregister(self, prompt_id: str) -> None:
|
||||
with self._lock:
|
||||
self._envelopes.pop(prompt_id, None)
|
||||
|
||||
def inject(self, data: Any) -> Any:
|
||||
# Snapshot the envelope under the lock so the spread in
|
||||
# ``inject_envelope`` runs against a consistent view even if a
|
||||
# concurrent ``register``/``unregister`` is mutating the map.
|
||||
def locked_lookup(prompt_id: str) -> Optional[dict]:
|
||||
with self._lock:
|
||||
return self._envelopes.get(prompt_id)
|
||||
return inject_envelope(data, locked_lookup)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._envelopes)
|
||||
|
||||
def __contains__(self, prompt_id: str) -> bool:
|
||||
with self._lock:
|
||||
return prompt_id in self._envelopes
|
||||
@ -1296,7 +1296,10 @@ class PromptQueue:
|
||||
|
||||
def wipe_queue(self):
|
||||
with self.mutex:
|
||||
dropped_prompt_ids = [item[1] for item in self.queue]
|
||||
self.queue = []
|
||||
for prompt_id in dropped_prompt_ids:
|
||||
self.server.unregister_prompt_metadata(prompt_id)
|
||||
self.server.queue_updated()
|
||||
|
||||
def delete_queue_item(self, function):
|
||||
@ -1306,8 +1309,9 @@ class PromptQueue:
|
||||
if len(self.queue) == 1:
|
||||
self.wipe_queue()
|
||||
else:
|
||||
self.queue.pop(x)
|
||||
deleted = self.queue.pop(x)
|
||||
heapq.heapify(self.queue)
|
||||
self.server.unregister_prompt_metadata(deleted[1])
|
||||
self.server.queue_updated()
|
||||
return True
|
||||
return False
|
||||
|
||||
29
main.py
29
main.py
@ -318,19 +318,26 @@ def prompt_worker(q, server_instance):
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
try:
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
finally:
|
||||
# Always drop the metadata envelope. If e.execute() raises
|
||||
# hard before its own error handling kicks in, the
|
||||
# registered envelope would otherwise leak for the
|
||||
# lifetime of the process.
|
||||
server_instance.unregister_prompt_metadata(prompt_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
|
||||
30
server.py
30
server.py
@ -44,6 +44,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from app.prompt_metadata import PromptMetadataStore
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@ -250,8 +251,14 @@ class PromptServer():
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
self.client_id = None
|
||||
|
||||
# Bounded prompt_id -> envelope store. Populated at submission,
|
||||
# drained on completion/cancel. Keeps workflow scope (and other
|
||||
# client-supplied tags) out of the execution layer.
|
||||
self._prompt_metadata = PromptMetadataStore()
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
|
||||
@routes.get('/ws')
|
||||
@ -275,7 +282,12 @@ class PromptServer():
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
payload: dict = {"node": self.last_node_id}
|
||||
if self.last_prompt_id is not None:
|
||||
payload["prompt_id"] = self.last_prompt_id
|
||||
await self.send(
|
||||
"executing", self._inject_prompt_metadata(payload), sid
|
||||
)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
@ -955,6 +967,7 @@ class PromptServer():
|
||||
if sensitive_val in extra_data:
|
||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||
self.register_prompt_metadata(prompt_id, extra_data)
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
@ -1216,7 +1229,22 @@ class PromptServer():
|
||||
elif sid in self.sockets:
|
||||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||
|
||||
def register_prompt_metadata(self, prompt_id: str, extra_data) -> None:
|
||||
"""Record per-prompt metadata for injection into outbound execution
|
||||
events. Called at submission, before the prompt is queued."""
|
||||
self._prompt_metadata.register(prompt_id, extra_data)
|
||||
|
||||
def unregister_prompt_metadata(self, prompt_id: str) -> None:
|
||||
"""Drop the per-prompt metadata envelope. Called after the prompt
|
||||
has finished executing and its terminal events have been queued,
|
||||
or when the prompt is cancelled before reaching the worker."""
|
||||
self._prompt_metadata.unregister(prompt_id)
|
||||
|
||||
def _inject_prompt_metadata(self, data):
|
||||
return self._prompt_metadata.inject(data)
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
data = self._inject_prompt_metadata(data)
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid))
|
||||
|
||||
|
||||
412
tests-unit/app_test/test_prompt_metadata.py
Normal file
412
tests-unit/app_test/test_prompt_metadata.py
Normal file
@ -0,0 +1,412 @@
|
||||
"""Unit tests for the metadata-envelope module in ``app.prompt_metadata``.
|
||||
|
||||
Covers the two pure helpers (``extract_envelope_from_extra_data`` and
|
||||
``inject_envelope``) and the ``PromptMetadataStore`` integration class
|
||||
that ``PromptServer`` owns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.prompt_metadata import (
|
||||
MAX_ENVELOPE_KEYS,
|
||||
MAX_ENVELOPE_KEY_LEN,
|
||||
MAX_ENVELOPE_VALUE_LEN,
|
||||
PromptMetadataStore,
|
||||
extract_envelope_from_extra_data,
|
||||
inject_envelope,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractEnvelopeFromExtraData:
|
||||
def test_explicit_metadata_dict_is_used_as_is(self):
|
||||
extra_data = {"metadata": {"workflow_id": "wf-1", "trace_id": "t-9"}}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "wf-1",
|
||||
"trace_id": "t-9",
|
||||
}
|
||||
|
||||
def test_explicit_metadata_takes_precedence_over_extra_pnginfo(self):
|
||||
extra_data = {
|
||||
"metadata": {"workflow_id": "explicit"},
|
||||
"extra_pnginfo": {"workflow": {"id": "fallback"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "explicit"
|
||||
}
|
||||
|
||||
def test_falls_back_to_extra_pnginfo_workflow_id(self):
|
||||
extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-legacy"}}}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "wf-legacy"
|
||||
}
|
||||
|
||||
def test_returns_none_when_no_metadata_and_no_workflow_id(self):
|
||||
assert extract_envelope_from_extra_data({}) is None
|
||||
assert (
|
||||
extract_envelope_from_extra_data({"extra_pnginfo": {"workflow": {}}})
|
||||
is None
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("bad", ["", 123, None, [], {}])
|
||||
def test_rejects_non_string_or_empty_workflow_id(self, bad):
|
||||
extra_data = {"extra_pnginfo": {"workflow": {"id": bad}}}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
def test_rejects_non_dict_inputs_at_each_level(self):
|
||||
assert extract_envelope_from_extra_data(None) is None
|
||||
assert extract_envelope_from_extra_data("not-a-dict") is None
|
||||
assert (
|
||||
extract_envelope_from_extra_data({"extra_pnginfo": "not-a-dict"})
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
extract_envelope_from_extra_data(
|
||||
{"extra_pnginfo": {"workflow": "not-a-dict"}}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_empty_explicit_metadata_falls_through_to_workflow_id(self):
|
||||
extra_data = {
|
||||
"metadata": {},
|
||||
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "wf-legacy"
|
||||
}
|
||||
|
||||
def test_returned_envelope_is_copy_not_reference(self):
|
||||
original = {"workflow_id": "wf-1"}
|
||||
result = extract_envelope_from_extra_data({"metadata": original})
|
||||
assert result is not None
|
||||
result["new_key"] = "x"
|
||||
assert "new_key" not in original
|
||||
|
||||
def test_non_dict_explicit_metadata_falls_through_to_workflow_id(self):
|
||||
extra_data = {
|
||||
"metadata": "not-a-dict",
|
||||
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) == {
|
||||
"workflow_id": "wf-legacy"
|
||||
}
|
||||
|
||||
|
||||
class TestEnvelopeSanitization:
|
||||
"""The wire contract is ``dict[str, str]`` with bounded size. A bad
|
||||
envelope is dropped (and a warning is logged) rather than truncated,
|
||||
so the boundary stays strict."""
|
||||
|
||||
def test_rejects_too_many_keys(self, caplog):
|
||||
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}
|
||||
with caplog.at_level("WARNING"):
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
assert any("exceeds limit" in r.message for r in caplog.records)
|
||||
|
||||
def test_accepts_max_keys_exactly(self):
|
||||
envelope = {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS)}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
|
||||
|
||||
def test_rejects_non_string_keys(self, caplog):
|
||||
with caplog.at_level("WARNING"):
|
||||
assert (
|
||||
extract_envelope_from_extra_data({"metadata": {42: "v"}})
|
||||
is None
|
||||
)
|
||||
assert any("non-string" in r.message for r in caplog.records)
|
||||
|
||||
def test_rejects_non_string_values(self, caplog):
|
||||
for bad_value in [42, None, ["x"], {"nested": "dict"}, b"bytes"]:
|
||||
with caplog.at_level("WARNING"):
|
||||
assert (
|
||||
extract_envelope_from_extra_data(
|
||||
{"metadata": {"k": bad_value}}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_rejects_oversized_key(self):
|
||||
envelope = {"x" * (MAX_ENVELOPE_KEY_LEN + 1): "v"}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
|
||||
def test_rejects_oversized_value(self):
|
||||
envelope = {"k": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) is None
|
||||
|
||||
def test_accepts_max_lengths_exactly(self):
|
||||
envelope = {
|
||||
"x" * MAX_ENVELOPE_KEY_LEN: "y" * MAX_ENVELOPE_VALUE_LEN
|
||||
}
|
||||
assert extract_envelope_from_extra_data({"metadata": envelope}) == envelope
|
||||
|
||||
def test_oversized_workflow_id_in_pnginfo_rejected(self):
|
||||
"""The legacy synthesized path also respects the value bound."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {
|
||||
"workflow": {"id": "x" * (MAX_ENVELOPE_VALUE_LEN + 1)}
|
||||
}
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
def test_invalid_explicit_metadata_does_not_fall_through(self):
|
||||
"""An explicit but invalid metadata dict means the caller asked
|
||||
for something specific and got it wrong; the synthesized
|
||||
fallback must not silently substitute."""
|
||||
extra_data = {
|
||||
"metadata": {"k": 42}, # non-string value
|
||||
"extra_pnginfo": {"workflow": {"id": "wf-legacy"}},
|
||||
}
|
||||
assert extract_envelope_from_extra_data(extra_data) is None
|
||||
|
||||
|
||||
class TestInjectEnvelope:
|
||||
@staticmethod
|
||||
def _lookup(table):
|
||||
return table.get
|
||||
|
||||
def test_spreads_envelope_keys_onto_payload(self):
|
||||
"""Envelope keys are merged at the top level so consumers can
|
||||
read them directly (e.g. ``event.workflow_id``)."""
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1", "trace_id": "t-9"}})
|
||||
assert inject_envelope({"node": "5", "prompt_id": "p1"}, lookup) == {
|
||||
"node": "5",
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
"trace_id": "t-9",
|
||||
}
|
||||
|
||||
def test_passthrough_when_prompt_id_not_registered(self):
|
||||
lookup = self._lookup({})
|
||||
data = {"node": "5", "prompt_id": "unknown"}
|
||||
assert inject_envelope(data, lookup) == data
|
||||
|
||||
def test_passthrough_when_payload_lacks_prompt_id(self):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
data = {"status": "ok"}
|
||||
assert inject_envelope(data, lookup) == data
|
||||
|
||||
def test_server_keys_win_on_collision_with_envelope(self):
|
||||
"""A misbehaving client cannot shadow server-emitted fields by
|
||||
stamping the same key in their submission envelope."""
|
||||
lookup = self._lookup({
|
||||
"p1": {"prompt_id": "client-claimed", "node": "spoofed", "workflow_id": "wf-1"}
|
||||
})
|
||||
result = inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
|
||||
assert result["prompt_id"] == "p1"
|
||||
assert result["node"] == "5"
|
||||
assert result["workflow_id"] == "wf-1"
|
||||
|
||||
def test_does_not_mutate_input_dict(self):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
original = {"node": "5", "prompt_id": "p1"}
|
||||
inject_envelope(original, lookup)
|
||||
assert "workflow_id" not in original
|
||||
|
||||
def test_does_not_mutate_envelope_dict(self):
|
||||
envelope = {"workflow_id": "wf-1"}
|
||||
lookup = self._lookup({"p1": envelope})
|
||||
inject_envelope({"prompt_id": "p1", "node": "5"}, lookup)
|
||||
assert envelope == {"workflow_id": "wf-1"}
|
||||
|
||||
def test_injects_into_inner_dict_of_preview_metadata_tuple(self):
|
||||
"""``PREVIEW_IMAGE_WITH_METADATA`` payloads arrive as
|
||||
``(preview_image, metadata_dict)``; the inner dict is the only
|
||||
place the envelope can attach."""
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
preview_image = ("PNG", object(), 256)
|
||||
inner = {"node_id": "5", "prompt_id": "p1"}
|
||||
result = inject_envelope((preview_image, inner), lookup)
|
||||
assert isinstance(result, tuple)
|
||||
assert result[0] is preview_image
|
||||
assert result[1] == {
|
||||
"node_id": "5",
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
}
|
||||
assert "workflow_id" not in inner
|
||||
|
||||
def test_preview_tuple_passthrough_when_no_envelope_registered(self):
|
||||
lookup = self._lookup({})
|
||||
preview_image = ("PNG", object(), 256)
|
||||
inner = {"node_id": "5", "prompt_id": "unknown"}
|
||||
result = inject_envelope((preview_image, inner), lookup)
|
||||
assert result == (preview_image, inner)
|
||||
|
||||
@pytest.mark.parametrize("payload", [b"raw-bytes", None, 42])
|
||||
def test_non_dict_non_tuple_payloads_passthrough(self, payload):
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
assert inject_envelope(payload, lookup) == payload
|
||||
|
||||
def test_tuple_of_wrong_arity_passthrough(self):
|
||||
"""Only the 2-tuple ``(preview, metadata_dict)`` shape is
|
||||
special-cased. Other tuples must not be touched."""
|
||||
lookup = self._lookup({"p1": {"workflow_id": "wf-1"}})
|
||||
triple = (1, {"prompt_id": "p1"}, 3)
|
||||
assert inject_envelope(triple, lookup) is triple
|
||||
|
||||
def test_envelope_lookup_called_per_invocation(self):
|
||||
"""The lookup runs each time the function is called, so changes
|
||||
to the backing store are immediately visible."""
|
||||
store = {"p1": {"workflow_id": "wf-1"}}
|
||||
first = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
store["p1"] = {"workflow_id": "wf-2"}
|
||||
second = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
del store["p1"]
|
||||
third = inject_envelope({"prompt_id": "p1"}, store.get)
|
||||
assert first["workflow_id"] == "wf-1"
|
||||
assert second["workflow_id"] == "wf-2"
|
||||
assert "workflow_id" not in third
|
||||
|
||||
|
||||
class TestPromptMetadataStore:
|
||||
"""End-to-end wiring tests that exercise the full register/inject/
|
||||
unregister cycle the way ``PromptServer`` does."""
|
||||
|
||||
def test_register_inject_unregister_cycle(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register(
|
||||
"p1", {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
|
||||
)
|
||||
injected = store.inject({"node": "5", "prompt_id": "p1"})
|
||||
assert injected == {
|
||||
"node": "5",
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
}
|
||||
store.unregister("p1")
|
||||
passthrough = store.inject({"node": "5", "prompt_id": "p1"})
|
||||
assert "workflow_id" not in passthrough
|
||||
|
||||
def test_register_with_no_derivable_envelope_is_noop(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {})
|
||||
assert "p1" not in store
|
||||
data = {"prompt_id": "p1"}
|
||||
assert store.inject(data) == data
|
||||
|
||||
def test_register_with_oversized_envelope_is_noop(self):
|
||||
"""Sanitization rejection means nothing is registered — the
|
||||
store stays empty and inject is a passthrough."""
|
||||
store = PromptMetadataStore()
|
||||
store.register(
|
||||
"p1",
|
||||
{"metadata": {f"k{i}": "v" for i in range(MAX_ENVELOPE_KEYS + 1)}},
|
||||
)
|
||||
assert "p1" not in store
|
||||
|
||||
def test_unregister_unknown_prompt_is_silent(self):
|
||||
store = PromptMetadataStore()
|
||||
store.unregister("does-not-exist")
|
||||
|
||||
def test_fifo_eviction_when_capacity_exceeded(self):
|
||||
"""If cleanup hooks are ever bypassed, the store must shed the
|
||||
oldest entry rather than grow without bound."""
|
||||
store = PromptMetadataStore(capacity=3)
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
store.register("p2", {"metadata": {"workflow_id": "wf-2"}})
|
||||
store.register("p3", {"metadata": {"workflow_id": "wf-3"}})
|
||||
assert len(store) == 3
|
||||
|
||||
store.register("p4", {"metadata": {"workflow_id": "wf-4"}})
|
||||
assert len(store) == 3
|
||||
assert "p1" not in store
|
||||
assert "p4" in store
|
||||
|
||||
# The newer entries are still injectable.
|
||||
assert store.inject({"prompt_id": "p4"})["workflow_id"] == "wf-4"
|
||||
# The evicted one is gone.
|
||||
assert "workflow_id" not in store.inject({"prompt_id": "p1"})
|
||||
|
||||
def test_register_after_unregister_does_not_count_against_capacity(self):
|
||||
"""Normal lifecycle: register, unregister, register many — the
|
||||
store should not silently evict valid entries because of stale
|
||||
accounting."""
|
||||
store = PromptMetadataStore(capacity=2)
|
||||
for i in range(10):
|
||||
store.register(f"p{i}", {"metadata": {"workflow_id": f"wf-{i}"}})
|
||||
store.unregister(f"p{i}")
|
||||
assert len(store) == 0
|
||||
|
||||
def test_re_register_overwrites(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-2"}})
|
||||
assert store.inject({"prompt_id": "p1"})["workflow_id"] == "wf-2"
|
||||
|
||||
def test_inject_with_no_registrations_is_passthrough(self):
|
||||
store = PromptMetadataStore()
|
||||
data = {"prompt_id": "p1", "node": "5"}
|
||||
assert store.inject(data) == data
|
||||
|
||||
def test_inject_into_preview_tuple(self):
|
||||
store = PromptMetadataStore()
|
||||
store.register("p1", {"metadata": {"workflow_id": "wf-1"}})
|
||||
result = store.inject((b"image-bytes", {"prompt_id": "p1"}))
|
||||
assert result == (b"image-bytes", {
|
||||
"prompt_id": "p1",
|
||||
"workflow_id": "wf-1",
|
||||
})
|
||||
|
||||
def test_concurrent_access_does_not_corrupt_or_raise(self):
|
||||
"""Smoke test for the store's lock. ``register`` is called from
|
||||
the aiohttp event-loop thread, ``unregister`` from the worker
|
||||
thread, and ``inject`` fires on every ``send_sync`` from
|
||||
whichever thread emits the event. Run all three concurrently
|
||||
and assert no exception escapes and the store stays internally
|
||||
consistent (the FIFO cap is never exceeded)."""
|
||||
import threading
|
||||
|
||||
store = PromptMetadataStore(capacity=64)
|
||||
stop = threading.Event()
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def registrar():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.register(
|
||||
f"p{i % 100}",
|
||||
{"metadata": {"workflow_id": f"wf-{i}"}},
|
||||
)
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
def canceller():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.unregister(f"p{i % 100}")
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
def injector():
|
||||
i = 0
|
||||
try:
|
||||
while not stop.is_set():
|
||||
store.inject({"prompt_id": f"p{i % 100}", "node": "5"})
|
||||
i += 1
|
||||
except BaseException as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=registrar),
|
||||
threading.Thread(target=registrar),
|
||||
threading.Thread(target=canceller),
|
||||
threading.Thread(target=injector),
|
||||
threading.Thread(target=injector),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
# Brief burst — long enough to interleave many ops, short enough
|
||||
# not to slow CI.
|
||||
threading.Event().wait(0.1)
|
||||
stop.set()
|
||||
for t in threads:
|
||||
t.join(timeout=2.0)
|
||||
|
||||
assert errors == [], f"concurrent access raised: {errors[:3]}"
|
||||
assert len(store) <= 64, "FIFO cap was breached under contention"
|
||||
Loading…
Reference in New Issue
Block a user