mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 12:22:37 +08:00
replace concat mask node
This commit is contained in:
parent
3aa8f900d6
commit
fce182c53b
@ -39,10 +39,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
|||||||
|
|
||||||
inverse_mask = torch.ones_like(mask) - mask
|
inverse_mask = torch.ones_like(mask) - mask
|
||||||
|
|
||||||
source_portion = mask * source[..., :visible_height, :visible_width]
|
source_rgb = source[:, :3, :visible_height, :visible_width]
|
||||||
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
|
dest_slice = destination[..., top:bottom, left:right]
|
||||||
|
|
||||||
|
if destination.shape[1] == 4:
|
||||||
|
if torch.max(dest_slice) == 0:
|
||||||
|
destination[:, :3, top:bottom, left:right] = source_rgb
|
||||||
|
destination[:, 3:4, top:bottom, left:right] = mask
|
||||||
|
else:
|
||||||
|
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
|
||||||
|
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
|
||||||
|
else:
|
||||||
|
source_portion = mask * source_rgb
|
||||||
|
destination_portion = inverse_mask * dest_slice
|
||||||
|
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||||
|
|
||||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
class LatentCompositeMasked(IO.ComfyNode):
|
class LatentCompositeMasked(IO.ComfyNode):
|
||||||
@ -82,18 +93,23 @@ class ImageCompositeMasked(IO.ComfyNode):
|
|||||||
search_aliases=["paste image", "overlay", "layer"],
|
search_aliases=["paste image", "overlay", "layer"],
|
||||||
category="image",
|
category="image",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("destination"),
|
|
||||||
IO.Image.Input("source"),
|
IO.Image.Input("source"),
|
||||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
IO.Boolean.Input("resize_source", default=False),
|
IO.Boolean.Input("resize_source", default=False),
|
||||||
|
IO.Image.Input("destination", optional=True),
|
||||||
IO.Mask.Input("mask", optional=True),
|
IO.Mask.Input("mask", optional=True),
|
||||||
],
|
],
|
||||||
outputs=[IO.Image.Output()],
|
outputs=[IO.Image.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
|
||||||
|
if destination is None: # transparent rgba
|
||||||
|
B, H, W, C = source.shape
|
||||||
|
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
|
||||||
|
if C == 3:
|
||||||
|
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
|
||||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||||
destination = destination.clone().movedim(-1, 1)
|
destination = destination.clone().movedim(-1, 1)
|
||||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||||
@ -397,33 +413,6 @@ class ClipVisionToMask(IO.ComfyNode):
|
|||||||
|
|
||||||
clip_vision_to_mask = execute
|
clip_vision_to_mask = execute
|
||||||
|
|
||||||
class ConcatMask(IO.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="ConcatMask",
|
|
||||||
search_aliases=["add mask", "concat mask", "merge mask"],
|
|
||||||
category="mask",
|
|
||||||
inputs=[
|
|
||||||
IO.Mask.Input("mask"),
|
|
||||||
IO.Image.Input("image"),
|
|
||||||
],
|
|
||||||
outputs=[IO.Image.Output("rgba_image"), IO.Mask.Output("input_mask")],
|
|
||||||
)
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, mask, image):
|
|
||||||
if image.shape[-1] == 3:
|
|
||||||
image = image.movedim(-1, 1)
|
|
||||||
target_h, target_w = image.shape[2], image.shape[3]
|
|
||||||
if mask.shape[-2:] != (target_h, target_w):
|
|
||||||
mask = torch.nn.functional.interpolate(
|
|
||||||
mask, size=(target_h, target_w), mode='bicubic', align_corners=False
|
|
||||||
)
|
|
||||||
rgba = torch.cat([image, mask], dim = 1)
|
|
||||||
return IO.NodeOutput(rgba.movedim(1, -1), mask)
|
|
||||||
|
|
||||||
concat_mask = execute
|
|
||||||
|
|
||||||
class ThresholdMask(IO.ComfyNode):
|
class ThresholdMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -487,7 +476,6 @@ class MaskExtension(ComfyExtension):
|
|||||||
GrowMask,
|
GrowMask,
|
||||||
ThresholdMask,
|
ThresholdMask,
|
||||||
MaskPreview,
|
MaskPreview,
|
||||||
ConcatMask,
|
|
||||||
ClipVisionToMask
|
ClipVisionToMask
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user