Allow mask batches

This allows LatentCompositeMasked -node to work with AnimateDiff. I tried to keep old functionality too, unsure if it's correct, but both single mask and batch of masks seems to work with this change.
This commit is contained in:
kijai 2023-09-24 22:14:07 +03:00
parent 77c124c5a1
commit 3e4c0f67d1

View File

@ -17,9 +17,15 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if mask is None:
mask = torch.ones_like(source)
else:
if len(mask.shape) != 3: # Check if mask has a batch dimension
print("is batch")
mask = mask.unsqueeze(0) # Add a batch dimension to the mask tensor
mask = mask.clone()
mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = mask.repeat((source.shape[0], source.shape[1], 1, 1))
resized_masks = []
for i in range(mask.shape[0]):
resized_mask = torch.nn.functional.interpolate(mask[i][None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
resized_masks.append(resized_mask)
mask = torch.cat(resized_masks, dim=0)
# calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds