mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 20:30:25 +08:00
Improve distributed testing
This commit is contained in:
parent
3f4049c5f4
commit
59cf9e5d93
5
.github/workflows/test-build.yml
vendored
5
.github/workflows/test-build.yml
vendored
@ -28,4 +28,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .
|
pip install .
|
||||||
|
- name: Run distributed tests
|
||||||
|
run: |
|
||||||
|
pytest tests/distributed
|
||||||
@ -23,6 +23,7 @@ class ExecutingMessage(TypedDict):
|
|||||||
node: str | None
|
node: str | None
|
||||||
prompt_id: NotRequired[str]
|
prompt_id: NotRequired[str]
|
||||||
output: NotRequired[dict]
|
output: NotRequired[dict]
|
||||||
|
sid: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
class ProgressMessage(TypedDict):
|
class ProgressMessage(TypedDict):
|
||||||
@ -30,6 +31,7 @@ class ProgressMessage(TypedDict):
|
|||||||
max: float
|
max: float
|
||||||
prompt_id: Optional[str]
|
prompt_id: Optional[str]
|
||||||
node: Optional[str]
|
node: Optional[str]
|
||||||
|
sid: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
ExecutedMessage: TypeAlias = ExecutingMessage
|
ExecutedMessage: TypeAlias = ExecutingMessage
|
||||||
|
|||||||
@ -25,7 +25,9 @@ def _get_name(queue_name: str, user_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True):
|
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=False):
|
||||||
|
if receive_all_progress_notifications:
|
||||||
|
raise NotImplementedError("this release does not yet support sending all progress notifications from the hook over the network")
|
||||||
self._rpc = rpc
|
self._rpc = rpc
|
||||||
self._queue_name = queue_name
|
self._queue_name = queue_name
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
@ -33,6 +35,7 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
|||||||
self.client_id = None
|
self.client_id = None
|
||||||
self.node_id = None
|
self.node_id = None
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
|
self.last_prompt_id = None
|
||||||
if receive_all_progress_notifications:
|
if receive_all_progress_notifications:
|
||||||
hijack_progress(self)
|
hijack_progress(self)
|
||||||
|
|
||||||
|
|||||||
@ -92,7 +92,11 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
|
|||||||
# load from entrypoints
|
# load from entrypoints
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
||||||
# Load the module associated with the current entry point
|
# Load the module associated with the current entry point
|
||||||
module = entry_point.load()
|
try:
|
||||||
|
module = entry_point.load()
|
||||||
|
except ModuleNotFoundError as module_not_found_error:
|
||||||
|
logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
||||||
|
continue
|
||||||
|
|
||||||
# Ensure that what we've loaded is indeed a module
|
# Ensure that what we've loaded is indeed a module
|
||||||
if isinstance(module, types.ModuleType):
|
if isinstance(module, types.ModuleType):
|
||||||
|
|||||||
@ -8,7 +8,8 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completes_prompt(comfy_background_server):
|
async def test_completes_prompt(comfy_background_server):
|
||||||
client = AsyncRemoteComfyClient()
|
client = AsyncRemoteComfyClient()
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
random_seed = random.randint(1,4294967295)
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||||
png_image_bytes = await client.queue_prompt(prompt)
|
png_image_bytes = await client.queue_prompt(prompt)
|
||||||
assert len(png_image_bytes) > 1000
|
assert len(png_image_bytes) > 1000
|
||||||
|
|
||||||
@ -16,7 +17,8 @@ async def test_completes_prompt(comfy_background_server):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completes_prompt_with_ui(comfy_background_server):
|
async def test_completes_prompt_with_ui(comfy_background_server):
|
||||||
client = AsyncRemoteComfyClient()
|
client = AsyncRemoteComfyClient()
|
||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
random_seed = random.randint(1,4294967295)
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, seed=random_seed, refiner_steps=1)
|
||||||
result_dict = await client.queue_prompt_ui(prompt)
|
result_dict = await client.queue_prompt_ui(prompt)
|
||||||
# should contain one output
|
# should contain one output
|
||||||
assert len(result_dict) == 1
|
assert len(result_dict) == 1
|
||||||
|
|||||||
@ -1,11 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import subprocess
|
||||||
|
import socket
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
from comfy.distributed.server_stub import ServerStub
|
from comfy.distributed.server_stub import ServerStub
|
||||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||||
@ -95,3 +104,69 @@ async def test_distributed_prompt_queues_same_process():
|
|||||||
assert frontend_pov_result.outputs is not None
|
assert frontend_pov_result.outputs is not None
|
||||||
assert len(frontend_pov_result.outputs) == 1
|
assert len(frontend_pov_result.outputs) == 1
|
||||||
assert frontend_pov_result.status is not None
|
assert frontend_pov_result.status is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_frontend_backend_workers():
|
||||||
|
processes_to_close: List[subprocess.Popen] = []
|
||||||
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
|
try:
|
||||||
|
params = rabbitmq.get_connection_params()
|
||||||
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|
||||||
|
frontend_command = [
|
||||||
|
"comfyui",
|
||||||
|
"--listen=0.0.0.0",
|
||||||
|
"--cpu",
|
||||||
|
"--distributed-queue-frontend",
|
||||||
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
|
]
|
||||||
|
|
||||||
|
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
|
backend_command = [
|
||||||
|
"comfyui-worker",
|
||||||
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
|
]
|
||||||
|
|
||||||
|
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
|
server_address = f"http://{get_lan_ip()}:8188"
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < 60:
|
||||||
|
try:
|
||||||
|
response = requests.get(server_address)
|
||||||
|
if response.status_code == 200:
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logging.warning("", exc_info=exc)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
client = AsyncRemoteComfyClient(server_address=server_address)
|
||||||
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||||
|
png_image_bytes = await client.queue_prompt(prompt)
|
||||||
|
assert len(png_image_bytes) > 1000
|
||||||
|
finally:
|
||||||
|
for process in processes_to_close:
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
def get_lan_ip():
|
||||||
|
"""
|
||||||
|
Finds the host's IP address on the LAN it's connected to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The IP address of the host on the LAN.
|
||||||
|
"""
|
||||||
|
# Create a dummy socket
|
||||||
|
s = None
|
||||||
|
try:
|
||||||
|
# Connect to a dummy address (Here, Google's public DNS server)
|
||||||
|
# The actual connection is not made, but this allows finding out the LAN IP
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
s.connect(("8.8.8.8", 80))
|
||||||
|
ip = s.getsockname()[0]
|
||||||
|
finally:
|
||||||
|
if s is not None:
|
||||||
|
s.close()
|
||||||
|
return ip
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user