mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-19 14:29:33 +08:00
Merge 85a12d0a83 into 6b61918a16
This commit is contained in:
commit
d3a95d2bc9
@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
|
||||||
|
"""Extract the workflow id from a prompt's ``extra_data``.
|
||||||
|
|
||||||
|
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
|
||||||
|
when a prompt is queued. Any value that is not a non-empty string is treated as
|
||||||
|
missing so callers can rely on the return being either ``None`` or a string.
|
||||||
|
"""
|
||||||
|
if not isinstance(extra_data, dict):
|
||||||
|
return None
|
||||||
|
extra_pnginfo = extra_data.get('extra_pnginfo')
|
||||||
|
if not isinstance(extra_pnginfo, dict):
|
||||||
|
return None
|
||||||
|
workflow = extra_pnginfo.get('workflow')
|
||||||
|
if not isinstance(workflow, dict):
|
||||||
|
return None
|
||||||
|
workflow_id = workflow.get('id')
|
||||||
|
if isinstance(workflow_id, str) and workflow_id:
|
||||||
|
return workflow_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||||
"""Extract create_time and workflow_id from extra_data.
|
"""Extract create_time and workflow_id from extra_data.
|
||||||
|
|
||||||
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
|||||||
tuple: (create_time, workflow_id)
|
tuple: (create_time, workflow_id)
|
||||||
"""
|
"""
|
||||||
create_time = extra_data.get('create_time')
|
create_time = extra_data.get('create_time')
|
||||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
workflow_id = extract_workflow_id(extra_data)
|
||||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
|
||||||
return create_time, workflow_id
|
return create_time, workflow_id
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
86
comfy_execution/metadata.py
Normal file
86
comfy_execution/metadata.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""Per-prompt metadata propagated alongside execution WebSocket events.
|
||||||
|
|
||||||
|
The execution layer (``execution.py``) is intentionally kept agnostic of
|
||||||
|
workflow-level concepts. ``PromptServer`` registers metadata at submission time
|
||||||
|
and merges it onto outgoing WebSocket payloads in ``send_sync``. Today only
|
||||||
|
``workflow_id`` is propagated; the structure is a ``TypedDict`` so additional
|
||||||
|
keys can be added without churn at the call sites.
|
||||||
|
|
||||||
|
Identity model — registry is keyed by an internal monotonic token, NOT by
|
||||||
|
``prompt_id``. ``post_prompt`` accepts a client-supplied ``prompt_id``
|
||||||
|
verbatim, so two prompts can be queued with the same id and a registry keyed
|
||||||
|
only by ``prompt_id`` would misattribute events when queue order differs from
|
||||||
|
registration order. ``main.py``'s queue worker pins the active token on the
|
||||||
|
server for the duration of one prompt's execution and ``send_sync`` reads that
|
||||||
|
token — so each prompt's events get its own registered metadata regardless of
|
||||||
|
``prompt_id`` collisions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import Optional, TypedDict
|
||||||
|
|
||||||
|
from comfy_execution.jobs import extract_workflow_id
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_METADATA_TOKEN_KEY = "__prompt_metadata_token"
|
||||||
|
|
||||||
|
|
||||||
|
class PromptMetadata(TypedDict, total=False):
|
||||||
|
workflow_id: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt_metadata(extra_data: Optional[dict]) -> PromptMetadata:
|
||||||
|
"""Build a ``PromptMetadata`` snapshot from a prompt's ``extra_data``.
|
||||||
|
|
||||||
|
Returns an empty dict when no recognized metadata is present so callers can
|
||||||
|
skip registering anything in that case.
|
||||||
|
"""
|
||||||
|
meta: PromptMetadata = {}
|
||||||
|
workflow_id = extract_workflow_id(extra_data)
|
||||||
|
if workflow_id is not None:
|
||||||
|
meta["workflow_id"] = workflow_id
|
||||||
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_progress_text_sid(sid, default_sid):
|
||||||
|
"""Pick the recipient for a ``send_progress_text`` binary frame.
|
||||||
|
|
||||||
|
Returns ``default_sid`` (typically ``PromptServer.client_id`` — the client
|
||||||
|
that submitted the active prompt) when the caller didn't pin a specific
|
||||||
|
socket. This narrows the audience for text status updates from "every
|
||||||
|
connected client" to "the client running this prompt", matching the
|
||||||
|
cross-client isolation other execution events already have.
|
||||||
|
|
||||||
|
Splitting this out keeps the unit test independent of the full ``server``
|
||||||
|
import chain.
|
||||||
|
"""
|
||||||
|
return default_sid if sid is None else sid
|
||||||
|
|
||||||
|
|
||||||
|
def merge_prompt_metadata(
|
||||||
|
registry: dict,
|
||||||
|
lock: threading.Lock,
|
||||||
|
active_token: Optional[int],
|
||||||
|
data,
|
||||||
|
):
|
||||||
|
"""Return ``data`` with the metadata for the currently-active token merged
|
||||||
|
top-level. The event payload wins on conflict, and non-dict payloads (e.g.
|
||||||
|
the binary preview tuple) pass through untouched.
|
||||||
|
|
||||||
|
The active token is set by ``main.py``'s queue worker around each
|
||||||
|
``e.execute(...)``. The merge happens only when a token is currently active
|
||||||
|
*and* the payload carries a ``prompt_id`` — using ``prompt_id`` purely as a
|
||||||
|
marker that this is an execution event meant to receive metadata, not as
|
||||||
|
the lookup key. This makes the merge immune to ``prompt_id`` collisions.
|
||||||
|
"""
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
if not data.get("prompt_id"):
|
||||||
|
return data
|
||||||
|
if active_token is None:
|
||||||
|
return data
|
||||||
|
with lock:
|
||||||
|
meta = registry.get(active_token)
|
||||||
|
if not meta:
|
||||||
|
return data
|
||||||
|
return {**meta, **data}
|
||||||
@ -159,11 +159,19 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
def set_registry(self, registry: "ProgressRegistry"):
|
def set_registry(self, registry: "ProgressRegistry"):
|
||||||
self.registry = registry
|
self.registry = registry
|
||||||
|
|
||||||
|
def _lookup_workflow_id(self, prompt_id: str) -> Optional[str]:
|
||||||
|
get_meta = getattr(self.server_instance, "get_active_prompt_metadata", None)
|
||||||
|
if get_meta is None:
|
||||||
|
return None
|
||||||
|
return get_meta().get("workflow_id")
|
||||||
|
|
||||||
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
|
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
|
||||||
"""Send the current progress state to the client"""
|
"""Send the current progress state to the client"""
|
||||||
if self.server_instance is None:
|
if self.server_instance is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
workflow_id = self._lookup_workflow_id(prompt_id)
|
||||||
|
|
||||||
# Only send info for non-pending nodes
|
# Only send info for non-pending nodes
|
||||||
active_nodes = {
|
active_nodes = {
|
||||||
node_id: {
|
node_id: {
|
||||||
@ -172,6 +180,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
"state": state["state"].value,
|
"state": state["state"].value,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": workflow_id,
|
||||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||||
@ -181,7 +190,10 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Send a combined progress_state message with all node states
|
# Send a combined progress_state message with all node states
|
||||||
# Include client_id to ensure message is only sent to the initiating client
|
# Include client_id to ensure message is only sent to the initiating client.
|
||||||
|
# The outer ``workflow_id`` is merged in by ``PromptServer.send_sync`` via
|
||||||
|
# the per-prompt metadata registry; the nested copy on each node entry
|
||||||
|
# mirrors the wire shape consumed by the frontend.
|
||||||
self.server_instance.send_sync(
|
self.server_instance.send_sync(
|
||||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||||
)
|
)
|
||||||
@ -215,6 +227,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": self._lookup_workflow_id(prompt_id),
|
||||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||||
node_id
|
node_id
|
||||||
),
|
),
|
||||||
|
|||||||
13
execution.py
13
execution.py
@ -36,6 +36,7 @@ from comfy_execution.graph import (
|
|||||||
get_input_info,
|
get_input_info,
|
||||||
)
|
)
|
||||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
|
from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
@ -1294,20 +1295,32 @@ class PromptQueue:
|
|||||||
with self.mutex:
|
with self.mutex:
|
||||||
return len(self.queue) + len(self.currently_running)
|
return len(self.queue) + len(self.currently_running)
|
||||||
|
|
||||||
|
def _extract_metadata_token(self, item):
|
||||||
|
# Queue item shape: (number, prompt_id, prompt, extra_data, outputs, sensitive)
|
||||||
|
extra_data = item[3] if len(item) > 3 and isinstance(item[3], dict) else None
|
||||||
|
if not extra_data:
|
||||||
|
return None
|
||||||
|
return extra_data.get(PROMPT_METADATA_TOKEN_KEY)
|
||||||
|
|
||||||
def wipe_queue(self):
|
def wipe_queue(self):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
|
cancelled_tokens = [self._extract_metadata_token(item) for item in self.queue]
|
||||||
self.queue = []
|
self.queue = []
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
for token in cancelled_tokens:
|
||||||
|
self.server.unregister_prompt_metadata(token)
|
||||||
|
|
||||||
def delete_queue_item(self, function):
|
def delete_queue_item(self, function):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
for x in range(len(self.queue)):
|
for x in range(len(self.queue)):
|
||||||
if function(self.queue[x]):
|
if function(self.queue[x]):
|
||||||
|
cancelled_token = self._extract_metadata_token(self.queue[x])
|
||||||
if len(self.queue) == 1:
|
if len(self.queue) == 1:
|
||||||
self.wipe_queue()
|
self.wipe_queue()
|
||||||
else:
|
else:
|
||||||
self.queue.pop(x)
|
self.queue.pop(x)
|
||||||
heapq.heapify(self.queue)
|
heapq.heapify(self.queue)
|
||||||
|
self.server.unregister_prompt_metadata(cancelled_token)
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
37
main.py
37
main.py
@ -27,6 +27,7 @@ from utils.mime_types import init_mime_types
|
|||||||
import faulthandler
|
import faulthandler
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY
|
||||||
from comfy_execution.progress import get_progress_state
|
from comfy_execution.progress import get_progress_state
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
@ -317,20 +318,34 @@ def prompt_worker(q, server_instance):
|
|||||||
for k in sensitive:
|
for k in sensitive:
|
||||||
extra_data[k] = sensitive[k]
|
extra_data[k] = sensitive[k]
|
||||||
|
|
||||||
|
# Pin the metadata token registered for this exact prompt while it
|
||||||
|
# runs so ``send_sync`` can decorate its frames with the right
|
||||||
|
# ``workflow_id`` even if another submission shares the same
|
||||||
|
# ``prompt_id``.
|
||||||
|
metadata_token = extra_data.pop(PROMPT_METADATA_TOKEN_KEY, None)
|
||||||
|
server_instance.active_prompt_metadata_token = metadata_token
|
||||||
|
|
||||||
asset_seeder.pause()
|
asset_seeder.pause()
|
||||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
try:
|
||||||
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
|
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
e.history_result,
|
e.history_result,
|
||||||
status=execution.PromptQueue.ExecutionStatus(
|
status=execution.PromptQueue.ExecutionStatus(
|
||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages), process_item=remove_sensitive)
|
messages=e.status_messages), process_item=remove_sensitive)
|
||||||
if server_instance.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||||
|
finally:
|
||||||
|
# Drop the per-prompt metadata only AFTER the terminal "executing"
|
||||||
|
# send so the registered workflow_id is merged onto that frame.
|
||||||
|
# This is what eliminates the #13684 finally-clear race.
|
||||||
|
server_instance.active_prompt_metadata_token = None
|
||||||
|
server_instance.unregister_prompt_metadata(metadata_token)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
|
|||||||
82
server.py
82
server.py
@ -8,7 +8,15 @@ import time
|
|||||||
import nodes
|
import nodes
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import execution
|
import execution
|
||||||
|
import threading
|
||||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
||||||
|
from comfy_execution.metadata import (
|
||||||
|
PROMPT_METADATA_TOKEN_KEY,
|
||||||
|
PromptMetadata,
|
||||||
|
build_prompt_metadata,
|
||||||
|
merge_prompt_metadata,
|
||||||
|
resolve_progress_text_sid,
|
||||||
|
)
|
||||||
import uuid
|
import uuid
|
||||||
import urllib
|
import urllib
|
||||||
import json
|
import json
|
||||||
@ -252,6 +260,15 @@ class PromptServer():
|
|||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.client_id = None
|
self.client_id = None
|
||||||
|
|
||||||
|
# Keyed by an internal monotonic token rather than by ``prompt_id``
|
||||||
|
# because clients can supply the same ``prompt_id`` on retry/dedupe.
|
||||||
|
self.prompt_metadata: dict[int, PromptMetadata] = {}
|
||||||
|
self._prompt_metadata_lock = threading.Lock()
|
||||||
|
self._prompt_metadata_counter = 0
|
||||||
|
# Set by ``main.py``'s queue worker for the duration of one prompt's
|
||||||
|
# execution; read by ``send_sync`` to merge the right metadata.
|
||||||
|
self.active_prompt_metadata_token: Optional[int] = None
|
||||||
|
|
||||||
self.on_prompt_handlers = []
|
self.on_prompt_handlers = []
|
||||||
|
|
||||||
@routes.get('/ws')
|
@routes.get('/ws')
|
||||||
@ -275,7 +292,12 @@ class PromptServer():
|
|||||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||||
# On reconnect if we are the currently executing client send the current node
|
# On reconnect if we are the currently executing client send the current node
|
||||||
if self.client_id == sid and self.last_node_id is not None:
|
if self.client_id == sid and self.last_node_id is not None:
|
||||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
last_prompt_id = getattr(self, "last_prompt_id", None)
|
||||||
|
payload: dict = {"node": self.last_node_id}
|
||||||
|
if last_prompt_id:
|
||||||
|
payload["prompt_id"] = last_prompt_id
|
||||||
|
payload.update(self.get_active_prompt_metadata())
|
||||||
|
await self.send("executing", payload, sid)
|
||||||
|
|
||||||
# Flag to track if we've received the first message
|
# Flag to track if we've received the first message
|
||||||
first_message = True
|
first_message = True
|
||||||
@ -955,6 +977,9 @@ class PromptServer():
|
|||||||
if sensitive_val in extra_data:
|
if sensitive_val in extra_data:
|
||||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||||
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||||
|
token = self.register_prompt_metadata(extra_data)
|
||||||
|
if token is not None:
|
||||||
|
extra_data[PROMPT_METADATA_TOKEN_KEY] = token
|
||||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||||
return web.json_response(response)
|
return web.json_response(response)
|
||||||
@ -1216,7 +1241,50 @@ class PromptServer():
|
|||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||||
|
|
||||||
|
def register_prompt_metadata(self, extra_data: dict) -> Optional[int]:
|
||||||
|
"""Capture per-prompt metadata at submission time and return a token
|
||||||
|
identifying this registration.
|
||||||
|
|
||||||
|
Returns ``None`` when there is no recognized metadata, signalling that
|
||||||
|
no token needs to be threaded through the queue. Otherwise the token
|
||||||
|
must be stored on the queue item (typically via
|
||||||
|
:data:`PROMPT_METADATA_TOKEN_KEY` in ``extra_data``) and pinned on the
|
||||||
|
server as ``active_prompt_metadata_token`` while the prompt runs, so
|
||||||
|
the merge in ``send_sync`` picks up this prompt's metadata even when
|
||||||
|
another prompt is registered under the same ``prompt_id``.
|
||||||
|
"""
|
||||||
|
meta = build_prompt_metadata(extra_data)
|
||||||
|
if not meta:
|
||||||
|
return None
|
||||||
|
with self._prompt_metadata_lock:
|
||||||
|
self._prompt_metadata_counter += 1
|
||||||
|
token = self._prompt_metadata_counter
|
||||||
|
self.prompt_metadata[token] = meta
|
||||||
|
return token
|
||||||
|
|
||||||
|
def unregister_prompt_metadata(self, token: Optional[int]) -> None:
|
||||||
|
if token is None:
|
||||||
|
return
|
||||||
|
with self._prompt_metadata_lock:
|
||||||
|
self.prompt_metadata.pop(token, None)
|
||||||
|
|
||||||
|
def get_prompt_metadata_by_token(self, token: Optional[int]) -> PromptMetadata:
|
||||||
|
if token is None:
|
||||||
|
return {}
|
||||||
|
with self._prompt_metadata_lock:
|
||||||
|
return dict(self.prompt_metadata.get(token, {}))
|
||||||
|
|
||||||
|
def get_active_prompt_metadata(self) -> PromptMetadata:
|
||||||
|
"""Snapshot of the metadata for the currently-executing prompt."""
|
||||||
|
return self.get_prompt_metadata_by_token(self.active_prompt_metadata_token)
|
||||||
|
|
||||||
def send_sync(self, event, data, sid=None):
|
def send_sync(self, event, data, sid=None):
|
||||||
|
data = merge_prompt_metadata(
|
||||||
|
self.prompt_metadata,
|
||||||
|
self._prompt_metadata_lock,
|
||||||
|
self.active_prompt_metadata_token,
|
||||||
|
data,
|
||||||
|
)
|
||||||
self.loop.call_soon_threadsafe(
|
self.loop.call_soon_threadsafe(
|
||||||
self.messages.put_nowait, (event, data, sid))
|
self.messages.put_nowait, (event, data, sid))
|
||||||
|
|
||||||
@ -1285,7 +1353,10 @@ class PromptServer():
|
|||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
def send_progress_text(
|
def send_progress_text(
|
||||||
self, text: Union[bytes, bytearray, str], node_id: str, sid=None
|
self,
|
||||||
|
text: Union[bytes, bytearray, str],
|
||||||
|
node_id: str,
|
||||||
|
sid=None,
|
||||||
):
|
):
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = text.encode("utf-8")
|
text = text.encode("utf-8")
|
||||||
@ -1294,4 +1365,11 @@ class PromptServer():
|
|||||||
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
|
# Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes
|
||||||
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text
|
||||||
|
|
||||||
|
# Default routing to the active prompt's client so other clients don't
|
||||||
|
# silently receive untagged text frames. The binary wire format does
|
||||||
|
# not yet carry prompt_id/workflow_id, so cross-tab filtering inside a
|
||||||
|
# single client still depends on a follow-up wire-format change with a
|
||||||
|
# feature flag.
|
||||||
|
sid = resolve_progress_text_sid(sid, self.client_id)
|
||||||
|
|
||||||
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
self.send_sync(BinaryEventTypes.TEXT, message, sid)
|
||||||
|
|||||||
182
tests-unit/server_test/test_prompt_metadata.py
Normal file
182
tests-unit/server_test/test_prompt_metadata.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
"""Tests for the per-prompt metadata registry used to propagate ``workflow_id``
|
||||||
|
through WebSocket events without coupling ``execution.py`` to workflow-level
|
||||||
|
concepts.
|
||||||
|
|
||||||
|
The registry is keyed by an internal monotonic token (NOT by ``prompt_id``)
|
||||||
|
because ``post_prompt`` accepts a client-supplied ``prompt_id`` verbatim and
|
||||||
|
two prompts can share an id. ``main.py``'s queue worker pins the active token
|
||||||
|
on the server around each ``e.execute(...)`` and the merge in ``send_sync``
|
||||||
|
reads that pinned token, so each prompt's events get its own metadata
|
||||||
|
regardless of ``prompt_id`` collisions or queue-vs-stack ordering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy_execution.metadata import (
|
||||||
|
build_prompt_metadata,
|
||||||
|
merge_prompt_metadata,
|
||||||
|
resolve_progress_text_sid,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lock():
|
||||||
|
return threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildPromptMetadata:
|
||||||
|
def test_returns_workflow_id_when_present(self):
|
||||||
|
extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
|
||||||
|
assert build_prompt_metadata(extra_data) == {"workflow_id": "wf-1"}
|
||||||
|
|
||||||
|
def test_empty_when_workflow_id_missing(self):
|
||||||
|
assert build_prompt_metadata({}) == {}
|
||||||
|
assert build_prompt_metadata({"extra_pnginfo": {}}) == {}
|
||||||
|
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {}}}) == {}
|
||||||
|
|
||||||
|
def test_empty_when_workflow_id_not_a_non_empty_string(self):
|
||||||
|
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": ""}}}) == {}
|
||||||
|
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": 42}}}) == {}
|
||||||
|
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": None}}}) == {}
|
||||||
|
|
||||||
|
def test_empty_on_non_dict_input(self):
|
||||||
|
assert build_prompt_metadata(None) == {}
|
||||||
|
assert build_prompt_metadata("not a dict") == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeMetadata:
|
||||||
|
"""``merge_prompt_metadata`` decorates execution events with the metadata
|
||||||
|
for the currently-active token. Event payload fields win on conflict,
|
||||||
|
binary payloads pass through, and the merge is gated on a ``prompt_id``
|
||||||
|
marker to avoid decorating server-status events like ``status`` /
|
||||||
|
``queue_updated``."""
|
||||||
|
|
||||||
|
def test_merges_for_active_token_when_payload_has_prompt_id(self, registry, lock):
|
||||||
|
registry[42] = {"workflow_id": "wf-1"}
|
||||||
|
merged = merge_prompt_metadata(registry, lock, 42, {"node": "n1", "prompt_id": "p1"})
|
||||||
|
assert merged == {"node": "n1", "prompt_id": "p1", "workflow_id": "wf-1"}
|
||||||
|
|
||||||
|
def test_passthrough_when_no_active_token(self, registry, lock):
|
||||||
|
registry[42] = {"workflow_id": "wf-1"}
|
||||||
|
merged = merge_prompt_metadata(registry, lock, None, {"node": "n1", "prompt_id": "p1"})
|
||||||
|
assert merged == {"node": "n1", "prompt_id": "p1"}
|
||||||
|
|
||||||
|
def test_passthrough_when_active_token_unknown(self, registry, lock):
|
||||||
|
# Token was unregistered already (or never registered) — merge is a no-op.
|
||||||
|
merged = merge_prompt_metadata(registry, lock, 99, {"prompt_id": "p1"})
|
||||||
|
assert merged == {"prompt_id": "p1"}
|
||||||
|
|
||||||
|
def test_passthrough_when_no_prompt_id(self, registry, lock):
|
||||||
|
# Server-status frames (status, queue_updated, etc.) carry no prompt_id
|
||||||
|
# and must not be decorated.
|
||||||
|
registry[42] = {"workflow_id": "wf-1"}
|
||||||
|
merged = merge_prompt_metadata(registry, lock, 42, {"status": {"queue_remaining": 0}})
|
||||||
|
assert merged == {"status": {"queue_remaining": 0}}
|
||||||
|
|
||||||
|
def test_passthrough_for_non_dict_payload(self, registry, lock):
|
||||||
|
registry[42] = {"workflow_id": "wf-1"}
|
||||||
|
binary = (b"image-bytes", {"prompt_id": "p1"})
|
||||||
|
assert merge_prompt_metadata(registry, lock, 42, binary) is binary
|
||||||
|
|
||||||
|
def test_event_payload_wins_over_registered_metadata(self, registry, lock):
|
||||||
|
registry[42] = {"workflow_id": "wf-registered"}
|
||||||
|
merged = merge_prompt_metadata(
|
||||||
|
registry, lock, 42, {"prompt_id": "p1", "workflow_id": "wf-caller"}
|
||||||
|
)
|
||||||
|
assert merged["workflow_id"] == "wf-caller"
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressTextSidResolution:
|
||||||
|
"""``BinaryEventTypes.TEXT`` frames don't yet carry ``prompt_id`` /
|
||||||
|
``workflow_id`` in their wire shape, so cross-client routing has to happen
|
||||||
|
at the ``sid`` level. The default sid pins the broadcast to the active
|
||||||
|
prompt's client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_explicit_sid_passes_through(self):
|
||||||
|
assert resolve_progress_text_sid("client-explicit", "client-active") == "client-explicit"
|
||||||
|
|
||||||
|
def test_none_sid_defaults_to_active_client(self):
|
||||||
|
assert resolve_progress_text_sid(None, "client-active") == "client-active"
|
||||||
|
|
||||||
|
def test_none_sid_with_no_active_client_stays_none(self):
|
||||||
|
assert resolve_progress_text_sid(None, None) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptIdCollisionWithTokens:
|
||||||
|
"""Two prompts can be queued with the same client-supplied ``prompt_id``.
|
||||||
|
With a registry keyed by ``prompt_id`` the second registration would
|
||||||
|
overwrite the first or be erased by the first's unregister. The token model
|
||||||
|
makes each registration independent."""
|
||||||
|
|
||||||
|
def test_two_submissions_get_distinct_tokens_and_each_merges_correctly(self, registry, lock):
|
||||||
|
# Two submissions of the same prompt_id with different workflow_ids.
|
||||||
|
registry[1] = {"workflow_id": "wf-A"} # token from submission #1
|
||||||
|
registry[2] = {"workflow_id": "wf-B"} # token from submission #2
|
||||||
|
|
||||||
|
# Worker is currently running submission #1.
|
||||||
|
merged = merge_prompt_metadata(registry, lock, 1, {"prompt_id": "P", "node": "x"})
|
||||||
|
assert merged["workflow_id"] == "wf-A"
|
||||||
|
|
||||||
|
# Worker switches to submission #2 (queue ordering, retry, whatever).
|
||||||
|
merged = merge_prompt_metadata(registry, lock, 2, {"prompt_id": "P", "node": "y"})
|
||||||
|
assert merged["workflow_id"] == "wf-B"
|
||||||
|
|
||||||
|
def test_unregister_by_token_does_not_drop_concurrent_submission(self, registry, lock):
|
||||||
|
registry[1] = {"workflow_id": "wf-A"}
|
||||||
|
registry[2] = {"workflow_id": "wf-B"}
|
||||||
|
|
||||||
|
# Submission #1 finishes — drop its token only.
|
||||||
|
registry.pop(1, None)
|
||||||
|
|
||||||
|
# Submission #2 still has its metadata.
|
||||||
|
merged = merge_prompt_metadata(registry, lock, 2, {"prompt_id": "P"})
|
||||||
|
assert merged["workflow_id"] == "wf-B"
|
||||||
|
|
||||||
|
def test_execution_order_independent_of_registration_order(self, registry, lock):
|
||||||
|
"""Regression for the LIFO-stack failure mode: queue executes #1 first
|
||||||
|
but the previous stack design would have made the merge pick the
|
||||||
|
latest-registered (#2) metadata. Token model is immune."""
|
||||||
|
registry[1] = {"workflow_id": "wf-A"}
|
||||||
|
registry[2] = {"workflow_id": "wf-B"}
|
||||||
|
|
||||||
|
# Even though #2 was registered after #1, executing #1 still sees wf-A.
|
||||||
|
merged_first_running = merge_prompt_metadata(registry, lock, 1, {"prompt_id": "P"})
|
||||||
|
assert merged_first_running["workflow_id"] == "wf-A"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRaceRegressionForTerminalExecutingFrame:
|
||||||
|
"""Regression for the PR #13684 finally-clear race.
|
||||||
|
|
||||||
|
Executor's ``finally`` previously cleared the workflow_id source, so the
|
||||||
|
post-completion terminal frame shipped ``workflow_id=None``. With the
|
||||||
|
token model, the active token stays pinned until ``main.py`` clears it
|
||||||
|
*after* the terminal send.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_terminal_executing_frame_includes_workflow_id(self, registry, lock):
|
||||||
|
registry[7] = {"workflow_id": "wf-1"}
|
||||||
|
active_token = 7
|
||||||
|
|
||||||
|
# main.py emits the terminal frame BEFORE clearing the active token.
|
||||||
|
terminal = merge_prompt_metadata(
|
||||||
|
registry, lock, active_token, {"node": None, "prompt_id": "p1"}
|
||||||
|
)
|
||||||
|
# main.py's finally: clear active token + unregister.
|
||||||
|
active_token = None
|
||||||
|
registry.pop(7, None)
|
||||||
|
|
||||||
|
assert terminal == {"node": None, "prompt_id": "p1", "workflow_id": "wf-1"}
|
||||||
|
|
||||||
|
# After cleanup, straggler events get no metadata.
|
||||||
|
straggler = merge_prompt_metadata(
|
||||||
|
registry, lock, active_token, {"node": None, "prompt_id": "p1"}
|
||||||
|
)
|
||||||
|
assert "workflow_id" not in straggler
|
||||||
164
tests-unit/server_test/test_prompt_metadata_e2e.py
Normal file
164
tests-unit/server_test/test_prompt_metadata_e2e.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
"""End-to-end checks of the prompt-metadata propagation.
|
||||||
|
|
||||||
|
These tests instantiate the real ``PromptServer`` (no heavy mocking) and drive
|
||||||
|
the same call sequence ``main.py`` would: register metadata, pin the token,
|
||||||
|
emit a bunch of execution-shaped frames through ``send_sync``, then clear the
|
||||||
|
token + unregister. We then drain the server's message queue and inspect the
|
||||||
|
exact frames that would have been written to the WebSocket — proving the wire
|
||||||
|
shape and the ``prompt_id`` collision invariants without needing a browser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Ensure the repo root is the first thing on sys.path so ``utils`` resolves to
|
||||||
|
# the ComfyUI package, not a site-packages namespace package. ``utils`` must
|
||||||
|
# load before ``app.frontend_management`` imports it.
|
||||||
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
if REPO_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, REPO_ROOT)
|
||||||
|
import utils # noqa: F401 -- pin the package before later imports
|
||||||
|
|
||||||
|
from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY
|
||||||
|
|
||||||
|
server_module = pytest.importorskip("server")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prompt_server():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
server = server_module.PromptServer(loop)
|
||||||
|
yield server
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _drain(server) -> list:
|
||||||
|
"""Empty the server's internal asyncio.Queue and return every (event, data, sid)."""
|
||||||
|
frames = []
|
||||||
|
queue = server.messages
|
||||||
|
while not queue.empty():
|
||||||
|
frames.append(queue.get_nowait())
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def _run_pending_callbacks(server):
|
||||||
|
"""``send_sync`` schedules ``messages.put_nowait`` via ``loop.call_soon_threadsafe``.
|
||||||
|
Pump the loop a single iteration so those callbacks land before we drain."""
|
||||||
|
server.loop.call_soon(server.loop.stop)
|
||||||
|
server.loop.run_forever()
|
||||||
|
|
||||||
|
|
||||||
|
def _simulate_prompt_lifecycle(server, extra_data: dict, prompt_id: str):
|
||||||
|
"""Replicates what ``main.py``'s queue worker does around ``e.execute(...)``
|
||||||
|
plus a representative sample of the events ``execution.py`` emits."""
|
||||||
|
token = server.register_prompt_metadata(extra_data)
|
||||||
|
if token is not None:
|
||||||
|
extra_data[PROMPT_METADATA_TOKEN_KEY] = token
|
||||||
|
|
||||||
|
# Worker picks up the item, pins the token.
|
||||||
|
server.active_prompt_metadata_token = extra_data.pop(PROMPT_METADATA_TOKEN_KEY, None)
|
||||||
|
client_id = "client-A"
|
||||||
|
server.client_id = client_id
|
||||||
|
|
||||||
|
# A representative sample of the eight execution events listed in PR #13684.
|
||||||
|
server.send_sync("execution_start", {"prompt_id": prompt_id}, client_id)
|
||||||
|
server.send_sync("execution_cached", {"nodes": [], "prompt_id": prompt_id}, client_id)
|
||||||
|
server.send_sync("executing", {"node": "n1", "display_node": "n1", "prompt_id": prompt_id}, client_id)
|
||||||
|
server.send_sync(
|
||||||
|
"executed",
|
||||||
|
{"node": "n1", "display_node": "n1", "output": {}, "prompt_id": prompt_id},
|
||||||
|
client_id,
|
||||||
|
)
|
||||||
|
server.send_sync("execution_success", {"prompt_id": prompt_id}, client_id)
|
||||||
|
# The terminal frame ``main.py`` itself emits, just before clearing state.
|
||||||
|
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, client_id)
|
||||||
|
|
||||||
|
# Worker finally clause: clear token then unregister.
|
||||||
|
server.active_prompt_metadata_token = None
|
||||||
|
server.unregister_prompt_metadata(token)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEndWorkflowIdOnAllExecutionEvents:
|
||||||
|
def test_every_execution_event_carries_workflow_id(self, prompt_server):
|
||||||
|
_simulate_prompt_lifecycle(
|
||||||
|
prompt_server,
|
||||||
|
extra_data={"extra_pnginfo": {"workflow": {"id": "wf-xyz"}}},
|
||||||
|
prompt_id="p-1",
|
||||||
|
)
|
||||||
|
_run_pending_callbacks(prompt_server)
|
||||||
|
frames = _drain(prompt_server)
|
||||||
|
|
||||||
|
assert frames, "no frames emitted"
|
||||||
|
# Every frame with a prompt_id payload must carry workflow_id top-level.
|
||||||
|
for event, data, _sid in frames:
|
||||||
|
if isinstance(data, dict) and data.get("prompt_id"):
|
||||||
|
assert data.get("workflow_id") == "wf-xyz", (event, data)
|
||||||
|
|
||||||
|
# And specifically verify the terminal frame — the #13684 race victim.
|
||||||
|
terminal = [
|
||||||
|
(e, d) for e, d, _ in frames
|
||||||
|
if e == "executing" and isinstance(d, dict) and d.get("node") is None
|
||||||
|
]
|
||||||
|
assert terminal, "no terminal executing frame emitted"
|
||||||
|
_, terminal_payload = terminal[-1]
|
||||||
|
assert terminal_payload["workflow_id"] == "wf-xyz"
|
||||||
|
|
||||||
|
def test_status_frame_is_not_decorated(self, prompt_server):
|
||||||
|
prompt_server.register_prompt_metadata({"extra_pnginfo": {"workflow": {"id": "wf-1"}}})
|
||||||
|
prompt_server.active_prompt_metadata_token = 1
|
||||||
|
|
||||||
|
# A status frame has no prompt_id and must remain untouched.
|
||||||
|
prompt_server.send_sync("status", {"status": {"queue_remaining": 0}}, None)
|
||||||
|
_run_pending_callbacks(prompt_server)
|
||||||
|
frames = _drain(prompt_server)
|
||||||
|
|
||||||
|
assert any(e == "status" for e, _, _ in frames)
|
||||||
|
for event, data, _ in frames:
|
||||||
|
if event == "status":
|
||||||
|
assert "workflow_id" not in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEndPromptIdCollision:
|
||||||
|
"""Drive two prompts with the same client-supplied ``prompt_id`` but
|
||||||
|
different ``workflow_id`` values. The token model must keep their event
|
||||||
|
streams attributed correctly."""
|
||||||
|
|
||||||
|
def test_same_prompt_id_two_workflows_each_stream_attributed_correctly(self, prompt_server):
|
||||||
|
_simulate_prompt_lifecycle(
|
||||||
|
prompt_server,
|
||||||
|
extra_data={"extra_pnginfo": {"workflow": {"id": "wf-A"}}},
|
||||||
|
prompt_id="P-shared",
|
||||||
|
)
|
||||||
|
_simulate_prompt_lifecycle(
|
||||||
|
prompt_server,
|
||||||
|
extra_data={"extra_pnginfo": {"workflow": {"id": "wf-B"}}},
|
||||||
|
prompt_id="P-shared",
|
||||||
|
)
|
||||||
|
_run_pending_callbacks(prompt_server)
|
||||||
|
frames = _drain(prompt_server)
|
||||||
|
|
||||||
|
# Partition the frames by lifecycle (execution_start..terminal-executing).
|
||||||
|
runs: list[list[dict]] = [[]]
|
||||||
|
for event, data, _ in frames:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
runs[-1].append({"event": event, **data})
|
||||||
|
if event == "executing" and data.get("node") is None:
|
||||||
|
runs.append([])
|
||||||
|
# The trailing empty bucket from the final split is fine.
|
||||||
|
runs = [r for r in runs if r]
|
||||||
|
assert len(runs) == 2, runs
|
||||||
|
|
||||||
|
run_a, run_b = runs
|
||||||
|
for entry in run_a:
|
||||||
|
if entry.get("prompt_id"):
|
||||||
|
assert entry.get("workflow_id") == "wf-A", entry
|
||||||
|
for entry in run_b:
|
||||||
|
if entry.get("prompt_id"):
|
||||||
|
assert entry.get("workflow_id") == "wf-B", entry
|
||||||
Loading…
Reference in New Issue
Block a user