mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
fix tests
This commit is contained in:
parent
76b32c8e6e
commit
a3ae6e74d2
@ -16,6 +16,7 @@ from ..api.components.schema.prompt_request import PromptRequest
|
|||||||
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
|
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
|
||||||
from ..api.schemas import immutabledict
|
from ..api.schemas import immutabledict
|
||||||
from ..component_model.file_output_path import file_output_path
|
from ..component_model.file_output_path import file_output_path
|
||||||
|
from ..component_model.outputs_types import OutputsDict
|
||||||
|
|
||||||
|
|
||||||
class AsyncRemoteComfyClient:
|
class AsyncRemoteComfyClient:
|
||||||
@ -122,9 +123,9 @@ class AsyncRemoteComfyClient:
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_prompt_ui(self, prompt: PromptDict) -> dict[str, List[Path]]:
|
async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict:
|
||||||
"""
|
"""
|
||||||
Uses the comfyui UI API calls to retrieve a list of paths of output files
|
Uses the comfyui UI API calls to retrieve the outputs dictionary
|
||||||
:param prompt:
|
:param prompt:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
@ -158,17 +159,4 @@ class AsyncRemoteComfyClient:
|
|||||||
|
|
||||||
# images have filename, subfolder, type keys
|
# images have filename, subfolder, type keys
|
||||||
# todo: use the OpenAPI spec for this when I get around to updating it
|
# todo: use the OpenAPI spec for this when I get around to updating it
|
||||||
outputs_by_node_id = history_json[prompt_id].outputs
|
return history_json[prompt_id].outputs
|
||||||
res: dict[str, List[Path]] = {}
|
|
||||||
for node_id, output in outputs_by_node_id.items():
|
|
||||||
if 'images' in output:
|
|
||||||
images = []
|
|
||||||
image_dicts: List[dict] = output['images']
|
|
||||||
for image_file_output_dict in image_dicts:
|
|
||||||
image_file_output_dict = defaultdict(None, image_file_output_dict)
|
|
||||||
filename = image_file_output_dict['filename']
|
|
||||||
subfolder = image_file_output_dict['subfolder']
|
|
||||||
type = image_file_output_dict['type']
|
|
||||||
images.append(Path(file_output_path(filename, subfolder=subfolder, type=type)))
|
|
||||||
res[node_id] = images
|
|
||||||
return res
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ async def test_completes_prompt_with_ui(comfy_background_server):
|
|||||||
random_seed = random.randint(1, 4294967295)
|
random_seed = random.randint(1, 4294967295)
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||||
result_dict = await client.queue_prompt_ui(prompt)
|
result_dict = await client.queue_prompt_ui(prompt)
|
||||||
# should contain one output
|
# should contain one output
|
||||||
assert len(result_dict) == 1
|
assert len(result_dict) == 1
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,9 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
from typing import Tuple, List
|
from contextvars import ContextVar
|
||||||
|
from multiprocessing import Process
|
||||||
|
from typing import List, Any, Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
@ -24,12 +26,12 @@ logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
|
|||||||
|
|
||||||
|
|
||||||
def run_server(server_arguments: Configuration):
|
def run_server(server_arguments: Configuration):
|
||||||
from comfy.cmd.main import main
|
from comfy.cmd.main import _start_comfyui
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import asyncio
|
import asyncio
|
||||||
for arg, value in server_arguments.items():
|
for arg, value in server_arguments.items():
|
||||||
args[arg] = value
|
args[arg] = value
|
||||||
asyncio.run(main())
|
asyncio.run(_start_comfyui())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=False)
|
@pytest.fixture(scope="function", autouse=False)
|
||||||
@ -134,9 +136,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=False)
|
@pytest.fixture(scope="module", autouse=False)
|
||||||
def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiprocessing.Process]:
|
def comfy_background_server(tmp_path_factory) -> Generator[tuple[Configuration, Process], Any, None]:
|
||||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||||
import torch
|
|
||||||
# Start server
|
# Start server
|
||||||
|
|
||||||
configuration = Configuration()
|
configuration = Configuration()
|
||||||
@ -144,6 +145,10 @@ def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiproce
|
|||||||
configuration.output_directory = str(tmp_path)
|
configuration.output_directory = str(tmp_path)
|
||||||
configuration.input_directory = str(tmp_path)
|
configuration.input_directory = str(tmp_path)
|
||||||
|
|
||||||
|
yield from comfy_background_server_from_config(configuration)
|
||||||
|
|
||||||
|
|
||||||
|
def comfy_background_server_from_config(configuration):
|
||||||
server_process = multiprocessing.Process(target=run_server, args=(configuration,))
|
server_process = multiprocessing.Process(target=run_server, args=(configuration,))
|
||||||
server_process.start()
|
server_process.start()
|
||||||
# wait for http url to be ready
|
# wait for http url to be ready
|
||||||
@ -161,6 +166,7 @@ def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiproce
|
|||||||
raise Exception("Failed to start background server")
|
raise Exception("Failed to start background server")
|
||||||
yield configuration, server_process
|
yield configuration, server_process
|
||||||
server_process.terminate()
|
server_process.terminate()
|
||||||
|
import torch
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@ -232,3 +238,13 @@ def use_temporary_input_directory(tmp_path: pathlib.Path):
|
|||||||
folder_paths.set_input_directory(tmp_path)
|
folder_paths.set_input_directory(tmp_path)
|
||||||
yield tmp_path
|
yield tmp_path
|
||||||
folder_paths.set_input_directory(orig_dir)
|
folder_paths.set_input_directory(orig_dir)
|
||||||
|
|
||||||
|
|
||||||
|
current_test_name = ContextVar('current_test_name', default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def set_test_name(request):
|
||||||
|
token = current_test_name.set(request.node.name)
|
||||||
|
yield
|
||||||
|
current_test_name.reset(token)
|
||||||
|
|||||||
@ -1,59 +1,44 @@
|
|||||||
import pytest
|
|
||||||
import time
|
import time
|
||||||
import torch
|
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import numpy as np
|
from typing import Any, AsyncGenerator
|
||||||
import subprocess
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from pytest import fixture
|
from pytest import fixture
|
||||||
|
|
||||||
|
from comfy.cli_args import default_configuration
|
||||||
|
from comfy.cli_args_types import Configuration
|
||||||
|
from comfy.client.embedded_comfy_client import Comfy
|
||||||
|
from comfy.execution_context import context_add_custom_nodes
|
||||||
|
from comfy.nodes.package_typing import ExportedNodes
|
||||||
from comfy_execution.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from tests.inference.test_execution import ComfyClient
|
from tests.inference.test_execution import ComfyClient
|
||||||
|
from .test_execution import ComfyClient, _ProgressHandler
|
||||||
|
from ..conftest import comfy_background_server_from_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.execution
|
@pytest.mark.execution
|
||||||
class TestAsyncNodes:
|
class TestAsyncNodes:
|
||||||
@fixture(scope="class", autouse=True, params=[
|
# Initialize server and client
|
||||||
(False, 0),
|
@fixture(scope="class", params=[
|
||||||
(True, 0),
|
# (lru_size)
|
||||||
(True, 100),
|
(0,),
|
||||||
|
(100,),
|
||||||
])
|
])
|
||||||
def _server(self, args_pytest, request):
|
async def shared_client(self, request) -> AsyncGenerator[ComfyClient, Any]:
|
||||||
pargs = [
|
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||||
'python','main.py',
|
|
||||||
'--output-directory', args_pytest["output_dir"],
|
|
||||||
'--listen', args_pytest["listen"],
|
|
||||||
'--port', str(args_pytest["port"]),
|
|
||||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
|
||||||
]
|
|
||||||
use_lru, lru_size = request.param
|
|
||||||
if use_lru:
|
|
||||||
pargs += ['--cache-lru', str(lru_size)]
|
|
||||||
# Running server with args: pargs
|
|
||||||
p = subprocess.Popen(pargs)
|
|
||||||
yield
|
|
||||||
p.kill()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
@fixture(scope="class", autouse=True)
|
|
||||||
def shared_client(self, args_pytest, _server):
|
|
||||||
client = ComfyClient()
|
|
||||||
n_tries = 5
|
|
||||||
for i in range(n_tries):
|
|
||||||
time.sleep(4)
|
|
||||||
try:
|
|
||||||
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
|
|
||||||
except ConnectionRefusedError:
|
|
||||||
# Retrying...
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
yield client
|
|
||||||
del client
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
lru_size, = request.param
|
||||||
|
configuration = default_configuration()
|
||||||
|
configuration.cache_lru = lru_size
|
||||||
|
progress_handler = _ProgressHandler()
|
||||||
|
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)):
|
||||||
|
async with Comfy(configuration, progress_handler=progress_handler) as embedded_client:
|
||||||
|
yield ComfyClient(embedded_client, progress_handler)
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
def client(self, shared_client, request):
|
async def client(self, shared_client: ComfyClient, request, set_test_name):
|
||||||
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
|
|
||||||
yield shared_client
|
yield shared_client
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
@ -62,14 +47,14 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
# Happy Path Tests
|
# Happy Path Tests
|
||||||
|
|
||||||
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that a basic async node executes correctly."""
|
"""Test that a basic async node executes correctly."""
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
|
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
|
||||||
output = g.node("SaveImage", images=sleep_node.out(0))
|
output = g.node("SaveImage", images=sleep_node.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
|
|
||||||
# Verify execution completed
|
# Verify execution completed
|
||||||
assert result.did_run(sleep_node), "Async sleep node should have executed"
|
assert result.did_run(sleep_node), "Async sleep node should have executed"
|
||||||
@ -80,7 +65,7 @@ class TestAsyncNodes:
|
|||||||
assert len(result_images) == 1, "Should have 1 image"
|
assert len(result_images) == 1, "Should have 1 image"
|
||||||
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
|
||||||
|
|
||||||
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that multiple async nodes execute in parallel."""
|
"""Test that multiple async nodes execute in parallel."""
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -96,7 +81,7 @@ class TestAsyncNodes:
|
|||||||
_output3 = g.node("PreviewImage", images=sleep3.out(0))
|
_output3 = g.node("PreviewImage", images=sleep3.out(0))
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
|
||||||
@ -105,7 +90,7 @@ class TestAsyncNodes:
|
|||||||
# Verify all nodes executed
|
# Verify all nodes executed
|
||||||
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
|
||||||
|
|
||||||
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async nodes with proper dependency handling."""
|
"""Test async nodes with proper dependency handling."""
|
||||||
g = builder
|
g = builder
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -119,7 +104,7 @@ class TestAsyncNodes:
|
|||||||
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
|
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
|
||||||
output = g.node("SaveImage", images=average.out(0))
|
output = g.node("SaveImage", images=average.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
|
|
||||||
# Verify execution order
|
# Verify execution order
|
||||||
assert result.did_run(sleep1) and result.did_run(sleep2)
|
assert result.did_run(sleep1) and result.did_run(sleep2)
|
||||||
@ -130,7 +115,7 @@ class TestAsyncNodes:
|
|||||||
avg_value = np.array(result_images[0]).mean()
|
avg_value = np.array(result_images[0]).mean()
|
||||||
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
|
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
|
||||||
|
|
||||||
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async VALIDATE_INPUTS function."""
|
"""Test async VALIDATE_INPUTS function."""
|
||||||
g = builder
|
g = builder
|
||||||
# Create a test node with async validation
|
# Create a test node with async validation
|
||||||
@ -138,15 +123,15 @@ class TestAsyncNodes:
|
|||||||
g.node("SaveImage", images=validation_node.out(0))
|
g.node("SaveImage", images=validation_node.out(0))
|
||||||
|
|
||||||
# Should pass validation
|
# Should pass validation
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
assert result.did_run(validation_node)
|
assert result.did_run(validation_node)
|
||||||
|
|
||||||
# Test validation failure
|
# Test validation failure
|
||||||
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
|
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
|
||||||
with pytest.raises(urllib.error.HTTPError):
|
with pytest.raises(ValueError):
|
||||||
client.run(g)
|
await client.run(g)
|
||||||
|
|
||||||
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async nodes with lazy evaluation."""
|
"""Test async nodes with lazy evaluation."""
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -162,7 +147,7 @@ class TestAsyncNodes:
|
|||||||
g.node("SaveImage", images=lazy_mix.out(0))
|
g.node("SaveImage", images=lazy_mix.out(0))
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# Should only execute sleep1, not sleep2
|
# Should only execute sleep1, not sleep2
|
||||||
@ -170,22 +155,22 @@ class TestAsyncNodes:
|
|||||||
assert result.did_run(sleep1), "Sleep1 should have executed"
|
assert result.did_run(sleep1), "Sleep1 should have executed"
|
||||||
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
|
||||||
|
|
||||||
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async check_lazy_status function."""
|
"""Test async check_lazy_status function."""
|
||||||
g = builder
|
g = builder
|
||||||
# Create a node with async check_lazy_status
|
# Create a node with async check_lazy_status
|
||||||
lazy_node = g.node("TestAsyncLazyCheck",
|
lazy_node = g.node("TestAsyncLazyCheck",
|
||||||
input1="value1",
|
input1="value1",
|
||||||
input2="value2",
|
input2="value2",
|
||||||
condition=True)
|
condition=True)
|
||||||
g.node("SaveImage", images=lazy_node.out(0))
|
g.node("SaveImage", images=lazy_node.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
assert result.did_run(lazy_node)
|
assert result.did_run(lazy_node)
|
||||||
|
|
||||||
# Error Handling Tests
|
# Error Handling Tests
|
||||||
|
|
||||||
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that async execution errors are properly handled."""
|
"""Test that async execution errors are properly handled."""
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -194,25 +179,23 @@ class TestAsyncNodes:
|
|||||||
g.node("SaveImage", images=error_node.out(0))
|
g.node("SaveImage", images=error_node.out(0))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client.run(g)
|
await client.run(g)
|
||||||
assert False, "Should have raised an error"
|
assert False, "Should have raised an error"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
|
||||||
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
|
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
|
||||||
|
|
||||||
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async validation error handling."""
|
"""Test async validation error handling."""
|
||||||
g = builder
|
g = builder
|
||||||
# Node with async validation that will fail
|
# Node with async validation that will fail
|
||||||
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
|
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
|
||||||
g.node("SaveImage", images=validation_node.out(0))
|
g.node("SaveImage", images=validation_node.out(0))
|
||||||
|
|
||||||
with pytest.raises(urllib.error.HTTPError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
client.run(g)
|
await client.run(g)
|
||||||
# Verify it's a validation error
|
|
||||||
assert exc_info.value.code == 400
|
|
||||||
|
|
||||||
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test handling of async operations that timeout."""
|
"""Test handling of async operations that timeout."""
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -221,12 +204,12 @@ class TestAsyncNodes:
|
|||||||
g.node("SaveImage", images=timeout_node.out(0))
|
g.node("SaveImage", images=timeout_node.out(0))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client.run(g)
|
await client.run(g)
|
||||||
assert False, "Should have raised a timeout error"
|
assert False, "Should have raised a timeout error"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
|
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
|
||||||
|
|
||||||
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that workflow can recover after async errors."""
|
"""Test that workflow can recover after async errors."""
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -236,7 +219,7 @@ class TestAsyncNodes:
|
|||||||
g.node("SaveImage", images=error_node.out(0))
|
g.node("SaveImage", images=error_node.out(0))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client.run(g)
|
await client.run(g)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Expected
|
pass # Expected
|
||||||
|
|
||||||
@ -246,10 +229,10 @@ class TestAsyncNodes:
|
|||||||
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
|
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
|
||||||
g2.node("SaveImage", images=sleep_node.out(0))
|
g2.node("SaveImage", images=sleep_node.out(0))
|
||||||
|
|
||||||
result = client.run(g2)
|
result = await client.run(g2)
|
||||||
assert result.did_run(sleep_node), "Should be able to run after error"
|
assert result.did_run(sleep_node), "Should be able to run after error"
|
||||||
|
|
||||||
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test handling when sync node errors while async node is executing."""
|
"""Test handling when sync node errors while async node is executing."""
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -265,7 +248,7 @@ class TestAsyncNodes:
|
|||||||
g.node("PreviewImage", images=error_node.out(0))
|
g.node("PreviewImage", images=error_node.out(0))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client.run(g)
|
await client.run(g)
|
||||||
assert False, "Should have raised an error"
|
assert False, "Should have raised an error"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Verify the sync error was caught even though async was running
|
# Verify the sync error was caught even though async was running
|
||||||
@ -273,7 +256,7 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
# Edge Cases
|
# Edge Cases
|
||||||
|
|
||||||
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async nodes with execution blockers."""
|
"""Test async nodes with execution blockers."""
|
||||||
g = builder
|
g = builder
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -299,11 +282,11 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
output = g.node("PreviewImage", images=blocker.out(0))
|
output = g.node("PreviewImage", images=blocker.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
images = result.get_images(output)
|
images = result.get_images(output)
|
||||||
assert len(images) == 1, "Should have blocked second image"
|
assert len(images) == 1, "Should have blocked second image"
|
||||||
|
|
||||||
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that async nodes are properly cached."""
|
"""Test that async nodes are properly cached."""
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -311,18 +294,18 @@ class TestAsyncNodes:
|
|||||||
g.node("SaveImage", images=sleep_node.out(0))
|
g.node("SaveImage", images=sleep_node.out(0))
|
||||||
|
|
||||||
# First run
|
# First run
|
||||||
result1 = client.run(g)
|
result1 = await client.run(g)
|
||||||
assert result1.did_run(sleep_node), "Should run first time"
|
assert result1.did_run(sleep_node), "Should run first time"
|
||||||
|
|
||||||
# Second run - should be cached
|
# Second run - should be cached
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result2 = client.run(g)
|
result2 = await client.run(g)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
assert not result2.did_run(sleep_node), "Should be cached"
|
assert not result2.did_run(sleep_node), "Should be cached"
|
||||||
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
|
||||||
|
|
||||||
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async nodes within dynamically generated prompts."""
|
"""Test async nodes within dynamically generated prompts."""
|
||||||
g = builder
|
g = builder
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -330,21 +313,21 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
# Node that generates async nodes dynamically
|
# Node that generates async nodes dynamically
|
||||||
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
dynamic_async = g.node("TestDynamicAsyncGeneration",
|
||||||
image1=image1.out(0),
|
image1=image1.out(0),
|
||||||
image2=image2.out(0),
|
image2=image2.out(0),
|
||||||
num_async_nodes=3,
|
num_async_nodes=3,
|
||||||
sleep_duration=0.2)
|
sleep_duration=0.2)
|
||||||
g.node("SaveImage", images=dynamic_async.out(0))
|
g.node("SaveImage", images=dynamic_async.out(0))
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# Should execute async nodes in parallel within dynamic prompt
|
# Should execute async nodes in parallel within dynamic prompt
|
||||||
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
|
||||||
assert result.did_run(dynamic_async)
|
assert result.did_run(dynamic_async)
|
||||||
|
|
||||||
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that async resources are properly cleaned up."""
|
"""Test that async resources are properly cleaned up."""
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -353,31 +336,31 @@ class TestAsyncNodes:
|
|||||||
resource_nodes = []
|
resource_nodes = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
node = g.node("TestAsyncResourceUser",
|
node = g.node("TestAsyncResourceUser",
|
||||||
value=image.out(0),
|
value=image.out(0),
|
||||||
resource_id=f"resource_{i}",
|
resource_id=f"resource_{i}",
|
||||||
duration=0.1)
|
duration=0.1)
|
||||||
resource_nodes.append(node)
|
resource_nodes.append(node)
|
||||||
g.node("PreviewImage", images=node.out(0))
|
g.node("PreviewImage", images=node.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
|
|
||||||
# Verify all nodes executed
|
# Verify all nodes executed
|
||||||
for node in resource_nodes:
|
for node in resource_nodes:
|
||||||
assert result.did_run(node)
|
assert result.did_run(node)
|
||||||
|
|
||||||
# Run again to ensure resources were cleaned up
|
# Run again to ensure resources were cleaned up
|
||||||
result2 = client.run(g)
|
result2 = await client.run(g)
|
||||||
# Should be cached but not error due to resource conflicts
|
# Should be cached but not error due to resource conflicts
|
||||||
for node in resource_nodes:
|
for node in resource_nodes:
|
||||||
assert not result2.did_run(node), "Should be cached"
|
assert not result2.did_run(node), "Should be cached"
|
||||||
|
|
||||||
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test cancellation of async operations."""
|
"""Test cancellation of async operations."""
|
||||||
# This would require implementing cancellation in the client
|
# This would require implementing cancellation in the client
|
||||||
# For now, we'll test that long-running async operations can be interrupted
|
# For now, we'll test that long-running async operations can be interrupted
|
||||||
pass # TODO: Implement when cancellation API is available
|
pass # TODO: Implement when cancellation API is available
|
||||||
|
|
||||||
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test workflows with both sync and async nodes."""
|
"""Test workflows with both sync and async nodes."""
|
||||||
g = builder
|
g = builder
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -396,7 +379,7 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
output = g.node("SaveImage", images=async_op2.out(0))
|
output = g.node("SaveImage", images=async_op2.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
|
|
||||||
# Verify all nodes executed in correct order
|
# Verify all nodes executed in correct order
|
||||||
assert result.did_run(sync_op1)
|
assert result.did_run(sync_op1)
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@ -17,15 +16,7 @@ from comfy.distributed.server_stub import ServerStub
|
|||||||
from comfy.execution_context import context_add_custom_nodes
|
from comfy.execution_context import context_add_custom_nodes
|
||||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||||
from comfy.nodes.package_typing import ExportedNodes
|
from comfy.nodes.package_typing import ExportedNodes
|
||||||
|
from tests.conftest import current_test_name
|
||||||
current_test_name = ContextVar('current_test_name', default=None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def set_test_name(request):
|
|
||||||
token = current_test_name.set(request.node.name)
|
|
||||||
yield
|
|
||||||
current_test_name.reset(token)
|
|
||||||
|
|
||||||
|
|
||||||
class RunResult:
|
class RunResult:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy.graph import ExecutionBlocker
|
from comfy_execution.graph import ExecutionBlocker
|
||||||
from .tools import VariantSupport
|
from .tools import VariantSupport
|
||||||
|
|
||||||
NUM_FLOW_SOCKETS = 5
|
NUM_FLOW_SOCKETS = 5
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user