From 739195b1270c3b6a79d7e7ae1f590603505f8c66 Mon Sep 17 00:00:00 2001
From: doctorpangloss <@hiddenswitch.com>
Date: Fri, 5 Jan 2024 15:11:21 -0800
Subject: [PATCH] Fix tests
---
.github/workflows/test-build.yml | 4 +-
.github/workflows/test-ui.yaml | 7 ++-
README.md | 8 ++-
requirements-tests.txt | 2 +
setup.py | 5 ++
tests-ui/tests/groupNode.test.js | 2 +-
tests-ui/utils/ezgraph.js | 14 +++---
tests-ui/utils/index.js | 6 +--
tests-ui/utils/litegraph.js | 4 +-
tests-ui/utils/setup.js | 10 ++--
tests/README.md | 5 +-
tests/inference/test_inference.py | 82 +++++++++++++++++--------------
12 files changed, 83 insertions(+), 66 deletions(-)
create mode 100644 requirements-tests.txt
diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml
index 444d6b254..77e15f3a1 100644
--- a/.github/workflows/test-build.yml
+++ b/.github/workflows/test-build.yml
@@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.8", "3.9", "3.10", "3.11"]
+ python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -28,4 +28,4 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -r requirements.txt
\ No newline at end of file
+ pip install .
\ No newline at end of file
diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml
index 4b8b97934..0b25336ad 100644
--- a/.github/workflows/test-ui.yaml
+++ b/.github/workflows/test-ui.yaml
@@ -11,13 +11,12 @@ jobs:
with:
node-version: 18
- uses: actions/setup-python@v4
- with:
- python-version: '3.10'
+ with:
+ python-version: '3.11'
- name: Install requirements
run: |
python -m pip install --upgrade pip
- pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- pip install -r requirements.txt
+ pip install .
- name: Run Tests
run: |
npm ci
diff --git a/README.md b/README.md
index cf01a0d24..78b501711 100644
--- a/README.md
+++ b/README.md
@@ -179,7 +179,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
5. Then, run the following command to install `comfyui` into your current environment. This will correctly select the version of pytorch that matches the GPU on your machine (NVIDIA or CPU on Windows, NVIDIA AMD or CPU on Linux):
```shell
- pip install -e .
+ pip install -e .[test]
```
6. To run the web server:
```shell
@@ -189,7 +189,11 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
```shell
comfyui-openapi-gen
```
-
+ To run tests:
+ ```shell
+ pytest tests/inference
+ (cd tests-ui && npm ci && npm test:generate && npm test)
+ ```
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
### Authoring Custom Nodes
diff --git a/requirements-tests.txt b/requirements-tests.txt
new file mode 100644
index 000000000..dce71e32c
--- /dev/null
+++ b/requirements-tests.txt
@@ -0,0 +1,2 @@
+pytest
+websocket-client==1.6.1
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 9894d35b1..95843ba21 100644
--- a/setup.py
+++ b/setup.py
@@ -151,6 +151,7 @@ def dependencies() -> List[str]:
package_data = ['sd1_tokenizer/*', '**/*.json', '**/*.yaml']
if not is_editable:
package_data.append('comfy/web/**/*')
+test_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-tests.txt")).readlines()
setup(
name=package_name,
description="",
@@ -172,4 +173,8 @@ setup(
package_data={
'comfy': package_data
},
+ tests_require=test_dependencies,
+ extras_require={
+ 'test': test_dependencies
+ },
)
diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js
index e6ebedd91..1afde42c6 100644
--- a/tests-ui/tests/groupNode.test.js
+++ b/tests-ui/tests/groupNode.test.js
@@ -432,7 +432,7 @@ describe("group node", () => {
nodes.save,
]);
- const { api } = require("../../web/scripts/api");
+ const { api } = require("../../comfy/web/scripts/api");
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));
diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js
index 8a55246ee..78c2cad95 100644
--- a/tests-ui/utils/ezgraph.js
+++ b/tests-ui/utils/ezgraph.js
@@ -1,13 +1,13 @@
// @ts-check
-///
+///
/**
- * @typedef { import("../../web/scripts/app")["app"] } app
- * @typedef { import("../../web/types/litegraph") } LG
- * @typedef { import("../../web/types/litegraph").IWidget } IWidget
- * @typedef { import("../../web/types/litegraph").ContextMenuItem } ContextMenuItem
- * @typedef { import("../../web/types/litegraph").INodeInputSlot } INodeInputSlot
- * @typedef { import("../../web/types/litegraph").INodeOutputSlot } INodeOutputSlot
+ * @typedef { import("../../comfy/web/scripts/app")["app"] } app
+ * @typedef { import("../../comfy/web/types/litegraph") } LG
+ * @typedef { import("../../comfy/web/types/litegraph").IWidget } IWidget
+ * @typedef { import("../../comfy/web/types/litegraph").ContextMenuItem } ContextMenuItem
+ * @typedef { import("../../comfy/web/types/litegraph").INodeInputSlot } INodeInputSlot
+ * @typedef { import("../../comfy/web/types/litegraph").INodeOutputSlot } INodeOutputSlot
* @typedef { InstanceType & { widgets?: Array } } LGNode
* @typedef { (...args: EzOutput[] | [...EzOutput[], Record]) => EzNode } EzNodeFactory
*/
diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js
index 6a08e8594..19f530f91 100644
--- a/tests-ui/utils/index.js
+++ b/tests-ui/utils/index.js
@@ -15,7 +15,7 @@ export async function start(config = {}) {
}
mockApi(config);
- const { app } = require("../../web/scripts/app");
+ const { app } = require("../../comfy/web/scripts/app");
config.preSetup?.(app);
await app.setup();
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
@@ -35,7 +35,7 @@ export async function checkBeforeAndAfterReload(graph, cb) {
* @param { string } name
* @param { Record } input
* @param { (string | string[])[] | Record } output
- * @returns { Record }
+ * @returns { Record }
*/
export function makeNodeDef(name, input, output = {}) {
const nodeDef = {
@@ -106,7 +106,7 @@ export function createDefaultWorkflow(ez, graph) {
}
export async function getNodeDefs() {
- const { api } = require("../../web/scripts/api");
+ const { api } = require("../../comfy/web/scripts/api");
return api.getNodeDefs();
}
diff --git a/tests-ui/utils/litegraph.js b/tests-ui/utils/litegraph.js
index 777f8c3ba..82e83e522 100644
--- a/tests-ui/utils/litegraph.js
+++ b/tests-ui/utils/litegraph.js
@@ -18,14 +18,14 @@ function forEachKey(cb) {
}
export function setup(ctx) {
- const lg = fs.readFileSync(path.resolve("../web/lib/litegraph.core.js"), "utf-8");
+ const lg = fs.readFileSync(path.resolve("../comfy/web/lib/litegraph.core.js"), "utf-8");
const globalTemp = {};
(function (console) {
eval(lg);
}).call(globalTemp, nop);
forEachKey((k) => (ctx[k] = globalTemp[k]));
- require(path.resolve("../web/lib/litegraph.extensions.js"));
+ require(path.resolve("../comfy/web/lib/litegraph.extensions.js"));
}
export function teardown(ctx) {
diff --git a/tests-ui/utils/setup.js b/tests-ui/utils/setup.js
index dd150214a..4807167ee 100644
--- a/tests-ui/utils/setup.js
+++ b/tests-ui/utils/setup.js
@@ -1,4 +1,4 @@
-require("../../web/scripts/api");
+require("../../comfy/web/scripts/api");
const fs = require("fs");
const path = require("path");
@@ -14,7 +14,7 @@ function* walkSync(dir) {
}
/**
- * @typedef { import("../../web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
+ * @typedef { import("../../comfy/web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
*/
/**
@@ -22,9 +22,9 @@ function* walkSync(dir) {
*/
export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
if (!mockExtensions) {
- mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core")))
+ mockExtensions = Array.from(walkSync(path.resolve("../comfy/web/extensions/core")))
.filter((x) => x.endsWith(".js"))
- .map((x) => path.relative(path.resolve("../web"), x));
+ .map((x) => path.relative(path.resolve("../comfy/web"), x));
}
if (!mockNodeDefs) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
@@ -41,7 +41,7 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x),
};
- jest.mock("../../web/scripts/api", () => ({
+ jest.mock("../../comfy/web/scripts/api", () => ({
get api() {
return mockApi;
},
diff --git a/tests/README.md b/tests/README.md
index 2005fd45b..6a71005ee 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -4,10 +4,7 @@
Additional requirements for running tests:
```
-pip install pytest
-pip install websocket-client==1.6.1
-opencv-python==4.6.0.66
-scikit-image==0.21.0
+pip install .[test]
```
Run inference tests:
```
diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py
index 141cc5c7e..76910ad86 100644
--- a/tests/inference/test_inference.py
+++ b/tests/inference/test_inference.py
@@ -11,20 +11,20 @@ import torch
from typing import Union
import json
import subprocess
-import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
+import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import urllib.request
import urllib.parse
-
from comfy.samplers import KSampler
"""
These tests generate and save images through a range of parameters
"""
+
class ComfyGraph:
- def __init__(self,
+ def __init__(self,
graph: dict,
sampler_nodes: list[str],
):
@@ -40,17 +40,17 @@ class ComfyGraph:
negative_prompt_node = self.graph[node]['inputs']['negative'][0]
self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt
- def set_sampler_name(self, sampler_name:str, ):
+ def set_sampler_name(self, sampler_name: str, ):
# sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
self.graph[node]['inputs']['sampler_name'] = sampler_name
-
- def set_scheduler(self, scheduler:str):
+
+ def set_scheduler(self, scheduler: str):
# sets the sampler name for the sampler nodes (eg. base and refiner)
for node in self.sampler_nodes:
self.graph[node]['inputs']['scheduler'] = scheduler
-
- def set_filename_prefix(self, prefix:str):
+
+ def set_filename_prefix(self, prefix: str):
# sets the filename prefix for the save nodes
for node in self.graph:
if self.graph[node]['class_type'] == 'SaveImage':
@@ -60,11 +60,11 @@ class ComfyGraph:
class ComfyClient:
# From examples/websockets_api_example.py
- def connect(self,
- listen:str = '127.0.0.1',
- port:Union[str,int] = 8188,
- client_id: str = str(uuid.uuid4())
- ):
+ def connect(self,
+ listen: str = '127.0.0.1',
+ port: Union[str, int] = 8188,
+ client_id: str = str(uuid.uuid4())
+ ):
self.client_id = client_id
self.server_address = f"{listen}:{port}"
ws = websocket.WebSocket()
@@ -74,7 +74,7 @@ class ComfyClient:
def queue_prompt(self, prompt):
p = {"prompt": prompt, "client_id": self.client_id}
data = json.dumps(p).encode('utf-8')
- req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
+ req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(self, filename, subfolder, folder_type):
@@ -104,9 +104,9 @@ class ComfyClient:
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
- break #Execution is done
+ break # Execution is done
else:
- continue #previews are binary data
+ continue # previews are binary data
history = self.get_history(prompt_id)[prompt_id]
for o in history['outputs']:
@@ -121,13 +121,14 @@ class ComfyClient:
return output_images
+
#
# Initialize graphs
#
default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
with open(default_graph_file, 'r') as file:
default_graph = json.loads(file.read())
-DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
+DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10', '14'])
DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
#
@@ -142,6 +143,14 @@ prompt_list = [
sampler_list = KSampler.SAMPLERS
scheduler_list = KSampler.SCHEDULERS
+def run_server(args_pytest):
+ from comfy.cmd.main import main
+ from comfy.cli_args import args
+ args.output_directory = args_pytest["output_dir"]
+ args.listen = args_pytest["listen"]
+ args.port = args_pytest["port"]
+ main()
+
@pytest.mark.inference
@pytest.mark.parametrize("sampler", sampler_list)
@pytest.mark.parametrize("scheduler", scheduler_list)
@@ -152,18 +161,21 @@ class TestInference:
#
@fixture(scope="class", autouse=True)
def _server(self, args_pytest):
+ import multiprocessing
# Start server
- p = subprocess.Popen([
- 'python','main.py',
- '--output-directory', args_pytest["output_dir"],
- '--listen', args_pytest["listen"],
- '--port', str(args_pytest["port"]),
- ])
+
+ pickled_args = {
+ "output_dir": args_pytest["output_dir"],
+ "listen": args_pytest["listen"],
+ "port": args_pytest["port"]
+ }
+ p = multiprocessing.Process(target=run_server, args=(pickled_args,))
+ p.start()
yield
p.kill()
torch.cuda.empty_cache()
- def start_client(self, listen:str, port:int):
+ def start_client(self, listen: str, port: int):
# Start client
comfy_client = ComfyClient()
# Connect to server (with retries)
@@ -174,7 +186,7 @@ class TestInference:
comfy_client.connect(listen=listen, port=port)
except ConnectionRefusedError as e:
print(e)
- print(f"({i+1}/{n_tries}) Retrying...")
+ print(f"({i + 1}/{n_tries}) Retrying...")
else:
break
return comfy_client
@@ -187,7 +199,7 @@ class TestInference:
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
comfy_graph = request.param
-
+
# Start client
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
@@ -203,7 +215,7 @@ class TestInference:
def client(self, _client_graph):
client = _client_graph[0]
yield client
-
+
@fixture
def comfy_graph(self, _client_graph):
# avoid mutating the graph
@@ -211,13 +223,13 @@ class TestInference:
yield graph
def test_comfy(
- self,
- client,
- comfy_graph,
- sampler,
- scheduler,
- prompt,
- request
+ self,
+ client,
+ comfy_graph,
+ sampler,
+ scheduler,
+ prompt,
+ request
):
test_info = request.node.name
comfy_graph.set_filename_prefix(test_info)
@@ -235,5 +247,3 @@ class TestInference:
for image_data in images_output:
pil_image = Image.open(BytesIO(image_data))
assert numpy.array(pil_image).any() != 0, "Image is blank"
-
-