"""Per-prompt metadata envelope shared between submission and outbound events. The metadata envelope is a small flat ``dict[str, str]`` (e.g. ``{"workflow_id": ...}``) attached to a prompt at submission and injected by the server into every outbound execution event that carries a ``prompt_id``. It lets consumers scope state by tags they care about (workflow, trace, tenant) without the execution layer ever needing to know those tags exist. This module is intentionally pure — no imports from ``server`` or ``execution`` — so ``PromptServer`` can own a ``PromptMetadataStore`` instance and the helpers can be unit-tested without the rest of the app. """ from __future__ import annotations import logging from typing import Any, Callable, Optional # Bounds. The envelope is forwarded to every WebSocket client connected to # the server on every execution event for the prompt — bounding key count, # key length, value length, and refusing nested structures keeps a # malicious or buggy client from inflating the broadcast volume. MAX_ENVELOPE_KEYS = 16 MAX_ENVELOPE_KEY_LEN = 64 MAX_ENVELOPE_VALUE_LEN = 256 # Cap on concurrently registered prompt envelopes. Acts as a backstop if # the cleanup hook is ever bypassed; FIFO eviction so the oldest stale # entry goes first. DEFAULT_STORE_CAPACITY = 4096 def _sanitize_envelope(envelope: Any) -> Optional[dict]: """Validate and copy a candidate envelope. Enforces the ``dict[str, str]`` contract that downstream consumers (cloud projections, frontend zod schemas, OpenAPI docs) rely on: - must be a non-empty ``dict`` - at most ``MAX_ENVELOPE_KEYS`` entries - every key and value must be a ``str`` - keys at most ``MAX_ENVELOPE_KEY_LEN`` chars - values at most ``MAX_ENVELOPE_VALUE_LEN`` chars Returns a defensive shallow copy on success, ``None`` on any violation. Logs a warning on violation so abuse is visible. """ if not isinstance(envelope, dict) or not envelope: return None if len(envelope) > MAX_ENVELOPE_KEYS: logging.warning( "prompt metadata envelope rejected: %d keys exceeds limit %d", len(envelope), MAX_ENVELOPE_KEYS, ) return None sanitized: dict[str, str] = {} for key, value in envelope.items(): if not isinstance(key, str) or not isinstance(value, str): logging.warning( "prompt metadata envelope rejected: non-string key/value (%s=%s)", type(key).__name__, type(value).__name__, ) return None if len(key) > MAX_ENVELOPE_KEY_LEN or len(value) > MAX_ENVELOPE_VALUE_LEN: logging.warning( "prompt metadata envelope rejected: key or value exceeds length limit", ) return None sanitized[key] = value return sanitized def extract_envelope_from_extra_data(extra_data: Any) -> Optional[dict]: """Pull the per-prompt metadata envelope out of a submitted prompt's ``extra_data``. Two sources, in order: 1. Explicit ``extra_data["metadata"]`` — sanitized via ``_sanitize_envelope``. Oversized or wrong-typed envelopes are rejected (a warning is logged) rather than truncated, so the contract stays strict at the boundary. 2. ``extra_data["extra_pnginfo"]["workflow"]["id"]`` — backward- compatibility fallback. Frontends that already stamp the workflow id into ``extra_pnginfo`` keep working; the synthesized envelope is ``{"workflow_id": }``. A debug log fires so the legacy path remains observable. Returns ``None`` when neither source yields a usable envelope. """ if not isinstance(extra_data, dict): return None if "metadata" in extra_data: sanitized = _sanitize_envelope(extra_data["metadata"]) if sanitized is not None: return sanitized # Explicit metadata was supplied but rejected — do not fall # through to the legacy path; the caller asked for something # specific and got it wrong. if isinstance(extra_data["metadata"], dict) and extra_data["metadata"]: return None extra_pnginfo = extra_data.get("extra_pnginfo") if isinstance(extra_pnginfo, dict): workflow = extra_pnginfo.get("workflow") if isinstance(workflow, dict): workflow_id = workflow.get("id") if ( isinstance(workflow_id, str) and workflow_id and len(workflow_id) <= MAX_ENVELOPE_VALUE_LEN ): logging.debug( "prompt metadata envelope synthesized from extra_pnginfo.workflow.id" ) return {"workflow_id": workflow_id} return None def inject_envelope( data: Any, envelope_lookup: Callable[[str], Optional[dict]], ) -> Any: """Return ``data`` with a per-prompt ``metadata`` envelope attached. ``envelope_lookup`` is called with the payload's ``prompt_id`` and is expected to return the registered envelope or ``None``. This keeps the function pure and avoids depending on any specific storage. Two payload shapes are handled: - **dict** carrying ``prompt_id``. A shallow copy is returned with a ``metadata`` key set to the envelope. - **(preview_image, metadata_dict) tuple** — the format used by ``PREVIEW_IMAGE_WITH_METADATA``. Only the inner dict is augmented; the binary preview is passed through by reference. No-op for payloads without a ``prompt_id``, payloads already declaring their own ``metadata`` field, prompts with no registered envelope, or any other payload shape. """ def inject(d: dict) -> dict: if not isinstance(d, dict): return d if "metadata" in d: return d prompt_id = d.get("prompt_id") if not prompt_id: return d envelope = envelope_lookup(prompt_id) if envelope is None: return d return {**d, "metadata": envelope} if isinstance(data, dict): return inject(data) if isinstance(data, tuple) and len(data) == 2 and isinstance(data[1], dict): injected = inject(data[1]) if injected is data[1]: return data return (data[0], injected) return data class PromptMetadataStore: """Bounded ``prompt_id -> envelope`` map. Owned by ``PromptServer``. Populated at submission, drained when the prompt finishes, wiped on queue cancel/delete. The FIFO cap is a backstop: if any cleanup hook is ever skipped, the store sheds the oldest entry instead of growing without bound. """ def __init__(self, capacity: int = DEFAULT_STORE_CAPACITY): self._envelopes: dict[str, dict] = {} self._capacity = capacity def register(self, prompt_id: str, extra_data: Any) -> None: envelope = extract_envelope_from_extra_data(extra_data) if envelope is None: return if len(self._envelopes) >= self._capacity: self._envelopes.pop(next(iter(self._envelopes))) self._envelopes[prompt_id] = envelope def unregister(self, prompt_id: str) -> None: self._envelopes.pop(prompt_id, None) def inject(self, data: Any) -> Any: return inject_envelope(data, self._envelopes.get) def __len__(self) -> int: return len(self._envelopes) def __contains__(self, prompt_id: str) -> bool: return prompt_id in self._envelopes