mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Fix custom model paths config paths, tweak tests
This commit is contained in:
parent
4a3feee1a2
commit
6af812f9a8
@ -177,7 +177,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if args.extra_model_paths_config:
|
if args.extra_model_paths_config:
|
||||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
for config_path in args.extra_model_paths_config:
|
||||||
load_extra_path_config(config_path)
|
load_extra_path_config(config_path)
|
||||||
|
|
||||||
if args.create_directories:
|
if args.create_directories:
|
||||||
|
|||||||
@ -173,7 +173,8 @@ def comfy_background_server_from_config(configuration):
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def skip_timing_checks(pytestconfig):
|
def skip_timing_checks(pytestconfig):
|
||||||
"""Fixture that returns whether timing checks should be skipped."""
|
"""Fixture that returns whether timing checks should be skipped."""
|
||||||
return pytestconfig.getoption("--skip-timing-checks")
|
# todo: in the LTS, we don't need to skip timing checks, everything just works
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(items):
|
def pytest_collection_modifyitems(items):
|
||||||
|
|||||||
0
tests/execution/__init__.py
Normal file
0
tests/execution/__init__.py
Normal file
0
tests/execution/base_dir/__init__.py
Normal file
0
tests/execution/base_dir/__init__.py
Normal file
7
tests/execution/base_dir/custom_nodes/stubs/__init__.py
Normal file
7
tests/execution/base_dir/custom_nodes/stubs/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .stubs import StubImage, StubConstantImage, StubMask
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"StubImage": StubImage,
|
||||||
|
"StubConstantImage": StubConstantImage,
|
||||||
|
"StubMask": StubMask
|
||||||
|
}
|
||||||
129
tests/execution/base_dir/custom_nodes/stubs/stubs.py
Normal file
129
tests/execution/base_dir/custom_nodes/stubs/stubs.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class StubImage:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"content": (['WHITE', 'BLACK', 'NOISE'],),
|
||||||
|
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "stub_image"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Stub Nodes"
|
||||||
|
|
||||||
|
def stub_image(self, content, height, width, batch_size):
|
||||||
|
if content == "WHITE":
|
||||||
|
return (torch.ones(batch_size, height, width, 3),)
|
||||||
|
elif content == "BLACK":
|
||||||
|
return (torch.zeros(batch_size, height, width, 3),)
|
||||||
|
elif content == "NOISE":
|
||||||
|
return (torch.rand(batch_size, height, width, 3),)
|
||||||
|
|
||||||
|
class StubConstantImage:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "stub_constant_image"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Stub Nodes"
|
||||||
|
|
||||||
|
def stub_constant_image(self, value, height, width, batch_size):
|
||||||
|
return (torch.ones(batch_size, height, width, 3) * value,)
|
||||||
|
|
||||||
|
class StubMask:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
FUNCTION = "stub_mask"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Stub Nodes"
|
||||||
|
|
||||||
|
def stub_mask(self, value, height, width, batch_size):
|
||||||
|
return (torch.ones(batch_size, height, width) * value,)
|
||||||
|
|
||||||
|
class StubInt:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("INT",)
|
||||||
|
FUNCTION = "stub_int"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Stub Nodes"
|
||||||
|
|
||||||
|
def stub_int(self, value):
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
class StubFloat:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("FLOAT",)
|
||||||
|
FUNCTION = "stub_float"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Stub Nodes"
|
||||||
|
|
||||||
|
def stub_float(self, value):
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||||
|
"StubImage": StubImage,
|
||||||
|
"StubConstantImage": StubConstantImage,
|
||||||
|
"StubMask": StubMask,
|
||||||
|
"StubInt": StubInt,
|
||||||
|
"StubFloat": StubFloat,
|
||||||
|
}
|
||||||
|
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"StubImage": "Stub Image",
|
||||||
|
"StubConstantImage": "Stub Constant Image",
|
||||||
|
"StubMask": "Stub Mask",
|
||||||
|
"StubInt": "Stub Int",
|
||||||
|
"StubFloat": "Stub Float",
|
||||||
|
}
|
||||||
@ -10,7 +10,7 @@ from comfy.client.embedded_comfy_client import Comfy
|
|||||||
from comfy.execution_context import context_add_custom_nodes
|
from comfy.execution_context import context_add_custom_nodes
|
||||||
from comfy.nodes.package_typing import ExportedNodes
|
from comfy.nodes.package_typing import ExportedNodes
|
||||||
from comfy_execution.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from tests.execution.test_execution import run_warmup
|
from .test_execution import run_warmup
|
||||||
from .test_execution import ComfyClient, _ProgressHandler
|
from .test_execution import ComfyClient, _ProgressHandler
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ class TestAsyncNodes:
|
|||||||
(100,),
|
(100,),
|
||||||
])
|
])
|
||||||
async def shared_client(self, request) -> AsyncGenerator[ComfyClient, Any]:
|
async def shared_client(self, request) -> AsyncGenerator[ComfyClient, Any]:
|
||||||
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
lru_size, = request.param
|
lru_size, = request.param
|
||||||
configuration = default_configuration()
|
configuration = default_configuration()
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import urllib.request
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
@ -16,7 +18,7 @@ from comfy.distributed.server_stub import ServerStub
|
|||||||
from comfy.execution_context import context_add_custom_nodes
|
from comfy.execution_context import context_add_custom_nodes
|
||||||
from comfy.nodes.package_typing import ExportedNodes
|
from comfy.nodes.package_typing import ExportedNodes
|
||||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||||
from tests.conftest import current_test_name
|
from ..conftest import current_test_name
|
||||||
|
|
||||||
|
|
||||||
async def run_warmup(client, prefix="warmup"):
|
async def run_warmup(client, prefix="warmup"):
|
||||||
@ -123,7 +125,7 @@ class TestExecution:
|
|||||||
(100,),
|
(100,),
|
||||||
])
|
])
|
||||||
async def client(self, request) -> ComfyClient:
|
async def client(self, request) -> ComfyClient:
|
||||||
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
lru_size, = request.param
|
lru_size, = request.param
|
||||||
configuration = default_configuration()
|
configuration = default_configuration()
|
||||||
|
|||||||
@ -1,14 +1,25 @@
|
|||||||
"""Test that progress updates are properly isolated between WebSocket clients."""
|
"""Test that progress updates are properly isolated between WebSocket clients."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from io import BytesIO
|
||||||
|
from multiprocessing import Process
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
import websocket
|
import websocket
|
||||||
from typing import List, Dict, Any
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
from typing import List, Dict, Any, Generator
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy.cli_args_types import Configuration
|
||||||
from comfy_execution.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from tests.execution.test_execution import ComfyClient
|
from .test_execution import ComfyClient, RunResult
|
||||||
|
from ..conftest import comfy_background_server_from_config
|
||||||
|
|
||||||
|
|
||||||
class ProgressTracker:
|
class ProgressTracker:
|
||||||
@ -42,19 +53,100 @@ class ProgressTracker:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class IsolatedClient(ComfyClient):
|
class IsolatedClient():
|
||||||
"""Extended ComfyClient that tracks all WebSocket messages."""
|
"""Extended ComfyClient that tracks all WebSocket messages."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
self.test_name = ""
|
||||||
self.progress_tracker = None
|
self.progress_tracker = None
|
||||||
self.all_messages: List[Dict[str, Any]] = []
|
self.all_messages: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def queue_prompt(self, prompt, partial_execution_targets=None):
|
||||||
|
p = {"prompt": prompt, "client_id": self.client_id}
|
||||||
|
if partial_execution_targets is not None:
|
||||||
|
p["partial_execution_targets"] = partial_execution_targets
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
||||||
|
return json.loads(urllib.request.urlopen(req).read())
|
||||||
|
|
||||||
|
def get_image(self, filename, subfolder, folder_type):
|
||||||
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||||
|
url_values = urllib.parse.urlencode(data)
|
||||||
|
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
|
||||||
|
return response.read()
|
||||||
|
|
||||||
|
def get_history(self, prompt_id):
|
||||||
|
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
||||||
|
return json.loads(response.read())
|
||||||
|
|
||||||
|
def get_all_history(self, max_items=None, offset=None):
|
||||||
|
url = "http://{}/history".format(self.server_address)
|
||||||
|
params = {}
|
||||||
|
if max_items is not None:
|
||||||
|
params["max_items"] = max_items
|
||||||
|
if offset is not None:
|
||||||
|
params["offset"] = offset
|
||||||
|
|
||||||
|
if params:
|
||||||
|
url_values = urllib.parse.urlencode(params)
|
||||||
|
url = "{}?{}".format(url, url_values)
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url) as response:
|
||||||
|
return json.loads(response.read())
|
||||||
|
|
||||||
|
def set_test_name(self, name):
|
||||||
|
self.test_name = name
|
||||||
|
|
||||||
|
def run(self, graph, partial_execution_targets=None):
|
||||||
|
prompt = graph.finalize()
|
||||||
|
for node in graph.nodes.values():
|
||||||
|
if node.class_type == 'SaveImage':
|
||||||
|
node.inputs['filename_prefix'] = self.test_name
|
||||||
|
|
||||||
|
prompt_id = self.queue_prompt(prompt, partial_execution_targets)['prompt_id']
|
||||||
|
result = RunResult(prompt_id)
|
||||||
|
while True:
|
||||||
|
out = self.ws.recv()
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
if message['type'] == 'executing':
|
||||||
|
data = message['data']
|
||||||
|
if data['prompt_id'] != prompt_id:
|
||||||
|
continue
|
||||||
|
if data['node'] is None:
|
||||||
|
break
|
||||||
|
result.runs[data['node']] = True
|
||||||
|
elif message['type'] == 'execution_error':
|
||||||
|
raise Exception(message['data'])
|
||||||
|
elif message['type'] == 'execution_cached':
|
||||||
|
if message['data']['prompt_id'] == prompt_id:
|
||||||
|
cached_nodes = message['data'].get('nodes', [])
|
||||||
|
for node_id in cached_nodes:
|
||||||
|
result.cached[node_id] = True
|
||||||
|
|
||||||
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
|
for node_id in history['outputs']:
|
||||||
|
node_output = history['outputs'][node_id]
|
||||||
|
result.outputs[node_id] = node_output
|
||||||
|
images_output = []
|
||||||
|
if 'images' in node_output:
|
||||||
|
for image in node_output['images']:
|
||||||
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
|
image_obj = Image.open(BytesIO(image_data))
|
||||||
|
images_output.append(image_obj)
|
||||||
|
node_output['image_objects'] = images_output
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
|
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
|
||||||
"""Connect with a specific client_id and set up message tracking."""
|
"""Connect with a specific client_id and set up message tracking."""
|
||||||
if client_id is None:
|
if client_id is None:
|
||||||
client_id = str(uuid.uuid4())
|
client_id = str(uuid.uuid4())
|
||||||
super().connect(listen, port, client_id)
|
self.client_id = client_id
|
||||||
|
self.server_address = f"{listen}:{port}"
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
||||||
|
self.ws = ws
|
||||||
self.progress_tracker = ProgressTracker(client_id)
|
self.progress_tracker = ProgressTracker(client_id)
|
||||||
|
|
||||||
def listen_for_messages(self, duration: float = 5.0):
|
def listen_for_messages(self, duration: float = 5.0):
|
||||||
@ -83,21 +175,39 @@ class IsolatedClient(ComfyClient):
|
|||||||
class TestProgressIsolation:
|
class TestProgressIsolation:
|
||||||
"""Test suite for verifying progress update isolation between clients."""
|
"""Test suite for verifying progress update isolation between clients."""
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def args_pytest(self, tmp_path_factory):
|
||||||
|
yield {
|
||||||
|
"listen": "localhost",
|
||||||
|
"port": 19090,
|
||||||
|
"output_dir": tmp_path_factory.mktemp("comfy_background_server")
|
||||||
|
}
|
||||||
|
|
||||||
@pytest.fixture(scope="class", autouse=True)
|
@pytest.fixture(scope="class", autouse=True)
|
||||||
def _server(self, args_pytest):
|
def _server(self, args_pytest, tmp_path_factory) -> Generator[tuple[Configuration, Process], Any, None]:
|
||||||
"""Start the ComfyUI server for testing."""
|
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||||
import subprocess
|
# Start server
|
||||||
pargs = [
|
|
||||||
'python', 'main.py',
|
configuration = Configuration()
|
||||||
'--output-directory', args_pytest["output_dir"],
|
configuration.listen = args_pytest["listen"]
|
||||||
'--listen', args_pytest["listen"],
|
configuration.port = args_pytest["port"]
|
||||||
'--port', str(args_pytest["port"]),
|
configuration.cpu = True
|
||||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
configuration.output_directory = str(args_pytest["output_dir"])
|
||||||
'--cpu',
|
configuration.input_directory = str(tmp_path)
|
||||||
]
|
from importlib.resources import files
|
||||||
p = subprocess.Popen(pargs)
|
|
||||||
yield
|
extra_nodes_path = str(files(f"{__package__}.base_dir"))
|
||||||
p.kill()
|
extra_nodes = f"""
|
||||||
|
testing_pack:
|
||||||
|
base_path: {extra_nodes_path}
|
||||||
|
custom_nodes: custom_nodes
|
||||||
|
"""
|
||||||
|
yaml_path = str(tmp_path_factory.mktemp("comfy_background_server") / "extra_nodes.yaml")
|
||||||
|
with open(yaml_path, mode="wt") as f:
|
||||||
|
f.write(extra_nodes)
|
||||||
|
configuration.extra_model_paths_config = [str(yaml_path)]
|
||||||
|
|
||||||
|
yield from comfy_background_server_from_config(configuration)
|
||||||
|
|
||||||
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
|
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
|
||||||
"""Start client with connection retries."""
|
"""Start client with connection retries."""
|
||||||
@ -111,10 +221,10 @@ class TestProgressIsolation:
|
|||||||
return client
|
return client
|
||||||
except ConnectionRefusedError as e:
|
except ConnectionRefusedError as e:
|
||||||
print(e) # noqa: T201
|
print(e) # noqa: T201
|
||||||
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
|
print(f"({i + 1}/{n_tries}) Retrying...") # noqa: T201
|
||||||
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
|
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
|
||||||
|
|
||||||
def test_progress_isolation_between_clients(self, args_pytest):
|
async def test_progress_isolation_between_clients(self, args_pytest):
|
||||||
"""Test that progress updates are isolated between different clients."""
|
"""Test that progress updates are isolated between different clients."""
|
||||||
listen = args_pytest["listen"]
|
listen = args_pytest["listen"]
|
||||||
port = args_pytest["port"]
|
port = args_pytest["port"]
|
||||||
@ -230,4 +340,3 @@ class TestProgressIsolation:
|
|||||||
finally:
|
finally:
|
||||||
if hasattr(client, 'ws'):
|
if hasattr(client, 'ws'):
|
||||||
client.ws.close()
|
client.ws.close()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user