mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-06 07:12:30 +08:00
Address review: clear stale workflow_id, expand reconnect payload, harden tests
- Clear self.server.last_workflow_id and self.workflow_id in the PromptExecutor finally block so a progress callback racing with teardown can no longer attach the previous run's workflow_id to a later 'progress' event. - Include prompt_id and last_workflow_id in the reconnect 'executing' message in server.py so reconnecting clients can recover both workflow- and prompt-scoped execution state, matching the regular 'executing' payload. - Add an AST-based static guard that walks execution.py, main.py, and comfy_execution/progress.py and asserts every dict literal carrying prompt_id also carries workflow_id. Also add a unit test covering PREVIEW_IMAGE_WITH_METADATA metadata. Together these regression-test every emitter (execution_start/success/error/interrupted/cached, executing, executed, progress, progress_state, preview metadata) without requiring a GPU-backed import of execution.py.
This commit is contained in:
parent
2205341279
commit
1f0b13705d
@ -816,6 +816,8 @@ class PromptExecutor:
|
|||||||
finally:
|
finally:
|
||||||
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
||||||
self._notify_prompt_lifecycle("end", prompt_id)
|
self._notify_prompt_lifecycle("end", prompt_id)
|
||||||
|
self.server.last_workflow_id = None
|
||||||
|
self.workflow_id = None
|
||||||
|
|
||||||
|
|
||||||
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||||
|
|||||||
@ -274,7 +274,11 @@ class PromptServer():
|
|||||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||||
# On reconnect if we are the currently executing client send the current node
|
# On reconnect if we are the currently executing client send the current node
|
||||||
if self.client_id == sid and self.last_node_id is not None:
|
if self.client_id == sid and self.last_node_id is not None:
|
||||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
await self.send("executing", {
|
||||||
|
"node": self.last_node_id,
|
||||||
|
"prompt_id": getattr(self, "last_prompt_id", None),
|
||||||
|
"workflow_id": getattr(self, "last_workflow_id", None),
|
||||||
|
}, sid)
|
||||||
|
|
||||||
# Flag to track if we've received the first message
|
# Flag to track if we've received the first message
|
||||||
first_message = True
|
first_message = True
|
||||||
|
|||||||
@ -109,3 +109,113 @@ class TestResetProgressState:
|
|||||||
def test_reset_default_workflow_id_none(self):
|
def test_reset_default_workflow_id_none(self):
|
||||||
reset_progress_state("prompt-2", _DummyDynPrompt())
|
reset_progress_state("prompt-2", _DummyDynPrompt())
|
||||||
assert get_progress_state().workflow_id is None
|
assert get_progress_state().workflow_id is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutionMessagePayloadsContainWorkflowId:
|
||||||
|
"""Static-analysis guard ensuring every WebSocket message payload that
|
||||||
|
carries `prompt_id` also carries `workflow_id`. This is a regression net
|
||||||
|
for future refactors of execution.py / main.py / progress.py and avoids
|
||||||
|
the GPU/torch dependency of importing `execution.py` directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _emitting_dicts(source: str):
|
||||||
|
"""Yield every dict literal in `source` that contains a 'prompt_id' key."""
|
||||||
|
import ast
|
||||||
|
|
||||||
|
tree = ast.parse(source)
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if not isinstance(node, ast.Dict):
|
||||||
|
continue
|
||||||
|
keys = [
|
||||||
|
k.value
|
||||||
|
for k in node.keys
|
||||||
|
if isinstance(k, ast.Constant) and isinstance(k.value, str)
|
||||||
|
]
|
||||||
|
if "prompt_id" in keys:
|
||||||
|
yield node, keys
|
||||||
|
|
||||||
|
def _assert_workflow_id_in_every_prompt_id_dict(self, file_path: str):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
source = Path(file_path).read_text()
|
||||||
|
offenders = []
|
||||||
|
for node, keys in self._emitting_dicts(source):
|
||||||
|
if "workflow_id" not in keys:
|
||||||
|
offenders.append((node.lineno, keys))
|
||||||
|
assert not offenders, (
|
||||||
|
f"{file_path}: dict literals with 'prompt_id' but no 'workflow_id': {offenders}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_execution_py_payloads_include_workflow_id(self):
|
||||||
|
self._assert_workflow_id_in_every_prompt_id_dict("execution.py")
|
||||||
|
|
||||||
|
def test_main_py_payloads_include_workflow_id(self):
|
||||||
|
self._assert_workflow_id_in_every_prompt_id_dict("main.py")
|
||||||
|
|
||||||
|
def test_progress_py_payloads_include_workflow_id(self):
|
||||||
|
self._assert_workflow_id_in_every_prompt_id_dict("comfy_execution/progress.py")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreviewImageMetadataPayload:
|
||||||
|
"""Verify PREVIEW_IMAGE_WITH_METADATA metadata carries workflow_id."""
|
||||||
|
|
||||||
|
def test_preview_metadata_includes_workflow_id(self):
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy_execution.progress import (
|
||||||
|
NodeState,
|
||||||
|
ProgressRegistry,
|
||||||
|
WebUIProgressHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
class _DynPrompt:
|
||||||
|
def get_display_node_id(self, n):
|
||||||
|
return n
|
||||||
|
|
||||||
|
def get_parent_node_id(self, n):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_real_node_id(self, n):
|
||||||
|
return n
|
||||||
|
|
||||||
|
server = MagicMock()
|
||||||
|
server.client_id = "cid"
|
||||||
|
server.sockets_metadata = {}
|
||||||
|
|
||||||
|
registry = ProgressRegistry(
|
||||||
|
prompt_id="p1", dynprompt=_DynPrompt(), workflow_id="wf-1"
|
||||||
|
)
|
||||||
|
handler = WebUIProgressHandler(server)
|
||||||
|
handler.set_registry(registry)
|
||||||
|
|
||||||
|
image = ("PNG", Image.new("RGB", (1, 1)), None)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"comfy_execution.progress.feature_flags.supports_feature",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
handler.update_handler(
|
||||||
|
node_id="n1",
|
||||||
|
value=1.0,
|
||||||
|
max_value=1.0,
|
||||||
|
state={
|
||||||
|
"state": NodeState.Running,
|
||||||
|
"value": 1.0,
|
||||||
|
"max": 1.0,
|
||||||
|
},
|
||||||
|
prompt_id="p1",
|
||||||
|
image=image,
|
||||||
|
)
|
||||||
|
|
||||||
|
preview_calls = [
|
||||||
|
c
|
||||||
|
for c in server.send_sync.call_args_list
|
||||||
|
if c.args[0] != "progress_state"
|
||||||
|
]
|
||||||
|
assert len(preview_calls) == 1
|
||||||
|
_, payload, _ = preview_calls[0].args
|
||||||
|
_, metadata = payload
|
||||||
|
assert metadata["prompt_id"] == "p1"
|
||||||
|
assert metadata["workflow_id"] == "wf-1"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user