ComfyUI/app/prompt_metadata.py
Deep Mehta 63784baed5 fix(server): serialize PromptMetadataStore access with a lock
Addresses comfyanonymous's review nit on PR #13905. The store is touched
from three threads — the aiohttp event loop (``register`` via
``post_prompt``), the worker thread (``unregister`` via the
``prompt_worker`` try/finally and ``execution_error`` paths), and any
thread that fires ``send_sync`` (``inject``). Individual ``dict``
operations are GIL-atomic but ``register``'s ``len -> pop -> setitem``
and ``inject``'s ``get -> {**a, **b}`` are multi-step compounds whose
interleaving without a lock is racy. A single ``threading.Lock`` keeps
the FIFO cap honest and snapshots the envelope under the lock before
the spread runs.

Adds a stress-test that runs concurrent register/unregister/inject for
100 ms across five threads and asserts no exception escapes and the
capacity bound is held.
2026-05-14 21:26:06 -07:00

227 lines
8.8 KiB
Python

"""Per-prompt metadata envelope shared between submission and outbound events.
The metadata envelope is a small flat ``dict[str, str]`` (e.g.
``{"workflow_id": ...}``) attached to a prompt at submission and injected
by the server into every outbound execution event that carries a
``prompt_id``. It lets consumers scope state by tags they care about
(workflow, trace, tenant) without the execution layer ever needing to
know those tags exist.
This module is intentionally pure — no imports from ``server`` or
``execution`` — so ``PromptServer`` can own a ``PromptMetadataStore``
instance and the helpers can be unit-tested without the rest of the app.
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Callable, Optional
# Bounds. The envelope is forwarded to every WebSocket client connected to
# the server on every execution event for the prompt — bounding key count,
# key length, value length, and refusing nested structures keeps a
# malicious or buggy client from inflating the broadcast volume.
MAX_ENVELOPE_KEYS = 16
MAX_ENVELOPE_KEY_LEN = 64
MAX_ENVELOPE_VALUE_LEN = 256
# Cap on concurrently registered prompt envelopes. Acts as a backstop if
# the cleanup hook is ever bypassed; FIFO eviction so the oldest stale
# entry goes first.
DEFAULT_STORE_CAPACITY = 4096
def _sanitize_envelope(envelope: Any) -> Optional[dict]:
"""Validate and copy a candidate envelope.
Enforces the ``dict[str, str]`` contract that downstream consumers
(cloud projections, frontend zod schemas, OpenAPI docs) rely on:
- must be a non-empty ``dict``
- at most ``MAX_ENVELOPE_KEYS`` entries
- every key and value must be a ``str``
- keys at most ``MAX_ENVELOPE_KEY_LEN`` chars
- values at most ``MAX_ENVELOPE_VALUE_LEN`` chars
Returns a defensive shallow copy on success, ``None`` on any
violation. Logs a warning on violation so abuse is visible.
"""
if not isinstance(envelope, dict) or not envelope:
return None
if len(envelope) > MAX_ENVELOPE_KEYS:
logging.warning(
"prompt metadata envelope rejected: %d keys exceeds limit %d",
len(envelope), MAX_ENVELOPE_KEYS,
)
return None
sanitized: dict[str, str] = {}
for key, value in envelope.items():
if not isinstance(key, str) or not isinstance(value, str):
logging.warning(
"prompt metadata envelope rejected: non-string key/value (%s=%s)",
type(key).__name__, type(value).__name__,
)
return None
if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN:
logging.warning(
"prompt metadata envelope rejected: key or value exceeds length limit",
)
return None
sanitized[key] = value
return sanitized
def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]:
"""Pull the per-prompt metadata envelope out of a submitted prompt's
``extra_data``.
Two sources, in order:
1. Explicit ``extra_data["metadata"]`` — sanitized via
``_sanitize_envelope``. Oversized or wrong-typed envelopes are
rejected (a warning is logged) rather than truncated, so the
contract stays strict at the boundary.
2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward-
compatibility fallback. Frontends that already stamp the workflow
id into ``extra_pnginfo`` keep working; the synthesized envelope
is ``{"workflow_id": <id>}``. A debug log fires so the legacy path
remains observable.
Returns ``None`` when neither source yields a usable envelope.
"""
if not isinstance(extra_data, dict):
return None
if "metadata" in extra_data:
sanitized = _sanitize_envelope(extra_data["metadata"])
if sanitized is not None:
return sanitized
# Explicit metadata was supplied but rejected — do not fall
# through to the legacy path; the caller asked for something
# specific and got it wrong.
if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]:
return None
extra_pnginfo = extra_data.get("extra_pnginfo")
if isinstance(extra_pnginfo, dict):
workflow = extra_pnginfo.get("workflow")
if isinstance(workflow, dict):
workflow_id = workflow.get("id")
if (
isinstance(workflow_id, str)
and workflow_id
and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN
):
logging.debug(
"prompt metadata envelope synthesized from extra_pnginfo.workflow.id"
)
return {"workflow_id": workflow_id}
return None
def inject_envelope(
data: Any,
envelope_lookup: Callable[[str], Optional[dict]],
) -> Any:
"""Return ``data`` with the per-prompt envelope's keys spread onto it.
``envelope_lookup`` is called with the payload's ``prompt_id`` and is
expected to return the registered envelope or ``None``. This keeps
the function pure and avoids depending on any specific storage.
The envelope's keys are merged onto the payload at the top level so
consumers can read them directly (e.g. ``event.workflow_id``) —
matching the wire shape of the prior workflow-id-on-events work and
avoiding an extra nesting hop for clients. Server-emitted fields on
the payload always win on collision (``{**envelope, **d}``); a
misbehaving client cannot shadow ``prompt_id``, ``node``, etc.
Two payload shapes are handled:
- **dict** carrying ``prompt_id``. A shallow copy is returned with
the envelope's keys merged onto it.
- **(preview_image, metadata_dict) tuple** — the format used by
``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented;
the binary preview is passed through by reference.
No-op for payloads without a ``prompt_id``, prompts with no
registered envelope, or any other payload shape.
"""
def inject(d: dict) -> dict:
if not isinstance(d, dict):
return d
prompt_id = d.get("prompt_id")
if not prompt_id:
return d
envelope = envelope_lookup(prompt_id)
if envelope is None:
return d
return {**envelope, **d}
if isinstance(data, dict):
return inject(data)
if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict):
injected = inject(data[1])
if injected is data[1]:
return data
return (data[0], injected)
return data
class PromptMetadataStore:
"""Bounded ``prompt_id -> envelope`` map.
Owned by ``PromptServer``. Populated at submission, drained when the
prompt finishes, wiped on queue cancel/delete. The FIFO cap is a
backstop: if any cleanup hook is ever skipped, the store sheds the
oldest entry instead of growing without bound.
Access is serialized through a ``threading.Lock``. ``register`` runs
on the aiohttp event-loop thread, ``unregister`` runs on the
``prompt_worker`` thread, and ``inject`` runs on whichever thread
fires ``send_sync`` (event loop, worker, asset seeder). Individual
``dict`` ops are GIL-atomic, but ``register``'s
``len() -> pop -> __setitem__`` and ``inject``'s ``get -> {**a, **b}``
are multi-step compounds whose interleaving without a lock is
racy. The lock is uncontended in steady state (sub-microsecond
critical sections) so the cost is negligible.
"""
def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY):
self._envelopes: dict[str, dict] = {}
self._capacity = capacity
self._lock = threading.Lock()
def register(self, prompt_id: str, extra_data: Any) -> None:
envelope = extract_envelope_from_extra_data(extra_data)
if envelope is None:
return
with self._lock:
if len(self._envelopes) >= self._capacity:
self._envelopes.pop(next(iter(self._envelopes)))
self._envelopes[prompt_id] = envelope
def unregister(self, prompt_id: str) -> None:
with self._lock:
self._envelopes.pop(prompt_id, None)
def inject(self, data: Any) -> Any:
# Snapshot the envelope under the lock so the spread in
# ``inject_envelope`` runs against a consistent view even if a
# concurrent ``register``/``unregister`` is mutating the map.
def locked_lookup(prompt_id: str) -> Optional[dict]:
with self._lock:
return self._envelopes.get(prompt_id)
return inject_envelope(data, locked_lookup)
def __len__(self) -> int:
with self._lock:
return len(self._envelopes)
def __contains__(self, prompt_id: str) -> bool:
with self._lock:
return prompt_id in self._envelopes