diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fcd7ef735..24dd1ffd0 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict: } +def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]: + """Extract the workflow id from a prompt's ``extra_data``. + + The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]`` + when a prompt is queued. Any value that is not a non-empty string is treated as + missing so callers can rely on the return being either ``None`` or a string. + """ + if not isinstance(extra_data, dict): + return None + extra_pnginfo = extra_data.get('extra_pnginfo') + if not isinstance(extra_pnginfo, dict): + return None + workflow = extra_pnginfo.get('workflow') + if not isinstance(workflow, dict): + return None + workflow_id = workflow.get('id') + if isinstance(workflow_id, str) and workflow_id: + return workflow_id + return None + + def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: """Extract create_time and workflow_id from extra_data. @@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str tuple: (create_time, workflow_id) """ create_time = extra_data.get('create_time') - extra_pnginfo = extra_data.get('extra_pnginfo', {}) - workflow_id = extra_pnginfo.get('workflow', {}).get('id') + workflow_id = extract_workflow_id(extra_data) return create_time, workflow_id diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index f951a3350..b6d3bd3e4 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -164,6 +164,8 @@ class WebUIProgressHandler(ProgressHandler): if self.server_instance is None: return + workflow_id = self.registry.workflow_id if self.registry else None + # Only send info for non-pending nodes active_nodes = { node_id: { @@ -172,6 +174,7 @@ class WebUIProgressHandler(ProgressHandler): "state": state["state"].value, "node_id": node_id, "prompt_id": prompt_id, + "workflow_id": workflow_id, "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), @@ -183,7 +186,7 @@ class WebUIProgressHandler(ProgressHandler): # Send a combined progress_state message with all node states # Include client_id to ensure message is only sent to the initiating client self.server_instance.send_sync( - "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id + "progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id ) @override @@ -215,6 +218,7 @@ class WebUIProgressHandler(ProgressHandler): metadata = { "node_id": node_id, "prompt_id": prompt_id, + "workflow_id": self.registry.workflow_id if self.registry else None, "display_node_id": self.registry.dynprompt.get_display_node_id( node_id ), @@ -240,9 +244,10 @@ class ProgressRegistry: Registry that maintains node progress state and notifies registered handlers. """ - def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): + def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None): self.prompt_id = prompt_id self.dynprompt = dynprompt + self.workflow_id = workflow_id self.nodes: Dict[str, NodeProgressState] = {} self.handlers: Dict[str, ProgressHandler] = {} @@ -322,7 +327,7 @@ class ProgressRegistry: # Global registry instance global_progress_registry: ProgressRegistry | None = None -def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: +def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None: global global_progress_registry # Reset existing handlers if registry exists @@ -330,7 +335,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: global_progress_registry.reset_handlers() # Create new registry - global_progress_registry = ProgressRegistry(prompt_id, dynprompt) + global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id) def add_progress_handler(handler: ProgressHandler) -> None: diff --git a/execution.py b/execution.py index f37d0360d..ff8240588 100644 --- a/execution.py +++ b/execution.py @@ -38,6 +38,7 @@ from comfy_execution.graph import ( from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler +from comfy_execution.jobs import extract_workflow_id from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io @@ -417,15 +418,15 @@ def _is_intermediate_output(dynprompt, node_id): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False) -def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs): +def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs): if server.client_id is None: return cached_ui = cached.ui or {} - server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id) if cached.ui is not None: ui_outputs[node_id] = cached.ui -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -435,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class_def = nodes.NODE_CLASS_MAPPINGS[class_type] cached = await caches.outputs.get(unique_id) if cached is not None: - _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs) + _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs) get_progress_state().finish_progress(unique_id) execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) @@ -483,7 +484,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id - server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id) obj = await caches.objects.get(unique_id) if obj is None: @@ -513,6 +514,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if block.message is not None: mes = { "prompt_id": prompt_id, + "workflow_id": workflow_id, "node_id": unique_id, "node_type": class_type, "executed": list(executed), @@ -561,7 +563,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, "output": output_ui } if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id) if has_subgraph: cached_outputs = [] new_node_ids = [] @@ -658,6 +660,7 @@ class PromptExecutor: self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.status_messages = [] self.success = True + self.workflow_id = None def add_message(self, event, data: dict, broadcast: bool): data = { @@ -677,6 +680,7 @@ class PromptExecutor: if isinstance(ex, comfy.model_management.InterruptProcessingException): mes = { "prompt_id": prompt_id, + "workflow_id": self.workflow_id, "node_id": node_id, "node_type": class_type, "executed": list(executed), @@ -685,6 +689,7 @@ class PromptExecutor: else: mes = { "prompt_id": prompt_id, + "workflow_id": self.workflow_id, "node_id": node_id, "node_type": class_type, "executed": list(executed), @@ -723,7 +728,9 @@ class PromptExecutor: self.server.client_id = None self.status_messages = [] - self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) + self.workflow_id = extract_workflow_id(extra_data) + self.server.last_workflow_id = self.workflow_id + self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False) self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) @@ -733,7 +740,7 @@ class PromptExecutor: try: with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) + reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id) add_progress_handler(WebUIProgressHandler(self.server)) is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) for cache in self.caches.all: @@ -751,7 +758,7 @@ class PromptExecutor: comfy.model_management.cleanup_models_gc() self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, + { "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False) pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results @@ -769,7 +776,7 @@ class PromptExecutor: break assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -793,8 +800,8 @@ class PromptExecutor: cached = await self.caches.outputs.get(node_id) if cached is not None: display_node_id = dynamic_prompt.get_display_node_id(node_id) - _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs) - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs) + self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False) ui_outputs = {} meta_outputs = {} @@ -811,6 +818,8 @@ class PromptExecutor: finally: comfy.memory_management.set_ram_cache_release_state(None, 0) self._notify_prompt_lifecycle("end", prompt_id) + self.server.last_workflow_id = None + self.workflow_id = None async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): diff --git a/main.py b/main.py index a6fdaf43c..3ac8395b1 100644 --- a/main.py +++ b/main.py @@ -29,6 +29,7 @@ import logging import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context +from comfy_execution.jobs import extract_workflow_id from comfy_api import feature_flags from app.database.db import init_db, dependencies_available @@ -317,6 +318,12 @@ def prompt_worker(q, server_instance): for k in sensitive: extra_data[k] = sensitive[k] + # Capture the workflow id for this prompt before execution: the + # executor clears server.last_workflow_id in its finally block, so + # reading it after e.execute() returns would emit workflow_id=None + # on the terminal "executing" reset below. + workflow_id = extract_workflow_id(extra_data) + asset_seeder.pause() e.execute(item[2], prompt_id, extra_data, item[4]) @@ -330,7 +337,7 @@ def prompt_worker(q, server_instance): completed=e.success, messages=e.status_messages), process_item=remove_sensitive) if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": workflow_id}, server_instance.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -393,7 +400,7 @@ def hijack_progress(server_instance): prompt_id = server_instance.last_prompt_id if node_id is None: node_id = server_instance.last_node_id - progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} + progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id} get_progress_state().update_progress(node_id, value, total, preview_image) server_instance.send_sync("progress", progress, server_instance.client_id) diff --git a/server.py b/server.py index 2f3b438bb..08eea1160 100644 --- a/server.py +++ b/server.py @@ -275,7 +275,11 @@ class PromptServer(): await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", { "node": self.last_node_id }, sid) + await self.send("executing", { + "node": self.last_node_id, + "prompt_id": getattr(self, "last_prompt_id", None), + "workflow_id": getattr(self, "last_workflow_id", None), + }, sid) # Flag to track if we've received the first message first_message = True diff --git a/tests-unit/execution_test/test_workflow_id_in_ws_messages.py b/tests-unit/execution_test/test_workflow_id_in_ws_messages.py new file mode 100644 index 000000000..cf1ff71e9 --- /dev/null +++ b/tests-unit/execution_test/test_workflow_id_in_ws_messages.py @@ -0,0 +1,297 @@ +"""Tests that workflow_id is included alongside prompt_id in WebSocket payloads +emitted by the progress handler and the prompt executor. + +Frontend stores extra_data["extra_pnginfo"]["workflow"]["id"] when queueing a +prompt; we propagate that as `workflow_id` on every execution event so a +multi-tab UI can scope progress state by workflow even when terminal +WebSocket frames are dropped. +""" + +from unittest.mock import MagicMock + +import pytest + +from comfy_execution.progress import ( + NodeState, + ProgressRegistry, + WebUIProgressHandler, + reset_progress_state, + get_progress_state, +) + + +class _DummyDynPrompt: + def get_display_node_id(self, node_id): + return node_id + + def get_parent_node_id(self, node_id): + return None + + def get_real_node_id(self, node_id): + return node_id + + +@pytest.fixture +def server(): + s = MagicMock() + s.client_id = "client-1" + return s + + +def _registry(workflow_id): + return ProgressRegistry( + prompt_id="prompt-1", + dynprompt=_DummyDynPrompt(), + workflow_id=workflow_id, + ) + + +class TestProgressStatePayload: + def test_progress_state_includes_workflow_id(self, server): + registry = _registry("wf-abc") + registry.nodes["n1"] = { + "state": NodeState.Running, + "value": 1.0, + "max": 5.0, + } + + handler = WebUIProgressHandler(server) + handler.set_registry(registry) + handler._send_progress_state("prompt-1", registry.nodes) + + server.send_sync.assert_called_once() + event, payload, sid = server.send_sync.call_args.args + assert event == "progress_state" + assert payload["prompt_id"] == "prompt-1" + assert payload["workflow_id"] == "wf-abc" + assert payload["nodes"]["n1"]["workflow_id"] == "wf-abc" + assert payload["nodes"]["n1"]["prompt_id"] == "prompt-1" + assert sid == "client-1" + + def test_progress_state_workflow_id_none_when_missing(self, server): + registry = _registry(None) + registry.nodes["n1"] = { + "state": NodeState.Running, + "value": 0.5, + "max": 1.0, + } + + handler = WebUIProgressHandler(server) + handler.set_registry(registry) + handler._send_progress_state("prompt-1", registry.nodes) + + _, payload, _ = server.send_sync.call_args.args + assert payload["workflow_id"] is None + assert payload["nodes"]["n1"]["workflow_id"] is None + + +class TestProgressRegistryConstruction: + def test_workflow_id_default_is_none(self): + registry = ProgressRegistry( + prompt_id="prompt-1", dynprompt=_DummyDynPrompt() + ) + assert registry.workflow_id is None + + def test_workflow_id_stored_on_registry(self): + registry = ProgressRegistry( + prompt_id="prompt-1", + dynprompt=_DummyDynPrompt(), + workflow_id="wf-xyz", + ) + assert registry.workflow_id == "wf-xyz" + + +class TestResetProgressState: + def test_reset_threads_workflow_id(self): + reset_progress_state("prompt-1", _DummyDynPrompt(), "wf-456") + assert get_progress_state().workflow_id == "wf-456" + + def test_reset_default_workflow_id_none(self): + reset_progress_state("prompt-2", _DummyDynPrompt()) + assert get_progress_state().workflow_id is None + + +class TestExecutionMessagePayloadsContainWorkflowId: + """Static-analysis guard ensuring every WebSocket message payload that + carries `prompt_id` also carries `workflow_id`. This is a regression net + for future refactors of execution.py / main.py / progress.py and avoids + the GPU/torch dependency of importing `execution.py` directly. + """ + + @staticmethod + def _emitting_dicts(source: str): + """Yield every dict literal in `source` that contains a 'prompt_id' key.""" + import ast + + tree = ast.parse(source) + for node in ast.walk(tree): + if not isinstance(node, ast.Dict): + continue + keys = [ + k.value + for k in node.keys + if isinstance(k, ast.Constant) and isinstance(k.value, str) + ] + if "prompt_id" in keys: + yield node, keys + + def _assert_workflow_id_in_every_prompt_id_dict(self, file_path: str): + from pathlib import Path + + repo_root = Path(__file__).resolve().parents[2] + source = (repo_root / file_path).read_text() + offenders = [] + for node, keys in self._emitting_dicts(source): + if "workflow_id" not in keys: + offenders.append((node.lineno, keys)) + assert not offenders, ( + f"{file_path}: dict literals with 'prompt_id' but no 'workflow_id': {offenders}" + ) + + def test_execution_py_payloads_include_workflow_id(self): + self._assert_workflow_id_in_every_prompt_id_dict("execution.py") + + def test_main_py_payloads_include_workflow_id(self): + self._assert_workflow_id_in_every_prompt_id_dict("main.py") + + def test_progress_py_payloads_include_workflow_id(self): + self._assert_workflow_id_in_every_prompt_id_dict("comfy_execution/progress.py") + + +class TestPreviewImageMetadataPayload: + """Verify PREVIEW_IMAGE_WITH_METADATA metadata carries workflow_id.""" + + def test_preview_metadata_includes_workflow_id(self): + from unittest.mock import MagicMock, patch + from PIL import Image + + from comfy_execution.progress import ( + NodeState, + ProgressRegistry, + WebUIProgressHandler, + ) + + class _DynPrompt: + def get_display_node_id(self, n): + return n + + def get_parent_node_id(self, n): + return None + + def get_real_node_id(self, n): + return n + + server = MagicMock() + server.client_id = "cid" + server.sockets_metadata = {} + + registry = ProgressRegistry( + prompt_id="p1", dynprompt=_DynPrompt(), workflow_id="wf-1" + ) + handler = WebUIProgressHandler(server) + handler.set_registry(registry) + + image = ("PNG", Image.new("RGB", (1, 1)), None) + + with patch( + "comfy_execution.progress.feature_flags.supports_feature", + return_value=True, + ): + handler.update_handler( + node_id="n1", + value=1.0, + max_value=1.0, + state={ + "state": NodeState.Running, + "value": 1.0, + "max": 1.0, + }, + prompt_id="p1", + image=image, + ) + + preview_calls = [ + c + for c in server.send_sync.call_args_list + if c.args[0] != "progress_state" + ] + assert len(preview_calls) == 1 + _, payload, _ = preview_calls[0].args + _, metadata = payload + assert metadata["prompt_id"] == "p1" + assert metadata["workflow_id"] == "wf-1" + + + +class TestTerminalExecutingResetInMainPy: + """Regression test for the main.py prompt_worker terminal 'executing' reset. + + The executor clears server.last_workflow_id in its finally block, so + main.py must capture the workflow id *before* calling e.execute() and use + that local value, not read server.last_workflow_id afterwards. + + Rather than importing main.py (which triggers torch CUDA init in this + environment), we statically assert the contract via AST: somewhere + between the `extra_data = item[3].copy()` line and the + `e.execute(item[2], ...)` call, the function must extract workflow_id + from extra_data into a local, and the subsequent send_sync("executing", + ...) must reference that local rather than server.last_workflow_id. + """ + + def test_terminal_executing_uses_locally_captured_workflow_id(self): + import ast + from pathlib import Path + + repo_root = Path(__file__).resolve().parents[2] + source = (repo_root / "main.py").read_text() + tree = ast.parse(source) + + worker = next( + ( + n + for n in ast.walk(tree) + if isinstance(n, ast.FunctionDef) and n.name == "prompt_worker" + ), + None, + ) + assert worker is not None, "prompt_worker function not found in main.py" + + worker_src = ast.get_source_segment(source, worker) or "" + + assert "extract_workflow_id(extra_data)" in worker_src, ( + "main.py:prompt_worker must capture workflow_id locally from extra_data " + "before calling e.execute() (the executor clears server.last_workflow_id " + "in finally)." + ) + + matched_terminal_executing_send = False + for node in ast.walk(worker): + if not isinstance(node, ast.Call): + continue + func = node.func + if not ( + isinstance(func, ast.Attribute) + and func.attr == "send_sync" + and node.args + and isinstance(node.args[0], ast.Constant) + and node.args[0].value == "executing" + and len(node.args) >= 2 + and isinstance(node.args[1], ast.Dict) + ): + continue + matched_terminal_executing_send = True + payload = node.args[1] + for key, value in zip(payload.keys, payload.values): + if isinstance(key, ast.Constant) and key.value == "workflow_id": + rendered = ast.unparse(value) + assert "last_workflow_id" not in rendered, ( + "main.py terminal 'executing' must not read " + "server.last_workflow_id; the executor clears it in its " + "finally block. Use a locally captured workflow_id instead." + ) + + assert matched_terminal_executing_send, ( + "main.py:prompt_worker no longer has an inline " + 'send_sync("executing", {...}) payload; update this regression test ' + "so it still verifies the terminal workflow_id source." + ) diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 814af5c13..6afa6cd9c 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -10,9 +10,44 @@ from comfy_execution.jobs import ( get_outputs_summary, apply_sorting, has_3d_extension, + extract_workflow_id, ) +class TestExtractWorkflowId: + """Unit tests for extract_workflow_id().""" + + def test_returns_id_from_extra_pnginfo(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 'wf-123'}}}) == 'wf-123' + + def test_missing_extra_data_returns_none(self): + assert extract_workflow_id(None) is None + + def test_non_dict_extra_data_returns_none(self): + assert extract_workflow_id('not-a-dict') is None + + def test_missing_extra_pnginfo_returns_none(self): + assert extract_workflow_id({}) is None + + def test_missing_workflow_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {}}) is None + + def test_missing_id_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': {}}}) is None + + def test_empty_string_id_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': ''}}}) is None + + def test_non_string_id_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 42}}}) is None + + def test_non_dict_workflow_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': {'workflow': 'not-a-dict'}}) is None + + def test_non_dict_extra_pnginfo_returns_none(self): + assert extract_workflow_id({'extra_pnginfo': 'not-a-dict'}) is None + + class TestJobStatus: """Test JobStatus constants."""