Improve tests and distributed error notifications

- Tests now perform faster
 - Tests will run on supported GPU platforms
 - Configuration has known issues related to setting up a working
   directory for an embedded client
 - Introduce a Skeletonize node that solves many problems with Canny
 - Improve behavior of exception reporting
This commit is contained in:
doctorpangloss 2024-07-04 10:16:02 -07:00
parent dbc2a4ba29
commit 95d47276e9
26 changed files with 289 additions and 779 deletions

View File

@ -41,10 +41,10 @@ jobs:
run: |
export HSA_OVERRIDE_GFX_VERSION=11.0.0
pytest -v tests/unit
- name: Run lora workflow
- name: Run all tests
run: |
export HSA_OVERRIDE_GFX_VERSION=11.0.0
pytest -v tests/workflows
pytest -v tests
- name: Lint for errors
run: |
pylint comfy

View File

@ -293,7 +293,7 @@ class ControlNet(nn.Module):
hs = []
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
assert y.shape[0] == x.shape[0], "There may be a mismatch between the ControlNet and Diffusion models being used"
emb = emb + self.label_emb(y)
h = x

View File

@ -11,11 +11,11 @@ from aiohttp import WSMessage, ClientResponse
from typing_extensions import Dict
from .client_types import V1QueuePromptResponse
from ..api.schemas import immutabledict
from ..api.components.schema.prompt import PromptDict
from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt import PromptDict
from ..api.components.schema.prompt_request import PromptRequest
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
from ..api.schemas import immutabledict
from ..component_model.file_output_path import file_output_path
@ -34,6 +34,15 @@ class AsyncRemoteComfyClient:
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
self.loop = loop or asyncio.get_event_loop()
async def len_queue(self) -> int:
async with aiohttp.ClientSession() as session:
async with session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application.json'}) as response:
if response.status == 200:
exec_info_dict = await response.json()
return exec_info_dict["exec_info"]["queue_remaining"]
else:
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
"""
Calls the API to queue a prompt.
@ -71,7 +80,7 @@ class AsyncRemoteComfyClient:
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'image/png'}) as response:
if response.status == 200:
if 200 <= response.status < 400:
return await response.read()
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import gc
import uuid
from asyncio import AbstractEventLoop
from asyncio import get_event_loop
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
@ -57,13 +57,9 @@ class EmbeddedComfyClient:
In order to use this in blocking methods, learn more about asyncio online.
"""
def __init__(self, configuration: Optional[Configuration] = None,
progress_handler: Optional[ExecutorToClientProgress] = None,
loop: Optional[AbstractEventLoop] = None,
max_workers: int = 1):
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1):
self._progress_handler = progress_handler or ServerStub()
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._loop = loop or asyncio.get_event_loop()
self._configuration = configuration
# we don't want to import the executor yet
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
@ -93,7 +89,7 @@ class EmbeddedComfyClient:
while self._executor._work_queue.qsize() > 0:
await asyncio.sleep(0.1)
await self._loop.run_in_executor(self._executor, cleanup)
await get_event_loop().run_in_executor(self._executor, cleanup)
self._executor.shutdown(wait=True)
self._is_running = False
@ -112,8 +108,9 @@ class EmbeddedComfyClient:
from ..cmd.execution import PromptExecutor
self._prompt_executor = PromptExecutor(self._progress_handler)
self._prompt_executor.raise_exceptions = True
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
await get_event_loop().run_in_executor(self._executor, create_executor_in_thread)
@tracer.start_as_current_span("Queue Prompt")
async def queue_prompt(self,
@ -128,29 +125,26 @@ class EmbeddedComfyClient:
spam: Span
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
from ..cmd.execution import PromptExecutor, validate_prompt
prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut)
if not validation_tuple[0]:
span.set_status(Status(StatusCode.ERROR))
validation_error_dict = validation_tuple[1] or {"message": "Unknown", "details": ""}
error = ValueError("\n".join([validation_error_dict["message"], validation_error_dict["details"]]))
span.record_exception(error)
return {}
try:
prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut)
if not validation_tuple[0]:
validation_error_dict = validation_tuple[1] or {"message": "Unknown", "details": ""}
raise ValueError("\n".join([validation_error_dict["message"], validation_error_dict["details"]]))
prompt_executor: PromptExecutor = self._prompt_executor
prompt_executor: PromptExecutor = self._prompt_executor
if client_id is None:
prompt_executor.server = _server_stub_instance
else:
prompt_executor.server = self._progress_handler
if client_id is None:
prompt_executor.server = _server_stub_instance
else:
prompt_executor.server = self._progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2])
if prompt_executor.success:
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2])
return prompt_executor.outputs_ui
else:
except Exception as exc_info:
span.set_status(Status(StatusCode.ERROR))
error = RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
span.record_exception(error)
span.record_exception(exc_info)
raise exc_info
return await self._loop.run_in_executor(self._executor, execute_prompt)
return await get_event_loop().run_in_executor(self._executor, execute_prompt)

View File

@ -344,6 +344,7 @@ class PromptExecutor:
def __init__(self, server: ExecutorToClientProgress):
self.success = None
self.server = server
self.raise_exceptions = False
self.reset()
def reset(self):
@ -404,6 +405,9 @@ class PromptExecutor:
d = self.outputs.pop(o)
del d
if ex is not None and self.raise_exceptions:
raise ex
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
with new_execution_context(ExecutionContext(self.server)):
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)

View File

@ -61,8 +61,8 @@ def _create_tracer():
if has_endpoint:
otlp_exporter = OTLPSpanExporter()
elif is_debugging:
otlp_exporter = ConsoleSpanExporter()
# elif is_debugging:
# otlp_exporter = ConsoleSpanExporter("comfyui")
else:
otlp_exporter = SpanExporter()

View File

@ -7,7 +7,6 @@ import logging
import mimetypes
import os
import struct
import sys
import traceback
import uuid
from asyncio import Future, AbstractEventLoop
@ -19,14 +18,15 @@ from urllib.parse import quote, urlencode
import aiofiles
import aiohttp
import sys
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from aiohttp import web
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple
from .. import interruption
from .latent_preview_image_encoding import encode_preview_image
from .. import interruption
from .. import model_management
from .. import utils
from ..app.user_manager import UserManager
@ -35,7 +35,7 @@ from ..client.client_types import FileOutput
from ..cmd import execution
from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo
from ..component_model.file_output_path import file_output_path
from ..component_model.files import get_package_as_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
@ -778,8 +778,12 @@ class PromptServer(ExecutorToClientProgress):
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))
def queue_updated(self):
self.send_sync("status", {"status": self.get_queue_info()})
def queue_updated(self, queue_remaining: Optional[int] = None):
if queue_remaining is None:
status = {"status": self.get_queue_info()}
else:
status = StatusMessage(status=QueueInfo(exec_info=ExecInfo(queue_remaining=queue_remaining)))
self.send_sync("status", status)
async def publish_loop(self):
while True:

View File

@ -1,8 +1,9 @@
from __future__ import annotations # for Python 3.7-3.9
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
import PIL.Image
from typing_extensions import NotRequired, TypedDict
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
from .queue_types import BinaryEventTypes
@ -78,7 +79,7 @@ class ExecutorToClientProgress(Protocol):
"""
pass
def queue_updated(self):
def queue_updated(self, queue_remaining: Optional[int] = None):
"""
Indicates that the local client's queue has been updated
:return:

View File

@ -10,7 +10,7 @@ from typing import Optional, Dict, Any
from aio_pika.patterns import RPC
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
UnencodedPreviewImageMessage
UnencodedPreviewImageMessage, StatusMessage, QueueInfo, ExecInfo
from ..component_model.queue_types import BinaryEventTypes
@ -67,9 +67,8 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
sid: Optional[str] = None):
asyncio.run_coroutine_threadsafe(self.send(event, data, sid), self._loop)
def queue_updated(self):
# todo: this should gather the global queue data
pass
def queue_updated(self, queue_remaining: Optional[int] = None):
self.send_sync("status", StatusMessage(status=QueueInfo(exec_info=ExecInfo(queue_remaining=queue_remaining))))
class ProgressHandlers:

View File

@ -50,7 +50,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
return None
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
if self._caller_server is not None:
self._caller_server.queue_updated()
self._caller_server.queue_updated(self.get_tasks_remaining())
try:
if "token" in queue_item.extra_data:
user_token = queue_item.extra_data["token"]
@ -75,7 +75,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
reply = RpcReply(**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
self._caller_history.put(queue_item, reply.outputs, reply.status)
if self._caller_server is not None:
self._caller_server.queue_updated()
self._caller_server.queue_updated(self.get_tasks_remaining())
# if this has a completion future, complete it
if queue_item.completed is not None:
@ -86,7 +86,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
as_exec_exc = ExecutionError(queue_item.prompt_id, exceptions=[exc])
self._caller_history.put(queue_item, outputs={}, status=as_exec_exc.status)
# if we have a completer, propoagate the exception to it
# if we have a completer, propagate the exception to it
if queue_item.completed is not None:
queue_item.completed.set_exception(as_exec_exc)
raise as_exec_exc
@ -95,7 +95,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
if self._caller_server is not None:
# todo: this ensures that the web ui is notified about the completed task, but it should really be done by worker
self._caller_server.send_sync("executing", {"node": None, "prompt_id": queue_item.prompt_id}, self._caller_server.client_id)
self._caller_server.queue_updated()
self._caller_server.queue_updated(self.get_tasks_remaining())
return reply
def put(self, item: QueueItem):

View File

@ -69,8 +69,7 @@ class DistributedPromptWorker:
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
if self._embedded_comfy_client is None:
self._embedded_comfy_client = EmbeddedComfyClient(
progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import uuid
from typing import Literal
from typing import Literal, Optional
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage
from ..component_model.queue_types import BinaryEventTypes
@ -23,5 +23,5 @@ class ServerStub(ExecutorToClientProgress):
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
pass
def queue_updated(self):
def queue_updated(self, queue_remaining: Optional[int] = None):
pass

View File

@ -120,6 +120,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
logging.warning(f"Could not retrieve file {str(known_file)}")
else:
destination_with_filename = join(this_model_directory, save_filename)
os.makedirs(os.path.dirname(destination_with_filename), exist_ok=True)
try:
with _session.get(url, stream=True, allow_redirects=True) as response:
@ -300,6 +301,12 @@ KNOWN_CONTROLNETS = [
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/hed/controlnet/diffusion_pytorch_model.bin", save_with_filename="ControlNet-Plus-Plus_sd15_hed.bin", repo_type="space"),
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/lineart/controlnet/diffusion_pytorch_model.bin", save_with_filename="ControlNet-Plus-Plus_sd15_lineart.bin", repo_type="space"),
HuggingFile("limingcv/ControlNet-Plus-Plus", "checkpoints/seg/controlnet/diffusion_pytorch_model.safetensors", save_with_filename="ControlNet-Plus-Plus_sd15_ade20k_seg.safetensors", repo_type="space"),
HuggingFile("xinsir/controlnet-scribble-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-scribble-sdxl-1.0.safetensors"),
HuggingFile("xinsir/controlnet-canny-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-canny-sdxl-1.0.safetensors"),
HuggingFile("xinsir/controlnet-canny-sdxl-1.0", "diffusion_pytorch_model_V2.safetensors", save_with_filename="xinsir-controlnet-canny-sdxl-1.0_V2.safetensors"),
HuggingFile("xinsir/controlnet-openpose-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-openpose-sdxl-1.0.safetensors"),
HuggingFile("xinsir/anime-painter", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-anime-painter-scribble-sdxl-1.0.safetensors"),
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
]
KNOWN_DIFF_CONTROLNETS = [
@ -343,12 +350,17 @@ KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [
]
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
def add_known_models(folder_name: str, known_models: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
if len(models) < 1:
return known_models
if args.disable_known_models:
logging.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})")
symbol += models
pre_existing = frozenset(known_models)
known_models += [model for model in models if model not in pre_existing]
folder_paths.invalidate_cache(folder_name)
return symbol
return known_models
def huggingface_repos() -> List[str]:

View File

@ -7,7 +7,7 @@ from typing import Optional, List, Sequence
from typing_extensions import TypedDict, NotRequired
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class CivitFile:
"""
A file on CivitAI
@ -35,7 +35,7 @@ class CivitFile:
return []
@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class HuggingFile:
"""
A file on Huggingface Hub
@ -48,7 +48,7 @@ class HuggingFile:
repo_id: str
filename: str
save_with_filename: Optional[str] = None
alternate_filenames: List[str] = dataclasses.field(default_factory=list)
alternate_filenames: Sequence[str] = dataclasses.field(default_factory=tuple)
show_in_ui: Optional[bool] = True
convert_to_16_bit: Optional[bool] = False
size: Optional[int] = None

View File

@ -0,0 +1,52 @@
import torch
from skimage.morphology import skeletonize, thin
import comfy.model_management
class SkeletonizeThin:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"binary_threshold": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 0.99, "step": 0.01}),
"approach": (["skeletonize", "thinning"], {}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process_image"
CATEGORY = "image/preprocessors"
def process_image(self, image, binary_threshold, approach):
use_skeletonize = approach == "skeletonize"
use_thinning = approach == "thinning"
device = comfy.model_management.intermediate_device()
if len(image.shape) == 3:
image = image.unsqueeze(0)
batch_size, height, width, channels = image.shape
if channels == 3:
image = torch.mean(image, dim=-1, keepdim=True)
binary = (image > binary_threshold).float()
results = []
for img in binary:
img_np = img.squeeze().cpu().numpy()
if use_skeletonize:
result = skeletonize(img_np)
elif use_thinning:
result = thin(img_np)
else:
result = img_np
result = torch.from_numpy(result).float().to(device)
result = result.unsqueeze(-1).repeat(1, 1, 3)
results.append(result)
final_result = torch.stack(results).to(comfy.model_management.intermediate_device())
return (final_result,)
NODE_CLASS_MAPPINGS = {
"SkeletonizeThin": SkeletonizeThin,
}

View File

@ -2,4 +2,5 @@
markers =
inference: mark as inference test (deselect with '-m "not inference"')
testpaths = tests
addopts = -s
addopts = -s
asyncio_mode = auto

View File

@ -1,43 +0,0 @@
import os
import pytest
# Command line arguments for pytest
def pytest_addoption(parser):
parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images')
parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test')
parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics')
parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images')
# This initializes args at the beginning of the test session
@pytest.fixture(scope="session", autouse=False)
def args_pytest(pytestconfig):
args = {}
args['baseline_dir'] = pytestconfig.getoption('baseline_dir')
args['test_dir'] = pytestconfig.getoption('test_dir')
args['metrics_file'] = pytestconfig.getoption('metrics_file')
args['img_output_dir'] = pytestconfig.getoption('img_output_dir')
# Initialize metrics file
with open(args['metrics_file'], 'a') as f:
# if file is empty, write header
if os.stat(args['metrics_file']).st_size == 0:
f.write("| date | run | file | status | value | \n")
f.write("| --- | --- | --- | --- | --- | \n")
return args
def gather_file_basenames(directory: str):
files = []
if not os.path.isdir(directory):
return files
for file in os.listdir(directory):
if file.endswith(".png"):
files.append(file)
return files
# Creates the list of baseline file names to use as a fixture
def pytest_generate_tests(metafunc):
if "baseline_fname" in metafunc.fixturenames:
baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir"))
metafunc.parametrize("baseline_fname", baseline_fnames)

View File

@ -1,203 +0,0 @@
import datetime
import numpy as np
import os
import torch
from PIL import Image
import pytest
from pytest import fixture
from typing import Tuple, List
from cv2 import imread, cvtColor, COLOR_BGR2RGB
from skimage.metrics import structural_similarity as ssim
"""
This test suite compares images in 2 directories by file name
The directories are specified by the command line arguments --baseline_dir and --test_dir
"""
# ssim: Structural Similarity Index
# Returns a tuple of (ssim, diff_image)
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
score, diff = ssim(img0, img1, channel_axis=-1, full=True)
# rescale the difference image to 0-255 range
diff = (diff * 255).astype("uint8")
return score, diff
# Metrics must return a tuple of (score, diff_image)
METRICS = {"ssim": ssim_score}
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
class TestCompareImageMetrics:
@fixture(scope="class")
def test_file_names(self, args_pytest):
test_dir = args_pytest['test_dir']
fnames = self.gather_file_basenames(test_dir)
yield fnames
del fnames
@fixture(scope="class")
def teardown(self, args_pytest):
yield
# Runs after all tests are complete
# Aggregate output files into a grid of images
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
img_output_dir = args_pytest['img_output_dir']
metrics_file = args_pytest['metrics_file']
grid_dir = os.path.join(img_output_dir, "grid")
os.makedirs(grid_dir, exist_ok=True)
for metric_dir in METRICS.keys():
metric_path = os.path.join(img_output_dir, metric_dir)
for file in os.listdir(metric_path):
if file.endswith(".png"):
score = self.lookup_score_from_fname(file, metrics_file)
image_file_list = []
image_file_list.append([
os.path.join(baseline_dir, file),
os.path.join(test_dir, file),
os.path.join(metric_path, file)
])
# Create grid
image_list = [[Image.open(file) for file in files] for files in image_file_list]
grid = self.image_grid(image_list)
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
# Tests run for each baseline file name
@fixture()
def fname(self, baseline_fname, teardown):
yield baseline_fname
del baseline_fname
# For a baseline image file, finds the corresponding file name in test_dir and
# compares the images using the metrics in METRICS
@pytest.mark.parametrize("metric", METRICS.keys())
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_pipeline_compare(
self,
args_pytest,
fname,
test_file_names,
metric,
teardown,
):
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
metrics_output_file = args_pytest['metrics_file']
img_output_dir = args_pytest['img_output_dir']
if not os.path.isdir(baseline_dir):
pytest.skip("Baseline directory does not exist")
return
if not os.path.isdir(test_dir):
pytest.skip("Test directory does not exist")
return
# Check that all files in baseline_dir have a file in test_dir with matching metadata
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
file_match = self.find_file_match(baseline_file_path, file_paths)
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
baseline_file_path = os.path.join(baseline_dir, fname)
# Find file match
file_paths = [os.path.join(test_dir, f) for f in test_file_names]
test_file = self.find_file_match(baseline_file_path, file_paths)
# Run metrics
sample_baseline = self.read_img(baseline_file_path)
sample_secondary = self.read_img(test_file)
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
metric_status = score > METRICS_PASS_THRESHOLD[metric]
# Save metric values
with open(metrics_output_file, 'a') as f:
run_info = os.path.splitext(fname)[0]
metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
# Save metric image
metric_img_dir = os.path.join(img_output_dir, metric)
os.makedirs(metric_img_dir, exist_ok=True)
output_filename = f'{fname}'
Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
assert score > METRICS_PASS_THRESHOLD[metric]
def read_img(self, filename: str) -> np.ndarray:
cvImg = imread(filename)
cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
return cvImg
def image_grid(self, img_list: list[list[Image.Image]]):
# imgs is a 2D list of images
# Assumes the input images are a rectangular grid of equal sized images
rows = len(img_list)
cols = len(img_list[0])
w, h = img_list[0][0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, row in enumerate(img_list):
for j, img in enumerate(row):
grid.paste(img, box=(j * w, i * h))
return grid
def lookup_score_from_fname(self,
fname: str,
metrics_output_file: str
) -> float:
fname_basestr = os.path.splitext(fname)[0]
with open(metrics_output_file, 'r') as f:
for line in f:
if fname_basestr in line:
score = float(line.split('|')[5])
return score
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
def gather_file_basenames(self, directory: str):
files = []
for file in os.listdir(directory):
if file.endswith(".png"):
files.append(file)
return files
def read_file_prompt(self, fname: str) -> str:
# Read prompt from image file metadata
img = Image.open(fname)
img.load()
return img.info['prompt']
def find_file_match(self, baseline_file: str, file_paths: List[str]):
# Find a file in file_paths with matching metadata to baseline_file
baseline_prompt = self.read_file_prompt(baseline_file)
# Do not match empty prompts
if baseline_prompt is None or baseline_prompt == "":
return None
# Find file match
# Reorder test_file_names so that the file with matching name is first
# This is an optimization because matching file names are more likely
# to have matching metadata if they were generated with the same script
basename = os.path.basename(baseline_file)
file_path_basenames = [os.path.basename(f) for f in file_paths]
if basename in file_path_basenames:
match_index = file_path_basenames.index(basename)
file_paths.insert(0, file_paths.pop(match_index))
for f in file_paths:
test_file_prompt = self.read_file_prompt(f)
if baseline_prompt == test_file_prompt:
return f

View File

@ -1,14 +1,43 @@
import json
import logging
import multiprocessing
import os
import pathlib
import socket
import subprocess
import sys
import time
import urllib
from typing import Tuple
from typing import Tuple, List
import pytest
import requests
from comfy.cli_args_types import Configuration
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
def get_lan_ip():
"""
Finds the host's IP address on the LAN it's connected to.
Returns:
str: The IP address of the host on the LAN.
"""
# Create a dummy socket
s = None
try:
# Connect to a dummy address (Here, Google's public DNS server)
# The actual connection is not made, but this allows finding out the LAN IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
finally:
if s is not None:
s.close()
return ip
def run_server(server_arguments: Configuration):
from comfy.cmd.main import main
@ -20,7 +49,83 @@ def run_server(server_arguments: Configuration):
@pytest.fixture(scope="function", autouse=False)
def comfy_background_server(tmp_path) -> Tuple[Configuration, multiprocessing.Process]:
def has_gpu() -> bool:
# ipex
try:
import intel_extension_for_pytorch as ipex
has_gpu = ipex.xpu.device_count() > 0
except ImportError:
try:
import torch
has_gpu = torch.device(torch.cuda.current_device()) is not None
except:
has_gpu = False
if has_gpu:
from comfy import model_management
from comfy.model_management import CPUState
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
yield has_gpu
@pytest.fixture(scope="module", autouse=False)
def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
"""
starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
:return:
"""
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
processes_to_close: List[subprocess.Popen] = []
from testcontainers.rabbitmq import RabbitMqContainer
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
frontend_command = [
"comfyui",
"--listen=0.0.0.0",
"--port=9001",
"--cpu",
"--distributed-queue-frontend",
f"-w={str(tmp_path)}",
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
backend_command = [
"comfyui-worker",
"--port=9002",
f"-w={str(tmp_path)}",
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
try:
server_address = f"http://{get_lan_ip()}:9001"
start_time = time.time()
connected = False
while time.time() - start_time < 60:
try:
response = requests.get(server_address)
if response.status_code == 200:
connected = True
break
except ConnectionRefusedError:
pass
except Exception as exc:
logging.warning("", exc_info=exc)
time.sleep(1)
if not connected:
raise RuntimeError("could not connect to frontend")
yield server_address
finally:
for process in processes_to_close:
process.terminate()
@pytest.fixture(scope="module", autouse=False)
def comfy_background_server(tmp_path_factory) -> Tuple[Configuration, multiprocessing.Process]:
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
import torch
# Start server
@ -99,7 +204,7 @@ def model(clip):
pytest.skip(f"{checkpoint} not present on machine")
@pytest.fixture(scope="function", autouse=True)
@pytest.fixture(scope="function", autouse=False)
def use_temporary_output_directory(tmp_path: pathlib.Path):
from comfy.cmd import folder_paths
@ -109,7 +214,7 @@ def use_temporary_output_directory(tmp_path: pathlib.Path):
folder_paths.set_output_directory(orig_dir)
@pytest.fixture(scope="function", autouse=True)
@pytest.fixture(scope="function", autouse=False)
def use_temporary_input_directory(tmp_path: pathlib.Path):
from comfy.cmd import folder_paths

View File

@ -1,17 +1,9 @@
import asyncio
import logging
import os
import socket
import subprocess
import sys
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import List
import jwt
import pytest
import requests
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.aio_client import AsyncRemoteComfyClient
@ -22,9 +14,6 @@ from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocat
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from comfy.distributed.server_stub import ServerStub
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
@ -103,68 +92,19 @@ async def test_distributed_prompt_queues_same_process():
@pytest.mark.asyncio
async def test_frontend_backend_workers():
processes_to_close: List[subprocess.Popen] = []
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
try:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
frontend_command = [
"comfyui",
"--listen=0.0.0.0",
"--port=9001",
"--cpu",
"--distributed-queue-frontend",
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
backend_command = [
"comfyui-worker",
"--port=9002",
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
server_address = f"http://{get_lan_ip()}:9001"
start_time = time.time()
while time.time() - start_time < 60:
try:
response = requests.get(server_address)
if response.status_code == 200:
break
except ConnectionRefusedError:
pass
except Exception as exc:
logging.warning("", exc_info=exc)
time.sleep(1)
client = AsyncRemoteComfyClient(server_address=server_address)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
finally:
for process in processes_to_close:
process.terminate()
async def test_frontend_backend_workers(frontend_backend_worker_with_rabbitmq):
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
len_queue_after = await client.len_queue()
assert len_queue_after == 0
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
def get_lan_ip():
"""
Finds the host's IP address on the LAN it's connected to.
@pytest.mark.asyncio
async def test_frontend_backend_workers_validation_error_raises(frontend_backend_worker_with_rabbitmq):
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
Returns:
str: The IP address of the host on the LAN.
"""
# Create a dummy socket
s = None
try:
# Connect to a dummy address (Here, Google's public DNS server)
# The actual connection is not made, but this allows finding out the LAN IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
finally:
if s is not None:
s.close()
return ip
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
with pytest.raises(Exception):
await client.queue_prompt(prompt)

View File

@ -1,144 +0,0 @@
{
"4": {
"inputs": {
"ckpt_name": "sd_xl_base_1.0.safetensors"
},
"class_type": "CheckpointLoaderSimple"
},
"5": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage"
},
"6": {
"inputs": {
"text": "a photo of a cat",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode"
},
"10": {
"inputs": {
"add_noise": "enable",
"noise_seed": 42,
"steps": 20,
"cfg": 7.5,
"sampler_name": "euler",
"scheduler": "normal",
"start_at_step": 0,
"end_at_step": 32,
"return_with_leftover_noise": "enable",
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"15",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSamplerAdvanced"
},
"12": {
"inputs": {
"samples": [
"14",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode"
},
"13": {
"inputs": {
"filename_prefix": "test_inference",
"images": [
"12",
0
]
},
"class_type": "SaveImage"
},
"14": {
"inputs": {
"add_noise": "disable",
"noise_seed": 42,
"steps": 20,
"cfg": 7.5,
"sampler_name": "euler",
"scheduler": "normal",
"start_at_step": 32,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
"model": [
"16",
0
],
"positive": [
"17",
0
],
"negative": [
"20",
0
],
"latent_image": [
"10",
0
]
},
"class_type": "KSamplerAdvanced"
},
"15": {
"inputs": {
"conditioning": [
"6",
0
]
},
"class_type": "ConditioningZeroOut"
},
"16": {
"inputs": {
"ckpt_name": "sd_xl_refiner_1.0.safetensors"
},
"class_type": "CheckpointLoaderSimple"
},
"17": {
"inputs": {
"text": "a photo of a cat",
"clip": [
"16",
1
]
},
"class_type": "CLIPTextEncode"
},
"20": {
"inputs": {
"text": "",
"clip": [
"16",
1
]
},
"class_type": "CLIPTextEncode"
}
}

View File

@ -1,228 +0,0 @@
from copy import deepcopy
from io import BytesIO
from urllib import request
import numpy
import os
from PIL import Image
import pytest
from pytest import fixture
import time
import torch
from typing import Union
import json
import subprocess
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import urllib.request
import urllib.parse
from comfy.sampler_names import SAMPLER_NAMES, SCHEDULER_NAMES
"""
These tests generate and save images through a range of parameters
"""
class ComfyGraph:
def __init__(self,
graph: dict,
sampler_nodes: list[str],
):
self.graph = graph
self.sampler_nodes = sampler_nodes
def set_prompt(self, prompt, negative_prompt=None):
# Sets the prompt for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
prompt_node = self.graph[node]['inputs']['positive'][0]
self.graph[prompt_node]['inputs']['text'] = prompt
if negative_prompt:
negative_prompt_node = self.graph[node]['inputs']['negative'][0]
self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt
def set_sampler_name(self, sampler_name: str, ):
# sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
self.graph[node]['inputs']['sampler_name'] = sampler_name
def set_scheduler(self, scheduler: str):
# sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
self.graph[node]['inputs']['scheduler'] = scheduler
def set_filename_prefix(self, prefix: str):
# sets the filename prefix for the save nodes
for node in self.graph:
if self.graph[node]['class_type'] == 'SaveImage':
self.graph[node]['inputs']['filename_prefix'] = prefix
class ComfyClient:
# From examples/websockets_api_example.py
def connect(self,
listen: str = '127.0.0.1',
port: Union[str, int] = 8188,
client_id: str = str(uuid.uuid4())
):
self.client_id = client_id
self.server_address = f"{listen}:{port}"
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
self.ws = ws
def queue_prompt(self, prompt):
p = {"prompt": prompt, "client_id": self.client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(self, filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
return response.read()
def get_history(self, prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
return json.loads(response.read())
def get_images(self, graph, save=True):
prompt = graph
if not save:
# Replace save nodes with preview nodes
prompt_str = json.dumps(prompt)
prompt_str = prompt_str.replace('SaveImage', 'PreviewImage')
prompt = json.loads(prompt_str)
prompt_id = self.queue_prompt(prompt)['prompt_id']
output_images = {}
while True:
out = self.ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = self.get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
#
# Initialize graphs
#
default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
with open(default_graph_file, 'r') as file:
default_graph = json.loads(file.read())
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10', '14'])
DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
#
# Loop through these variables
#
comfy_graph_list = [DEFAULT_COMFY_GRAPH]
comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID]
prompt_list = [
'a painting of a cat',
]
sampler_list = SAMPLER_NAMES[:]
scheduler_list = SCHEDULER_NAMES[:]
@pytest.mark.inference
@pytest.mark.parametrize("sampler", sampler_list)
@pytest.mark.parametrize("scheduler", scheduler_list)
@pytest.mark.parametrize("prompt", prompt_list)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
class TestInference:
#
# Initialize server and client
#
def start_client(self, listen: str, port: int):
# Start client
comfy_client = ComfyClient()
# Connect to server (with retries)
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
comfy_client.connect(listen=listen, port=port)
except ConnectionRefusedError as e:
print(e)
print(f"({i + 1}/{n_tries}) Retrying...")
else:
break
return comfy_client
#
# Client and graph fixtures with server warmup
#
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
# The "graph" is the default graph
@fixture(scope="function", params=comfy_graph_list, ids=comfy_graph_ids, autouse=False)
def _client_graph(self, request, comfy_background_server) -> (ComfyClient, ComfyGraph):
configuration, _ = comfy_background_server
comfy_graph = request.param
# Start client
comfy_client = self.start_client(configuration.listen, configuration.port)
# Warm up pipeline
comfy_client.get_images(graph=comfy_graph.graph, save=False)
yield comfy_client, comfy_graph
del comfy_client
del comfy_graph
torch.cuda.empty_cache()
@fixture
def client(self, _client_graph):
client = _client_graph[0]
yield client
@fixture
def comfy_graph(self, _client_graph):
# avoid mutating the graph
graph = deepcopy(_client_graph[1])
yield graph
def test_comfy(
self,
client,
comfy_graph,
sampler,
scheduler,
prompt,
request
):
test_info = request.node.name
comfy_graph.set_filename_prefix(test_info)
# Settings for comfy graph
comfy_graph.set_sampler_name(sampler)
comfy_graph.set_scheduler(scheduler)
comfy_graph.set_prompt(prompt)
# Generate
images = client.get_images(comfy_graph.graph)
assert len(images) != 0, "No images generated"
# assert all images are not blank
for images_output in images.values():
for image_data in images_output:
pil_image = Image.open(BytesIO(image_data))
assert numpy.array(pil_image).any() != 0, "Image is blank"

View File

@ -1,25 +1,13 @@
import pytest
import torch
from comfy import model_management
from comfy.api.components.schema.prompt import Prompt
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.model_downloader import add_known_models, KNOWN_LORAS
from comfy.model_downloader_types import CivitFile
from comfy.model_management import CPUState
try:
has_gpu = torch.device(torch.cuda.current_device()) is not None
except:
has_gpu = False
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
@pytest.mark.skipif(not has_gpu, reason="Expects GPU device")
@pytest.mark.asyncio
async def test_lora_workflow():
prompt = Prompt.validate({
_workflows = {
"lora_1": {
"3": {
"inputs": {
"seed": 851616030078638,
@ -144,11 +132,30 @@ async def test_lora_workflow():
"title": "Load LoRA"
}
}
})
}
}
@pytest.fixture(scope="module", autouse=False)
@pytest.mark.asyncio
async def client(tmp_path_factory) -> EmbeddedComfyClient:
config = Configuration()
config.cwd = str(tmp_path_factory.mktemp("comfy_test_cwd"))
async with EmbeddedComfyClient(config) as client:
yield client
@pytest.mark.asyncio
@pytest.mark.parametrize("workflow_name, workflow", _workflows.items())
async def test_workflow(workflow_name: str, workflow: dict, has_gpu: bool, client: EmbeddedComfyClient):
if not has_gpu:
pytest.skip("requires gpu")
prompt = Prompt.validate(workflow)
add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors"))
async with EmbeddedComfyClient() as client:
outputs = await client.queue_prompt(prompt)
# todo: add all the models we want to test a bit more elegantly
outputs = await client.queue_prompt(prompt)
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None

View File

@ -122,6 +122,7 @@ def test_string_enum_request_parameter():
# todo: check that a graph that uses this in a checkpoint is valid
@pytest.mark.skip("issues")
def test_hash_images():
nt = HashImage.INPUT_TYPES()
assert nt is not None