From 3fd30d8fb3020f5f94fb6b9034fb14ad40fdd7af Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 11 Dec 2023 19:07:08 +0000 Subject: [PATCH] basic support for basic rerouted converted inputs --- tests-ui/tests/groupNode.test.js | 63 ++++++++++++++++++++++++++++++++ web/extensions/core/groupNode.js | 57 +++++++++++++++++++++++++---- 2 files changed, 113 insertions(+), 7 deletions(-) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index f270f8347..f02f9c223 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -901,4 +901,67 @@ describe("group node", () => { expect(p2.widgets.control_after_generate.value).toBe("randomize"); expect(p2.widgets.control_filter_list.value).toBe("/.+/"); }); + test("internal reroutes work with converted inputs and merge options", async () => { + const { ez, graph, app } = await start(); + const vae = ez.VAELoader(); + const latent = ez.EmptyLatentImage(); + const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE); + const scale = ez.ImageScale(decode.outputs.IMAGE); + ez.PreviewImage(scale.outputs.IMAGE); + + const r1 = ez.Reroute(); + const r2 = ez.Reroute(); + + latent.widgets.width.convertToInput(); + latent.widgets.height.convertToInput(); + latent.widgets.batch_size.convertToInput(); + + scale.widgets.width.convertToInput(); + scale.widgets.height.convertToInput(); + + r1.inputs[0].input.label = "hbw"; + r1.outputs[0].connectTo(latent.inputs.height); + r1.outputs[0].connectTo(latent.inputs.batch_size); + r1.outputs[0].connectTo(scale.inputs.width); + + r2.inputs[0].input.label = "wh"; + r2.outputs[0].connectTo(latent.inputs.width); + r2.outputs[0].connectTo(scale.inputs.height); + + const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]); + + expect(group.inputs[0].input.type).toBe("VAE"); + expect(group.inputs[1].input.type).toBe("INT"); + expect(group.inputs[2].input.type).toBe("INT"); + + const p1 = ez.PrimitiveNode(); + const p2 = ez.PrimitiveNode(); + p1.outputs[0].connectTo(group.inputs[1]); + p2.outputs[0].connectTo(group.inputs[2]); + + expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max + expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + p1.widgets.value.value = 16; + p2.widgets.value.value = 32; + + await checkBeforeAndAfterReload(graph, async (r) => { + const id = (v) => (r ? `${group.id}:` : "") + v; + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, + [id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" }, + [id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" }, + [id(4)]: { + inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] }, + class_type: "ImageScale", + }, + 5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" }, + }); + }); + }); }); diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index e93c0cb2d..be171a7a5 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -197,7 +197,10 @@ export class GroupNodeConfig { if (!this.linksFrom[sourceNodeId]) { this.linksFrom[sourceNodeId] = {}; } - this.linksFrom[sourceNodeId][sourceNodeSlot] = l; + if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) { + this.linksFrom[sourceNodeId][sourceNodeSlot] = []; + } + this.linksFrom[sourceNodeId][sourceNodeSlot].push(l); if (!this.linksTo[targetNodeId]) { this.linksTo[targetNodeId] = {}; @@ -235,11 +238,11 @@ export class GroupNodeConfig { // Skip as its not linked if (!linksFrom) return; - let type = linksFrom["0"][5]; + let type = linksFrom["0"][0][5]; if (type === "COMBO") { // Use the array items const source = node.outputs[0].widget.name; - const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type; + const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type; const fromType = globalDefs[fromTypeName]; const input = fromType.input.required[source] ?? fromType.input.optional[source]; type = input[0]; @@ -263,10 +266,33 @@ export class GroupNodeConfig { return null; } + let config = {}; let rerouteType = "*"; if (linksFrom) { - const [, , id, slot] = linksFrom["0"]; - rerouteType = this.nodeData.nodes[id].inputs[slot].type; + for (const [, , id, slot] of linksFrom["0"]) { + const node = this.nodeData.nodes[id]; + const input = node.inputs[slot]; + if (rerouteType === "*") { + rerouteType = input.type; + } + if (input.widget) { + const targetDef = globalDefs[node.type]; + const targetWidget = + targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; + + const widget = [targetWidget[0], config]; + const res = mergeIfValid( + { + widget, + }, + targetWidget, + false, + null, + widget + ); + config = res?.customConfig ?? config; + } + } } else if (linksTo) { const [id, slot] = linksTo["0"]; rerouteType = this.nodeData.nodes[id].outputs[slot].type; @@ -287,10 +313,11 @@ export class GroupNodeConfig { } } + config.forceInput = true; return { input: { required: { - [rerouteType]: [rerouteType, {}], + [rerouteType]: [rerouteType, config], }, }, output: [rerouteType], @@ -690,6 +717,8 @@ export class GroupNodeHandler { top = newNode.pos[1]; } + if (!newNode.widgets) continue; + const map = this.groupData.oldToNewWidgetMap[innerNode.index]; if (map) { const widgets = Object.keys(map); @@ -746,7 +775,7 @@ export class GroupNodeHandler { } }; - const reconnectOutputs = () => { + const reconnectOutputs = (selectedIds) => { for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) { const output = node.outputs[groupOutputId]; if (!output.links) continue; @@ -895,6 +924,18 @@ export class GroupNodeHandler { } } continue; + } else if (innerNode.type === "Reroute") { + const rerouteLinks = this.groupData.linksFrom[old.node.index]; + for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { + const node = this.innerNodes[targetNodeId]; + const input = node.inputs[targetSlot]; + if (input.widget) { + const widget = node.widgets?.find((w) => w.name === input.widget.name); + if (widget) { + widget.value = newValue; + } + } + } } const widget = innerNode.widgets?.find((w) => w.name === old.inputName); @@ -926,6 +967,8 @@ export class GroupNodeHandler { } populateWidgets() { + if (!this.node.widgets) return; + for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) { const node = this.groupData.nodeData.nodes[nodeId];