From 50763d0f3975ff40a0de752a9cee40099f6cba0b Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 29 Nov 2023 21:28:19 +0000 Subject: [PATCH] Fix control widget values --- tests-ui/tests/groupNode.test.js | 36 ++++++++++++++++++++++++++- web/extensions/core/groupNode.js | 42 ++++++++++++++++++++------------ web/scripts/widgets.js | 2 +- 3 files changed, 62 insertions(+), 18 deletions(-) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index 321f04a63..d60ff8de4 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -747,7 +747,7 @@ describe("group node", () => { expect(widget).toBeTruthy(); expect(widget.widget.type).toBe("button"); }); - test.only("internal primitive populates widgets for all linked inputs", async () => { + test("internal primitive populates widgets for all linked inputs", async () => { const { ez, graph, app } = await start(); const img = ez.LoadImage(); const scale1 = ez.ImageScale(img.outputs[0]); @@ -781,4 +781,38 @@ describe("group node", () => { 5: { inputs: { images: ["3", 0] }, class_type: "PreviewImage" }, }); }); + test("primitive control widgets values are copied on convert", async () => { + const { ez, graph, app } = await start(); + const sampler = ez.KSampler(); + sampler.widgets.seed.convertToInput(); + sampler.widgets.sampler_name.convertToInput(); + + let p1 = ez.PrimitiveNode(); + let p2 = ez.PrimitiveNode(); + p1.outputs[0].connectTo(sampler.inputs.seed); + p2.outputs[0].connectTo(sampler.inputs.sampler_name); + + p1.widgets.control_after_generate.value = "increment"; + p2.widgets.control_after_generate.value = "decrement"; + p2.widgets.control_filter_list.value = "/.*/"; + + p2.node.title = "p2"; + + const group = await convertToGroup(app, graph, "test", [sampler, p1, p2]); + expect(group.widgets.control_after_generate.value).toBe("increment"); + expect(group.widgets["p2 control_after_generate"].value).toBe("decrement"); + expect(group.widgets["p2 control_filter_list"].value).toBe("/.*/"); + + group.widgets.control_after_generate.value = "fixed"; + group.widgets["p2 control_after_generate"].value = "randomize"; + group.widgets["p2 control_filter_list"].value = "/.+/"; + + group.menu["Convert to nodes"].call(); + p1 = graph.find(p1); + p2 = graph.find(p2); + + expect(p1.widgets.control_after_generate.value).toBe("fixed"); + expect(p2.widgets.control_after_generate.value).toBe("randomize"); + expect(p2.widgets.control_filter_list.value).toBe("/.+/"); + }); }); diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 4661e23ec..450b4f5f3 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -639,22 +639,23 @@ export class GroupNodeHandler { top = newNode.pos[1]; } - if (innerNode.type === "PrimitiveNode" && innerNode.primitiveValue != null) { - newNode.widgets[0].value = innerNode.primitiveValue; - newNode.widgets[0].callback?.(newNode.widgets[0].value); - } else { - const map = this.groupData.oldToNewWidgetMap[innerNode.index]; - if (map) { - const widgets = Object.keys(map); + const map = this.groupData.oldToNewWidgetMap[innerNode.index]; + if (map) { + const widgets = Object.keys(map); - for (const oldName of widgets) { - const newName = map[oldName]; - if (!newName) continue; + for (const oldName of widgets) { + const newName = map[oldName]; + if (!newName) continue; - const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); - if (widgetIndex === -1) continue; + const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); + if (widgetIndex === -1) continue; - // Populate the main and any linked widget + // Populate the main and any linked widgets + if (innerNode.type === "PrimitiveNode") { + for (let i = 0; i < newNode.widgets.length; i++) { + newNode.widgets[i].value = this.node.widgets[widgetIndex + i].value; + } + } else { const outerWidget = this.node.widgets[widgetIndex]; const newWidget = newNode.widgets.find((w) => w.name === oldName); if (!newWidget) continue; @@ -857,9 +858,18 @@ export class GroupNodeHandler { const primitiveId = this.groupData.widgetToPrimitive[nodeId]?.[oldName]; if (primitiveId == null) return; const targetWidgetName = this.groupData.oldToNewWidgetMap[primitiveId]["value"]; - const targetWidget = this.node.widgets.find((w) => w.name === targetWidgetName); - if (targetWidget) { - targetWidget.value = node.widgets_values[i + linkedShift]; + const targetWidgetIndex = this.node.widgets.findIndex((w) => w.name === targetWidgetName); + if (targetWidgetIndex > -1) { + const primitiveNode = this.innerNodes[primitiveId]; + let len = primitiveNode.widgets.length; + if (len - 1 !== this.node.widgets[targetWidgetIndex].linkedWidgets?.length) { + // Fallback handling for if some reason the primitive has a different number of widgets + // we dont want to overwrite random widgets, better to leave blank + len = 1; + } + for (let i = 0; i < len; i++) { + this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value; + } } } diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index d73e4631d..de5877e54 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -304,7 +304,7 @@ export const ComfyWidgets = { } const res = { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; if (inputData[1]?.control_after_generate) { - addValueControlWidgets(node, res.widget, undefined, undefined, inputData); + res.widget.linkedWidgets = addValueControlWidgets(node, res.widget, undefined, undefined, inputData); } return res; },