deal with noise mask edge cases in latentfrombatch

This commit is contained in:
BlenderNeko 2023-05-06 20:26:16 +02:00
parent 5edf59fb1f
commit 946da9b1bb

View File

@ -646,7 +646,13 @@ class LatentFromBatch:
length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
s["noise_mask"] = samples["noise_mask"][batch_index:batch_index + length].clone()
masks = samples["noise_mask"]
if masks.shape[0] == 1:
s["noise_mask"] = masks.clone()
else:
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
if "batch_index" not in s:
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
else: