mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 15:32:35 +08:00
add rebatch node
This commit is contained in:
parent
0732fc8f2a
commit
f1bd46c519
89
comfy_extras/nodes_rebatch.py
Normal file
89
comfy_extras/nodes_rebatch.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class LatentRebatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "latents": ("LATENT",),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1000}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
INPUT_IS_LIST = True
|
||||||
|
OUTPUT_IS_LIST = (True, )
|
||||||
|
|
||||||
|
FUNCTION = "rebatch"
|
||||||
|
|
||||||
|
CATEGORY = "latent"
|
||||||
|
|
||||||
|
def get_batch(self, latent, i):
|
||||||
|
samples = latent[i]['samples']
|
||||||
|
shape = samples.shape
|
||||||
|
mask = latent[i]['noise_mask'] if 'noise_mask' in latent[i] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||||
|
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||||
|
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||||
|
if mask.shape[0] < samples.shape[0]:
|
||||||
|
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||||
|
return samples, mask
|
||||||
|
|
||||||
|
def get_slices(self, tensors, num, batch_size):
|
||||||
|
slices = []
|
||||||
|
for i in range(num):
|
||||||
|
slices.append(tensors[i*batch_size:(i+1)*batch_size])
|
||||||
|
if num * batch_size < tensors.shape[0]:
|
||||||
|
return slices, tensors[num * batch_size:]
|
||||||
|
else:
|
||||||
|
return slices, None
|
||||||
|
|
||||||
|
def rebatch(self, latents, batch_size):
|
||||||
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
current_samples = None
|
||||||
|
current_masks = None
|
||||||
|
|
||||||
|
for i in range(len(latents)):
|
||||||
|
# fetch new entry of list
|
||||||
|
samples, masks = self.get_batch(latents, i)
|
||||||
|
# set to current if current is None
|
||||||
|
if current_samples is None:
|
||||||
|
current_samples = samples
|
||||||
|
current_masks = masks
|
||||||
|
# add previous to list if dimensions do not match
|
||||||
|
elif samples.shape[-1] != current_samples.shape[-1] or samples.shape[-2] != current_samples.shape[-2]:
|
||||||
|
s = dict()
|
||||||
|
sample_slices, _ = self.get_slices(current_samples, 1, batch_size)
|
||||||
|
mask_slices, _ = self.get_slices(current_masks, 1, batch_size)
|
||||||
|
output_list.append({'samples': sample_slices[0], 'noise_mask': mask_slices[0]})
|
||||||
|
current_samples = samples
|
||||||
|
current_masks = masks
|
||||||
|
# cat if everything checks out
|
||||||
|
else:
|
||||||
|
current_samples = torch.cat((current_samples, samples))
|
||||||
|
current_masks = torch.cat((current_masks, masks))
|
||||||
|
|
||||||
|
# add to list if dimensions gone above target batch size
|
||||||
|
if current_samples.shape[0] > batch_size:
|
||||||
|
num = current_samples.shape[0] // batch_size
|
||||||
|
sample_slices, latent_remainder = self.get_slices(current_samples, num, batch_size)
|
||||||
|
mask_slices, mask_remainder = self.get_slices(current_masks, num, batch_size)
|
||||||
|
|
||||||
|
for i in range(num):
|
||||||
|
output_list.append({'samples': sample_slices[i], 'noise_mask': mask_slices[i]})
|
||||||
|
|
||||||
|
current_samples = latent_remainder
|
||||||
|
current_masks = mask_remainder
|
||||||
|
|
||||||
|
#add remainder
|
||||||
|
if current_samples is not None:
|
||||||
|
sample_slices, _ = self.get_slices(current_samples, 1, batch_size)
|
||||||
|
mask_slices, _ = self.get_slices(current_masks, 1, batch_size)
|
||||||
|
output_list.append({'samples': sample_slices[0], 'noise_mask': mask_slices[0]})
|
||||||
|
|
||||||
|
return (output_list,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"RebatchLatents": LatentRebatch,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"RebatchLatents": "Rebatch Latents",
|
||||||
|
}
|
||||||
1
nodes.py
1
nodes.py
@ -1231,3 +1231,4 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user