From f1bd46c519d202f3cc5467da15b98864fb641308 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 28 Apr 2023 18:03:22 +0200 Subject: [PATCH] add rebatch node --- comfy_extras/nodes_rebatch.py | 89 +++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 90 insertions(+) create mode 100644 comfy_extras/nodes_rebatch.py diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py new file mode 100644 index 000000000..b5601adc8 --- /dev/null +++ b/comfy_extras/nodes_rebatch.py @@ -0,0 +1,89 @@ +import torch + +class LatentRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "latents": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1000}), + }} + RETURN_TYPES = ("LATENT",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "latent" + + def get_batch(self, latent, i): + samples = latent[i]['samples'] + shape = samples.shape + mask = latent[i]['noise_mask'] if 'noise_mask' in latent[i] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + return samples, mask + + def get_slices(self, tensors, num, batch_size): + slices = [] + for i in range(num): + slices.append(tensors[i*batch_size:(i+1)*batch_size]) + if num * batch_size < tensors.shape[0]: + return slices, tensors[num * batch_size:] + else: + return slices, None + + def rebatch(self, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_samples = None + current_masks = None + + for i in range(len(latents)): + # fetch new entry of list + samples, masks = self.get_batch(latents, i) + # set to current if current is None + if current_samples is None: + current_samples = samples + current_masks = masks + # add previous to list if dimensions do not match + elif samples.shape[-1] != current_samples.shape[-1] or samples.shape[-2] != current_samples.shape[-2]: + s = dict() + sample_slices, _ = self.get_slices(current_samples, 1, batch_size) + mask_slices, _ = self.get_slices(current_masks, 1, batch_size) + output_list.append({'samples': sample_slices[0], 'noise_mask': mask_slices[0]}) + current_samples = samples + current_masks = masks + # cat if everything checks out + else: + current_samples = torch.cat((current_samples, samples)) + current_masks = torch.cat((current_masks, masks)) + + # add to list if dimensions gone above target batch size + if current_samples.shape[0] > batch_size: + num = current_samples.shape[0] // batch_size + sample_slices, latent_remainder = self.get_slices(current_samples, num, batch_size) + mask_slices, mask_remainder = self.get_slices(current_masks, num, batch_size) + + for i in range(num): + output_list.append({'samples': sample_slices[i], 'noise_mask': mask_slices[i]}) + + current_samples = latent_remainder + current_masks = mask_remainder + + #add remainder + if current_samples is not None: + sample_slices, _ = self.get_slices(current_samples, 1, batch_size) + mask_slices, _ = self.get_slices(current_masks, 1, batch_size) + output_list.append({'samples': sample_slices[0], 'noise_mask': mask_slices[0]}) + + return (output_list,) + +NODE_CLASS_MAPPINGS = { + "RebatchLatents": LatentRebatch, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RebatchLatents": "Rebatch Latents", +} \ No newline at end of file diff --git a/nodes.py b/nodes.py index 0a9513bed..6c8b1b167 100644 --- a/nodes.py +++ b/nodes.py @@ -1231,3 +1231,4 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))