mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-25 09:19:46 +08:00
More cleanup
This commit is contained in:
parent
856149a999
commit
140e34a4bb
@ -1631,13 +1631,15 @@ class SCAILWanModel(WanModel):
|
|||||||
|
|
||||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||||
|
|
||||||
if reference_latent is not None:
|
if reference_latent is not None:
|
||||||
x = torch.cat((reference_latent, x), dim=2)
|
x = torch.cat((reference_latent, x), dim=2)
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
if ref_mask_latents is not None: # SCAIL-2 additive mask stream
|
||||||
|
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
|
||||||
grid_sizes = x.shape[2:]
|
grid_sizes = x.shape[2:]
|
||||||
transformer_options["grid_sizes"] = grid_sizes
|
transformer_options["grid_sizes"] = grid_sizes
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
@ -1645,6 +1647,8 @@ class SCAILWanModel(WanModel):
|
|||||||
scail_pose_seq_len = 0
|
scail_pose_seq_len = 0
|
||||||
if pose_latents is not None:
|
if pose_latents is not None:
|
||||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||||
|
if sam_latents is not None: # SCAIL-2 additive mask stream
|
||||||
|
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
|
||||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||||
scail_pose_seq_len = scail_x.shape[1]
|
scail_pose_seq_len = scail_x.shape[1]
|
||||||
x = torch.cat([x, scail_x], dim=1)
|
x = torch.cat([x, scail_x], dim=1)
|
||||||
@ -1695,7 +1699,36 @@ class SCAILWanModel(WanModel):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
|
||||||
|
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
|
||||||
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
|
||||||
|
if ref_mask_flag is not None and not bool(ref_mask_flag):
|
||||||
|
REF_ROPE_H = 120.0
|
||||||
|
POSE_ROPE_W = 120.0
|
||||||
|
|
||||||
|
ref_t_patches = 0
|
||||||
|
if reference_latent is not None:
|
||||||
|
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||||
|
main_t_patches = t - ref_t_patches
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if ref_t_patches > 0:
|
||||||
|
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
|
||||||
|
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
|
||||||
|
if main_t_patches > 0:
|
||||||
|
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
|
||||||
|
|
||||||
|
if pose_latents is not None:
|
||||||
|
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||||
|
h_scale = h / H_pose
|
||||||
|
w_scale = w / W_pose
|
||||||
|
h_shift = (h_scale - 1) / 2
|
||||||
|
w_shift = (w_scale - 1) / 2
|
||||||
|
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||||
|
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
|
||||||
|
|
||||||
|
return torch.cat(parts, dim=1)
|
||||||
|
|
||||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
if pose_latents is None:
|
if pose_latents is None:
|
||||||
@ -1719,138 +1752,15 @@ class SCAILWanModel(WanModel):
|
|||||||
|
|
||||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
|
||||||
bs, c, t, h, w = x.shape
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
|
||||||
|
|
||||||
if pose_latents is not None:
|
|
||||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
|
||||||
|
|
||||||
t_len = t
|
|
||||||
if time_dim_concat is not None:
|
|
||||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
|
||||||
x = torch.cat([x, time_dim_concat], dim=2)
|
|
||||||
t_len = x.shape[2]
|
|
||||||
|
|
||||||
reference_latent = None
|
|
||||||
if "reference_latent" in kwargs:
|
|
||||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
|
||||||
t_len += reference_latent.shape[2]
|
|
||||||
|
|
||||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
|
||||||
|
|
||||||
|
|
||||||
class SCAIL2WanModel(SCAILWanModel):
|
|
||||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
|
||||||
|
|
||||||
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
|
||||||
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
|
||||||
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
|
||||||
if reference_latent is not None:
|
|
||||||
x = torch.cat((reference_latent, x), dim=2)
|
|
||||||
|
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
|
||||||
if ref_mask_latents is not None:
|
|
||||||
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
|
|
||||||
grid_sizes = x.shape[2:]
|
|
||||||
transformer_options["grid_sizes"] = grid_sizes
|
|
||||||
x = x.flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
scail_pose_seq_len = 0
|
|
||||||
if pose_latents is not None:
|
|
||||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
|
||||||
if sam_latents is not None:
|
|
||||||
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
|
|
||||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
|
||||||
scail_pose_seq_len = scail_x.shape[1]
|
|
||||||
x = torch.cat([x, scail_x], dim=1)
|
|
||||||
del scail_x
|
|
||||||
|
|
||||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
|
||||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
|
||||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
|
||||||
|
|
||||||
context = self.text_embedding(context)
|
|
||||||
|
|
||||||
context_img_len = None
|
|
||||||
if clip_fea is not None:
|
|
||||||
if self.img_emb is not None:
|
|
||||||
context_clip = self.img_emb(clip_fea)
|
|
||||||
context = torch.cat([context_clip, context], dim=1)
|
|
||||||
context_img_len = clip_fea.shape[-2]
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
|
||||||
transformer_options["total_blocks"] = len(self.blocks)
|
|
||||||
transformer_options["block_type"] = "double"
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
transformer_options["block_index"] = i
|
|
||||||
if ("double_block", i) in blocks_replace:
|
|
||||||
def block_wrap(args):
|
|
||||||
out = {}
|
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
|
||||||
return out
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
|
||||||
x = out["img"]
|
|
||||||
else:
|
|
||||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
x = self.head(x, e)
|
|
||||||
|
|
||||||
if scail_pose_seq_len > 0:
|
|
||||||
x = x[:, :-scail_pose_seq_len]
|
|
||||||
|
|
||||||
x = self.unpatchify(x, grid_sizes)
|
|
||||||
|
|
||||||
if reference_latent is not None:
|
|
||||||
x = x[:, :, reference_latent.shape[2]:]
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
# ref_mask_flag is a scalar bool (CONDConstant); the mode is uniform across the batch.
|
|
||||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
|
|
||||||
is_replacement = ref_mask_flag is not None and not bool(ref_mask_flag)
|
|
||||||
if not is_replacement:
|
|
||||||
return super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, pose_latents=pose_latents, reference_latent=reference_latent, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
REF_ROPE_H = 120.0
|
|
||||||
POSE_ROPE_W = 120.0
|
|
||||||
|
|
||||||
ref_t_patches = 0
|
|
||||||
if reference_latent is not None:
|
|
||||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
|
||||||
main_t_patches = t - ref_t_patches
|
|
||||||
|
|
||||||
parts = []
|
|
||||||
if ref_t_patches > 0:
|
|
||||||
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
|
|
||||||
parts.append(super(SCAILWanModel, self).rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
|
|
||||||
if main_t_patches > 0:
|
|
||||||
parts.append(super(SCAILWanModel, self).rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
|
|
||||||
|
|
||||||
if pose_latents is not None:
|
|
||||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
|
||||||
h_scale = h / H_pose
|
|
||||||
w_scale = w / W_pose
|
|
||||||
h_shift = (h_scale - 1) / 2
|
|
||||||
w_shift = (w_scale - 1) / 2
|
|
||||||
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
|
||||||
parts.append(super(SCAILWanModel, self).rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
|
|
||||||
|
|
||||||
return torch.cat(parts, dim=1)
|
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
if pose_latents is not None:
|
if pose_latents is not None:
|
||||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||||
if ref_mask_latents is not None:
|
if ref_mask_latents is not None: # SCAIL-2
|
||||||
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
|
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
|
||||||
if sam_latents is not None:
|
if sam_latents is not None: # SCAIL-2
|
||||||
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
|
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
|
||||||
|
|
||||||
t_len = t
|
t_len = t
|
||||||
@ -1864,7 +1774,15 @@ class SCAIL2WanModel(SCAILWanModel):
|
|||||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||||
t_len += reference_latent.shape[2]
|
t_len += reference_latent.shape[2]
|
||||||
|
|
||||||
ref_mask_flag = kwargs.pop("ref_mask_flag", None)
|
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
|
||||||
|
|
||||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
|
|
||||||
|
class SCAIL2WanModel(SCAILWanModel):
|
||||||
|
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||||
|
|
||||||
|
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||||
|
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||||
|
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import node_helpers
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy.ldm.sam3.tracker import unpack_masks
|
||||||
|
|
||||||
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
||||||
|
|
||||||
@ -28,7 +28,6 @@ DEFAULT_PALETTE = [
|
|||||||
|
|
||||||
|
|
||||||
def _unpack(track_data):
|
def _unpack(track_data):
|
||||||
from comfy.ldm.sam3.tracker import unpack_masks
|
|
||||||
packed = track_data["packed_masks"]
|
packed = track_data["packed_masks"]
|
||||||
if packed is None or packed.shape[1] == 0:
|
if packed is None or packed.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
@ -45,16 +44,6 @@ def _first_frame_cx_area(masks_bool):
|
|||||||
return (cx / W).tolist(), (area / n_pixels).tolist()
|
return (cx / W).tolist(), (area / n_pixels).tolist()
|
||||||
|
|
||||||
|
|
||||||
def _sort_tracks(track_data, sort_by):
|
|
||||||
masks_bool = _unpack(track_data)
|
|
||||||
if masks_bool is None:
|
|
||||||
return []
|
|
||||||
cx, area = _first_frame_cx_area(masks_bool)
|
|
||||||
if sort_by == "x":
|
|
||||||
return sorted(range(len(cx)), key=lambda i: cx[i])
|
|
||||||
return sorted(range(len(area)), key=lambda i: -area[i]) # "area"
|
|
||||||
|
|
||||||
|
|
||||||
def _subset_track_data(track_data, obj_indices):
|
def _subset_track_data(track_data, obj_indices):
|
||||||
out = dict(track_data)
|
out = dict(track_data)
|
||||||
packed = track_data["packed_masks"]
|
packed = track_data["packed_masks"]
|
||||||
@ -70,19 +59,12 @@ def _subset_track_data(track_data, obj_indices):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _bg_to_rgb(background):
|
|
||||||
if background.startswith("white"):
|
|
||||||
return (1.0, 1.0, 1.0)
|
|
||||||
return (0.0, 0.0, 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def _render_colored_masks(track_data, background="black"):
|
def _render_colored_masks(track_data, background="black"):
|
||||||
from comfy.ldm.sam3.tracker import unpack_masks
|
|
||||||
packed = track_data["packed_masks"]
|
packed = track_data["packed_masks"]
|
||||||
H, W = track_data["orig_size"]
|
H, W = track_data["orig_size"]
|
||||||
device = comfy.model_management.intermediate_device()
|
device = comfy.model_management.intermediate_device()
|
||||||
dtype = comfy.model_management.intermediate_dtype()
|
dtype = comfy.model_management.intermediate_dtype()
|
||||||
bg_rgb = _bg_to_rgb(background)
|
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
|
||||||
if packed is None or packed.shape[1] == 0:
|
if packed is None or packed.shape[1] == 0:
|
||||||
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
|
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
|
||||||
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
|
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
|
||||||
@ -277,19 +259,19 @@ class SCAIL2ColoredMask(io.ComfyNode):
|
|||||||
display_name="SCAIL-2 Colored Mask",
|
display_name="SCAIL-2 Colored Mask",
|
||||||
category="conditioning/video_models/scail",
|
category="conditioning/video_models/scail",
|
||||||
inputs=[
|
inputs=[
|
||||||
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving video. Will be rendered into the driving_mask_video output."),
|
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving video. Will be rendered into the pose_video_mask output."),
|
||||||
SAM3TrackData.Input("ref_track_data", optional=True,
|
SAM3TrackData.Input("ref_track_data", optional=True,
|
||||||
tooltip="SAM3 track of the reference image. Optional — wire it for the ref_mask_image output."),
|
tooltip="SAM3 track of the reference image. Optional — wire it for the reference_image_mask output."),
|
||||||
io.String.Input("object_indices", default="",
|
io.String.Input("object_indices", default="",
|
||||||
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Applied to both sides. Empty = all."),
|
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Applied to both sides. Empty = all."),
|
||||||
io.Combo.Input("sort_by", options=["none", "x", "area"],
|
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
|
||||||
tooltip="Applied to both sides identically so index K = same logical slot. x = left-to-right by first-frame centroid; area = descending mask area; none = SAM3's order."),
|
tooltip="Applied to both sides identically so that color index order matches. left_to_right = by first-frame centroid; area = descending mask area; none = SAM3's original order."),
|
||||||
io.Boolean.Input("replacement_mode", default=False,
|
io.Boolean.Input("replacement_mode", default=False,
|
||||||
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. ref_mask_image is always black-bg regardless."),
|
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output("driving_mask_video"),
|
io.Image.Output("pose_video_mask"),
|
||||||
io.Image.Output("ref_mask_image"),
|
io.Image.Output("reference_image_mask"),
|
||||||
],
|
],
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
)
|
)
|
||||||
@ -297,8 +279,14 @@ class SCAIL2ColoredMask(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
|
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
|
||||||
def _prep(td):
|
def _prep(td):
|
||||||
if sort_by != "none":
|
masks_bool = _unpack(td)
|
||||||
td = _subset_track_data(td, _sort_tracks(td, sort_by))
|
if sort_by != "none" and masks_bool is not None:
|
||||||
|
cx, area = _first_frame_cx_area(masks_bool)
|
||||||
|
if sort_by == "left_to_right":
|
||||||
|
order = sorted(range(len(cx)), key=lambda i: cx[i])
|
||||||
|
else: # "area"
|
||||||
|
order = sorted(range(len(area)), key=lambda i: -area[i])
|
||||||
|
td = _subset_track_data(td, order)
|
||||||
if object_indices.strip():
|
if object_indices.strip():
|
||||||
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
||||||
packed = td.get("packed_masks")
|
packed = td.get("packed_masks")
|
||||||
@ -312,12 +300,12 @@ class SCAIL2ColoredMask(io.ComfyNode):
|
|||||||
|
|
||||||
if ref_track_data is not None:
|
if ref_track_data is not None:
|
||||||
ref = _prep(ref_track_data)
|
ref = _prep(ref_track_data)
|
||||||
ref_mask_image = _render_colored_masks(ref, "black")
|
reference_image_mask = _render_colored_masks(ref, "black")
|
||||||
else:
|
else:
|
||||||
H, W = drv["orig_size"]
|
H, W = drv["orig_size"]
|
||||||
ref_mask_image = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
|
|
||||||
return io.NodeOutput(mask_video, ref_mask_image)
|
return io.NodeOutput(mask_video, reference_image_mask)
|
||||||
|
|
||||||
|
|
||||||
class SCAILExtension(ComfyExtension):
|
class SCAILExtension(ComfyExtension):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user