mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
98 lines
4.8 KiB
Python
98 lines
4.8 KiB
Python
import asyncio
|
|
import os
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
import jwt
|
|
import pytest
|
|
|
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
|
from comfy.distributed.server_stub import ServerStub
|
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
|
from comfy.component_model.make_mutable import make_mutable
|
|
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
|
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
|
from testcontainers.rabbitmq import RabbitMqContainer
|
|
|
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
|
os.environ["TC_HOST"] = "localhost"
|
|
|
|
|
|
def create_test_prompt() -> QueueItem:
|
|
from comfy.cmd.execution import validate_prompt
|
|
|
|
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
|
validation_tuple = validate_prompt(prompt)
|
|
item_id = str(uuid.uuid4())
|
|
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
|
return QueueItem(queue_tuple, None)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sign_jwt_auth_none():
|
|
client_id = str(uuid.uuid4())
|
|
user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none")
|
|
user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False})
|
|
assert user_token["sub"] == client_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_queue_worker() -> None:
|
|
# there are lots of side effects from importing that we have to deal with
|
|
|
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
|
params = rabbitmq.get_connection_params()
|
|
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"):
|
|
# this unfortunately does a bunch of initialization on the test thread
|
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
|
# now submit some jobs
|
|
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
|
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
|
await distributed_queue.init()
|
|
queue_item = create_test_prompt()
|
|
res: TaskInvocation = await distributed_queue.put_async(queue_item)
|
|
assert res.item_id == queue_item.prompt_id
|
|
assert len(res.outputs) == 1
|
|
assert res.status is not None
|
|
assert res.status.status_str == "success"
|
|
await distributed_queue.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_distributed_prompt_queues_same_process():
|
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
|
params = rabbitmq.get_connection_params()
|
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
|
|
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
|
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
|
connection_uri=connection_uri) as frontend:
|
|
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False,
|
|
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
|
test_prompt = create_test_prompt()
|
|
test_prompt.completed = asyncio.Future()
|
|
|
|
frontend.put(test_prompt)
|
|
|
|
# start a worker thread
|
|
thread_pool = ThreadPoolExecutor(max_workers=1)
|
|
|
|
async def in_thread():
|
|
incoming, incoming_prompt_id = worker.get()
|
|
assert incoming is not None
|
|
incoming_named = NamedQueueTuple(incoming)
|
|
assert incoming_named.prompt_id == incoming_prompt_id
|
|
async with EmbeddedComfyClient() as embedded_comfy_client:
|
|
outputs = await embedded_comfy_client.queue_prompt(incoming_named.prompt,
|
|
incoming_named.prompt_id)
|
|
worker.task_done(incoming_named.prompt_id, outputs, ExecutionStatus("success", True, []))
|
|
|
|
thread_pool.submit(lambda: asyncio.run(in_thread()))
|
|
# this was completed over the comfyui queue interface, so it should be a task invocation
|
|
frontend_pov_result: TaskInvocation = await test_prompt.completed
|
|
assert frontend_pov_result is not None
|
|
assert frontend_pov_result.item_id == test_prompt.prompt_id
|
|
assert frontend_pov_result.outputs is not None
|
|
assert len(frontend_pov_result.outputs) == 1
|
|
assert frontend_pov_result.status is not None
|