Fix convert to nodes with internal reroutes

This commit is contained in:
pythongosssss 2023-11-29 17:43:42 +00:00
parent f02cb2d3df
commit e051b5b8d6
2 changed files with 41 additions and 20 deletions

View File

@ -267,6 +267,26 @@ describe("group node", () => {
group.outputs[1].connectTo(nodes.pos.inputs.clip); group.outputs[1].connectTo(nodes.pos.inputs.clip);
group.outputs[1].connectTo(nodes.neg.inputs.clip); group.outputs[1].connectTo(nodes.neg.inputs.clip);
}); });
test("can handle reroutes used internally", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let reroutes = [];
let prevNode = nodes.ckpt;
for(let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
reroutes.push(reroute);
}
prevNode.outputs[0].connectTo(nodes.sampler.inputs.model);
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());
group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("creates with widget values from inner nodes", async () => { test("creates with widget values from inner nodes", async () => {
const { ez, graph, app } = await start(); const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph); const nodes = createDefaultWorkflow(ez, graph);

View File

@ -1,6 +1,6 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js"; import { api } from "../../scripts/api.js";
import { getWidgetType, addValueControlWidgets } from "../../scripts/widgets.js"; import { getWidgetType } from "../../scripts/widgets.js";
import { mergeIfValid } from "./widgetInputs.js"; import { mergeIfValid } from "./widgetInputs.js";
const GROUP = Symbol(); const GROUP = Symbol();
@ -107,7 +107,7 @@ class GroupNodeBuilder {
for (const l of output.links) { for (const l of output.links) {
const link = app.graph.links[l]; const link = app.graph.links[l];
if (!link) continue; if (!link) continue;
if(type === "*") type = link.type; if (type === "*") type = link.type;
if (!app.canvas.selected_nodes[link.target_id]) { if (!app.canvas.selected_nodes[link.target_id]) {
hasExternal = true; hasExternal = true;
@ -244,7 +244,7 @@ export class GroupNodeConfig {
const def = (this.primitiveDefs[node.index] = { const def = (this.primitiveDefs[node.index] = {
input: { input: {
required: { required: {
value: [type, { }], value: [type, {}],
}, },
}, },
output: [type], output: [type],
@ -274,10 +274,10 @@ export class GroupNodeConfig {
break; break;
} }
} }
if(rerouteType === "*") { if (rerouteType === "*") {
// Check for an external link // Check for an external link
const t = this.externalFrom[node.index]?.[0]; const t = this.externalFrom[node.index]?.[0];
if(t) { if (t) {
rerouteType = t; rerouteType = t;
} }
} }
@ -646,10 +646,10 @@ export class GroupNodeHandler {
newNode.widgets[0].callback?.(newNode.widgets[0].value); newNode.widgets[0].callback?.(newNode.widgets[0].value);
} else { } else {
const map = this.groupData.oldToNewWidgetMap[innerNode.index]; const map = this.groupData.oldToNewWidgetMap[innerNode.index];
if (map) {
const widgets = Object.keys(map); const widgets = Object.keys(map);
for (let i = 0; i < widgets.length; i++) { for (const oldName of widgets) {
const oldName = widgets[i];
const newName = map[oldName]; const newName = map[oldName];
if (!newName) continue; if (!newName) continue;
@ -668,6 +668,7 @@ export class GroupNodeHandler {
} }
} }
} }
}
// Shift each node // Shift each node
for (const newNode of newNodes) { for (const newNode of newNodes) {
@ -826,7 +827,7 @@ export class GroupNodeHandler {
updateInnerWidgets() { updateInnerWidgets() {
for (const newWidgetName in this.groupData.newToOldWidgetMap) { for (const newWidgetName in this.groupData.newToOldWidgetMap) {
const newWidget = this.node.widgets.find((w) => w.name === newWidgetName); const newWidget = this.node.widgets.find((w) => w.name === newWidgetName);
if(!newWidget) continue; if (!newWidget) continue;
const newValue = newWidget.value; const newValue = newWidget.value;
const old = this.groupData.newToOldWidgetMap[newWidgetName]; const old = this.groupData.newToOldWidgetMap[newWidgetName];