ComfyUI/comfy_extras/nodes_scail.py
Jukka Seppänen a590d60bb1
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
feat: SCAIL-2 multireference (CORE-310) (#14509)
* SCAIl-2: support multiref
2026-06-17 16:21:23 +03:00

352 lines
20 KiB
Python

"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3
preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes."""
from typing_extensions import override
import torch
import torch.nn.functional as F
import nodes
import node_helpers
import comfy.model_management
import comfy.utils
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.sam3.tracker import unpack_masks
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
# Model was trained on these exact colors; deviating degrades multi-identity quality.
DEFAULT_PALETTE = [
(0.0, 0.0, 1.0), # Blue
(1.0, 0.0, 0.0), # Red
(0.0, 1.0, 0.0), # Green
(1.0, 0.0, 1.0), # Magenta
(0.0, 1.0, 1.0), # Cyan
(1.0, 1.0, 0.0), # Yellow
]
def _unpack(track_data):
packed = track_data["packed_masks"]
if packed is None or packed.shape[1] == 0:
return None
return unpack_masks(packed)
def _first_appearance_cx_area(masks_bool):
"""Per object: first frame it appears in, plus centroid-x and area in that frame."""
m = masks_bool.float()
T, H, W = m.shape[0], m.shape[-2], m.shape[-1]
grid_x = torch.arange(W, device=m.device, dtype=m.dtype).view(1, 1, 1, W)
area_t = m.sum(dim=(-1, -2))
cx_t = (m * grid_x).sum(dim=(-1, -2)) / area_t.clamp(min=1)
present = area_t > 0
frame_idx = torch.arange(T, device=m.device).unsqueeze(1)
first_t = torch.where(present, frame_idx, T).amin(dim=0)
sel = first_t.clamp(max=T - 1).unsqueeze(0)
cx = cx_t.gather(0, sel).squeeze(0)
area = area_t.gather(0, sel).squeeze(0)
return first_t.tolist(), (cx / W).tolist(), (area / (H * W)).tolist()
def _subset_track_data(track_data, obj_indices):
out = dict(track_data)
packed = track_data["packed_masks"]
if packed is None or not obj_indices:
out["packed_masks"] = None
if "scores" in out:
out["scores"] = []
return out
out["packed_masks"] = packed[:, obj_indices].contiguous()
scores = track_data.get("scores")
if scores is not None:
out["scores"] = [scores[i] for i in obj_indices if i < len(scores)]
return out
def _render_colored_masks(track_data, background="black"):
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
device = comfy.model_management.intermediate_device()
dtype = comfy.model_management.intermediate_dtype()
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
if packed is None or packed.shape[1] == 0:
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2]
return out
T, N_obj = packed.shape[0], packed.shape[1]
colors = torch.tensor(
[DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)],
device=device, dtype=dtype,
)
masks_full = unpack_masks(packed.to(device)).float()
Hm, Wm = masks_full.shape[-2], masks_full.shape[-1]
masks_full = F.interpolate(
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
).view(T, N_obj, H, W) > 0.5
any_mask = masks_full.any(dim=1)
color_overlay = colors[masks_full.to(torch.uint8).argmax(dim=1)]
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
def _render_mask_as_identity(mask, background="black"):
"""Plain comfy MASK (B,H,W) or (H,W) -> (B,H,W,3) rendered as a single identity (palette[0])
on the given background. A batch is treated as multiple views of that one subject."""
device = comfy.model_management.intermediate_device()
dtype = comfy.model_management.intermediate_dtype()
if mask.ndim == 2:
mask = mask.unsqueeze(0)
mask = mask.to(device=device, dtype=dtype)
B, H, W = mask.shape
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
color = torch.tensor(DEFAULT_PALETTE[0], device=device, dtype=dtype).view(1, 1, 1, 3)
bg = torch.tensor(bg_rgb, device=device, dtype=dtype).view(1, 1, 1, 3)
return torch.where((mask > 0.5).unsqueeze(-1), color.expand(B, H, W, 3), bg.expand(B, H, W, 3))
def _extract_mask_to_28ch(rgb_video):
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking."""
T, H, W, _ = rgb_video.shape
_ON_THRESH = 225.0 / 255.0
mask = rgb_video.movedim(-1, 1).float()
R = (mask[:, 0:1] > _ON_THRESH).float()
G = (mask[:, 1:2] > _ON_THRESH).float()
B = (mask[:, 2:3] > _ON_THRESH).float()
nR, nG, nB = 1 - R, 1 - G, 1 - B
binary_7ch = torch.cat([
R * G * B, # white
R * nG * nB, # red
nR * G * nB, # green
nR * nG * B, # blue
R * G * nB, # yellow
R * nG * B, # magenta
nR * G * B, # cyan
], dim=1)
H_lat, W_lat = H, W
for _ in range(3):
H_lat = (H_lat + 1) // 2
W_lat = (W_lat + 1) // 2
binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area')
T_latent = (T - 1) // 4 + 1
padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0)
out = padded.view(T_latent, 28, H_lat, W_lat)
return out.unsqueeze(0)
class WanSCAILToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSCAILToVideo",
category="model/conditioning/wan/scail",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."),
io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
io.Image.Input("reference_image", optional=True, tooltip="Reference image. The first image is the primary reference (composite all identities onto it). SCAIL-2: extra batch images are used as additional views (back view, close-up, occluded background), each needing a matching reference_image_mask in that identity's color."),
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask, batch matching reference_image (first = primary reference mask, rest = identity masks for the additional reference_image)."),
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end,
video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None,
pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
noise_mask = None
ref_mask_flag = not replacement_mode
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag})
prev_trimmed = None
if previous_frames is not None and previous_frames.shape[0] > 0:
prev_trimmed = previous_frames[-previous_frame_count:]
video_frame_offset -= prev_trimmed.shape[0]
video_frame_offset = max(0, video_frame_offset)
if reference_image is not None:
ref_imgs = comfy.utils.common_upscale(reference_image.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
n_ref = ref_imgs.shape[0]
# SCAIL-2 multi-reference: the first image is the primary ref, the rest are additional references.
# Replacement Mode: composite each ref on black bg using its mask as alpha matte
if replacement_mode and reference_image_mask is not None:
rm = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
rm = rm[[min(i, rm.shape[0] - 1) for i in range(n_ref)]]
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(ref_imgs.dtype)
ref_imgs = ref_imgs * is_char
# encode each ref individually so each stays a single latent frame (a batched encode would be treated as a video)
ref_latents = [vae.encode(ref_imgs[i:i + 1, :, :, :3]) for i in range(n_ref)]
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
if pose_video.shape[0] <= video_frame_offset:
pose_video = None
else:
pose_video = pose_video[video_frame_offset:]
if pose_video_mask is not None:
if pose_video_mask.shape[0] <= video_frame_offset:
pose_video_mask = None
else:
pose_video_mask = pose_video_mask[video_frame_offset:]
# Truncate pose+mask jointly to the shorter of the two, capped at length.
ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None]
if ts:
T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1
if pose_video is not None:
pose_video = pose_video[:T_kept]
if pose_video_mask is not None:
pose_video_mask = pose_video_mask[:T_kept]
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
if pose_video_mask is not None:
mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw)
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
# The ref mask binds reference frames to identities, so it only applies when there's a reference image.
if reference_image_mask is not None and reference_image is not None:
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
n_masks = ref_mask_hw.shape[0]
n_ref = reference_image.shape[0]
add_masks = [_extract_mask_to_28ch(ref_mask_hw[min(i, n_masks - 1)][None]) for i in range(1, n_ref)]
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw[:1])
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
ref_mask_28ch = torch.cat(add_masks + [ref_mask_1f, zeros], dim=1)
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
if prev_trimmed is not None:
pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
prev_latent = vae.encode(pf[:, :, :, :3])
prev_latent_frames = min(prev_latent.shape[2], latent.shape[2])
latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype)
noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype)
noise_mask[:, :, :prev_latent_frames] = 0.0
out_latent = {"samples": latent}
if noise_mask is not None:
out_latent["noise_mask"] = noise_mask
return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length)
class SCAIL2ColoredMask(io.ComfyNode):
"""Render SAM3 tracks for the driving pose video and reference image(s) into the
colored masks WanSCAILToVideo consumes. Shared `sort_by` keeps each identity on the
same color across both outputs.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SCAIL2ColoredMask",
display_name="Create SCAIL-2 Colored Mask",
category="model/conditioning/wan/scail",
inputs=[
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
io.MultiType.Input("ref_track_data", [SAM3TrackData, io.Mask], optional=True, display_name="reference_masks",
tooltip="SAM3 track of the reference image(s) (one identity per object, colored in batch order), or a plain MASK of the reference subject (rendered as a single identity)."),
io.String.Input("object_indices", default="",
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). Objects that appear in earlier frames always come first; within a frame, left_to_right = leftmost object (by centroid at first appearance) gets the first color, area = biggest object (by mask area at first appearance) gets the first color; none = keep SAM3's order."),
io.Boolean.Input("replacement_mode", default=False,
tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). "
"True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."),
],
outputs=[
io.Image.Output("pose_video_mask"),
io.Image.Output("reference_image_mask"),
],
is_experimental=True,
)
@classmethod
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
def _prep(td):
masks_bool = _unpack(td)
if sort_by != "none" and masks_bool is not None:
first_t, cx, area = _first_appearance_cx_area(masks_bool)
if sort_by == "left_to_right":
order = sorted(range(len(cx)), key=lambda i: (first_t[i], cx[i]))
else: # "area"
order = sorted(range(len(area)), key=lambda i: (first_t[i], -area[i]))
td = _subset_track_data(td, order)
if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
packed = td.get("packed_masks")
n_obj = packed.shape[1] if packed is not None else 0
indices = [i for i in indices if 0 <= i < n_obj]
td = _subset_track_data(td, indices)
return td
drv = _prep(driving_track_data)
# Animation: driving=black, ref=white. Replacement: driving=white, ref=black.
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
ref_bg = "black" if replacement_mode else "white"
if ref_track_data is not None:
if isinstance(ref_track_data, torch.Tensor): # plain comfy MASK
reference_image_mask = _render_mask_as_identity(ref_track_data, ref_bg)
else:
reference_image_mask = _render_colored_masks(_prep(ref_track_data), ref_bg)
else:
H, W = drv["orig_size"]
fill_value = 1.0 if ref_bg == "white" else 0.0
reference_image_mask = torch.full((1, H, W, 3), fill_value, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput(mask_video, reference_image_mask)
class SCAILExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanSCAILToVideo,
SCAIL2ColoredMask,
]
async def comfy_entrypoint() -> SCAILExtension:
return SCAILExtension()