Group node updates

- persist internal ids in current session
- copy widget values when converting to nodes
- fix issue serializing converted inputs
This commit is contained in:
pythongosssss 2023-11-08 20:53:16 +00:00
parent c7eea7cb8e
commit 2870b1c68c
4 changed files with 180 additions and 141 deletions

View File

@ -48,10 +48,16 @@ describe("group node", () => {
}
/**
* @param { Record<string, string> } idMap
* @param { Record<string, string> | number[] } idMap
* @param { Record<string, Record<string, unknown>> } valueMap
*/
function getOutput(idMap = {}, valueMap = {}) {
if (idMap instanceof Array) {
idMap = idMap.reduce((p, n) => {
p[n] = n + "";
return p;
}, {});
}
const expected = {
1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" },
@ -153,26 +159,41 @@ describe("group node", () => {
expect(decode.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(decode.outputs[0].connections[2].targetNode.id).toBe(save3.id);
});
test("can be be converted back to nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
const toConvert = [nodes.pos, nodes.neg, nodes.empty, nodes.sampler];
const group = await convertToGroup(app, graph, "test", toConvert);
// Edit some values to ensure they are set back onto the converted nodes
expect(group.widgets["CLIPTextEncode text"].value).toBe("positive");
group.widgets["CLIPTextEncode text"].value = "pos";
expect(group.widgets["CLIPTextEncode 2 text"].value).toBe("negative");
group.widgets["CLIPTextEncode 2 text"].value = "neg";
expect(group.widgets["EmptyLatentImage width"].value).toBe(512);
group.widgets["EmptyLatentImage width"].value = 1024;
expect(group.widgets["KSampler sampler_name"].value).toBe("euler");
group.widgets["KSampler sampler_name"].value = "ddim";
expect(group.widgets["KSampler control_after_generate"].value).toBe("randomize");
group.widgets["KSampler control_after_generate"].value = "fixed";
/** @type { Array<any> } */
const newNodes = group.menu["Convert to nodes"].call();
group.menu["Convert to nodes"].call();
const pos = graph.find(
newNodes.find(
(n) => n.type === "CLIPTextEncode" && n.widgets.find((w) => w.name === "text")?.value === "positive"
)
);
const neg = graph.find(
newNodes.find(
(n) => n.type === "CLIPTextEncode" && n.widgets.find((w) => w.name === "text")?.value === "negative"
)
);
const empty = graph.find(newNodes.find((n) => n.type === "EmptyLatentImage"));
// ensure widget values are set
const pos = graph.find(nodes.pos.id);
expect(pos.node.type).toBe("CLIPTextEncode");
expect(pos.widgets["text"].value).toBe("pos");
const neg = graph.find(nodes.neg.id);
expect(neg.node.type).toBe("CLIPTextEncode");
expect(neg.widgets["text"].value).toBe("neg");
const empty = graph.find(nodes.empty.id);
expect(empty.node.type).toBe("EmptyLatentImage");
expect(empty.widgets["width"].value).toBe(1024);
const sampler = graph.find(nodes.sampler.id);
expect(sampler.node.type).toBe("KSampler");
expect(sampler.widgets["sampler_name"].value).toBe("ddim");
expect(sampler.widgets["control_after_generate"].value).toBe("fixed");
// validate links
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
@ -207,13 +228,7 @@ describe("group node", () => {
expect(group.inputs).toHaveLength(1);
expect(group.inputs[0].input.type).toEqual("CLIP");
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.pos.id]: `${group.id}:0`,
[nodes.neg.id]: `${group.id}:1`,
[nodes.empty.id]: `${group.id}:2`,
})
);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id]));
});
test("it can embed reroutes as outputs", async () => {
const { ez, graph, app } = await start();
@ -227,12 +242,7 @@ describe("group node", () => {
const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]);
expect(group.outputs).toHaveLength(1);
expect(group.outputs[0].output.type).toEqual("IMAGE");
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.decode.id]: `${group.id}:0`,
[nodes.save.id]: `${group.id}:1`,
})
);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.decode.id, nodes.save.id]));
});
test("it can embed reroutes as pipes", async () => {
const { ez, graph, app } = await start();
@ -297,32 +307,22 @@ describe("group node", () => {
expect(group.widgets["KSampler denoise"].value).toEqual(0.9);
expect((await graph.toPrompt()).output).toEqual(
getOutput(
{
[nodes.ckpt.id]: `${group.id}:0`,
[nodes.pos.id]: `${group.id}:1`,
[nodes.neg.id]: `${group.id}:2`,
[nodes.empty.id]: `${group.id}:3`,
[nodes.sampler.id]: `${group.id}:4`,
getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], {
[nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
[nodes.pos.id]: { text: "hello" },
[nodes.neg.id]: { text: "world" },
[nodes.empty.id]: { width: 256, height: 1024 },
[nodes.sampler.id]: {
seed: 1,
steps: 8,
cfg: 4.5,
sampler_name: "uni_pc",
scheduler: "karras",
denoise: 0.9,
},
{
[nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
[nodes.pos.id]: { text: "hello" },
[nodes.neg.id]: { text: "world" },
[nodes.empty.id]: { width: 256, height: 1024 },
[nodes.sampler.id]: {
seed: 1,
steps: 8,
cfg: 4.5,
sampler_name: "uni_pc",
scheduler: "karras",
denoise: 0.9,
},
}
)
})
);
});
test("group inputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
@ -334,12 +334,7 @@ describe("group node", () => {
reroute.outputs[0].connectTo(group.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[1]);
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.pos.id]: `${group.id}:0`,
[nodes.neg.id]: `${group.id}:1`,
})
);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("group outputs can be reroutes", async () => {
const { ez, graph, app } = await start();
@ -354,12 +349,7 @@ describe("group node", () => {
reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive);
reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative);
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.pos.id]: `${group.id}:0`,
[nodes.neg.id]: `${group.id}:1`,
})
);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("groups can connect to each other", async () => {
const { ez, graph, app } = await start();
@ -371,12 +361,7 @@ describe("group node", () => {
group1.outputs[1].connectTo(group2.inputs["KSampler negative"]);
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.pos.id]: `${group1.id}:0`,
[nodes.neg.id]: `${group1.id}:1`,
[nodes.empty.id]: `${group2.id}:0`,
[nodes.sampler.id]: `${group2.id}:1`,
})
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
);
});
test("displays generated image on group node", async () => {
@ -434,15 +419,9 @@ describe("group node", () => {
primitive.widgets[0].value = "hello";
expect((await graph.toPrompt()).output).toEqual(
getOutput(
{
[nodes.pos.id]: `${group.id}:0`,
[nodes.neg.id]: `${group.id}:1`,
},
{
[nodes.pos.id]: { text: "hello" },
}
)
getOutput([nodes.pos.id, nodes.neg.id], {
[nodes.pos.id]: { text: "hello" },
})
);
});
test("can be copied", async () => {
@ -480,21 +459,11 @@ describe("group node", () => {
group2.widgets["KSampler seed"].value = 100;
expect((await graph.toPrompt()).output).toEqual({
...getOutput(
{
[nodes.pos.id]: `${group1.id}:0`,
[nodes.neg.id]: `${group1.id}:1`,
[nodes.empty.id]: `${group1.id}:2`,
[nodes.sampler.id]: `${group1.id}:3`,
[nodes.decode.id]: `${group1.id}:4`,
[nodes.save.id]: `${group1.id}:5`,
},
{
[nodes.pos.id]: { text: "hello" },
[nodes.empty.id]: { width: 256 },
[nodes.sampler.id]: { seed: 1 },
}
),
...getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id, nodes.decode.id, nodes.save.id], {
[nodes.pos.id]: { text: "hello" },
[nodes.empty.id]: { width: 256 },
[nodes.sampler.id]: { seed: 1 },
}),
...getOutput(
{
[nodes.pos.id]: `${group2.id}:0`,
@ -527,7 +496,7 @@ describe("group node", () => {
// Ensure the node isnt registered
expect(() => ez["workflow/test"]).toThrow();
// Relaod the workflow
// Reload the workflow
await app.loadGraphData(JSON.parse(workflow));
// Ensure the node is found
@ -594,6 +563,7 @@ describe("group node", () => {
primitive.outputs[0].connectTo(neg.inputs.text);
const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]);
// These will both be the same due to the primitive
expect(group.widgets["Positive text"].value).toBe("positive");
expect(group.widgets["Negative text"].value).toBe("positive");
@ -605,5 +575,10 @@ describe("group node", () => {
expect(pos.inputs).toHaveLength(2);
expect(neg.inputs).toHaveLength(2);
expect(primitive.outputs[0].connections).toHaveLength(2);
expect((await graph.toPrompt()).output).toEqual({
1: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
});

View File

@ -33,6 +33,7 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
const events = new EventTarget();
const mockApi = {
addEventListener: events.addEventListener.bind(events),
removeEventListener: events.removeEventListener.bind(events),
dispatchEvent: events.dispatchEvent.bind(events),
getSystemStats: jest.fn(),
getExtensions: jest.fn(() => mockExtensions),

View File

@ -5,6 +5,7 @@ import { getWidgetType } from "../../scripts/widgets.js";
export const IS_GROUP_NODE = Symbol();
export const GROUP_DATA = Symbol();
const GROUP_SLOTS = Symbol();
const GROUP_IDS = Symbol();
export async function registerGroupNodes(groupNodes, source, prefix, missingNodeTypes) {
if (!groupNodes) return;
@ -297,6 +298,7 @@ class ConvertToGroupAction {
let left;
let index = 0;
const slots = def[GROUP_SLOTS];
newNode[GROUP_IDS] = {};
for (const id in app.canvas.selected_nodes) {
const node = app.graph.getNodeById(id);
if (left == null || node.pos[0] < left) {
@ -308,6 +310,8 @@ class ConvertToGroupAction {
this.linkOutputs(newNode, node, slots, index++);
// Store the original ID so the node is reused in this session
newNode[GROUP_IDS][node._relative_id] = id;
app.graph.remove(node);
}
@ -353,39 +357,6 @@ const ext = {
return options;
};
api.addEventListener("executing", ({ detail }) => {
if (detail) {
const node = app.graph.getNodeById(detail);
if (!node) {
const split = detail.split(":");
if (split.length === 2) {
const outerNode = app.graph.getNodeById(+split[0]);
if (outerNode?.constructor.nodeData?.[IS_GROUP_NODE]) {
outerNode.runningInternalNodeId = +split[1];
api.dispatchEvent(new CustomEvent("executing", { detail: split[0] }));
}
}
}
}
});
api.addEventListener("executed", ({ detail }) => {
const node = app.graph.getNodeById(detail.node);
if (!node) {
const split = detail.node.split(":");
if (split.length === 2) {
const outerNode = app.graph.getNodeById(+split[0]);
if (outerNode?.constructor.nodeData?.[IS_GROUP_NODE]) {
outerNode.runningInternalNodeId = null;
api.dispatchEvent(
new CustomEvent("executed", { detail: { ...detail, node: split[0], merge: !outerNode.resetExecution } })
);
outerNode.resetExecution = false;
}
}
}
});
// Attach handlers after everything is registered to ensure all nodes are found
for (const k in LiteGraph.registered_node_types) {
const nodeType = LiteGraph.registered_node_types[k];
@ -409,7 +380,7 @@ const ext = {
}
},
async beforeConfigureGraph(graphData, missingNodeTypes) {
registerGroupNodes(graphData?.extra?.groupNodes, "workflow", undefined, missingNodeTypes);
await registerGroupNodes(graphData?.extra?.groupNodes, "workflow", undefined, missingNodeTypes);
},
addCustomNodeDefs(defs) {
globalDefs = defs;
@ -469,6 +440,7 @@ const ext = {
return r;
};
let executing, executed;
const onNodeCreated = node.onNodeCreated;
node.onNodeCreated = function () {
for (let innerNodeId = 0; innerNodeId < config.nodes.length; innerNodeId++) {
@ -499,20 +471,80 @@ const ext = {
}
}
function handleEvent(type, getId, getEvent) {
const handler = ({ detail }) => {
const id = getId(detail);
if (!id) return;
const node = app.graph.getNodeById(id);
if (node) return;
const split = id.split(":");
let groupNode;
let runningId;
if (split.length === 2) {
const outerNode = app.graph.getNodeById(+split[0]);
if (outerNode?.constructor.nodeData?.[IS_GROUP_NODE]) {
groupNode = outerNode;
runningId = split[1];
}
} else if (this[GROUP_IDS]) {
// Check if this is an internal node using its original ID
const isInternal = Object.values(this[GROUP_IDS]).indexOf(id) > -1;
if (isInternal) {
groupNode = this;
runningId = id;
}
}
if (groupNode) {
groupNode.runningInternalNodeId = +runningId;
api.dispatchEvent(new CustomEvent(type, { detail: getEvent(detail, groupNode.id + "", groupNode) }));
}
};
api.addEventListener(type, handler);
return handler;
}
executed = handleEvent.call(
this,
"executing",
(d) => d,
(d, id, node) => id
);
executed = handleEvent.call(
this,
"executed",
(d) => d?.node,
(d, id, node) => ({ ...d, node: id, merge: !node.resetExecution })
);
return onNodeCreated?.apply(this, arguments);
};
const onRemoved = node.onRemoved;
node.onRemoved = function () {
onRemoved?.apply(this, arguments);
api.removeEventListener("executing", executing);
api.removeEventListener("executed", executed);
};
const getExtraMenuOptions = node.getExtraMenuOptions ?? node.prototype.getExtraMenuOptions;
node.getExtraMenuOptions = function (_, options) {
let i = options.findIndex((o) => o.content === "Outputs");
if (i === -1) i = options.length;
else i++;
let optionIndex = options.findIndex((o) => o.content === "Outputs");
if (optionIndex === -1) optionIndex = options.length;
else optionIndex++;
options.splice(i, 0, null, {
options.splice(optionIndex, 0, null, {
content: "Convert to nodes",
callback: () => {
const backup = localStorage.getItem("litegrapheditor_clipboard");
localStorage.setItem("litegrapheditor_clipboard", JSON.stringify(config));
let c = config;
if (node[GROUP_IDS]) {
for (let i = 0; i < c.nodes.length; i++) {
c.nodes[i].id = +node[GROUP_IDS][i + ""];
}
}
localStorage.setItem("litegrapheditor_clipboard", JSON.stringify(c));
app.canvas.pasteFromClipboard();
localStorage.setItem("litegrapheditor_clipboard", backup);
@ -520,9 +552,11 @@ const ext = {
const [x, y] = this.pos;
let top;
let left;
const slots = def[GROUP_SLOTS];
const selectedIds = Object.keys(app.canvas.selected_nodes);
const newNodes = [];
for (const id of selectedIds) {
for (let i = 0; i < selectedIds.length; i++) {
const id = selectedIds[i];
const newNode = app.graph.getNodeById(id);
newNodes.push(newNode);
if (left == null || newNode.pos[0] < left) {
@ -531,6 +565,26 @@ const ext = {
if (top == null || newNode.pos[1] < top) {
top = newNode.pos[1];
}
// Copy values
for (const innerWidget of newNode.widgets ?? []) {
const groupWidgetName = slots.widgets[i]?.[innerWidget.name];
if (!groupWidgetName) continue;
const groupWidget = node.widgets.find((w) => w.name === groupWidgetName);
if (groupWidget) {
innerWidget.value = groupWidget.value;
// Copy linked widget values (control_after_generate)
if (groupWidget.linkedWidgets && innerWidget.linkedWidgets) {
for (let linkIndex = 0; linkIndex < groupWidget.linkedWidgets.length; linkIndex++) {
const w = innerWidget.linkedWidgets[linkIndex];
if (w) {
w.value = groupWidget.linkedWidgets[linkIndex].value;
}
}
}
}
}
}
// Shift each node
@ -540,7 +594,6 @@ const ext = {
}
// Reconnect inputs
const slots = def[GROUP_SLOTS];
for (const nodeIndex in slots.inputs) {
const id = selectedIds[nodeIndex];
const newNode = app.graph.getNodeById(id);
@ -594,12 +647,22 @@ const ext = {
return link;
};
function getInnerNodeId(i) {
// Use the node id from the instance if available
return node[GROUP_IDS]?.[i + ""] ?? `${node.id}:${i}`;
}
node.getInnerNodes = function () {
const links = getLinks(config);
const innerNodes = config.nodes.map((n, i) => {
const innerNode = LiteGraph.createNode(n.type);
innerNode.configure(n);
const config = { ...n };
// Remove any converted widget inputs
if (config.inputs) {
config.inputs = config.inputs.filter((c) => !c.widget);
}
const innerNode = LiteGraph.createNode(config.type);
innerNode.configure(config);
for (const innerWidget of innerNode.widgets ?? []) {
const groupWidgetName = slots.widgets[i]?.[innerWidget.name];
@ -610,7 +673,7 @@ const ext = {
}
}
innerNode.id = node.id + ":" + i;
innerNode.id = getInnerNodeId(i);
innerNode.getInputNode = function (slot) {
if (!innerNode.comfyClass) slot = 0;
const outerSlot = slots.inputs?.[i]?.[slot];
@ -648,9 +711,9 @@ const ext = {
if (!link) return null;
// Use the inner link, but update the origin node to be inner node id
link = {
origin_id: node.id + ":" + link[0],
origin_id: getInnerNodeId(link[0]),
origin_slot: link[1],
target_id: node.id + ":" + i,
target_id: getInnerNodeId(i),
target_slot: slot,
};

View File

@ -1271,6 +1271,7 @@ export class ComfyApp {
this.#addProcessMouseHandler();
this.#addProcessKeyHandler();
this.#addConfigureHandler();
this.#addApiUpdateHandlers();
this.graph = new LGraph();
@ -1324,7 +1325,6 @@ export class ComfyApp {
this.#addDrawNodeHandler();
this.#addDrawGroupsHandler();
this.#addApiUpdateHandlers();
this.#addDropHandler();
this.#addCopyHandler();
this.#addPasteHandler();