import asyncio import os import uuid from concurrent.futures import ThreadPoolExecutor import jwt import pytest from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.distributed.server_stub import ServerStub from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner from comfy.component_model.make_mutable import make_mutable from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker from testcontainers.rabbitmq import RabbitMqContainer # fixes issues with running the testcontainers rabbitmqcontainer on Windows os.environ["TC_HOST"] = "localhost" def create_test_prompt() -> QueueItem: from comfy.cmd.execution import validate_prompt prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) validation_tuple = validate_prompt(prompt) item_id = str(uuid.uuid4()) queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) return QueueItem(queue_tuple, None) @pytest.mark.asyncio async def test_sign_jwt_auth_none(): client_id = str(uuid.uuid4()) user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none") user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False}) assert user_token["sub"] == client_id @pytest.mark.asyncio async def test_basic_queue_worker() -> None: # there are lots of side effects from importing that we have to deal with with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"): # this unfortunately does a bunch of initialization on the test thread from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue # now submit some jobs distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") await distributed_queue.init() queue_item = create_test_prompt() res: TaskInvocation = await distributed_queue.put_async(queue_item) assert res.item_id == queue_item.prompt_id assert len(res.outputs) == 1 assert res.status is not None assert res.status.status_str == "success" await distributed_queue.close() @pytest.mark.asyncio async def test_distributed_prompt_queues_same_process(): with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker: test_prompt = create_test_prompt() test_prompt.completed = asyncio.Future() frontend.put(test_prompt) # start a worker thread thread_pool = ThreadPoolExecutor(max_workers=1) async def in_thread(): incoming, incoming_prompt_id = worker.get() assert incoming is not None incoming_named = NamedQueueTuple(incoming) assert incoming_named.prompt_id == incoming_prompt_id async with EmbeddedComfyClient() as embedded_comfy_client: outputs = await embedded_comfy_client.queue_prompt(incoming_named.prompt, incoming_named.prompt_id) worker.task_done(incoming_named.prompt_id, outputs, ExecutionStatus("success", True, [])) thread_pool.submit(lambda: asyncio.run(in_thread())) # this was completed over the comfyui queue interface, so it should be a task invocation frontend_pov_result: TaskInvocation = await test_prompt.completed assert frontend_pov_result is not None assert frontend_pov_result.item_id == test_prompt.prompt_id assert frontend_pov_result.outputs is not None assert len(frontend_pov_result.outputs) == 1 assert frontend_pov_result.status is not None