add repeat latent batch

This commit is contained in:
BlenderNeko 2023-05-06 20:09:32 +02:00
parent 3859cbb98d
commit 5edf59fb1f
2 changed files with 38 additions and 7 deletions

View File

@ -12,7 +12,7 @@ class LatentRebatch:
FUNCTION = "rebatch" FUNCTION = "rebatch"
CATEGORY = "latent" CATEGORY = "latent/batch"
@staticmethod @staticmethod
def get_batch(latents, list_ind, offset): def get_batch(latents, list_ind, offset):

View File

@ -632,23 +632,51 @@ class LatentFromBatch:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "frombatch" FUNCTION = "frombatch"
CATEGORY = "latent" CATEGORY = "latent/batch"
def frombatch(self, samples, batch_index): def frombatch(self, samples, batch_index, length):
s = samples.copy() s = samples.copy()
s_in = samples["samples"] s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index) batch_index = min(s_in.shape[0] - 1, batch_index)
s["samples"] = s_in[batch_index:batch_index + 1].clone() length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples: if "noise_mask" in samples:
s["noise_mask"] = samples["noise_mask"][batch_index:batch_index + 1].clone() s["noise_mask"] = samples["noise_mask"][batch_index:batch_index + length].clone()
if "batch_index" not in s: if "batch_index" not in s:
s["batch_index"] = [batch_index] s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
else: else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + 1] s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "repeat"
CATEGORY = "latent/batch"
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
return (s,) return (s,)
class LatentUpscale: class LatentUpscale:
@ -1176,6 +1204,7 @@ NODE_CLASS_MAPPINGS = {
"EmptyLatentImage": EmptyLatentImage, "EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale, "LatentUpscale": LatentUpscale,
"LatentFromBatch": LatentFromBatch, "LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage, "SaveImage": SaveImage,
"PreviewImage": PreviewImage, "PreviewImage": PreviewImage,
"LoadImage": LoadImage, "LoadImage": LoadImage,
@ -1250,6 +1279,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentImage": "Empty Latent Image", "EmptyLatentImage": "Empty Latent Image",
"LatentUpscale": "Upscale Latent", "LatentUpscale": "Upscale Latent",
"LatentComposite": "Latent Composite", "LatentComposite": "Latent Composite",
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image # Image
"SaveImage": "Save Image", "SaveImage": "Save Image",
"PreviewImage": "Preview Image", "PreviewImage": "Preview Image",