diff --git a/nodes.py b/nodes.py index 1fab809b8..d0888988c 100644 --- a/nodes.py +++ b/nodes.py @@ -646,7 +646,13 @@ class LatentFromBatch: length = min(s_in.shape[0] - batch_index, length) s["samples"] = s_in[batch_index:batch_index + length].clone() if "noise_mask" in samples: - s["noise_mask"] = samples["noise_mask"][batch_index:batch_index + length].clone() + masks = samples["noise_mask"] + if masks.shape[0] == 1: + s["noise_mask"] = masks.clone() + else: + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = masks[batch_index:batch_index + length].clone() if "batch_index" not in s: s["batch_index"] = [x for x in range(batch_index, batch_index+length)] else: