initial SCAIL2 support

This commit is contained in:
kijai 2026-06-09 01:02:50 +03:00
parent a0a055bc4e
commit 01c4fa4c74
7 changed files with 498 additions and 7 deletions

View File

@ -1739,3 +1739,132 @@ class SCAILWanModel(WanModel):
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
class SCAIL2WanModel(SCAILWanModel):
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
if reference_latent is not None:
x = torch.cat((reference_latent, x), dim=2)
x = self.patch_embedding(x.float()).to(x.dtype)
if ref_mask_latents is not None:
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
scail_pose_seq_len = 0
if pose_latents is not None:
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
if sam_latents is not None:
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
scail_x = scail_x.flatten(2).transpose(1, 2)
scail_pose_seq_len = scail_x.shape[1]
x = torch.cat([x, scail_x], dim=1)
del scail_x
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea)
context = torch.cat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x = self.head(x, e)
if scail_pose_seq_len > 0:
x = x[:, :-scail_pose_seq_len]
x = self.unpatchify(x, grid_sizes)
if reference_latent is not None:
x = x[:, :, reference_latent.shape[2]:]
return x
# Reads the first element of ref_mask_flag and assumes a uniform mode across the batch.
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
is_replacement = ref_mask_flag is not None and not bool(ref_mask_flag.flatten()[0].item())
if not is_replacement:
return super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, pose_latents=pose_latents, reference_latent=reference_latent, transformer_options=transformer_options)
REF_ROPE_H = 120.0
POSE_ROPE_W = 120.0
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
main_t_patches = t - ref_t_patches
parts = []
if ref_t_patches > 0:
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
parts.append(super(SCAILWanModel, self).rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
if main_t_patches > 0:
parts.append(super(SCAILWanModel, self).rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
if pose_latents is not None:
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
h_scale = h / H_pose
w_scale = w / W_pose
h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
parts.append(super(SCAILWanModel, self).rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
return torch.cat(parts, dim=1)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if pose_latents is not None:
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
if ref_mask_latents is not None:
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
if sam_latents is not None:
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = x.shape[2]
reference_latent = None
if "reference_latent" in kwargs:
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
t_len += reference_latent.shape[2]
ref_mask_flag = kwargs.pop("ref_mask_flag", None)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]

View File

@ -1766,6 +1766,83 @@ class WAN21_SCAIL(WAN21):
return out return out
class WAN21_SCAIL2(WAN21_SCAIL):
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
# Bypass WAN21.__init__ to override unet_model to SCAIL2WanModel.
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel)
self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents")
self.memory_usage_shape_process = {
"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
"sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
}
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
sam_28ch = kwargs.get("sam_28ch", None)
if sam_28ch is not None:
out['sam_latents'] = comfy.conds.CONDRegular(sam_28ch.movedim(1, 2).contiguous())
ref_sam_28ch = kwargs.get("ref_sam_28ch", None)
if ref_sam_28ch is not None:
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_sam_28ch.movedim(1, 2).contiguous())
ref_mask_flag = kwargs.get("ref_mask_flag", None)
if ref_mask_flag is not None:
out['ref_mask_flag'] = comfy.conds.CONDRegular(ref_mask_flag)
return out
def extra_conds_shapes(self, **kwargs):
out = super().extra_conds_shapes(**kwargs)
sam_28ch = kwargs.get("sam_28ch", None)
if sam_28ch is not None:
s = sam_28ch.shape
out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]]
ref_sam_28ch = kwargs.get("ref_sam_28ch", None)
if ref_sam_28ch is not None:
s = ref_sam_28ch.shape
out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]]
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key in ("sam_latents", "pose_latents"):
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
def concat_cond(self, **kwargs):
# Override base path that short-circuits to 4 zeros when image_to_video=False
# and extra_channels == image.shape[1]: history needs the mask channels to be 1
# at anchor slots.
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
if extra_channels != 4:
return super().concat_cond(**kwargs)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
return torch.zeros_like(noise)[:, :4]
device = kwargs["device"]
if mask.shape[1] != 4:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
if mask.shape[1] == 1:
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return mask
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
# Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image.
return latent_image
class WAN22_WanDancer(WAN21): class WAN22_WanDancer(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)

View File

@ -680,6 +680,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "humo" dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate" dit_config["model_type"] = "animate"
elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail2"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail" dit_config["model_type"] = "scail"
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys: elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:

View File

@ -1450,6 +1450,17 @@ class WAN21_SCAIL(WAN21_T2V):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out return out
class WAN21_SCAIL2(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "scail2",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device)
return out
class WAN22_WanDancer(WAN21_T2V): class WAN22_WanDancer(WAN21_T2V):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -2287,6 +2298,7 @@ models = [
WAN22_Animate, WAN22_Animate,
WAN21_FlowRVS, WAN21_FlowRVS,
WAN21_SCAIL, WAN21_SCAIL,
WAN21_SCAIL2,
WAN22_WanDancer, WAN22_WanDancer,
Hunyuan3Dv2mini, Hunyuan3Dv2mini,
Hunyuan3Dv2, Hunyuan3Dv2,

View File

@ -0,0 +1,172 @@
"""SCAIL-2 preprocessing nodes that turn SAM3 video tracks into the conditioning
bundle the SCAIL-2 model consumes."""
from typing_extensions import override
import torch
import torch.nn.functional as F
import comfy.model_management
import comfy.utils
from comfy_api.latest import ComfyExtension, io
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
# Model was trained on these exact colors; deviating degrades multi-identity quality.
DEFAULT_PALETTE = [
(0.0, 0.0, 1.0), # Blue
(1.0, 0.0, 0.0), # Red
(0.0, 1.0, 0.0), # Green
(1.0, 0.0, 1.0), # Magenta
(0.0, 1.0, 1.0), # Cyan
(1.0, 1.0, 0.0), # Yellow
]
def _unpack(track_data):
from comfy.ldm.sam3.tracker import unpack_masks
packed = track_data["packed_masks"]
if packed is None or packed.shape[1] == 0:
return None
return unpack_masks(packed)
def _first_frame_cx_area(masks_bool):
first = masks_bool[0].float()
H, W = first.shape[-2], first.shape[-1]
n_pixels = H * W
grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W)
area = first.sum(dim=(-1, -2)).clamp_(min=1)
cx = (first * grid_x).sum(dim=(-1, -2)) / area
return (cx / W).tolist(), (area / n_pixels).tolist()
def _sort_tracks(track_data, sort_by):
masks_bool = _unpack(track_data)
if masks_bool is None:
return []
cx, area = _first_frame_cx_area(masks_bool)
if sort_by == "x":
return sorted(range(len(cx)), key=lambda i: cx[i])
return sorted(range(len(area)), key=lambda i: -area[i]) # "area"
def _subset_track_data(track_data, obj_indices):
out = dict(track_data)
packed = track_data["packed_masks"]
if packed is None or not obj_indices:
out["packed_masks"] = None
if "scores" in out:
out["scores"] = []
return out
out["packed_masks"] = packed[:, obj_indices].contiguous()
scores = track_data.get("scores")
if scores is not None:
out["scores"] = [scores[i] for i in obj_indices if i < len(scores)]
return out
def _bg_to_rgb(background):
if background.startswith("white"):
return (1.0, 1.0, 1.0)
return (0.0, 0.0, 0.0)
def _render_colored_masks(track_data, background="black"):
from comfy.ldm.sam3.tracker import unpack_masks
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
device = comfy.model_management.intermediate_device()
bg_rgb = _bg_to_rgb(background)
if packed is None or packed.shape[1] == 0:
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
out = torch.empty(T, H, W, 3, device=device)
out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2]
return out
T, N_obj = packed.shape[0], packed.shape[1]
colors = torch.tensor(
[DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)],
device=device, dtype=torch.float32,
)
masks_full = unpack_masks(packed.to(device)).float()
Hm, Wm = masks_full.shape[-2], masks_full.shape[-1]
masks_full = F.interpolate(
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
).view(T, N_obj, H, W) > 0.5
any_mask = masks_full.any(dim=1)
obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1)
color_overlay = colors[obj_idx_map]
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
class SCAIL2ColoredMask(io.ComfyNode):
"""Render SAM3 tracks for the driving video and (optionally) the reference
image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by`
across both outputs guarantees identity K maps to the same color on both
sides, so multi-person workflows stay consistent without a separate
alignment node. ref_mask is always rendered black-bg (model convention);
mask_video bg follows the mode you'll set on WanSCAILToVideo."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SCAIL2ColoredMask",
display_name="SCAIL-2 Colored Mask",
category="conditioning/video_models/scail",
inputs=[
SAM3TrackData.Input("driving_track_data"),
SAM3TrackData.Input("ref_track_data", optional=True,
tooltip="SAM3 track of the reference image. Optional — wire it for the ref_mask_image output."),
io.String.Input("object_indices", default="",
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Applied to both sides. Empty = all."),
io.Combo.Input("sort_by", options=["none", "x", "area"],
tooltip="Applied to both sides identically so index K = same logical slot. x = left-to-right by first-frame centroid; area = descending mask area; none = SAM3's order."),
io.Boolean.Input("replacement_mode", default=False,
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). WanSCAILToVideo auto-detects mode from the wired mask_video's bg color, so this is the single source of truth. ref_mask_image is always black-bg regardless."),
],
outputs=[
io.Image.Output("driving_mask_video"),
io.Image.Output("ref_mask_image"),
],
is_experimental=True,
)
@classmethod
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
def _prep(td):
if sort_by != "none":
td = _subset_track_data(td, _sort_tracks(td, sort_by))
if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
packed = td.get("packed_masks")
n_obj = packed.shape[1] if packed is not None else 0
indices = [i for i in indices if 0 <= i < n_obj]
td = _subset_track_data(td, indices)
return td
drv = _prep(driving_track_data)
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
if ref_track_data is not None:
ref = _prep(ref_track_data)
ref_mask_image = _render_colored_masks(ref, "black")
else:
H, W = drv["orig_size"]
ref_mask_image = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device())
return io.NodeOutput(mask_video, ref_mask_image)
class SCAIL2Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SCAIL2ColoredMask,
]
async def comfy_entrypoint() -> SCAIL2Extension:
return SCAIL2Extension()

View File

@ -1456,6 +1456,37 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
def _extract_mask_to_28ch(rgb_video):
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking."""
T, H, W, _ = rgb_video.shape
_ON_THRESH = 225.0 / 255.0
mask = rgb_video.movedim(-1, 1).float()
R = (mask[:, 0:1] > _ON_THRESH).float()
G = (mask[:, 1:2] > _ON_THRESH).float()
B = (mask[:, 2:3] > _ON_THRESH).float()
nR, nG, nB = 1 - R, 1 - G, 1 - B
binary_7ch = torch.cat([
R * G * B, # white
R * nG * nB, # red
nR * G * nB, # green
nR * nG * B, # blue
R * G * nB, # yellow
R * nG * B, # magenta
nR * G * B, # cyan
], dim=1)
H_lat, W_lat = H, W
for _ in range(3):
H_lat = (H_lat + 1) // 2
W_lat = (W_lat + 1) // 2
binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area')
T_latent = (T - 1) // 4 + 1
padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0)
out = padded.view(T_latent, 28, H_lat, W_lat)
return out.unsqueeze(0)
class WanSCAILToVideo(io.ComfyNode): class WanSCAILToVideo(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -1470,47 +1501,114 @@ class WanSCAILToVideo(io.ComfyNode):
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096), io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Image.Input("driving_mask_video", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video. Mode is auto-detected from bg color: black bg = Animation, white bg = Replacement."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
io.Image.Input("reference_image", optional=True),
io.Image.Input("ref_mask_image", optional=True, tooltip="SCAIL-2 only. Single-frame colored ref mask at the reference image's full resolution."),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the inpainting anchor."),
], ],
outputs=[ outputs=[
io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"), io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."),
], ],
is_experimental=True, is_experimental=True,
) )
@classmethod @classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: def execute(cls, positive, negative, vae, width, height, length, batch_size,
pose_strength, pose_start, pose_end,
video_frame_offset, previous_frame_count,
reference_image=None, clip_vision_output=None, pose_video=None,
driving_mask_video=None, ref_mask_image=None, previous_frames=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
noise_mask = None
# Auto-detect mode from driving_mask_video bg color. White bg => Replacement, else Animation.
replacement_mode = driving_mask_video is not None and driving_mask_video[0, ..., :3].mean().item() > 0.5
ref_mask_flag = torch.tensor([not replacement_mode], dtype=torch.bool, device=latent.device)
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag})
prev_trimmed = None
if previous_frames is not None and previous_frames.shape[0] > 0:
prev_trimmed = previous_frames[-previous_frame_count:]
video_frame_offset -= prev_trimmed.shape[0]
video_frame_offset = max(0, video_frame_offset)
ref_latent = None ref_latent = None
if reference_image is not None: if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
# Replacement Mode: composite ref on black bg using ref_mask_image as alpha matte
# (matches the pre-composited examples that ship with SCAIL-2). Pixels where the
# mask is non-black (max channel > 0.1) are kept; bg pixels go to black.
if replacement_mode and ref_mask_image is not None:
rm = comfy.utils.common_upscale(ref_mask_image[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype)
reference_image = reference_image * is_char
ref_latent = vae.encode(reference_image[:, :, :, :3]) ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None: if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
if clip_vision_output is not None: if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
if pose_video.shape[0] <= video_frame_offset:
pose_video = None
else:
pose_video = pose_video[video_frame_offset:]
if driving_mask_video is not None:
if driving_mask_video.shape[0] <= video_frame_offset:
driving_mask_video = None
else:
driving_mask_video = driving_mask_video[video_frame_offset:]
if pose_video is not None: if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
out_latent = {} if driving_mask_video is not None:
out_latent["samples"] = latent mask_video_hw = comfy.utils.common_upscale(driving_mask_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
return io.NodeOutput(positive, negative, out_latent) sam_28ch = _extract_mask_to_28ch(mask_video_hw)
positive = node_helpers.conditioning_set_values(positive, {"sam_28ch": sam_28ch})
negative = node_helpers.conditioning_set_values(negative, {"sam_28ch": sam_28ch})
if ref_mask_image is not None:
ref_mask_hw = comfy.utils.common_upscale(ref_mask_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_sam_1f = _extract_mask_to_28ch(ref_mask_hw)
T_lat = ((length - 1) // 4) + 1
zeros = torch.zeros((1, T_lat, 28, ref_sam_1f.shape[-2], ref_sam_1f.shape[-1]),
device=ref_sam_1f.device, dtype=ref_sam_1f.dtype)
ref_sam_28ch = torch.cat([ref_sam_1f, zeros], dim=1)
positive = node_helpers.conditioning_set_values(positive, {"ref_sam_28ch": ref_sam_28ch})
negative = node_helpers.conditioning_set_values(negative, {"ref_sam_28ch": ref_sam_28ch})
if prev_trimmed is not None:
pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
prev_latent = vae.encode(pf[:, :, :, :3])
T_p_lat = min(prev_latent.shape[2], latent.shape[2])
latent[:, :, :T_p_lat] = prev_latent[:, :, :T_p_lat].to(latent.dtype)
noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]),
device=latent.device, dtype=latent.dtype)
noise_mask[:, :, :T_p_lat] = 0.0
out_latent = {"samples": latent}
if noise_mask is not None:
out_latent["noise_mask"] = noise_mask
return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length)
class WanExtension(ComfyExtension): class WanExtension(ComfyExtension):

View File

@ -2472,6 +2472,7 @@ async def init_builtin_extra_nodes():
"nodes_rtdetr.py", "nodes_rtdetr.py",
"nodes_frame_interpolation.py", "nodes_frame_interpolation.py",
"nodes_sam3.py", "nodes_sam3.py",
"nodes_scail2.py",
"nodes_void.py", "nodes_void.py",
"nodes_wandancer.py", "nodes_wandancer.py",
"nodes_hidream_o1.py", "nodes_hidream_o1.py",