diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 939f3303e..9178b3344 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1631,13 +1631,15 @@ class SCAILWanModel(WanModel): self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) - def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs): if reference_latent is not None: x = torch.cat((reference_latent, x), dim=2) # embeddings x = self.patch_embedding(x.float()).to(x.dtype) + if ref_mask_latents is not None: # SCAIL-2 additive mask stream + x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype) grid_sizes = x.shape[2:] transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) @@ -1645,6 +1647,8 @@ class SCAILWanModel(WanModel): scail_pose_seq_len = 0 if pose_latents is not None: scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + if sam_latents is not None: # SCAIL-2 additive mask stream + scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype) scail_x = scail_x.flatten(2).transpose(1, 2) scail_pose_seq_len = scail_x.shape[1] x = torch.cat([x, scail_x], dim=1) @@ -1695,7 +1699,36 @@ class SCAILWanModel(WanModel): return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + # ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode, + # which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset. + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): + if ref_mask_flag is not None and not bool(ref_mask_flag): + REF_ROPE_H = 120.0 + POSE_ROPE_W = 120.0 + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + main_t_patches = t - ref_t_patches + + parts = [] + if ref_t_patches > 0: + ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}} + parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf)) + if main_t_patches > 0: + parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options)) + + if pose_latents is not None: + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + h_scale = h / H_pose + w_scale = w / W_pose + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf)) + + return torch.cat(parts, dim=1) + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) if pose_latents is None: @@ -1719,138 +1752,15 @@ class SCAILWanModel(WanModel): return torch.cat([main_freqs, pose_freqs], dim=1) - def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): - bs, c, t, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) - - if pose_latents is not None: - pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) - - t_len = t - if time_dim_concat is not None: - time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) - x = torch.cat([x, time_dim_concat], dim=2) - t_len = x.shape[2] - - reference_latent = None - if "reference_latent" in kwargs: - reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) - t_len += reference_latent.shape[2] - - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] - - -class SCAIL2WanModel(SCAILWanModel): - """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" - - def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs): - super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) - self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) - - def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs): - if reference_latent is not None: - x = torch.cat((reference_latent, x), dim=2) - - x = self.patch_embedding(x.float()).to(x.dtype) - if ref_mask_latents is not None: - x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype) - grid_sizes = x.shape[2:] - transformer_options["grid_sizes"] = grid_sizes - x = x.flatten(2).transpose(1, 2) - - scail_pose_seq_len = 0 - if pose_latents is not None: - scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) - if sam_latents is not None: - scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype) - scail_x = scail_x.flatten(2).transpose(1, 2) - scail_pose_seq_len = scail_x.shape[1] - x = torch.cat([x, scail_x], dim=1) - del scail_x - - e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) - e = e.reshape(t.shape[0], -1, e.shape[-1]) - e0 = self.time_projection(e).unflatten(2, (6, self.dim)) - - context = self.text_embedding(context) - - context_img_len = None - if clip_fea is not None: - if self.img_emb is not None: - context_clip = self.img_emb(clip_fea) - context = torch.cat([context_clip, context], dim=1) - context_img_len = clip_fea.shape[-2] - - patches_replace = transformer_options.get("patches_replace", {}) - blocks_replace = patches_replace.get("dit", {}) - transformer_options["total_blocks"] = len(self.blocks) - transformer_options["block_type"] = "double" - for i, block in enumerate(self.blocks): - transformer_options["block_index"] = i - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) - return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) - x = out["img"] - else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) - - x = self.head(x, e) - - if scail_pose_seq_len > 0: - x = x[:, :-scail_pose_seq_len] - - x = self.unpatchify(x, grid_sizes) - - if reference_latent is not None: - x = x[:, :, reference_latent.shape[2]:] - - return x - - # ref_mask_flag is a scalar bool (CONDConstant); the mode is uniform across the batch. - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): - is_replacement = ref_mask_flag is not None and not bool(ref_mask_flag) - if not is_replacement: - return super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, pose_latents=pose_latents, reference_latent=reference_latent, transformer_options=transformer_options) - - REF_ROPE_H = 120.0 - POSE_ROPE_W = 120.0 - - ref_t_patches = 0 - if reference_latent is not None: - ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] - main_t_patches = t - ref_t_patches - - parts = [] - if ref_t_patches > 0: - ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}} - parts.append(super(SCAILWanModel, self).rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf)) - if main_t_patches > 0: - parts.append(super(SCAILWanModel, self).rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options)) - - if pose_latents is not None: - F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] - h_scale = h / H_pose - w_scale = w / W_pose - h_shift = (h_scale - 1) / 2 - w_shift = (w_scale - 1) / 2 - pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}} - parts.append(super(SCAILWanModel, self).rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf)) - - return torch.cat(parts, dim=1) - def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) if pose_latents is not None: pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) - if ref_mask_latents is not None: + if ref_mask_latents is not None: # SCAIL-2 ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size) - if sam_latents is not None: + if sam_latents is not None: # SCAIL-2 sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size) t_len = t @@ -1864,7 +1774,15 @@ class SCAIL2WanModel(SCAILWanModel): reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) t_len += reference_latent.shape[2] - ref_mask_flag = kwargs.pop("ref_mask_flag", None) + ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2 freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w] + + +class SCAIL2WanModel(SCAILWanModel): + """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" + + def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) diff --git a/comfy_extras/nodes_scail.py b/comfy_extras/nodes_scail.py index afb27f566..acb64ad26 100644 --- a/comfy_extras/nodes_scail.py +++ b/comfy_extras/nodes_scail.py @@ -11,7 +11,7 @@ import node_helpers import comfy.model_management import comfy.utils from comfy_api.latest import ComfyExtension, io - +from comfy.ldm.sam3.tracker import unpack_masks SAM3TrackData = io.Custom("SAM3_TRACK_DATA") @@ -28,7 +28,6 @@ DEFAULT_PALETTE = [ def _unpack(track_data): - from comfy.ldm.sam3.tracker import unpack_masks packed = track_data["packed_masks"] if packed is None or packed.shape[1] == 0: return None @@ -45,16 +44,6 @@ def _first_frame_cx_area(masks_bool): return (cx / W).tolist(), (area / n_pixels).tolist() -def _sort_tracks(track_data, sort_by): - masks_bool = _unpack(track_data) - if masks_bool is None: - return [] - cx, area = _first_frame_cx_area(masks_bool) - if sort_by == "x": - return sorted(range(len(cx)), key=lambda i: cx[i]) - return sorted(range(len(area)), key=lambda i: -area[i]) # "area" - - def _subset_track_data(track_data, obj_indices): out = dict(track_data) packed = track_data["packed_masks"] @@ -70,19 +59,12 @@ def _subset_track_data(track_data, obj_indices): return out -def _bg_to_rgb(background): - if background.startswith("white"): - return (1.0, 1.0, 1.0) - return (0.0, 0.0, 0.0) - - def _render_colored_masks(track_data, background="black"): - from comfy.ldm.sam3.tracker import unpack_masks packed = track_data["packed_masks"] H, W = track_data["orig_size"] device = comfy.model_management.intermediate_device() dtype = comfy.model_management.intermediate_dtype() - bg_rgb = _bg_to_rgb(background) + bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0) if packed is None or packed.shape[1] == 0: T = track_data.get("n_frames", 1) if packed is None else packed.shape[0] out = torch.empty(T, H, W, 3, device=device, dtype=dtype) @@ -277,19 +259,19 @@ class SCAIL2ColoredMask(io.ComfyNode): display_name="SCAIL-2 Colored Mask", category="conditioning/video_models/scail", inputs=[ - SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving video. Will be rendered into the driving_mask_video output."), + SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving video. Will be rendered into the pose_video_mask output."), SAM3TrackData.Input("ref_track_data", optional=True, - tooltip="SAM3 track of the reference image. Optional — wire it for the ref_mask_image output."), + tooltip="SAM3 track of the reference image. Optional — wire it for the reference_image_mask output."), io.String.Input("object_indices", default="", tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Applied to both sides. Empty = all."), - io.Combo.Input("sort_by", options=["none", "x", "area"], - tooltip="Applied to both sides identically so index K = same logical slot. x = left-to-right by first-frame centroid; area = descending mask area; none = SAM3's order."), + io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right", + tooltip="Applied to both sides identically so that color index order matches. left_to_right = by first-frame centroid; area = descending mask area; none = SAM3's original order."), io.Boolean.Input("replacement_mode", default=False, - tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. ref_mask_image is always black-bg regardless."), + tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."), ], outputs=[ - io.Image.Output("driving_mask_video"), - io.Image.Output("ref_mask_image"), + io.Image.Output("pose_video_mask"), + io.Image.Output("reference_image_mask"), ], is_experimental=True, ) @@ -297,8 +279,14 @@ class SCAIL2ColoredMask(io.ComfyNode): @classmethod def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None): def _prep(td): - if sort_by != "none": - td = _subset_track_data(td, _sort_tracks(td, sort_by)) + masks_bool = _unpack(td) + if sort_by != "none" and masks_bool is not None: + cx, area = _first_frame_cx_area(masks_bool) + if sort_by == "left_to_right": + order = sorted(range(len(cx)), key=lambda i: cx[i]) + else: # "area" + order = sorted(range(len(area)), key=lambda i: -area[i]) + td = _subset_track_data(td, order) if object_indices.strip(): indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] packed = td.get("packed_masks") @@ -312,12 +300,12 @@ class SCAIL2ColoredMask(io.ComfyNode): if ref_track_data is not None: ref = _prep(ref_track_data) - ref_mask_image = _render_colored_masks(ref, "black") + reference_image_mask = _render_colored_masks(ref, "black") else: H, W = drv["orig_size"] - ref_mask_image = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) - return io.NodeOutput(mask_video, ref_mask_image) + return io.NodeOutput(mask_video, reference_image_mask) class SCAILExtension(ComfyExtension):