diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index 683ccb612..cdf16e0f1 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -16,6 +16,7 @@ from ..api.components.schema.prompt_request import PromptRequest from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict from ..api.schemas import immutabledict from ..component_model.file_output_path import file_output_path +from ..component_model.outputs_types import OutputsDict class AsyncRemoteComfyClient: @@ -122,9 +123,9 @@ class AsyncRemoteComfyClient: else: raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") - async def queue_prompt_ui(self, prompt: PromptDict) -> dict[str, List[Path]]: + async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict: """ - Uses the comfyui UI API calls to retrieve a list of paths of output files + Uses the comfyui UI API calls to retrieve the outputs dictionary :param prompt: :return: """ @@ -158,17 +159,4 @@ class AsyncRemoteComfyClient: # images have filename, subfolder, type keys # todo: use the OpenAPI spec for this when I get around to updating it - outputs_by_node_id = history_json[prompt_id].outputs - res: dict[str, List[Path]] = {} - for node_id, output in outputs_by_node_id.items(): - if 'images' in output: - images = [] - image_dicts: List[dict] = output['images'] - for image_file_output_dict in image_dicts: - image_file_output_dict = defaultdict(None, image_file_output_dict) - filename = image_file_output_dict['filename'] - subfolder = image_file_output_dict['subfolder'] - type = image_file_output_dict['type'] - images.append(Path(file_output_path(filename, subfolder=subfolder, type=type))) - res[node_id] = images - return res + return history_json[prompt_id].outputs diff --git a/tests/asyncio/test_asyncio_remote_client.py b/tests/asyncio/test_asyncio_remote_client.py index 5ba7684dd..4999dbaa9 100644 --- a/tests/asyncio/test_asyncio_remote_client.py +++ b/tests/asyncio/test_asyncio_remote_client.py @@ -24,7 +24,7 @@ async def test_completes_prompt_with_ui(comfy_background_server): random_seed = random.randint(1, 4294967295) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) result_dict = await client.queue_prompt_ui(prompt) - # should contain one output + # should contain one output assert len(result_dict) == 1 diff --git a/tests/conftest.py b/tests/conftest.py index 5686ab565..552813bea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,9 @@ import subprocess import sys import time import urllib -from typing import Tuple, List +from contextvars import ContextVar +from multiprocessing import Process +from typing import List, Any, Generator import pytest import requests @@ -24,12 +26,12 @@ logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1) def run_server(server_arguments: Configuration): - from comfy.cmd.main import main + from comfy.cmd.main import _start_comfyui from comfy.cli_args import args import asyncio for arg, value in server_arguments.items(): args[arg] = value - asyncio.run(main()) + asyncio.run(_start_comfyui()) @pytest.fixture(scope="function", autouse=False) @@ -134,9 +136,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers @pytest.fixture(scope="module", autouse=False) -def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiprocessing.Process]: +def comfy_background_server(tmp_path_factory) -> Generator[tuple[Configuration, Process], Any, None]: tmp_path = tmp_path_factory.mktemp("comfy_background_server") - import torch # Start server configuration = Configuration() @@ -144,6 +145,10 @@ def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiproce configuration.output_directory = str(tmp_path) configuration.input_directory = str(tmp_path) + yield from comfy_background_server_from_config(configuration) + + +def comfy_background_server_from_config(configuration): server_process = multiprocessing.Process(target=run_server, args=(configuration,)) server_process.start() # wait for http url to be ready @@ -161,6 +166,7 @@ def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiproce raise Exception("Failed to start background server") yield configuration, server_process server_process.terminate() + import torch torch.cuda.empty_cache() @@ -232,3 +238,13 @@ def use_temporary_input_directory(tmp_path: pathlib.Path): folder_paths.set_input_directory(tmp_path) yield tmp_path folder_paths.set_input_directory(orig_dir) + + +current_test_name = ContextVar('current_test_name', default=None) + + +@pytest.fixture(autouse=True) +def set_test_name(request): + token = current_test_name.set(request.node.name) + yield + current_test_name.reset(token) diff --git a/tests/inference/test_async_nodes.py b/tests/inference/test_async_nodes.py index b243bbca9..b64d8b7a6 100644 --- a/tests/inference/test_async_nodes.py +++ b/tests/inference/test_async_nodes.py @@ -1,59 +1,44 @@ -import pytest import time -import torch import urllib.error -import numpy as np -import subprocess +from typing import Any, AsyncGenerator +import numpy as np +import pytest +import torch from pytest import fixture + +from comfy.cli_args import default_configuration +from comfy.cli_args_types import Configuration +from comfy.client.embedded_comfy_client import Comfy +from comfy.execution_context import context_add_custom_nodes +from comfy.nodes.package_typing import ExportedNodes from comfy_execution.graph_utils import GraphBuilder from tests.inference.test_execution import ComfyClient +from .test_execution import ComfyClient, _ProgressHandler +from ..conftest import comfy_background_server_from_config @pytest.mark.execution class TestAsyncNodes: - @fixture(scope="class", autouse=True, params=[ - (False, 0), - (True, 0), - (True, 100), + # Initialize server and client + @fixture(scope="class", params=[ + # (lru_size) + (0,), + (100,), ]) - def _server(self, args_pytest, request): - pargs = [ - 'python','main.py', - '--output-directory', args_pytest["output_dir"], - '--listen', args_pytest["listen"], - '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', - ] - use_lru, lru_size = request.param - if use_lru: - pargs += ['--cache-lru', str(lru_size)] - # Running server with args: pargs - p = subprocess.Popen(pargs) - yield - p.kill() - torch.cuda.empty_cache() - - @fixture(scope="class", autouse=True) - def shared_client(self, args_pytest, _server): - client = ComfyClient() - n_tries = 5 - for i in range(n_tries): - time.sleep(4) - try: - client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) - except ConnectionRefusedError: - # Retrying... - pass - else: - break - yield client - del client - torch.cuda.empty_cache() + async def shared_client(self, request) -> AsyncGenerator[ComfyClient, Any]: + from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + lru_size, = request.param + configuration = default_configuration() + configuration.cache_lru = lru_size + progress_handler = _ProgressHandler() + with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)): + async with Comfy(configuration, progress_handler=progress_handler) as embedded_client: + yield ComfyClient(embedded_client, progress_handler) + @fixture - def client(self, shared_client, request): - shared_client.set_test_name(f"async_nodes[{request.node.name}]") + async def client(self, shared_client: ComfyClient, request, set_test_name): yield shared_client @fixture @@ -62,14 +47,14 @@ class TestAsyncNodes: # Happy Path Tests - def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): + async def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): """Test that a basic async node executes correctly.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1) output = g.node("SaveImage", images=sleep_node.out(0)) - result = client.run(g) + result = await client.run(g) # Verify execution completed assert result.did_run(sleep_node), "Async sleep node should have executed" @@ -80,7 +65,7 @@ class TestAsyncNodes: assert len(result_images) == 1, "Should have 1 image" assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" - def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): + async def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): """Test that multiple async nodes execute in parallel.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -96,7 +81,7 @@ class TestAsyncNodes: _output3 = g.node("PreviewImage", images=sleep3.out(0)) start_time = time.time() - result = client.run(g) + result = await client.run(g) elapsed_time = time.time() - start_time # Should take ~0.5s (max duration) not 1.2s (sum of durations) @@ -105,7 +90,7 @@ class TestAsyncNodes: # Verify all nodes executed assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) - def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with proper dependency handling.""" g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -119,7 +104,7 @@ class TestAsyncNodes: average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0)) output = g.node("SaveImage", images=average.out(0)) - result = client.run(g) + result = await client.run(g) # Verify execution order assert result.did_run(sleep1) and result.did_run(sleep2) @@ -130,7 +115,7 @@ class TestAsyncNodes: avg_value = np.array(result_images[0]).mean() assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5" - def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): """Test async VALIDATE_INPUTS function.""" g = builder # Create a test node with async validation @@ -138,15 +123,15 @@ class TestAsyncNodes: g.node("SaveImage", images=validation_node.out(0)) # Should pass validation - result = client.run(g) + result = await client.run(g) assert result.did_run(validation_node) # Test validation failure validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold - with pytest.raises(urllib.error.HTTPError): - client.run(g) + with pytest.raises(ValueError): + await client.run(g) - def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with lazy evaluation.""" g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -162,7 +147,7 @@ class TestAsyncNodes: g.node("SaveImage", images=lazy_mix.out(0)) start_time = time.time() - result = client.run(g) + result = await client.run(g) elapsed_time = time.time() - start_time # Should only execute sleep1, not sleep2 @@ -170,22 +155,22 @@ class TestAsyncNodes: assert result.did_run(sleep1), "Sleep1 should have executed" assert not result.did_run(sleep2), "Sleep2 should have been skipped" - def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): """Test async check_lazy_status function.""" g = builder # Create a node with async check_lazy_status lazy_node = g.node("TestAsyncLazyCheck", - input1="value1", - input2="value2", - condition=True) + input1="value1", + input2="value2", + condition=True) g.node("SaveImage", images=lazy_node.out(0)) - result = client.run(g) + result = await client.run(g) assert result.did_run(lazy_node) # Error Handling Tests - def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): """Test that async execution errors are properly handled.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -194,25 +179,23 @@ class TestAsyncNodes: g.node("SaveImage", images=error_node.out(0)) try: - client.run(g) + await client.run(g) assert False, "Should have raised an error" except Exception as e: assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node" - def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): """Test async validation error handling.""" g = builder # Node with async validation that will fail validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0) g.node("SaveImage", images=validation_node.out(0)) - with pytest.raises(urllib.error.HTTPError) as exc_info: - client.run(g) - # Verify it's a validation error - assert exc_info.value.code == 400 + with pytest.raises(ValueError) as exc_info: + await client.run(g) - def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): """Test handling of async operations that timeout.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -221,12 +204,12 @@ class TestAsyncNodes: g.node("SaveImage", images=timeout_node.out(0)) try: - client.run(g) + await client.run(g) assert False, "Should have raised a timeout error" except Exception as e: assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}" - def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): + async def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): """Test that workflow can recover after async errors.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -236,7 +219,7 @@ class TestAsyncNodes: g.node("SaveImage", images=error_node.out(0)) try: - client.run(g) + await client.run(g) except Exception: pass # Expected @@ -246,10 +229,10 @@ class TestAsyncNodes: sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1) g2.node("SaveImage", images=sleep_node.out(0)) - result = client.run(g2) + result = await client.run(g2) assert result.did_run(sleep_node), "Should be able to run after error" - def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): + async def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): """Test handling when sync node errors while async node is executing.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -265,7 +248,7 @@ class TestAsyncNodes: g.node("PreviewImage", images=error_node.out(0)) try: - client.run(g) + await client.run(g) assert False, "Should have raised an error" except Exception as e: # Verify the sync error was caught even though async was running @@ -273,7 +256,7 @@ class TestAsyncNodes: # Edge Cases - def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with execution blockers.""" g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -299,11 +282,11 @@ class TestAsyncNodes: output = g.node("PreviewImage", images=blocker.out(0)) - result = client.run(g) + result = await client.run(g) images = result.get_images(output) assert len(images) == 1, "Should have blocked second image" - def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): """Test that async nodes are properly cached.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -311,18 +294,18 @@ class TestAsyncNodes: g.node("SaveImage", images=sleep_node.out(0)) # First run - result1 = client.run(g) + result1 = await client.run(g) assert result1.did_run(sleep_node), "Should run first time" # Second run - should be cached start_time = time.time() - result2 = client.run(g) + result2 = await client.run(g) elapsed_time = time.time() - start_time assert not result2.did_run(sleep_node), "Should be cached" assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" - def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes within dynamically generated prompts.""" g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -330,21 +313,21 @@ class TestAsyncNodes: # Node that generates async nodes dynamically dynamic_async = g.node("TestDynamicAsyncGeneration", - image1=image1.out(0), - image2=image2.out(0), - num_async_nodes=3, - sleep_duration=0.2) + image1=image1.out(0), + image2=image2.out(0), + num_async_nodes=3, + sleep_duration=0.2) g.node("SaveImage", images=dynamic_async.out(0)) start_time = time.time() - result = client.run(g) + result = await client.run(g) elapsed_time = time.time() - start_time # Should execute async nodes in parallel within dynamic prompt assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s" assert result.did_run(dynamic_async) - def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): """Test that async resources are properly cleaned up.""" g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -353,31 +336,31 @@ class TestAsyncNodes: resource_nodes = [] for i in range(5): node = g.node("TestAsyncResourceUser", - value=image.out(0), - resource_id=f"resource_{i}", - duration=0.1) + value=image.out(0), + resource_id=f"resource_{i}", + duration=0.1) resource_nodes.append(node) g.node("PreviewImage", images=node.out(0)) - result = client.run(g) + result = await client.run(g) # Verify all nodes executed for node in resource_nodes: assert result.did_run(node) # Run again to ensure resources were cleaned up - result2 = client.run(g) + result2 = await client.run(g) # Should be cached but not error due to resource conflicts for node in resource_nodes: assert not result2.did_run(node), "Should be cached" - def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): + async def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): """Test cancellation of async operations.""" # This would require implementing cancellation in the client # For now, we'll test that long-running async operations can be interrupted pass # TODO: Implement when cancellation API is available - def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): + async def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): """Test workflows with both sync and async nodes.""" g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -396,7 +379,7 @@ class TestAsyncNodes: output = g.node("SaveImage", images=async_op2.out(0)) - result = client.run(g) + result = await client.run(g) # Verify all nodes executed in correct order assert result.did_run(sync_op1) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index b3055f129..260b53a5a 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -1,7 +1,6 @@ import logging import time import uuid -from contextvars import ContextVar from typing import Dict, Optional import numpy @@ -17,15 +16,7 @@ from comfy.distributed.server_stub import ServerStub from comfy.execution_context import context_add_custom_nodes from comfy_execution.graph_utils import GraphBuilder, Node from comfy.nodes.package_typing import ExportedNodes - -current_test_name = ContextVar('current_test_name', default=None) - - -@pytest.fixture(autouse=True) -def set_test_name(request): - token = current_test_name.set(request.node.name) - yield - current_test_name.reset(token) +from tests.conftest import current_test_name class RunResult: diff --git a/tests/inference/testing_pack/flow_control.py b/tests/inference/testing_pack/flow_control.py index f5d495f38..ba943be60 100644 --- a/tests/inference/testing_pack/flow_control.py +++ b/tests/inference/testing_pack/flow_control.py @@ -1,5 +1,5 @@ from comfy_execution.graph_utils import GraphBuilder, is_link -from comfy.graph import ExecutionBlocker +from comfy_execution.graph import ExecutionBlocker from .tools import VariantSupport NUM_FLOW_SOCKETS = 5