diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js
new file mode 100644
index 000000000..d2a318da5
--- /dev/null
+++ b/tests-ui/tests/groupNode.test.js
@@ -0,0 +1,522 @@
+// @ts-check
+///
+
+const { start, createDefaultWorkflow } = require("../utils");
+const lg = require("../utils/litegraph");
+
+describe("group node", () => {
+ beforeEach(() => {
+ lg.setup(global);
+ });
+
+ afterEach(() => {
+ lg.teardown(global);
+ });
+
+ /**
+ *
+ * @param {*} app
+ * @param {*} graph
+ * @param {*} name
+ * @param {*} nodes
+ * @returns { Promise> }
+ */
+ async function convertToGroup(app, graph, name, nodes) {
+ // Select the nodes we are converting
+ for (const n of nodes) {
+ n.select(true);
+ }
+
+ expect(Object.keys(app.canvas.selected_nodes).sort((a, b) => +a - +b)).toEqual(
+ nodes.map((n) => n.id + "").sort((a, b) => +a - +b)
+ );
+
+ global.prompt = jest.fn().mockImplementation(() => name);
+ const groupNode = await nodes[0].menu["Convert to Group Node"].call(false);
+
+ // Check group name was requested
+ expect(window.prompt).toHaveBeenCalled();
+
+ // Ensure old nodes are removed
+ for (const n of nodes) {
+ expect(n.isRemoved).toBeTruthy();
+ }
+
+ expect(groupNode.type).toEqual("workflow/" + name);
+
+ return graph.find(groupNode);
+ }
+
+ /**
+ * @param { Record } idMap
+ * @param { Record> } valueMap
+ */
+ function getOutput(idMap = {}, valueMap = {}) {
+ const expected = {
+ 1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" },
+ 2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" },
+ 3: { inputs: { text: "negative", clip: ["1", 1], ...valueMap?.[3] }, class_type: "CLIPTextEncode" },
+ 4: { inputs: { width: 512, height: 512, batch_size: 1, ...valueMap?.[4] }, class_type: "EmptyLatentImage" },
+ 5: {
+ inputs: {
+ seed: 0,
+ steps: 20,
+ cfg: 8,
+ sampler_name: "euler",
+ scheduler: "normal",
+ denoise: 1,
+ model: ["1", 0],
+ positive: ["2", 0],
+ negative: ["3", 0],
+ latent_image: ["4", 0],
+ ...valueMap?.[5],
+ },
+ class_type: "KSampler",
+ },
+ 6: { inputs: { samples: ["5", 0], vae: ["1", 2], ...valueMap?.[6] }, class_type: "VAEDecode" },
+ 7: { inputs: { filename_prefix: "ComfyUI", images: ["6", 0], ...valueMap?.[7] }, class_type: "SaveImage" },
+ };
+
+ for (const oldId in idMap) {
+ const old = expected[oldId];
+ delete expected[oldId];
+ expected[idMap[oldId]] = old;
+
+ for (const k in expected) {
+ for (const input in expected[k].inputs) {
+ const v = expected[k].inputs[input];
+ if (v instanceof Array) {
+ if (v[0] in idMap) {
+ v[0] = idMap[v[0]];
+ }
+ }
+ }
+ }
+ }
+
+ return expected;
+ }
+
+ test("can be created from selected nodes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
+
+ // Ensure links are now to the group node
+ expect(group.inputs).toHaveLength(2);
+ expect(group.outputs).toHaveLength(3);
+
+ expect(group.inputs.map((i) => i.input.name)).toEqual(["CLIPTextEncode clip", "CLIPTextEncode 2 clip"]);
+ expect(group.outputs.map((i) => i.output.name)).toEqual([
+ "CLIPTextEncode CONDITIONING",
+ "CLIPTextEncode 2 CONDITIONING",
+ "EmptyLatentImage LATENT",
+ ]);
+
+ // ckpt clip to both clip inputs on the group
+ expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [group.id, 0],
+ [group.id, 1],
+ ]);
+
+ // group conditioning to sampler
+ expect(
+ group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
+ ).toEqual([[nodes.sampler.id, 1]]);
+ // group conditioning 2 to sampler
+ expect(
+ group.outputs["CLIPTextEncode 2 CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
+ ).toEqual([[nodes.sampler.id, 2]]);
+ // group latent to sampler
+ expect(
+ group.outputs["EmptyLatentImage LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
+ ).toEqual([[nodes.sampler.id, 3]]);
+ });
+ test("can be be converted back to nodes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
+
+ /** @type { Array } */
+ const newNodes = group.menu["Convert to nodes"].call();
+
+ const pos = graph.find(
+ newNodes.find(
+ (n) => n.type === "CLIPTextEncode" && n.widgets.find((w) => w.name === "text")?.value === "positive"
+ )
+ );
+ const neg = graph.find(
+ newNodes.find(
+ (n) => n.type === "CLIPTextEncode" && n.widgets.find((w) => w.name === "text")?.value === "negative"
+ )
+ );
+ const empty = graph.find(newNodes.find((n) => n.type === "EmptyLatentImage"));
+
+ // validate links
+ expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [pos.id, 0],
+ [neg.id, 0],
+ ]);
+
+ expect(pos.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [nodes.sampler.id, 1],
+ ]);
+
+ expect(neg.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [nodes.sampler.id, 2],
+ ]);
+
+ expect(empty.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
+ [nodes.sampler.id, 3],
+ ]);
+ });
+ test("it can embed reroutes as inputs", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ // Add and connect a reroute to the clip text encodes
+ const reroute = ez.Reroute();
+ nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
+ reroute.outputs[0].connectTo(nodes.pos.inputs[0]);
+ reroute.outputs[0].connectTo(nodes.neg.inputs[0]);
+
+ // Convert to group and ensure we only have 1 input of the correct type
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty, reroute]);
+ expect(group.inputs).toHaveLength(1);
+ expect(group.inputs[0].input.type).toEqual("CLIP");
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput({
+ [nodes.pos.id]: `${group.id}:0`,
+ [nodes.neg.id]: `${group.id}:1`,
+ [nodes.empty.id]: `${group.id}:2`,
+ })
+ );
+ });
+ test("it can embed reroutes as outputs", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ // Add a reroute with no output so we output IMAGE even though its used internally
+ const reroute = ez.Reroute();
+ nodes.decode.outputs.IMAGE.connectTo(reroute.inputs[0]);
+
+ // Convert to group and ensure there is an IMAGE output
+ const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]);
+ expect(group.outputs).toHaveLength(1);
+ expect(group.outputs[0].output.type).toEqual("IMAGE");
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput({
+ [nodes.decode.id]: `${group.id}:0`,
+ [nodes.save.id]: `${group.id}:1`,
+ })
+ );
+ });
+ test("it can embed reroutes as pipes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ // Use reroutes as a pipe
+ const rerouteModel = ez.Reroute();
+ const rerouteClip = ez.Reroute();
+ const rerouteVae = ez.Reroute();
+ nodes.ckpt.outputs.MODEL.connectTo(rerouteModel.inputs[0]);
+ nodes.ckpt.outputs.CLIP.connectTo(rerouteClip.inputs[0]);
+ nodes.ckpt.outputs.VAE.connectTo(rerouteVae.inputs[0]);
+
+ const group = await convertToGroup(app, graph, "test", [rerouteModel, rerouteClip, rerouteVae]);
+
+ expect(group.outputs).toHaveLength(3);
+ expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
+
+ expect(group.outputs).toHaveLength(3);
+ expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
+
+ group.outputs[0].connectTo(nodes.sampler.inputs.model);
+ group.outputs[1].connectTo(nodes.pos.inputs.clip);
+ group.outputs[1].connectTo(nodes.neg.inputs.clip);
+ });
+ test("creates with widget values from inner nodes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ nodes.ckpt.widgets.ckpt_name.value = "model2.ckpt";
+ nodes.pos.widgets.text.value = "hello";
+ nodes.neg.widgets.text.value = "world";
+ nodes.empty.widgets.width.value = 256;
+ nodes.empty.widgets.height.value = 1024;
+ nodes.sampler.widgets.seed.value = 1;
+ nodes.sampler.widgets.control_after_generate.value = "increment";
+ nodes.sampler.widgets.steps.value = 8;
+ nodes.sampler.widgets.cfg.value = 4.5;
+ nodes.sampler.widgets.sampler_name.value = "uni_pc";
+ nodes.sampler.widgets.scheduler.value = "karras";
+ nodes.sampler.widgets.denoise.value = 0.9;
+
+ const group = await convertToGroup(app, graph, "test", [
+ nodes.ckpt,
+ nodes.pos,
+ nodes.neg,
+ nodes.empty,
+ nodes.sampler,
+ ]);
+
+ expect(group.widgets["CheckpointLoaderSimple ckpt_name"].value).toEqual("model2.ckpt");
+ expect(group.widgets["CLIPTextEncode text"].value).toEqual("hello");
+ expect(group.widgets["CLIPTextEncode 2 text"].value).toEqual("world");
+ expect(group.widgets["EmptyLatentImage width"].value).toEqual(256);
+ expect(group.widgets["EmptyLatentImage height"].value).toEqual(1024);
+ expect(group.widgets["KSampler seed"].value).toEqual(1);
+ expect(group.widgets["KSampler control_after_generate"].value).toEqual("increment");
+ expect(group.widgets["KSampler steps"].value).toEqual(8);
+ expect(group.widgets["KSampler cfg"].value).toEqual(4.5);
+ expect(group.widgets["KSampler sampler_name"].value).toEqual("uni_pc");
+ expect(group.widgets["KSampler scheduler"].value).toEqual("karras");
+ expect(group.widgets["KSampler denoise"].value).toEqual(0.9);
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput(
+ {
+ [nodes.ckpt.id]: `${group.id}:0`,
+ [nodes.pos.id]: `${group.id}:1`,
+ [nodes.neg.id]: `${group.id}:2`,
+ [nodes.empty.id]: `${group.id}:3`,
+ [nodes.sampler.id]: `${group.id}:4`,
+ },
+ {
+ [nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
+ [nodes.pos.id]: { text: "hello" },
+ [nodes.neg.id]: { text: "world" },
+ [nodes.empty.id]: { width: 256, height: 1024 },
+ [nodes.sampler.id]: {
+ seed: 1,
+ steps: 8,
+ cfg: 4.5,
+ sampler_name: "uni_pc",
+ scheduler: "karras",
+ denoise: 0.9,
+ },
+ }
+ )
+ );
+ });
+
+ test("group inputs can be reroutes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+
+ const reroute = ez.Reroute();
+ nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
+
+ reroute.outputs[0].connectTo(group.inputs[0]);
+ reroute.outputs[0].connectTo(group.inputs[1]);
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput({
+ [nodes.pos.id]: `${group.id}:0`,
+ [nodes.neg.id]: `${group.id}:1`,
+ })
+ );
+ });
+ test("group outputs can be reroutes", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+
+ const reroute1 = ez.Reroute();
+ const reroute2 = ez.Reroute();
+ group.outputs[0].connectTo(reroute1.inputs[0]);
+ group.outputs[1].connectTo(reroute2.inputs[0]);
+
+ reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive);
+ reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative);
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput({
+ [nodes.pos.id]: `${group.id}:0`,
+ [nodes.neg.id]: `${group.id}:1`,
+ })
+ );
+ });
+ test("groups can connect to each other", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+ const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]);
+
+ group1.outputs[0].connectTo(group2.inputs["KSampler positive"]);
+ group1.outputs[1].connectTo(group2.inputs["KSampler negative"]);
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput({
+ [nodes.pos.id]: `${group1.id}:0`,
+ [nodes.neg.id]: `${group1.id}:1`,
+ [nodes.empty.id]: `${group2.id}:0`,
+ [nodes.sampler.id]: `${group2.id}:1`,
+ })
+ );
+ });
+ test("displays generated image on group node", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [
+ nodes.pos,
+ nodes.neg,
+ nodes.empty,
+ nodes.sampler,
+ nodes.decode,
+ nodes.save,
+ ]);
+
+ const { api } = require("../../web/scripts/api");
+ api.dispatchEvent(new CustomEvent("execution_start", {}));
+ api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:3` }));
+ // Event should be forwarded to group node id
+ expect(+app.runningNodeId).toEqual(group.id);
+ expect(group.node["imgs"]).toBeFalsy();
+ api.dispatchEvent(
+ new CustomEvent("executed", {
+ detail: {
+ node: `${group.id}:3`,
+ output: {
+ images: [
+ {
+ filename: "test.png",
+ type: "output",
+ },
+ ],
+ },
+ },
+ })
+ );
+
+ // Trigger paint
+ group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
+
+ expect(group.node["images"]).toEqual([
+ {
+ filename: "test.png",
+ type: "output",
+ },
+ ]);
+ });
+ test("allows widgets to be converted to inputs", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+ group.widgets[0].convertToInput();
+
+ const primitive = ez.PrimitiveNode();
+ primitive.outputs[0].connectTo(group.inputs["CLIPTextEncode text"]);
+ primitive.widgets[0].value = "hello";
+
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput(
+ {
+ [nodes.pos.id]: `${group.id}:0`,
+ [nodes.neg.id]: `${group.id}:1`,
+ },
+ {
+ [nodes.pos.id]: { text: "hello" },
+ }
+ )
+ );
+ });
+ test("can be copied", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ const group1 = await convertToGroup(app, graph, "test", [
+ nodes.pos,
+ nodes.neg,
+ nodes.empty,
+ nodes.sampler,
+ nodes.decode,
+ nodes.save,
+ ]);
+
+ group1.widgets["CLIPTextEncode text"].value = "hello";
+ group1.widgets["EmptyLatentImage width"].value = 256;
+ group1.widgets["KSampler seed"].value = 1;
+
+ // Clone the node
+ group1.menu.Clone.call();
+ expect(app.graph._nodes).toHaveLength(3);
+ const group2 = graph.find(app.graph._nodes[2]);
+ expect(group2.node.type).toEqual("workflow/test");
+ expect(group2.id).not.toEqual(group1.id);
+
+ // Reconnect ckpt
+ nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["KSampler model"]);
+ nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]);
+ nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode 2 clip"]);
+ nodes.ckpt.outputs.VAE.connectTo(group2.inputs["VAEDecode vae"]);
+
+ group2.widgets["CLIPTextEncode text"].value = "world";
+ group2.widgets["EmptyLatentImage width"].value = 1024;
+ group2.widgets["KSampler seed"].value = 100;
+
+ expect((await graph.toPrompt()).output).toEqual({
+ ...getOutput(
+ {
+ [nodes.pos.id]: `${group1.id}:0`,
+ [nodes.neg.id]: `${group1.id}:1`,
+ [nodes.empty.id]: `${group1.id}:2`,
+ [nodes.sampler.id]: `${group1.id}:3`,
+ [nodes.decode.id]: `${group1.id}:4`,
+ [nodes.save.id]: `${group1.id}:5`,
+ },
+ {
+ [nodes.pos.id]: { text: "hello" },
+ [nodes.empty.id]: { width: 256 },
+ [nodes.sampler.id]: { seed: 1 },
+ }
+ ),
+ ...getOutput(
+ {
+ [nodes.pos.id]: `${group2.id}:0`,
+ [nodes.neg.id]: `${group2.id}:1`,
+ [nodes.empty.id]: `${group2.id}:2`,
+ [nodes.sampler.id]: `${group2.id}:3`,
+ [nodes.decode.id]: `${group2.id}:4`,
+ [nodes.save.id]: `${group2.id}:5`,
+ },
+ {
+ [nodes.pos.id]: { text: "world" },
+ [nodes.empty.id]: { width: 1024 },
+ [nodes.sampler.id]: { seed: 100 },
+ }
+ ),
+ });
+
+ graph.arrange();
+ });
+ test("is embedded in workflow", async () => {
+ let { ez, graph, app } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+ let group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
+ const workflow = JSON.stringify((await graph.toPrompt()).workflow);
+
+ // Clear the environment
+ ({ ez, graph, app } = await start({
+ resetEnv: true,
+ }));
+ // Ensure the node isnt registered
+ expect(() => ez["workflow/test"]).toThrow();
+
+ // Relaod the workflow
+ await app.loadGraphData(JSON.parse(workflow));
+
+ // Ensure the node is found
+ group = graph.find(group);
+
+ // Generate prompt and ensure it is as expected
+ expect((await graph.toPrompt()).output).toEqual(
+ getOutput({
+ [nodes.pos.id]: `${group.id}:0`,
+ [nodes.neg.id]: `${group.id}:1`,
+ })
+ );
+ });
+});
diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js
index 0e81fd47b..898b82db0 100644
--- a/tests-ui/utils/ezgraph.js
+++ b/tests-ui/utils/ezgraph.js
@@ -150,7 +150,7 @@ export class EzNodeMenuItem {
if (selectNode) {
this.node.select();
}
- this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
+ return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
}
}
@@ -240,8 +240,12 @@ export class EzNode {
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
}
- select() {
- this.app.canvas.selectNode(this.node);
+ get isRemoved() {
+ return !this.app.graph.getNodeById(this.id);
+ }
+
+ select(addToSelection = false) {
+ this.app.canvas.selectNode(this.node, addToSelection);
}
// /**
@@ -275,12 +279,17 @@ export class EzNode {
if (!s) return p;
const name = s[nameProperty];
+ const item = new ctor(this, i, s);
// @ts-ignore
- if (!name || name in p) {
- throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
+ p.push(item);
+ if (name) {
+ // @ts-ignore
+ if (name in p) {
+ throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
+ }
}
// @ts-ignore
- p.push((p[name] = new ctor(this, i, s)));
+ p[name] = item;
return p;
}, Object.assign([], { $: this }));
}
@@ -348,6 +357,19 @@ export class EzGraph {
}, 10);
});
}
+
+ /**
+ * @returns { Promise<{
+ * workflow: {},
+ * output: Record
+ * }>}> }
+ */
+ toPrompt() {
+ // @ts-ignore
+ return this.app.graphToPrompt();
+ }
}
export const Ez = {
@@ -356,12 +378,12 @@ export const Ez = {
* @example
* const { ez, graph } = Ez.graph(app);
* graph.clear();
- * const [model, clip, vae] = ez.CheckpointLoaderSimple();
- * const [pos] = ez.CLIPTextEncode(clip, { text: "positive" });
- * const [neg] = ez.CLIPTextEncode(clip, { text: "negative" });
- * const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage());
- * const [image] = ez.VAEDecode(latent, vae);
- * const saveNode = ez.SaveImage(image).node;
+ * const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
+ * const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
+ * const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
+ * const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
+ * const [image] = ez.VAEDecode(latent, vae).outputs;
+ * const saveNode = ez.SaveImage(image);
* console.log(saveNode);
* graph.arrange();
* @param { app } app
diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js
index 01c58b21f..eeccdb3d9 100644
--- a/tests-ui/utils/index.js
+++ b/tests-ui/utils/index.js
@@ -1,21 +1,28 @@
const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph");
+const lg = require("./litegraph");
/**
*
- * @param { Parameters[0] } config
+ * @param { Parameters[0] & { resetEnv?: boolean } } config
* @returns
*/
export async function start(config = undefined) {
+ if(config?.resetEnv) {
+ jest.resetModules();
+ jest.resetAllMocks();
+ lg.setup(global);
+ }
+
mockApi(config);
const { app } = require("../../web/scripts/app");
await app.setup();
- return Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]);
+ return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
}
/**
- * @param { ReturnType["graph"] } graph
- * @param { (hasReloaded: boolean) => (Promise | void) } cb
+ * @param { ReturnType["graph"] } graph
+ * @param { (hasReloaded: boolean) => (Promise | void) } cb
*/
export async function checkBeforeAndAfterReload(graph, cb) {
await cb(false);
@@ -24,10 +31,10 @@ export async function checkBeforeAndAfterReload(graph, cb) {
}
/**
- * @param { string } name
- * @param { Record } input
+ * @param { string } name
+ * @param { Record } input
* @param { (string | string[])[] | Record } output
- * @returns { Record }
+ * @returns { Record }
*/
export function makeNodeDef(name, input, output = {}) {
const nodeDef = {
@@ -37,19 +44,19 @@ export function makeNodeDef(name, input, output = {}) {
output_name: [],
output_is_list: [],
input: {
- required: {}
+ required: {},
},
};
- for(const k in input) {
+ for (const k in input) {
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
}
- if(output instanceof Array) {
+ if (output instanceof Array) {
output = output.reduce((p, c) => {
p[c] = c;
return p;
- }, {})
+ }, {});
}
- for(const k in output) {
+ for (const k in output) {
nodeDef.output.push(output[k]);
nodeDef.output_name.push(k);
nodeDef.output_is_list.push(false);
@@ -68,4 +75,31 @@ export function assertNotNullOrUndefined(x) {
expect(x).not.toEqual(null);
expect(x).not.toEqual(undefined);
return true;
-}
\ No newline at end of file
+}
+
+/**
+ *
+ * @param { ReturnType["ez"] } ez
+ * @param { ReturnType["graph"] } graph
+ */
+export function createDefaultWorkflow(ez, graph) {
+ graph.clear();
+ const ckpt = ez.CheckpointLoaderSimple();
+
+ const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
+ const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
+
+ const empty = ez.EmptyLatentImage();
+ const sampler = ez.KSampler(
+ ckpt.outputs.MODEL,
+ pos.outputs.CONDITIONING,
+ neg.outputs.CONDITIONING,
+ empty.outputs.LATENT
+ );
+
+ const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
+ const save = ez.SaveImage(decode.outputs.IMAGE);
+ graph.arrange();
+
+ return { ckpt, pos, neg, empty, sampler, decode, save };
+}
diff --git a/tests-ui/utils/setup.js b/tests-ui/utils/setup.js
index 17e8ac1ad..5793ded2b 100644
--- a/tests-ui/utils/setup.js
+++ b/tests-ui/utils/setup.js
@@ -30,16 +30,19 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
}
+ const events = new EventTarget();
+ const mockApi = {
+ addEventListener: events.addEventListener.bind(events),
+ dispatchEvent: events.dispatchEvent.bind(events),
+ getSystemStats: jest.fn(),
+ getExtensions: jest.fn(() => mockExtensions),
+ getNodeDefs: jest.fn(() => mockNodeDefs),
+ init: jest.fn(),
+ apiURL: jest.fn((x) => "../../web/" + x),
+ };
jest.mock("../../web/scripts/api", () => ({
get api() {
- return {
- addEventListener: jest.fn(),
- getSystemStats: jest.fn(),
- getExtensions: jest.fn(() => mockExtensions),
- getNodeDefs: jest.fn(() => mockNodeDefs),
- init: jest.fn(),
- apiURL: jest.fn((x) => "../../web/" + x),
- };
+ return mockApi;
},
}));
}