diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 6dab40b11..1e018be35 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -28,4 +28,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . \ No newline at end of file + pip install . + - name: Run distributed tests + run: | + pytest tests/distributed \ No newline at end of file diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index ff06a17a8..38523aac1 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -23,6 +23,7 @@ class ExecutingMessage(TypedDict): node: str | None prompt_id: NotRequired[str] output: NotRequired[dict] + sid: NotRequired[str] class ProgressMessage(TypedDict): @@ -30,6 +31,7 @@ class ProgressMessage(TypedDict): max: float prompt_id: Optional[str] node: Optional[str] + sid: NotRequired[str] ExecutedMessage: TypeAlias = ExecutingMessage diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py index 49b0054ab..11f73fcfe 100644 --- a/comfy/distributed/distributed_progress.py +++ b/comfy/distributed/distributed_progress.py @@ -25,7 +25,9 @@ def _get_name(queue_name: str, user_id: str) -> str: class DistributedExecutorToClientProgress(ExecutorToClientProgress): - def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True): + def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=False): + if receive_all_progress_notifications: + raise NotImplementedError("this release does not yet support sending all progress notifications from the hook over the network") self._rpc = rpc self._queue_name = queue_name self._loop = loop @@ -33,6 +35,7 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress): self.client_id = None self.node_id = None self.last_node_id = None + self.last_prompt_id = None if receive_all_progress_notifications: hijack_progress(self) diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index 4a8c8b721..994061b5e 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -92,7 +92,11 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes: # load from entrypoints for entry_point in entry_points().select(group='comfyui.custom_nodes'): # Load the module associated with the current entry point - module = entry_point.load() + try: + module = entry_point.load() + except ModuleNotFoundError as module_not_found_error: + logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error) + continue # Ensure that what we've loaded is indeed a module if isinstance(module, types.ModuleType): diff --git a/tests/distributed/test_asyncio_remote_client.py b/tests/distributed/test_asyncio_remote_client.py index d0aab0650..6382857fe 100644 --- a/tests/distributed/test_asyncio_remote_client.py +++ b/tests/distributed/test_asyncio_remote_client.py @@ -8,7 +8,8 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner @pytest.mark.asyncio async def test_completes_prompt(comfy_background_server): client = AsyncRemoteComfyClient() - prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + random_seed = random.randint(1,4294967295) + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) png_image_bytes = await client.queue_prompt(prompt) assert len(png_image_bytes) > 1000 @@ -16,7 +17,8 @@ async def test_completes_prompt(comfy_background_server): @pytest.mark.asyncio async def test_completes_prompt_with_ui(comfy_background_server): client = AsyncRemoteComfyClient() - prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + random_seed = random.randint(1,4294967295) + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1) result_dict = await client.queue_prompt_ui(prompt) # should contain one output assert len(result_dict) == 1 diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 3801c3ef3..1251cddf0 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -1,11 +1,20 @@ import asyncio +import logging import os +import sys +import time import uuid +import subprocess +import socket + from concurrent.futures import ThreadPoolExecutor +from typing import List import jwt import pytest +import requests +from comfy.client.aio_client import AsyncRemoteComfyClient from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.distributed.server_stub import ServerStub from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner @@ -95,3 +104,69 @@ async def test_distributed_prompt_queues_same_process(): assert frontend_pov_result.outputs is not None assert len(frontend_pov_result.outputs) == 1 assert frontend_pov_result.status is not None + + +@pytest.mark.asyncio +async def test_frontend_backend_workers(): + processes_to_close: List[subprocess.Popen] = [] + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + try: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + frontend_command = [ + "comfyui", + "--listen=0.0.0.0", + "--cpu", + "--distributed-queue-frontend", + f"--distributed-queue-connection-uri={connection_uri}", + ] + + processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr)) + backend_command = [ + "comfyui-worker", + f"--distributed-queue-connection-uri={connection_uri}", + ] + + processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) + server_address = f"http://{get_lan_ip()}:8188" + start_time = time.time() + while time.time() - start_time < 60: + try: + response = requests.get(server_address) + if response.status_code == 200: + break + except ConnectionRefusedError: + pass + except Exception as exc: + logging.warning("", exc_info=exc) + time.sleep(1) + + client = AsyncRemoteComfyClient(server_address=server_address) + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + png_image_bytes = await client.queue_prompt(prompt) + assert len(png_image_bytes) > 1000 + finally: + for process in processes_to_close: + process.terminate() + + +def get_lan_ip(): + """ + Finds the host's IP address on the LAN it's connected to. + + Returns: + str: The IP address of the host on the LAN. + """ + # Create a dummy socket + s = None + try: + # Connect to a dummy address (Here, Google's public DNS server) + # The actual connection is not made, but this allows finding out the LAN IP + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + ip = s.getsockname()[0] + finally: + if s is not None: + s.close() + return ip