Match upstream mask background color fix

This commit is contained in:
kijai 2026-06-11 10:51:04 +03:00
parent 593e7c9267
commit c895fc8d06

View File

@ -267,7 +267,7 @@ class SCAIL2ColoredMask(io.ComfyNode):
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right", io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."), tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."),
io.Boolean.Input("replacement_mode", default=False, io.Boolean.Input("replacement_mode", default=False,
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."), tooltip="False = Animation Mode (pose black bg, reference white bg). True = Replacement Mode (pose white bg, reference black bg)."),
], ],
outputs=[ outputs=[
io.Image.Output("pose_video_mask"), io.Image.Output("pose_video_mask"),
@ -296,14 +296,17 @@ class SCAIL2ColoredMask(io.ComfyNode):
return td return td
drv = _prep(driving_track_data) drv = _prep(driving_track_data)
# Animation: driving=black, ref=white. Replacement: driving=white, ref=black.
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black") mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
ref_bg = "black" if replacement_mode else "white"
if ref_track_data is not None: if ref_track_data is not None:
ref = _prep(ref_track_data) ref = _prep(ref_track_data)
reference_image_mask = _render_colored_masks(ref, "black") reference_image_mask = _render_colored_masks(ref, ref_bg)
else: else:
H, W = drv["orig_size"] H, W = drv["orig_size"]
reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) fill_value = 1.0 if ref_bg == "white" else 0.0
reference_image_mask = torch.full((1, H, W, 3), fill_value, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput(mask_video, reference_image_mask) return io.NodeOutput(mask_video, reference_image_mask)