diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 43f623a62..d585b3b81 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -17,9 +17,15 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou if mask is None: mask = torch.ones_like(source) else: + if len(mask.shape) != 3: # Check if mask has a batch dimension + print("is batch") + mask = mask.unsqueeze(0) # Add a batch dimension to the mask tensor mask = mask.clone() - mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear") - mask = mask.repeat((source.shape[0], source.shape[1], 1, 1)) + resized_masks = [] + for i in range(mask.shape[0]): + resized_mask = torch.nn.functional.interpolate(mask[i][None, None], size=(source.shape[2], source.shape[3]), mode="bilinear") + resized_masks.append(resized_mask) + mask = torch.cat(resized_masks, dim=0) # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds