mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
fix tests
This commit is contained in:
parent
42dd7a59e3
commit
1d29c97266
@ -40,9 +40,6 @@ class UserManager():
|
||||
self.settings = AppSettings(self)
|
||||
if not os.path.exists(user_directory):
|
||||
os.makedirs(user_directory, exist_ok=True)
|
||||
if not args.multi_user:
|
||||
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(self.get_users_file()):
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from asyncio import AbstractEventLoop
|
||||
from collections import defaultdict
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from urllib.parse import urlparse, urljoin
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
||||
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.api_client import JSONEncoder
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
@ -33,17 +33,25 @@ class AsyncRemoteComfyClient:
|
||||
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
if asyncio.get_event_loop() is not None:
|
||||
self._ensure_session()
|
||||
except RuntimeError as no_running_event_loop:
|
||||
pass
|
||||
|
||||
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0))
|
||||
return self._session
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Allows the client to be used in an 'async with' block."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Closes the session when exiting an 'async with' block."""
|
||||
await self.close()
|
||||
|
||||
async def close(self):
|
||||
"""Closes the underlying aiohttp.ClientSession."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
@property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
return self._ensure_session()
|
||||
|
||||
@ -40,17 +40,6 @@ from ..tracing_compatibility import ProgressSpanSampler
|
||||
from ..tracing_compatibility import patch_spanbuilder_set_channel
|
||||
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
|
||||
|
||||
# Manually call the _init_dll_path method to ensure that the system path is searched for FFMPEG.
|
||||
# Calling torchaudio._extension.utils._init_dll_path does not work because it is initializing the torchadio module prematurely or something.
|
||||
# See: https://github.com/pytorch/audio/issues/3789
|
||||
if sys.platform == "win32":
|
||||
for path in os.environ.get("PATH", "").split(os.pathsep):
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.add_dll_directory(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
this_logger = logging.getLogger(__name__)
|
||||
|
||||
options.enable_args_parsing()
|
||||
@ -66,6 +55,10 @@ warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplic
|
||||
warnings.filterwarnings("ignore", message="Please import `gaussian_filter` from the `scipy.ndimage` namespace; the `scipy.ndimage.filters` namespace is deprecated", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support")
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="Unsupported Windows version .* ONNX Runtime supports Windows 10 and above, only.")
|
||||
log_msg_to_filter = "NOTE: Redirects are currently not supported in Windows or MacOs."
|
||||
logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilter(
|
||||
lambda record: log_msg_to_filter not in record.getMessage()
|
||||
)
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Optional, List, Literal, Sequence
|
||||
from typing import Tuple
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .outputs_types import OutputsDict
|
||||
@ -71,15 +69,26 @@ class ExtraData(TypedDict):
|
||||
token: NotRequired[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NamedQueueTuple:
|
||||
class NamedQueueTuple(dict):
|
||||
"""
|
||||
A wrapper class for a queue tuple, the object that is given to executors.
|
||||
|
||||
Attributes:
|
||||
queue_tuple (QueueTuple): the corresponding queued workflow and other related data
|
||||
"""
|
||||
queue_tuple: QueueTuple
|
||||
__slots__ = ('queue_tuple',)
|
||||
|
||||
def __init__(self, queue_tuple: QueueTuple):
|
||||
# Initialize the dictionary superclass with the data we want to serialize.
|
||||
super().__init__(
|
||||
priority=queue_tuple[0],
|
||||
prompt_id=queue_tuple[1],
|
||||
prompt=queue_tuple[2],
|
||||
extra_data=queue_tuple[3] if len(queue_tuple) > 3 else None,
|
||||
good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None
|
||||
)
|
||||
# Store the original tuple in a slot, making it invisible to json.dumps.
|
||||
self.queue_tuple = queue_tuple
|
||||
|
||||
@property
|
||||
def priority(self) -> float:
|
||||
@ -95,20 +104,17 @@ class NamedQueueTuple:
|
||||
|
||||
@property
|
||||
def extra_data(self) -> Optional[ExtraData]:
|
||||
if len(self.queue_tuple) > 2:
|
||||
if len(self.queue_tuple) > 3:
|
||||
return self.queue_tuple[3]
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
@property
|
||||
def good_outputs(self) -> Optional[List[str]]:
|
||||
if len(self.queue_tuple) > 3:
|
||||
if len(self.queue_tuple) > 4:
|
||||
return self.queue_tuple[4]
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueItem(NamedQueueTuple):
|
||||
"""
|
||||
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
|
||||
@ -118,10 +124,18 @@ class QueueItem(NamedQueueTuple):
|
||||
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
|
||||
or a dictionary of outputs
|
||||
"""
|
||||
completed: asyncio.Future[TaskInvocation | dict] | None
|
||||
__slots__ = ('completed',)
|
||||
|
||||
def __init__(self, queue_tuple: QueueTuple, completed: asyncio.Future[TaskInvocation | dict] | None):
|
||||
# Initialize the parent, which sets up the dictionary representation.
|
||||
super().__init__(queue_tuple=queue_tuple)
|
||||
# Store the future in a slot so it won't be serialized.
|
||||
self.completed = completed
|
||||
|
||||
def __lt__(self, other: QueueItem):
|
||||
return self.queue_tuple[0] < other.queue_tuple[0]
|
||||
if not isinstance(other, QueueItem):
|
||||
return NotImplemented
|
||||
return self.priority < other.priority
|
||||
|
||||
|
||||
class BinaryEventTypes(Enum):
|
||||
|
||||
@ -18,6 +18,7 @@ from ..client.embedded_comfy_client import Comfy
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.queue_types import ExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DistributedPromptWorker:
|
||||
"""
|
||||
@ -62,12 +63,12 @@ class DistributedPromptWorker:
|
||||
site = web.TCPSite(runner, port=self._health_check_port)
|
||||
await site.start()
|
||||
self._health_check_site = site
|
||||
logging.info(f"health check server started on port {self._health_check_port}")
|
||||
logger.info(f"health check server started on port {self._health_check_port}")
|
||||
except OSError as e:
|
||||
if e.errno == 98:
|
||||
logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
|
||||
logger.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
|
||||
else:
|
||||
logging.error(f"failed to start health check server with error {str(e)}, starting anyway")
|
||||
logger.error(f"failed to start health check server with error {str(e)}, starting anyway")
|
||||
|
||||
@tracer.start_as_current_span("Do Work Item")
|
||||
async def _do_work_item(self, request: dict) -> dict:
|
||||
@ -117,7 +118,7 @@ class DistributedPromptWorker:
|
||||
try:
|
||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||
except AMQPConnectionError as connection_error:
|
||||
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
||||
logger.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
||||
raise connection_error
|
||||
self._channel = await self._connection.channel()
|
||||
await self._channel.set_qos(prefetch_count=1)
|
||||
|
||||
@ -11,29 +11,29 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completes_prompt(comfy_background_server):
|
||||
client = AsyncRemoteComfyClient()
|
||||
random_seed = random.randint(1, 4294967295)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||
png_image_bytes = await client.queue_prompt(prompt)
|
||||
async with AsyncRemoteComfyClient() as client:
|
||||
random_seed = random.randint(1, 4294967295)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||
png_image_bytes = await client.queue_prompt(prompt)
|
||||
assert len(png_image_bytes) > 1000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completes_prompt_with_ui(comfy_background_server):
|
||||
client = AsyncRemoteComfyClient()
|
||||
random_seed = random.randint(1, 4294967295)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||
result_dict = await client.queue_prompt_ui(prompt)
|
||||
# should contain one output
|
||||
async with AsyncRemoteComfyClient() as client:
|
||||
random_seed = random.randint(1, 4294967295)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||
result_dict = await client.queue_prompt_ui(prompt)
|
||||
# should contain one output
|
||||
assert len(result_dict) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completes_prompt_with_image_urls(comfy_background_server):
|
||||
client = AsyncRemoteComfyClient()
|
||||
random_seed = random.randint(1, 4294967295)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
|
||||
result = await client.queue_prompt_api(prompt)
|
||||
async with AsyncRemoteComfyClient() as client:
|
||||
random_seed = random.randint(1, 4294967295)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
|
||||
result = await client.queue_prompt_api(prompt)
|
||||
assert len(result.urls) == 2
|
||||
for url_str in result.urls:
|
||||
url: URL = parse(url_str)
|
||||
|
||||
@ -1,18 +1,25 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import urllib
|
||||
from typing import Tuple, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import socket
|
||||
import subprocess
|
||||
import urllib
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
from typing import Tuple, List
|
||||
|
||||
from comfy.cli_args_types import Configuration
|
||||
|
||||
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
|
||||
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
|
||||
|
||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||
os.environ["TC_HOST"] = "localhost"
|
||||
|
||||
@ -95,7 +102,6 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
||||
executor_factory = request.param
|
||||
processes_to_close: List[subprocess.Popen] = []
|
||||
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
|
||||
@ -103,21 +103,20 @@ async def test_distributed_prompt_queues_same_process():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq):
|
||||
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||
png_image_bytes = await client.queue_prompt(prompt)
|
||||
len_queue_after = await client.len_queue()
|
||||
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||
png_image_bytes = await client.queue_prompt(prompt)
|
||||
len_queue_after = await client.len_queue()
|
||||
assert len_queue_after == 0
|
||||
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq):
|
||||
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
||||
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
|
||||
with pytest.raises(Exception):
|
||||
await client.queue_prompt(prompt)
|
||||
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
|
||||
with pytest.raises(Exception):
|
||||
await client.queue_prompt(prompt)
|
||||
|
||||
|
||||
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
|
||||
@ -151,48 +150,47 @@ async def test_basic_queue_worker_with_health_check(executor_factory):
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
|
||||
# Create the client using the server address from the fixture
|
||||
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
||||
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||
|
||||
# Create a test prompt
|
||||
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
|
||||
# Create a test prompt
|
||||
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
|
||||
|
||||
# Queue the prompt
|
||||
task_id = await client.queue_and_forget_prompt_api(prompt)
|
||||
# Queue the prompt
|
||||
task_id = await client.queue_and_forget_prompt_api(prompt)
|
||||
|
||||
assert task_id is not None, "Failed to get a valid task ID"
|
||||
assert task_id is not None, "Failed to get a valid task ID"
|
||||
|
||||
# Poll for the result
|
||||
max_attempts = 60 # Increase max attempts for integration test
|
||||
poll_interval = 1 # Increase poll interval for integration test
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
assert result is not None, "Received empty result"
|
||||
# Poll for the result
|
||||
max_attempts = 60 # Increase max attempts for integration test
|
||||
poll_interval = 1 # Increase poll interval for integration test
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
assert result is not None, "Received empty result"
|
||||
|
||||
# Find the first output node with images
|
||||
output_node = next((node for node in result.values() if 'images' in node), None)
|
||||
assert output_node is not None, "No output node with images found"
|
||||
# Find the first output node with images
|
||||
output_node = next((node for node in result.values() if 'images' in node), None)
|
||||
assert output_node is not None, "No output node with images found"
|
||||
|
||||
assert len(output_node['images']) > 0, "No images in output node"
|
||||
assert 'filename' in output_node['images'][0], "No filename in image output"
|
||||
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
|
||||
assert 'type' in output_node['images'][0], "No type in image output"
|
||||
assert len(output_node['images']) > 0, "No images in output node"
|
||||
assert 'filename' in output_node['images'][0], "No filename in image output"
|
||||
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
|
||||
assert 'type' in output_node['images'][0], "No type in image output"
|
||||
|
||||
# Check if we can access the image
|
||||
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
|
||||
image_response = await client.session.get(image_url)
|
||||
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
|
||||
# Check if we can access the image
|
||||
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
|
||||
image_response = await client.session.get(image_url)
|
||||
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
|
||||
|
||||
return # Test passed
|
||||
elif response.status == 204:
|
||||
return # Test passed
|
||||
elif response.status == 204:
|
||||
await asyncio.sleep(poll_interval)
|
||||
else:
|
||||
response.raise_for_status()
|
||||
except _:
|
||||
await asyncio.sleep(poll_interval)
|
||||
else:
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Error while polling: {e}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user