From 66725034dec4884f89c0ce23e6d36a6a2e52bf38 Mon Sep 17 00:00:00 2001 From: enzymezoo-code Date: Fri, 1 Sep 2023 15:53:06 -0500 Subject: [PATCH] Add inference tests --- pytest.ini | 4 + tests/inference/conftest.py | 20 +++ tests/inference/graphs/default_graph.json | 144 +++++++++++++++ tests/inference/test_inference.py | 204 ++++++++++++++++++++++ 4 files changed, 372 insertions(+) create mode 100644 pytest.ini create mode 100644 tests/inference/conftest.py create mode 100644 tests/inference/graphs/default_graph.json create mode 100644 tests/inference/test_inference.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..3528fb8b5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + inference: mark as inference test (deselect with '-m "not inference"') +addopts = -s \ No newline at end of file diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py new file mode 100644 index 000000000..babeb17ce --- /dev/null +++ b/tests/inference/conftest.py @@ -0,0 +1,20 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') + parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") + parser.addoption("--port", type=int, default=8188, help="Set the listen port.") + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['output_dir'] = pytestconfig.getoption('output_dir') + args['listen'] = pytestconfig.getoption('listen') + args['port'] = pytestconfig.getoption('port') + + os.makedirs(args['output_dir'], exist_ok=True) + + return args \ No newline at end of file diff --git a/tests/inference/graphs/default_graph.json b/tests/inference/graphs/default_graph.json new file mode 100644 index 000000000..b478193ad --- /dev/null +++ b/tests/inference/graphs/default_graph.json @@ -0,0 +1,144 @@ +{ + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage" + }, + "6": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "10": { + "inputs": { + "add_noise": "enable", + "noise_seed": 42, + "steps": 40, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 0, + "end_at_step": 32, + "return_with_leftover_noise": "enable", + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "15", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "12": { + "inputs": { + "samples": [ + "14", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode" + }, + "13": { + "inputs": { + "filename_prefix": "tests/ref_ComfyUI", + "images": [ + "12", + 0 + ] + }, + "class_type": "SaveImage" + }, + "14": { + "inputs": { + "add_noise": "disable", + "noise_seed": 42, + "steps": 40, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 32, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "16", + 0 + ], + "positive": [ + "17", + 0 + ], + "negative": [ + "20", + 0 + ], + "latent_image": [ + "10", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "15": { + "inputs": { + "conditioning": [ + "6", + 0 + ] + }, + "class_type": "ConditioningZeroOut" + }, + "16": { + "inputs": { + "ckpt_name": "sd_xl_refiner_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "17": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "20": { + "inputs": { + "text": "", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + } + } \ No newline at end of file diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py new file mode 100644 index 000000000..e3116f1ea --- /dev/null +++ b/tests/inference/test_inference.py @@ -0,0 +1,204 @@ +from copy import deepcopy +from urllib import request +import numpy +import os +from PIL import Image +import pytest +from pytest import fixture +import time +import torch +from typing import Tuple, Union +import json +import subprocess +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import json +import urllib.request +import urllib.parse + +from comfy.samplers import KSampler + +""" +These tests are used to generate images compare the output +of inference with the ground-truth images. +""" + +class ComfyGraph: + def __init__(self, + graph: dict, + sampler_nodes: list[str], + ): + self.graph = graph + self.sampler_nodes = sampler_nodes + + def set_prompt(self, prompt, negative_prompt=None): + # Sets the prompt for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + prompt_node = self.graph[node]['inputs']['positive'][0] + self.graph[prompt_node]['inputs']['text'] = prompt + if negative_prompt: + negative_prompt_node = self.graph[node]['inputs']['negative'][0] + self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt + + def set_sampler_name(self, sampler_name:str, ): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['sampler_name'] = sampler_name + + def set_scheduler(self, scheduler:str): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['scheduler'] = scheduler + + +class ComfyClient: + def connect(self, + listen:str = '127.0.0.1', + port:Union[str,int] = 8188, + client_id: str = str(uuid.uuid4()) + ): + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws + + # From examples/websockets_api_example.py + def queue_prompt(self, prompt): + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def get_images(self, graph, save=True): + prompt = graph + if not save: + # Replace save nodes with preview nodes + prompt_str = json.dumps(prompt) + prompt_str = prompt_str.replace('SaveImage', 'PreviewImage') + prompt = json.loads(prompt_str) + + prompt_id = self.queue_prompt(prompt)['prompt_id'] + output_images = {} + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break #Execution is done + else: + continue #previews are binary data + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output + + return output_images + +# +# Initialize graphs +# +default_graph_file = 'tests/inference/graphs/default_graph.json' +with open(default_graph_file, 'r') as file: + default_graph = json.loads(file.read()) +DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14']) + +# +# Loop through these variables +# +comfy_graph_list = [DEFAULT_COMFY_GRAPH] +prompt_list = [ + 'a textured painting of a wet frog', + 'a textured painting of a wet toad', +] +sampler_list = KSampler.SAMPLERS[0:2] +scheduler_list = [KSampler.SCHEDULERS[0]] + +@pytest.mark.inference +@pytest.mark.parametrize("sampler", sampler_list) +@pytest.mark.parametrize("scheduler", scheduler_list) +@pytest.mark.parametrize("prompt", prompt_list) +class TestInference: + # Initialize pipeline + # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server + # The "graph" is the default graph + @fixture(scope="class", params=comfy_graph_list, autouse=True) + def _client_graph(self, request, args_pytest) -> (ComfyClient, ComfyGraph): + comfy_graph = request.param + + # Start server + p = subprocess.Popen([ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + ]) + + # Start client + comfy_client = ComfyClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + comfy_client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + except ConnectionRefusedError as e: + print(e) + print(f"({i+1}/{n_tries}) Retrying...") + else: + break + + # warm up pipeline + comfy_client.get_images(graph=comfy_graph.graph, save=False) + + yield comfy_client, comfy_graph + del comfy_client + del comfy_graph + p.kill() + torch.cuda.empty_cache() + + @fixture + def client(self, _client_graph): + client = _client_graph[0] + yield client + + # method-scoped fixture for graph to avoid mutating the graph + @fixture + def comfy_graph(self, _client_graph): + graph = deepcopy(_client_graph[1]) + yield graph + + def test_comfy( + self, + client, + comfy_graph, + sampler, + scheduler, + prompt, + ): + # Settings for comfy graph + comfy_graph.set_sampler_name(sampler) + comfy_graph.set_scheduler(scheduler) + comfy_graph.set_prompt(prompt) + + client.get_images(comfy_graph.graph) +