Fix widget inputs being incorrect order and value

This commit is contained in:
pythongosssss 2023-12-11 17:03:37 +00:00
parent de230392ff
commit f17249cb51
4 changed files with 76 additions and 4 deletions

View File

@ -1,7 +1,7 @@
// @ts-check // @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" /> /// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow } = require("../utils"); const { start, createDefaultWorkflow, getNodeDef } = require("../utils");
const lg = require("../utils/litegraph"); const lg = require("../utils/litegraph");
describe("group node", () => { describe("group node", () => {
@ -679,6 +679,50 @@ describe("group node", () => {
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" }, 2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
}); });
}); });
test("correctly handles widget inputs", async () => {
const { ez, graph, app } = await start();
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
const image = ez.LoadImage();
const scale1 = ez.ImageScaleBy(image.outputs[0]);
const scale2 = ez.ImageScaleBy(image.outputs[0]);
const preview1 = ez.PreviewImage(scale1.outputs[0]);
const preview2 = ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.upscale_method.convertToInput();
const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
expect(group.inputs.length).toBe(3);
expect(group.inputs[0].input.type).toBe("IMAGE");
expect(group.inputs[1].input.type).toBe("IMAGE");
expect(group.inputs[2].input.type).toBe("COMBO");
// Ensure links are maintained
expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[2].connection).toBeFalsy();
// Ensure primitive gets correct type
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs[2]);
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
expect(primitive.widgets.value.value).toBe(upscaleMethods[0]);
primitive.widgets.value.value = upscaleMethods[1];
// Ensure widget value is applied to prompt
expect((await graph.toPrompt()).output).toStrictEqual({
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
[scale1.id]: {
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[scale2.id]: {
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[preview1.id]: { inputs: { images: [`${scale1.id}`, 0] }, class_type: "PreviewImage" },
[preview2.id]: { inputs: { images: [`${scale2.id}`, 0] }, class_type: "PreviewImage" },
});
});
test("adds widgets in node execution order", async () => { test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start(); const { ez, graph, app } = await start();
const scale = ez.LatentUpscale(); const scale = ez.LatentUpscale();

View File

@ -78,6 +78,14 @@ export class EzInput extends EzSlot {
this.input = input; this.input = input;
} }
get connection() {
const link = this.node.node.inputs?.[this.index]?.link;
if (link == null) {
return null;
}
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
}
disconnect() { disconnect() {
this.node.node.disconnectInput(this.index); this.node.node.disconnectInput(this.index);
} }

View File

@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) {
return { ckpt, pos, neg, empty, sampler, decode, save }; return { ckpt, pos, neg, empty, sampler, decode, save };
} }
export async function getNodeDefs() {
const { api } = require("../../web/scripts/api");
return api.getNodeDefs();
}
export async function getNodeDef(nodeId) {
return (await getNodeDefs())[nodeId];
}

View File

@ -174,6 +174,11 @@ export class GroupNodeConfig {
node.index = i; node.index = i;
this.processNode(node, seenInputs, seenOutputs); this.processNode(node, seenInputs, seenOutputs);
} }
for (const p of this.#convertedToProcess) {
p();
}
this.#convertedToProcess = null;
await app.registerNodeDef("workflow/" + this.name, this.nodeDef); await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
} }
@ -420,10 +425,12 @@ export class GroupNodeConfig {
defaultInput: true, defaultInput: true,
}); });
this.nodeDef.input.required[name] = config; this.nodeDef.input.required[name] = config;
this.newToOldWidgetMap[name] = { node, inputName };
inputMap[slots.length + i] = this.inputCount++; inputMap[slots.length + i] = this.inputCount++;
} }
} }
#convertedToProcess = [];
processNodeInputs(node, seenInputs, inputs) { processNodeInputs(node, seenInputs, inputs) {
const inputMapping = []; const inputMapping = [];
@ -434,7 +441,11 @@ export class GroupNodeConfig {
const linksTo = this.linksTo[node.index] ?? {}; const linksTo = this.linksTo[node.index] ?? {};
const inputMap = (this.oldToNewInputMap[node.index] = {}); const inputMap = (this.oldToNewInputMap[node.index] = {});
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs);
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list
this.#convertedToProcess.push(() =>
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
);
return inputMapping; return inputMapping;
} }