From 1d29c97266faff4b60b6cea6fec9a72373534748 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 17 Jun 2025 18:38:42 -0700 Subject: [PATCH] fix tests --- comfy/app/user_manager.py | 3 - comfy/client/aio_client.py | 30 ++++--- comfy/cmd/main_pre.py | 15 +--- comfy/component_model/queue_types.py | 42 ++++++---- .../distributed/distributed_prompt_worker.py | 9 +- tests/asyncio/test_asyncio_remote_client.py | 26 +++--- tests/conftest.py | 22 +++-- tests/distributed/test_distributed_queue.py | 82 +++++++++---------- 8 files changed, 123 insertions(+), 106 deletions(-) diff --git a/comfy/app/user_manager.py b/comfy/app/user_manager.py index d1b8c5e91..5fdaf0bbd 100644 --- a/comfy/app/user_manager.py +++ b/comfy/app/user_manager.py @@ -40,9 +40,6 @@ class UserManager(): self.settings = AppSettings(self) if not os.path.exists(user_directory): os.makedirs(user_directory, exist_ok=True) - if not args.multi_user: - logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******") - logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") if args.multi_user: if os.path.isfile(self.get_users_file()): diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index 2e3f33c88..683ccb612 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -1,14 +1,14 @@ -import asyncio -import uuid from asyncio import AbstractEventLoop from collections import defaultdict + +import aiohttp +import asyncio +import uuid +from aiohttp import WSMessage, ClientResponse, ClientTimeout from pathlib import Path from typing import Optional, List from urllib.parse import urlparse, urljoin -import aiohttp -from aiohttp import WSMessage, ClientResponse, ClientTimeout - from .client_types import V1QueuePromptResponse from ..api.api_client import JSONEncoder from ..api.components.schema.prompt import PromptDict @@ -33,17 +33,25 @@ class AsyncRemoteComfyClient: f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}") self.loop = loop or asyncio.get_event_loop() self._session: aiohttp.ClientSession | None = None - try: - if asyncio.get_event_loop() is not None: - self._ensure_session() - except RuntimeError as no_running_event_loop: - pass def _ensure_session(self) -> aiohttp.ClientSession: - if self._session is None: + if self._session is None or self._session.closed: self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0)) return self._session + async def __aenter__(self): + """Allows the client to be used in an 'async with' block.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Closes the session when exiting an 'async with' block.""" + await self.close() + + async def close(self): + """Closes the underlying aiohttp.ClientSession.""" + if self._session and not self._session.closed: + await self._session.close() + @property def session(self) -> aiohttp.ClientSession: return self._ensure_session() diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 5d7861bbf..3bb4f713d 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -40,17 +40,6 @@ from ..tracing_compatibility import ProgressSpanSampler from ..tracing_compatibility import patch_spanbuilder_set_channel from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor -# Manually call the _init_dll_path method to ensure that the system path is searched for FFMPEG. -# Calling torchaudio._extension.utils._init_dll_path does not work because it is initializing the torchadio module prematurely or something. -# See: https://github.com/pytorch/audio/issues/3789 -if sys.platform == "win32": - for path in os.environ.get("PATH", "").split(os.pathsep): - if os.path.exists(path): - try: - os.add_dll_directory(path) - except Exception: - pass - this_logger = logging.getLogger(__name__) options.enable_args_parsing() @@ -66,6 +55,10 @@ warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplic warnings.filterwarnings("ignore", message="Please import `gaussian_filter` from the `scipy.ndimage` namespace; the `scipy.ndimage.filters` namespace is deprecated", category=DeprecationWarning) warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support") warnings.filterwarnings("ignore", category=UserWarning, message="Unsupported Windows version .* ONNX Runtime supports Windows 10 and above, only.") +log_msg_to_filter = "NOTE: Redirects are currently not supported in Windows or MacOs." +logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilter( + lambda record: log_msg_to_filter not in record.getMessage() +) from ..cli_args import args diff --git a/comfy/component_model/queue_types.py b/comfy/component_model/queue_types.py index 217df738c..640a62a47 100644 --- a/comfy/component_model/queue_types.py +++ b/comfy/component_model/queue_types.py @@ -1,11 +1,9 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass from enum import Enum from typing import NamedTuple, Optional, List, Literal, Sequence from typing import Tuple - from typing_extensions import NotRequired, TypedDict from .outputs_types import OutputsDict @@ -71,15 +69,26 @@ class ExtraData(TypedDict): token: NotRequired[str] -@dataclass -class NamedQueueTuple: +class NamedQueueTuple(dict): """ A wrapper class for a queue tuple, the object that is given to executors. Attributes: queue_tuple (QueueTuple): the corresponding queued workflow and other related data """ - queue_tuple: QueueTuple + __slots__ = ('queue_tuple',) + + def __init__(self, queue_tuple: QueueTuple): + # Initialize the dictionary superclass with the data we want to serialize. + super().__init__( + priority=queue_tuple[0], + prompt_id=queue_tuple[1], + prompt=queue_tuple[2], + extra_data=queue_tuple[3] if len(queue_tuple) > 3 else None, + good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None + ) + # Store the original tuple in a slot, making it invisible to json.dumps. + self.queue_tuple = queue_tuple @property def priority(self) -> float: @@ -95,20 +104,17 @@ class NamedQueueTuple: @property def extra_data(self) -> Optional[ExtraData]: - if len(self.queue_tuple) > 2: + if len(self.queue_tuple) > 3: return self.queue_tuple[3] - else: - return None + return None @property def good_outputs(self) -> Optional[List[str]]: - if len(self.queue_tuple) > 3: + if len(self.queue_tuple) > 4: return self.queue_tuple[4] - else: - return None + return None -@dataclass class QueueItem(NamedQueueTuple): """ An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done @@ -118,10 +124,18 @@ class QueueItem(NamedQueueTuple): completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method) or a dictionary of outputs """ - completed: asyncio.Future[TaskInvocation | dict] | None + __slots__ = ('completed',) + + def __init__(self, queue_tuple: QueueTuple, completed: asyncio.Future[TaskInvocation | dict] | None): + # Initialize the parent, which sets up the dictionary representation. + super().__init__(queue_tuple=queue_tuple) + # Store the future in a slot so it won't be serialized. + self.completed = completed def __lt__(self, other: QueueItem): - return self.queue_tuple[0] < other.queue_tuple[0] + if not isinstance(other, QueueItem): + return NotImplemented + return self.priority < other.priority class BinaryEventTypes(Enum): diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index a5e52754c..c7ad179fb 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -18,6 +18,7 @@ from ..client.embedded_comfy_client import Comfy from ..cmd.main_pre import tracer from ..component_model.queue_types import ExecutionStatus +logger = logging.getLogger(__name__) class DistributedPromptWorker: """ @@ -62,12 +63,12 @@ class DistributedPromptWorker: site = web.TCPSite(runner, port=self._health_check_port) await site.start() self._health_check_site = site - logging.info(f"health check server started on port {self._health_check_port}") + logger.info(f"health check server started on port {self._health_check_port}") except OSError as e: if e.errno == 98: - logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway") + logger.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway") else: - logging.error(f"failed to start health check server with error {str(e)}, starting anyway") + logger.error(f"failed to start health check server with error {str(e)}, starting anyway") @tracer.start_as_current_span("Do Work Item") async def _do_work_item(self, request: dict) -> dict: @@ -117,7 +118,7 @@ class DistributedPromptWorker: try: self._connection = await connect_robust(self._connection_uri, loop=self._loop) except AMQPConnectionError as connection_error: - logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error) + logger.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error) raise connection_error self._channel = await self._connection.channel() await self._channel.set_qos(prefetch_count=1) diff --git a/tests/asyncio/test_asyncio_remote_client.py b/tests/asyncio/test_asyncio_remote_client.py index 8c54feaa1..5ba7684dd 100644 --- a/tests/asyncio/test_asyncio_remote_client.py +++ b/tests/asyncio/test_asyncio_remote_client.py @@ -11,29 +11,29 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner @pytest.mark.asyncio async def test_completes_prompt(comfy_background_server): - client = AsyncRemoteComfyClient() - random_seed = random.randint(1, 4294967295) - prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) - png_image_bytes = await client.queue_prompt(prompt) + async with AsyncRemoteComfyClient() as client: + random_seed = random.randint(1, 4294967295) + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) + png_image_bytes = await client.queue_prompt(prompt) assert len(png_image_bytes) > 1000 @pytest.mark.asyncio async def test_completes_prompt_with_ui(comfy_background_server): - client = AsyncRemoteComfyClient() - random_seed = random.randint(1, 4294967295) - prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) - result_dict = await client.queue_prompt_ui(prompt) - # should contain one output + async with AsyncRemoteComfyClient() as client: + random_seed = random.randint(1, 4294967295) + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) + result_dict = await client.queue_prompt_ui(prompt) + # should contain one output assert len(result_dict) == 1 @pytest.mark.asyncio async def test_completes_prompt_with_image_urls(comfy_background_server): - client = AsyncRemoteComfyClient() - random_seed = random.randint(1, 4294967295) - prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl") - result = await client.queue_prompt_api(prompt) + async with AsyncRemoteComfyClient() as client: + random_seed = random.randint(1, 4294967295) + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl") + result = await client.queue_prompt_api(prompt) assert len(result.urls) == 2 for url_str in result.urls: url: URL = parse(url_str) diff --git a/tests/conftest.py b/tests/conftest.py index 3a20e8584..fdcf57430 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,25 @@ +import sys +import time + +import logging import multiprocessing import os import pathlib -import socket -import subprocess -import sys -import time -import urllib -from typing import Tuple, List - import pytest import requests +import socket +import subprocess +import urllib +from testcontainers.rabbitmq import RabbitMqContainer +from typing import Tuple, List from comfy.cli_args_types import Configuration +logging.getLogger("pika").setLevel(logging.CRITICAL + 1) +logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1) +logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING) +logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING) + # fixes issues with running the testcontainers rabbitmqcontainer on Windows os.environ["TC_HOST"] = "localhost" @@ -95,7 +102,6 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers executor_factory = request.param processes_to_close: List[subprocess.Popen] = [] - from testcontainers.rabbitmq import RabbitMqContainer with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 0d3083cdf..927acebf3 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -103,21 +103,20 @@ async def test_distributed_prompt_queues_same_process(): @pytest.mark.asyncio async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq): - client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) - prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) - png_image_bytes = await client.queue_prompt(prompt) - len_queue_after = await client.len_queue() + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + png_image_bytes = await client.queue_prompt(prompt) + len_queue_after = await client.len_queue() assert len_queue_after == 0 assert len(png_image_bytes) > 1000, "expected an image, but got nothing" @pytest.mark.asyncio async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq): - client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) - - prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors") - with pytest.raises(Exception): - await client.queue_prompt(prompt) + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors") + with pytest.raises(Exception): + await client.queue_prompt(prompt) async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0): @@ -151,48 +150,47 @@ async def test_basic_queue_worker_with_health_check(executor_factory): @pytest.mark.asyncio async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq): # Create the client using the server address from the fixture - client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: - # Create a test prompt - prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1) + # Create a test prompt + prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1) - # Queue the prompt - task_id = await client.queue_and_forget_prompt_api(prompt) + # Queue the prompt + task_id = await client.queue_and_forget_prompt_api(prompt) - assert task_id is not None, "Failed to get a valid task ID" + assert task_id is not None, "Failed to get a valid task ID" - # Poll for the result - max_attempts = 60 # Increase max attempts for integration test - poll_interval = 1 # Increase poll interval for integration test - for _ in range(max_attempts): - try: - response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}") - if response.status == 200: - result = await response.json() - assert result is not None, "Received empty result" + # Poll for the result + max_attempts = 60 # Increase max attempts for integration test + poll_interval = 1 # Increase poll interval for integration test + for _ in range(max_attempts): + try: + response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}") + if response.status == 200: + result = await response.json() + assert result is not None, "Received empty result" - # Find the first output node with images - output_node = next((node for node in result.values() if 'images' in node), None) - assert output_node is not None, "No output node with images found" + # Find the first output node with images + output_node = next((node for node in result.values() if 'images' in node), None) + assert output_node is not None, "No output node with images found" - assert len(output_node['images']) > 0, "No images in output node" - assert 'filename' in output_node['images'][0], "No filename in image output" - assert 'subfolder' in output_node['images'][0], "No subfolder in image output" - assert 'type' in output_node['images'][0], "No type in image output" + assert len(output_node['images']) > 0, "No images in output node" + assert 'filename' in output_node['images'][0], "No filename in image output" + assert 'subfolder' in output_node['images'][0], "No subfolder in image output" + assert 'type' in output_node['images'][0], "No type in image output" - # Check if we can access the image - image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}" - image_response = await client.session.get(image_url) - assert image_response.status == 200, f"Failed to retrieve image from {image_url}" + # Check if we can access the image + image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}" + image_response = await client.session.get(image_url) + assert image_response.status == 200, f"Failed to retrieve image from {image_url}" - return # Test passed - elif response.status == 204: + return # Test passed + elif response.status == 204: + await asyncio.sleep(poll_interval) + else: + response.raise_for_status() + except _: await asyncio.sleep(poll_interval) - else: - response.raise_for_status() - except Exception as e: - print(f"Error while polling: {e}") - await asyncio.sleep(poll_interval) pytest.fail("Failed to get a 200 response with valid data within the timeout period")