ComfyUI/tests-ui/tests/groupNode.test.js
Jacob Segal 36b2214e30 Execution Model Inversion
This PR inverts the execution model -- from recursively calling nodes to
using a topological sort of the nodes. This change allows for
modification of the node graph during execution. This allows for two
major advantages:

    1. The implementation of lazy evaluation in nodes. For example, if a
    "Mix Images" node has a mix factor of exactly 0.0, the second image
    input doesn't even need to be evaluated (and visa-versa if the mix
    factor is 1.0).

    2. Dynamic expansion of nodes. This allows for the creation of dynamic
    "node groups". Specifically, custom nodes can return subgraphs that
    replace the original node in the graph. This is an incredibly
    powerful concept. Using this functionality, it was easy to
    implement:
        a. Components (a.k.a. node groups)
        b. Flow control (i.e. while loops) via tail recursion
        c. All-in-one nodes that replicate the WebUI functionality
        d. and more
    All of those were able to be implemented entirely via custom nodes,
    so those features are *not* a part of this PR. (There are some
    front-end changes that should occur before that functionality is
    made widely available, particularly around variant sockets.)

The custom nodes associated with this PR can be found at:
https://github.com/BadCafeCode/execution-inversion-demo-comfyui

Note that some of them require that variant socket types ("*") be
enabled.
2024-01-28 20:48:42 -08:00

1008 lines
37 KiB
JavaScript

// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
const lg = require("../utils/litegraph");
describe("group node", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
/**
*
* @param {*} app
* @param {*} graph
* @param {*} name
* @param {*} nodes
* @returns { Promise<InstanceType<import("../utils/ezgraph")["EzNode"]>> }
*/
async function convertToGroup(app, graph, name, nodes) {
// Select the nodes we are converting
for (const n of nodes) {
n.select(true);
}
expect(Object.keys(app.canvas.selected_nodes).sort((a, b) => +a - +b)).toEqual(
nodes.map((n) => n.id + "").sort((a, b) => +a - +b)
);
global.prompt = jest.fn().mockImplementation(() => name);
const groupNode = await nodes[0].menu["Convert to Group Node"].call(false);
// Check group name was requested
expect(window.prompt).toHaveBeenCalled();
// Ensure old nodes are removed
for (const n of nodes) {
expect(n.isRemoved).toBeTruthy();
}
expect(groupNode.type).toEqual("workflow/" + name);
return graph.find(groupNode);
}
/**
* @param { Record<string, string | number> | number[] } idMap
* @param { Record<string, Record<string, unknown>> } valueMap
*/
function getOutput(idMap = {}, valueMap = {}) {
if (idMap instanceof Array) {
idMap = idMap.reduce((p, n) => {
p[n] = n + "";
return p;
}, {});
}
const expected = {
1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" },
3: { inputs: { text: "negative", clip: ["1", 1], ...valueMap?.[3] }, class_type: "CLIPTextEncode" },
4: { inputs: { width: 512, height: 512, batch_size: 1, ...valueMap?.[4] }, class_type: "EmptyLatentImage" },
5: {
inputs: {
seed: 0,
steps: 20,
cfg: 8,
sampler_name: "euler",
scheduler: "normal",
denoise: 1,
model: ["1", 0],
positive: ["2", 0],
negative: ["3", 0],
latent_image: ["4", 0],
...valueMap?.[5],
},
class_type: "KSampler",
},
6: { inputs: { samples: ["5", 0], vae: ["1", 2], ...valueMap?.[6] }, class_type: "VAEDecode" },
7: { inputs: { filename_prefix: "ComfyUI", images: ["6", 0], ...valueMap?.[7] }, class_type: "SaveImage" },
};
// Map old IDs to new at the top level
const mapped = {};
for (const oldId in idMap) {
mapped[idMap[oldId]] = expected[oldId];
delete expected[oldId];
}
Object.assign(mapped, expected);
// Map old IDs to new inside links
for (const k in mapped) {
for (const input in mapped[k].inputs) {
const v = mapped[k].inputs[input];
if (v instanceof Array) {
if (v[0] in idMap) {
v[0] = idMap[v[0]] + "";
}
}
}
}
return mapped;
}
test("can be created from selected nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
// Ensure links are now to the group node
expect(group.inputs).toHaveLength(2);
expect(group.outputs).toHaveLength(3);
expect(group.inputs.map((i) => i.input.name)).toEqual(["clip", "CLIPTextEncode clip"]);
expect(group.outputs.map((i) => i.output.name)).toEqual(["LATENT", "CONDITIONING", "CLIPTextEncode CONDITIONING"]);
// ckpt clip to both clip inputs on the group
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[group.id, 0],
[group.id, 1],
]);
// group conditioning to sampler
expect(group.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
// group conditioning 2 to sampler
expect(
group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
).toEqual([[nodes.sampler.id, 2]]);
// group latent to sampler
expect(group.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("maintains all output links on conversion", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const save2 = ez.SaveImage(...nodes.decode.outputs);
const save3 = ez.SaveImage(...nodes.decode.outputs);
// Ensure an output with multiple links maintains them on convert to group
const group = await convertToGroup(app, graph, "test", [nodes.sampler, nodes.decode]);
expect(group.outputs[0].connections.length).toBe(3);
expect(group.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
expect(group.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(group.outputs[0].connections[2].targetNode.id).toBe(save3.id);
// and they're still linked when converting back to nodes
const newNodes = group.menu["Convert to nodes"].call();
const decode = graph.find(newNodes.find((n) => n.type === "VAEDecode"));
expect(decode.outputs[0].connections.length).toBe(3);
expect(decode.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
expect(decode.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(decode.outputs[0].connections[2].targetNode.id).toBe(save3.id);
});
test("can be be converted back to nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const toConvert = [nodes.pos, nodes.neg, nodes.empty, nodes.sampler];
const group = await convertToGroup(app, graph, "test", toConvert);
// Edit some values to ensure they are set back onto the converted nodes
expect(group.widgets["text"].value).toBe("positive");
group.widgets["text"].value = "pos";
expect(group.widgets["CLIPTextEncode text"].value).toBe("negative");
group.widgets["CLIPTextEncode text"].value = "neg";
expect(group.widgets["width"].value).toBe(512);
group.widgets["width"].value = 1024;
expect(group.widgets["sampler_name"].value).toBe("euler");
group.widgets["sampler_name"].value = "ddim";
expect(group.widgets["control_after_generate"].value).toBe("randomize");
group.widgets["control_after_generate"].value = "fixed";
/** @type { Array<any> } */
group.menu["Convert to nodes"].call();
// ensure widget values are set
const pos = graph.find(nodes.pos.id);
expect(pos.node.type).toBe("CLIPTextEncode");
expect(pos.widgets["text"].value).toBe("pos");
const neg = graph.find(nodes.neg.id);
expect(neg.node.type).toBe("CLIPTextEncode");
expect(neg.widgets["text"].value).toBe("neg");
const empty = graph.find(nodes.empty.id);
expect(empty.node.type).toBe("EmptyLatentImage");
expect(empty.widgets["width"].value).toBe(1024);
const sampler = graph.find(nodes.sampler.id);
expect(sampler.node.type).toBe("KSampler");
expect(sampler.widgets["sampler_name"].value).toBe("ddim");
expect(sampler.widgets["control_after_generate"].value).toBe("fixed");
// validate links
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[pos.id, 0],
[neg.id, 0],
]);
expect(pos.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
expect(neg.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 2],
]);
expect(empty.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("it can embed reroutes as inputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Add and connect a reroute to the clip text encodes
const reroute = ez.Reroute();
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
reroute.outputs[0].connectTo(nodes.pos.inputs[0]);
reroute.outputs[0].connectTo(nodes.neg.inputs[0]);
// Convert to group and ensure we only have 1 input of the correct type
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty, reroute]);
expect(group.inputs).toHaveLength(1);
expect(group.inputs[0].input.type).toEqual("CLIP");
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("it can embed reroutes as outputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Add a reroute with no output so we output IMAGE even though its used internally
const reroute = ez.Reroute();
nodes.decode.outputs.IMAGE.connectTo(reroute.inputs[0]);
// Convert to group and ensure there is an IMAGE output
const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]);
expect(group.outputs).toHaveLength(1);
expect(group.outputs[0].output.type).toEqual("IMAGE");
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.decode.id, nodes.save.id]));
});
test("it can embed reroutes as pipes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Use reroutes as a pipe
const rerouteModel = ez.Reroute();
const rerouteClip = ez.Reroute();
const rerouteVae = ez.Reroute();
nodes.ckpt.outputs.MODEL.connectTo(rerouteModel.inputs[0]);
nodes.ckpt.outputs.CLIP.connectTo(rerouteClip.inputs[0]);
nodes.ckpt.outputs.VAE.connectTo(rerouteVae.inputs[0]);
const group = await convertToGroup(app, graph, "test", [rerouteModel, rerouteClip, rerouteVae]);
expect(group.outputs).toHaveLength(3);
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
expect(group.outputs).toHaveLength(3);
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
group.outputs[0].connectTo(nodes.sampler.inputs.model);
group.outputs[1].connectTo(nodes.pos.inputs.clip);
group.outputs[1].connectTo(nodes.neg.inputs.clip);
});
test("can handle reroutes used internally", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let reroutes = [];
let prevNode = nodes.ckpt;
for (let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
reroutes.push(reroute);
}
prevNode.outputs[0].connectTo(nodes.sampler.inputs.model);
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());
group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("creates with widget values from inner nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.ckpt.widgets.ckpt_name.value = "model2.ckpt";
nodes.pos.widgets.text.value = "hello";
nodes.neg.widgets.text.value = "world";
nodes.empty.widgets.width.value = 256;
nodes.empty.widgets.height.value = 1024;
nodes.sampler.widgets.seed.value = 1;
nodes.sampler.widgets.control_after_generate.value = "increment";
nodes.sampler.widgets.steps.value = 8;
nodes.sampler.widgets.cfg.value = 4.5;
nodes.sampler.widgets.sampler_name.value = "uni_pc";
nodes.sampler.widgets.scheduler.value = "karras";
nodes.sampler.widgets.denoise.value = 0.9;
const group = await convertToGroup(app, graph, "test", [
nodes.ckpt,
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
]);
expect(group.widgets["ckpt_name"].value).toEqual("model2.ckpt");
expect(group.widgets["text"].value).toEqual("hello");
expect(group.widgets["CLIPTextEncode text"].value).toEqual("world");
expect(group.widgets["width"].value).toEqual(256);
expect(group.widgets["height"].value).toEqual(1024);
expect(group.widgets["seed"].value).toEqual(1);
expect(group.widgets["control_after_generate"].value).toEqual("increment");
expect(group.widgets["steps"].value).toEqual(8);
expect(group.widgets["cfg"].value).toEqual(4.5);
expect(group.widgets["sampler_name"].value).toEqual("uni_pc");
expect(group.widgets["scheduler"].value).toEqual("karras");
expect(group.widgets["denoise"].value).toEqual(0.9);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], {
[nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
[nodes.pos.id]: { text: "hello" },
[nodes.neg.id]: { text: "world" },
[nodes.empty.id]: { width: 256, height: 1024 },
[nodes.sampler.id]: {
seed: 1,
steps: 8,
cfg: 4.5,
sampler_name: "uni_pc",
scheduler: "karras",
denoise: 0.9,
},
})
);
});
test("group inputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const reroute = ez.Reroute();
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[1]);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("group outputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const reroute1 = ez.Reroute();
const reroute2 = ez.Reroute();
group.outputs[0].connectTo(reroute1.inputs[0]);
group.outputs[1].connectTo(reroute2.inputs[0]);
reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive);
reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("groups can connect to each other", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]);
group1.outputs[0].connectTo(group2.inputs["positive"]);
group1.outputs[1].connectTo(group2.inputs["negative"]);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
);
});
test("groups can connect to each other via internal reroutes", async () => {
const { ez, graph, app } = await start();
const latent = ez.EmptyLatentImage();
const vae = ez.VAELoader();
const latentReroute = ez.Reroute();
const vaeReroute = ez.Reroute();
latent.outputs[0].connectTo(latentReroute.inputs[0]);
vae.outputs[0].connectTo(vaeReroute.inputs[0]);
const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]);
group1.menu.Clone.call();
expect(app.graph._nodes).toHaveLength(4);
const group2 = graph.find(app.graph._nodes[3]);
expect(group2.node.type).toEqual("workflow/test");
expect(group2.id).not.toEqual(group1.id);
group1.outputs.VAE.connectTo(group2.inputs.VAE);
group1.outputs.LATENT.connectTo(group2.inputs.LATENT);
const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE);
const preview = ez.PreviewImage(decode.outputs[0]);
const output = {
[latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" },
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" },
[preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" },
};
expect((await graph.toPrompt()).output).toEqual(output);
// Ensure missing connections dont cause errors
group2.inputs.VAE.disconnect();
delete output[decode.id].inputs.vae;
expect((await graph.toPrompt()).output).toEqual(output);
});
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let group = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
nodes.decode,
nodes.save,
]);
const { api } = require("../../web/scripts/api");
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${nodes.save.id}`,
display_node: `${nodes.save.id}`,
output: {
images: [
{
filename: "test.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test.png",
type: "output",
},
]);
// Reload
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
await app.loadGraphData(JSON.parse(workflow));
group = graph.find(group);
// Trigger inner nodes to get created
group.node["getInnerNodes"]();
// Check it works for internal node ids
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:5` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${group.id}:5`,
display_node: `${group.id}:5`,
output: {
images: [
{
filename: "test2.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test2.png",
type: "output",
},
]);
});
test("allows widgets to be converted to inputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
group.widgets[0].convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs["text"]);
primitive.widgets[0].value = "hello";
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id], {
[nodes.pos.id]: { text: "hello" },
})
);
});
test("can be copied", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group1 = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
nodes.decode,
nodes.save,
]);
group1.widgets["text"].value = "hello";
group1.widgets["width"].value = 256;
group1.widgets["seed"].value = 1;
// Clone the node
group1.menu.Clone.call();
expect(app.graph._nodes).toHaveLength(3);
const group2 = graph.find(app.graph._nodes[2]);
expect(group2.node.type).toEqual("workflow/test");
expect(group2.id).not.toEqual(group1.id);
// Reconnect ckpt
nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["model"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["clip"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]);
nodes.ckpt.outputs.VAE.connectTo(group2.inputs["vae"]);
group2.widgets["text"].value = "world";
group2.widgets["width"].value = 1024;
group2.widgets["seed"].value = 100;
let i = 0;
expect((await graph.toPrompt()).output).toEqual({
...getOutput([nodes.empty.id, nodes.pos.id, nodes.neg.id, nodes.sampler.id, nodes.decode.id, nodes.save.id], {
[nodes.empty.id]: { width: 256 },
[nodes.pos.id]: { text: "hello" },
[nodes.sampler.id]: { seed: 1 },
}),
...getOutput(
{
[nodes.empty.id]: `${group2.id}:${i++}`,
[nodes.pos.id]: `${group2.id}:${i++}`,
[nodes.neg.id]: `${group2.id}:${i++}`,
[nodes.sampler.id]: `${group2.id}:${i++}`,
[nodes.decode.id]: `${group2.id}:${i++}`,
[nodes.save.id]: `${group2.id}:${i++}`,
},
{
[nodes.empty.id]: { width: 1024 },
[nodes.pos.id]: { text: "world" },
[nodes.sampler.id]: { seed: 100 },
}
),
});
graph.arrange();
});
test("is embedded in workflow", async () => {
let { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
// Clear the environment
({ ez, graph, app } = await start({
resetEnv: true,
}));
// Ensure the node isnt registered
expect(() => ez["workflow/test"]).toThrow();
// Reload the workflow
await app.loadGraphData(JSON.parse(workflow));
// Ensure the node is found
group = graph.find(group);
// Generate prompt and ensure it is as expected
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.pos.id]: `${group.id}:0`,
[nodes.neg.id]: `${group.id}:1`,
})
);
});
test("shows missing node error on missing internal node when loading graph data", async () => {
const { graph } = await start();
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
await graph.app.loadGraphData({
last_node_id: 3,
last_link_id: 1,
nodes: [
{
id: 3,
type: "workflow/testerror",
},
],
links: [],
groups: [],
config: {},
extra: {
groupNodes: {
testerror: {
nodes: [
{
type: "NotKSampler",
},
{
type: "NotVAEDecode",
},
],
},
},
},
});
expect(dialogShow).toBeCalledTimes(1);
const call = dialogShow.mock.calls[0][0].innerHTML;
expect(call).toContain("the following node types were not found");
expect(call).toContain("NotKSampler");
expect(call).toContain("NotVAEDecode");
expect(call).toContain("workflow/testerror");
});
test("maintains widget inputs on conversion back to nodes", async () => {
const { ez, graph, app } = await start();
let pos = ez.CLIPTextEncode({ text: "positive" });
pos.node.title = "Positive";
let neg = ez.CLIPTextEncode({ text: "negative" });
neg.node.title = "Negative";
pos.widgets.text.convertToInput();
neg.widgets.text.convertToInput();
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(pos.inputs.text);
primitive.outputs[0].connectTo(neg.inputs.text);
const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]);
// This will use a primitive widget named 'value'
expect(group.widgets.length).toBe(1);
expect(group.widgets["value"].value).toBe("positive");
const newNodes = group.menu["Convert to nodes"].call();
pos = graph.find(newNodes.find((n) => n.title === "Positive"));
neg = graph.find(newNodes.find((n) => n.title === "Negative"));
primitive = graph.find(newNodes.find((n) => n.type === "PrimitiveNode"));
expect(pos.inputs).toHaveLength(2);
expect(neg.inputs).toHaveLength(2);
expect(primitive.outputs[0].connections).toHaveLength(2);
expect((await graph.toPrompt()).output).toEqual({
1: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
test("correctly handles widget inputs", async () => {
const { ez, graph, app } = await start();
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
const image = ez.LoadImage();
const scale1 = ez.ImageScaleBy(image.outputs[0]);
const scale2 = ez.ImageScaleBy(image.outputs[0]);
const preview1 = ez.PreviewImage(scale1.outputs[0]);
const preview2 = ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.upscale_method.value = upscaleMethods[1];
scale1.widgets.upscale_method.convertToInput();
const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
expect(group.inputs.length).toBe(3);
expect(group.inputs[0].input.type).toBe("IMAGE");
expect(group.inputs[1].input.type).toBe("IMAGE");
expect(group.inputs[2].input.type).toBe("COMBO");
// Ensure links are maintained
expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[2].connection).toBeFalsy();
// Ensure primitive gets correct type
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs[2]);
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
primitive.widgets.value.value = upscaleMethods[1];
await checkBeforeAndAfterReload(graph, async (r) => {
const scale1id = r ? `${group.id}:0` : scale1.id;
const scale2id = r ? `${group.id}:1` : scale2.id;
// Ensure widget value is applied to prompt
expect((await graph.toPrompt()).output).toStrictEqual({
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
[scale1id]: {
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[scale2id]: {
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
[preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
});
});
});
test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start();
const scale = ez.LatentUpscale();
const save = ez.SaveImage();
const empty = ez.EmptyLatentImage();
const decode = ez.VAEDecode();
scale.outputs.LATENT.connectTo(decode.inputs.samples);
decode.outputs.IMAGE.connectTo(save.inputs.images);
empty.outputs.LATENT.connectTo(scale.inputs.samples);
const group = await convertToGroup(app, graph, "test", [scale, save, empty, decode]);
const widgets = group.widgets.map((w) => w.widget.name);
expect(widgets).toStrictEqual([
"width",
"height",
"batch_size",
"upscale_method",
"LatentUpscale width",
"LatentUpscale height",
"crop",
"filename_prefix",
]);
});
test("adds output for external links when converting to group", async () => {
const { ez, graph, app } = await start();
const img = ez.EmptyLatentImage();
let decode = ez.VAEDecode(...img.outputs);
const preview1 = ez.PreviewImage(...decode.outputs);
const preview2 = ez.PreviewImage(...decode.outputs);
const group = await convertToGroup(app, graph, "test", [img, decode, preview1]);
// Ensure we have an output connected to the 2nd preview node
expect(group.outputs.length).toBe(1);
expect(group.outputs[0].connections.length).toBe(1);
expect(group.outputs[0].connections[0].targetNode.id).toBe(preview2.id);
// Convert back and ensure bothe previews are still connected
group.menu["Convert to nodes"].call();
decode = graph.find(decode);
expect(decode.outputs[0].connections.length).toBe(2);
expect(decode.outputs[0].connections[0].targetNode.id).toBe(preview1.id);
expect(decode.outputs[0].connections[1].targetNode.id).toBe(preview2.id);
});
test("adds output for external links when converting to group when nodes are not in execution order", async () => {
const { ez, graph, app } = await start();
const sampler = ez.KSampler();
const ckpt = ez.CheckpointLoaderSimple();
const empty = ez.EmptyLatentImage();
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
const decode1 = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
const save = ez.SaveImage(decode1.outputs.IMAGE);
ckpt.outputs.MODEL.connectTo(sampler.inputs.model);
pos.outputs.CONDITIONING.connectTo(sampler.inputs.positive);
neg.outputs.CONDITIONING.connectTo(sampler.inputs.negative);
empty.outputs.LATENT.connectTo(sampler.inputs.latent_image);
const encode = ez.VAEEncode(decode1.outputs.IMAGE);
const vae = ez.VAELoader();
const decode2 = ez.VAEDecode(encode.outputs.LATENT, vae.outputs.VAE);
const preview = ez.PreviewImage(decode2.outputs.IMAGE);
vae.outputs.VAE.connectTo(encode.inputs.vae);
const group = await convertToGroup(app, graph, "test", [vae, decode1, encode, sampler]);
expect(group.outputs.length).toBe(3);
expect(group.outputs[0].output.name).toBe("VAE");
expect(group.outputs[0].output.type).toBe("VAE");
expect(group.outputs[1].output.name).toBe("IMAGE");
expect(group.outputs[1].output.type).toBe("IMAGE");
expect(group.outputs[2].output.name).toBe("LATENT");
expect(group.outputs[2].output.type).toBe("LATENT");
expect(group.outputs[0].connections.length).toBe(1);
expect(group.outputs[0].connections[0].targetNode.id).toBe(decode2.id);
expect(group.outputs[0].connections[0].targetInput.index).toBe(1);
expect(group.outputs[1].connections.length).toBe(1);
expect(group.outputs[1].connections[0].targetNode.id).toBe(save.id);
expect(group.outputs[1].connections[0].targetInput.index).toBe(0);
expect(group.outputs[2].connections.length).toBe(1);
expect(group.outputs[2].connections[0].targetNode.id).toBe(decode2.id);
expect(group.outputs[2].connections[0].targetInput.index).toBe(0);
expect((await graph.toPrompt()).output).toEqual({
...getOutput({ 1: ckpt.id, 2: pos.id, 3: neg.id, 4: empty.id, 5: sampler.id, 6: decode1.id, 7: save.id }),
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: vae.node.type },
[encode.id]: { inputs: { pixels: ["6", 0], vae: [vae.id + "", 0] }, class_type: encode.node.type },
[decode2.id]: { inputs: { samples: [encode.id + "", 0], vae: [vae.id + "", 0] }, class_type: decode2.node.type },
[preview.id]: { inputs: { images: [decode2.id + "", 0] }, class_type: preview.node.type },
});
});
test("works with IMAGEUPLOAD widget", async () => {
const { ez, graph, app } = await start();
const img = ez.LoadImage();
const preview1 = ez.PreviewImage(img.outputs[0]);
const group = await convertToGroup(app, graph, "test", [img, preview1]);
const widget = group.widgets["upload"];
expect(widget).toBeTruthy();
expect(widget.widget.type).toBe("button");
});
test("internal primitive populates widgets for all linked inputs", async () => {
const { ez, graph, app } = await start();
const img = ez.LoadImage();
const scale1 = ez.ImageScale(img.outputs[0]);
const scale2 = ez.ImageScale(img.outputs[0]);
ez.PreviewImage(scale1.outputs[0]);
ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.width.convertToInput();
scale2.widgets.height.convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(scale1.inputs.width);
primitive.outputs[0].connectTo(scale2.inputs.height);
const group = await convertToGroup(app, graph, "test", [img, primitive, scale1, scale2]);
group.widgets.value.value = 100;
expect((await graph.toPrompt()).output).toEqual({
1: {
inputs: { image: img.widgets.image.value, upload: "image" },
class_type: "LoadImage",
},
2: {
inputs: { upscale_method: "nearest-exact", width: 100, height: 512, crop: "disabled", image: ["1", 0] },
class_type: "ImageScale",
},
3: {
inputs: { upscale_method: "nearest-exact", width: 512, height: 100, crop: "disabled", image: ["1", 0] },
class_type: "ImageScale",
},
4: { inputs: { images: ["2", 0] }, class_type: "PreviewImage" },
5: { inputs: { images: ["3", 0] }, class_type: "PreviewImage" },
});
});
test("primitive control widgets values are copied on convert", async () => {
const { ez, graph, app } = await start();
const sampler = ez.KSampler();
sampler.widgets.seed.convertToInput();
sampler.widgets.sampler_name.convertToInput();
let p1 = ez.PrimitiveNode();
let p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(sampler.inputs.seed);
p2.outputs[0].connectTo(sampler.inputs.sampler_name);
p1.widgets.control_after_generate.value = "increment";
p2.widgets.control_after_generate.value = "decrement";
p2.widgets.control_filter_list.value = "/.*/";
p2.node.title = "p2";
const group = await convertToGroup(app, graph, "test", [sampler, p1, p2]);
expect(group.widgets.control_after_generate.value).toBe("increment");
expect(group.widgets["p2 control_after_generate"].value).toBe("decrement");
expect(group.widgets["p2 control_filter_list"].value).toBe("/.*/");
group.widgets.control_after_generate.value = "fixed";
group.widgets["p2 control_after_generate"].value = "randomize";
group.widgets["p2 control_filter_list"].value = "/.+/";
group.menu["Convert to nodes"].call();
p1 = graph.find(p1);
p2 = graph.find(p2);
expect(p1.widgets.control_after_generate.value).toBe("fixed");
expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
});
test("internal reroutes work with converted inputs and merge options", async () => {
const { ez, graph, app } = await start();
const vae = ez.VAELoader();
const latent = ez.EmptyLatentImage();
const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE);
const scale = ez.ImageScale(decode.outputs.IMAGE);
ez.PreviewImage(scale.outputs.IMAGE);
const r1 = ez.Reroute();
const r2 = ez.Reroute();
latent.widgets.width.value = 64;
latent.widgets.height.value = 128;
latent.widgets.width.convertToInput();
latent.widgets.height.convertToInput();
latent.widgets.batch_size.convertToInput();
scale.widgets.width.convertToInput();
scale.widgets.height.convertToInput();
r1.inputs[0].input.label = "hbw";
r1.outputs[0].connectTo(latent.inputs.height);
r1.outputs[0].connectTo(latent.inputs.batch_size);
r1.outputs[0].connectTo(scale.inputs.width);
r2.inputs[0].input.label = "wh";
r2.outputs[0].connectTo(latent.inputs.width);
r2.outputs[0].connectTo(scale.inputs.height);
const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]);
expect(group.inputs[0].input.type).toBe("VAE");
expect(group.inputs[1].input.type).toBe("INT");
expect(group.inputs[2].input.type).toBe("INT");
const p1 = ez.PrimitiveNode();
const p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(group.inputs[1]);
p2.outputs[0].connectTo(group.inputs[2]);
expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p1.widgets.value.value).toBe(128);
expect(p2.widgets.value.value).toBe(64);
p1.widgets.value.value = 16;
p2.widgets.value.value = 32;
await checkBeforeAndAfterReload(graph, async (r) => {
const id = (v) => (r ? `${group.id}:` : "") + v;
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" },
[id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" },
[id(4)]: {
inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] },
class_type: "ImageScale",
},
5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" },
});
});
});
test("converted inputs with linked widgets map values correctly on creation", async () => {
const { ez, graph, app } = await start();
const k1 = ez.KSampler();
const k2 = ez.KSampler();
k1.widgets.seed.convertToInput();
k2.widgets.seed.convertToInput();
const rr = ez.Reroute();
rr.outputs[0].connectTo(k1.inputs.seed);
rr.outputs[0].connectTo(k2.inputs.seed);
const group = await convertToGroup(app, graph, "test", [k1, k2, rr]);
expect(group.widgets.steps.value).toBe(20);
expect(group.widgets.cfg.value).toBe(8);
expect(group.widgets.scheduler.value).toBe("normal");
expect(group.widgets["KSampler steps"].value).toBe(20);
expect(group.widgets["KSampler cfg"].value).toBe(8);
expect(group.widgets["KSampler scheduler"].value).toBe("normal");
});
test("allow multiple of the same node type to be added", async () => {
const { ez, graph, app } = await start();
const nodes = [...Array(10)].map(() => ez.ImageScaleBy());
const group = await convertToGroup(app, graph, "test", nodes);
expect(group.inputs.length).toBe(10);
expect(group.outputs.length).toBe(10);
expect(group.widgets.length).toBe(20);
expect(group.widgets.map((w) => w.widget.name)).toStrictEqual(
[...Array(10)]
.map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`)
.flatMap((p) => [`${p}upscale_method`, `${p}scale_by`])
);
});
});