Fix tests

This commit is contained in:
doctorpangloss 2024-01-05 15:11:21 -08:00
parent 42232f4d20
commit 739195b127
12 changed files with 83 additions and 66 deletions

View File

@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@ -28,4 +28,4 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install .

View File

@ -11,13 +11,12 @@ jobs:
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
with:
python-version: '3.11'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install .
- name: Run Tests
run: |
npm ci

View File

@ -179,7 +179,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
5. Then, run the following command to install `comfyui` into your current environment. This will correctly select the version of pytorch that matches the GPU on your machine (NVIDIA or CPU on Windows, NVIDIA AMD or CPU on Linux):
```shell
pip install -e .
pip install -e .[test]
```
6. To run the web server:
```shell
@ -189,7 +189,11 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
```shell
comfyui-openapi-gen
```
To run tests:
```shell
pytest tests/inference
(cd tests-ui && npm ci && npm test:generate && npm test)
```
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
### Authoring Custom Nodes

2
requirements-tests.txt Normal file
View File

@ -0,0 +1,2 @@
pytest
websocket-client==1.6.1

View File

@ -151,6 +151,7 @@ def dependencies() -> List[str]:
package_data = ['sd1_tokenizer/*', '**/*.json', '**/*.yaml']
if not is_editable:
package_data.append('comfy/web/**/*')
test_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-tests.txt")).readlines()
setup(
name=package_name,
description="",
@ -172,4 +173,8 @@ setup(
package_data={
'comfy': package_data
},
tests_require=test_dependencies,
extras_require={
'test': test_dependencies
},
)

View File

@ -432,7 +432,7 @@ describe("group node", () => {
nodes.save,
]);
const { api } = require("../../web/scripts/api");
const { api } = require("../../comfy/web/scripts/api");
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));

View File

@ -1,13 +1,13 @@
// @ts-check
/// <reference path="../../web/types/litegraph.d.ts" />
/// <reference path="../../comfy/web/types/litegraph.d.ts" />
/**
* @typedef { import("../../web/scripts/app")["app"] } app
* @typedef { import("../../web/types/litegraph") } LG
* @typedef { import("../../web/types/litegraph").IWidget } IWidget
* @typedef { import("../../web/types/litegraph").ContextMenuItem } ContextMenuItem
* @typedef { import("../../web/types/litegraph").INodeInputSlot } INodeInputSlot
* @typedef { import("../../web/types/litegraph").INodeOutputSlot } INodeOutputSlot
* @typedef { import("../../comfy/web/scripts/app")["app"] } app
* @typedef { import("../../comfy/web/types/litegraph") } LG
* @typedef { import("../../comfy/web/types/litegraph").IWidget } IWidget
* @typedef { import("../../comfy/web/types/litegraph").ContextMenuItem } ContextMenuItem
* @typedef { import("../../comfy/web/types/litegraph").INodeInputSlot } INodeInputSlot
* @typedef { import("../../comfy/web/types/litegraph").INodeOutputSlot } INodeOutputSlot
* @typedef { InstanceType<LG["LGraphNode"]> & { widgets?: Array<IWidget> } } LGNode
* @typedef { (...args: EzOutput[] | [...EzOutput[], Record<string, unknown>]) => EzNode } EzNodeFactory
*/

View File

@ -15,7 +15,7 @@ export async function start(config = {}) {
}
mockApi(config);
const { app } = require("../../web/scripts/app");
const { app } = require("../../comfy/web/scripts/app");
config.preSetup?.(app);
await app.setup();
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
@ -35,7 +35,7 @@ export async function checkBeforeAndAfterReload(graph, cb) {
* @param { string } name
* @param { Record<string, string | [string | string[], any]> } input
* @param { (string | string[])[] | Record<string, string | string[]> } output
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
* @returns { Record<string, import("../../comfy/web/types/comfy").ComfyObjectInfo> }
*/
export function makeNodeDef(name, input, output = {}) {
const nodeDef = {
@ -106,7 +106,7 @@ export function createDefaultWorkflow(ez, graph) {
}
export async function getNodeDefs() {
const { api } = require("../../web/scripts/api");
const { api } = require("../../comfy/web/scripts/api");
return api.getNodeDefs();
}

View File

@ -18,14 +18,14 @@ function forEachKey(cb) {
}
export function setup(ctx) {
const lg = fs.readFileSync(path.resolve("../web/lib/litegraph.core.js"), "utf-8");
const lg = fs.readFileSync(path.resolve("../comfy/web/lib/litegraph.core.js"), "utf-8");
const globalTemp = {};
(function (console) {
eval(lg);
}).call(globalTemp, nop);
forEachKey((k) => (ctx[k] = globalTemp[k]));
require(path.resolve("../web/lib/litegraph.extensions.js"));
require(path.resolve("../comfy/web/lib/litegraph.extensions.js"));
}
export function teardown(ctx) {

View File

@ -1,4 +1,4 @@
require("../../web/scripts/api");
require("../../comfy/web/scripts/api");
const fs = require("fs");
const path = require("path");
@ -14,7 +14,7 @@ function* walkSync(dir) {
}
/**
* @typedef { import("../../web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
* @typedef { import("../../comfy/web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
*/
/**
@ -22,9 +22,9 @@ function* walkSync(dir) {
*/
export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
if (!mockExtensions) {
mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core")))
mockExtensions = Array.from(walkSync(path.resolve("../comfy/web/extensions/core")))
.filter((x) => x.endsWith(".js"))
.map((x) => path.relative(path.resolve("../web"), x));
.map((x) => path.relative(path.resolve("../comfy/web"), x));
}
if (!mockNodeDefs) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
@ -41,7 +41,7 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x),
};
jest.mock("../../web/scripts/api", () => ({
jest.mock("../../comfy/web/scripts/api", () => ({
get api() {
return mockApi;
},

View File

@ -4,10 +4,7 @@
Additional requirements for running tests:
```
pip install pytest
pip install websocket-client==1.6.1
opencv-python==4.6.0.66
scikit-image==0.21.0
pip install .[test]
```
Run inference tests:
```

View File

@ -11,20 +11,20 @@ import torch
from typing import Union
import json
import subprocess
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import urllib.request
import urllib.parse
from comfy.samplers import KSampler
"""
These tests generate and save images through a range of parameters
"""
class ComfyGraph:
def __init__(self,
def __init__(self,
graph: dict,
sampler_nodes: list[str],
):
@ -40,17 +40,17 @@ class ComfyGraph:
negative_prompt_node = self.graph[node]['inputs']['negative'][0]
self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt
def set_sampler_name(self, sampler_name:str, ):
def set_sampler_name(self, sampler_name: str, ):
# sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
self.graph[node]['inputs']['sampler_name'] = sampler_name
def set_scheduler(self, scheduler:str):
def set_scheduler(self, scheduler: str):
# sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
self.graph[node]['inputs']['scheduler'] = scheduler
def set_filename_prefix(self, prefix:str):
def set_filename_prefix(self, prefix: str):
# sets the filename prefix for the save nodes
for node in self.graph:
if self.graph[node]['class_type'] == 'SaveImage':
@ -60,11 +60,11 @@ class ComfyGraph:
class ComfyClient:
# From examples/websockets_api_example.py
def connect(self,
listen:str = '127.0.0.1',
port:Union[str,int] = 8188,
client_id: str = str(uuid.uuid4())
):
def connect(self,
listen: str = '127.0.0.1',
port: Union[str, int] = 8188,
client_id: str = str(uuid.uuid4())
):
self.client_id = client_id
self.server_address = f"{listen}:{port}"
ws = websocket.WebSocket()
@ -74,7 +74,7 @@ class ComfyClient:
def queue_prompt(self, prompt):
p = {"prompt": prompt, "client_id": self.client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(self, filename, subfolder, folder_type):
@ -104,9 +104,9 @@ class ComfyClient:
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
break # Execution is done
else:
continue #previews are binary data
continue # previews are binary data
history = self.get_history(prompt_id)[prompt_id]
for o in history['outputs']:
@ -121,13 +121,14 @@ class ComfyClient:
return output_images
#
# Initialize graphs
#
default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
with open(default_graph_file, 'r') as file:
default_graph = json.loads(file.read())
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10', '14'])
DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
#
@ -142,6 +143,14 @@ prompt_list = [
sampler_list = KSampler.SAMPLERS
scheduler_list = KSampler.SCHEDULERS
def run_server(args_pytest):
from comfy.cmd.main import main
from comfy.cli_args import args
args.output_directory = args_pytest["output_dir"]
args.listen = args_pytest["listen"]
args.port = args_pytest["port"]
main()
@pytest.mark.inference
@pytest.mark.parametrize("sampler", sampler_list)
@pytest.mark.parametrize("scheduler", scheduler_list)
@ -152,18 +161,21 @@ class TestInference:
#
@fixture(scope="class", autouse=True)
def _server(self, args_pytest):
import multiprocessing
# Start server
p = subprocess.Popen([
'python','main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
])
pickled_args = {
"output_dir": args_pytest["output_dir"],
"listen": args_pytest["listen"],
"port": args_pytest["port"]
}
p = multiprocessing.Process(target=run_server, args=(pickled_args,))
p.start()
yield
p.kill()
torch.cuda.empty_cache()
def start_client(self, listen:str, port:int):
def start_client(self, listen: str, port: int):
# Start client
comfy_client = ComfyClient()
# Connect to server (with retries)
@ -174,7 +186,7 @@ class TestInference:
comfy_client.connect(listen=listen, port=port)
except ConnectionRefusedError as e:
print(e)
print(f"({i+1}/{n_tries}) Retrying...")
print(f"({i + 1}/{n_tries}) Retrying...")
else:
break
return comfy_client
@ -187,7 +199,7 @@ class TestInference:
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
comfy_graph = request.param
# Start client
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
@ -203,7 +215,7 @@ class TestInference:
def client(self, _client_graph):
client = _client_graph[0]
yield client
@fixture
def comfy_graph(self, _client_graph):
# avoid mutating the graph
@ -211,13 +223,13 @@ class TestInference:
yield graph
def test_comfy(
self,
client,
comfy_graph,
sampler,
scheduler,
prompt,
request
self,
client,
comfy_graph,
sampler,
scheduler,
prompt,
request
):
test_info = request.node.name
comfy_graph.set_filename_prefix(test_info)
@ -235,5 +247,3 @@ class TestInference:
for image_data in images_output:
pil_image = Image.open(BytesIO(image_data))
assert numpy.array(pil_image).any() != 0, "Image is blank"