mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Improve distributed testing
This commit is contained in:
parent
3f4049c5f4
commit
59cf9e5d93
5
.github/workflows/test-build.yml
vendored
5
.github/workflows/test-build.yml
vendored
@ -28,4 +28,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .
|
||||
pip install .
|
||||
- name: Run distributed tests
|
||||
run: |
|
||||
pytest tests/distributed
|
||||
@ -23,6 +23,7 @@ class ExecutingMessage(TypedDict):
|
||||
node: str | None
|
||||
prompt_id: NotRequired[str]
|
||||
output: NotRequired[dict]
|
||||
sid: NotRequired[str]
|
||||
|
||||
|
||||
class ProgressMessage(TypedDict):
|
||||
@ -30,6 +31,7 @@ class ProgressMessage(TypedDict):
|
||||
max: float
|
||||
prompt_id: Optional[str]
|
||||
node: Optional[str]
|
||||
sid: NotRequired[str]
|
||||
|
||||
|
||||
ExecutedMessage: TypeAlias = ExecutingMessage
|
||||
|
||||
@ -25,7 +25,9 @@ def _get_name(queue_name: str, user_id: str) -> str:
|
||||
|
||||
|
||||
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True):
|
||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=False):
|
||||
if receive_all_progress_notifications:
|
||||
raise NotImplementedError("this release does not yet support sending all progress notifications from the hook over the network")
|
||||
self._rpc = rpc
|
||||
self._queue_name = queue_name
|
||||
self._loop = loop
|
||||
@ -33,6 +35,7 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
self.client_id = None
|
||||
self.node_id = None
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
if receive_all_progress_notifications:
|
||||
hijack_progress(self)
|
||||
|
||||
|
||||
@ -92,7 +92,11 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
|
||||
# load from entrypoints
|
||||
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
||||
# Load the module associated with the current entry point
|
||||
module = entry_point.load()
|
||||
try:
|
||||
module = entry_point.load()
|
||||
except ModuleNotFoundError as module_not_found_error:
|
||||
logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
||||
continue
|
||||
|
||||
# Ensure that what we've loaded is indeed a module
|
||||
if isinstance(module, types.ModuleType):
|
||||
|
||||
@ -8,7 +8,8 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
@pytest.mark.asyncio
|
||||
async def test_completes_prompt(comfy_background_server):
|
||||
client = AsyncRemoteComfyClient()
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||
random_seed = random.randint(1,4294967295)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||
png_image_bytes = await client.queue_prompt(prompt)
|
||||
assert len(png_image_bytes) > 1000
|
||||
|
||||
@ -16,7 +17,8 @@ async def test_completes_prompt(comfy_background_server):
|
||||
@pytest.mark.asyncio
|
||||
async def test_completes_prompt_with_ui(comfy_background_server):
|
||||
client = AsyncRemoteComfyClient()
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||
random_seed = random.randint(1,4294967295)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||
result_dict = await client.queue_prompt_ui(prompt)
|
||||
# should contain one output
|
||||
assert len(result_dict) == 1
|
||||
|
||||
@ -1,11 +1,20 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import subprocess
|
||||
import socket
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
@ -95,3 +104,69 @@ async def test_distributed_prompt_queues_same_process():
|
||||
assert frontend_pov_result.outputs is not None
|
||||
assert len(frontend_pov_result.outputs) == 1
|
||||
assert frontend_pov_result.status is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frontend_backend_workers():
|
||||
processes_to_close: List[subprocess.Popen] = []
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
try:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
|
||||
frontend_command = [
|
||||
"comfyui",
|
||||
"--listen=0.0.0.0",
|
||||
"--cpu",
|
||||
"--distributed-queue-frontend",
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
]
|
||||
|
||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
backend_command = [
|
||||
"comfyui-worker",
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
]
|
||||
|
||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
server_address = f"http://{get_lan_ip()}:8188"
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 60:
|
||||
try:
|
||||
response = requests.get(server_address)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logging.warning("", exc_info=exc)
|
||||
time.sleep(1)
|
||||
|
||||
client = AsyncRemoteComfyClient(server_address=server_address)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||
png_image_bytes = await client.queue_prompt(prompt)
|
||||
assert len(png_image_bytes) > 1000
|
||||
finally:
|
||||
for process in processes_to_close:
|
||||
process.terminate()
|
||||
|
||||
|
||||
def get_lan_ip():
|
||||
"""
|
||||
Finds the host's IP address on the LAN it's connected to.
|
||||
|
||||
Returns:
|
||||
str: The IP address of the host on the LAN.
|
||||
"""
|
||||
# Create a dummy socket
|
||||
s = None
|
||||
try:
|
||||
# Connect to a dummy address (Here, Google's public DNS server)
|
||||
# The actual connection is not made, but this allows finding out the LAN IP
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 80))
|
||||
ip = s.getsockname()[0]
|
||||
finally:
|
||||
if s is not None:
|
||||
s.close()
|
||||
return ip
|
||||
|
||||
Loading…
Reference in New Issue
Block a user