Fix initial primitive values on connect

This commit is contained in:
pythongosssss 2023-12-11 17:57:34 +00:00
parent f17249cb51
commit 1361adb29a
2 changed files with 30 additions and 23 deletions

View File

@ -1,7 +1,7 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow, getNodeDef } = require("../utils");
const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
const lg = require("../utils/litegraph");
describe("group node", () => {
@ -682,12 +682,13 @@ describe("group node", () => {
test("correctly handles widget inputs", async () => {
const { ez, graph, app } = await start();
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
const image = ez.LoadImage();
const scale1 = ez.ImageScaleBy(image.outputs[0]);
const scale2 = ez.ImageScaleBy(image.outputs[0]);
const preview1 = ez.PreviewImage(scale1.outputs[0]);
const preview2 = ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.upscale_method.value = upscaleMethods[1];
scale1.widgets.upscale_method.convertToInput();
const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
@ -705,22 +706,26 @@ describe("group node", () => {
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs[2]);
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
expect(primitive.widgets.value.value).toBe(upscaleMethods[0]);
expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
primitive.widgets.value.value = upscaleMethods[1];
// Ensure widget value is applied to prompt
expect((await graph.toPrompt()).output).toStrictEqual({
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
[scale1.id]: {
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[scale2.id]: {
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[preview1.id]: { inputs: { images: [`${scale1.id}`, 0] }, class_type: "PreviewImage" },
[preview2.id]: { inputs: { images: [`${scale2.id}`, 0] }, class_type: "PreviewImage" },
await checkBeforeAndAfterReload(graph, async (r) => {
const scale1id = r ? `${group.id}:0` : scale1.id;
const scale2id = r ? `${group.id}:1` : scale2.id;
// Ensure widget value is applied to prompt
expect((await graph.toPrompt()).output).toStrictEqual({
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
[scale1id]: {
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[scale2id]: {
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
[preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
});
});
});
test("adds widgets in node execution order", async () => {

View File

@ -426,6 +426,12 @@ export class GroupNodeConfig {
});
this.nodeDef.input.required[name] = config;
this.newToOldWidgetMap[name] = { node, inputName };
if (!this.oldToNewWidgetMap[node.index]) {
this.oldToNewWidgetMap[node.index] = {};
}
this.oldToNewWidgetMap[node.index][inputName] = name;
inputMap[slots.length + i] = this.inputCount++;
}
}
@ -916,6 +922,7 @@ export class GroupNodeHandler {
this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
}
}
return true;
}
populateWidgets() {
@ -933,16 +940,11 @@ export class GroupNodeHandler {
const newName = map[oldName];
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
const mainWidget = this.node.widgets[widgetIndex];
if (!newName) {
// New name will be null if its a converted widget
this.populatePrimitive(node, nodeId, oldName, i, linkedShift);
if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift)) {
// Find the inner widget and shift by the number of linked widgets as they will have been removed too
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
linkedShift += innerWidget.linkedWidgets?.length ?? 0;
continue;
}
if (widgetIndex === -1) {
continue;
}