mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 22:12:33 +08:00
This allows LatentCompositeMasked -node to work with AnimateDiff. I tried to keep old functionality too, unsure if it's correct, but both single mask and batch of masks seems to work with this change.
358 lines
12 KiB
Python
358 lines
12 KiB
Python
import numpy as np
|
|
from scipy.ndimage import grey_dilation
|
|
import torch
|
|
|
|
from nodes import MAX_RESOLUTION
|
|
|
|
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
|
if resize_source:
|
|
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
|
|
|
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
|
|
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
|
|
|
|
left, top = (x // multiplier, y // multiplier)
|
|
right, bottom = (left + source.shape[3], top + source.shape[2],)
|
|
|
|
if mask is None:
|
|
mask = torch.ones_like(source)
|
|
else:
|
|
if len(mask.shape) != 3: # Check if mask has a batch dimension
|
|
print("is batch")
|
|
mask = mask.unsqueeze(0) # Add a batch dimension to the mask tensor
|
|
mask = mask.clone()
|
|
resized_masks = []
|
|
for i in range(mask.shape[0]):
|
|
resized_mask = torch.nn.functional.interpolate(mask[i][None, None], size=(source.shape[2], source.shape[3]), mode="bilinear")
|
|
resized_masks.append(resized_mask)
|
|
mask = torch.cat(resized_masks, dim=0)
|
|
|
|
# calculate the bounds of the source that will be overlapping the destination
|
|
# this prevents the source trying to overwrite latent pixels that are out of bounds
|
|
# of the destination
|
|
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
|
|
|
|
mask = mask[:, :, :visible_height, :visible_width]
|
|
inverse_mask = torch.ones_like(mask) - mask
|
|
|
|
source_portion = mask * source[:, :, :visible_height, :visible_width]
|
|
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
|
|
|
|
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
|
|
return destination
|
|
|
|
class LatentCompositeMasked:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"destination": ("LATENT",),
|
|
"source": ("LATENT",),
|
|
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
|
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
|
"resize_source": ("BOOLEAN", {"default": False}),
|
|
},
|
|
"optional": {
|
|
"mask": ("MASK",),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("LATENT",)
|
|
FUNCTION = "composite"
|
|
|
|
CATEGORY = "latent"
|
|
|
|
def composite(self, destination, source, x, y, resize_source, mask = None):
|
|
output = destination.copy()
|
|
destination = destination["samples"].clone()
|
|
source = source["samples"]
|
|
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
|
|
return (output,)
|
|
|
|
class ImageCompositeMasked:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"destination": ("IMAGE",),
|
|
"source": ("IMAGE",),
|
|
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"resize_source": ("BOOLEAN", {"default": False}),
|
|
},
|
|
"optional": {
|
|
"mask": ("MASK",),
|
|
}
|
|
}
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "composite"
|
|
|
|
CATEGORY = "image"
|
|
|
|
def composite(self, destination, source, x, y, resize_source, mask = None):
|
|
destination = destination.clone().movedim(-1, 1)
|
|
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
|
return (output,)
|
|
|
|
class MaskToImage:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "mask_to_image"
|
|
|
|
def mask_to_image(self, mask):
|
|
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
|
return (result,)
|
|
|
|
class ImageToMask:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"channel": (["red", "green", "blue"],),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "image_to_mask"
|
|
|
|
def image_to_mask(self, image, channel):
|
|
channels = ["red", "green", "blue"]
|
|
mask = image[0, :, :, channels.index(channel)]
|
|
return (mask,)
|
|
|
|
class ImageColorToMask:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "image_to_mask"
|
|
|
|
def image_to_mask(self, image, color):
|
|
temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int)
|
|
temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2]
|
|
mask = torch.where(temp == color, 255, 0).float()
|
|
return (mask,)
|
|
|
|
class SolidMask:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
|
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
|
|
FUNCTION = "solid"
|
|
|
|
def solid(self, value, width, height):
|
|
out = torch.full((height, width), value, dtype=torch.float32, device="cpu")
|
|
return (out,)
|
|
|
|
class InvertMask:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
|
|
FUNCTION = "invert"
|
|
|
|
def invert(self, mask):
|
|
out = 1.0 - mask
|
|
return (out,)
|
|
|
|
class CropMask:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
|
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
|
|
FUNCTION = "crop"
|
|
|
|
def crop(self, mask, x, y, width, height):
|
|
out = mask[y:y + height, x:x + width]
|
|
return (out,)
|
|
|
|
class MaskComposite:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"destination": ("MASK",),
|
|
"source": ("MASK",),
|
|
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
|
|
FUNCTION = "combine"
|
|
|
|
def combine(self, destination, source, x, y, operation):
|
|
output = destination.clone()
|
|
|
|
left, top = (x, y,)
|
|
right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0]))
|
|
visible_width, visible_height = (right - left, bottom - top,)
|
|
|
|
source_portion = source[:visible_height, :visible_width]
|
|
destination_portion = destination[top:bottom, left:right]
|
|
|
|
if operation == "multiply":
|
|
output[top:bottom, left:right] = destination_portion * source_portion
|
|
elif operation == "add":
|
|
output[top:bottom, left:right] = destination_portion + source_portion
|
|
elif operation == "subtract":
|
|
output[top:bottom, left:right] = destination_portion - source_portion
|
|
elif operation == "and":
|
|
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
|
|
elif operation == "or":
|
|
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
|
|
elif operation == "xor":
|
|
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
|
|
|
|
output = torch.clamp(output, 0.0, 1.0)
|
|
|
|
return (output,)
|
|
|
|
class FeatherMask:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
|
|
FUNCTION = "feather"
|
|
|
|
def feather(self, mask, left, top, right, bottom):
|
|
output = mask.clone()
|
|
|
|
left = min(left, output.shape[1])
|
|
right = min(right, output.shape[1])
|
|
top = min(top, output.shape[0])
|
|
bottom = min(bottom, output.shape[0])
|
|
|
|
for x in range(left):
|
|
feather_rate = (x + 1.0) / left
|
|
output[:, x] *= feather_rate
|
|
|
|
for x in range(right):
|
|
feather_rate = (x + 1) / right
|
|
output[:, -x] *= feather_rate
|
|
|
|
for y in range(top):
|
|
feather_rate = (y + 1) / top
|
|
output[y, :] *= feather_rate
|
|
|
|
for y in range(bottom):
|
|
feather_rate = (y + 1) / bottom
|
|
output[-y, :] *= feather_rate
|
|
|
|
return (output,)
|
|
|
|
class GrowMask:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
"expand": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
|
"tapered_corners": ("BOOLEAN", {"default": True}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "mask"
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
|
|
FUNCTION = "expand_mask"
|
|
|
|
def expand_mask(self, mask, expand, tapered_corners):
|
|
c = 0 if tapered_corners else 1
|
|
kernel = np.array([[c, 1, c],
|
|
[1, 1, 1],
|
|
[c, 1, c]])
|
|
output = mask.numpy().copy()
|
|
while expand > 0:
|
|
output = grey_dilation(output, footprint=kernel)
|
|
expand -= 1
|
|
output = torch.from_numpy(output)
|
|
return (output,)
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"LatentCompositeMasked": LatentCompositeMasked,
|
|
"ImageCompositeMasked": ImageCompositeMasked,
|
|
"MaskToImage": MaskToImage,
|
|
"ImageToMask": ImageToMask,
|
|
"ImageColorToMask": ImageColorToMask,
|
|
"SolidMask": SolidMask,
|
|
"InvertMask": InvertMask,
|
|
"CropMask": CropMask,
|
|
"MaskComposite": MaskComposite,
|
|
"FeatherMask": FeatherMask,
|
|
"GrowMask": GrowMask,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"ImageToMask": "Convert Image to Mask",
|
|
"MaskToImage": "Convert Mask to Image",
|
|
}
|