From e19cb8368f299776e589b8ae6455654d3cb7ad61 Mon Sep 17 00:00:00 2001 From: flyingshutter Date: Thu, 6 Apr 2023 01:04:21 +0200 Subject: [PATCH] Add BatchIndex node Add "INT:batch_index" subtype Add onInitBatch() callback for widgets --- nodes.py | 17 +++++++++++++++++ web/scripts/app.js | 18 +++++++++++++++++- web/scripts/widgets.js | 20 ++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 187d54a11..82296cacd 100644 --- a/nodes.py +++ b/nodes.py @@ -1034,6 +1034,22 @@ class ImagePadForOutpaint: return (new_image, mask) +class BatchIndex: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "batch_index": ("INT", {}), + }, + } + + RETURN_TYPES = ("INT", "FLOAT") + FUNCTION = "getBatchIndex" + CATEGORY = "utils" + + def getBatchIndex(self, batch_index): + return (batch_index, batch_index) + NODE_CLASS_MAPPINGS = { "KSampler": KSampler, @@ -1076,6 +1092,7 @@ NODE_CLASS_MAPPINGS = { "TomePatchModel": TomePatchModel, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "CheckpointLoader": CheckpointLoader, + "BatchIndex": BatchIndex, } def load_custom_node(module_path): diff --git a/web/scripts/app.js b/web/scripts/app.js index 5aef31c38..2dfba0bed 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -946,7 +946,23 @@ class ComfyApp { ({ number, batchCount } = this.#queueItems.pop()); for (let i = 0; i < batchCount; i++) { - const p = await this.graphToPrompt(); + let p = await this.graphToPrompt(); + + if (i == 0) { + for (const n of p.workflow.nodes) { + const node = graph.getNodeById(n.id); + if (node.widgets) { + for (const widget of node.widgets) { + // Allow widgets to run callbacks on firts iteration of a batch + // e.g. random seed after every gen + if (widget.onInitBatch) { + widget.onInitBatch(); + } + } + } + } + p = await this.graphToPrompt(); + } try { await api.queuePrompt(number, p); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 5f5043cd0..4f0d7de9d 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -43,6 +43,25 @@ function seedWidget(node, inputName, inputData) { return { widget: seed, randomize }; } +export function batchIndexWidget(node, inputName, inputData) { + + const { val, config } = getNumberDefaults(inputData, 1); + Object.assign(config, { precision: 0 }); + const batchIndex = node.addWidget("number", inputName, val, () => {}, config) + batchIndex.forbidConvertToInput = true; + batchIndex.disabled = true; + + batchIndex.onInitBatch = () => { + batchIndex.value = 0; + }; + + batchIndex.afterQueued = () => { + batchIndex.value += 1; + }; + + return batchIndex; +} + const MultilineSymbol = Symbol(); const MultilineResizeSymbol = Symbol(); @@ -197,6 +216,7 @@ function addMultilineWidget(node, name, opts, app) { } export const ComfyWidgets = { + "INT:batch_index": batchIndexWidget, "INT:seed": seedWidget, "INT:noise_seed": seedWidget, FLOAT(node, inputName, inputData) {