From 9d8ed7b28e825cdddf33c494d94c9f475b902d21 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 6 May 2023 15:15:49 +0200 Subject: [PATCH] add batch index logic --- comfy/sample.py | 17 +++++-- comfy_extras/nodes_rebatch.py | 91 +++++++++++++++++++++-------------- nodes.py | 15 ++++-- 3 files changed, 78 insertions(+), 45 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index f4132bbed..99c4c9a4a 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,17 +2,26 @@ import torch import comfy.model_management import comfy.samplers import math +import numpy as np -def prepare_noise(latent_image, seed, skip=0): +def prepare_noise(latent_image, seed, noise_inds=None): """ creates random noise given a latent image and a seed. optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.manual_seed(seed) - for _ in range(skip): + if noise_inds is None: + return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1]+1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - return noise + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + noises = torch.cat(noises, axis=0) + return noises def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index b5601adc8..14d948ec0 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -4,7 +4,7 @@ class LatentRebatch: @classmethod def INPUT_TYPES(s): return {"required": { "latents": ("LATENT",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 1000}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) INPUT_IS_LIST = True @@ -14,69 +14,88 @@ class LatentRebatch: CATEGORY = "latent" - def get_batch(self, latent, i): - samples = latent[i]['samples'] + @staticmethod + def get_batch(latents, list_ind, offset): + '''prepare a batch out of the list of latents''' + samples = latents[list_ind]['samples'] shape = samples.shape - mask = latent[i]['noise_mask'] if 'noise_mask' in latent[i] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") if mask.shape[0] < samples.shape[0]: mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] - return samples, mask + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds - def get_slices(self, tensors, num, batch_size): + @staticmethod + def get_slices(indexable, num, batch_size): + '''divides an indexable object into num slices of length batch_size, and a remainder''' slices = [] for i in range(num): - slices.append(tensors[i*batch_size:(i+1)*batch_size]) - if num * batch_size < tensors.shape[0]: - return slices, tensors[num * batch_size:] + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] else: return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result def rebatch(self, latents, batch_size): batch_size = batch_size[0] output_list = [] - current_samples = None - current_masks = None + current_batch = (None, None, None) + processed = 0 for i in range(len(latents)): # fetch new entry of list - samples, masks = self.get_batch(latents, i) + #samples, masks, indices = self.get_batch(latents, i) + next_batch = self.get_batch(latents, i, processed) + processed += len(next_batch[2]) # set to current if current is None - if current_samples is None: - current_samples = samples - current_masks = masks + if current_batch[0] is None: + current_batch = next_batch # add previous to list if dimensions do not match - elif samples.shape[-1] != current_samples.shape[-1] or samples.shape[-2] != current_samples.shape[-2]: - s = dict() - sample_slices, _ = self.get_slices(current_samples, 1, batch_size) - mask_slices, _ = self.get_slices(current_masks, 1, batch_size) - output_list.append({'samples': sample_slices[0], 'noise_mask': mask_slices[0]}) - current_samples = samples - current_masks = masks + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch # cat if everything checks out else: - current_samples = torch.cat((current_samples, samples)) - current_masks = torch.cat((current_masks, masks)) + current_batch = self.cat_batch(current_batch, next_batch) # add to list if dimensions gone above target batch size - if current_samples.shape[0] > batch_size: - num = current_samples.shape[0] // batch_size - sample_slices, latent_remainder = self.get_slices(current_samples, num, batch_size) - mask_slices, mask_remainder = self.get_slices(current_masks, num, batch_size) + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = self.slice_batch(current_batch, num, batch_size) for i in range(num): - output_list.append({'samples': sample_slices[i], 'noise_mask': mask_slices[i]}) + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder - current_samples = latent_remainder - current_masks = mask_remainder - #add remainder - if current_samples is not None: - sample_slices, _ = self.get_slices(current_samples, 1, batch_size) - mask_slices, _ = self.get_slices(current_masks, 1, batch_size) - output_list.append({'samples': sample_slices[0], 'noise_mask': mask_slices[0]}) + if current_batch[0] is not None: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] return (output_list,) diff --git a/nodes.py b/nodes.py index 6c8b1b167..239ab0d5a 100644 --- a/nodes.py +++ b/nodes.py @@ -563,16 +563,21 @@ class LatentFromBatch: "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), }} RETURN_TYPES = ("LATENT",) - FUNCTION = "rotate" + FUNCTION = "frombatch" CATEGORY = "latent" - def rotate(self, samples, batch_index): + def frombatch(self, samples, batch_index): s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) s["samples"] = s_in[batch_index:batch_index + 1].clone() - s["batch_index"] = batch_index + if "noise_mask" in samples: + s["noise_mask"] = samples["noise_mask"][batch_index:batch_index + 1].clone() + if "batch_index" not in s: + s["batch_index"] = [batch_index] + else: + s["batch_index"] = samples["batch_index"][batch_index:batch_index + 1] return (s,) class LatentUpscale: @@ -747,8 +752,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - skip = latent["batch_index"] if "batch_index" in latent else 0 - noise = comfy.sample.prepare_noise(latent_image, seed, skip) + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: