Add BatchIndex node

Add "INT:batch_index" subtype
Add onInitBatch() callback for widgets
This commit is contained in:
flyingshutter 2023-04-06 01:04:21 +02:00
parent d5cce8345f
commit e19cb8368f
3 changed files with 54 additions and 1 deletions

View File

@ -1034,6 +1034,22 @@ class ImagePadForOutpaint:
return (new_image, mask) return (new_image, mask)
class BatchIndex:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"batch_index": ("INT", {}),
},
}
RETURN_TYPES = ("INT", "FLOAT")
FUNCTION = "getBatchIndex"
CATEGORY = "utils"
def getBatchIndex(self, batch_index):
return (batch_index, batch_index)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"KSampler": KSampler, "KSampler": KSampler,
@ -1076,6 +1092,7 @@ NODE_CLASS_MAPPINGS = {
"TomePatchModel": TomePatchModel, "TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"BatchIndex": BatchIndex,
} }
def load_custom_node(module_path): def load_custom_node(module_path):

View File

@ -946,7 +946,23 @@ class ComfyApp {
({ number, batchCount } = this.#queueItems.pop()); ({ number, batchCount } = this.#queueItems.pop());
for (let i = 0; i < batchCount; i++) { for (let i = 0; i < batchCount; i++) {
const p = await this.graphToPrompt(); let p = await this.graphToPrompt();
if (i == 0) {
for (const n of p.workflow.nodes) {
const node = graph.getNodeById(n.id);
if (node.widgets) {
for (const widget of node.widgets) {
// Allow widgets to run callbacks on firts iteration of a batch
// e.g. random seed after every gen
if (widget.onInitBatch) {
widget.onInitBatch();
}
}
}
}
p = await this.graphToPrompt();
}
try { try {
await api.queuePrompt(number, p); await api.queuePrompt(number, p);

View File

@ -43,6 +43,25 @@ function seedWidget(node, inputName, inputData) {
return { widget: seed, randomize }; return { widget: seed, randomize };
} }
export function batchIndexWidget(node, inputName, inputData) {
const { val, config } = getNumberDefaults(inputData, 1);
Object.assign(config, { precision: 0 });
const batchIndex = node.addWidget("number", inputName, val, () => {}, config)
batchIndex.forbidConvertToInput = true;
batchIndex.disabled = true;
batchIndex.onInitBatch = () => {
batchIndex.value = 0;
};
batchIndex.afterQueued = () => {
batchIndex.value += 1;
};
return batchIndex;
}
const MultilineSymbol = Symbol(); const MultilineSymbol = Symbol();
const MultilineResizeSymbol = Symbol(); const MultilineResizeSymbol = Symbol();
@ -197,6 +216,7 @@ function addMultilineWidget(node, name, opts, app) {
} }
export const ComfyWidgets = { export const ComfyWidgets = {
"INT:batch_index": batchIndexWidget,
"INT:seed": seedWidget, "INT:seed": seedWidget,
"INT:noise_seed": seedWidget, "INT:noise_seed": seedWidget,
FLOAT(node, inputName, inputData) { FLOAT(node, inputName, inputData) {