fix tests

This commit is contained in:
doctorpangloss 2025-07-25 15:01:30 -07:00
parent 76b32c8e6e
commit a3ae6e74d2
6 changed files with 105 additions and 127 deletions

View File

@ -16,6 +16,7 @@ from ..api.components.schema.prompt_request import PromptRequest
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
from ..api.schemas import immutabledict from ..api.schemas import immutabledict
from ..component_model.file_output_path import file_output_path from ..component_model.file_output_path import file_output_path
from ..component_model.outputs_types import OutputsDict
class AsyncRemoteComfyClient: class AsyncRemoteComfyClient:
@ -122,9 +123,9 @@ class AsyncRemoteComfyClient:
else: else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt_ui(self, prompt: PromptDict) -> dict[str, List[Path]]: async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict:
""" """
Uses the comfyui UI API calls to retrieve a list of paths of output files Uses the comfyui UI API calls to retrieve the outputs dictionary
:param prompt: :param prompt:
:return: :return:
""" """
@ -158,17 +159,4 @@ class AsyncRemoteComfyClient:
# images have filename, subfolder, type keys # images have filename, subfolder, type keys
# todo: use the OpenAPI spec for this when I get around to updating it # todo: use the OpenAPI spec for this when I get around to updating it
outputs_by_node_id = history_json[prompt_id].outputs return history_json[prompt_id].outputs
res: dict[str, List[Path]] = {}
for node_id, output in outputs_by_node_id.items():
if 'images' in output:
images = []
image_dicts: List[dict] = output['images']
for image_file_output_dict in image_dicts:
image_file_output_dict = defaultdict(None, image_file_output_dict)
filename = image_file_output_dict['filename']
subfolder = image_file_output_dict['subfolder']
type = image_file_output_dict['type']
images.append(Path(file_output_path(filename, subfolder=subfolder, type=type)))
res[node_id] = images
return res

View File

@ -24,7 +24,7 @@ async def test_completes_prompt_with_ui(comfy_background_server):
random_seed = random.randint(1, 4294967295) random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_dict = await client.queue_prompt_ui(prompt) result_dict = await client.queue_prompt_ui(prompt)
# should contain one output # should contain one output
assert len(result_dict) == 1 assert len(result_dict) == 1

View File

@ -6,7 +6,9 @@ import subprocess
import sys import sys
import time import time
import urllib import urllib
from typing import Tuple, List from contextvars import ContextVar
from multiprocessing import Process
from typing import List, Any, Generator
import pytest import pytest
import requests import requests
@ -24,12 +26,12 @@ logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
def run_server(server_arguments: Configuration): def run_server(server_arguments: Configuration):
from comfy.cmd.main import main from comfy.cmd.main import _start_comfyui
from comfy.cli_args import args from comfy.cli_args import args
import asyncio import asyncio
for arg, value in server_arguments.items(): for arg, value in server_arguments.items():
args[arg] = value args[arg] = value
asyncio.run(main()) asyncio.run(_start_comfyui())
@pytest.fixture(scope="function", autouse=False) @pytest.fixture(scope="function", autouse=False)
@ -134,9 +136,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
@pytest.fixture(scope="module", autouse=False) @pytest.fixture(scope="module", autouse=False)
def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiprocessing.Process]: def comfy_background_server(tmp_path_factory) -> Generator[tuple[Configuration, Process], Any, None]:
tmp_path = tmp_path_factory.mktemp("comfy_background_server") tmp_path = tmp_path_factory.mktemp("comfy_background_server")
import torch
# Start server # Start server
configuration = Configuration() configuration = Configuration()
@ -144,6 +145,10 @@ def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiproce
configuration.output_directory = str(tmp_path) configuration.output_directory = str(tmp_path)
configuration.input_directory = str(tmp_path) configuration.input_directory = str(tmp_path)
yield from comfy_background_server_from_config(configuration)
def comfy_background_server_from_config(configuration):
server_process = multiprocessing.Process(target=run_server, args=(configuration,)) server_process = multiprocessing.Process(target=run_server, args=(configuration,))
server_process.start() server_process.start()
# wait for http url to be ready # wait for http url to be ready
@ -161,6 +166,7 @@ def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiproce
raise Exception("Failed to start background server") raise Exception("Failed to start background server")
yield configuration, server_process yield configuration, server_process
server_process.terminate() server_process.terminate()
import torch
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -232,3 +238,13 @@ def use_temporary_input_directory(tmp_path: pathlib.Path):
folder_paths.set_input_directory(tmp_path) folder_paths.set_input_directory(tmp_path)
yield tmp_path yield tmp_path
folder_paths.set_input_directory(orig_dir) folder_paths.set_input_directory(orig_dir)
current_test_name = ContextVar('current_test_name', default=None)
@pytest.fixture(autouse=True)
def set_test_name(request):
token = current_test_name.set(request.node.name)
yield
current_test_name.reset(token)

View File

@ -1,59 +1,44 @@
import pytest
import time import time
import torch
import urllib.error import urllib.error
import numpy as np from typing import Any, AsyncGenerator
import subprocess
import numpy as np
import pytest
import torch
from pytest import fixture from pytest import fixture
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import Comfy
from comfy.execution_context import context_add_custom_nodes
from comfy.nodes.package_typing import ExportedNodes
from comfy_execution.graph_utils import GraphBuilder from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient from tests.inference.test_execution import ComfyClient
from .test_execution import ComfyClient, _ProgressHandler
from ..conftest import comfy_background_server_from_config
@pytest.mark.execution @pytest.mark.execution
class TestAsyncNodes: class TestAsyncNodes:
@fixture(scope="class", autouse=True, params=[ # Initialize server and client
(False, 0), @fixture(scope="class", params=[
(True, 0), # (lru_size)
(True, 100), (0,),
(100,),
]) ])
def _server(self, args_pytest, request): async def shared_client(self, request) -> AsyncGenerator[ComfyClient, Any]:
pargs = [ from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
'python','main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
]
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
# Running server with args: pargs
p = subprocess.Popen(pargs)
yield
p.kill()
torch.cuda.empty_cache()
@fixture(scope="class", autouse=True) lru_size, = request.param
def shared_client(self, args_pytest, _server): configuration = default_configuration()
client = ComfyClient() configuration.cache_lru = lru_size
n_tries = 5 progress_handler = _ProgressHandler()
for i in range(n_tries): with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)):
time.sleep(4) async with Comfy(configuration, progress_handler=progress_handler) as embedded_client:
try: yield ComfyClient(embedded_client, progress_handler)
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
except ConnectionRefusedError:
# Retrying...
pass
else:
break
yield client
del client
torch.cuda.empty_cache()
@fixture @fixture
def client(self, shared_client, request): async def client(self, shared_client: ComfyClient, request, set_test_name):
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
yield shared_client yield shared_client
@fixture @fixture
@ -62,14 +47,14 @@ class TestAsyncNodes:
# Happy Path Tests # Happy Path Tests
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): async def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that a basic async node executes correctly.""" """Test that a basic async node executes correctly."""
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
output = g.node("SaveImage", images=sleep_node.out(0)) output = g.node("SaveImage", images=sleep_node.out(0))
result = client.run(g) result = await client.run(g)
# Verify execution completed # Verify execution completed
assert result.did_run(sleep_node), "Async sleep node should have executed" assert result.did_run(sleep_node), "Async sleep node should have executed"
@ -80,7 +65,7 @@ class TestAsyncNodes:
assert len(result_images) == 1, "Should have 1 image" assert len(result_images) == 1, "Should have 1 image"
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): async def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel.""" """Test that multiple async nodes execute in parallel."""
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -96,7 +81,7 @@ class TestAsyncNodes:
_output3 = g.node("PreviewImage", images=sleep3.out(0)) _output3 = g.node("PreviewImage", images=sleep3.out(0))
start_time = time.time() start_time = time.time()
result = client.run(g) result = await client.run(g)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Should take ~0.5s (max duration) not 1.2s (sum of durations) # Should take ~0.5s (max duration) not 1.2s (sum of durations)
@ -105,7 +90,7 @@ class TestAsyncNodes:
# Verify all nodes executed # Verify all nodes executed
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): async def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with proper dependency handling.""" """Test async nodes with proper dependency handling."""
g = builder g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -119,7 +104,7 @@ class TestAsyncNodes:
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0)) average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
output = g.node("SaveImage", images=average.out(0)) output = g.node("SaveImage", images=average.out(0))
result = client.run(g) result = await client.run(g)
# Verify execution order # Verify execution order
assert result.did_run(sleep1) and result.did_run(sleep2) assert result.did_run(sleep1) and result.did_run(sleep2)
@ -130,7 +115,7 @@ class TestAsyncNodes:
avg_value = np.array(result_images[0]).mean() avg_value = np.array(result_images[0]).mean()
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5" assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): async def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
"""Test async VALIDATE_INPUTS function.""" """Test async VALIDATE_INPUTS function."""
g = builder g = builder
# Create a test node with async validation # Create a test node with async validation
@ -138,15 +123,15 @@ class TestAsyncNodes:
g.node("SaveImage", images=validation_node.out(0)) g.node("SaveImage", images=validation_node.out(0))
# Should pass validation # Should pass validation
result = client.run(g) result = await client.run(g)
assert result.did_run(validation_node) assert result.did_run(validation_node)
# Test validation failure # Test validation failure
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
with pytest.raises(urllib.error.HTTPError): with pytest.raises(ValueError):
client.run(g) await client.run(g)
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): async def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation.""" """Test async nodes with lazy evaluation."""
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -162,7 +147,7 @@ class TestAsyncNodes:
g.node("SaveImage", images=lazy_mix.out(0)) g.node("SaveImage", images=lazy_mix.out(0))
start_time = time.time() start_time = time.time()
result = client.run(g) result = await client.run(g)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Should only execute sleep1, not sleep2 # Should only execute sleep1, not sleep2
@ -170,22 +155,22 @@ class TestAsyncNodes:
assert result.did_run(sleep1), "Sleep1 should have executed" assert result.did_run(sleep1), "Sleep1 should have executed"
assert not result.did_run(sleep2), "Sleep2 should have been skipped" assert not result.did_run(sleep2), "Sleep2 should have been skipped"
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): async def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
"""Test async check_lazy_status function.""" """Test async check_lazy_status function."""
g = builder g = builder
# Create a node with async check_lazy_status # Create a node with async check_lazy_status
lazy_node = g.node("TestAsyncLazyCheck", lazy_node = g.node("TestAsyncLazyCheck",
input1="value1", input1="value1",
input2="value2", input2="value2",
condition=True) condition=True)
g.node("SaveImage", images=lazy_node.out(0)) g.node("SaveImage", images=lazy_node.out(0))
result = client.run(g) result = await client.run(g)
assert result.did_run(lazy_node) assert result.did_run(lazy_node)
# Error Handling Tests # Error Handling Tests
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): async def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async execution errors are properly handled.""" """Test that async execution errors are properly handled."""
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -194,25 +179,23 @@ class TestAsyncNodes:
g.node("SaveImage", images=error_node.out(0)) g.node("SaveImage", images=error_node.out(0))
try: try:
client.run(g) await client.run(g)
assert False, "Should have raised an error" assert False, "Should have raised an error"
except Exception as e: except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node" assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): async def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
"""Test async validation error handling.""" """Test async validation error handling."""
g = builder g = builder
# Node with async validation that will fail # Node with async validation that will fail
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0) validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
g.node("SaveImage", images=validation_node.out(0)) g.node("SaveImage", images=validation_node.out(0))
with pytest.raises(urllib.error.HTTPError) as exc_info: with pytest.raises(ValueError) as exc_info:
client.run(g) await client.run(g)
# Verify it's a validation error
assert exc_info.value.code == 400
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): async def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
"""Test handling of async operations that timeout.""" """Test handling of async operations that timeout."""
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -221,12 +204,12 @@ class TestAsyncNodes:
g.node("SaveImage", images=timeout_node.out(0)) g.node("SaveImage", images=timeout_node.out(0))
try: try:
client.run(g) await client.run(g)
assert False, "Should have raised a timeout error" assert False, "Should have raised a timeout error"
except Exception as e: except Exception as e:
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}" assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): async def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
"""Test that workflow can recover after async errors.""" """Test that workflow can recover after async errors."""
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -236,7 +219,7 @@ class TestAsyncNodes:
g.node("SaveImage", images=error_node.out(0)) g.node("SaveImage", images=error_node.out(0))
try: try:
client.run(g) await client.run(g)
except Exception: except Exception:
pass # Expected pass # Expected
@ -246,10 +229,10 @@ class TestAsyncNodes:
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1) sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
g2.node("SaveImage", images=sleep_node.out(0)) g2.node("SaveImage", images=sleep_node.out(0))
result = client.run(g2) result = await client.run(g2)
assert result.did_run(sleep_node), "Should be able to run after error" assert result.did_run(sleep_node), "Should be able to run after error"
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): async def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test handling when sync node errors while async node is executing.""" """Test handling when sync node errors while async node is executing."""
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -265,7 +248,7 @@ class TestAsyncNodes:
g.node("PreviewImage", images=error_node.out(0)) g.node("PreviewImage", images=error_node.out(0))
try: try:
client.run(g) await client.run(g)
assert False, "Should have raised an error" assert False, "Should have raised an error"
except Exception as e: except Exception as e:
# Verify the sync error was caught even though async was running # Verify the sync error was caught even though async was running
@ -273,7 +256,7 @@ class TestAsyncNodes:
# Edge Cases # Edge Cases
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): async def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with execution blockers.""" """Test async nodes with execution blockers."""
g = builder g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -299,11 +282,11 @@ class TestAsyncNodes:
output = g.node("PreviewImage", images=blocker.out(0)) output = g.node("PreviewImage", images=blocker.out(0))
result = client.run(g) result = await client.run(g)
images = result.get_images(output) images = result.get_images(output)
assert len(images) == 1, "Should have blocked second image" assert len(images) == 1, "Should have blocked second image"
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): async def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached.""" """Test that async nodes are properly cached."""
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -311,18 +294,18 @@ class TestAsyncNodes:
g.node("SaveImage", images=sleep_node.out(0)) g.node("SaveImage", images=sleep_node.out(0))
# First run # First run
result1 = client.run(g) result1 = await client.run(g)
assert result1.did_run(sleep_node), "Should run first time" assert result1.did_run(sleep_node), "Should run first time"
# Second run - should be cached # Second run - should be cached
start_time = time.time() start_time = time.time()
result2 = client.run(g) result2 = await client.run(g)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
assert not result2.did_run(sleep_node), "Should be cached" assert not result2.did_run(sleep_node), "Should be cached"
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): async def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts.""" """Test async nodes within dynamically generated prompts."""
g = builder g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -330,21 +313,21 @@ class TestAsyncNodes:
# Node that generates async nodes dynamically # Node that generates async nodes dynamically
dynamic_async = g.node("TestDynamicAsyncGeneration", dynamic_async = g.node("TestDynamicAsyncGeneration",
image1=image1.out(0), image1=image1.out(0),
image2=image2.out(0), image2=image2.out(0),
num_async_nodes=3, num_async_nodes=3,
sleep_duration=0.2) sleep_duration=0.2)
g.node("SaveImage", images=dynamic_async.out(0)) g.node("SaveImage", images=dynamic_async.out(0))
start_time = time.time() start_time = time.time()
result = client.run(g) result = await client.run(g)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# Should execute async nodes in parallel within dynamic prompt # Should execute async nodes in parallel within dynamic prompt
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s" assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
assert result.did_run(dynamic_async) assert result.did_run(dynamic_async)
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): async def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async resources are properly cleaned up.""" """Test that async resources are properly cleaned up."""
g = builder g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -353,31 +336,31 @@ class TestAsyncNodes:
resource_nodes = [] resource_nodes = []
for i in range(5): for i in range(5):
node = g.node("TestAsyncResourceUser", node = g.node("TestAsyncResourceUser",
value=image.out(0), value=image.out(0),
resource_id=f"resource_{i}", resource_id=f"resource_{i}",
duration=0.1) duration=0.1)
resource_nodes.append(node) resource_nodes.append(node)
g.node("PreviewImage", images=node.out(0)) g.node("PreviewImage", images=node.out(0))
result = client.run(g) result = await client.run(g)
# Verify all nodes executed # Verify all nodes executed
for node in resource_nodes: for node in resource_nodes:
assert result.did_run(node) assert result.did_run(node)
# Run again to ensure resources were cleaned up # Run again to ensure resources were cleaned up
result2 = client.run(g) result2 = await client.run(g)
# Should be cached but not error due to resource conflicts # Should be cached but not error due to resource conflicts
for node in resource_nodes: for node in resource_nodes:
assert not result2.did_run(node), "Should be cached" assert not result2.did_run(node), "Should be cached"
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): async def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
"""Test cancellation of async operations.""" """Test cancellation of async operations."""
# This would require implementing cancellation in the client # This would require implementing cancellation in the client
# For now, we'll test that long-running async operations can be interrupted # For now, we'll test that long-running async operations can be interrupted
pass # TODO: Implement when cancellation API is available pass # TODO: Implement when cancellation API is available
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): async def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test workflows with both sync and async nodes.""" """Test workflows with both sync and async nodes."""
g = builder g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -396,7 +379,7 @@ class TestAsyncNodes:
output = g.node("SaveImage", images=async_op2.out(0)) output = g.node("SaveImage", images=async_op2.out(0))
result = client.run(g) result = await client.run(g)
# Verify all nodes executed in correct order # Verify all nodes executed in correct order
assert result.did_run(sync_op1) assert result.did_run(sync_op1)

View File

@ -1,7 +1,6 @@
import logging import logging
import time import time
import uuid import uuid
from contextvars import ContextVar
from typing import Dict, Optional from typing import Dict, Optional
import numpy import numpy
@ -17,15 +16,7 @@ from comfy.distributed.server_stub import ServerStub
from comfy.execution_context import context_add_custom_nodes from comfy.execution_context import context_add_custom_nodes
from comfy_execution.graph_utils import GraphBuilder, Node from comfy_execution.graph_utils import GraphBuilder, Node
from comfy.nodes.package_typing import ExportedNodes from comfy.nodes.package_typing import ExportedNodes
from tests.conftest import current_test_name
current_test_name = ContextVar('current_test_name', default=None)
@pytest.fixture(autouse=True)
def set_test_name(request):
token = current_test_name.set(request.node.name)
yield
current_test_name.reset(token)
class RunResult: class RunResult:

View File

@ -1,5 +1,5 @@
from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy.graph import ExecutionBlocker from comfy_execution.graph import ExecutionBlocker
from .tools import VariantSupport from .tools import VariantSupport
NUM_FLOW_SOCKETS = 5 NUM_FLOW_SOCKETS = 5