mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fix custom model paths config paths, tweak tests
This commit is contained in:
parent
4a3feee1a2
commit
6af812f9a8
@ -177,7 +177,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
pass
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
for config_path in args.extra_model_paths_config:
|
||||
load_extra_path_config(config_path)
|
||||
|
||||
if args.create_directories:
|
||||
|
||||
@ -173,7 +173,8 @@ def comfy_background_server_from_config(configuration):
|
||||
@pytest.fixture(scope="session")
|
||||
def skip_timing_checks(pytestconfig):
|
||||
"""Fixture that returns whether timing checks should be skipped."""
|
||||
return pytestconfig.getoption("--skip-timing-checks")
|
||||
# todo: in the LTS, we don't need to skip timing checks, everything just works
|
||||
return False
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
|
||||
0
tests/execution/__init__.py
Normal file
0
tests/execution/__init__.py
Normal file
0
tests/execution/base_dir/__init__.py
Normal file
0
tests/execution/base_dir/__init__.py
Normal file
7
tests/execution/base_dir/custom_nodes/stubs/__init__.py
Normal file
7
tests/execution/base_dir/custom_nodes/stubs/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .stubs import StubImage, StubConstantImage, StubMask
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubConstantImage": StubConstantImage,
|
||||
"StubMask": StubMask
|
||||
}
|
||||
129
tests/execution/base_dir/custom_nodes/stubs/stubs.py
Normal file
129
tests/execution/base_dir/custom_nodes/stubs/stubs.py
Normal file
@ -0,0 +1,129 @@
|
||||
import torch
|
||||
|
||||
class StubImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"content": (['WHITE', 'BLACK', 'NOISE'],),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_image"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_image(self, content, height, width, batch_size):
|
||||
if content == "WHITE":
|
||||
return (torch.ones(batch_size, height, width, 3),)
|
||||
elif content == "BLACK":
|
||||
return (torch.zeros(batch_size, height, width, 3),)
|
||||
elif content == "NOISE":
|
||||
return (torch.rand(batch_size, height, width, 3),)
|
||||
|
||||
class StubConstantImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_constant_image"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_constant_image(self, value, height, width, batch_size):
|
||||
return (torch.ones(batch_size, height, width, 3) * value,)
|
||||
|
||||
class StubMask:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "stub_mask"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_mask(self, value, height, width, batch_size):
|
||||
return (torch.ones(batch_size, height, width) * value,)
|
||||
|
||||
class StubInt:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "stub_int"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_int(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubFloat:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "stub_float"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_float(self, value):
|
||||
return (value,)
|
||||
|
||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubConstantImage": StubConstantImage,
|
||||
"StubMask": StubMask,
|
||||
"StubInt": StubInt,
|
||||
"StubFloat": StubFloat,
|
||||
}
|
||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubImage": "Stub Image",
|
||||
"StubConstantImage": "Stub Constant Image",
|
||||
"StubMask": "Stub Mask",
|
||||
"StubInt": "Stub Int",
|
||||
"StubFloat": "Stub Float",
|
||||
}
|
||||
@ -10,7 +10,7 @@ from comfy.client.embedded_comfy_client import Comfy
|
||||
from comfy.execution_context import context_add_custom_nodes
|
||||
from comfy.nodes.package_typing import ExportedNodes
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.execution.test_execution import run_warmup
|
||||
from .test_execution import run_warmup
|
||||
from .test_execution import ComfyClient, _ProgressHandler
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class TestAsyncNodes:
|
||||
(100,),
|
||||
])
|
||||
async def shared_client(self, request) -> AsyncGenerator[ComfyClient, Any]:
|
||||
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
lru_size, = request.param
|
||||
configuration = default_configuration()
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import urllib.request
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
@ -16,7 +18,7 @@ from comfy.distributed.server_stub import ServerStub
|
||||
from comfy.execution_context import context_add_custom_nodes
|
||||
from comfy.nodes.package_typing import ExportedNodes
|
||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||
from tests.conftest import current_test_name
|
||||
from ..conftest import current_test_name
|
||||
|
||||
|
||||
async def run_warmup(client, prefix="warmup"):
|
||||
@ -123,7 +125,7 @@ class TestExecution:
|
||||
(100,),
|
||||
])
|
||||
async def client(self, request) -> ComfyClient:
|
||||
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
lru_size, = request.param
|
||||
configuration = default_configuration()
|
||||
|
||||
@ -1,14 +1,25 @@
|
||||
"""Test that progress updates are properly isolated between WebSocket clients."""
|
||||
|
||||
import json
|
||||
from io import BytesIO
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import threading
|
||||
import uuid
|
||||
import websocket
|
||||
from typing import List, Dict, Any
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from typing import List, Dict, Any, Generator
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from comfy.cli_args_types import Configuration
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from tests.execution.test_execution import ComfyClient
|
||||
from .test_execution import ComfyClient, RunResult
|
||||
from ..conftest import comfy_background_server_from_config
|
||||
|
||||
|
||||
class ProgressTracker:
|
||||
@ -42,19 +53,100 @@ class ProgressTracker:
|
||||
return False
|
||||
|
||||
|
||||
class IsolatedClient(ComfyClient):
|
||||
class IsolatedClient():
|
||||
"""Extended ComfyClient that tracks all WebSocket messages."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.test_name = ""
|
||||
self.progress_tracker = None
|
||||
self.all_messages: List[Dict[str, Any]] = []
|
||||
|
||||
def queue_prompt(self, prompt, partial_execution_targets=None):
|
||||
p = {"prompt": prompt, "client_id": self.client_id}
|
||||
if partial_execution_targets is not None:
|
||||
p["partial_execution_targets"] = partial_execution_targets
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
def get_image(self, filename, subfolder, folder_type):
|
||||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
url_values = urllib.parse.urlencode(data)
|
||||
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
|
||||
return response.read()
|
||||
|
||||
def get_history(self, prompt_id):
|
||||
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def get_all_history(self, max_items=None, offset=None):
|
||||
url = "http://{}/history".format(self.server_address)
|
||||
params = {}
|
||||
if max_items is not None:
|
||||
params["max_items"] = max_items
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
|
||||
if params:
|
||||
url_values = urllib.parse.urlencode(params)
|
||||
url = "{}?{}".format(url, url_values)
|
||||
|
||||
with urllib.request.urlopen(url) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def set_test_name(self, name):
|
||||
self.test_name = name
|
||||
|
||||
def run(self, graph, partial_execution_targets=None):
|
||||
prompt = graph.finalize()
|
||||
for node in graph.nodes.values():
|
||||
if node.class_type == 'SaveImage':
|
||||
node.inputs['filename_prefix'] = self.test_name
|
||||
|
||||
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
|
||||
result = RunResult(prompt_id)
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message['type'] == 'executing':
|
||||
data = message['data']
|
||||
if data['prompt_id'] != prompt_id:
|
||||
continue
|
||||
if data['node'] is None:
|
||||
break
|
||||
result.runs[data['node']] = True
|
||||
elif message['type'] == 'execution_error':
|
||||
raise Exception(message['data'])
|
||||
elif message['type'] == 'execution_cached':
|
||||
if message['data']['prompt_id'] == prompt_id:
|
||||
cached_nodes = message['data'].get('nodes', [])
|
||||
for node_id in cached_nodes:
|
||||
result.cached[node_id] = True
|
||||
|
||||
history = self.get_history(prompt_id)[prompt_id]
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
result.outputs[node_id] = node_output
|
||||
images_output = []
|
||||
if 'images' in node_output:
|
||||
for image in node_output['images']:
|
||||
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||
image_obj = Image.open(BytesIO(image_data))
|
||||
images_output.append(image_obj)
|
||||
node_output['image_objects'] = images_output
|
||||
|
||||
return result
|
||||
|
||||
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
|
||||
"""Connect with a specific client_id and set up message tracking."""
|
||||
if client_id is None:
|
||||
client_id = str(uuid.uuid4())
|
||||
super().connect(listen, port, client_id)
|
||||
self.client_id = client_id
|
||||
self.server_address = f"{listen}:{port}"
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
||||
self.ws = ws
|
||||
self.progress_tracker = ProgressTracker(client_id)
|
||||
|
||||
def listen_for_messages(self, duration: float = 5.0):
|
||||
@ -83,21 +175,39 @@ class IsolatedClient(ComfyClient):
|
||||
class TestProgressIsolation:
|
||||
"""Test suite for verifying progress update isolation between clients."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def args_pytest(self, tmp_path_factory):
|
||||
yield {
|
||||
"listen": "localhost",
|
||||
"port": 19090,
|
||||
"output_dir": tmp_path_factory.mktemp("comfy_background_server")
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def _server(self, args_pytest):
|
||||
"""Start the ComfyUI server for testing."""
|
||||
import subprocess
|
||||
pargs = [
|
||||
'python', 'main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
def _server(self, args_pytest, tmp_path_factory) -> Generator[tuple[Configuration, Process], Any, None]:
|
||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||
# Start server
|
||||
|
||||
configuration = Configuration()
|
||||
configuration.listen = args_pytest["listen"]
|
||||
configuration.port = args_pytest["port"]
|
||||
configuration.cpu = True
|
||||
configuration.output_directory = str(args_pytest["output_dir"])
|
||||
configuration.input_directory = str(tmp_path)
|
||||
from importlib.resources import files
|
||||
|
||||
extra_nodes_path = str(files(f"{__package__}.base_dir"))
|
||||
extra_nodes = f"""
|
||||
testing_pack:
|
||||
base_path: {extra_nodes_path}
|
||||
custom_nodes: custom_nodes
|
||||
"""
|
||||
yaml_path = str(tmp_path_factory.mktemp("comfy_background_server") / "extra_nodes.yaml")
|
||||
with open(yaml_path, mode="wt") as f:
|
||||
f.write(extra_nodes)
|
||||
configuration.extra_model_paths_config = [str(yaml_path)]
|
||||
|
||||
yield from comfy_background_server_from_config(configuration)
|
||||
|
||||
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
|
||||
"""Start client with connection retries."""
|
||||
@ -111,10 +221,10 @@ class TestProgressIsolation:
|
||||
return client
|
||||
except ConnectionRefusedError as e:
|
||||
print(e) # noqa: T201
|
||||
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
|
||||
print(f"({i + 1}/{n_tries}) Retrying...") # noqa: T201
|
||||
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
|
||||
|
||||
def test_progress_isolation_between_clients(self, args_pytest):
|
||||
async def test_progress_isolation_between_clients(self, args_pytest):
|
||||
"""Test that progress updates are isolated between different clients."""
|
||||
listen = args_pytest["listen"]
|
||||
port = args_pytest["port"]
|
||||
@ -230,4 +340,3 @@ class TestProgressIsolation:
|
||||
finally:
|
||||
if hasattr(client, 'ws'):
|
||||
client.ws.close()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user