ComfyUI/tests/distributed/test_distributed_queue.py

240 lines
11 KiB
Python

import asyncio
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Callable
import jwt
import pytest
from aiohttp import ClientSession
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.executor_types import Executor
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.server_stub import ServerStub
def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
validation_tuple = validate_prompt(prompt)
item_id = str(uuid.uuid4())
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
return QueueItem(queue_tuple, None)
@pytest.mark.asyncio
async def test_sign_jwt_auth_none():
client_id = str(uuid.uuid4())
user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none")
user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False})
assert user_token["sub"] == client_id
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None:
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}", executor=executor_factory(max_workers=1)):
# this unfortunately does a bunch of initialization on the test thread
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
# now submit some jobs
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
await distributed_queue.init()
queue_item = create_test_prompt()
res: TaskInvocation = await distributed_queue.put_async(queue_item)
assert res.item_id == queue_item.prompt_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
await distributed_queue.close()
@pytest.mark.asyncio
async def test_distributed_prompt_queues_same_process():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
test_prompt = create_test_prompt()
test_prompt.completed = asyncio.Future()
frontend.put(test_prompt)
# start a worker thread
thread_pool = ThreadPoolExecutor(max_workers=1)
async def in_thread():
incoming, incoming_prompt_id = worker.get()
assert incoming is not None
incoming_named = NamedQueueTuple(incoming)
assert incoming_named.prompt_id == incoming_prompt_id
async with EmbeddedComfyClient() as embedded_comfy_client:
outputs = await embedded_comfy_client.queue_prompt(incoming_named.prompt,
incoming_named.prompt_id)
worker.task_done(incoming_named.prompt_id, outputs, ExecutionStatus("success", True, []))
thread_pool.submit(lambda: asyncio.run(in_thread()))
# this was completed over the comfyui queue interface, so it should be a task invocation
frontend_pov_result: TaskInvocation = await test_prompt.completed
assert frontend_pov_result is not None
assert frontend_pov_result.item_id == test_prompt.prompt_id
assert frontend_pov_result.outputs is not None
assert len(frontend_pov_result.outputs) == 1
assert frontend_pov_result.status is not None
@pytest.mark.asyncio
async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq):
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
len_queue_after = await client.len_queue()
assert len_queue_after == 0
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
@pytest.mark.asyncio
async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq):
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
with pytest.raises(Exception):
await client.queue_prompt(prompt)
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
async with ClientSession() as session:
for _ in range(max_retries):
try:
async with session.get(url, timeout=1) as response:
if response.status == 200:
return True
except Exception as exc_info:
pass
await asyncio.sleep(retry_delay)
return False
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
async def test_basic_queue_worker_with_health_check(executor_factory):
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
health_check_port = 9090
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port, executor=executor_factory(max_workers=1)) as worker:
health_check_url = f"http://localhost:{health_check_port}/health"
health_check_ok = await check_health(health_check_url)
assert health_check_ok, "Health check server did not start properly"
@pytest.mark.asyncio
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
# Create the client using the server address from the fixture
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
# Create a test prompt
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
# Queue the prompt
task_id = await client.queue_and_forget_prompt_api(prompt)
assert task_id is not None, "Failed to get a valid task ID"
# Poll for the result
max_attempts = 60 # Increase max attempts for integration test
poll_interval = 1 # Increase poll interval for integration test
for _ in range(max_attempts):
try:
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
if response.status == 200:
result = await response.json()
assert result is not None, "Received empty result"
# Find the first output node with images
output_node = next((node for node in result.values() if 'images' in node), None)
assert output_node is not None, "No output node with images found"
assert len(output_node['images']) > 0, "No images in output node"
assert 'filename' in output_node['images'][0], "No filename in image output"
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
assert 'type' in output_node['images'][0], "No type in image output"
# Check if we can access the image
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
image_response = await client.session.get(image_url)
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
return # Test passed
elif response.status == 204:
await asyncio.sleep(poll_interval)
else:
response.raise_for_status()
except Exception as e:
print(f"Error while polling: {e}")
await asyncio.sleep(poll_interval)
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
class TestWorker(DistributedPromptWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.processed_workflows: set[str] = set()
async def on_will_complete_work_item(self, request: dict):
workflow_id = request.get('prompt_id', 'unknown')
self.processed_workflows.add(workflow_id)
await super().on_will_complete_work_item(request)
@pytest.mark.asyncio
async def test_two_workers_distinct_requests():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
# Start two test workers
workers: list[TestWorker] = []
for i in range(2):
worker = TestWorker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1))
await worker.init()
workers.append(worker)
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
queue = DistributedPromptQueue(is_callee=False, is_caller=True, connection_uri=connection_uri)
await queue.init()
# Submit two prompts
task1 = asyncio.create_task(queue.put_async(create_test_prompt()))
task2 = asyncio.create_task(queue.put_async(create_test_prompt()))
# Wait for tasks to complete
await asyncio.gather(task1, task2)
# Clean up
for worker in workers:
await worker.close()
await queue.close()
# Assert that each worker processed exactly one distinct workflow
all_workflows = set()
for worker in workers:
assert len(worker.processed_workflows) == 1, f"Worker processed {len(worker.processed_workflows)} workflows instead of 1"
all_workflows.update(worker.processed_workflows)
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"