mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 16:59:29 +08:00
More cleanup
This commit is contained in:
parent
856149a999
commit
140e34a4bb
@ -1631,13 +1631,15 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if ref_mask_latents is not None: # SCAIL-2 additive mask stream
|
||||
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
@ -1645,6 +1647,8 @@ class SCAILWanModel(WanModel):
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
if sam_latents is not None: # SCAIL-2 additive mask stream
|
||||
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
@ -1695,7 +1699,36 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
|
||||
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
|
||||
if ref_mask_flag is not None and not bool(ref_mask_flag):
|
||||
REF_ROPE_H = 120.0
|
||||
POSE_ROPE_W = 120.0
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
main_t_patches = t - ref_t_patches
|
||||
|
||||
parts = []
|
||||
if ref_t_patches > 0:
|
||||
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
|
||||
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
|
||||
if main_t_patches > 0:
|
||||
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
|
||||
|
||||
if pose_latents is not None:
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
|
||||
|
||||
return torch.cat(parts, dim=1)
|
||||
|
||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
|
||||
if pose_latents is None:
|
||||
@ -1719,138 +1752,15 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
reference_latent = None
|
||||
if "reference_latent" in kwargs:
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
|
||||
class SCAIL2WanModel(SCAILWanModel):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if ref_mask_latents is not None:
|
||||
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
if sam_latents is not None:
|
||||
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
del scail_x
|
||||
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.cat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
x = self.head(x, e)
|
||||
|
||||
if scail_pose_seq_len > 0:
|
||||
x = x[:, :-scail_pose_seq_len]
|
||||
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
|
||||
if reference_latent is not None:
|
||||
x = x[:, :, reference_latent.shape[2]:]
|
||||
|
||||
return x
|
||||
|
||||
# ref_mask_flag is a scalar bool (CONDConstant); the mode is uniform across the batch.
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
|
||||
is_replacement = ref_mask_flag is not None and not bool(ref_mask_flag)
|
||||
if not is_replacement:
|
||||
return super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, pose_latents=pose_latents, reference_latent=reference_latent, transformer_options=transformer_options)
|
||||
|
||||
REF_ROPE_H = 120.0
|
||||
POSE_ROPE_W = 120.0
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
main_t_patches = t - ref_t_patches
|
||||
|
||||
parts = []
|
||||
if ref_t_patches > 0:
|
||||
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
|
||||
parts.append(super(SCAILWanModel, self).rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
|
||||
if main_t_patches > 0:
|
||||
parts.append(super(SCAILWanModel, self).rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
|
||||
|
||||
if pose_latents is not None:
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
parts.append(super(SCAILWanModel, self).rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
|
||||
|
||||
return torch.cat(parts, dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
if ref_mask_latents is not None:
|
||||
if ref_mask_latents is not None: # SCAIL-2
|
||||
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
|
||||
if sam_latents is not None:
|
||||
if sam_latents is not None: # SCAIL-2
|
||||
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
@ -1864,7 +1774,15 @@ class SCAIL2WanModel(SCAILWanModel):
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
ref_mask_flag = kwargs.pop("ref_mask_flag", None)
|
||||
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
|
||||
class SCAIL2WanModel(SCAILWanModel):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
@ -11,7 +11,7 @@ import node_helpers
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
|
||||
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
||||
|
||||
@ -28,7 +28,6 @@ DEFAULT_PALETTE = [
|
||||
|
||||
|
||||
def _unpack(track_data):
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
packed = track_data["packed_masks"]
|
||||
if packed is None or packed.shape[1] == 0:
|
||||
return None
|
||||
@ -45,16 +44,6 @@ def _first_frame_cx_area(masks_bool):
|
||||
return (cx / W).tolist(), (area / n_pixels).tolist()
|
||||
|
||||
|
||||
def _sort_tracks(track_data, sort_by):
|
||||
masks_bool = _unpack(track_data)
|
||||
if masks_bool is None:
|
||||
return []
|
||||
cx, area = _first_frame_cx_area(masks_bool)
|
||||
if sort_by == "x":
|
||||
return sorted(range(len(cx)), key=lambda i: cx[i])
|
||||
return sorted(range(len(area)), key=lambda i: -area[i]) # "area"
|
||||
|
||||
|
||||
def _subset_track_data(track_data, obj_indices):
|
||||
out = dict(track_data)
|
||||
packed = track_data["packed_masks"]
|
||||
@ -70,19 +59,12 @@ def _subset_track_data(track_data, obj_indices):
|
||||
return out
|
||||
|
||||
|
||||
def _bg_to_rgb(background):
|
||||
if background.startswith("white"):
|
||||
return (1.0, 1.0, 1.0)
|
||||
return (0.0, 0.0, 0.0)
|
||||
|
||||
|
||||
def _render_colored_masks(track_data, background="black"):
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
packed = track_data["packed_masks"]
|
||||
H, W = track_data["orig_size"]
|
||||
device = comfy.model_management.intermediate_device()
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
bg_rgb = _bg_to_rgb(background)
|
||||
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
|
||||
if packed is None or packed.shape[1] == 0:
|
||||
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
|
||||
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
|
||||
@ -277,19 +259,19 @@ class SCAIL2ColoredMask(io.ComfyNode):
|
||||
display_name="SCAIL-2 Colored Mask",
|
||||
category="conditioning/video_models/scail",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving video. Will be rendered into the driving_mask_video output."),
|
||||
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving video. Will be rendered into the pose_video_mask output."),
|
||||
SAM3TrackData.Input("ref_track_data", optional=True,
|
||||
tooltip="SAM3 track of the reference image. Optional — wire it for the ref_mask_image output."),
|
||||
tooltip="SAM3 track of the reference image. Optional — wire it for the reference_image_mask output."),
|
||||
io.String.Input("object_indices", default="",
|
||||
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Applied to both sides. Empty = all."),
|
||||
io.Combo.Input("sort_by", options=["none", "x", "area"],
|
||||
tooltip="Applied to both sides identically so index K = same logical slot. x = left-to-right by first-frame centroid; area = descending mask area; none = SAM3's order."),
|
||||
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
|
||||
tooltip="Applied to both sides identically so that color index order matches. left_to_right = by first-frame centroid; area = descending mask area; none = SAM3's original order."),
|
||||
io.Boolean.Input("replacement_mode", default=False,
|
||||
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. ref_mask_image is always black-bg regardless."),
|
||||
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("driving_mask_video"),
|
||||
io.Image.Output("ref_mask_image"),
|
||||
io.Image.Output("pose_video_mask"),
|
||||
io.Image.Output("reference_image_mask"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
@ -297,8 +279,14 @@ class SCAIL2ColoredMask(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
|
||||
def _prep(td):
|
||||
if sort_by != "none":
|
||||
td = _subset_track_data(td, _sort_tracks(td, sort_by))
|
||||
masks_bool = _unpack(td)
|
||||
if sort_by != "none" and masks_bool is not None:
|
||||
cx, area = _first_frame_cx_area(masks_bool)
|
||||
if sort_by == "left_to_right":
|
||||
order = sorted(range(len(cx)), key=lambda i: cx[i])
|
||||
else: # "area"
|
||||
order = sorted(range(len(area)), key=lambda i: -area[i])
|
||||
td = _subset_track_data(td, order)
|
||||
if object_indices.strip():
|
||||
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
||||
packed = td.get("packed_masks")
|
||||
@ -312,12 +300,12 @@ class SCAIL2ColoredMask(io.ComfyNode):
|
||||
|
||||
if ref_track_data is not None:
|
||||
ref = _prep(ref_track_data)
|
||||
ref_mask_image = _render_colored_masks(ref, "black")
|
||||
reference_image_mask = _render_colored_masks(ref, "black")
|
||||
else:
|
||||
H, W = drv["orig_size"]
|
||||
ref_mask_image = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
|
||||
return io.NodeOutput(mask_video, ref_mask_image)
|
||||
return io.NodeOutput(mask_video, reference_image_mask)
|
||||
|
||||
|
||||
class SCAILExtension(ComfyExtension):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user