diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index b57cf7e09..5bd9b9787 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -28,27 +28,46 @@ test("converted widget works after reload", async () => { // Convert back to widget and ensure input is removed n.widgets.ckpt_name.convertToWidget(); expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy(); - expect(() => n.inputs.ckpt_name).toThrow(/Unknown input/); + expect(n.inputs.ckpt_name).toBeFalsy(); expect(n.inputs.length).toEqual(inputCount); // Convert again and reload the graph to ensure it maintains state n.widgets.ckpt_name.convertToInput(); expect(n.inputs.length).toEqual(inputCount + 1); - // TODO: connect primitive - await graph.reload(); - // TODO: ensure primitive connected, disconnect, reconnect + let { $: primitive } = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(n.inputs.ckpt_name); - // Find the reloaded node in the graph + await graph.reload(); + + // Find the reloaded nodes in the graph n = graph.find(n); + primitive = graph.find(primitive); + + // Ensure widget is converted expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy(); expect(n.inputs.ckpt_name).toBeTruthy(); expect(n.inputs.length).toEqual(inputCount + 1); + // Ensure primitive is connected + let { connections } = primitive.outputs[0]; + expect(connections).toHaveLength(1); + expect(connections[0].targetNode.id).toBe(n.node.id); + + // Disconnect & reconnect + connections[0].disconnect(); + ({ connections } = primitive.outputs[0]); + expect(connections).toHaveLength(0); + primitive.outputs[0].connectTo(n.inputs.ckpt_name); + + ({ connections } = primitive.outputs[0]); + expect(connections).toHaveLength(1); + expect(connections[0].targetNode.id).toBe(n.node.id); + // Convert back to widget and ensure input is removed n.widgets.ckpt_name.convertToWidget(); expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy(); - expect(() => n.inputs.ckpt_name).toThrow(/Unknown input/); + expect(n.inputs.ckpt_name).toBeFalsy(); expect(n.inputs.length).toEqual(inputCount); }); @@ -70,10 +89,14 @@ test("converted widget works on clone", async () => { expect(clone.widgets.ckpt_name.isConvertedToInput).toBeTruthy(); expect(clone.inputs.ckpt_name).toBeTruthy(); - // TODO: connect primitive to clone + // Ensure primitive connects to both nodes + let { $: primitive } = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(n.inputs.ckpt_name); + primitive.outputs[0].connectTo(clone.inputs.ckpt_name); + expect(primitive.outputs[0].connections).toHaveLength(2); // Convert back to widget and ensure input is removed clone.widgets.ckpt_name.convertToWidget(); expect(clone.widgets.ckpt_name.isConvertedToInput).toBeFalsy(); - expect(() => clone.inputs.ckpt_name).toThrow(/Unknown input/); + expect(clone.inputs.ckpt_name).toBeFalsy(); }); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 2e2a8ad31..a91602c60 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -1,40 +1,128 @@ // @ts-check /// -const NODE = Symbol(); - /** * @typedef { import("../../web/scripts/app")["app"] } app * @typedef { import("../../web/types/litegraph") } LG * @typedef { import("../../web/types/litegraph").IWidget } IWidget * @typedef { import("../../web/types/litegraph").ContextMenuItem } ContextMenuItem * @typedef { import("../../web/types/litegraph").INodeInputSlot } INodeInputSlot + * @typedef { import("../../web/types/litegraph").INodeOutputSlot } INodeOutputSlot * @typedef { InstanceType & { widgets?: Array } } LGNode - * @typedef { { [k in keyof typeof Ez["util"]]: typeof Ez["util"][k] extends (app: any, ...rest: infer A) => infer R ? (...args: A) => R : never } } EzUtils * @typedef { (...args: EzOutput[] | [...EzOutput[], Record]) => Array & { $: EzNode, node: LG["LGraphNode"]} } EzNodeFactory - * @typedef { ReturnType[0] } EzOutput */ -class EzInput { +class EzConnection { + /** @type { app } */ + app; + /** @type { InstanceType } */ + link; + + get originNode() { + return new EzNode(this.app, this.app.graph.getNodeById(this.link.origin_id)); + } + + get originOutput() { + return this.originNode.outputs[this.link.origin_slot]; + } + + get targetNode() { + return new EzNode(this.app, this.app.graph.getNodeById(this.link.target_id)); + } + + get targetInput() { + return this.targetNode.inputs[this.link.target_slot]; + } + + /** + * @param { app } app + * @param { InstanceType } link + */ + constructor(app, link) { + this.app = app; + this.link = link; + } + + disconnect() { + this.targetInput.disconnect(); + } +} + +class EzSlot { /** @type { EzNode } */ node; - /** @type { INodeInputSlot } */ - input; /** @type { number } */ index; /** * @param { EzNode } node - * @param { INodeInputSlot } input * @param { number } index */ - constructor(node, input, index) { + constructor(node, index) { this.node = node; - this.input = input; this.index = index; } } +class EzInput extends EzSlot { + /** @type { INodeInputSlot } */ + input; + + /** + * @param { EzNode } node + * @param { number } index + * @param { INodeInputSlot } input + */ + constructor(node, index, input) { + super(node, index); + this.input = input; + } + + disconnect() { + this.node.node.disconnectInput(this.index); + } +} + +class EzOutput extends EzSlot { + /** @type { INodeOutputSlot } */ + output; + + /** + * @param { EzNode } node + * @param { number } index + * @param { INodeOutputSlot } output + */ + constructor(node, index, output) { + super(node, index); + this.output = output; + } + + get connections() { + return (this.node.node.outputs?.[this.index]?.links ?? []) + .map(l => new EzConnection(this.node.app, this.node.app.graph.links[l])); + } + + /** + * @param { EzInput } input + */ + connectTo(input) { + /** + * @type { LG["LLink"] | null } + */ + const link = this.node.node.connect(this.index, input.node.node, input.index); + if (!link) { + const inp = input.input; + const inName = inp.name || inp.label || inp.type; + throw new Error( + `Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${ + this.output.name ?? this.output.type + }#${this.index}] failed.` + ); + } + return link; + } +} + class EzNodeMenuItem { /** @type { EzNode } */ node; @@ -105,12 +193,6 @@ class EzNode { app; /** @type { LGNode } */ node; - /** @type { { length: number } & Record } */ - inputs; - /** @type { Record } */ - widgets; - /** @type { Record } */ - menu; /** * @param { app } app @@ -119,72 +201,59 @@ class EzNode { constructor(app, node) { this.app = app; this.node = node; - - // @ts-ignore : this proxy returns the length - this.inputs = new Proxy( - {}, - { - get: (_, p) => { - if (typeof p !== "string") throw new Error(`Invalid widget name.`); - if (p === "length") return this.node.inputs?.length ?? 0; - const index = this.node.inputs.findIndex((i) => i.name === p); - if (index === -1) throw new Error(`Unknown input "${p}" on node "${this.node.type}".`); - return new EzInput(this, this.node.inputs[index], index); - }, - } - ); - - this.widgets = new Proxy( - {}, - { - get: (_, p) => { - if (typeof p !== "string") throw new Error(`Invalid widget name.`); - const widget = this.node.widgets?.find((w) => w.name === p); - if (!widget) throw new Error(`Unknown widget "${p}" on node "${this.node.type}".`); - - return new EzWidget(this, widget); - }, - } - ); - - this.menu = new Proxy( - {}, - { - get: (_, p) => { - if (typeof p !== "string") throw new Error(`Invalid menu item name.`); - const options = this.menuItems(); - const option = options.find((o) => o?.content === p); - if (!option) throw new Error(`Unknown menu item "${p}" on node "${this.node.type}".`); - - return new EzNodeMenuItem(this, option); - }, - } - ); } get id() { return this.node.id; } - menuItems() { - return this.app.canvas.getNodeMenuOptions(this.node); + get inputs() { + return this.#getSlotItems("inputs"); } - outputs() { - return ( - this.node.outputs?.map((data, index) => { - return { - [NODE]: this.node, - index, - data, - }; - }) ?? [] - ); + get outputs() { + return this.#getSlotItems("outputs"); + } + + /** @returns { Record } */ + get widgets() { + return (this.node.widgets ?? []).reduce((p, w, i) => { + p[w.name ?? i] = new EzWidget(this, w); + return p; + }, {}); + } + + get menu() { + const items = this.app.canvas.getNodeMenuOptions(this.node); + return items.reduce((p, w) => { + if(w?.content) { + p[w.content] = new EzNodeMenuItem(this, w); + } + return p; + }, {}); } select() { this.app.canvas.selectNode(this.node); } + + /** + * @template { "inputs" | "outputs" } T + * @param { T } type + * @returns { Record & (type extends "inputs" ? EzInput [] : EzOutput[]) } + */ + #getSlotItems(type) { + // @ts-ignore : these items are correct + return (this.node[type] ?? []).reduce((p, s, i) => { + if(s.name in p) { + throw new Error(`Unable to store input ${s.name} on array as name conflicts.`); + } + ; + // @ts-ignore + p.push(p[s.name] = new (type === "inputs" ? EzInput : EzOutput)(this, i, s)); + return p; + }, Object.assign([], {$: this})) + } } class EzGraph { @@ -267,7 +336,7 @@ export const Ez = { * @param { boolean } clearGraph * @returns { { graph: EzGraph, ez: Record } } */ - graph(app, LiteGraph, LGraphCanvas, clearGraph = true) { + graph(app, LiteGraph = window["LiteGraph"], LGraphCanvas = window["LGraphCanvas"], clearGraph = true) { // Always set the active canvas so things work LGraphCanvas.active_canvas = app.canvas; @@ -290,13 +359,14 @@ export const Ez = { */ return function (...args) { const ezNode = new EzNode(app, node); + const outputs = ezNode.outputs; + const inputs = ezNode.inputs; - // console.log("Created " + node.type, "Populating:", args); let slot = 0; for (let i = 0; i < args.length; i++) { const arg = args[i]; - if (arg[NODE]) { - arg[NODE].connect(arg.index, node, slot++); + if (arg instanceof EzOutput) { + arg.connectTo(inputs[slot++]); } else { for (const k in arg) { ezNode.widgets[k].value = arg[k]; @@ -304,9 +374,6 @@ export const Ez = { } } - const outputs = ezNode.outputs(); - outputs["$"] = ezNode; - outputs["node"] = node; return outputs; }; },