diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index ece50cfd7..939f3303e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1810,9 +1810,9 @@ class SCAIL2WanModel(SCAILWanModel): return x - # Reads the first element of ref_mask_flag and assumes a uniform mode across the batch. + # ref_mask_flag is a scalar bool (CONDConstant); the mode is uniform across the batch. def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): - is_replacement = ref_mask_flag is not None and not bool(ref_mask_flag.flatten()[0].item()) + is_replacement = ref_mask_flag is not None and not bool(ref_mask_flag) if not is_replacement: return super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, pose_latents=pose_latents, reference_latent=reference_latent, transformer_options=transformer_options) diff --git a/comfy/model_base.py b/comfy/model_base.py index cb795f792..d212a7c2a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1758,7 +1758,6 @@ class WAN21_SCAIL2(WAN21_SCAIL): """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): - # Bypass WAN21.__init__ to override unet_model to SCAIL2WanModel. super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel) self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents") self.memory_usage_shape_process = { @@ -1770,29 +1769,29 @@ class WAN21_SCAIL2(WAN21_SCAIL): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) - sam_28ch = kwargs.get("sam_28ch", None) - if sam_28ch is not None: - out['sam_latents'] = comfy.conds.CONDRegular(sam_28ch.movedim(1, 2).contiguous()) + driving_mask_28ch = kwargs.get("driving_mask_28ch", None) + if driving_mask_28ch is not None: + out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous()) - ref_sam_28ch = kwargs.get("ref_sam_28ch", None) - if ref_sam_28ch is not None: - out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_sam_28ch.movedim(1, 2).contiguous()) + ref_mask_28ch = kwargs.get("ref_mask_28ch", None) + if ref_mask_28ch is not None: + out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous()) ref_mask_flag = kwargs.get("ref_mask_flag", None) if ref_mask_flag is not None: - out['ref_mask_flag'] = comfy.conds.CONDRegular(ref_mask_flag) + out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag) return out def extra_conds_shapes(self, **kwargs): out = super().extra_conds_shapes(**kwargs) - sam_28ch = kwargs.get("sam_28ch", None) - if sam_28ch is not None: - s = sam_28ch.shape + driving_mask_28ch = kwargs.get("driving_mask_28ch", None) + if driving_mask_28ch is not None: + s = driving_mask_28ch.shape out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]] - ref_sam_28ch = kwargs.get("ref_sam_28ch", None) - if ref_sam_28ch is not None: - s = ref_sam_28ch.shape + ref_mask_28ch = kwargs.get("ref_mask_28ch", None) + if ref_mask_28ch is not None: + s = ref_mask_28ch.shape out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]] return out @@ -1802,9 +1801,7 @@ class WAN21_SCAIL2(WAN21_SCAIL): return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) def concat_cond(self, **kwargs): - # Override base path that short-circuits to 4 zeros when image_to_video=False - # and extra_channels == image.shape[1]: history needs the mask channels to be 1 - # at anchor slots. + # The 4 extra channels are the history_mask (1 at clean-anchor frames). noise = kwargs.get("noise", None) extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] if extra_channels != 4: diff --git a/comfy_extras/nodes_scail.py b/comfy_extras/nodes_scail.py new file mode 100644 index 000000000..afb27f566 --- /dev/null +++ b/comfy_extras/nodes_scail.py @@ -0,0 +1,333 @@ +"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3 +preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes.""" + +from typing_extensions import override + +import torch +import torch.nn.functional as F + +import nodes +import node_helpers +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, io + + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") + + +# Model was trained on these exact colors; deviating degrades multi-identity quality. +DEFAULT_PALETTE = [ + (0.0, 0.0, 1.0), # Blue + (1.0, 0.0, 0.0), # Red + (0.0, 1.0, 0.0), # Green + (1.0, 0.0, 1.0), # Magenta + (0.0, 1.0, 1.0), # Cyan + (1.0, 1.0, 0.0), # Yellow +] + + +def _unpack(track_data): + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + if packed is None or packed.shape[1] == 0: + return None + return unpack_masks(packed) + + +def _first_frame_cx_area(masks_bool): + first = masks_bool[0].float() + H, W = first.shape[-2], first.shape[-1] + n_pixels = H * W + grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W) + area = first.sum(dim=(-1, -2)).clamp_(min=1) + cx = (first * grid_x).sum(dim=(-1, -2)) / area + return (cx / W).tolist(), (area / n_pixels).tolist() + + +def _sort_tracks(track_data, sort_by): + masks_bool = _unpack(track_data) + if masks_bool is None: + return [] + cx, area = _first_frame_cx_area(masks_bool) + if sort_by == "x": + return sorted(range(len(cx)), key=lambda i: cx[i]) + return sorted(range(len(area)), key=lambda i: -area[i]) # "area" + + +def _subset_track_data(track_data, obj_indices): + out = dict(track_data) + packed = track_data["packed_masks"] + if packed is None or not obj_indices: + out["packed_masks"] = None + if "scores" in out: + out["scores"] = [] + return out + out["packed_masks"] = packed[:, obj_indices].contiguous() + scores = track_data.get("scores") + if scores is not None: + out["scores"] = [scores[i] for i in obj_indices if i < len(scores)] + return out + + +def _bg_to_rgb(background): + if background.startswith("white"): + return (1.0, 1.0, 1.0) + return (0.0, 0.0, 0.0) + + +def _render_colored_masks(track_data, background="black"): + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + device = comfy.model_management.intermediate_device() + dtype = comfy.model_management.intermediate_dtype() + bg_rgb = _bg_to_rgb(background) + if packed is None or packed.shape[1] == 0: + T = track_data.get("n_frames", 1) if packed is None else packed.shape[0] + out = torch.empty(T, H, W, 3, device=device, dtype=dtype) + out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2] + return out + T, N_obj = packed.shape[0], packed.shape[1] + colors = torch.tensor( + [DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)], + device=device, dtype=dtype, + ) + masks_full = unpack_masks(packed.to(device)).float() + Hm, Wm = masks_full.shape[-2], masks_full.shape[-1] + masks_full = F.interpolate( + masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest" + ).view(T, N_obj, H, W) > 0.5 + any_mask = masks_full.any(dim=1) + obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1) + color_overlay = colors[obj_idx_map] + bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3) + return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay)) + + +def _extract_mask_to_28ch(rgb_video): + """Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent + (1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c) + threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking.""" + T, H, W, _ = rgb_video.shape + _ON_THRESH = 225.0 / 255.0 + mask = rgb_video.movedim(-1, 1).float() + R = (mask[:, 0:1] > _ON_THRESH).float() + G = (mask[:, 1:2] > _ON_THRESH).float() + B = (mask[:, 2:3] > _ON_THRESH).float() + nR, nG, nB = 1 - R, 1 - G, 1 - B + binary_7ch = torch.cat([ + R * G * B, # white + R * nG * nB, # red + nR * G * nB, # green + nR * nG * B, # blue + R * G * nB, # yellow + R * nG * B, # magenta + nR * G * B, # cyan + ], dim=1) + H_lat, W_lat = H, W + for _ in range(3): + H_lat = (H_lat + 1) // 2 + W_lat = (W_lat + 1) // 2 + binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area') + T_latent = (T - 1) // 4 + 1 + padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0) + out = padded.view(T_latent, 28, H_lat, W_lat) + return out.unsqueeze(0) + + +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="model/conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."), + io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (mask bg black). True = Replacement Mode (mask bg white; the reference is composited onto black using ref_image_mask as an alpha matte). Must match how the colored masks were rendered."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), + io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."), + io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Single-frame colored ref mask at the reference image's full resolution."), + io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."), + io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."), + io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."), + io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, + video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None, + pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + noise_mask = None + + ref_mask_flag = not replacement_mode + positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag}) + negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag}) + + prev_trimmed = None + if previous_frames is not None and previous_frames.shape[0] > 0: + prev_trimmed = previous_frames[-previous_frame_count:] + video_frame_offset -= prev_trimmed.shape[0] + video_frame_offset = max(0, video_frame_offset) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + # Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte + if replacement_mode and reference_image_mask is not None: + rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) + is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype) + reference_image = reference_image * is_char + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + if pose_video.shape[0] <= video_frame_offset: + pose_video = None + else: + pose_video = pose_video[video_frame_offset:] + if pose_video_mask is not None: + if pose_video_mask.shape[0] <= video_frame_offset: + pose_video_mask = None + else: + pose_video_mask = pose_video_mask[video_frame_offset:] + + # Truncate pose+mask jointly to the shorter of the two, capped at length. + ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None] + if ts: + T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1 + if pose_video is not None: + pose_video = pose_video[:T_kept] + if pose_video_mask is not None: + pose_video_mask = pose_video_mask[:T_kept] + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + + if pose_video_mask is not None: + mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw) + positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch}) + negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch}) + + if reference_image_mask is not None: + ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw) + zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype) + ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1) + positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch}) + negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch}) + + if prev_trimmed is not None: + pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + prev_latent = vae.encode(pf[:, :, :, :3]) + prev_latent_frames = min(prev_latent.shape[2], latent.shape[2]) + latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype) + noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype) + noise_mask[:, :, :prev_latent_frames] = 0.0 + + out_latent = {"samples": latent} + if noise_mask is not None: + out_latent["noise_mask"] = noise_mask + return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length) + + +class SCAIL2ColoredMask(io.ComfyNode): + """Render SAM3 tracks for the driving video and (optionally) the reference + image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by` + across both outputs guarantees identity K maps to the same color on both + sides, for multi-person workflow consistency. + ref_mask is always rendered black-bg (model convention) + mask_video bg follows replacement_mode: black = Animation Mode, white = Replacement Mode + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SCAIL2ColoredMask", + display_name="SCAIL-2 Colored Mask", + category="conditioning/video_models/scail", + inputs=[ + SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving video. Will be rendered into the driving_mask_video output."), + SAM3TrackData.Input("ref_track_data", optional=True, + tooltip="SAM3 track of the reference image. Optional — wire it for the ref_mask_image output."), + io.String.Input("object_indices", default="", + tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Applied to both sides. Empty = all."), + io.Combo.Input("sort_by", options=["none", "x", "area"], + tooltip="Applied to both sides identically so index K = same logical slot. x = left-to-right by first-frame centroid; area = descending mask area; none = SAM3's order."), + io.Boolean.Input("replacement_mode", default=False, + tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. ref_mask_image is always black-bg regardless."), + ], + outputs=[ + io.Image.Output("driving_mask_video"), + io.Image.Output("ref_mask_image"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None): + def _prep(td): + if sort_by != "none": + td = _subset_track_data(td, _sort_tracks(td, sort_by)) + if object_indices.strip(): + indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] + packed = td.get("packed_masks") + n_obj = packed.shape[1] if packed is not None else 0 + indices = [i for i in indices if 0 <= i < n_obj] + td = _subset_track_data(td, indices) + return td + + drv = _prep(driving_track_data) + mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black") + + if ref_track_data is not None: + ref = _prep(ref_track_data) + ref_mask_image = _render_colored_masks(ref, "black") + else: + H, W = drv["orig_size"] + ref_mask_image = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + + return io.NodeOutput(mask_video, ref_mask_image) + + +class SCAILExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanSCAILToVideo, + SCAIL2ColoredMask, + ] + + +async def comfy_entrypoint() -> SCAILExtension: + return SCAILExtension() diff --git a/comfy_extras/nodes_scail2.py b/comfy_extras/nodes_scail2.py deleted file mode 100644 index 70e94ed7f..000000000 --- a/comfy_extras/nodes_scail2.py +++ /dev/null @@ -1,172 +0,0 @@ -"""SCAIL-2 preprocessing nodes that turn SAM3 video tracks into the conditioning -bundle the SCAIL-2 model consumes.""" - -from typing_extensions import override - -import torch -import torch.nn.functional as F - -import comfy.model_management -import comfy.utils -from comfy_api.latest import ComfyExtension, io - - -SAM3TrackData = io.Custom("SAM3_TRACK_DATA") - - -# Model was trained on these exact colors; deviating degrades multi-identity quality. -DEFAULT_PALETTE = [ - (0.0, 0.0, 1.0), # Blue - (1.0, 0.0, 0.0), # Red - (0.0, 1.0, 0.0), # Green - (1.0, 0.0, 1.0), # Magenta - (0.0, 1.0, 1.0), # Cyan - (1.0, 1.0, 0.0), # Yellow -] - - -def _unpack(track_data): - from comfy.ldm.sam3.tracker import unpack_masks - packed = track_data["packed_masks"] - if packed is None or packed.shape[1] == 0: - return None - return unpack_masks(packed) - - -def _first_frame_cx_area(masks_bool): - first = masks_bool[0].float() - H, W = first.shape[-2], first.shape[-1] - n_pixels = H * W - grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W) - area = first.sum(dim=(-1, -2)).clamp_(min=1) - cx = (first * grid_x).sum(dim=(-1, -2)) / area - return (cx / W).tolist(), (area / n_pixels).tolist() - - -def _sort_tracks(track_data, sort_by): - masks_bool = _unpack(track_data) - if masks_bool is None: - return [] - cx, area = _first_frame_cx_area(masks_bool) - if sort_by == "x": - return sorted(range(len(cx)), key=lambda i: cx[i]) - return sorted(range(len(area)), key=lambda i: -area[i]) # "area" - - -def _subset_track_data(track_data, obj_indices): - out = dict(track_data) - packed = track_data["packed_masks"] - if packed is None or not obj_indices: - out["packed_masks"] = None - if "scores" in out: - out["scores"] = [] - return out - out["packed_masks"] = packed[:, obj_indices].contiguous() - scores = track_data.get("scores") - if scores is not None: - out["scores"] = [scores[i] for i in obj_indices if i < len(scores)] - return out - - -def _bg_to_rgb(background): - if background.startswith("white"): - return (1.0, 1.0, 1.0) - return (0.0, 0.0, 0.0) - - -def _render_colored_masks(track_data, background="black"): - from comfy.ldm.sam3.tracker import unpack_masks - packed = track_data["packed_masks"] - H, W = track_data["orig_size"] - device = comfy.model_management.intermediate_device() - bg_rgb = _bg_to_rgb(background) - if packed is None or packed.shape[1] == 0: - T = track_data.get("n_frames", 1) if packed is None else packed.shape[0] - out = torch.empty(T, H, W, 3, device=device) - out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2] - return out - T, N_obj = packed.shape[0], packed.shape[1] - colors = torch.tensor( - [DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)], - device=device, dtype=torch.float32, - ) - masks_full = unpack_masks(packed.to(device)).float() - Hm, Wm = masks_full.shape[-2], masks_full.shape[-1] - masks_full = F.interpolate( - masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest" - ).view(T, N_obj, H, W) > 0.5 - any_mask = masks_full.any(dim=1) - obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1) - color_overlay = colors[obj_idx_map] - bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3) - return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay)) - - -class SCAIL2ColoredMask(io.ComfyNode): - """Render SAM3 tracks for the driving video and (optionally) the reference - image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by` - across both outputs guarantees identity K maps to the same color on both - sides, so multi-person workflows stay consistent without a separate - alignment node. ref_mask is always rendered black-bg (model convention); - mask_video bg follows the mode you'll set on WanSCAILToVideo.""" - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SCAIL2ColoredMask", - display_name="SCAIL-2 Colored Mask", - category="conditioning/video_models/scail", - inputs=[ - SAM3TrackData.Input("driving_track_data"), - SAM3TrackData.Input("ref_track_data", optional=True, - tooltip="SAM3 track of the reference image. Optional — wire it for the ref_mask_image output."), - io.String.Input("object_indices", default="", - tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Applied to both sides. Empty = all."), - io.Combo.Input("sort_by", options=["none", "x", "area"], - tooltip="Applied to both sides identically so index K = same logical slot. x = left-to-right by first-frame centroid; area = descending mask area; none = SAM3's order."), - io.Boolean.Input("replacement_mode", default=False, - tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). WanSCAILToVideo auto-detects mode from the wired mask_video's bg color, so this is the single source of truth. ref_mask_image is always black-bg regardless."), - ], - outputs=[ - io.Image.Output("driving_mask_video"), - io.Image.Output("ref_mask_image"), - ], - is_experimental=True, - ) - - @classmethod - def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None): - def _prep(td): - if sort_by != "none": - td = _subset_track_data(td, _sort_tracks(td, sort_by)) - if object_indices.strip(): - indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] - packed = td.get("packed_masks") - n_obj = packed.shape[1] if packed is not None else 0 - indices = [i for i in indices if 0 <= i < n_obj] - td = _subset_track_data(td, indices) - return td - - drv = _prep(driving_track_data) - mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black") - - if ref_track_data is not None: - ref = _prep(ref_track_data) - ref_mask_image = _render_colored_masks(ref, "black") - else: - H, W = drv["orig_size"] - ref_mask_image = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device()) - - return io.NodeOutput(mask_video, ref_mask_image) - - -class SCAIL2Extension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[io.ComfyNode]]: - return [ - SCAIL2ColoredMask, - ] - - -async def comfy_entrypoint() -> SCAIL2Extension: - return SCAIL2Extension() diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 8ef41ca8b..d73be8e00 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,161 +1456,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode): return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) -def _extract_mask_to_28ch(rgb_video): - """Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent - (1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c) - threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking.""" - T, H, W, _ = rgb_video.shape - _ON_THRESH = 225.0 / 255.0 - mask = rgb_video.movedim(-1, 1).float() - R = (mask[:, 0:1] > _ON_THRESH).float() - G = (mask[:, 1:2] > _ON_THRESH).float() - B = (mask[:, 2:3] > _ON_THRESH).float() - nR, nG, nB = 1 - R, 1 - G, 1 - B - binary_7ch = torch.cat([ - R * G * B, # white - R * nG * nB, # red - nR * G * nB, # green - nR * nG * B, # blue - R * G * nB, # yellow - R * nG * B, # magenta - nR * G * B, # cyan - ], dim=1) - H_lat, W_lat = H, W - for _ in range(3): - H_lat = (H_lat + 1) // 2 - W_lat = (W_lat + 1) // 2 - binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area') - T_latent = (T - 1) // 4 + 1 - padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0) - out = padded.view(T_latent, 28, H_lat, W_lat) - return out.unsqueeze(0) - - -class WanSCAILToVideo(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="WanSCAILToVideo", - category="model/conditioning/video_models", - inputs=[ - io.Conditioning.Input("positive"), - io.Conditioning.Input("negative"), - io.Vae.Input("vae"), - io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), - io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), - io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), - io.Int.Input("batch_size", default=1, min=1, max=4096), - io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), - io.Image.Input("driving_mask_video", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video. Mode is auto-detected from bg color: black bg = Animation, white bg = Replacement."), - io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), - io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), - io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), - io.Image.Input("reference_image", optional=True), - io.Image.Input("ref_mask_image", optional=True, tooltip="SCAIL-2 only. Single-frame colored ref mask at the reference image's full resolution."), - io.ClipVisionOutput.Input("clip_vision_output", optional=True), - io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."), - io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."), - io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the inpainting anchor."), - ], - outputs=[ - io.Conditioning.Output(display_name="positive"), - io.Conditioning.Output(display_name="negative"), - io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), - io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."), - ], - is_experimental=True, - ) - - @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, - pose_strength, pose_start, pose_end, - video_frame_offset, previous_frame_count, - reference_image=None, clip_vision_output=None, pose_video=None, - driving_mask_video=None, ref_mask_image=None, previous_frames=None) -> io.NodeOutput: - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - noise_mask = None - - # Auto-detect mode from driving_mask_video bg color. White bg => Replacement, else Animation. - replacement_mode = driving_mask_video is not None and driving_mask_video[0, ..., :3].mean().item() > 0.5 - ref_mask_flag = torch.tensor([not replacement_mode], dtype=torch.bool, device=latent.device) - positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag}) - negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag}) - - prev_trimmed = None - if previous_frames is not None and previous_frames.shape[0] > 0: - prev_trimmed = previous_frames[-previous_frame_count:] - video_frame_offset -= prev_trimmed.shape[0] - video_frame_offset = max(0, video_frame_offset) - - ref_latent = None - if reference_image is not None: - reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - # Replacement Mode: composite ref on black bg using ref_mask_image as alpha matte - # (matches the pre-composited examples that ship with SCAIL-2). Pixels where the - # mask is non-black (max channel > 0.1) are kept; bg pixels go to black. - if replacement_mode and ref_mask_image is not None: - rm = comfy.utils.common_upscale(ref_mask_image[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) - is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype) - reference_image = reference_image * is_char - ref_latent = vae.encode(reference_image[:, :, :, :3]) - - if ref_latent is not None: - positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) - negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) - - if clip_vision_output is not None: - positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) - negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - - if pose_video is not None: - if pose_video.shape[0] <= video_frame_offset: - pose_video = None - else: - pose_video = pose_video[video_frame_offset:] - if driving_mask_video is not None: - if driving_mask_video.shape[0] <= video_frame_offset: - driving_mask_video = None - else: - driving_mask_video = driving_mask_video[video_frame_offset:] - - if pose_video is not None: - pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) - pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength - positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) - negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) - - if driving_mask_video is not None: - mask_video_hw = comfy.utils.common_upscale(driving_mask_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) - sam_28ch = _extract_mask_to_28ch(mask_video_hw) - positive = node_helpers.conditioning_set_values(positive, {"sam_28ch": sam_28ch}) - negative = node_helpers.conditioning_set_values(negative, {"sam_28ch": sam_28ch}) - - if ref_mask_image is not None: - ref_mask_hw = comfy.utils.common_upscale(ref_mask_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - ref_sam_1f = _extract_mask_to_28ch(ref_mask_hw) - T_lat = ((length - 1) // 4) + 1 - zeros = torch.zeros((1, T_lat, 28, ref_sam_1f.shape[-2], ref_sam_1f.shape[-1]), - device=ref_sam_1f.device, dtype=ref_sam_1f.dtype) - ref_sam_28ch = torch.cat([ref_sam_1f, zeros], dim=1) - positive = node_helpers.conditioning_set_values(positive, {"ref_sam_28ch": ref_sam_28ch}) - negative = node_helpers.conditioning_set_values(negative, {"ref_sam_28ch": ref_sam_28ch}) - - if prev_trimmed is not None: - pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - prev_latent = vae.encode(pf[:, :, :, :3]) - T_p_lat = min(prev_latent.shape[2], latent.shape[2]) - latent[:, :, :T_p_lat] = prev_latent[:, :, :T_p_lat].to(latent.dtype) - noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), - device=latent.device, dtype=latent.dtype) - noise_mask[:, :, :T_p_lat] = 0.0 - - out_latent = {"samples": latent} - if noise_mask is not None: - out_latent["noise_mask"] = noise_mask - return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length) - - class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1631,7 +1476,6 @@ class WanExtension(ComfyExtension): WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, - WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/nodes.py b/nodes.py index 09190de60..4bf768045 100644 --- a/nodes.py +++ b/nodes.py @@ -2450,7 +2450,7 @@ async def init_builtin_extra_nodes(): "nodes_rtdetr.py", "nodes_frame_interpolation.py", "nodes_sam3.py", - "nodes_scail2.py", + "nodes_scail.py", "nodes_void.py", "nodes_wandancer.py", "nodes_hidream_o1.py",