mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
171 lines
7.2 KiB
Python
171 lines
7.2 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import List
|
|
|
|
import jwt
|
|
import pytest
|
|
import requests
|
|
from testcontainers.rabbitmq import RabbitMqContainer
|
|
|
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
|
from comfy.component_model.make_mutable import make_mutable
|
|
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
|
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
|
from comfy.distributed.server_stub import ServerStub
|
|
|
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
|
os.environ["TC_HOST"] = "localhost"
|
|
|
|
|
|
def create_test_prompt() -> QueueItem:
|
|
from comfy.cmd.execution import validate_prompt
|
|
|
|
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
|
validation_tuple = validate_prompt(prompt)
|
|
item_id = str(uuid.uuid4())
|
|
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
|
return QueueItem(queue_tuple, None)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sign_jwt_auth_none():
|
|
client_id = str(uuid.uuid4())
|
|
user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none")
|
|
user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False})
|
|
assert user_token["sub"] == client_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_queue_worker() -> None:
|
|
# there are lots of side effects from importing that we have to deal with
|
|
|
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
|
params = rabbitmq.get_connection_params()
|
|
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"):
|
|
# this unfortunately does a bunch of initialization on the test thread
|
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
|
# now submit some jobs
|
|
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
|
await distributed_queue.init()
|
|
queue_item = create_test_prompt()
|
|
res: TaskInvocation = await distributed_queue.put_async(queue_item)
|
|
assert res.item_id == queue_item.prompt_id
|
|
assert len(res.outputs) == 1
|
|
assert res.status is not None
|
|
assert res.status.status_str == "success"
|
|
await distributed_queue.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_distributed_prompt_queues_same_process():
|
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
|
params = rabbitmq.get_connection_params()
|
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
|
|
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
|
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
|
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
|
test_prompt = create_test_prompt()
|
|
test_prompt.completed = asyncio.Future()
|
|
|
|
frontend.put(test_prompt)
|
|
|
|
# start a worker thread
|
|
thread_pool = ThreadPoolExecutor(max_workers=1)
|
|
|
|
async def in_thread():
|
|
incoming, incoming_prompt_id = worker.get()
|
|
assert incoming is not None
|
|
incoming_named = NamedQueueTuple(incoming)
|
|
assert incoming_named.prompt_id == incoming_prompt_id
|
|
async with EmbeddedComfyClient() as embedded_comfy_client:
|
|
outputs = await embedded_comfy_client.queue_prompt(incoming_named.prompt,
|
|
incoming_named.prompt_id)
|
|
worker.task_done(incoming_named.prompt_id, outputs, ExecutionStatus("success", True, []))
|
|
|
|
thread_pool.submit(lambda: asyncio.run(in_thread()))
|
|
# this was completed over the comfyui queue interface, so it should be a task invocation
|
|
frontend_pov_result: TaskInvocation = await test_prompt.completed
|
|
assert frontend_pov_result is not None
|
|
assert frontend_pov_result.item_id == test_prompt.prompt_id
|
|
assert frontend_pov_result.outputs is not None
|
|
assert len(frontend_pov_result.outputs) == 1
|
|
assert frontend_pov_result.status is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_frontend_backend_workers():
|
|
processes_to_close: List[subprocess.Popen] = []
|
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
|
try:
|
|
params = rabbitmq.get_connection_params()
|
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
|
|
|
frontend_command = [
|
|
"comfyui",
|
|
"--listen=0.0.0.0",
|
|
"--port=9001",
|
|
"--cpu",
|
|
"--distributed-queue-frontend",
|
|
f"--distributed-queue-connection-uri={connection_uri}",
|
|
]
|
|
|
|
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
|
backend_command = [
|
|
"comfyui-worker",
|
|
"--port=9002",
|
|
f"--distributed-queue-connection-uri={connection_uri}",
|
|
]
|
|
|
|
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
|
server_address = f"http://{get_lan_ip()}:8188"
|
|
start_time = time.time()
|
|
while time.time() - start_time < 60:
|
|
try:
|
|
response = requests.get(server_address)
|
|
if response.status_code == 200:
|
|
break
|
|
except ConnectionRefusedError:
|
|
pass
|
|
except Exception as exc:
|
|
logging.warning("", exc_info=exc)
|
|
time.sleep(1)
|
|
|
|
client = AsyncRemoteComfyClient(server_address=server_address)
|
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
|
png_image_bytes = await client.queue_prompt(prompt)
|
|
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
|
|
finally:
|
|
for process in processes_to_close:
|
|
process.terminate()
|
|
|
|
|
|
def get_lan_ip():
|
|
"""
|
|
Finds the host's IP address on the LAN it's connected to.
|
|
|
|
Returns:
|
|
str: The IP address of the host on the LAN.
|
|
"""
|
|
# Create a dummy socket
|
|
s = None
|
|
try:
|
|
# Connect to a dummy address (Here, Google's public DNS server)
|
|
# The actual connection is not made, but this allows finding out the LAN IP
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
s.connect(("8.8.8.8", 80))
|
|
ip = s.getsockname()[0]
|
|
finally:
|
|
if s is not None:
|
|
s.close()
|
|
return ip
|