fix tests

This commit is contained in:
doctorpangloss 2025-06-17 18:38:42 -07:00
parent 42dd7a59e3
commit 1d29c97266
8 changed files with 123 additions and 106 deletions

View File

@ -40,9 +40,6 @@ class UserManager():
self.settings = AppSettings(self) self.settings = AppSettings(self)
if not os.path.exists(user_directory): if not os.path.exists(user_directory):
os.makedirs(user_directory, exist_ok=True) os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
if args.multi_user: if args.multi_user:
if os.path.isfile(self.get_users_file()): if os.path.isfile(self.get_users_file()):

View File

@ -1,14 +1,14 @@
import asyncio
import uuid
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from collections import defaultdict from collections import defaultdict
import aiohttp
import asyncio
import uuid
from aiohttp import WSMessage, ClientResponse, ClientTimeout
from pathlib import Path from pathlib import Path
from typing import Optional, List from typing import Optional, List
from urllib.parse import urlparse, urljoin from urllib.parse import urlparse, urljoin
import aiohttp
from aiohttp import WSMessage, ClientResponse, ClientTimeout
from .client_types import V1QueuePromptResponse from .client_types import V1QueuePromptResponse
from ..api.api_client import JSONEncoder from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
@ -33,17 +33,25 @@ class AsyncRemoteComfyClient:
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}") f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
self._session: aiohttp.ClientSession | None = None self._session: aiohttp.ClientSession | None = None
try:
if asyncio.get_event_loop() is not None:
self._ensure_session()
except RuntimeError as no_running_event_loop:
pass
def _ensure_session(self) -> aiohttp.ClientSession: def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None: if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0)) self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0))
return self._session return self._session
async def __aenter__(self):
"""Allows the client to be used in an 'async with' block."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Closes the session when exiting an 'async with' block."""
await self.close()
async def close(self):
"""Closes the underlying aiohttp.ClientSession."""
if self._session and not self._session.closed:
await self._session.close()
@property @property
def session(self) -> aiohttp.ClientSession: def session(self) -> aiohttp.ClientSession:
return self._ensure_session() return self._ensure_session()

View File

@ -40,17 +40,6 @@ from ..tracing_compatibility import ProgressSpanSampler
from ..tracing_compatibility import patch_spanbuilder_set_channel from ..tracing_compatibility import patch_spanbuilder_set_channel
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
# Manually call the _init_dll_path method to ensure that the system path is searched for FFMPEG.
# Calling torchaudio._extension.utils._init_dll_path does not work because it is initializing the torchadio module prematurely or something.
# See: https://github.com/pytorch/audio/issues/3789
if sys.platform == "win32":
for path in os.environ.get("PATH", "").split(os.pathsep):
if os.path.exists(path):
try:
os.add_dll_directory(path)
except Exception:
pass
this_logger = logging.getLogger(__name__) this_logger = logging.getLogger(__name__)
options.enable_args_parsing() options.enable_args_parsing()
@ -66,6 +55,10 @@ warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplic
warnings.filterwarnings("ignore", message="Please import `gaussian_filter` from the `scipy.ndimage` namespace; the `scipy.ndimage.filters` namespace is deprecated", category=DeprecationWarning) warnings.filterwarnings("ignore", message="Please import `gaussian_filter` from the `scipy.ndimage` namespace; the `scipy.ndimage.filters` namespace is deprecated", category=DeprecationWarning)
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support") warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support")
warnings.filterwarnings("ignore", category=UserWarning, message="Unsupported Windows version .* ONNX Runtime supports Windows 10 and above, only.") warnings.filterwarnings("ignore", category=UserWarning, message="Unsupported Windows version .* ONNX Runtime supports Windows 10 and above, only.")
log_msg_to_filter = "NOTE: Redirects are currently not supported in Windows or MacOs."
logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilter(
lambda record: log_msg_to_filter not in record.getMessage()
)
from ..cli_args import args from ..cli_args import args

View File

@ -1,11 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import NamedTuple, Optional, List, Literal, Sequence from typing import NamedTuple, Optional, List, Literal, Sequence
from typing import Tuple from typing import Tuple
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict from .outputs_types import OutputsDict
@ -71,15 +69,26 @@ class ExtraData(TypedDict):
token: NotRequired[str] token: NotRequired[str]
@dataclass class NamedQueueTuple(dict):
class NamedQueueTuple:
""" """
A wrapper class for a queue tuple, the object that is given to executors. A wrapper class for a queue tuple, the object that is given to executors.
Attributes: Attributes:
queue_tuple (QueueTuple): the corresponding queued workflow and other related data queue_tuple (QueueTuple): the corresponding queued workflow and other related data
""" """
queue_tuple: QueueTuple __slots__ = ('queue_tuple',)
def __init__(self, queue_tuple: QueueTuple):
# Initialize the dictionary superclass with the data we want to serialize.
super().__init__(
priority=queue_tuple[0],
prompt_id=queue_tuple[1],
prompt=queue_tuple[2],
extra_data=queue_tuple[3] if len(queue_tuple) > 3 else None,
good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None
)
# Store the original tuple in a slot, making it invisible to json.dumps.
self.queue_tuple = queue_tuple
@property @property
def priority(self) -> float: def priority(self) -> float:
@ -95,20 +104,17 @@ class NamedQueueTuple:
@property @property
def extra_data(self) -> Optional[ExtraData]: def extra_data(self) -> Optional[ExtraData]:
if len(self.queue_tuple) > 2: if len(self.queue_tuple) > 3:
return self.queue_tuple[3] return self.queue_tuple[3]
else: return None
return None
@property @property
def good_outputs(self) -> Optional[List[str]]: def good_outputs(self) -> Optional[List[str]]:
if len(self.queue_tuple) > 3: if len(self.queue_tuple) > 4:
return self.queue_tuple[4] return self.queue_tuple[4]
else: return None
return None
@dataclass
class QueueItem(NamedQueueTuple): class QueueItem(NamedQueueTuple):
""" """
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
@ -118,10 +124,18 @@ class QueueItem(NamedQueueTuple):
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method) completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
or a dictionary of outputs or a dictionary of outputs
""" """
completed: asyncio.Future[TaskInvocation | dict] | None __slots__ = ('completed',)
def __init__(self, queue_tuple: QueueTuple, completed: asyncio.Future[TaskInvocation | dict] | None):
# Initialize the parent, which sets up the dictionary representation.
super().__init__(queue_tuple=queue_tuple)
# Store the future in a slot so it won't be serialized.
self.completed = completed
def __lt__(self, other: QueueItem): def __lt__(self, other: QueueItem):
return self.queue_tuple[0] < other.queue_tuple[0] if not isinstance(other, QueueItem):
return NotImplemented
return self.priority < other.priority
class BinaryEventTypes(Enum): class BinaryEventTypes(Enum):

View File

@ -18,6 +18,7 @@ from ..client.embedded_comfy_client import Comfy
from ..cmd.main_pre import tracer from ..cmd.main_pre import tracer
from ..component_model.queue_types import ExecutionStatus from ..component_model.queue_types import ExecutionStatus
logger = logging.getLogger(__name__)
class DistributedPromptWorker: class DistributedPromptWorker:
""" """
@ -62,12 +63,12 @@ class DistributedPromptWorker:
site = web.TCPSite(runner, port=self._health_check_port) site = web.TCPSite(runner, port=self._health_check_port)
await site.start() await site.start()
self._health_check_site = site self._health_check_site = site
logging.info(f"health check server started on port {self._health_check_port}") logger.info(f"health check server started on port {self._health_check_port}")
except OSError as e: except OSError as e:
if e.errno == 98: if e.errno == 98:
logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway") logger.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
else: else:
logging.error(f"failed to start health check server with error {str(e)}, starting anyway") logger.error(f"failed to start health check server with error {str(e)}, starting anyway")
@tracer.start_as_current_span("Do Work Item") @tracer.start_as_current_span("Do Work Item")
async def _do_work_item(self, request: dict) -> dict: async def _do_work_item(self, request: dict) -> dict:
@ -117,7 +118,7 @@ class DistributedPromptWorker:
try: try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop) self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error: except AMQPConnectionError as connection_error:
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error) logger.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
raise connection_error raise connection_error
self._channel = await self._connection.channel() self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1) await self._channel.set_qos(prefetch_count=1)

View File

@ -11,29 +11,29 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completes_prompt(comfy_background_server): async def test_completes_prompt(comfy_background_server):
client = AsyncRemoteComfyClient() async with AsyncRemoteComfyClient() as client:
random_seed = random.randint(1, 4294967295) random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt) png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000 assert len(png_image_bytes) > 1000
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completes_prompt_with_ui(comfy_background_server): async def test_completes_prompt_with_ui(comfy_background_server):
client = AsyncRemoteComfyClient() async with AsyncRemoteComfyClient() as client:
random_seed = random.randint(1, 4294967295) random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_dict = await client.queue_prompt_ui(prompt) result_dict = await client.queue_prompt_ui(prompt)
# should contain one output # should contain one output
assert len(result_dict) == 1 assert len(result_dict) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completes_prompt_with_image_urls(comfy_background_server): async def test_completes_prompt_with_image_urls(comfy_background_server):
client = AsyncRemoteComfyClient() async with AsyncRemoteComfyClient() as client:
random_seed = random.randint(1, 4294967295) random_seed = random.randint(1, 4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl") prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
result = await client.queue_prompt_api(prompt) result = await client.queue_prompt_api(prompt)
assert len(result.urls) == 2 assert len(result.urls) == 2
for url_str in result.urls: for url_str in result.urls:
url: URL = parse(url_str) url: URL = parse(url_str)

View File

@ -1,18 +1,25 @@
import sys
import time
import logging
import multiprocessing import multiprocessing
import os import os
import pathlib import pathlib
import socket
import subprocess
import sys
import time
import urllib
from typing import Tuple, List
import pytest import pytest
import requests import requests
import socket
import subprocess
import urllib
from testcontainers.rabbitmq import RabbitMqContainer
from typing import Tuple, List
from comfy.cli_args_types import Configuration from comfy.cli_args_types import Configuration
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
# fixes issues with running the testcontainers rabbitmqcontainer on Windows # fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost" os.environ["TC_HOST"] = "localhost"
@ -95,7 +102,6 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
executor_factory = request.param executor_factory = request.param
processes_to_close: List[subprocess.Popen] = [] processes_to_close: List[subprocess.Popen] = []
from testcontainers.rabbitmq import RabbitMqContainer
with RabbitMqContainer("rabbitmq:latest") as rabbitmq: with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params() params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"

View File

@ -103,21 +103,20 @@ async def test_distributed_prompt_queues_same_process():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq): async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq):
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt) png_image_bytes = await client.queue_prompt(prompt)
len_queue_after = await client.len_queue() len_queue_after = await client.len_queue()
assert len_queue_after == 0 assert len_queue_after == 0
assert len(png_image_bytes) > 1000, "expected an image, but got nothing" assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq): async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq):
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors") with pytest.raises(Exception):
with pytest.raises(Exception): await client.queue_prompt(prompt)
await client.queue_prompt(prompt)
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0): async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
@ -151,48 +150,47 @@ async def test_basic_queue_worker_with_health_check(executor_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq): async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
# Create the client using the server address from the fixture # Create the client using the server address from the fixture
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a test prompt # Create a test prompt
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
# Queue the prompt # Queue the prompt
task_id = await client.queue_and_forget_prompt_api(prompt) task_id = await client.queue_and_forget_prompt_api(prompt)
assert task_id is not None, "Failed to get a valid task ID" assert task_id is not None, "Failed to get a valid task ID"
# Poll for the result # Poll for the result
max_attempts = 60 # Increase max attempts for integration test max_attempts = 60 # Increase max attempts for integration test
poll_interval = 1 # Increase poll interval for integration test poll_interval = 1 # Increase poll interval for integration test
for _ in range(max_attempts): for _ in range(max_attempts):
try: try:
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}") response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
if response.status == 200: if response.status == 200:
result = await response.json() result = await response.json()
assert result is not None, "Received empty result" assert result is not None, "Received empty result"
# Find the first output node with images # Find the first output node with images
output_node = next((node for node in result.values() if 'images' in node), None) output_node = next((node for node in result.values() if 'images' in node), None)
assert output_node is not None, "No output node with images found" assert output_node is not None, "No output node with images found"
assert len(output_node['images']) > 0, "No images in output node" assert len(output_node['images']) > 0, "No images in output node"
assert 'filename' in output_node['images'][0], "No filename in image output" assert 'filename' in output_node['images'][0], "No filename in image output"
assert 'subfolder' in output_node['images'][0], "No subfolder in image output" assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
assert 'type' in output_node['images'][0], "No type in image output" assert 'type' in output_node['images'][0], "No type in image output"
# Check if we can access the image # Check if we can access the image
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}" image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
image_response = await client.session.get(image_url) image_response = await client.session.get(image_url)
assert image_response.status == 200, f"Failed to retrieve image from {image_url}" assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
return # Test passed return # Test passed
elif response.status == 204: elif response.status == 204:
await asyncio.sleep(poll_interval)
else:
response.raise_for_status()
except _:
await asyncio.sleep(poll_interval) await asyncio.sleep(poll_interval)
else:
response.raise_for_status()
except Exception as e:
print(f"Error while polling: {e}")
await asyncio.sleep(poll_interval)
pytest.fail("Failed to get a 200 response with valid data within the timeout period") pytest.fail("Failed to get a 200 response with valid data within the timeout period")