Group node refactor

Support for primitives/converted widgets
This commit is contained in:
pythongosssss 2023-11-26 21:40:25 +00:00
parent c44d8df7b1
commit abdff29823
6 changed files with 1119 additions and 829 deletions

View File

@ -115,12 +115,8 @@ describe("group node", () => {
expect(group.inputs).toHaveLength(2);
expect(group.outputs).toHaveLength(3);
expect(group.inputs.map((i) => i.input.name)).toEqual(["CLIPTextEncode clip", "CLIPTextEncode 2 clip"]);
expect(group.outputs.map((i) => i.output.name)).toEqual([
"EmptyLatentImage LATENT",
"CLIPTextEncode CONDITIONING",
"CLIPTextEncode 2 CONDITIONING",
]);
expect(group.inputs.map((i) => i.input.name)).toEqual(["clip", "CLIPTextEncode clip"]);
expect(group.outputs.map((i) => i.output.name)).toEqual(["LATENT", "CONDITIONING", "CLIPTextEncode CONDITIONING"]);
// ckpt clip to both clip inputs on the group
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
@ -129,17 +125,17 @@ describe("group node", () => {
]);
// group conditioning to sampler
expect(
group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
).toEqual([[nodes.sampler.id, 1]]);
expect(group.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
// group conditioning 2 to sampler
expect(
group.outputs["CLIPTextEncode 2 CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
).toEqual([[nodes.sampler.id, 2]]);
// group latent to sampler
expect(
group.outputs["EmptyLatentImage LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
).toEqual([[nodes.sampler.id, 3]]);
expect(group.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("maintains all output links on conversion", async () => {
@ -169,16 +165,16 @@ describe("group node", () => {
const group = await convertToGroup(app, graph, "test", toConvert);
// Edit some values to ensure they are set back onto the converted nodes
expect(group.widgets["CLIPTextEncode text"].value).toBe("positive");
group.widgets["CLIPTextEncode text"].value = "pos";
expect(group.widgets["CLIPTextEncode 2 text"].value).toBe("negative");
group.widgets["CLIPTextEncode 2 text"].value = "neg";
expect(group.widgets["EmptyLatentImage width"].value).toBe(512);
group.widgets["EmptyLatentImage width"].value = 1024;
expect(group.widgets["KSampler sampler_name"].value).toBe("euler");
group.widgets["KSampler sampler_name"].value = "ddim";
expect(group.widgets["KSampler control_after_generate"].value).toBe("randomize");
group.widgets["KSampler control_after_generate"].value = "fixed";
expect(group.widgets["text"].value).toBe("positive");
group.widgets["text"].value = "pos";
expect(group.widgets["CLIPTextEncode text"].value).toBe("negative");
group.widgets["CLIPTextEncode text"].value = "neg";
expect(group.widgets["width"].value).toBe(512);
group.widgets["width"].value = 1024;
expect(group.widgets["sampler_name"].value).toBe("euler");
group.widgets["sampler_name"].value = "ddim";
expect(group.widgets["control_after_generate"].value).toBe("randomize");
group.widgets["control_after_generate"].value = "fixed";
/** @type { Array<any> } */
group.menu["Convert to nodes"].call();
@ -296,18 +292,18 @@ describe("group node", () => {
nodes.sampler,
]);
expect(group.widgets["CheckpointLoaderSimple ckpt_name"].value).toEqual("model2.ckpt");
expect(group.widgets["CLIPTextEncode text"].value).toEqual("hello");
expect(group.widgets["CLIPTextEncode 2 text"].value).toEqual("world");
expect(group.widgets["EmptyLatentImage width"].value).toEqual(256);
expect(group.widgets["EmptyLatentImage height"].value).toEqual(1024);
expect(group.widgets["KSampler seed"].value).toEqual(1);
expect(group.widgets["KSampler control_after_generate"].value).toEqual("increment");
expect(group.widgets["KSampler steps"].value).toEqual(8);
expect(group.widgets["KSampler cfg"].value).toEqual(4.5);
expect(group.widgets["KSampler sampler_name"].value).toEqual("uni_pc");
expect(group.widgets["KSampler scheduler"].value).toEqual("karras");
expect(group.widgets["KSampler denoise"].value).toEqual(0.9);
expect(group.widgets["ckpt_name"].value).toEqual("model2.ckpt");
expect(group.widgets["text"].value).toEqual("hello");
expect(group.widgets["CLIPTextEncode text"].value).toEqual("world");
expect(group.widgets["width"].value).toEqual(256);
expect(group.widgets["height"].value).toEqual(1024);
expect(group.widgets["seed"].value).toEqual(1);
expect(group.widgets["control_after_generate"].value).toEqual("increment");
expect(group.widgets["steps"].value).toEqual(8);
expect(group.widgets["cfg"].value).toEqual(4.5);
expect(group.widgets["sampler_name"].value).toEqual("uni_pc");
expect(group.widgets["scheduler"].value).toEqual("karras");
expect(group.widgets["denoise"].value).toEqual(0.9);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], {
@ -360,8 +356,8 @@ describe("group node", () => {
const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]);
group1.outputs[0].connectTo(group2.inputs["KSampler positive"]);
group1.outputs[1].connectTo(group2.inputs["KSampler negative"]);
group1.outputs[0].connectTo(group2.inputs["positive"]);
group1.outputs[1].connectTo(group2.inputs["negative"]);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
@ -370,7 +366,7 @@ describe("group node", () => {
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [
let group = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
@ -380,15 +376,16 @@ describe("group node", () => {
]);
const { api } = require("../../web/scripts/api");
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:3` }));
api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${group.id}:3`,
node: `${nodes.save.id}`,
output: {
images: [
{
@ -410,6 +407,46 @@ describe("group node", () => {
type: "output",
},
]);
// Reload
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
await app.loadGraphData(JSON.parse(workflow));
group = graph.find(group);
// Trigger inner nodes to get created
group.node["getInnerNodes"]();
// Check it works for internal node ids
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:5` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${group.id}:5`,
output: {
images: [
{
filename: "test2.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test2.png",
type: "output",
},
]);
});
test("allows widgets to be converted to inputs", async () => {
const { ez, graph, app } = await start();
@ -418,7 +455,7 @@ describe("group node", () => {
group.widgets[0].convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs["CLIPTextEncode text"]);
primitive.outputs[0].connectTo(group.inputs["text"]);
primitive.widgets[0].value = "hello";
expect((await graph.toPrompt()).output).toEqual(
@ -440,9 +477,9 @@ describe("group node", () => {
nodes.save,
]);
group1.widgets["CLIPTextEncode text"].value = "hello";
group1.widgets["EmptyLatentImage width"].value = 256;
group1.widgets["KSampler seed"].value = 1;
group1.widgets["text"].value = "hello";
group1.widgets["width"].value = 256;
group1.widgets["seed"].value = 1;
// Clone the node
group1.menu.Clone.call();
@ -452,14 +489,14 @@ describe("group node", () => {
expect(group2.id).not.toEqual(group1.id);
// Reconnect ckpt
nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["KSampler model"]);
nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["model"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["clip"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode 2 clip"]);
nodes.ckpt.outputs.VAE.connectTo(group2.inputs["VAEDecode vae"]);
nodes.ckpt.outputs.VAE.connectTo(group2.inputs["vae"]);
group2.widgets["CLIPTextEncode text"].value = "world";
group2.widgets["EmptyLatentImage width"].value = 1024;
group2.widgets["KSampler seed"].value = 100;
group2.widgets["text"].value = "world";
group2.widgets["width"].value = 1024;
group2.widgets["seed"].value = 100;
let i = 0;
expect((await graph.toPrompt()).output).toEqual({
@ -567,9 +604,9 @@ describe("group node", () => {
primitive.outputs[0].connectTo(neg.inputs.text);
const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]);
// These will both be the same due to the primitive
expect(group.widgets["Positive text"].value).toBe("positive");
expect(group.widgets["Negative text"].value).toBe("positive");
// This will use a primitive widget named 'value'
expect(group.widgets.length).toBe(1);
expect(group.widgets["value"].value).toBe("positive");
const newNodes = group.menu["Convert to nodes"].call();
pos = graph.find(newNodes.find((n) => n.title === "Positive"));
@ -599,14 +636,14 @@ describe("group node", () => {
const group = await convertToGroup(app, graph, "test", [scale, save, empty, decode]);
const widgets = group.widgets.map((w) => w.widget.name);
expect(widgets).toStrictEqual([
"EmptyLatentImage width",
"EmptyLatentImage height",
"EmptyLatentImage batch_size",
"LatentUpscale upscale_method",
"width",
"height",
"batch_size",
"upscale_method",
"LatentUpscale width",
"LatentUpscale height",
"LatentUpscale crop",
"SaveImage filename_prefix",
"crop",
"filename_prefix",
]);
});
test("adds output for external links when converting to group", async () => {
@ -653,11 +690,11 @@ describe("group node", () => {
const group = await convertToGroup(app, graph, "test", [vae, decode1, encode, sampler]);
expect(group.outputs.length).toBe(3);
expect(group.outputs[0].output.name).toBe("VAELoader VAE");
expect(group.outputs[0].output.name).toBe("VAE");
expect(group.outputs[0].output.type).toBe("VAE");
expect(group.outputs[1].output.name).toBe("VAEDecode IMAGE");
expect(group.outputs[1].output.name).toBe("IMAGE");
expect(group.outputs[1].output.type).toBe("IMAGE");
expect(group.outputs[2].output.name).toBe("VAEEncode LATENT");
expect(group.outputs[2].output.name).toBe("LATENT");
expect(group.outputs[2].output.type).toBe("LATENT");
expect(group.outputs[0].connections.length).toBe(1);
@ -686,7 +723,7 @@ describe("group node", () => {
const preview1 = ez.PreviewImage(img.outputs[0]);
const group = await convertToGroup(app, graph, "test", [img, preview1]);
const widget = group.widgets["LoadImage upload"];
const widget = group.widgets["upload"];
expect(widget).toBeTruthy();
expect(widget.widget.type).toBe("button");
});

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
import { app } from "../../scripts/app.js";
import { ComfyDialog, $el } from "../../scripts/ui.js";
import { GROUP_DATA, IS_GROUP_NODE, registerGroupNodes } from "./groupNode.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
// Adds the ability to save and add multiple nodes as a template
// To save:
@ -320,13 +320,15 @@ app.registerExtension({
for (let i = 0; i < nodeIds.length; i++) {
const node = app.graph.getNodeById(nodeIds[i]);
const nodeData = node?.constructor.nodeData;
if (nodeData?.[IS_GROUP_NODE]) {
const groupData = nodeData[GROUP_DATA];
let groupData = GroupNodeHandler.getGroupData(node);
if (groupData) {
groupData = groupData.nodeData;
if (!data.groupNodes) {
data.groupNodes = {};
}
data.groupNodes[nodeData.name] = groupData;
data.nodes[i].type = "workflow/" + nodeData.name;
data.nodes[i].type = nodeData.name;
}
}
@ -346,7 +348,7 @@ app.registerExtension({
callback: () => {
clipboardAction(async () => {
const data = JSON.parse(t.data);
await registerGroupNodes(data.groupNodes, "workflow", t.name);
await GroupNodeConfig.registerFromWorkflow(data.groupNodes, {});
localStorage.setItem("litegrapheditor_clipboard", t.data);
app.canvas.pasteFromClipboard();
});

View File

@ -152,13 +152,13 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
let v1 = config1[1][k];
let v2 = config2[1][k];
let v2 = config2[1]?.[k];
if (v1 === v2 || (!v1 && !v2)) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]["max"];
const theirMax = config2[1]?.["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return false;
@ -166,7 +166,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]["min"];
const theirMin = config2[1]?.["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return false;
@ -211,7 +211,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
output.widget[CONFIG] = [config1[0], customConfig];
}
const widget = recreateWidget?.();
const widget = recreateWidget?.call(this);
// When deleting a node this can be null
if (widget) {
const min = widget.options.min;
@ -570,12 +570,12 @@ app.registerExtension({
}
}
if ((widget.type === "number" && !inputData?.[1]?.control_after_generate) || widget.type === "combo") {
if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
let control_value = this.widgets_values?.[1];
if (!control_value) {
control_value = "fixed";
}
addValueControlWidgets(this, widget, control_value);
addValueControlWidgets(this, widget, control_value, undefined, inputData);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
@ -657,7 +657,7 @@ app.registerExtension({
// Only allow connections where the configs match
const output = this.outputs[0];
const config2 = input.widget[GET_CONFIG]();
return !!mergeIfValid(output, config2, forceUpdate, this.#recreateWidget);
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
}
#removeWidgets() {

View File

@ -44,7 +44,7 @@ function getClipPath(node, element, elRect) {
}
function computeSize(size) {
if (this.widgets?.[0].last_y == null) return;
if (this.widgets?.[0]?.last_y == null) return;
let y = this.widgets[0].last_y;
let freeSpace = size[1] - y;
@ -195,7 +195,6 @@ export function addDomClippingSetting() {
type: "boolean",
defaultValue: enableDomClipping,
onChange(value) {
console.log("enableDomClipping", enableDomClipping);
enableDomClipping = !!value;
},
});

View File

@ -37,29 +37,59 @@ export function getWidgetType(inputData, inputName) {
}
}
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName) {
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, {
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) {
let name = inputData[1]?.control_after_generate;
if(typeof name !== "string") {
name = widgetName;
}
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, {
addFilterList: false,
});
controlAfterGenerateName: name
}, inputData);
return widgets[0];
}
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) {
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", options, inputData) {
if (!defaultValue) defaultValue = "randomize";
if (!options) options = {};
const getName = (defaultName, optionName) => {
let name = defaultName;
if (options[optionName]) {
name = options[optionName];
} else if (typeof inputData?.[1]?.[defaultName] === "string") {
name = inputData?.[1]?.[defaultName];
} else if (inputData?.[1]?.control_prefix) {
name = inputData?.[1]?.control_prefix + " " + name
}
return name;
}
const widgets = [];
const valueControl = node.addWidget("combo", widgetName ?? "control_after_generate", defaultValue, function (v) { }, {
values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt.
});
const valueControl = node.addWidget(
"combo",
getName("control_after_generate", "controlAfterGenerateName"),
defaultValue,
function () {},
{
values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt.
}
);
widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
let comboFilter;
if (isCombo && options.addFilterList !== false) {
comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, {
serialize: false, // Don't include this in prompt.
});
comboFilter = node.addWidget(
"string",
getName("control_filter_list", "controlFilterListName"),
"",
function () {},
{
serialize: false, // Don't include this in prompt.
}
);
widgets.push(comboFilter);
}
@ -148,7 +178,7 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando
function seedWidget(node, inputName, inputData, app, widgetName) {
const seed = createIntWidget(node, inputName, inputData, app, true);
const seedControl = addValueControlWidget(node, seed.widget, "randomize", undefined, widgetName);
const seedControl = addValueControlWidget(node, seed.widget, "randomize", undefined, widgetName, inputData);
seed.widget.linkedWidgets = [seedControl];
return seed;
@ -181,7 +211,7 @@ function addMultilineWidget(node, name, opts, app) {
const inputEl = document.createElement("textarea");
inputEl.className = "comfy-multiline-input";
inputEl.value = opts.defaultVal;
inputEl.placeholder = opts.placeholder || "";
inputEl.placeholder = opts.placeholder || name;
const widget = node.addDOMWidget(name, "customtext", inputEl, {
getValue() {
@ -272,7 +302,11 @@ export const ComfyWidgets = {
if (inputData[1] && inputData[1].default) {
defaultValue = inputData[1].default;
}
return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
const res = { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
if (inputData[1]?.control_after_generate) {
addValueControlWidgets(node, res.widget, undefined, undefined, inputData);
}
return res;
},
IMAGEUPLOAD(node, inputName, inputData, app) {
const imageWidget = node.widgets.find((w) => w.name === (inputData[1]?.widget ?? "image"));