mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 23:30:16 +08:00
fix tests
This commit is contained in:
parent
42dd7a59e3
commit
1d29c97266
@ -40,9 +40,6 @@ class UserManager():
|
|||||||
self.settings = AppSettings(self)
|
self.settings = AppSettings(self)
|
||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
os.makedirs(user_directory, exist_ok=True)
|
os.makedirs(user_directory, exist_ok=True)
|
||||||
if not args.multi_user:
|
|
||||||
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
|
||||||
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(self.get_users_file()):
|
if os.path.isfile(self.get_users_file()):
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
import asyncio
|
|
||||||
import uuid
|
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import AbstractEventLoop
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from urllib.parse import urlparse, urljoin
|
from urllib.parse import urlparse, urljoin
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
|
||||||
|
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.api_client import JSONEncoder
|
from ..api.api_client import JSONEncoder
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
@ -33,17 +33,25 @@ class AsyncRemoteComfyClient:
|
|||||||
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
|
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
|
||||||
self.loop = loop or asyncio.get_event_loop()
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
self._session: aiohttp.ClientSession | None = None
|
self._session: aiohttp.ClientSession | None = None
|
||||||
try:
|
|
||||||
if asyncio.get_event_loop() is not None:
|
|
||||||
self._ensure_session()
|
|
||||||
except RuntimeError as no_running_event_loop:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _ensure_session(self) -> aiohttp.ClientSession:
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||||
if self._session is None:
|
if self._session is None or self._session.closed:
|
||||||
self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0))
|
self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0))
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Allows the client to be used in an 'async with' block."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Closes the session when exiting an 'async with' block."""
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Closes the underlying aiohttp.ClientSession."""
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session(self) -> aiohttp.ClientSession:
|
def session(self) -> aiohttp.ClientSession:
|
||||||
return self._ensure_session()
|
return self._ensure_session()
|
||||||
|
|||||||
@ -40,17 +40,6 @@ from ..tracing_compatibility import ProgressSpanSampler
|
|||||||
from ..tracing_compatibility import patch_spanbuilder_set_channel
|
from ..tracing_compatibility import patch_spanbuilder_set_channel
|
||||||
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
|
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
|
||||||
|
|
||||||
# Manually call the _init_dll_path method to ensure that the system path is searched for FFMPEG.
|
|
||||||
# Calling torchaudio._extension.utils._init_dll_path does not work because it is initializing the torchadio module prematurely or something.
|
|
||||||
# See: https://github.com/pytorch/audio/issues/3789
|
|
||||||
if sys.platform == "win32":
|
|
||||||
for path in os.environ.get("PATH", "").split(os.pathsep):
|
|
||||||
if os.path.exists(path):
|
|
||||||
try:
|
|
||||||
os.add_dll_directory(path)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
this_logger = logging.getLogger(__name__)
|
this_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
options.enable_args_parsing()
|
options.enable_args_parsing()
|
||||||
@ -66,6 +55,10 @@ warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplic
|
|||||||
warnings.filterwarnings("ignore", message="Please import `gaussian_filter` from the `scipy.ndimage` namespace; the `scipy.ndimage.filters` namespace is deprecated", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", message="Please import `gaussian_filter` from the `scipy.ndimage` namespace; the `scipy.ndimage.filters` namespace is deprecated", category=DeprecationWarning)
|
||||||
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support")
|
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support")
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, message="Unsupported Windows version .* ONNX Runtime supports Windows 10 and above, only.")
|
warnings.filterwarnings("ignore", category=UserWarning, message="Unsupported Windows version .* ONNX Runtime supports Windows 10 and above, only.")
|
||||||
|
log_msg_to_filter = "NOTE: Redirects are currently not supported in Windows or MacOs."
|
||||||
|
logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").addFilter(
|
||||||
|
lambda record: log_msg_to_filter not in record.getMessage()
|
||||||
|
)
|
||||||
|
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple, Optional, List, Literal, Sequence
|
from typing import NamedTuple, Optional, List, Literal, Sequence
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from .outputs_types import OutputsDict
|
from .outputs_types import OutputsDict
|
||||||
@ -71,15 +69,26 @@ class ExtraData(TypedDict):
|
|||||||
token: NotRequired[str]
|
token: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class NamedQueueTuple(dict):
|
||||||
class NamedQueueTuple:
|
|
||||||
"""
|
"""
|
||||||
A wrapper class for a queue tuple, the object that is given to executors.
|
A wrapper class for a queue tuple, the object that is given to executors.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
queue_tuple (QueueTuple): the corresponding queued workflow and other related data
|
queue_tuple (QueueTuple): the corresponding queued workflow and other related data
|
||||||
"""
|
"""
|
||||||
queue_tuple: QueueTuple
|
__slots__ = ('queue_tuple',)
|
||||||
|
|
||||||
|
def __init__(self, queue_tuple: QueueTuple):
|
||||||
|
# Initialize the dictionary superclass with the data we want to serialize.
|
||||||
|
super().__init__(
|
||||||
|
priority=queue_tuple[0],
|
||||||
|
prompt_id=queue_tuple[1],
|
||||||
|
prompt=queue_tuple[2],
|
||||||
|
extra_data=queue_tuple[3] if len(queue_tuple) > 3 else None,
|
||||||
|
good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None
|
||||||
|
)
|
||||||
|
# Store the original tuple in a slot, making it invisible to json.dumps.
|
||||||
|
self.queue_tuple = queue_tuple
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def priority(self) -> float:
|
def priority(self) -> float:
|
||||||
@ -95,20 +104,17 @@ class NamedQueueTuple:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def extra_data(self) -> Optional[ExtraData]:
|
def extra_data(self) -> Optional[ExtraData]:
|
||||||
if len(self.queue_tuple) > 2:
|
if len(self.queue_tuple) > 3:
|
||||||
return self.queue_tuple[3]
|
return self.queue_tuple[3]
|
||||||
else:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def good_outputs(self) -> Optional[List[str]]:
|
def good_outputs(self) -> Optional[List[str]]:
|
||||||
if len(self.queue_tuple) > 3:
|
if len(self.queue_tuple) > 4:
|
||||||
return self.queue_tuple[4]
|
return self.queue_tuple[4]
|
||||||
else:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class QueueItem(NamedQueueTuple):
|
class QueueItem(NamedQueueTuple):
|
||||||
"""
|
"""
|
||||||
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
|
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
|
||||||
@ -118,10 +124,18 @@ class QueueItem(NamedQueueTuple):
|
|||||||
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
|
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
|
||||||
or a dictionary of outputs
|
or a dictionary of outputs
|
||||||
"""
|
"""
|
||||||
completed: asyncio.Future[TaskInvocation | dict] | None
|
__slots__ = ('completed',)
|
||||||
|
|
||||||
|
def __init__(self, queue_tuple: QueueTuple, completed: asyncio.Future[TaskInvocation | dict] | None):
|
||||||
|
# Initialize the parent, which sets up the dictionary representation.
|
||||||
|
super().__init__(queue_tuple=queue_tuple)
|
||||||
|
# Store the future in a slot so it won't be serialized.
|
||||||
|
self.completed = completed
|
||||||
|
|
||||||
def __lt__(self, other: QueueItem):
|
def __lt__(self, other: QueueItem):
|
||||||
return self.queue_tuple[0] < other.queue_tuple[0]
|
if not isinstance(other, QueueItem):
|
||||||
|
return NotImplemented
|
||||||
|
return self.priority < other.priority
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes(Enum):
|
class BinaryEventTypes(Enum):
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from ..client.embedded_comfy_client import Comfy
|
|||||||
from ..cmd.main_pre import tracer
|
from ..cmd.main_pre import tracer
|
||||||
from ..component_model.queue_types import ExecutionStatus
|
from ..component_model.queue_types import ExecutionStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DistributedPromptWorker:
|
class DistributedPromptWorker:
|
||||||
"""
|
"""
|
||||||
@ -62,12 +63,12 @@ class DistributedPromptWorker:
|
|||||||
site = web.TCPSite(runner, port=self._health_check_port)
|
site = web.TCPSite(runner, port=self._health_check_port)
|
||||||
await site.start()
|
await site.start()
|
||||||
self._health_check_site = site
|
self._health_check_site = site
|
||||||
logging.info(f"health check server started on port {self._health_check_port}")
|
logger.info(f"health check server started on port {self._health_check_port}")
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
if e.errno == 98:
|
if e.errno == 98:
|
||||||
logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
|
logger.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
|
||||||
else:
|
else:
|
||||||
logging.error(f"failed to start health check server with error {str(e)}, starting anyway")
|
logger.error(f"failed to start health check server with error {str(e)}, starting anyway")
|
||||||
|
|
||||||
@tracer.start_as_current_span("Do Work Item")
|
@tracer.start_as_current_span("Do Work Item")
|
||||||
async def _do_work_item(self, request: dict) -> dict:
|
async def _do_work_item(self, request: dict) -> dict:
|
||||||
@ -117,7 +118,7 @@ class DistributedPromptWorker:
|
|||||||
try:
|
try:
|
||||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||||
except AMQPConnectionError as connection_error:
|
except AMQPConnectionError as connection_error:
|
||||||
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
logger.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
||||||
raise connection_error
|
raise connection_error
|
||||||
self._channel = await self._connection.channel()
|
self._channel = await self._connection.channel()
|
||||||
await self._channel.set_qos(prefetch_count=1)
|
await self._channel.set_qos(prefetch_count=1)
|
||||||
|
|||||||
@ -11,29 +11,29 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completes_prompt(comfy_background_server):
|
async def test_completes_prompt(comfy_background_server):
|
||||||
client = AsyncRemoteComfyClient()
|
async with AsyncRemoteComfyClient() as client:
|
||||||
random_seed = random.randint(1, 4294967295)
|
random_seed = random.randint(1, 4294967295)
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||||
png_image_bytes = await client.queue_prompt(prompt)
|
png_image_bytes = await client.queue_prompt(prompt)
|
||||||
assert len(png_image_bytes) > 1000
|
assert len(png_image_bytes) > 1000
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completes_prompt_with_ui(comfy_background_server):
|
async def test_completes_prompt_with_ui(comfy_background_server):
|
||||||
client = AsyncRemoteComfyClient()
|
async with AsyncRemoteComfyClient() as client:
|
||||||
random_seed = random.randint(1, 4294967295)
|
random_seed = random.randint(1, 4294967295)
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||||
result_dict = await client.queue_prompt_ui(prompt)
|
result_dict = await client.queue_prompt_ui(prompt)
|
||||||
# should contain one output
|
# should contain one output
|
||||||
assert len(result_dict) == 1
|
assert len(result_dict) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completes_prompt_with_image_urls(comfy_background_server):
|
async def test_completes_prompt_with_image_urls(comfy_background_server):
|
||||||
client = AsyncRemoteComfyClient()
|
async with AsyncRemoteComfyClient() as client:
|
||||||
random_seed = random.randint(1, 4294967295)
|
random_seed = random.randint(1, 4294967295)
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1, filename_prefix="subdirtest/sdxl")
|
||||||
result = await client.queue_prompt_api(prompt)
|
result = await client.queue_prompt_api(prompt)
|
||||||
assert len(result.urls) == 2
|
assert len(result.urls) == 2
|
||||||
for url_str in result.urls:
|
for url_str in result.urls:
|
||||||
url: URL = parse(url_str)
|
url: URL = parse(url_str)
|
||||||
|
|||||||
@ -1,18 +1,25 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import urllib
|
|
||||||
from typing import Tuple, List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import urllib
|
||||||
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
from comfy.cli_args_types import Configuration
|
from comfy.cli_args_types import Configuration
|
||||||
|
|
||||||
|
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
|
||||||
|
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
|
||||||
|
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
|
||||||
|
|
||||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||||
os.environ["TC_HOST"] = "localhost"
|
os.environ["TC_HOST"] = "localhost"
|
||||||
|
|
||||||
@ -95,7 +102,6 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
|||||||
executor_factory = request.param
|
executor_factory = request.param
|
||||||
processes_to_close: List[subprocess.Popen] = []
|
processes_to_close: List[subprocess.Popen] = []
|
||||||
|
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
|
||||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
params = rabbitmq.get_connection_params()
|
params = rabbitmq.get_connection_params()
|
||||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|||||||
@ -103,21 +103,20 @@ async def test_distributed_prompt_queues_same_process():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq):
|
async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq):
|
||||||
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
png_image_bytes = await client.queue_prompt(prompt)
|
png_image_bytes = await client.queue_prompt(prompt)
|
||||||
len_queue_after = await client.len_queue()
|
len_queue_after = await client.len_queue()
|
||||||
assert len_queue_after == 0
|
assert len_queue_after == 0
|
||||||
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
|
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq):
|
async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq):
|
||||||
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
|
with pytest.raises(Exception):
|
||||||
with pytest.raises(Exception):
|
await client.queue_prompt(prompt)
|
||||||
await client.queue_prompt(prompt)
|
|
||||||
|
|
||||||
|
|
||||||
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
|
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
|
||||||
@ -151,48 +150,47 @@ async def test_basic_queue_worker_with_health_check(executor_factory):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
|
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
|
||||||
# Create the client using the server address from the fixture
|
# Create the client using the server address from the fixture
|
||||||
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
|
||||||
# Create a test prompt
|
# Create a test prompt
|
||||||
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
|
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
|
||||||
|
|
||||||
# Queue the prompt
|
# Queue the prompt
|
||||||
task_id = await client.queue_and_forget_prompt_api(prompt)
|
task_id = await client.queue_and_forget_prompt_api(prompt)
|
||||||
|
|
||||||
assert task_id is not None, "Failed to get a valid task ID"
|
assert task_id is not None, "Failed to get a valid task ID"
|
||||||
|
|
||||||
# Poll for the result
|
# Poll for the result
|
||||||
max_attempts = 60 # Increase max attempts for integration test
|
max_attempts = 60 # Increase max attempts for integration test
|
||||||
poll_interval = 1 # Increase poll interval for integration test
|
poll_interval = 1 # Increase poll interval for integration test
|
||||||
for _ in range(max_attempts):
|
for _ in range(max_attempts):
|
||||||
try:
|
try:
|
||||||
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
|
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
assert result is not None, "Received empty result"
|
assert result is not None, "Received empty result"
|
||||||
|
|
||||||
# Find the first output node with images
|
# Find the first output node with images
|
||||||
output_node = next((node for node in result.values() if 'images' in node), None)
|
output_node = next((node for node in result.values() if 'images' in node), None)
|
||||||
assert output_node is not None, "No output node with images found"
|
assert output_node is not None, "No output node with images found"
|
||||||
|
|
||||||
assert len(output_node['images']) > 0, "No images in output node"
|
assert len(output_node['images']) > 0, "No images in output node"
|
||||||
assert 'filename' in output_node['images'][0], "No filename in image output"
|
assert 'filename' in output_node['images'][0], "No filename in image output"
|
||||||
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
|
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
|
||||||
assert 'type' in output_node['images'][0], "No type in image output"
|
assert 'type' in output_node['images'][0], "No type in image output"
|
||||||
|
|
||||||
# Check if we can access the image
|
# Check if we can access the image
|
||||||
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
|
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
|
||||||
image_response = await client.session.get(image_url)
|
image_response = await client.session.get(image_url)
|
||||||
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
|
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
|
||||||
|
|
||||||
return # Test passed
|
return # Test passed
|
||||||
elif response.status == 204:
|
elif response.status == 204:
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
except _:
|
||||||
await asyncio.sleep(poll_interval)
|
await asyncio.sleep(poll_interval)
|
||||||
else:
|
|
||||||
response.raise_for_status()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error while polling: {e}")
|
|
||||||
await asyncio.sleep(poll_interval)
|
|
||||||
|
|
||||||
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
|
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user