ComfyUI/tests/distributed/test_distributed_queue.py

171 lines
7.2 KiB
Python

import asyncio
import logging
import os
import socket
import subprocess
import sys
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import List
import jwt
import pytest
import requests
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from comfy.distributed.server_stub import ServerStub
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
validation_tuple = validate_prompt(prompt)
item_id = str(uuid.uuid4())
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
return QueueItem(queue_tuple, None)
@pytest.mark.asyncio
async def test_sign_jwt_auth_none():
client_id = str(uuid.uuid4())
user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none")
user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False})
assert user_token["sub"] == client_id
@pytest.mark.asyncio
async def test_basic_queue_worker() -> None:
# there are lots of side effects from importing that we have to deal with
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"):
# this unfortunately does a bunch of initialization on the test thread
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
# now submit some jobs
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
await distributed_queue.init()
queue_item = create_test_prompt()
res: TaskInvocation = await distributed_queue.put_async(queue_item)
assert res.item_id == queue_item.prompt_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
await distributed_queue.close()
@pytest.mark.asyncio
async def test_distributed_prompt_queues_same_process():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
test_prompt = create_test_prompt()
test_prompt.completed = asyncio.Future()
frontend.put(test_prompt)
# start a worker thread
thread_pool = ThreadPoolExecutor(max_workers=1)
async def in_thread():
incoming, incoming_prompt_id = worker.get()
assert incoming is not None
incoming_named = NamedQueueTuple(incoming)
assert incoming_named.prompt_id == incoming_prompt_id
async with EmbeddedComfyClient() as embedded_comfy_client:
outputs = await embedded_comfy_client.queue_prompt(incoming_named.prompt,
incoming_named.prompt_id)
worker.task_done(incoming_named.prompt_id, outputs, ExecutionStatus("success", True, []))
thread_pool.submit(lambda: asyncio.run(in_thread()))
# this was completed over the comfyui queue interface, so it should be a task invocation
frontend_pov_result: TaskInvocation = await test_prompt.completed
assert frontend_pov_result is not None
assert frontend_pov_result.item_id == test_prompt.prompt_id
assert frontend_pov_result.outputs is not None
assert len(frontend_pov_result.outputs) == 1
assert frontend_pov_result.status is not None
@pytest.mark.asyncio
async def test_frontend_backend_workers():
processes_to_close: List[subprocess.Popen] = []
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
try:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
frontend_command = [
"comfyui",
"--listen=0.0.0.0",
"--port=9001",
"--cpu",
"--distributed-queue-frontend",
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
backend_command = [
"comfyui-worker",
"--port=9002",
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
server_address = f"http://{get_lan_ip()}:8188"
start_time = time.time()
while time.time() - start_time < 60:
try:
response = requests.get(server_address)
if response.status_code == 200:
break
except ConnectionRefusedError:
pass
except Exception as exc:
logging.warning("", exc_info=exc)
time.sleep(1)
client = AsyncRemoteComfyClient(server_address=server_address)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
finally:
for process in processes_to_close:
process.terminate()
def get_lan_ip():
"""
Finds the host's IP address on the LAN it's connected to.
Returns:
str: The IP address of the host on the LAN.
"""
# Create a dummy socket
s = None
try:
# Connect to a dummy address (Here, Google's public DNS server)
# The actual connection is not made, but this allows finding out the LAN IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
finally:
if s is not None:
s.close()
return ip