mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-19 06:19:38 +08:00
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* SCAIl-2: support multiref
352 lines
20 KiB
Python
352 lines
20 KiB
Python
"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3
|
|
preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes."""
|
|
|
|
from typing_extensions import override
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import nodes
|
|
import node_helpers
|
|
import comfy.model_management
|
|
import comfy.utils
|
|
from comfy_api.latest import ComfyExtension, io
|
|
from comfy.ldm.sam3.tracker import unpack_masks
|
|
|
|
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
|
|
|
|
|
# Model was trained on these exact colors; deviating degrades multi-identity quality.
|
|
DEFAULT_PALETTE = [
|
|
(0.0, 0.0, 1.0), # Blue
|
|
(1.0, 0.0, 0.0), # Red
|
|
(0.0, 1.0, 0.0), # Green
|
|
(1.0, 0.0, 1.0), # Magenta
|
|
(0.0, 1.0, 1.0), # Cyan
|
|
(1.0, 1.0, 0.0), # Yellow
|
|
]
|
|
|
|
|
|
def _unpack(track_data):
|
|
packed = track_data["packed_masks"]
|
|
if packed is None or packed.shape[1] == 0:
|
|
return None
|
|
return unpack_masks(packed)
|
|
|
|
|
|
def _first_appearance_cx_area(masks_bool):
|
|
"""Per object: first frame it appears in, plus centroid-x and area in that frame."""
|
|
m = masks_bool.float()
|
|
T, H, W = m.shape[0], m.shape[-2], m.shape[-1]
|
|
grid_x = torch.arange(W, device=m.device, dtype=m.dtype).view(1, 1, 1, W)
|
|
area_t = m.sum(dim=(-1, -2))
|
|
cx_t = (m * grid_x).sum(dim=(-1, -2)) / area_t.clamp(min=1)
|
|
present = area_t > 0
|
|
frame_idx = torch.arange(T, device=m.device).unsqueeze(1)
|
|
first_t = torch.where(present, frame_idx, T).amin(dim=0)
|
|
sel = first_t.clamp(max=T - 1).unsqueeze(0)
|
|
cx = cx_t.gather(0, sel).squeeze(0)
|
|
area = area_t.gather(0, sel).squeeze(0)
|
|
return first_t.tolist(), (cx / W).tolist(), (area / (H * W)).tolist()
|
|
|
|
|
|
def _subset_track_data(track_data, obj_indices):
|
|
out = dict(track_data)
|
|
packed = track_data["packed_masks"]
|
|
if packed is None or not obj_indices:
|
|
out["packed_masks"] = None
|
|
if "scores" in out:
|
|
out["scores"] = []
|
|
return out
|
|
out["packed_masks"] = packed[:, obj_indices].contiguous()
|
|
scores = track_data.get("scores")
|
|
if scores is not None:
|
|
out["scores"] = [scores[i] for i in obj_indices if i < len(scores)]
|
|
return out
|
|
|
|
|
|
def _render_colored_masks(track_data, background="black"):
|
|
packed = track_data["packed_masks"]
|
|
H, W = track_data["orig_size"]
|
|
device = comfy.model_management.intermediate_device()
|
|
dtype = comfy.model_management.intermediate_dtype()
|
|
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
|
|
if packed is None or packed.shape[1] == 0:
|
|
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
|
|
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
|
|
out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2]
|
|
return out
|
|
T, N_obj = packed.shape[0], packed.shape[1]
|
|
colors = torch.tensor(
|
|
[DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)],
|
|
device=device, dtype=dtype,
|
|
)
|
|
masks_full = unpack_masks(packed.to(device)).float()
|
|
Hm, Wm = masks_full.shape[-2], masks_full.shape[-1]
|
|
masks_full = F.interpolate(
|
|
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
|
|
).view(T, N_obj, H, W) > 0.5
|
|
any_mask = masks_full.any(dim=1)
|
|
color_overlay = colors[masks_full.to(torch.uint8).argmax(dim=1)]
|
|
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
|
|
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
|
|
|
|
|
|
def _render_mask_as_identity(mask, background="black"):
|
|
"""Plain comfy MASK (B,H,W) or (H,W) -> (B,H,W,3) rendered as a single identity (palette[0])
|
|
on the given background. A batch is treated as multiple views of that one subject."""
|
|
device = comfy.model_management.intermediate_device()
|
|
dtype = comfy.model_management.intermediate_dtype()
|
|
if mask.ndim == 2:
|
|
mask = mask.unsqueeze(0)
|
|
mask = mask.to(device=device, dtype=dtype)
|
|
B, H, W = mask.shape
|
|
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
|
|
color = torch.tensor(DEFAULT_PALETTE[0], device=device, dtype=dtype).view(1, 1, 1, 3)
|
|
bg = torch.tensor(bg_rgb, device=device, dtype=dtype).view(1, 1, 1, 3)
|
|
return torch.where((mask > 0.5).unsqueeze(-1), color.expand(B, H, W, 3), bg.expand(B, H, W, 3))
|
|
|
|
|
|
def _extract_mask_to_28ch(rgb_video):
|
|
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
|
|
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
|
|
threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking."""
|
|
T, H, W, _ = rgb_video.shape
|
|
_ON_THRESH = 225.0 / 255.0
|
|
mask = rgb_video.movedim(-1, 1).float()
|
|
R = (mask[:, 0:1] > _ON_THRESH).float()
|
|
G = (mask[:, 1:2] > _ON_THRESH).float()
|
|
B = (mask[:, 2:3] > _ON_THRESH).float()
|
|
nR, nG, nB = 1 - R, 1 - G, 1 - B
|
|
binary_7ch = torch.cat([
|
|
R * G * B, # white
|
|
R * nG * nB, # red
|
|
nR * G * nB, # green
|
|
nR * nG * B, # blue
|
|
R * G * nB, # yellow
|
|
R * nG * B, # magenta
|
|
nR * G * B, # cyan
|
|
], dim=1)
|
|
H_lat, W_lat = H, W
|
|
for _ in range(3):
|
|
H_lat = (H_lat + 1) // 2
|
|
W_lat = (W_lat + 1) // 2
|
|
binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area')
|
|
T_latent = (T - 1) // 4 + 1
|
|
padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0)
|
|
out = padded.view(T_latent, 28, H_lat, W_lat)
|
|
return out.unsqueeze(0)
|
|
|
|
|
|
class WanSCAILToVideo(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="WanSCAILToVideo",
|
|
category="model/conditioning/wan/scail",
|
|
inputs=[
|
|
io.Conditioning.Input("positive"),
|
|
io.Conditioning.Input("negative"),
|
|
io.Vae.Input("vae"),
|
|
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
|
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
|
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
|
io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."),
|
|
io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."),
|
|
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
|
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
|
|
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
|
|
io.Image.Input("reference_image", optional=True, tooltip="Reference image. The first image is the primary reference (composite all identities onto it). SCAIL-2: extra batch images are used as additional views (back view, close-up, occluded background), each needing a matching reference_image_mask in that identity's color."),
|
|
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask, batch matching reference_image (first = primary reference mask, rest = identity masks for the additional reference_image)."),
|
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
|
|
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
|
|
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
|
|
io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."),
|
|
],
|
|
outputs=[
|
|
io.Conditioning.Output(display_name="positive"),
|
|
io.Conditioning.Output(display_name="negative"),
|
|
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
|
io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."),
|
|
],
|
|
is_experimental=True,
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end,
|
|
video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None,
|
|
pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput:
|
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
|
noise_mask = None
|
|
|
|
ref_mask_flag = not replacement_mode
|
|
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag})
|
|
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag})
|
|
|
|
prev_trimmed = None
|
|
if previous_frames is not None and previous_frames.shape[0] > 0:
|
|
prev_trimmed = previous_frames[-previous_frame_count:]
|
|
video_frame_offset -= prev_trimmed.shape[0]
|
|
video_frame_offset = max(0, video_frame_offset)
|
|
|
|
if reference_image is not None:
|
|
ref_imgs = comfy.utils.common_upscale(reference_image.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
|
n_ref = ref_imgs.shape[0]
|
|
# SCAIL-2 multi-reference: the first image is the primary ref, the rest are additional references.
|
|
|
|
# Replacement Mode: composite each ref on black bg using its mask as alpha matte
|
|
if replacement_mode and reference_image_mask is not None:
|
|
rm = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
|
|
rm = rm[[min(i, rm.shape[0] - 1) for i in range(n_ref)]]
|
|
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(ref_imgs.dtype)
|
|
ref_imgs = ref_imgs * is_char
|
|
# encode each ref individually so each stays a single latent frame (a batched encode would be treated as a video)
|
|
ref_latents = [vae.encode(ref_imgs[i:i + 1, :, :, :3]) for i in range(n_ref)]
|
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True)
|
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True)
|
|
|
|
if clip_vision_output is not None:
|
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
|
|
|
if pose_video is not None:
|
|
if pose_video.shape[0] <= video_frame_offset:
|
|
pose_video = None
|
|
else:
|
|
pose_video = pose_video[video_frame_offset:]
|
|
if pose_video_mask is not None:
|
|
if pose_video_mask.shape[0] <= video_frame_offset:
|
|
pose_video_mask = None
|
|
else:
|
|
pose_video_mask = pose_video_mask[video_frame_offset:]
|
|
|
|
# Truncate pose+mask jointly to the shorter of the two, capped at length.
|
|
ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None]
|
|
if ts:
|
|
T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1
|
|
if pose_video is not None:
|
|
pose_video = pose_video[:T_kept]
|
|
if pose_video_mask is not None:
|
|
pose_video_mask = pose_video_mask[:T_kept]
|
|
|
|
if pose_video is not None:
|
|
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
|
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
|
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
|
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
|
|
|
if pose_video_mask is not None:
|
|
mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
|
driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw)
|
|
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
|
|
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
|
|
|
|
# The ref mask binds reference frames to identities, so it only applies when there's a reference image.
|
|
if reference_image_mask is not None and reference_image is not None:
|
|
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
|
|
n_masks = ref_mask_hw.shape[0]
|
|
n_ref = reference_image.shape[0]
|
|
|
|
add_masks = [_extract_mask_to_28ch(ref_mask_hw[min(i, n_masks - 1)][None]) for i in range(1, n_ref)]
|
|
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw[:1])
|
|
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
|
|
ref_mask_28ch = torch.cat(add_masks + [ref_mask_1f, zeros], dim=1)
|
|
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
|
|
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
|
|
|
|
if prev_trimmed is not None:
|
|
pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
|
prev_latent = vae.encode(pf[:, :, :, :3])
|
|
prev_latent_frames = min(prev_latent.shape[2], latent.shape[2])
|
|
latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype)
|
|
noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype)
|
|
noise_mask[:, :, :prev_latent_frames] = 0.0
|
|
|
|
out_latent = {"samples": latent}
|
|
if noise_mask is not None:
|
|
out_latent["noise_mask"] = noise_mask
|
|
return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length)
|
|
|
|
|
|
class SCAIL2ColoredMask(io.ComfyNode):
|
|
"""Render SAM3 tracks for the driving pose video and reference image(s) into the
|
|
colored masks WanSCAILToVideo consumes. Shared `sort_by` keeps each identity on the
|
|
same color across both outputs.
|
|
"""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SCAIL2ColoredMask",
|
|
display_name="Create SCAIL-2 Colored Mask",
|
|
category="model/conditioning/wan/scail",
|
|
inputs=[
|
|
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
|
|
io.MultiType.Input("ref_track_data", [SAM3TrackData, io.Mask], optional=True, display_name="reference_masks",
|
|
tooltip="SAM3 track of the reference image(s) (one identity per object, colored in batch order), or a plain MASK of the reference subject (rendered as a single identity)."),
|
|
io.String.Input("object_indices", default="",
|
|
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
|
|
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
|
|
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). Objects that appear in earlier frames always come first; within a frame, left_to_right = leftmost object (by centroid at first appearance) gets the first color, area = biggest object (by mask area at first appearance) gets the first color; none = keep SAM3's order."),
|
|
io.Boolean.Input("replacement_mode", default=False,
|
|
tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). "
|
|
"True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."),
|
|
],
|
|
outputs=[
|
|
io.Image.Output("pose_video_mask"),
|
|
io.Image.Output("reference_image_mask"),
|
|
],
|
|
is_experimental=True,
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
|
|
def _prep(td):
|
|
masks_bool = _unpack(td)
|
|
if sort_by != "none" and masks_bool is not None:
|
|
first_t, cx, area = _first_appearance_cx_area(masks_bool)
|
|
if sort_by == "left_to_right":
|
|
order = sorted(range(len(cx)), key=lambda i: (first_t[i], cx[i]))
|
|
else: # "area"
|
|
order = sorted(range(len(area)), key=lambda i: (first_t[i], -area[i]))
|
|
td = _subset_track_data(td, order)
|
|
if object_indices.strip():
|
|
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
|
packed = td.get("packed_masks")
|
|
n_obj = packed.shape[1] if packed is not None else 0
|
|
indices = [i for i in indices if 0 <= i < n_obj]
|
|
td = _subset_track_data(td, indices)
|
|
return td
|
|
|
|
drv = _prep(driving_track_data)
|
|
# Animation: driving=black, ref=white. Replacement: driving=white, ref=black.
|
|
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
|
|
ref_bg = "black" if replacement_mode else "white"
|
|
|
|
if ref_track_data is not None:
|
|
if isinstance(ref_track_data, torch.Tensor): # plain comfy MASK
|
|
reference_image_mask = _render_mask_as_identity(ref_track_data, ref_bg)
|
|
else:
|
|
reference_image_mask = _render_colored_masks(_prep(ref_track_data), ref_bg)
|
|
else:
|
|
H, W = drv["orig_size"]
|
|
fill_value = 1.0 if ref_bg == "white" else 0.0
|
|
reference_image_mask = torch.full((1, H, W, 3), fill_value, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
|
|
|
return io.NodeOutput(mask_video, reference_image_mask)
|
|
|
|
|
|
class SCAILExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
WanSCAILToVideo,
|
|
SCAIL2ColoredMask,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> SCAILExtension:
|
|
return SCAILExtension()
|