Improve distributed testing

This commit is contained in:
doctorpangloss 2024-03-20 20:43:21 -07:00
parent 3f4049c5f4
commit 59cf9e5d93
6 changed files with 94 additions and 5 deletions

View File

@ -28,4 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .
pip install .
- name: Run distributed tests
run: |
pytest tests/distributed

View File

@ -23,6 +23,7 @@ class ExecutingMessage(TypedDict):
node: str | None
prompt_id: NotRequired[str]
output: NotRequired[dict]
sid: NotRequired[str]
class ProgressMessage(TypedDict):
@ -30,6 +31,7 @@ class ProgressMessage(TypedDict):
max: float
prompt_id: Optional[str]
node: Optional[str]
sid: NotRequired[str]
ExecutedMessage: TypeAlias = ExecutingMessage

View File

@ -25,7 +25,9 @@ def _get_name(queue_name: str, user_id: str) -> str:
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True):
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=False):
if receive_all_progress_notifications:
raise NotImplementedError("this release does not yet support sending all progress notifications from the hook over the network")
self._rpc = rpc
self._queue_name = queue_name
self._loop = loop
@ -33,6 +35,7 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
self.client_id = None
self.node_id = None
self.last_node_id = None
self.last_prompt_id = None
if receive_all_progress_notifications:
hijack_progress(self)

View File

@ -92,7 +92,11 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
# load from entrypoints
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
# Load the module associated with the current entry point
module = entry_point.load()
try:
module = entry_point.load()
except ModuleNotFoundError as module_not_found_error:
logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
continue
# Ensure that what we've loaded is indeed a module
if isinstance(module, types.ModuleType):

View File

@ -8,7 +8,8 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@pytest.mark.asyncio
async def test_completes_prompt(comfy_background_server):
client = AsyncRemoteComfyClient()
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
random_seed = random.randint(1,4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000
@ -16,7 +17,8 @@ async def test_completes_prompt(comfy_background_server):
@pytest.mark.asyncio
async def test_completes_prompt_with_ui(comfy_background_server):
client = AsyncRemoteComfyClient()
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
random_seed = random.randint(1,4294967295)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
result_dict = await client.queue_prompt_ui(prompt)
# should contain one output
assert len(result_dict) == 1

View File

@ -1,11 +1,20 @@
import asyncio
import logging
import os
import sys
import time
import uuid
import subprocess
import socket
from concurrent.futures import ThreadPoolExecutor
from typing import List
import jwt
import pytest
import requests
from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.distributed.server_stub import ServerStub
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@ -95,3 +104,69 @@ async def test_distributed_prompt_queues_same_process():
assert frontend_pov_result.outputs is not None
assert len(frontend_pov_result.outputs) == 1
assert frontend_pov_result.status is not None
@pytest.mark.asyncio
async def test_frontend_backend_workers():
processes_to_close: List[subprocess.Popen] = []
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
try:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
frontend_command = [
"comfyui",
"--listen=0.0.0.0",
"--cpu",
"--distributed-queue-frontend",
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
backend_command = [
"comfyui-worker",
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
server_address = f"http://{get_lan_ip()}:8188"
start_time = time.time()
while time.time() - start_time < 60:
try:
response = requests.get(server_address)
if response.status_code == 200:
break
except ConnectionRefusedError:
pass
except Exception as exc:
logging.warning("", exc_info=exc)
time.sleep(1)
client = AsyncRemoteComfyClient(server_address=server_address)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000
finally:
for process in processes_to_close:
process.terminate()
def get_lan_ip():
"""
Finds the host's IP address on the LAN it's connected to.
Returns:
str: The IP address of the host on the LAN.
"""
# Create a dummy socket
s = None
try:
# Connect to a dummy address (Here, Google's public DNS server)
# The actual connection is not made, but this allows finding out the LAN IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
finally:
if s is not None:
s.close()
return ip