From 5edf59fb1fbbedb4a58fb7ebde4062d7edb9a0f3 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 6 May 2023 20:09:32 +0200 Subject: [PATCH] add repeat latent batch --- comfy_extras/nodes_rebatch.py | 2 +- nodes.py | 43 ++++++++++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 14d948ec0..0a9daf272 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -12,7 +12,7 @@ class LatentRebatch: FUNCTION = "rebatch" - CATEGORY = "latent" + CATEGORY = "latent/batch" @staticmethod def get_batch(latents, list_ind, offset): diff --git a/nodes.py b/nodes.py index d7f3e8f0c..1fab809b8 100644 --- a/nodes.py +++ b/nodes.py @@ -632,23 +632,51 @@ class LatentFromBatch: def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "frombatch" - CATEGORY = "latent" + CATEGORY = "latent/batch" - def frombatch(self, samples, batch_index): + def frombatch(self, samples, batch_index, length): s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) - s["samples"] = s_in[batch_index:batch_index + 1].clone() + length = min(s_in.shape[0] - batch_index, length) + s["samples"] = s_in[batch_index:batch_index + length].clone() if "noise_mask" in samples: - s["noise_mask"] = samples["noise_mask"][batch_index:batch_index + 1].clone() + s["noise_mask"] = samples["noise_mask"][batch_index:batch_index + length].clone() if "batch_index" not in s: - s["batch_index"] = [batch_index] + s["batch_index"] = [x for x in range(batch_index, batch_index+length)] else: - s["batch_index"] = samples["batch_index"][batch_index:batch_index + 1] + s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] + return (s,) + +class RepeatLatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "repeat" + + CATEGORY = "latent/batch" + + def repeat(self, samples, amount): + s = samples.copy() + s_in = samples["samples"] + + s["samples"] = s_in.repeat((amount, 1,1,1)) + if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: + masks = samples["noise_mask"] + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + if "batch_index" in s: + offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 + s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) class LatentUpscale: @@ -1176,6 +1204,7 @@ NODE_CLASS_MAPPINGS = { "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentFromBatch": LatentFromBatch, + "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, @@ -1250,6 +1279,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", + "LatentFromBatch" : "Latent From Batch", + "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", "PreviewImage": "Preview Image",