From 6af812f9a89cc9921622f8d34c47c21559dfffef Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 23 Sep 2025 11:01:48 -0700 Subject: [PATCH] Fix custom model paths config paths, tweak tests --- comfy/cmd/main.py | 2 +- tests/conftest.py | 3 +- tests/execution/__init__.py | 0 tests/execution/base_dir/__init__.py | 0 .../base_dir/custom_nodes/stubs/__init__.py | 7 + .../base_dir/custom_nodes/stubs/stubs.py | 129 +++++++++++++++ tests/execution/test_async_nodes.py | 4 +- tests/execution/test_execution.py | 6 +- tests/execution/test_progress_isolation.py | 153 +++++++++++++++--- 9 files changed, 276 insertions(+), 28 deletions(-) create mode 100644 tests/execution/__init__.py create mode 100644 tests/execution/base_dir/__init__.py create mode 100644 tests/execution/base_dir/custom_nodes/stubs/__init__.py create mode 100644 tests/execution/base_dir/custom_nodes/stubs/stubs.py diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 04b02f9f0..c5748964e 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -177,7 +177,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None): pass if args.extra_model_paths_config: - for config_path in itertools.chain(*args.extra_model_paths_config): + for config_path in args.extra_model_paths_config: load_extra_path_config(config_path) if args.create_directories: diff --git a/tests/conftest.py b/tests/conftest.py index 5a39b1ddc..668ea947c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -173,7 +173,8 @@ def comfy_background_server_from_config(configuration): @pytest.fixture(scope="session") def skip_timing_checks(pytestconfig): """Fixture that returns whether timing checks should be skipped.""" - return pytestconfig.getoption("--skip-timing-checks") + # todo: in the LTS, we don't need to skip timing checks, everything just works + return False def pytest_collection_modifyitems(items): diff --git a/tests/execution/__init__.py b/tests/execution/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/execution/base_dir/__init__.py b/tests/execution/base_dir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/execution/base_dir/custom_nodes/stubs/__init__.py b/tests/execution/base_dir/custom_nodes/stubs/__init__.py new file mode 100644 index 000000000..806a84b18 --- /dev/null +++ b/tests/execution/base_dir/custom_nodes/stubs/__init__.py @@ -0,0 +1,7 @@ +from .stubs import StubImage, StubConstantImage, StubMask + +NODE_CLASS_MAPPINGS = { + "StubImage": StubImage, + "StubConstantImage": StubConstantImage, + "StubMask": StubMask +} diff --git a/tests/execution/base_dir/custom_nodes/stubs/stubs.py b/tests/execution/base_dir/custom_nodes/stubs/stubs.py new file mode 100644 index 000000000..a1df87529 --- /dev/null +++ b/tests/execution/base_dir/custom_nodes/stubs/stubs.py @@ -0,0 +1,129 @@ +import torch + +class StubImage: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "content": (['WHITE', 'BLACK', 'NOISE'],), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_image" + + CATEGORY = "Testing/Stub Nodes" + + def stub_image(self, content, height, width, batch_size): + if content == "WHITE": + return (torch.ones(batch_size, height, width, 3),) + elif content == "BLACK": + return (torch.zeros(batch_size, height, width, 3),) + elif content == "NOISE": + return (torch.rand(batch_size, height, width, 3),) + +class StubConstantImage: + def __init__(self): + pass + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_constant_image" + + CATEGORY = "Testing/Stub Nodes" + + def stub_constant_image(self, value, height, width, batch_size): + return (torch.ones(batch_size, height, width, 3) * value,) + +class StubMask: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "stub_mask" + + CATEGORY = "Testing/Stub Nodes" + + def stub_mask(self, value, height, width, batch_size): + return (torch.ones(batch_size, height, width) * value,) + +class StubInt: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "stub_int" + + CATEGORY = "Testing/Stub Nodes" + + def stub_int(self, value): + return (value,) + +class StubFloat: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}), + }, + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "stub_float" + + CATEGORY = "Testing/Stub Nodes" + + def stub_float(self, value): + return (value,) + +TEST_STUB_NODE_CLASS_MAPPINGS = { + "StubImage": StubImage, + "StubConstantImage": StubConstantImage, + "StubMask": StubMask, + "StubInt": StubInt, + "StubFloat": StubFloat, +} +TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { + "StubImage": "Stub Image", + "StubConstantImage": "Stub Constant Image", + "StubMask": "Stub Mask", + "StubInt": "Stub Int", + "StubFloat": "Stub Float", +} diff --git a/tests/execution/test_async_nodes.py b/tests/execution/test_async_nodes.py index aecbbb34d..ce77ff0bc 100644 --- a/tests/execution/test_async_nodes.py +++ b/tests/execution/test_async_nodes.py @@ -10,7 +10,7 @@ from comfy.client.embedded_comfy_client import Comfy from comfy.execution_context import context_add_custom_nodes from comfy.nodes.package_typing import ExportedNodes from comfy_execution.graph_utils import GraphBuilder -from tests.execution.test_execution import run_warmup +from .test_execution import run_warmup from .test_execution import ComfyClient, _ProgressHandler @@ -23,7 +23,7 @@ class TestAsyncNodes: (100,), ]) async def shared_client(self, request) -> AsyncGenerator[ComfyClient, Any]: - from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS lru_size, = request.param configuration = default_configuration() diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index 5ec426e6d..8c2e1725a 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -1,5 +1,7 @@ +import json import logging import time +import urllib.request import uuid from typing import Dict, Optional @@ -16,7 +18,7 @@ from comfy.distributed.server_stub import ServerStub from comfy.execution_context import context_add_custom_nodes from comfy.nodes.package_typing import ExportedNodes from comfy_execution.graph_utils import GraphBuilder, Node -from tests.conftest import current_test_name +from ..conftest import current_test_name async def run_warmup(client, prefix="warmup"): @@ -123,7 +125,7 @@ class TestExecution: (100,), ]) async def client(self, request) -> ComfyClient: - from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS lru_size, = request.param configuration = default_configuration() diff --git a/tests/execution/test_progress_isolation.py b/tests/execution/test_progress_isolation.py index 93dc0d41b..d09a9fe9f 100644 --- a/tests/execution/test_progress_isolation.py +++ b/tests/execution/test_progress_isolation.py @@ -1,14 +1,25 @@ """Test that progress updates are properly isolated between WebSocket clients.""" import json +from io import BytesIO +from multiprocessing import Process +from pathlib import Path + import pytest import time import threading import uuid import websocket -from typing import List, Dict, Any +import urllib.parse +import urllib.request +from typing import List, Dict, Any, Generator + +from PIL import Image + +from comfy.cli_args_types import Configuration from comfy_execution.graph_utils import GraphBuilder -from tests.execution.test_execution import ComfyClient +from .test_execution import ComfyClient, RunResult +from ..conftest import comfy_background_server_from_config class ProgressTracker: @@ -42,19 +53,100 @@ class ProgressTracker: return False -class IsolatedClient(ComfyClient): +class IsolatedClient(): """Extended ComfyClient that tracks all WebSocket messages.""" def __init__(self): - super().__init__() + self.test_name = "" self.progress_tracker = None self.all_messages: List[Dict[str, Any]] = [] + def queue_prompt(self, prompt, partial_execution_targets=None): + p = {"prompt": prompt, "client_id": self.client_id} + if partial_execution_targets is not None: + p["partial_execution_targets"] = partial_execution_targets + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def get_all_history(self, max_items=None, offset=None): + url = "http://{}/history".format(self.server_address) + params = {} + if max_items is not None: + params["max_items"] = max_items + if offset is not None: + params["offset"] = offset + + if params: + url_values = urllib.parse.urlencode(params) + url = "{}?{}".format(url, url_values) + + with urllib.request.urlopen(url) as response: + return json.loads(response.read()) + + def set_test_name(self, name): + self.test_name = name + + def run(self, graph, partial_execution_targets=None): + prompt = graph.finalize() + for node in graph.nodes.values(): + if node.class_type == 'SaveImage': + node.inputs['filename_prefix'] = self.test_name + + prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id'] + result = RunResult(prompt_id) + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['prompt_id'] != prompt_id: + continue + if data['node'] is None: + break + result.runs[data['node']] = True + elif message['type'] == 'execution_error': + raise Exception(message['data']) + elif message['type'] == 'execution_cached': + if message['data']['prompt_id'] == prompt_id: + cached_nodes = message['data'].get('nodes', []) + for node_id in cached_nodes: + result.cached[node_id] = True + + history = self.get_history(prompt_id)[prompt_id] + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + result.outputs[node_id] = node_output + images_output = [] + if 'images' in node_output: + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + image_obj = Image.open(BytesIO(image_data)) + images_output.append(image_obj) + node_output['image_objects'] = images_output + + return result + def connect(self, listen='127.0.0.1', port=8188, client_id=None): """Connect with a specific client_id and set up message tracking.""" if client_id is None: client_id = str(uuid.uuid4()) - super().connect(listen, port, client_id) + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws self.progress_tracker = ProgressTracker(client_id) def listen_for_messages(self, duration: float = 5.0): @@ -83,21 +175,39 @@ class IsolatedClient(ComfyClient): class TestProgressIsolation: """Test suite for verifying progress update isolation between clients.""" + @pytest.fixture(scope="class") + def args_pytest(self, tmp_path_factory): + yield { + "listen": "localhost", + "port": 19090, + "output_dir": tmp_path_factory.mktemp("comfy_background_server") + } + @pytest.fixture(scope="class", autouse=True) - def _server(self, args_pytest): - """Start the ComfyUI server for testing.""" - import subprocess - pargs = [ - 'python', 'main.py', - '--output-directory', args_pytest["output_dir"], - '--listen', args_pytest["listen"], - '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', - '--cpu', - ] - p = subprocess.Popen(pargs) - yield - p.kill() + def _server(self, args_pytest, tmp_path_factory) -> Generator[tuple[Configuration, Process], Any, None]: + tmp_path = tmp_path_factory.mktemp("comfy_background_server") + # Start server + + configuration = Configuration() + configuration.listen = args_pytest["listen"] + configuration.port = args_pytest["port"] + configuration.cpu = True + configuration.output_directory = str(args_pytest["output_dir"]) + configuration.input_directory = str(tmp_path) + from importlib.resources import files + + extra_nodes_path = str(files(f"{__package__}.base_dir")) + extra_nodes = f""" +testing_pack: + base_path: {extra_nodes_path} + custom_nodes: custom_nodes + """ + yaml_path = str(tmp_path_factory.mktemp("comfy_background_server") / "extra_nodes.yaml") + with open(yaml_path, mode="wt") as f: + f.write(extra_nodes) + configuration.extra_model_paths_config = [str(yaml_path)] + + yield from comfy_background_server_from_config(configuration) def start_client_with_retry(self, listen: str, port: int, client_id: str = None): """Start client with connection retries.""" @@ -111,10 +221,10 @@ class TestProgressIsolation: return client except ConnectionRefusedError as e: print(e) # noqa: T201 - print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201 + print(f"({i + 1}/{n_tries}) Retrying...") # noqa: T201 raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts") - def test_progress_isolation_between_clients(self, args_pytest): + async def test_progress_isolation_between_clients(self, args_pytest): """Test that progress updates are isolated between different clients.""" listen = args_pytest["listen"] port = args_pytest["port"] @@ -230,4 +340,3 @@ class TestProgressIsolation: finally: if hasattr(client, 'ws'): client.ws.close() -