diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 282408891..1c9782a38 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1665,7 +1665,7 @@ class SCAILWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) - if ref_mask_latents is not None: # SCAIL-2 additive mask stream + if ref_mask_latents is not None: # SCAIL-2 additive mask stream (one identity mask frame per reference, then video) x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype) grid_sizes = x.shape[2:] transformer_options["grid_sizes"] = grid_sizes @@ -1728,22 +1728,25 @@ class SCAILWanModel(WanModel): # ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode, # which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset. + # reference_latent may stack several frames: the last is the primary reference adjacent to the video, the earlier frames are additional references. def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + if ref_mask_flag is not None and not bool(ref_mask_flag): REF_ROPE_H = 120.0 POSE_ROPE_W = 120.0 - ref_t_patches = 0 - if reference_latent is not None: - ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] main_t_patches = t - ref_t_patches + video_t_start = max(ref_t_patches - 1, 0) parts = [] if ref_t_patches > 0: ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}} parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf)) if main_t_patches > 0: - parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options)) + parts.append(super().rope_encode(main_t_patches, h, w, t_start=video_t_start, device=device, dtype=dtype, transformer_options=transformer_options)) if pose_latents is not None: F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] @@ -1752,7 +1755,7 @@ class SCAILWanModel(WanModel): h_shift = (h_scale - 1) / 2 w_shift = (w_scale - 1) / 2 pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}} - parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf)) + parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=video_t_start, device=device, dtype=dtype, transformer_options=pose_tf)) return torch.cat(parts, dim=1) @@ -1761,10 +1764,6 @@ class SCAILWanModel(WanModel): if pose_latents is None: return main_freqs - ref_t_patches = 0 - if reference_latent is not None: - ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] - F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames diff --git a/comfy/model_base.py b/comfy/model_base.py index ab4a11022..d143dc06f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1747,10 +1747,14 @@ class WAN21_SCAIL(WAN21): reference_latents = kwargs.get("reference_latents", None) if reference_latents is not None: - ref_latent = self.process_latent_in(reference_latents[-1]) - ref_mask = torch.ones_like(ref_latent[:, :4]) - ref_latent = torch.cat([ref_latent, ref_mask], dim=1) - out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + # SCAIL-2 multi-reference: reference_latents[0] is the primary ref, [1:] are additional + # references. Stack as [additional..., primary] so the primary stays adjacent to the video. + ordered = list(reference_latents[1:]) + list(reference_latents[:1]) + stacked = [] + for lat in ordered: + lat = self.process_latent_in(lat) + stacked.append(torch.cat([lat, torch.ones_like(lat[:, :4])], dim=1)) + out['reference_latent'] = comfy.conds.CONDRegular(torch.cat(stacked, dim=2)) pose_latents = kwargs.get("pose_video_latent", None) if pose_latents is not None: @@ -1792,6 +1796,7 @@ class WAN21_SCAIL2(WAN21_SCAIL): if driving_mask_28ch is not None: out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous()) + # ref_mask_28ch holds one identity mask per stacked reference frame (additional refs first, then the primary ref), followed by zeros over the video frames. ref_mask_28ch = kwargs.get("ref_mask_28ch", None) if ref_mask_28ch is not None: out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous()) @@ -1819,10 +1824,11 @@ class WAN21_SCAIL2(WAN21_SCAIL): # Return sliced view omitting retain_index_list return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0) if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): - # The ref mask is just a single frame padded with frames of zeros, so just grab the first frames for all windows + # The ref mask is N leading ref frames padded with frames of zeros, so just grab the first frames for all windows full_ref_mask = cond_value.cond video_frame_count = x_in.shape[2] - if full_ref_mask.shape[2] != video_frame_count + 1: + ref_frame_count = full_ref_mask.shape[2] - video_frame_count + if ref_frame_count < 1: return None window_length = len(window.index_list) @@ -1831,7 +1837,7 @@ class WAN21_SCAIL2(WAN21_SCAIL): if anchor_index is not None and anchor_index >= 0: window_length += 1 - window_ref_mask = full_ref_mask[:, :, :window_length + 1].to(device) + window_ref_mask = full_ref_mask[:, :, :window_length + ref_frame_count].to(device) return cond_value._copy_with(window_ref_mask) return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) diff --git a/comfy_extras/nodes_scail.py b/comfy_extras/nodes_scail.py index 007733efc..55c9897e3 100644 --- a/comfy_extras/nodes_scail.py +++ b/comfy_extras/nodes_scail.py @@ -34,14 +34,20 @@ def _unpack(track_data): return unpack_masks(packed) -def _first_frame_cx_area(masks_bool): - first = masks_bool[0].float() - H, W = first.shape[-2], first.shape[-1] - n_pixels = H * W - grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W) - area = first.sum(dim=(-1, -2)).clamp_(min=1) - cx = (first * grid_x).sum(dim=(-1, -2)) / area - return (cx / W).tolist(), (area / n_pixels).tolist() +def _first_appearance_cx_area(masks_bool): + """Per object: first frame it appears in, plus centroid-x and area in that frame.""" + m = masks_bool.float() + T, H, W = m.shape[0], m.shape[-2], m.shape[-1] + grid_x = torch.arange(W, device=m.device, dtype=m.dtype).view(1, 1, 1, W) + area_t = m.sum(dim=(-1, -2)) + cx_t = (m * grid_x).sum(dim=(-1, -2)) / area_t.clamp(min=1) + present = area_t > 0 + frame_idx = torch.arange(T, device=m.device).unsqueeze(1) + first_t = torch.where(present, frame_idx, T).amin(dim=0) + sel = first_t.clamp(max=T - 1).unsqueeze(0) + cx = cx_t.gather(0, sel).squeeze(0) + area = area_t.gather(0, sel).squeeze(0) + return first_t.tolist(), (cx / W).tolist(), (area / (H * W)).tolist() def _subset_track_data(track_data, obj_indices): @@ -81,12 +87,26 @@ def _render_colored_masks(track_data, background="black"): masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest" ).view(T, N_obj, H, W) > 0.5 any_mask = masks_full.any(dim=1) - obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1) - color_overlay = colors[obj_idx_map] + color_overlay = colors[masks_full.to(torch.uint8).argmax(dim=1)] bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3) return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay)) +def _render_mask_as_identity(mask, background="black"): + """Plain comfy MASK (B,H,W) or (H,W) -> (B,H,W,3) rendered as a single identity (palette[0]) + on the given background. A batch is treated as multiple views of that one subject.""" + device = comfy.model_management.intermediate_device() + dtype = comfy.model_management.intermediate_dtype() + if mask.ndim == 2: + mask = mask.unsqueeze(0) + mask = mask.to(device=device, dtype=dtype) + B, H, W = mask.shape + bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0) + color = torch.tensor(DEFAULT_PALETTE[0], device=device, dtype=dtype).view(1, 1, 1, 3) + bg = torch.tensor(bg_rgb, device=device, dtype=dtype).view(1, 1, 1, 3) + return torch.where((mask > 0.5).unsqueeze(-1), color.expand(B, H, W, 3), bg.expand(B, H, W, 3)) + + def _extract_mask_to_28ch(rgb_video): """Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent (1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c) @@ -138,8 +158,8 @@ class WanSCAILToVideo(io.ComfyNode): io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."), io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."), - io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."), - io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."), + io.Image.Input("reference_image", optional=True, tooltip="Reference image. The first image is the primary reference (composite all identities onto it). SCAIL-2: extra batch images are used as additional views (back view, close-up, occluded background), each needing a matching reference_image_mask in that identity's color."), + io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask, batch matching reference_image (first = primary reference mask, rest = identity masks for the additional reference_image)."), io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."), io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."), io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."), @@ -171,19 +191,21 @@ class WanSCAILToVideo(io.ComfyNode): video_frame_offset -= prev_trimmed.shape[0] video_frame_offset = max(0, video_frame_offset) - ref_latent = None if reference_image is not None: - reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) - # Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte - if replacement_mode and reference_image_mask is not None: - rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) - is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype) - reference_image = reference_image * is_char - ref_latent = vae.encode(reference_image[:, :, :, :3]) + ref_imgs = comfy.utils.common_upscale(reference_image.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + n_ref = ref_imgs.shape[0] + # SCAIL-2 multi-reference: the first image is the primary ref, the rest are additional references. - if ref_latent is not None: - positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) - negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + # Replacement Mode: composite each ref on black bg using its mask as alpha matte + if replacement_mode and reference_image_mask is not None: + rm = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) + rm = rm[[min(i, rm.shape[0] - 1) for i in range(n_ref)]] + is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(ref_imgs.dtype) + ref_imgs = ref_imgs * is_char + # encode each ref individually so each stays a single latent frame (a batched encode would be treated as a video) + ref_latents = [vae.encode(ref_imgs[i:i + 1, :, :, :3]) for i in range(n_ref)] + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True) if clip_vision_output is not None: positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) @@ -221,11 +243,16 @@ class WanSCAILToVideo(io.ComfyNode): positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch}) negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch}) - if reference_image_mask is not None: - ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) - ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw) + # The ref mask binds reference frames to identities, so it only applies when there's a reference image. + if reference_image_mask is not None and reference_image is not None: + ref_mask_hw = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) + n_masks = ref_mask_hw.shape[0] + n_ref = reference_image.shape[0] + + add_masks = [_extract_mask_to_28ch(ref_mask_hw[min(i, n_masks - 1)][None]) for i in range(1, n_ref)] + ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw[:1]) zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype) - ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1) + ref_mask_28ch = torch.cat(add_masks + [ref_mask_1f, zeros], dim=1) positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch}) negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch}) @@ -244,12 +271,9 @@ class WanSCAILToVideo(io.ComfyNode): class SCAIL2ColoredMask(io.ComfyNode): - """Render SAM3 tracks for the driving pose video and (optionally) the reference - image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by` - across both outputs guarantees identity K maps to the same color on both - sides, for multi-person workflow consistency. - reference_image_mask is always rendered black-bg (model convention) - pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode + """Render SAM3 tracks for the driving pose video and reference image(s) into the + colored masks WanSCAILToVideo consumes. Shared `sort_by` keeps each identity on the + same color across both outputs. """ @classmethod @@ -260,10 +284,12 @@ class SCAIL2ColoredMask(io.ComfyNode): category="model/conditioning/wan/scail", inputs=[ SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."), - SAM3TrackData.Input("ref_track_data", optional=True, tooltip="SAM3 track of the reference image."), - io.String.Input("object_indices", default="", tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."), + io.MultiType.Input("ref_track_data", [SAM3TrackData, io.Mask], optional=True, display_name="reference_masks", + tooltip="SAM3 track of the reference image(s) (one identity per object, colored in batch order), or a plain MASK of the reference subject (rendered as a single identity)."), + io.String.Input("object_indices", default="", + tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."), io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right", - tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."), + tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). Objects that appear in earlier frames always come first; within a frame, left_to_right = leftmost object (by centroid at first appearance) gets the first color, area = biggest object (by mask area at first appearance) gets the first color; none = keep SAM3's order."), io.Boolean.Input("replacement_mode", default=False, tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). " "True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."), @@ -280,11 +306,11 @@ class SCAIL2ColoredMask(io.ComfyNode): def _prep(td): masks_bool = _unpack(td) if sort_by != "none" and masks_bool is not None: - cx, area = _first_frame_cx_area(masks_bool) + first_t, cx, area = _first_appearance_cx_area(masks_bool) if sort_by == "left_to_right": - order = sorted(range(len(cx)), key=lambda i: cx[i]) + order = sorted(range(len(cx)), key=lambda i: (first_t[i], cx[i])) else: # "area" - order = sorted(range(len(area)), key=lambda i: -area[i]) + order = sorted(range(len(area)), key=lambda i: (first_t[i], -area[i])) td = _subset_track_data(td, order) if object_indices.strip(): indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] @@ -300,8 +326,10 @@ class SCAIL2ColoredMask(io.ComfyNode): ref_bg = "black" if replacement_mode else "white" if ref_track_data is not None: - ref = _prep(ref_track_data) - reference_image_mask = _render_colored_masks(ref, ref_bg) + if isinstance(ref_track_data, torch.Tensor): # plain comfy MASK + reference_image_mask = _render_mask_as_identity(ref_track_data, ref_bg) + else: + reference_image_mask = _render_colored_masks(_prep(ref_track_data), ref_bg) else: H, W = drv["orig_size"] fill_value = 1.0 if ref_bg == "white" else 0.0