Fix custom model paths config paths, tweak tests

This commit is contained in:
doctorpangloss 2025-09-23 11:01:48 -07:00
parent 4a3feee1a2
commit 6af812f9a8
9 changed files with 276 additions and 28 deletions

View File

@ -177,7 +177,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
pass
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
for config_path in args.extra_model_paths_config:
load_extra_path_config(config_path)
if args.create_directories:

View File

@ -173,7 +173,8 @@ def comfy_background_server_from_config(configuration):
@pytest.fixture(scope="session")
def skip_timing_checks(pytestconfig):
"""Fixture that returns whether timing checks should be skipped."""
return pytestconfig.getoption("--skip-timing-checks")
# todo: in the LTS, we don't need to skip timing checks, everything just works
return False
def pytest_collection_modifyitems(items):

View File

View File

View File

@ -0,0 +1,7 @@
from .stubs import StubImage, StubConstantImage, StubMask
NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubConstantImage": StubConstantImage,
"StubMask": StubMask
}

View File

@ -0,0 +1,129 @@
import torch
class StubImage:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"content": (['WHITE', 'BLACK', 'NOISE'],),
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_image"
CATEGORY = "Testing/Stub Nodes"
def stub_image(self, content, height, width, batch_size):
if content == "WHITE":
return (torch.ones(batch_size, height, width, 3),)
elif content == "BLACK":
return (torch.zeros(batch_size, height, width, 3),)
elif content == "NOISE":
return (torch.rand(batch_size, height, width, 3),)
class StubConstantImage:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_constant_image"
CATEGORY = "Testing/Stub Nodes"
def stub_constant_image(self, value, height, width, batch_size):
return (torch.ones(batch_size, height, width, 3) * value,)
class StubMask:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "stub_mask"
CATEGORY = "Testing/Stub Nodes"
def stub_mask(self, value, height, width, batch_size):
return (torch.ones(batch_size, height, width) * value,)
class StubInt:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
},
}
RETURN_TYPES = ("INT",)
FUNCTION = "stub_int"
CATEGORY = "Testing/Stub Nodes"
def stub_int(self, value):
return (value,)
class StubFloat:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "stub_float"
CATEGORY = "Testing/Stub Nodes"
def stub_float(self, value):
return (value,)
TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubConstantImage": StubConstantImage,
"StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
}
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image",
"StubConstantImage": "Stub Constant Image",
"StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",
}

View File

@ -10,7 +10,7 @@ from comfy.client.embedded_comfy_client import Comfy
from comfy.execution_context import context_add_custom_nodes
from comfy.nodes.package_typing import ExportedNodes
from comfy_execution.graph_utils import GraphBuilder
from tests.execution.test_execution import run_warmup
from .test_execution import run_warmup
from .test_execution import ComfyClient, _ProgressHandler
@ -23,7 +23,7 @@ class TestAsyncNodes:
(100,),
])
async def shared_client(self, request) -> AsyncGenerator[ComfyClient, Any]:
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
lru_size, = request.param
configuration = default_configuration()

View File

@ -1,5 +1,7 @@
import json
import logging
import time
import urllib.request
import uuid
from typing import Dict, Optional
@ -16,7 +18,7 @@ from comfy.distributed.server_stub import ServerStub
from comfy.execution_context import context_add_custom_nodes
from comfy.nodes.package_typing import ExportedNodes
from comfy_execution.graph_utils import GraphBuilder, Node
from tests.conftest import current_test_name
from ..conftest import current_test_name
async def run_warmup(client, prefix="warmup"):
@ -123,7 +125,7 @@ class TestExecution:
(100,),
])
async def client(self, request) -> ComfyClient:
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
lru_size, = request.param
configuration = default_configuration()

View File

@ -1,14 +1,25 @@
"""Test that progress updates are properly isolated between WebSocket clients."""
import json
from io import BytesIO
from multiprocessing import Process
from pathlib import Path
import pytest
import time
import threading
import uuid
import websocket
from typing import List, Dict, Any
import urllib.parse
import urllib.request
from typing import List, Dict, Any, Generator
from PIL import Image
from comfy.cli_args_types import Configuration
from comfy_execution.graph_utils import GraphBuilder
from tests.execution.test_execution import ComfyClient
from .test_execution import ComfyClient, RunResult
from ..conftest import comfy_background_server_from_config
class ProgressTracker:
@ -42,19 +53,100 @@ class ProgressTracker:
return False
class IsolatedClient(ComfyClient):
class IsolatedClient():
"""Extended ComfyClient that tracks all WebSocket messages."""
def __init__(self):
super().__init__()
self.test_name = ""
self.progress_tracker = None
self.all_messages: List[Dict[str, Any]] = []
def queue_prompt(self, prompt, partial_execution_targets=None):
p = {"prompt": prompt, "client_id": self.client_id}
if partial_execution_targets is not None:
p["partial_execution_targets"] = partial_execution_targets
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(self, filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
return response.read()
def get_history(self, prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
return json.loads(response.read())
def get_all_history(self, max_items=None, offset=None):
url = "http://{}/history".format(self.server_address)
params = {}
if max_items is not None:
params["max_items"] = max_items
if offset is not None:
params["offset"] = offset
if params:
url_values = urllib.parse.urlencode(params)
url = "{}?{}".format(url, url_values)
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
def set_test_name(self, name):
self.test_name = name
def run(self, graph, partial_execution_targets=None):
prompt = graph.finalize()
for node in graph.nodes.values():
if node.class_type == 'SaveImage':
node.inputs['filename_prefix'] = self.test_name
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
result = RunResult(prompt_id)
while True:
out = self.ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['prompt_id'] != prompt_id:
continue
if data['node'] is None:
break
result.runs[data['node']] = True
elif message['type'] == 'execution_error':
raise Exception(message['data'])
elif message['type'] == 'execution_cached':
if message['data']['prompt_id'] == prompt_id:
cached_nodes = message['data'].get('nodes', [])
for node_id in cached_nodes:
result.cached[node_id] = True
history = self.get_history(prompt_id)[prompt_id]
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
result.outputs[node_id] = node_output
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
image_obj = Image.open(BytesIO(image_data))
images_output.append(image_obj)
node_output['image_objects'] = images_output
return result
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
"""Connect with a specific client_id and set up message tracking."""
if client_id is None:
client_id = str(uuid.uuid4())
super().connect(listen, port, client_id)
self.client_id = client_id
self.server_address = f"{listen}:{port}"
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
self.ws = ws
self.progress_tracker = ProgressTracker(client_id)
def listen_for_messages(self, duration: float = 5.0):
@ -83,21 +175,39 @@ class IsolatedClient(ComfyClient):
class TestProgressIsolation:
"""Test suite for verifying progress update isolation between clients."""
@pytest.fixture(scope="class")
def args_pytest(self, tmp_path_factory):
yield {
"listen": "localhost",
"port": 19090,
"output_dir": tmp_path_factory.mktemp("comfy_background_server")
}
@pytest.fixture(scope="class", autouse=True)
def _server(self, args_pytest):
"""Start the ComfyUI server for testing."""
import subprocess
pargs = [
'python', 'main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
p = subprocess.Popen(pargs)
yield
p.kill()
def _server(self, args_pytest, tmp_path_factory) -> Generator[tuple[Configuration, Process], Any, None]:
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
# Start server
configuration = Configuration()
configuration.listen = args_pytest["listen"]
configuration.port = args_pytest["port"]
configuration.cpu = True
configuration.output_directory = str(args_pytest["output_dir"])
configuration.input_directory = str(tmp_path)
from importlib.resources import files
extra_nodes_path = str(files(f"{__package__}.base_dir"))
extra_nodes = f"""
testing_pack:
base_path: {extra_nodes_path}
custom_nodes: custom_nodes
"""
yaml_path = str(tmp_path_factory.mktemp("comfy_background_server") / "extra_nodes.yaml")
with open(yaml_path, mode="wt") as f:
f.write(extra_nodes)
configuration.extra_model_paths_config = [str(yaml_path)]
yield from comfy_background_server_from_config(configuration)
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
"""Start client with connection retries."""
@ -111,10 +221,10 @@ class TestProgressIsolation:
return client
except ConnectionRefusedError as e:
print(e) # noqa: T201
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
print(f"({i + 1}/{n_tries}) Retrying...") # noqa: T201
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
def test_progress_isolation_between_clients(self, args_pytest):
async def test_progress_isolation_between_clients(self, args_pytest):
"""Test that progress updates are isolated between different clients."""
listen = args_pytest["listen"]
port = args_pytest["port"]
@ -230,4 +340,3 @@ class TestProgressIsolation:
finally:
if hasattr(client, 'ws'):
client.ws.close()