diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 57126fa4a..2aae7ec50 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -193,6 +193,8 @@ class HiddenInputTypeDict(TypedDict): """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" dynprompt: NotRequired[Literal["DYNPROMPT"]] """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + prompt_id: NotRequired[Literal["PROMPT_ID"]] + """PROMPT_ID is the unique identifier of the current prompt/job being executed. Useful for associating progress updates with specific jobs.""" class InputTypeDict(TypedDict): diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 9f6918315..6a3595af9 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -12,6 +12,7 @@ from comfy.cli_args import args # Default server capabilities SERVER_FEATURE_FLAGS: dict[str, Any] = { "supports_preview_metadata": True, + "supports_progress_text_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 7ca8f4e0c..1ce6ea51e 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1282,9 +1282,16 @@ class V3Data(TypedDict): 'When True, the value of the dynamic input will be in the format (value, path_key).' class HiddenHolder: + """Holds hidden input values resolved during node execution. + + Hidden inputs are special values automatically provided by the execution + engine (e.g., node ID, prompt data, authentication tokens) rather than + being connected by the user in the graph. + """ def __init__(self, unique_id: str, prompt: Any, extra_pnginfo: Any, dynprompt: Any, - auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs): + auth_token_comfy_org: str, api_key_comfy_org: str, + prompt_id: str = None, **kwargs): self.unique_id = unique_id """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" self.prompt = prompt @@ -1297,6 +1304,8 @@ class HiddenHolder: """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" self.api_key_comfy_org = api_key_comfy_org """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + self.prompt_id = prompt_id + """PROMPT_ID is the unique identifier of the current prompt/job being executed.""" def __getattr__(self, key: str): '''If hidden variable not found, return None.''' @@ -1304,6 +1313,14 @@ class HiddenHolder: @classmethod def from_dict(cls, d: dict | None): + """Create a HiddenHolder from a dictionary of hidden input values. + + Args: + d: Dictionary mapping Hidden enum values to their resolved values. + + Returns: + A new HiddenHolder instance with values populated from the dict. + """ if d is None: d = {} return cls( @@ -1313,6 +1330,7 @@ class HiddenHolder: dynprompt=d.get(Hidden.dynprompt, None), auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None), api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), + prompt_id=d.get(Hidden.prompt_id, None), ) @classmethod @@ -1335,6 +1353,8 @@ class Hidden(str, Enum): """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" api_key_comfy_org = "API_KEY_COMFY_ORG" """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + prompt_id = "PROMPT_ID" + """PROMPT_ID is the unique identifier of the current prompt/job being executed. Useful for associating progress updates with specific jobs.""" @dataclass diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9d730b81a..1ce6019a2 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -17,6 +17,7 @@ from pydantic import BaseModel from comfy import utils from comfy_api.latest import IO +from comfy_execution.utils import get_executing_context from server import PromptServer from . import request_logger @@ -440,6 +441,17 @@ def _display_text( status: str | int | None = None, price: float | None = None, ) -> None: + """Send a progress text message to the client for display on a node. + + Assembles status, price, and text lines, then sends them via WebSocket. + Automatically retrieves the current prompt_id from the execution context. + + Args: + node_cls: The ComfyNode class sending the progress text. + text: Optional text content to display. + status: Optional status string or code to display. + price: Optional price in dollars to display as credits. + """ display_lines: list[str] = [] if status: display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") @@ -450,7 +462,9 @@ def _display_text( if text is not None: display_lines.append(text) if display_lines: - PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) + ctx = get_executing_context() + prompt_id = ctx.prompt_id if ctx is not None else None + PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls), prompt_id=prompt_id) def _display_time_progress( diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 4c57bb5cb..fa60512cc 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -566,7 +566,7 @@ class GetImageSize(IO.ComfyNode): IO.Int.Output(display_name="height"), IO.Int.Output(display_name="batch_size"), ], - hidden=[IO.Hidden.unique_id], + hidden=[IO.Hidden.unique_id, IO.Hidden.prompt_id], ) @classmethod @@ -577,7 +577,7 @@ class GetImageSize(IO.ComfyNode): # Send progress text to display size on the node if cls.hidden.unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id, prompt_id=cls.hidden.prompt_id) return IO.NodeOutput(width, height, batch_size) diff --git a/execution.py b/execution.py index 1a6c3429c..061fe25f8 100644 --- a/execution.py +++ b/execution.py @@ -150,7 +150,7 @@ class CacheSet: SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}, prompt_id=None): is_v3 = issubclass(class_def, _ComfyNodeInternal) v3_data: io.V3Data = {} hidden_inputs_v3 = {} @@ -197,6 +197,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) if io.Hidden.api_key_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + if io.Hidden.prompt_id.name in hidden: + hidden_inputs_v3[io.Hidden.prompt_id] = prompt_id else: if "hidden" in valid_inputs: h = valid_inputs["hidden"] @@ -213,6 +215,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] + if h[x] == "PROMPT_ID": + input_data_all[x] = [prompt_id] v3_data["hidden_inputs"] = hidden_inputs_v3 return input_data_all, missing_keys, v3_data @@ -470,7 +474,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data, prompt_id=prompt_id) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) diff --git a/server.py b/server.py index 85a8964be..4efb709a5 100644 --- a/server.py +++ b/server.py @@ -1230,13 +1230,47 @@ class PromptServer(): return json_data def send_progress_text( - self, text: Union[bytes, bytearray, str], node_id: str, sid=None + self, + text: Union[bytes, bytearray, str], + node_id: str, + prompt_id: Optional[str] = None, + sid=None, ): + """Send a progress text message to the client via WebSocket. + + Encodes the text as a binary message with length-prefixed node_id. When + the client supports the ``supports_progress_text_metadata`` feature flag, + the prompt_id is always prepended as a length-prefixed field (empty string + when None) to ensure consistent binary framing. + + Args: + text: The progress text content to send. + node_id: The unique identifier of the node sending the progress. + prompt_id: Optional prompt/job identifier to associate the message with. + sid: Optional session ID to target a specific client. + """ if isinstance(text, str): text = text.encode("utf-8") node_id_bytes = str(node_id).encode("utf-8") - # Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes - message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text + # Auto-resolve sid to the currently executing client + target_sid = sid if sid is not None else self.client_id - self.send_sync(BinaryEventTypes.TEXT, message, sid) + # When client supports the new format, always send + # [prompt_id_len][prompt_id][node_id_len][node_id][text] + # even when prompt_id is None (encoded as zero-length string) + if feature_flags.supports_feature( + self.sockets_metadata, target_sid, "supports_progress_text_metadata" + ): + prompt_id_bytes = (prompt_id or "").encode("utf-8") + message = ( + struct.pack(">I", len(prompt_id_bytes)) + + prompt_id_bytes + + struct.pack(">I", len(node_id_bytes)) + + node_id_bytes + + text + ) + else: + message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text + + self.send_sync(BinaryEventTypes.TEXT, message, target_sid)