mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
Ci quality workflows compare (#1)
* Add image comparison tests * Comparison tests do not pass with empty metadata * Ensure tests are run in correct order * Save image files with test name * Update tests readme
This commit is contained in:
parent
26c8e8bcdb
commit
dc4bb9206a
@ -6,8 +6,24 @@ Additional requirements for running tests:
|
|||||||
```
|
```
|
||||||
pip install pytest
|
pip install pytest
|
||||||
pip install websocket-client==1.6.1
|
pip install websocket-client==1.6.1
|
||||||
|
opencv-python==4.6.0.66
|
||||||
|
scikit-image==0.21.0
|
||||||
```
|
```
|
||||||
Run tests:
|
Run inference tests:
|
||||||
```
|
```
|
||||||
python -m pytest
|
pytest tests/inference
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quality regression test
|
||||||
|
Compares images in 2 directories to ensure they are the same
|
||||||
|
|
||||||
|
1) Run an inference test to save a directory of "ground truth" images
|
||||||
|
```
|
||||||
|
pytest tests/inference --output_dir tests/inference/baseline
|
||||||
|
```
|
||||||
|
2) Make code edits
|
||||||
|
|
||||||
|
3) Run inference and quality comparison tests
|
||||||
|
```
|
||||||
|
pytest
|
||||||
```
|
```
|
||||||
41
tests/compare/conftest.py
Normal file
41
tests/compare/conftest.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Command line arguments for pytest
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images')
|
||||||
|
parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test')
|
||||||
|
parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics')
|
||||||
|
parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images')
|
||||||
|
|
||||||
|
# This initializes args at the beginning of the test session
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def args_pytest(pytestconfig):
|
||||||
|
args = {}
|
||||||
|
args['baseline_dir'] = pytestconfig.getoption('baseline_dir')
|
||||||
|
args['test_dir'] = pytestconfig.getoption('test_dir')
|
||||||
|
args['metrics_file'] = pytestconfig.getoption('metrics_file')
|
||||||
|
args['img_output_dir'] = pytestconfig.getoption('img_output_dir')
|
||||||
|
|
||||||
|
# Initialize metrics file
|
||||||
|
with open(args['metrics_file'], 'a') as f:
|
||||||
|
# if file is empty, write header
|
||||||
|
if os.stat(args['metrics_file']).st_size == 0:
|
||||||
|
f.write("| date | run | file | status | value | \n")
|
||||||
|
f.write("| --- | --- | --- | --- | --- | \n")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def gather_file_basenames(directory: str):
|
||||||
|
files = []
|
||||||
|
for file in os.listdir(directory):
|
||||||
|
if file.endswith(".png"):
|
||||||
|
files.append(file)
|
||||||
|
return files
|
||||||
|
|
||||||
|
# Creates the list of baseline file names to use as a fixture
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "baseline_fname" in metafunc.fixturenames:
|
||||||
|
baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir"))
|
||||||
|
metafunc.parametrize("baseline_fname", baseline_fnames)
|
||||||
195
tests/compare/test_quality.py
Normal file
195
tests/compare/test_quality.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import pytest
|
||||||
|
from pytest import fixture
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
from cv2 import imread, cvtColor, COLOR_BGR2RGB
|
||||||
|
from skimage.metrics import structural_similarity as ssim
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This test suite compares images in 2 directories by file name
|
||||||
|
The directories are specified by the command line arguments --baseline_dir and --test_dir
|
||||||
|
|
||||||
|
"""
|
||||||
|
# ssim: Structural Similarity Index
|
||||||
|
# Returns a tuple of (ssim, diff_image)
|
||||||
|
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
||||||
|
score, diff = ssim(img0, img1, channel_axis=-1, full=True)
|
||||||
|
# rescale the difference image to 0-255 range
|
||||||
|
diff = (diff * 255).astype("uint8")
|
||||||
|
return score, diff
|
||||||
|
|
||||||
|
# Metrics must return a tuple of (score, diff_image)
|
||||||
|
METRICS = {"ssim": ssim_score}
|
||||||
|
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompareImageMetrics:
|
||||||
|
@fixture(scope="class")
|
||||||
|
def test_file_names(self, args_pytest):
|
||||||
|
test_dir = args_pytest['test_dir']
|
||||||
|
fnames = self.gather_file_basenames(test_dir)
|
||||||
|
yield fnames
|
||||||
|
del fnames
|
||||||
|
|
||||||
|
@fixture(scope="class", autouse=True)
|
||||||
|
def teardown(self, args_pytest):
|
||||||
|
yield
|
||||||
|
# Runs after all tests are complete
|
||||||
|
# Aggregate output files into a grid of images
|
||||||
|
baseline_dir = args_pytest['baseline_dir']
|
||||||
|
test_dir = args_pytest['test_dir']
|
||||||
|
img_output_dir = args_pytest['img_output_dir']
|
||||||
|
metrics_file = args_pytest['metrics_file']
|
||||||
|
|
||||||
|
grid_dir = os.path.join(img_output_dir, "grid")
|
||||||
|
os.makedirs(grid_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for metric_dir in METRICS.keys():
|
||||||
|
metric_path = os.path.join(img_output_dir, metric_dir)
|
||||||
|
for file in os.listdir(metric_path):
|
||||||
|
if file.endswith(".png"):
|
||||||
|
score = self.lookup_score_from_fname(file, metrics_file)
|
||||||
|
image_file_list = []
|
||||||
|
image_file_list.append([
|
||||||
|
os.path.join(baseline_dir, file),
|
||||||
|
os.path.join(test_dir, file),
|
||||||
|
os.path.join(metric_path, file)
|
||||||
|
])
|
||||||
|
# Create grid
|
||||||
|
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
||||||
|
grid = self.image_grid(image_list)
|
||||||
|
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
|
||||||
|
|
||||||
|
# Tests run for each baseline file name
|
||||||
|
@fixture()
|
||||||
|
def fname(self, baseline_fname):
|
||||||
|
yield baseline_fname
|
||||||
|
del baseline_fname
|
||||||
|
|
||||||
|
def test_directories_not_empty(self, args_pytest):
|
||||||
|
baseline_dir = args_pytest['baseline_dir']
|
||||||
|
test_dir = args_pytest['test_dir']
|
||||||
|
assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
|
||||||
|
assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
|
||||||
|
|
||||||
|
def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest):
|
||||||
|
# Check that all files in baseline_dir have a file in test_dir with matching metadata
|
||||||
|
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
|
||||||
|
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
|
||||||
|
file_match = self.find_file_match(baseline_file_path, file_paths)
|
||||||
|
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
|
||||||
|
|
||||||
|
# For a baseline image file, finds the corresponding file name in test_dir and
|
||||||
|
# compares the images using the metrics in METRICS
|
||||||
|
@pytest.mark.parametrize("metric", METRICS.keys())
|
||||||
|
def test_pipeline_compare(
|
||||||
|
self,
|
||||||
|
args_pytest,
|
||||||
|
fname,
|
||||||
|
test_file_names,
|
||||||
|
metric,
|
||||||
|
):
|
||||||
|
baseline_dir = args_pytest['baseline_dir']
|
||||||
|
test_dir = args_pytest['test_dir']
|
||||||
|
metrics_output_file = args_pytest['metrics_file']
|
||||||
|
img_output_dir = args_pytest['img_output_dir']
|
||||||
|
|
||||||
|
baseline_file_path = os.path.join(baseline_dir, fname)
|
||||||
|
|
||||||
|
# Find file match
|
||||||
|
file_paths = [os.path.join(test_dir, f) for f in test_file_names]
|
||||||
|
test_file = self.find_file_match(baseline_file_path, file_paths)
|
||||||
|
|
||||||
|
# Run metrics
|
||||||
|
sample_baseline = self.read_img(baseline_file_path)
|
||||||
|
sample_secondary = self.read_img(test_file)
|
||||||
|
|
||||||
|
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
|
||||||
|
metric_status = score > METRICS_PASS_THRESHOLD[metric]
|
||||||
|
|
||||||
|
# Save metric values
|
||||||
|
with open(metrics_output_file, 'a') as f:
|
||||||
|
run_info = os.path.splitext(fname)[0]
|
||||||
|
metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
|
||||||
|
date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
|
||||||
|
|
||||||
|
# Save metric image
|
||||||
|
metric_img_dir = os.path.join(img_output_dir, metric)
|
||||||
|
os.makedirs(metric_img_dir, exist_ok=True)
|
||||||
|
output_filename = f'{fname}'
|
||||||
|
Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
|
||||||
|
|
||||||
|
assert score > METRICS_PASS_THRESHOLD[metric]
|
||||||
|
|
||||||
|
def read_img(self, filename: str) -> np.ndarray:
|
||||||
|
cvImg = imread(filename)
|
||||||
|
cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
|
||||||
|
return cvImg
|
||||||
|
|
||||||
|
def image_grid(self, img_list: list[list[Image.Image]]):
|
||||||
|
# imgs is a 2D list of images
|
||||||
|
# Assumes the input images are a rectangular grid of equal sized images
|
||||||
|
rows = len(img_list)
|
||||||
|
cols = len(img_list[0])
|
||||||
|
|
||||||
|
w, h = img_list[0][0].size
|
||||||
|
grid = Image.new('RGB', size=(cols*w, rows*h))
|
||||||
|
|
||||||
|
for i, row in enumerate(img_list):
|
||||||
|
for j, img in enumerate(row):
|
||||||
|
grid.paste(img, box=(j*w, i*h))
|
||||||
|
return grid
|
||||||
|
|
||||||
|
def lookup_score_from_fname(self,
|
||||||
|
fname: str,
|
||||||
|
metrics_output_file: str
|
||||||
|
) -> float:
|
||||||
|
fname_basestr = os.path.splitext(fname)[0]
|
||||||
|
with open(metrics_output_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
if fname_basestr in line:
|
||||||
|
score = float(line.split('|')[5])
|
||||||
|
return score
|
||||||
|
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
|
||||||
|
|
||||||
|
def gather_file_basenames(self, directory: str):
|
||||||
|
files = []
|
||||||
|
for file in os.listdir(directory):
|
||||||
|
if file.endswith(".png"):
|
||||||
|
files.append(file)
|
||||||
|
return files
|
||||||
|
|
||||||
|
def read_file_prompt(self, fname:str) -> str:
|
||||||
|
# Read prompt from image file metadata
|
||||||
|
img = Image.open(fname)
|
||||||
|
img.load()
|
||||||
|
return img.info['prompt']
|
||||||
|
|
||||||
|
def find_file_match(self, baseline_file: str, file_paths: List[str]):
|
||||||
|
# Find a file in file_paths with matching metadata to baseline_file
|
||||||
|
baseline_prompt = self.read_file_prompt(baseline_file)
|
||||||
|
|
||||||
|
# Do not match empty prompts
|
||||||
|
if baseline_prompt is None or baseline_prompt == "":
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find file match
|
||||||
|
# Reorder test_file_names so that the file with matching name is first
|
||||||
|
# This is an optimization because matching file names are more likely
|
||||||
|
# to have matching metadata if they were generated with the same script
|
||||||
|
basename = os.path.basename(baseline_file)
|
||||||
|
file_path_basenames = [os.path.basename(f) for f in file_paths]
|
||||||
|
if basename in file_path_basenames:
|
||||||
|
match_index = file_path_basenames.index(basename)
|
||||||
|
file_paths.insert(0, file_paths.pop(match_index))
|
||||||
|
|
||||||
|
for f in file_paths:
|
||||||
|
test_file_prompt = self.read_file_prompt(f)
|
||||||
|
if baseline_prompt == test_file_prompt:
|
||||||
|
return f
|
||||||
@ -17,4 +17,20 @@ def args_pytest(pytestconfig):
|
|||||||
|
|
||||||
os.makedirs(args['output_dir'], exist_ok=True)
|
os.makedirs(args['output_dir'], exist_ok=True)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(items):
|
||||||
|
# Modifies items so tests run in the correct order
|
||||||
|
|
||||||
|
LAST_TESTS = ['test_quality']
|
||||||
|
|
||||||
|
# Move the last items to the end
|
||||||
|
last_items = []
|
||||||
|
for test_name in LAST_TESTS:
|
||||||
|
for item in items.copy():
|
||||||
|
print(item.module.__name__, item)
|
||||||
|
if item.module.__name__ == test_name:
|
||||||
|
last_items.append(item)
|
||||||
|
items.remove(item)
|
||||||
|
|
||||||
|
items.extend(last_items)
|
||||||
|
|||||||
@ -8,12 +8,11 @@ import pytest
|
|||||||
from pytest import fixture
|
from pytest import fixture
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
from typing import Tuple, Union
|
from typing import Union
|
||||||
import json
|
import json
|
||||||
import subprocess
|
import subprocess
|
||||||
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
@ -52,6 +51,12 @@ class ComfyGraph:
|
|||||||
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||||
for node in self.sampler_nodes:
|
for node in self.sampler_nodes:
|
||||||
self.graph[node]['inputs']['scheduler'] = scheduler
|
self.graph[node]['inputs']['scheduler'] = scheduler
|
||||||
|
|
||||||
|
def set_filename_prefix(self, prefix:str):
|
||||||
|
# sets the filename prefix for the save nodes
|
||||||
|
for node in self.graph:
|
||||||
|
if self.graph[node]['class_type'] == 'SaveImage':
|
||||||
|
self.graph[node]['inputs']['filename_prefix'] = prefix
|
||||||
|
|
||||||
|
|
||||||
class ComfyClient:
|
class ComfyClient:
|
||||||
@ -125,11 +130,13 @@ default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
|
|||||||
with open(default_graph_file, 'r') as file:
|
with open(default_graph_file, 'r') as file:
|
||||||
default_graph = json.loads(file.read())
|
default_graph = json.loads(file.read())
|
||||||
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
|
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
|
||||||
|
DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
|
||||||
|
|
||||||
#
|
#
|
||||||
# Loop through these variables
|
# Loop through these variables
|
||||||
#
|
#
|
||||||
comfy_graph_list = [DEFAULT_COMFY_GRAPH]
|
comfy_graph_list = [DEFAULT_COMFY_GRAPH]
|
||||||
|
comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID]
|
||||||
prompt_list = [
|
prompt_list = [
|
||||||
'a painting of a cat',
|
'a painting of a cat',
|
||||||
]
|
]
|
||||||
@ -185,7 +192,7 @@ class TestInference:
|
|||||||
#
|
#
|
||||||
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
|
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
|
||||||
# The "graph" is the default graph
|
# The "graph" is the default graph
|
||||||
@fixture(scope="class", params=comfy_graph_list, autouse=True)
|
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
|
||||||
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
|
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
|
||||||
comfy_graph = request.param
|
comfy_graph = request.param
|
||||||
|
|
||||||
@ -218,7 +225,10 @@ class TestInference:
|
|||||||
sampler,
|
sampler,
|
||||||
scheduler,
|
scheduler,
|
||||||
prompt,
|
prompt,
|
||||||
|
request
|
||||||
):
|
):
|
||||||
|
test_info = request.node.name
|
||||||
|
comfy_graph.set_filename_prefix(test_info)
|
||||||
# Settings for comfy graph
|
# Settings for comfy graph
|
||||||
comfy_graph.set_sampler_name(sampler)
|
comfy_graph.set_sampler_name(sampler)
|
||||||
comfy_graph.set_scheduler(scheduler)
|
comfy_graph.set_scheduler(scheduler)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user