mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
add batch index logic
This commit is contained in:
parent
f1bd46c519
commit
9d8ed7b28e
@ -2,17 +2,26 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import math
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def prepare_noise(latent_image, seed, skip=0):
|
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||||
"""
|
"""
|
||||||
creates random noise given a latent image and a seed.
|
creates random noise given a latent image and a seed.
|
||||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||||
"""
|
"""
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
for _ in range(skip):
|
if noise_inds is None:
|
||||||
|
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
|
|
||||||
|
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||||
|
noises = []
|
||||||
|
for i in range(unique_inds[-1]+1):
|
||||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
if i in unique_inds:
|
||||||
return noise
|
noises.append(noise)
|
||||||
|
noises = [noises[i] for i in inverse]
|
||||||
|
noises = torch.cat(noises, axis=0)
|
||||||
|
return noises
|
||||||
|
|
||||||
def prepare_mask(noise_mask, shape, device):
|
def prepare_mask(noise_mask, shape, device):
|
||||||
"""ensures noise mask is of proper dimensions"""
|
"""ensures noise mask is of proper dimensions"""
|
||||||
|
|||||||
@ -4,7 +4,7 @@ class LatentRebatch:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "latents": ("LATENT",),
|
return {"required": { "latents": ("LATENT",),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1000}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
INPUT_IS_LIST = True
|
INPUT_IS_LIST = True
|
||||||
@ -14,69 +14,88 @@ class LatentRebatch:
|
|||||||
|
|
||||||
CATEGORY = "latent"
|
CATEGORY = "latent"
|
||||||
|
|
||||||
def get_batch(self, latent, i):
|
@staticmethod
|
||||||
samples = latent[i]['samples']
|
def get_batch(latents, list_ind, offset):
|
||||||
|
'''prepare a batch out of the list of latents'''
|
||||||
|
samples = latents[list_ind]['samples']
|
||||||
shape = samples.shape
|
shape = samples.shape
|
||||||
mask = latent[i]['noise_mask'] if 'noise_mask' in latent[i] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||||
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||||
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||||
if mask.shape[0] < samples.shape[0]:
|
if mask.shape[0] < samples.shape[0]:
|
||||||
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||||
return samples, mask
|
if 'batch_index' in latents[list_ind]:
|
||||||
|
batch_inds = latents[list_ind]['batch_index']
|
||||||
|
else:
|
||||||
|
batch_inds = [x+offset for x in range(shape[0])]
|
||||||
|
return samples, mask, batch_inds
|
||||||
|
|
||||||
def get_slices(self, tensors, num, batch_size):
|
@staticmethod
|
||||||
|
def get_slices(indexable, num, batch_size):
|
||||||
|
'''divides an indexable object into num slices of length batch_size, and a remainder'''
|
||||||
slices = []
|
slices = []
|
||||||
for i in range(num):
|
for i in range(num):
|
||||||
slices.append(tensors[i*batch_size:(i+1)*batch_size])
|
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||||
if num * batch_size < tensors.shape[0]:
|
if num * batch_size < len(indexable):
|
||||||
return slices, tensors[num * batch_size:]
|
return slices, indexable[num * batch_size:]
|
||||||
else:
|
else:
|
||||||
return slices, None
|
return slices, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def slice_batch(batch, num, batch_size):
|
||||||
|
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||||
|
return list(zip(*result))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cat_batch(batch1, batch2):
|
||||||
|
if batch1[0] is None:
|
||||||
|
return batch2
|
||||||
|
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||||
|
return result
|
||||||
|
|
||||||
def rebatch(self, latents, batch_size):
|
def rebatch(self, latents, batch_size):
|
||||||
batch_size = batch_size[0]
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
output_list = []
|
output_list = []
|
||||||
current_samples = None
|
current_batch = (None, None, None)
|
||||||
current_masks = None
|
processed = 0
|
||||||
|
|
||||||
for i in range(len(latents)):
|
for i in range(len(latents)):
|
||||||
# fetch new entry of list
|
# fetch new entry of list
|
||||||
samples, masks = self.get_batch(latents, i)
|
#samples, masks, indices = self.get_batch(latents, i)
|
||||||
|
next_batch = self.get_batch(latents, i, processed)
|
||||||
|
processed += len(next_batch[2])
|
||||||
# set to current if current is None
|
# set to current if current is None
|
||||||
if current_samples is None:
|
if current_batch[0] is None:
|
||||||
current_samples = samples
|
current_batch = next_batch
|
||||||
current_masks = masks
|
|
||||||
# add previous to list if dimensions do not match
|
# add previous to list if dimensions do not match
|
||||||
elif samples.shape[-1] != current_samples.shape[-1] or samples.shape[-2] != current_samples.shape[-2]:
|
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||||
s = dict()
|
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||||
sample_slices, _ = self.get_slices(current_samples, 1, batch_size)
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
mask_slices, _ = self.get_slices(current_masks, 1, batch_size)
|
current_batch = next_batch
|
||||||
output_list.append({'samples': sample_slices[0], 'noise_mask': mask_slices[0]})
|
|
||||||
current_samples = samples
|
|
||||||
current_masks = masks
|
|
||||||
# cat if everything checks out
|
# cat if everything checks out
|
||||||
else:
|
else:
|
||||||
current_samples = torch.cat((current_samples, samples))
|
current_batch = self.cat_batch(current_batch, next_batch)
|
||||||
current_masks = torch.cat((current_masks, masks))
|
|
||||||
|
|
||||||
# add to list if dimensions gone above target batch size
|
# add to list if dimensions gone above target batch size
|
||||||
if current_samples.shape[0] > batch_size:
|
if current_batch[0].shape[0] > batch_size:
|
||||||
num = current_samples.shape[0] // batch_size
|
num = current_batch[0].shape[0] // batch_size
|
||||||
sample_slices, latent_remainder = self.get_slices(current_samples, num, batch_size)
|
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||||
mask_slices, mask_remainder = self.get_slices(current_masks, num, batch_size)
|
|
||||||
|
|
||||||
for i in range(num):
|
for i in range(num):
|
||||||
output_list.append({'samples': sample_slices[i], 'noise_mask': mask_slices[i]})
|
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||||
|
|
||||||
|
current_batch = remainder
|
||||||
|
|
||||||
current_samples = latent_remainder
|
|
||||||
current_masks = mask_remainder
|
|
||||||
|
|
||||||
#add remainder
|
#add remainder
|
||||||
if current_samples is not None:
|
if current_batch[0] is not None:
|
||||||
sample_slices, _ = self.get_slices(current_samples, 1, batch_size)
|
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||||
mask_slices, _ = self.get_slices(current_masks, 1, batch_size)
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
output_list.append({'samples': sample_slices[0], 'noise_mask': mask_slices[0]})
|
|
||||||
|
#get rid of empty masks
|
||||||
|
for s in output_list:
|
||||||
|
if s['noise_mask'].mean() == 1.0:
|
||||||
|
del s['noise_mask']
|
||||||
|
|
||||||
return (output_list,)
|
return (output_list,)
|
||||||
|
|
||||||
|
|||||||
15
nodes.py
15
nodes.py
@ -563,16 +563,21 @@ class LatentFromBatch:
|
|||||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "rotate"
|
FUNCTION = "frombatch"
|
||||||
|
|
||||||
CATEGORY = "latent"
|
CATEGORY = "latent"
|
||||||
|
|
||||||
def rotate(self, samples, batch_index):
|
def frombatch(self, samples, batch_index):
|
||||||
s = samples.copy()
|
s = samples.copy()
|
||||||
s_in = samples["samples"]
|
s_in = samples["samples"]
|
||||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||||
s["samples"] = s_in[batch_index:batch_index + 1].clone()
|
s["samples"] = s_in[batch_index:batch_index + 1].clone()
|
||||||
s["batch_index"] = batch_index
|
if "noise_mask" in samples:
|
||||||
|
s["noise_mask"] = samples["noise_mask"][batch_index:batch_index + 1].clone()
|
||||||
|
if "batch_index" not in s:
|
||||||
|
s["batch_index"] = [batch_index]
|
||||||
|
else:
|
||||||
|
s["batch_index"] = samples["batch_index"][batch_index:batch_index + 1]
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class LatentUpscale:
|
class LatentUpscale:
|
||||||
@ -747,8 +752,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
if disable_noise:
|
if disable_noise:
|
||||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||||
else:
|
else:
|
||||||
skip = latent["batch_index"] if "batch_index" in latent else 0
|
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||||
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
|
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
||||||
|
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user