fix tests

This commit is contained in:
doctorpangloss 2025-06-17 18:38:42 -07:00
parent 42dd7a59e3
commit 1d29c97266
8 changed files with 123 additions and 106 deletions

View File

@ -40,9 +40,6 @@ class UserManager():
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user:
if os.path.isfile(self.get_users_file()):

View File

@ -1,14 +1,14 @@
import asyncio
import uuid
from asyncio import AbstractEventLoop
from collections import defaultdict
import aiohttp
import asyncio
import uuid
from aiohttp import WSMessage, ClientResponse, ClientTimeout
from pathlib import Path
from typing import Optional, List
from urllib.parse import urlparse, urljoin
import aiohttp
from aiohttp import WSMessage, ClientResponse, ClientTimeout
from .client_types import V1QueuePromptResponse
from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt import PromptDict
@ -33,17 +33,25 @@ class AsyncRemoteComfyClient:
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
self.loop = loop or asyncio.get_event_loop()
self._session: aiohttp.ClientSession | None = None
try:
if asyncio.get_event_loop() is not None:
self._ensure_session()
except RuntimeError as no_running_event_loop:
pass
def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0))
return self._session
async def __aenter__(self):
"""Allows the client to be used in an 'async with' block."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Closes the session when exiting an 'async with' block."""
await self.close()
async def close(self):
"""Closes the underlying aiohttp.ClientSession."""
if self._session and not self._session.closed:
await self._session.close()
@property
def session(self) -> aiohttp.ClientSession:
return self._ensure_session()

View File

@ -40,17 +40,6 @@ from ..tracing_compatibility import ProgressSpanSampler
from ..tracing_compatibility import patch_spanbuilder_set_channel
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
# Manually call the _init_dll_path method to ensure that the system path is searched for FFMPEG.
# Calling torchaudio._extension.utils._init_dll_path does not work because it is initializing the torchadio module prematurely or something.
# See: https://github.com/pytorch/audio/issues/3789
if sys.platform == "win32":
for path in os.environ.get("PATH", "").split(os.pathsep):
if os.path.exists(path):
try:
os.add_dll_directory(path)
except Exception:
pass
this_logger = logging.getLogger(__name__)
options.enable_args_parsing()
@ -66,6 +55,10 @@ warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplic
warnings.filterwarnings("ignore", message="Please import `gaussian_filter` from the `scipy.ndimage` namespace; the `scipy.ndimage.filters` namespace is deprecated", category=DeprecationWarning)
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support")
warnings.filterwarnings("ignore", category=UserWarning, message="Unsupported Windows version .* ONNX Runtime supports Windows 10 and above, only.")
log_msg_to_filter = "NOTE: Redirects are currently not supported in Windows or MacOs."
logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilter(
lambda record: log_msg_to_filter not in record.getMessage()
)
from ..cli_args import args

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from enum import Enum
from typing import NamedTuple, Optional, List, Literal, Sequence
from typing import Tuple
from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict
@ -71,15 +69,26 @@ class ExtraData(TypedDict):
token: NotRequired[str]
@dataclass
class NamedQueueTuple:
class NamedQueueTuple(dict):
"""
A wrapper class for a queue tuple, the object that is given to executors.
Attributes:
queue_tuple (QueueTuple): the corresponding queued workflow and other related data
"""
queue_tuple: QueueTuple
__slots__ = ('queue_tuple',)
def __init__(self, queue_tuple: QueueTuple):
# Initialize the dictionary superclass with the data we want to serialize.
super().__init__(
priority=queue_tuple[0],
prompt_id=queue_tuple[1],
prompt=queue_tuple[2],
extra_data=queue_tuple[3] if len(queue_tuple) > 3 else None,
good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None
)
# Store the original tuple in a slot, making it invisible to json.dumps.
self.queue_tuple = queue_tuple
@property
def priority(self) -> float:
@ -95,20 +104,17 @@ class NamedQueueTuple:
@property
def extra_data(self) -> Optional[ExtraData]:
if len(self.queue_tuple) > 2:
if len(self.queue_tuple) > 3:
return self.queue_tuple[3]
else:
return None
return None
@property
def good_outputs(self) -> Optional[List[str]]:
if len(self.queue_tuple) > 3:
if len(self.queue_tuple) > 4:
return self.queue_tuple[4]
else:
return None
return None
@dataclass
class QueueItem(NamedQueueTuple):
"""
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
@ -118,10 +124,18 @@ class QueueItem(NamedQueueTuple):
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
or a dictionary of outputs
"""
completed: asyncio.Future[TaskInvocation | dict] | None
__slots__ = ('completed',)
def __init__(self, queue_tuple: QueueTuple, completed: asyncio.Future[TaskInvocation | dict] | None):
# Initialize the parent, which sets up the dictionary representation.
super().__init__(queue_tuple=queue_tuple)
# Store the future in a slot so it won't be serialized.
self.completed = completed
def __lt__(self, other: QueueItem):
return self.queue_tuple[0] < other.queue_tuple[0]
if not isinstance(other, QueueItem):
return NotImplemented
return self.priority < other.priority
class BinaryEventTypes(Enum):

View File

@ -18,6 +18,7 @@ from ..client.embedded_comfy_client import Comfy
from ..cmd.main_pre import tracer
from ..component_model.queue_types import ExecutionStatus
logger = logging.getLogger(__name__)
class DistributedPromptWorker:
"""
@ -62,12 +63,12 @@ class DistributedPromptWorker:
site = web.TCPSite(runner, port=self._health_check_port)
await site.start()
self._health_check_site = site
logging.info(f"health check server started on port {self._health_check_port}")
logger.info(f"health check server started on port {self._health_check_port}")
except OSError as e:
if e.errno == 98:
logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
logger.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
else:
logging.error(f"failed to start health check server with error {str(e)}, starting anyway")
logger.error(f"failed to start health check server with error {str(e)}, starting anyway")
@tracer.start_as_current_span("Do Work Item")
async def _do_work_item(self, request: dict) -> dict:
@ -117,7 +118,7 @@ class DistributedPromptWorker:
try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error:
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
logger.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
raise connection_error
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)

View File

@ -11,29 +11,29 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@pytest.mark.asyncio
async def test_completes_prompt(comfy_background_server):
client = AsyncRemoteComfyClient()
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
async with AsyncRemoteComfyClient() as client:
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000
@pytest.mark.asyncio
async def test_completes_prompt_with_ui(comfy_background_server):
client = AsyncRemoteComfyClient()
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_dict = await client.queue_prompt_ui(prompt)
# should contain one output
async with AsyncRemoteComfyClient() as client:
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_dict = await client.queue_prompt_ui(prompt)
# should contain one output
assert len(result_dict) == 1
@pytest.mark.asyncio
async def test_completes_prompt_with_image_urls(comfy_background_server):
client = AsyncRemoteComfyClient()
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
result = await client.queue_prompt_api(prompt)
async with AsyncRemoteComfyClient() as client:
random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
result = await client.queue_prompt_api(prompt)
assert len(result.urls) == 2
for url_str in result.urls:
url: URL = parse(url_str)

View File

@ -1,18 +1,25 @@
import sys
import time
import logging
import multiprocessing
import os
import pathlib
import socket
import subprocess
import sys
import time
import urllib
from typing import Tuple, List
import pytest
import requests
import socket
import subprocess
import urllib
from testcontainers.rabbitmq import RabbitMqContainer
from typing import Tuple, List
from comfy.cli_args_types import Configuration
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
@ -95,7 +102,6 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
executor_factory = request.param
processes_to_close: List[subprocess.Popen] = []
from testcontainers.rabbitmq import RabbitMqContainer
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"

View File

@ -103,21 +103,20 @@ async def test_distributed_prompt_queues_same_process():
@pytest.mark.asyncio
async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq):
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
len_queue_after = await client.len_queue()
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
len_queue_after = await client.len_queue()
assert len_queue_after == 0
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
@pytest.mark.asyncio
async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq):
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
with pytest.raises(Exception):
await client.queue_prompt(prompt)
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
with pytest.raises(Exception):
await client.queue_prompt(prompt)
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
@ -151,48 +150,47 @@ async def test_basic_queue_worker_with_health_check(executor_factory):
@pytest.mark.asyncio
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
# Create the client using the server address from the fixture
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a test prompt
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
# Create a test prompt
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
# Queue the prompt
task_id = await client.queue_and_forget_prompt_api(prompt)
# Queue the prompt
task_id = await client.queue_and_forget_prompt_api(prompt)
assert task_id is not None, "Failed to get a valid task ID"
assert task_id is not None, "Failed to get a valid task ID"
# Poll for the result
max_attempts = 60 # Increase max attempts for integration test
poll_interval = 1 # Increase poll interval for integration test
for _ in range(max_attempts):
try:
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
if response.status == 200:
result = await response.json()
assert result is not None, "Received empty result"
# Poll for the result
max_attempts = 60 # Increase max attempts for integration test
poll_interval = 1 # Increase poll interval for integration test
for _ in range(max_attempts):
try:
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
if response.status == 200:
result = await response.json()
assert result is not None, "Received empty result"
# Find the first output node with images
output_node = next((node for node in result.values() if 'images' in node), None)
assert output_node is not None, "No output node with images found"
# Find the first output node with images
output_node = next((node for node in result.values() if 'images' in node), None)
assert output_node is not None, "No output node with images found"
assert len(output_node['images']) > 0, "No images in output node"
assert 'filename' in output_node['images'][0], "No filename in image output"
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
assert 'type' in output_node['images'][0], "No type in image output"
assert len(output_node['images']) > 0, "No images in output node"
assert 'filename' in output_node['images'][0], "No filename in image output"
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
assert 'type' in output_node['images'][0], "No type in image output"
# Check if we can access the image
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
image_response = await client.session.get(image_url)
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
# Check if we can access the image
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
image_response = await client.session.get(image_url)
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
return # Test passed
elif response.status == 204:
return # Test passed
elif response.status == 204:
await asyncio.sleep(poll_interval)
else:
response.raise_for_status()
except _:
await asyncio.sleep(poll_interval)
else:
response.raise_for_status()
except Exception as e:
print(f"Error while polling: {e}")
await asyncio.sleep(poll_interval)
pytest.fail("Failed to get a 200 response with valid data within the timeout period")