mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-21 07:52:39 +08:00
1787 lines
92 KiB
Python
1787 lines
92 KiB
Python
# SAM3 video tracker: memory encoder, memory attention, SAM mask decoder/prompt encoder.
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from tqdm import tqdm
|
|
|
|
try:
|
|
import cv2
|
|
_HAS_CV2 = True
|
|
except ImportError:
|
|
from scipy import ndimage
|
|
_HAS_CV2 = False
|
|
|
|
import comfy.model_management
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
from comfy.ldm.sam3.sam import rope_2d, PositionEmbeddingSine
|
|
from comfy.ops import cast_to_input
|
|
from comfy.ldm.flux.math import apply_rope1
|
|
from comfy.ldm.cascade.common import LayerNorm2d_op
|
|
from comfy.ldm.sam3.sam import MLP, PositionEmbeddingRandom
|
|
from comfy.ldm.sam3.sam import TwoWayTransformer as SAMTwoWayTransformer
|
|
|
|
NO_OBJ_SCORE = -1024.0
|
|
|
|
|
|
def to_spatial(x, H, W):
|
|
"""Reshape (B, H*W, C) → (B, C, H, W)."""
|
|
return x.view(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
|
|
|
|
class MultiplexState:
|
|
"""Tracks object-to-slot assignments for multiplex tracking. Provides mux/demux operations."""
|
|
|
|
def __init__(self, num_objects, multiplex_count, device, dtype):
|
|
self.multiplex_count = multiplex_count
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self._build(num_objects)
|
|
|
|
def mux(self, x):
|
|
"""[N_obj, ...] -> [num_buckets, multiplex_count, ...]"""
|
|
out_shape = (self.num_buckets, self.multiplex_count) + x.shape[1:]
|
|
return (self.mux_matrix.to(device=x.device, dtype=x.dtype) @ x.reshape(self.total_valid_entries, -1)).view(out_shape)
|
|
|
|
def demux(self, x):
|
|
"""[num_buckets, multiplex_count, ...] -> [N_obj, ...]"""
|
|
out_shape = (self.total_valid_entries,) + x.shape[2:]
|
|
flat = x.reshape(self.num_buckets * self.multiplex_count, -1)
|
|
return (self.demux_matrix.to(device=x.device, dtype=x.dtype) @ flat).view(out_shape)
|
|
|
|
def get_valid_object_mask(self):
|
|
"""[num_buckets, multiplex_count] bool tensor, True for valid slots."""
|
|
return (self.mux_matrix.sum(dim=1) > 0).reshape(self.num_buckets, self.multiplex_count)
|
|
|
|
def _build(self, num_objects):
|
|
M = self.multiplex_count
|
|
self.num_buckets = (num_objects + M - 1) // M
|
|
self.total_valid_entries = num_objects
|
|
total_slots = self.num_buckets * M
|
|
self.mux_matrix = torch.zeros(total_slots, num_objects, device=self.device, dtype=self.dtype)
|
|
self.demux_matrix = torch.zeros(num_objects, total_slots, device=self.device, dtype=self.dtype)
|
|
oids = torch.arange(num_objects, device=self.device)
|
|
slots = (oids // M) * M + (oids % M)
|
|
self.mux_matrix[slots, oids] = 1.0
|
|
self.demux_matrix[oids, slots] = 1.0
|
|
|
|
def add_objects(self, n_new):
|
|
"""Grow multiplex state for n_new additional objects."""
|
|
self._build(self.total_valid_entries + n_new)
|
|
|
|
def _compute_mask_overlap(masks_a, masks_b):
|
|
"""Max of IoU and IoM (intersection over minimum area). More robust to size differences."""
|
|
a_flat = (masks_a > 0).float().flatten(1)
|
|
b_flat = (masks_b > 0).float().flatten(1)
|
|
intersection = a_flat @ b_flat.T
|
|
area_a = a_flat.sum(1, keepdim=True)
|
|
area_b = b_flat.sum(1, keepdim=True).T
|
|
iou = intersection / (area_a + area_b - intersection).clamp(min=1)
|
|
iom = intersection / torch.min(area_a.expand_as(iou), area_b.expand_as(iou)).clamp(min=1)
|
|
return torch.max(iou, iom)
|
|
|
|
|
|
def _nms_masks(masks, scores, thresh=0.5):
|
|
"""Mask-based NMS using IoU+IoM overlap. Returns (filtered_masks, filtered_scores)."""
|
|
order = scores.argsort(descending=True)
|
|
masks, scores = masks[order], scores[order]
|
|
keep = []
|
|
for i in range(masks.shape[0]):
|
|
if keep:
|
|
if _compute_mask_overlap(masks[i:i+1], masks[torch.tensor(keep, device=masks.device)]).max() >= thresh:
|
|
continue
|
|
keep.append(i)
|
|
return masks[keep], scores[keep]
|
|
|
|
|
|
def _get_connected_components(mask_bin):
|
|
"""Get connected component labels and areas. mask_bin: [B, 1, H, W] uint8."""
|
|
labels_list, areas_list = [], []
|
|
for i in range(mask_bin.shape[0]):
|
|
m = mask_bin[i, 0].cpu().numpy()
|
|
if _HAS_CV2:
|
|
_, labeled, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
|
|
areas = stats[labeled, cv2.CC_STAT_AREA].astype('int32')
|
|
else:
|
|
labeled, num_features = ndimage.label(m)
|
|
areas = np.zeros_like(m, dtype=np.int32)
|
|
for c in range(1, num_features + 1):
|
|
component = labeled == c
|
|
areas[component] = component.sum()
|
|
labels_list.append(torch.from_numpy(labeled).to(mask_bin.device))
|
|
areas_list.append(torch.from_numpy(areas).to(device=mask_bin.device, dtype=torch.int32))
|
|
return torch.stack(labels_list).unsqueeze(1), torch.stack(areas_list).unsqueeze(1)
|
|
|
|
|
|
def fill_holes_in_mask_scores(mask, max_area=0):
|
|
"""Remove small foreground sprinkles and fill small background holes using connected components."""
|
|
if max_area <= 0:
|
|
return mask
|
|
|
|
# Fill holes: small connected components in background → foreground
|
|
mask_bg = (mask <= 0).to(torch.uint8)
|
|
_, areas_bg = _get_connected_components(mask_bg)
|
|
small_bg = mask_bg.bool() & (areas_bg <= max_area)
|
|
mask = torch.where(small_bg, 0.1, mask)
|
|
|
|
# Remove sprinkles: small connected components in foreground → background
|
|
# Only remove if area < min(max_area, half of total foreground area)
|
|
mask_fg = (mask > 0).to(torch.uint8)
|
|
fg_area_thresh = mask_fg.sum(dim=(2, 3), keepdim=True, dtype=torch.int32)
|
|
fg_area_thresh.floor_divide_(2).clamp_(max=max_area)
|
|
_, areas_fg = _get_connected_components(mask_fg)
|
|
small_fg = mask_fg.bool() & (areas_fg <= fg_area_thresh)
|
|
mask = torch.where(small_fg, -0.1, mask)
|
|
|
|
return mask
|
|
|
|
|
|
def apply_rope_memory(q, k, freqs, num_heads, num_k_exclude_rope=0):
|
|
"""Apply 2D axial RoPE to memory attention using flux rope format.
|
|
|
|
Args:
|
|
q: [B, Nq, C] projected queries (current frame features)
|
|
k: [B, Nk, C] projected keys (memory tokens)
|
|
freqs: [1, Nq, dim//2, 2, 2] flux-format rotation matrices for one frame
|
|
num_heads: number of attention heads
|
|
num_k_exclude_rope: number of trailing k tokens to skip RoPE (object pointers)
|
|
"""
|
|
B, Nq, C = q.shape
|
|
head_dim = C // num_heads
|
|
|
|
# freqs shape: [1, 1, Nq, dim//2, 2, 2] (heads broadcast dim already included)
|
|
q_h = q.view(B, Nq, num_heads, head_dim).transpose(1, 2)
|
|
q_h = apply_rope1(q_h, freqs)
|
|
q = q_h.transpose(1, 2).reshape(B, Nq, C)
|
|
|
|
# Apply RoPE to k (excluding last num_k_exclude_rope tokens)
|
|
Nk = k.shape[1]
|
|
num_k_rope = Nk - num_k_exclude_rope
|
|
if num_k_rope > 0:
|
|
# Repeat freqs for multiple frames of spatial memory
|
|
Nf = freqs.shape[2] # spatial positions in one frame
|
|
if num_k_rope > Nf:
|
|
r = (num_k_rope + Nf - 1) // Nf
|
|
pe_k = freqs.repeat(1, 1, r, 1, 1, 1)[:, :, :num_k_rope]
|
|
else:
|
|
pe_k = freqs[:, :, :num_k_rope]
|
|
|
|
k_h = k[:, :num_k_rope].view(B, num_k_rope, num_heads, head_dim).transpose(1, 2)
|
|
k_h = apply_rope1(k_h, pe_k)
|
|
k = k.clone()
|
|
k[:, :num_k_rope] = k_h.transpose(1, 2).reshape(B, num_k_rope, C)
|
|
|
|
return q, k
|
|
|
|
|
|
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
|
"""1D sinusoidal positional encoding for temporal positions."""
|
|
pe_dim = dim // 2
|
|
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
|
|
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
|
pos_embed = pos_inds.unsqueeze(-1) / dim_t
|
|
return torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
|
|
|
|
|
|
def _pad_to_buckets(tensor, target_buckets):
|
|
"""Pad a [num_buckets, ...] tensor to target_buckets along dim 0 if needed."""
|
|
if tensor.shape[0] >= target_buckets:
|
|
return tensor
|
|
pad_shape = (target_buckets - tensor.shape[0],) + tensor.shape[1:]
|
|
return torch.cat([tensor, torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)], dim=0)
|
|
|
|
|
|
def pack_masks(masks):
|
|
"""Pack binary masks [*, H, W] to bit-packed [*, H, W//8] uint8. W must be divisible by 8."""
|
|
binary = masks > 0
|
|
shifts = torch.arange(8, device=masks.device)
|
|
return (binary.view(*masks.shape[:-1], -1, 8) * (1 << shifts)).sum(-1).byte()
|
|
|
|
|
|
def unpack_masks(packed):
|
|
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
|
|
shifts = torch.arange(8, device=packed.device)
|
|
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
|
|
|
|
|
|
def _compute_backbone(backbone_fn, frame, frame_idx=None):
|
|
"""Compute backbone features for a single frame. Returns (vision_feats, vision_pos, feat_sizes, features, trunk_out)."""
|
|
features, positions, trunk_out = backbone_fn(frame, frame_idx=frame_idx)
|
|
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in features]
|
|
vision_feats = [x.flatten(2).permute(0, 2, 1) for x in features]
|
|
vision_pos = [x.flatten(2).permute(0, 2, 1) for x in positions]
|
|
return vision_feats, vision_pos, feat_sizes, features, trunk_out
|
|
|
|
|
|
def collect_memory_tokens(output_dict, frame_idx, num_maskmem, maskmem_tpos_enc, device,
|
|
collect_image_feats=False, tpos_v2=False, num_buckets=None):
|
|
"""Collect spatial memory, position encodings, and optionally image features from past frames."""
|
|
to_cat_memory, to_cat_memory_pos = [], []
|
|
to_cat_image_feat, to_cat_image_pos = [], []
|
|
|
|
def _append(out, tpos_idx):
|
|
feats = out["maskmem_features"].to(device)
|
|
if num_buckets is not None:
|
|
feats = _pad_to_buckets(feats, num_buckets)
|
|
to_cat_memory.append(feats.flatten(2).permute(0, 2, 1))
|
|
enc = out["maskmem_pos_enc"][-1].to(device).flatten(2).permute(0, 2, 1)
|
|
if num_buckets is not None:
|
|
enc = _pad_to_buckets(enc, num_buckets)
|
|
tpos = cast_to_input(maskmem_tpos_enc[tpos_idx], enc)
|
|
to_cat_memory_pos.append(enc + tpos)
|
|
if collect_image_feats and "image_features" in out:
|
|
to_cat_image_feat.append(out["image_features"].to(device))
|
|
to_cat_image_pos.append(out["image_pos_enc"].to(device) + tpos)
|
|
|
|
cond_outputs = output_dict["cond_frame_outputs"]
|
|
for t, out in cond_outputs.items():
|
|
if tpos_v2:
|
|
t_pos = frame_idx - t
|
|
tpos_idx = num_maskmem - t_pos - 1 if 0 < t_pos < num_maskmem else num_maskmem - 1
|
|
else:
|
|
tpos_idx = num_maskmem - 1
|
|
_append(out, tpos_idx)
|
|
|
|
for t_pos in range(1, num_maskmem):
|
|
out = output_dict["non_cond_frame_outputs"].get(frame_idx - (num_maskmem - t_pos), None)
|
|
if out is None or out.get("maskmem_features") is None:
|
|
continue
|
|
_append(out, num_maskmem - t_pos - 1)
|
|
|
|
return to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs
|
|
|
|
|
|
def compute_tpos_enc(rel_pos_list, device, d_model, proj_layer, dtype=None, max_abs_pos=None):
|
|
"""Temporal position encoding for object pointers."""
|
|
pos_enc = torch.tensor(rel_pos_list, dtype=torch.float32, device=device) / max((max_abs_pos or 2) - 1, 1)
|
|
pos_enc = get_1d_sine_pe(pos_enc, dim=d_model)
|
|
if dtype is not None:
|
|
pos_enc = pos_enc.to(dtype)
|
|
return proj_layer(pos_enc)
|
|
|
|
|
|
def forward_sam_heads(backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn,
|
|
image_size, point_inputs=None, mask_inputs=None, box_inputs=None,
|
|
high_res_features=None, multimask_output=False):
|
|
"""Shared SAM prompt encoder + mask decoder forward for both SAM3 and SAM3.1 trackers."""
|
|
device = backbone_features.device
|
|
# Batch size from inputs (mask_inputs may have N_obj > 1 while backbone is batch 1)
|
|
if mask_inputs is not None:
|
|
B = mask_inputs.shape[0]
|
|
elif box_inputs is not None:
|
|
B = box_inputs.shape[0]
|
|
elif point_inputs is not None:
|
|
B = point_inputs["point_coords"].shape[0]
|
|
else:
|
|
B = backbone_features.shape[0]
|
|
|
|
if point_inputs is not None:
|
|
sam_point_coords = point_inputs["point_coords"]
|
|
sam_point_labels = point_inputs["point_labels"]
|
|
else:
|
|
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
|
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
|
|
|
if mask_inputs is not None:
|
|
prompt_size = (prompt_encoder.image_embedding_size[0] * 4, prompt_encoder.image_embedding_size[1] * 4)
|
|
if mask_inputs.shape[-2:] != prompt_size:
|
|
sam_mask_prompt = F.interpolate(mask_inputs, size=prompt_size, mode="bilinear", align_corners=False, antialias=True)
|
|
else:
|
|
sam_mask_prompt = mask_inputs
|
|
else:
|
|
sam_mask_prompt = None
|
|
|
|
sparse, dense = prompt_encoder(points=(sam_point_coords, sam_point_labels), boxes=box_inputs, masks=sam_mask_prompt)
|
|
sparse = cast_to_input(sparse, backbone_features)
|
|
dense = cast_to_input(dense, backbone_features)
|
|
image_pe = cast_to_input(prompt_encoder.get_dense_pe(), backbone_features)
|
|
|
|
low_res_multimasks, ious, sam_output_tokens, object_score_logits = mask_decoder(
|
|
image_embeddings=backbone_features, image_pe=image_pe,
|
|
sparse_prompt_embeddings=sparse, dense_prompt_embeddings=dense,
|
|
high_res_features=high_res_features, multimask_output=multimask_output, return_all=True,
|
|
)
|
|
|
|
is_obj_appearing = object_score_logits > 0
|
|
low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks,
|
|
torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_multimasks.dtype))
|
|
high_res_multimasks = F.interpolate(low_res_multimasks, size=(image_size, image_size), mode="bilinear", align_corners=False)
|
|
|
|
sam_output_token = sam_output_tokens[:, 0]
|
|
if multimask_output:
|
|
best_iou_inds = torch.argmax(ious, dim=-1)
|
|
batch_inds = torch.arange(B, device=device)
|
|
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
|
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
|
if sam_output_tokens.size(1) > 1:
|
|
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
|
else:
|
|
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
|
|
|
obj_ptr = obj_ptr_proj(sam_output_token)
|
|
obj_ptr = no_obj_fn(obj_ptr, is_obj_appearing)
|
|
|
|
return low_res_masks, high_res_masks, obj_ptr, object_score_logits
|
|
|
|
|
|
def use_mask_as_output(backbone_features, high_res_features, mask_inputs, mask_downsample,
|
|
prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, image_size, backbone_stride):
|
|
"""Shared mask-as-output for both SAM3 and SAM3.1 trackers."""
|
|
out_scale, out_bias = 20.0, -10.0
|
|
mask_inputs_float = cast_to_input(mask_inputs, backbone_features)
|
|
high_res_masks = mask_inputs_float * out_scale + out_bias
|
|
low_res_masks = F.interpolate(high_res_masks, size=(image_size // backbone_stride * 4,) * 2,
|
|
mode="bilinear", align_corners=False, antialias=True)
|
|
_, _, obj_ptr, _ = forward_sam_heads(
|
|
backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn,
|
|
image_size, mask_inputs=mask_downsample(mask_inputs_float), high_res_features=high_res_features,
|
|
)
|
|
is_obj_appearing = torch.any(mask_inputs.flatten(1) > 0.0, dim=1)[..., None]
|
|
alpha = is_obj_appearing.to(obj_ptr.dtype)
|
|
object_score_logits = out_scale * alpha + out_bias
|
|
return low_res_masks, high_res_masks, obj_ptr, object_score_logits
|
|
|
|
|
|
# Split attention with configurable input dims (for asymmetric cross-attention)
|
|
class SplitAttn(nn.Module):
|
|
def __init__(self, embed_dim, num_heads=1, kv_dim=None, internal_dim=None, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
kv_dim = kv_dim or embed_dim
|
|
internal_dim = internal_dim or embed_dim
|
|
self.q_proj = operations.Linear(embed_dim, internal_dim, device=device, dtype=dtype)
|
|
self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
|
|
self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
|
|
self.out_proj = operations.Linear(internal_dim, embed_dim, device=device, dtype=dtype)
|
|
|
|
def forward(self, q, k=None, v=None, rope=None, num_k_exclude_rope=0):
|
|
if k is None:
|
|
k = q
|
|
if v is None:
|
|
v = k
|
|
q = self.q_proj(q)
|
|
k = self.k_proj(k)
|
|
v = self.v_proj(v)
|
|
if rope is not None:
|
|
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
|
|
out = optimized_attention(q, k, v, self.num_heads)
|
|
return self.out_proj(out)
|
|
|
|
|
|
class MemoryAttnLayer(nn.Module):
|
|
def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.self_attn = SplitAttn(d_model, num_heads, device=device, dtype=dtype, operations=operations)
|
|
self.cross_attn_image = SplitAttn(d_model, num_heads, kv_dim=kv_dim, device=device, dtype=dtype, operations=operations)
|
|
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
|
|
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
|
|
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
|
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
|
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
|
|
|
def forward(self, x, memory, memory_pos=None, rope=None, num_k_exclude_rope=0):
|
|
x = x + self.self_attn(self.norm1(x), rope=rope)
|
|
mem_k = memory + memory_pos if memory_pos is not None else memory
|
|
x = x + self.cross_attn_image(self.norm2(x), mem_k, memory, rope=rope, num_k_exclude_rope=num_k_exclude_rope)
|
|
normed = self.norm3(x)
|
|
x = x + self.linear2(F.relu(self.linear1(normed)))
|
|
return x
|
|
|
|
|
|
class MemoryAttnEncoder(nn.Module):
|
|
def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14,
|
|
device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([
|
|
MemoryAttnLayer(d_model, num_heads, kv_dim, dim_ff, device=device, dtype=dtype, operations=operations)
|
|
for _ in range(num_layers)
|
|
])
|
|
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
|
hw = image_size // patch_size
|
|
self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False)
|
|
|
|
def forward(self, x, memory, src_pos=None, memory_pos=None, num_k_exclude_rope=0):
|
|
if src_pos is not None:
|
|
x = x + 0.1 * src_pos
|
|
|
|
rope = self._rope.to(device=x.device)
|
|
for layer in self.layers:
|
|
x = layer(x, memory, memory_pos=memory_pos, rope=rope, num_k_exclude_rope=num_k_exclude_rope)
|
|
return self.norm(x)
|
|
|
|
|
|
class MemoryTransformer(nn.Module):
|
|
def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.encoder = MemoryAttnEncoder(d_model, num_heads, kv_dim, dim_ff, num_layers, device=device, dtype=dtype, operations=operations)
|
|
|
|
|
|
def _upscale_masks(output_upscaling, conv_s0, conv_s1, src_out, high_res_features):
|
|
"""Shared upscaling for SAM mask decoders: deconv + high-res feature integration."""
|
|
dc1, ln1, act1, dc2, act2 = output_upscaling
|
|
if high_res_features is not None:
|
|
upscaled = act1(ln1(dc1(src_out) + conv_s1(high_res_features[1])))
|
|
upscaled = act2(dc2(upscaled) + conv_s0(high_res_features[0]))
|
|
else:
|
|
upscaled = act2(dc2(act1(ln1(dc1(src_out)))))
|
|
return upscaled
|
|
|
|
|
|
class SAMMaskDecoder(nn.Module):
|
|
def __init__(self, d_model=256, num_multimask_outputs=3, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.num_mask_tokens = num_multimask_outputs + 1
|
|
|
|
self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations)
|
|
|
|
self.iou_token = operations.Embedding(1, d_model, device=device, dtype=dtype)
|
|
self.mask_tokens = operations.Embedding(self.num_mask_tokens, d_model, device=device, dtype=dtype)
|
|
self.obj_score_token = operations.Embedding(1, d_model, device=device, dtype=dtype)
|
|
|
|
# Output upscaling: d_model -> d_model//4 -> d_model//8 at 4x resolution
|
|
LN2d = LayerNorm2d_op(operations)
|
|
self.output_upscaling = nn.Sequential(
|
|
operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(),
|
|
operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(),
|
|
)
|
|
|
|
# High-res feature integration
|
|
self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype)
|
|
self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype)
|
|
|
|
# Per-mask hypernetwork MLPs
|
|
self.output_hypernetworks_mlps = nn.ModuleList([
|
|
MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations)
|
|
for _ in range(self.num_mask_tokens)
|
|
])
|
|
|
|
self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_tokens, 3, device=device, dtype=dtype, operations=operations)
|
|
self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations)
|
|
|
|
def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings,
|
|
high_res_features=None, multimask_output=False, return_all=False):
|
|
B = sparse_prompt_embeddings.shape[0]
|
|
ref = sparse_prompt_embeddings
|
|
# Token order: [obj_score(1), iou(1), mask(num_mask_tokens)]
|
|
tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref),
|
|
cast_to_input(self.iou_token.weight, ref),
|
|
cast_to_input(self.mask_tokens.weight, ref)], dim=0)
|
|
tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1)
|
|
|
|
src = image_embeddings
|
|
if src.shape[0] != B:
|
|
src = src.expand(B, -1, -1, -1)
|
|
src = src + dense_prompt_embeddings
|
|
pos_src = image_pe.expand(B, -1, -1, -1)
|
|
|
|
b, c, h, w = src.shape
|
|
src_flat = src.flatten(2).permute(0, 2, 1)
|
|
pos_flat = pos_src.flatten(2).permute(0, 2, 1)
|
|
|
|
hs, src_out = self.transformer(src_flat, pos_flat, tokens)
|
|
|
|
obj_score_token_out = hs[:, 0, :]
|
|
iou_token_out = hs[:, 1, :]
|
|
mask_tokens_out = hs[:, 2:2 + self.num_mask_tokens, :]
|
|
|
|
src_out = src_out.permute(0, 2, 1).view(b, c, h, w)
|
|
upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features)
|
|
|
|
hyper_in = torch.stack([
|
|
mlp(mask_tokens_out[:, i, :]) for i, mlp in enumerate(self.output_hypernetworks_mlps)
|
|
], dim=1)
|
|
|
|
masks = (hyper_in @ upscaled.flatten(2)).view(B, self.num_mask_tokens, upscaled.shape[2], upscaled.shape[3])
|
|
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
object_score_logits = self.pred_obj_score_head(obj_score_token_out)
|
|
|
|
if multimask_output:
|
|
out_masks = masks[:, 1:]
|
|
out_iou = iou_pred[:, 1:]
|
|
out_tokens = mask_tokens_out[:, 1:]
|
|
else:
|
|
out_masks = masks[:, 0:1]
|
|
out_iou = iou_pred[:, 0:1]
|
|
out_tokens = mask_tokens_out[:, 0:1]
|
|
|
|
if return_all:
|
|
return out_masks, out_iou, out_tokens, object_score_logits
|
|
return out_masks, out_iou
|
|
|
|
|
|
class SAMPromptEncoder(nn.Module):
|
|
def __init__(self, d_model=256, image_embedding_size=(72, 72), input_image_size=(1008, 1008), device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.embed_dim = d_model
|
|
self.image_embedding_size = image_embedding_size
|
|
self.input_image_size = input_image_size
|
|
|
|
self.pe_layer = PositionEmbeddingRandom(d_model // 2)
|
|
self.point_embeddings = nn.ModuleList([
|
|
operations.Embedding(1, d_model, device=device, dtype=dtype) for _ in range(4)
|
|
])
|
|
self.not_a_point_embed = operations.Embedding(1, d_model, device=device, dtype=dtype)
|
|
|
|
LN2d = LayerNorm2d_op(operations)
|
|
self.mask_downscaling = nn.Sequential(
|
|
operations.Conv2d(1, 4, kernel_size=2, stride=2, device=device, dtype=dtype),
|
|
LN2d(4, device=device, dtype=dtype), nn.GELU(),
|
|
operations.Conv2d(4, 16, kernel_size=2, stride=2, device=device, dtype=dtype),
|
|
LN2d(16, device=device, dtype=dtype), nn.GELU(),
|
|
operations.Conv2d(16, d_model, kernel_size=1, device=device, dtype=dtype),
|
|
)
|
|
self.no_mask_embed = operations.Embedding(1, d_model, device=device, dtype=dtype)
|
|
|
|
def get_dense_pe(self):
|
|
return self.pe_layer(self.image_embedding_size)
|
|
|
|
def forward(self, points=None, boxes=None, masks=None):
|
|
ref = points[0] if points is not None else boxes if boxes is not None else masks
|
|
B = 1
|
|
sparse = torch.empty((B, 0, self.embed_dim), device=ref.device, dtype=ref.dtype)
|
|
|
|
if points is not None:
|
|
coords, labels = points
|
|
B = coords.shape[0]
|
|
# Pad with an extra point (label=-1) when no boxes are provided (matching reference)
|
|
if boxes is None:
|
|
coords = torch.cat([coords, torch.zeros(B, 1, 2, device=coords.device, dtype=coords.dtype)], dim=1)
|
|
labels = torch.cat([labels, -torch.ones(B, 1, device=labels.device, dtype=labels.dtype)], dim=1)
|
|
pe = self.pe_layer.forward_with_coords(coords + 0.5, self.input_image_size)
|
|
for i in range(4):
|
|
pe[labels == i] += cast_to_input(self.point_embeddings[i].weight, ref)
|
|
invalid = (labels == -1)
|
|
pe[invalid] = 0.0
|
|
pe[invalid] += cast_to_input(self.not_a_point_embed.weight, ref)
|
|
sparse = torch.cat([sparse.expand(B, -1, -1), pe], dim=1)
|
|
|
|
if boxes is not None:
|
|
B = boxes.shape[0]
|
|
corners = self.pe_layer.forward_with_coords((boxes.reshape(-1, 2, 2) + 0.5), self.input_image_size)
|
|
corners[:, 0] += cast_to_input(self.point_embeddings[2].weight, ref)
|
|
corners[:, 1] += cast_to_input(self.point_embeddings[3].weight, ref)
|
|
sparse = torch.cat([sparse.expand(B, -1, -1), corners], dim=1)
|
|
|
|
if masks is not None:
|
|
dense = self.mask_downscaling(masks)
|
|
else:
|
|
dense = cast_to_input(self.no_mask_embed.weight, ref).reshape(1, -1, 1, 1).expand(
|
|
B, -1, self.image_embedding_size[0], self.image_embedding_size[1])
|
|
|
|
return sparse, dense
|
|
|
|
|
|
class CXBlock(nn.Module):
|
|
def __init__(self, dim=256, kernel_size=7, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.dwconv = operations.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim, device=device, dtype=dtype)
|
|
self.norm = operations.LayerNorm(dim, device=device, dtype=dtype)
|
|
self.pwconv1 = operations.Linear(dim, 4 * dim, device=device, dtype=dtype)
|
|
self.pwconv2 = operations.Linear(4 * dim, dim, device=device, dtype=dtype)
|
|
self.gamma = nn.Parameter(torch.ones(dim, device=device, dtype=dtype))
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
x = self.dwconv(x).permute(0, 2, 3, 1)
|
|
x = self.pwconv2(F.gelu(self.pwconv1(self.norm(x))))
|
|
x.mul_(cast_to_input(self.gamma, x))
|
|
return residual + x.permute(0, 3, 1, 2)
|
|
|
|
|
|
class MaskDownSampler(nn.Module):
|
|
def __init__(self, out_dim=256, in_chans=1, channels=None, interpol_size=(1152, 1152), device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.interpol_size = list(interpol_size) if interpol_size else None
|
|
if channels is None:
|
|
channels = [4, 16, 64, out_dim] # SAM3 default
|
|
LN2d = LayerNorm2d_op(operations)
|
|
layers = []
|
|
prev = in_chans
|
|
for ch in channels:
|
|
layers += [operations.Conv2d(prev, ch, kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
|
LN2d(ch, device=device, dtype=dtype), nn.GELU()]
|
|
prev = ch
|
|
layers.append(operations.Conv2d(prev, out_dim, kernel_size=1, device=device, dtype=dtype))
|
|
self.encoder = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
if self.interpol_size is not None and list(x.shape[-2:]) != self.interpol_size:
|
|
x = F.interpolate(x, size=self.interpol_size, mode="bilinear", align_corners=False, antialias=True)
|
|
return self.encoder(x)
|
|
|
|
|
|
class Fuser(nn.Module):
|
|
def __init__(self, dim=256, num_layers=2, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.layers = nn.Sequential(*[CXBlock(dim, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)])
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
|
|
# --- SAM3.1 Multiplex components ---
|
|
|
|
class DecoupledMemoryAttnLayer(nn.Module):
|
|
"""Decoupled cross-attention layer for SAM3.1: fuses image and memory projections."""
|
|
|
|
def __init__(self, d_model=256, num_heads=1, dim_ff=2048, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
# Self-attention projections (flat, not nested)
|
|
self.self_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
self.self_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
self.self_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
self.self_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
# Cross-attention projections
|
|
self.cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
self.cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
self.cross_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
self.cross_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
# Image cross-attention (q/k only, fused with cross_attn)
|
|
self.image_cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
self.image_cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
# FFN
|
|
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
|
|
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
|
|
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
|
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
|
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
|
|
|
def forward(self, image, x, memory_image, memory, memory_image_pos=None,
|
|
rope=None, num_k_exclude_rope=0):
|
|
# Self-attention with RoPE
|
|
normed = self.norm1(x)
|
|
q = self.self_attn_q_proj(normed)
|
|
k = self.self_attn_k_proj(normed)
|
|
v = self.self_attn_v_proj(normed)
|
|
if rope is not None:
|
|
q, k = apply_rope_memory(q, k, rope, self.num_heads, 0)
|
|
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
|
|
|
|
# Decoupled cross-attention: fuse image and memory projections
|
|
normed = self.norm2(x)
|
|
q = self.image_cross_attn_q_proj(image) + self.cross_attn_q_proj(normed)
|
|
k = self.image_cross_attn_k_proj(memory_image) + self.cross_attn_k_proj(memory)
|
|
if memory_image_pos is not None:
|
|
k = k + memory_image_pos
|
|
v = self.cross_attn_v_proj(memory)
|
|
if rope is not None:
|
|
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
|
|
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
|
|
|
|
# FFN
|
|
x = x + self.linear2(F.gelu(self.linear1(self.norm3(x))))
|
|
return image, x
|
|
|
|
|
|
class DecoupledMemoryEncoder(nn.Module):
|
|
"""Memory attention encoder for SAM3.1 with decoupled cross-attention."""
|
|
|
|
def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14,
|
|
device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([
|
|
DecoupledMemoryAttnLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
|
|
for _ in range(num_layers)
|
|
])
|
|
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
|
hw = image_size // patch_size
|
|
self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False)
|
|
|
|
def forward(self, x, memory, memory_pos=None, src_pos=None, num_k_exclude_rope=0,
|
|
memory_image=None, memory_image_pos=None):
|
|
image = x # constant residual for decoupled cross-attention
|
|
output = x
|
|
if src_pos is not None:
|
|
output = output + 0.1 * src_pos
|
|
|
|
B, _, C = x.shape
|
|
rope = self._rope.to(device=x.device)
|
|
|
|
# memory_image: raw backbone features from past frames for decoupled cross-attention
|
|
if memory_image is None:
|
|
# Fallback: use spatial portion of memory (without obj pointers)
|
|
num_spatial = memory.shape[1] - num_k_exclude_rope
|
|
memory_image = memory[:, :num_spatial]
|
|
memory_image_pos = memory_pos[:, :num_spatial] if memory_pos is not None else None
|
|
# Pad memory_image to match memory length (zeros for obj pointer tokens)
|
|
if memory_image.shape[1] < memory.shape[1]:
|
|
pad_len = memory.shape[1] - memory_image.shape[1]
|
|
pad = torch.zeros(B, pad_len, C, device=memory.device, dtype=memory.dtype)
|
|
memory_image = torch.cat([memory_image, pad], dim=1)
|
|
if memory_image_pos is not None:
|
|
ptr_pos = memory_pos[:, -pad_len:] if memory_pos is not None else torch.zeros_like(pad)
|
|
memory_image_pos = torch.cat([memory_image_pos, ptr_pos], dim=1)
|
|
|
|
for layer in self.layers:
|
|
image, output = layer(image, output, memory_image, memory,
|
|
memory_image_pos=memory_image_pos, rope=rope,
|
|
num_k_exclude_rope=num_k_exclude_rope)
|
|
|
|
return self.norm(output)
|
|
|
|
|
|
class DecoupledMemoryTransformer(nn.Module):
|
|
def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.encoder = DecoupledMemoryEncoder(d_model, num_heads, dim_ff, num_layers,
|
|
device=device, dtype=dtype, operations=operations)
|
|
|
|
|
|
class MemoryBackbone(nn.Module):
|
|
"""Memory encoder: downsamples mask, fuses with pixel features, optionally compresses."""
|
|
|
|
def __init__(self, d_model=256, out_dim=None, in_chans=1, channels=None, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.mask_downsampler = MaskDownSampler(d_model, in_chans=in_chans, channels=channels, device=device, dtype=dtype, operations=operations)
|
|
self.pix_feat_proj = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype)
|
|
self.fuser = Fuser(d_model, num_layers=2, device=device, dtype=dtype, operations=operations)
|
|
self.has_out_proj = out_dim is not None and out_dim != d_model
|
|
if self.has_out_proj:
|
|
self.out_proj = operations.Conv2d(d_model, out_dim, kernel_size=1, device=device, dtype=dtype)
|
|
feat_dim = out_dim
|
|
else:
|
|
feat_dim = d_model
|
|
self.position_encoding = PositionEmbeddingSine(num_pos_feats=feat_dim, normalize=True)
|
|
|
|
def forward(self, image_features, mask_for_mem, skip_mask_sigmoid=False):
|
|
if not skip_mask_sigmoid:
|
|
mask_for_mem = mask_for_mem.sigmoid()
|
|
mask_features = self.mask_downsampler(cast_to_input(mask_for_mem, image_features))
|
|
if mask_features.shape[-2:] != image_features.shape[-2:]:
|
|
mask_features = F.interpolate(mask_features, size=image_features.shape[-2:], mode="bilinear", align_corners=False)
|
|
features = self.pix_feat_proj(image_features) + mask_features
|
|
features = self.fuser(features)
|
|
if self.has_out_proj:
|
|
features = self.out_proj(features)
|
|
pos = cast_to_input(self.position_encoding(features), features)
|
|
return {"vision_features": features, "vision_pos_enc": [pos]}
|
|
|
|
|
|
class MultiplexMaskDecoder(nn.Module):
|
|
"""SAM mask decoder for SAM3.1 multiplex: predicts masks for num_multiplex objects simultaneously.
|
|
|
|
Uses multimask_outputs_only=True: num_mask_output_per_object = num_multimask_outputs (no +1).
|
|
Hypernetwork MLPs are shared across multiplex objects.
|
|
Token order: [obj_score_token(M), iou_token(M), mask_tokens(M*T)].
|
|
"""
|
|
|
|
def __init__(self, d_model=256, num_multiplex=16, num_multimask_outputs=3, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.num_multiplex = num_multiplex
|
|
self.num_mask_output_per_object = num_multimask_outputs # 3 (multimask_outputs_only)
|
|
total_mask_tokens = num_multiplex * self.num_mask_output_per_object # 48
|
|
|
|
self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations)
|
|
|
|
self.obj_score_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype)
|
|
self.iou_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype)
|
|
self.mask_tokens = operations.Embedding(total_mask_tokens, d_model, device=device, dtype=dtype)
|
|
|
|
LN2d = LayerNorm2d_op(operations)
|
|
self.output_upscaling = nn.Sequential(
|
|
operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
|
|
LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(),
|
|
operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(),
|
|
)
|
|
self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype)
|
|
self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype)
|
|
|
|
# Shared across all multiplex objects (one per mask output)
|
|
self.output_hypernetworks_mlps = nn.ModuleList([
|
|
MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations)
|
|
for _ in range(self.num_mask_output_per_object)
|
|
])
|
|
self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_output_per_object, 3, device=device, dtype=dtype, operations=operations)
|
|
self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations)
|
|
|
|
def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings,
|
|
high_res_features=None, multimask_output=False, return_all=False, extra_per_object_embeddings=None):
|
|
B = sparse_prompt_embeddings.shape[0]
|
|
M = self.num_multiplex
|
|
T = self.num_mask_output_per_object
|
|
|
|
# Token order: [obj_score(M), iou(M), mask(M*T)]
|
|
ref = sparse_prompt_embeddings
|
|
mask_tokens = cast_to_input(self.mask_tokens.weight, ref)
|
|
if extra_per_object_embeddings is not None:
|
|
mask_tokens = mask_tokens.view(1, M, T, -1).expand(B, -1, -1, -1) + extra_per_object_embeddings.unsqueeze(2)
|
|
mask_tokens = mask_tokens.flatten(1, 2) # [B, M*T, C]
|
|
other_tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref),
|
|
cast_to_input(self.iou_token.weight, ref)], dim=0).unsqueeze(0).expand(B, -1, -1)
|
|
tokens = torch.cat([other_tokens, mask_tokens, sparse_prompt_embeddings], dim=1)
|
|
else:
|
|
tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref),
|
|
cast_to_input(self.iou_token.weight, ref), mask_tokens], dim=0)
|
|
tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1)
|
|
|
|
src = image_embeddings
|
|
if src.shape[0] != B:
|
|
src = src.expand(B, -1, -1, -1)
|
|
src = src + dense_prompt_embeddings
|
|
pos_src = image_pe.expand(B, -1, -1, -1)
|
|
|
|
b, c, h, w = src.shape
|
|
hs, src_out = self.transformer(src.flatten(2).permute(0, 2, 1), pos_src.flatten(2).permute(0, 2, 1), tokens)
|
|
|
|
# Parse output tokens
|
|
obj_score_token_out = hs[:, :M]
|
|
iou_token_out = hs[:, M:2 * M]
|
|
mask_tokens_out = hs[:, 2 * M:2 * M + M * T]
|
|
|
|
src_out = src_out.permute(0, 2, 1).view(b, c, h, w)
|
|
upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features)
|
|
|
|
# Reshape mask tokens to [B, M, T, C] and apply shared hypernetwork MLPs per mask output index
|
|
mask_tokens_2d = mask_tokens_out.view(B, M, T, -1)
|
|
hyper_in = torch.stack([
|
|
self.output_hypernetworks_mlps[i](mask_tokens_2d[:, :, i, :]) # [B, M, C//8]
|
|
for i in range(T)
|
|
], dim=2) # [B, M, T, C//8]
|
|
|
|
# Generate masks: [B, M*T, H*W] -> [B, M, T, H, W]
|
|
masks = torch.bmm(hyper_in.flatten(1, 2), upscaled.flatten(2)).view(b, M, T, upscaled.shape[2], upscaled.shape[3])
|
|
|
|
# IoU and object scores
|
|
iou_pred = self.iou_prediction_head(iou_token_out).view(b, M, T)
|
|
object_score_logits = self.pred_obj_score_head(obj_score_token_out) # [B, M, 1]
|
|
|
|
# multimask_outputs_only: always output all T masks (no singlemask token)
|
|
sam_tokens_out = mask_tokens_2d[:, :, 0:1] # [B, M, 1, C]
|
|
|
|
if return_all:
|
|
return masks, iou_pred, sam_tokens_out, object_score_logits
|
|
return masks, iou_pred
|
|
|
|
|
|
class SAM3Tracker(nn.Module):
|
|
def __init__(self, d_model=256, mem_dim=64, num_maskmem=7, device=None, dtype=None, operations=None, **kwargs):
|
|
super().__init__()
|
|
|
|
# Memory attention transformer
|
|
self.transformer = MemoryTransformer(d_model, num_heads=1, kv_dim=mem_dim, dim_ff=2048, num_layers=4,
|
|
device=device, dtype=dtype, operations=operations)
|
|
# SAM components
|
|
self.sam_mask_decoder = SAMMaskDecoder(d_model, device=device, dtype=dtype, operations=operations)
|
|
self.sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations)
|
|
|
|
# Memory backbone
|
|
self.maskmem_backbone = MemoryBackbone(d_model, out_dim=mem_dim, device=device, dtype=dtype, operations=operations)
|
|
|
|
# Standalone parameters
|
|
self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype))
|
|
self.no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype))
|
|
self.register_buffer("no_mem_pos_enc", torch.zeros(1, 1, d_model, device=device, dtype=dtype)) # checkpoint key, unused in forward
|
|
self.no_obj_embed_spatial = nn.Parameter(torch.zeros(1, mem_dim, device=device, dtype=dtype))
|
|
self.no_obj_ptr = nn.Parameter(torch.zeros(1, d_model, device=device, dtype=dtype))
|
|
|
|
# Object pointer projection
|
|
self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
|
|
self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype)
|
|
|
|
# Mask downsample: Conv2d stride 4 to reduce GT mask to SAM logit scale
|
|
self.mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype)
|
|
|
|
# Config
|
|
self.d_model = d_model
|
|
self.mem_dim = mem_dim
|
|
self.num_maskmem = num_maskmem
|
|
self.image_size = 1008
|
|
self.backbone_stride = 14
|
|
self.max_obj_ptrs_in_encoder = 16
|
|
self.sigmoid_scale_for_mem_enc = 20.0
|
|
self.sigmoid_bias_for_mem_enc = -10.0
|
|
|
|
def _no_obj_blend(self, obj_ptr, is_obj):
|
|
alpha = is_obj.to(obj_ptr.dtype)
|
|
return torch.lerp(cast_to_input(self.no_obj_ptr, obj_ptr), obj_ptr, alpha)
|
|
|
|
def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None,
|
|
high_res_features=None, multimask_output=False):
|
|
return forward_sam_heads(backbone_features, self.sam_prompt_encoder, self.sam_mask_decoder,
|
|
self.obj_ptr_proj, self._no_obj_blend, self.image_size,
|
|
point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output)
|
|
|
|
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
|
return use_mask_as_output(backbone_features, high_res_features, mask_inputs,
|
|
self.mask_downsample, self.sam_prompt_encoder, self.sam_mask_decoder,
|
|
self.obj_ptr_proj, self._no_obj_blend, self.image_size, self.backbone_stride)
|
|
|
|
def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, output_dict, num_frames):
|
|
"""Fuse current frame features with memory from previous frames."""
|
|
B = current_vision_feats[-1].shape[0]
|
|
C = self.d_model
|
|
H, W = feat_sizes[-1]
|
|
device = current_vision_feats[-1].device
|
|
|
|
if self.num_maskmem == 0:
|
|
return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W)
|
|
|
|
if is_init_cond_frame:
|
|
# First conditioning frame: no memory yet, add no_mem_embed
|
|
pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1])
|
|
return to_spatial(pix_feat, H, W)
|
|
|
|
to_cat_memory, to_cat_memory_pos, _, _, cond_outputs = collect_memory_tokens(
|
|
output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device)
|
|
|
|
max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder)
|
|
pos_and_ptrs = []
|
|
for t, out in cond_outputs.items():
|
|
if t <= frame_idx:
|
|
pos_and_ptrs.append(((frame_idx - t), out["obj_ptr"].to(device)))
|
|
for t_diff in range(1, max_obj_ptrs):
|
|
t = frame_idx - t_diff
|
|
if t < 0:
|
|
break
|
|
out = output_dict["non_cond_frame_outputs"].get(t, None)
|
|
if out is not None:
|
|
pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device)))
|
|
|
|
num_obj_ptr_tokens = 0
|
|
if len(pos_and_ptrs) > 0:
|
|
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
|
obj_ptrs = torch.stack(ptrs_list, dim=1) # [B, N, C=256]
|
|
|
|
# Temporal position encoding for pointers
|
|
obj_pos = compute_tpos_enc(
|
|
list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj,
|
|
max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype
|
|
) # [N, mem_dim=64]
|
|
obj_pos = obj_pos.unsqueeze(0).expand(B, -1, -1) # [B, N, 64]
|
|
|
|
# Split each 256-dim pointer into 4 x 64-dim tokens
|
|
if self.mem_dim < C:
|
|
N = obj_ptrs.shape[1]
|
|
obj_ptrs = obj_ptrs.view(B, N, C // self.mem_dim, self.mem_dim) # [B, N, 4, 64]
|
|
obj_ptrs = obj_ptrs.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64]
|
|
obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, C // self.mem_dim, -1)
|
|
obj_pos = obj_pos.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64]
|
|
|
|
to_cat_memory.append(obj_ptrs)
|
|
to_cat_memory_pos.append(obj_pos)
|
|
num_obj_ptr_tokens = obj_ptrs.shape[1]
|
|
|
|
if len(to_cat_memory) == 0:
|
|
# No memory available yet, add no_mem_embed
|
|
pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1])
|
|
return to_spatial(pix_feat, H, W)
|
|
|
|
# Concatenate all memory and position encodings [B, total_mem, mem_dim=64]
|
|
memory = torch.cat(to_cat_memory, dim=1)
|
|
memory_pos = torch.cat(to_cat_memory_pos, dim=1)
|
|
|
|
# Run memory attention encoder
|
|
pix_feat = current_vision_feats[-1] # [B, HW, C]
|
|
src_pos = current_vision_pos_embeds[-1] # [B, HW, C]
|
|
|
|
pix_feat_with_mem = self.transformer.encoder(
|
|
x=pix_feat,
|
|
memory=memory,
|
|
src_pos=src_pos,
|
|
memory_pos=memory_pos,
|
|
num_k_exclude_rope=num_obj_ptr_tokens,
|
|
)
|
|
return to_spatial(pix_feat_with_mem, H, W)
|
|
|
|
def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False):
|
|
"""Encode predicted mask into memory features."""
|
|
if is_mask_from_pts:
|
|
mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype)
|
|
else:
|
|
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
|
|
|
mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc)
|
|
|
|
maskmem_out = self.maskmem_backbone(pix_feat, mask_for_mem, skip_mask_sigmoid=True)
|
|
maskmem_features = maskmem_out["vision_features"]
|
|
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
|
|
|
# Add no_obj_embed for occluded objects
|
|
alpha = (object_score_logits > 0).to(maskmem_features.dtype)[..., None, None]
|
|
no_obj = cast_to_input(self.no_obj_embed_spatial, maskmem_features)[..., None, None].expand_as(maskmem_features)
|
|
return maskmem_features + (1 - alpha) * no_obj, maskmem_pos_enc
|
|
|
|
def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, mask_inputs, output_dict,
|
|
num_frames, point_inputs=None):
|
|
"""Track one frame: fuse with memory, predict mask, encode memory."""
|
|
current_out = {}
|
|
|
|
# High-res features for SAM head [stride-8, stride-4]
|
|
if len(current_vision_feats) > 1:
|
|
high_res_features = [
|
|
x.view(x.shape[0], feat_sizes[i][0], feat_sizes[i][1], -1).permute(0, 3, 1, 2)
|
|
for i, x in enumerate(current_vision_feats[:-1])
|
|
]
|
|
else:
|
|
high_res_features = None
|
|
|
|
# Top-level feature for memory
|
|
H, W = feat_sizes[-1]
|
|
|
|
if mask_inputs is not None:
|
|
# Conditioning frame: use mask directly
|
|
pix_feat = to_spatial(current_vision_feats[-1], H, W)
|
|
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
|
|
else:
|
|
# Track frame: fuse with memory, then SAM decoder
|
|
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
|
frame_idx=frame_idx,
|
|
is_init_cond_frame=is_init_cond_frame,
|
|
current_vision_feats=current_vision_feats,
|
|
current_vision_pos_embeds=current_vision_pos_embeds,
|
|
feat_sizes=feat_sizes,
|
|
output_dict=output_dict,
|
|
num_frames=num_frames,
|
|
)
|
|
# Use multimask for point prompts on init frames (picks best of 3 candidates)
|
|
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
|
multimask_output = is_init_cond_frame and 0 < num_pts <= 1
|
|
sam_outputs = self._forward_sam_heads(
|
|
backbone_features=pix_feat_with_mem,
|
|
point_inputs=point_inputs,
|
|
high_res_features=high_res_features,
|
|
multimask_output=multimask_output,
|
|
)
|
|
|
|
(low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs
|
|
|
|
# Clean low-res masks: remove sprinkles and fill holes
|
|
low_res_masks = fill_holes_in_mask_scores(low_res_masks, max_area=200)
|
|
high_res_masks = F.interpolate(low_res_masks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False)
|
|
|
|
current_out["pred_masks"] = low_res_masks
|
|
current_out["pred_masks_high_res"] = high_res_masks
|
|
current_out["obj_ptr"] = obj_ptr
|
|
current_out["object_score_logits"] = object_score_logits
|
|
|
|
# Encode memory
|
|
if self.num_maskmem > 0:
|
|
pix_feat = to_spatial(current_vision_feats[-1], H, W)
|
|
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
|
pix_feat=pix_feat,
|
|
pred_masks_high_res=high_res_masks,
|
|
object_score_logits=object_score_logits,
|
|
is_mask_from_pts=(point_inputs is not None),
|
|
)
|
|
current_out["maskmem_features"] = maskmem_features
|
|
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
else:
|
|
current_out["maskmem_features"] = None
|
|
current_out["maskmem_pos_enc"] = None
|
|
|
|
return current_out
|
|
|
|
def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None):
|
|
vision_feats, vision_pos, feat_sizes, _, _ = _compute_backbone(backbone_fn, frame, frame_idx)
|
|
# SAM3: drop last FPN level
|
|
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
|
|
|
|
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
|
|
"""Track one object, computing backbone per frame to save VRAM."""
|
|
N = images.shape[0]
|
|
device, dt = images.device, images.dtype
|
|
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
|
all_masks = []
|
|
|
|
for frame_idx in tqdm(range(N), desc="tracking"):
|
|
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
|
|
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
|
|
mask_input = None
|
|
if frame_idx == 0:
|
|
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
|
|
size=(self.image_size, self.image_size), mode="bilinear", align_corners=False)
|
|
mask_input = (mask_input > 0.5).to(dt)
|
|
|
|
current_out = self.track_step(
|
|
frame_idx=frame_idx, is_init_cond_frame=(frame_idx == 0),
|
|
current_vision_feats=vision_feats, current_vision_pos_embeds=vision_pos,
|
|
feat_sizes=feat_sizes, mask_inputs=mask_input, output_dict=output_dict, num_frames=N)
|
|
|
|
if frame_idx == 0:
|
|
output_dict["cond_frame_outputs"][frame_idx] = current_out
|
|
else:
|
|
output_dict["non_cond_frame_outputs"][frame_idx] = current_out
|
|
lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder)
|
|
for old_idx in list(output_dict["non_cond_frame_outputs"]):
|
|
if old_idx < frame_idx - lookback:
|
|
del output_dict["non_cond_frame_outputs"][old_idx]
|
|
# Move masks to CPU immediately to free VRAM
|
|
all_masks.append(current_out["pred_masks_high_res"].to(comfy.model_management.intermediate_device()))
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
|
|
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
|
|
|
|
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
|
|
"""Track one or more objects across video frames.
|
|
|
|
Args:
|
|
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
|
|
images: [N, 3, 1008, 1008] video frames
|
|
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
|
|
pbar: optional progress bar
|
|
|
|
Returns:
|
|
[N, N_obj, image_size, image_size] predicted mask logits per frame per object
|
|
"""
|
|
N_obj = initial_masks.shape[0]
|
|
per_object = []
|
|
for obj_idx in range(N_obj):
|
|
obj_masks = self._track_single_object(
|
|
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
|
|
per_object.append(obj_masks)
|
|
|
|
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
|
|
|
|
|
|
class SAM31Tracker(nn.Module):
|
|
"""SAM3.1 multiplex tracker: decoupled memory attention, dual decoder, 16-object multiplex."""
|
|
|
|
def __init__(self, d_model=256, mem_dim=256, num_maskmem=7, num_multiplex=16, device=None, dtype=None, operations=None, **kwargs):
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.mem_dim = mem_dim
|
|
self.num_maskmem = num_maskmem
|
|
self.num_multiplex = num_multiplex
|
|
self.image_size = 1008
|
|
self.backbone_stride = 14
|
|
self.max_obj_ptrs_in_encoder = 16
|
|
self.sigmoid_scale_for_mem_enc = 2.0
|
|
self.sigmoid_bias_for_mem_enc = -1.0
|
|
|
|
# Memory attention (decoupled cross-attention, 8 heads matching reference)
|
|
self.transformer = DecoupledMemoryTransformer(d_model, num_heads=8, dim_ff=2048, num_layers=4,
|
|
device=device, dtype=dtype, operations=operations)
|
|
|
|
# Propagation decoder (multiplex: 16 objects, multimask_outputs_only)
|
|
self.sam_mask_decoder = MultiplexMaskDecoder(d_model, num_multiplex, num_multimask_outputs=3,
|
|
device=device, dtype=dtype, operations=operations)
|
|
# Interactive decoder (single object, same as SAM3)
|
|
self.interactive_sam_mask_decoder = SAMMaskDecoder(d_model, num_multimask_outputs=3,
|
|
device=device, dtype=dtype, operations=operations)
|
|
self.interactive_sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations)
|
|
|
|
# Memory backbone (mem_dim=256, no out_proj compression)
|
|
self.maskmem_backbone = MemoryBackbone(d_model, in_chans=num_multiplex * 2, channels=[16, 64, 256, 1024],
|
|
device=device, dtype=dtype, operations=operations)
|
|
|
|
# Standalone parameters
|
|
self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype))
|
|
self.no_obj_embed_spatial = nn.Parameter(torch.zeros(num_multiplex, mem_dim, device=device, dtype=dtype))
|
|
self.interactivity_no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype))
|
|
|
|
# Object pointer projection
|
|
self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
|
|
self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype)
|
|
self.no_obj_ptr_linear = operations.Linear(d_model, d_model, device=device, dtype=dtype)
|
|
self.interactive_obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
|
|
|
|
# Interactive mask downsample
|
|
self.interactive_mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype)
|
|
|
|
# Multiplex validity embeddings
|
|
self.output_valid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype))
|
|
self.output_invalid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype))
|
|
|
|
# Position encoding for image (used by multiplex decoder)
|
|
self.image_pe_layer = PositionEmbeddingRandom(d_model // 2)
|
|
|
|
def _no_obj_blend(self, obj_ptr, is_obj):
|
|
alpha = is_obj.to(obj_ptr.dtype)
|
|
return torch.lerp(self.no_obj_ptr_linear(obj_ptr), obj_ptr, alpha)
|
|
|
|
def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None,
|
|
high_res_features=None, multimask_output=False):
|
|
return forward_sam_heads(backbone_features, self.interactive_sam_prompt_encoder, self.interactive_sam_mask_decoder,
|
|
self.interactive_obj_ptr_proj, self._no_obj_blend, self.image_size,
|
|
point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output)
|
|
|
|
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
|
return use_mask_as_output(backbone_features, high_res_features, mask_inputs,
|
|
self.interactive_mask_downsample, self.interactive_sam_prompt_encoder,
|
|
self.interactive_sam_mask_decoder, self.interactive_obj_ptr_proj,
|
|
self._no_obj_blend, self.image_size, self.backbone_stride)
|
|
|
|
def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats,
|
|
current_vision_pos_embeds, feat_sizes, output_dict, num_frames,
|
|
multiplex_state=None):
|
|
B = current_vision_feats[-1].shape[0]
|
|
C = self.d_model
|
|
H, W = feat_sizes[-1]
|
|
device = current_vision_feats[-1].device
|
|
num_buc = multiplex_state.num_buckets if multiplex_state is not None else None
|
|
|
|
if self.num_maskmem == 0:
|
|
return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W)
|
|
|
|
if is_init_cond_frame:
|
|
pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1])
|
|
return to_spatial(pix_feat, H, W)
|
|
|
|
to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs = collect_memory_tokens(
|
|
output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device,
|
|
collect_image_feats=True, tpos_v2=True, num_buckets=num_buc)
|
|
|
|
max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder)
|
|
pos_and_ptrs = []
|
|
for t, out in cond_outputs.items():
|
|
if t <= frame_idx and "obj_ptr" in out:
|
|
ptr = out["obj_ptr"].to(device)
|
|
if num_buc is not None:
|
|
ptr = _pad_to_buckets(ptr, num_buc)
|
|
pos_and_ptrs.append(((frame_idx - t), ptr))
|
|
for t_diff in range(1, max_obj_ptrs):
|
|
t = frame_idx - t_diff
|
|
if t < 0:
|
|
break
|
|
out = output_dict["non_cond_frame_outputs"].get(t, None)
|
|
if out is not None and "obj_ptr" in out:
|
|
ptr = out["obj_ptr"].to(device)
|
|
if num_buc is not None:
|
|
ptr = _pad_to_buckets(ptr, num_buc)
|
|
pos_and_ptrs.append((t_diff, ptr))
|
|
|
|
num_obj_ptr_tokens = 0
|
|
if len(pos_and_ptrs) > 0:
|
|
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
|
obj_ptrs = torch.stack(ptrs_list, dim=1) # [num_buckets, N, M, C]
|
|
B_ptr = obj_ptrs.shape[0]
|
|
N_ptrs = obj_ptrs.shape[1]
|
|
M = obj_ptrs.shape[2]
|
|
obj_ptrs = obj_ptrs.reshape(B_ptr, N_ptrs * M, -1)
|
|
obj_pos = compute_tpos_enc(list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj,
|
|
max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype)
|
|
obj_pos = obj_pos.unsqueeze(0).expand(B_ptr, -1, -1)
|
|
obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, M, -1).reshape(B_ptr, N_ptrs * M, -1)
|
|
to_cat_memory.append(obj_ptrs)
|
|
to_cat_memory_pos.append(obj_pos)
|
|
num_obj_ptr_tokens = obj_ptrs.shape[1]
|
|
|
|
if len(to_cat_memory) == 0:
|
|
pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1])
|
|
return to_spatial(pix_feat, H, W)
|
|
|
|
memory = torch.cat(to_cat_memory, dim=1)
|
|
memory_pos = torch.cat(to_cat_memory_pos, dim=1)
|
|
|
|
# Expand vision features to num_buckets if memory has more buckets than B
|
|
mem_B = memory.shape[0]
|
|
x = current_vision_feats[-1]
|
|
x_pos = current_vision_pos_embeds[-1]
|
|
if x.shape[0] < mem_B:
|
|
x = x.expand(mem_B, -1, -1)
|
|
x_pos = x_pos.expand(mem_B, -1, -1)
|
|
|
|
if len(to_cat_image_feat) > 0:
|
|
# Decoupled cross-attention: separate image features from memory
|
|
memory_image = cast_to_input(torch.cat(to_cat_image_feat, dim=1), x)
|
|
memory_image_pos = cast_to_input(torch.cat(to_cat_image_pos, dim=1), x)
|
|
if memory_image.shape[0] < mem_B:
|
|
memory_image = memory_image.expand(mem_B, -1, -1)
|
|
memory_image_pos = memory_image_pos.expand(mem_B, -1, -1)
|
|
pix_feat_with_mem = self.transformer.encoder(
|
|
x=x,
|
|
memory=cast_to_input(memory, x),
|
|
memory_pos=cast_to_input(memory_pos, x),
|
|
src_pos=cast_to_input(x_pos, x),
|
|
num_k_exclude_rope=num_obj_ptr_tokens,
|
|
memory_image=memory_image,
|
|
memory_image_pos=memory_image_pos,
|
|
)
|
|
else:
|
|
pix_feat_with_mem = self.transformer.encoder(
|
|
x=x,
|
|
memory=memory,
|
|
memory_pos=memory_pos,
|
|
src_pos=x_pos,
|
|
num_k_exclude_rope=num_obj_ptr_tokens,
|
|
)
|
|
return to_spatial(pix_feat_with_mem, H, W)
|
|
|
|
def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False,
|
|
multiplex_state=None, is_conditioning=False, cond_obj_mask=None):
|
|
if is_mask_from_pts:
|
|
mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype)
|
|
else:
|
|
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
|
mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc)
|
|
|
|
# Mux masks: [N_obj, 1, H, W] -> [num_buckets, M, H, W]
|
|
mux_masks = multiplex_state.mux(mask_for_mem[:, 0])
|
|
|
|
# Conditioning channel: 1.0 = clean mask (trust it), 0.0 = propagation (noisy)
|
|
N_obj = mask_for_mem.shape[0]
|
|
cond_values = torch.full((N_obj,), 0.0, device=mask_for_mem.device, dtype=mask_for_mem.dtype)
|
|
if is_conditioning:
|
|
cond_values[:] = 1.0
|
|
elif cond_obj_mask is not None:
|
|
cond_values[cond_obj_mask] = 1.0
|
|
cond_spatial = cond_values.view(-1, 1, 1, 1).expand_as(mask_for_mem[:, 0:1, :, :]).squeeze(1)
|
|
mux_cond = multiplex_state.mux(cond_spatial) # [num_buckets, M, H, W]
|
|
mux_input = torch.cat([mux_masks, mux_cond], dim=1) # [num_buckets, 2*M, H, W]
|
|
|
|
maskmem_out = self.maskmem_backbone(pix_feat, mux_input, skip_mask_sigmoid=True)
|
|
maskmem_features = maskmem_out["vision_features"]
|
|
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
|
|
|
# Add no_obj_embed_spatial for occluded objects
|
|
is_obj = (object_score_logits > 0).float() # [N_obj, 1]
|
|
mux_is_obj = multiplex_state.mux(is_obj) # [num_buckets, M, 1]
|
|
no_obj_embed = cast_to_input(self.no_obj_embed_spatial, maskmem_features) # [M, C]
|
|
no_obj_spatial = no_obj_embed.unsqueeze(0)[..., None, None] # [1, M, C, 1, 1]
|
|
# Expand and sum across multiplex slots weighted by (1 - is_obj)
|
|
alpha = mux_is_obj[..., None, None] # [num_buckets, M, 1, 1, 1]
|
|
per_slot_no_obj = ((1 - alpha) * no_obj_spatial).sum(dim=1) # [num_buckets, C, 1, 1]
|
|
maskmem_features = maskmem_features + per_slot_no_obj.expand_as(maskmem_features)
|
|
|
|
return maskmem_features, maskmem_pos_enc
|
|
|
|
def _forward_propagation(self, backbone_features, high_res_features=None, multiplex_state=None):
|
|
"""Propagation path using the multiplex SAM decoder (no prompts)."""
|
|
B = backbone_features.shape[0]
|
|
device = backbone_features.device
|
|
|
|
# Suppression embeddings from valid object mask
|
|
valid_mask = cast_to_input(multiplex_state.get_valid_object_mask().unsqueeze(-1).float(), backbone_features)
|
|
output_valid = cast_to_input(self.output_valid_embed, backbone_features).unsqueeze(0)
|
|
output_invalid = cast_to_input(self.output_invalid_embed, backbone_features).unsqueeze(0)
|
|
extra_embed = valid_mask * output_valid + (1 - valid_mask) * output_invalid
|
|
|
|
image_pe = self.image_pe_layer((backbone_features.shape[-2], backbone_features.shape[-1]), device=backbone_features.device)
|
|
image_pe = cast_to_input(image_pe, backbone_features)
|
|
|
|
masks, iou_pred, sam_tokens_out, object_score_logits = self.sam_mask_decoder(
|
|
image_embeddings=backbone_features, image_pe=image_pe,
|
|
sparse_prompt_embeddings=torch.empty(B, 0, self.d_model, device=device, dtype=backbone_features.dtype),
|
|
dense_prompt_embeddings=torch.zeros(B, self.d_model, *backbone_features.shape[-2:], device=device, dtype=backbone_features.dtype),
|
|
high_res_features=high_res_features, multimask_output=True, return_all=True,
|
|
extra_per_object_embeddings=extra_embed.expand(B, -1, -1),
|
|
)
|
|
# masks: [B=num_buckets, M, T, H, W]
|
|
# Demux to per-object: [N_obj, T, H, W]
|
|
masks_obj = multiplex_state.demux(masks)
|
|
iou_obj = multiplex_state.demux(iou_pred)
|
|
score_obj = multiplex_state.demux(object_score_logits)
|
|
tokens_obj = multiplex_state.demux(sam_tokens_out)
|
|
|
|
# Select best mask by IoU for each object
|
|
best_idx = torch.argmax(iou_obj, dim=-1) # [N_obj]
|
|
N_obj = masks_obj.shape[0]
|
|
obj_range = torch.arange(N_obj, device=device)
|
|
low_res_masks = masks_obj[obj_range, best_idx].unsqueeze(1) # [N_obj, 1, H, W]
|
|
# Suppress masks for objects with low confidence
|
|
is_obj = score_obj > 0
|
|
low_res_masks = torch.where(is_obj[:, :, None, None], low_res_masks,
|
|
torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_masks.dtype))
|
|
high_res_masks = F.interpolate(low_res_masks.float(), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False)
|
|
|
|
# Object pointer: compute per-object, mux for storage as [num_buckets, M, C]
|
|
sam_token = tokens_obj[:, 0] # [N_obj, C]
|
|
obj_ptr = self.obj_ptr_proj(sam_token)
|
|
is_obj = (score_obj > 0).float()
|
|
no_obj = self.no_obj_ptr_linear(obj_ptr)
|
|
obj_ptr = is_obj * obj_ptr + (1 - is_obj) * no_obj
|
|
obj_ptr_muxed = multiplex_state.mux(obj_ptr) # [num_buckets, M, C]
|
|
|
|
return low_res_masks, high_res_masks, obj_ptr_muxed, score_obj
|
|
|
|
def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds,
|
|
feat_sizes, mask_inputs, output_dict, num_frames, point_inputs=None,
|
|
interactive_high_res=None, interactive_backbone=None, propagation_high_res=None,
|
|
multiplex_state=None, run_mem_encoder=True):
|
|
current_out = {}
|
|
H, W = feat_sizes[-1]
|
|
|
|
if mask_inputs is not None:
|
|
# Conditioning frame: use interactive features if available, else propagation
|
|
if interactive_backbone is not None:
|
|
pix_feat = interactive_backbone
|
|
# Add no_mem_embed for interactive path
|
|
pix_flat = pix_feat.flatten(2)
|
|
bf = pix_flat.permute(0, 2, 1) + cast_to_input(self.interactivity_no_mem_embed, pix_flat)
|
|
pix_feat = to_spatial(bf, H, W)
|
|
hi_res = interactive_high_res
|
|
else:
|
|
# Fallback: interactive backbone not available (e.g. called outside track_video).
|
|
# Propagation features work but may produce lower-quality conditioning.
|
|
pix_feat = to_spatial(current_vision_feats[-1], H, W)
|
|
hi_res = propagation_high_res
|
|
sam_outputs = self._use_mask_as_output(pix_feat, hi_res, mask_inputs)
|
|
elif point_inputs is not None:
|
|
# Interactive path: use interactive SAM decoder
|
|
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
|
frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame,
|
|
current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds,
|
|
feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames,
|
|
multiplex_state=multiplex_state,
|
|
)
|
|
hi_res = interactive_high_res if interactive_high_res is not None else propagation_high_res
|
|
num_pts = point_inputs["point_labels"].size(1)
|
|
multimask_output = is_init_cond_frame and 0 < num_pts <= 1
|
|
sam_outputs = self._forward_sam_heads(
|
|
backbone_features=pix_feat_with_mem, point_inputs=point_inputs,
|
|
high_res_features=hi_res, multimask_output=multimask_output,
|
|
)
|
|
else:
|
|
# Propagation path: use multiplex SAM decoder with propagation features
|
|
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
|
frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame,
|
|
current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds,
|
|
feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames,
|
|
multiplex_state=multiplex_state,
|
|
)
|
|
sam_outputs = self._forward_propagation(pix_feat_with_mem, propagation_high_res,
|
|
multiplex_state=multiplex_state)
|
|
|
|
(low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs
|
|
|
|
# Mux obj_ptr if it came from interactive path (shape [B, C]) vs propagation ([num_buckets, M, C])
|
|
if multiplex_state is not None and obj_ptr.dim() == 2:
|
|
obj_ptr = multiplex_state.mux(obj_ptr) # [N_obj, C] -> [num_buckets, M, C]
|
|
|
|
# Encode memory (can be deferred with run_mem_encoder=False)
|
|
if run_mem_encoder and self.num_maskmem > 0:
|
|
pix_feat = to_spatial(current_vision_feats[-1], H, W)
|
|
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
|
pix_feat=pix_feat, pred_masks_high_res=high_res_masks,
|
|
object_score_logits=object_score_logits,
|
|
is_mask_from_pts=(point_inputs is not None),
|
|
multiplex_state=multiplex_state,
|
|
is_conditioning=(mask_inputs is not None),
|
|
)
|
|
current_out["maskmem_features"] = maskmem_features
|
|
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
else:
|
|
current_out["maskmem_features"] = None
|
|
current_out["maskmem_pos_enc"] = None
|
|
|
|
# Store propagation image features for decoupled memory attention
|
|
current_out["image_features"] = current_vision_feats[-1] # [B, HW, C]
|
|
current_out["image_pos_enc"] = current_vision_pos_embeds[-1] # [B, HW, C]
|
|
|
|
current_out["pred_masks"] = low_res_masks
|
|
current_out["pred_masks_high_res"] = high_res_masks
|
|
current_out["obj_ptr"] = obj_ptr
|
|
current_out["object_score_logits"] = object_score_logits
|
|
|
|
return current_out
|
|
|
|
def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None):
|
|
vision_feats, vision_pos, feat_sizes, features, trunk_out = _compute_backbone(backbone_fn, frame, frame_idx)
|
|
return vision_feats, vision_pos, feat_sizes, list(features[:-1]), trunk_out
|
|
|
|
@staticmethod
|
|
def _suppress_recently_occluded(low_res_masks, last_occluded, frame_idx, threshold=0.3):
|
|
"""Suppress overlapping masks for objects that were most recently occluded.
|
|
Prevents corrupted masks from occluded objects from contaminating other objects."""
|
|
N_obj = low_res_masks.shape[0]
|
|
if N_obj <= 1:
|
|
return low_res_masks
|
|
binary = low_res_masks[:, 0] > 0 # [N_obj, H, W]
|
|
iou = _compute_mask_overlap(low_res_masks[:, 0], low_res_masks[:, 0])
|
|
overlapping = torch.triu(iou >= threshold, diagonal=1) # [N, N] upper triangle
|
|
last_occ_i = last_occluded.unsqueeze(1) # [N, 1]
|
|
last_occ_j = last_occluded.unsqueeze(0) # [1, N]
|
|
# Suppress the more recently occluded object in each overlapping pair
|
|
suppress_i = overlapping & (last_occ_i > last_occ_j) & (last_occ_j > -1)
|
|
suppress_j = overlapping & (last_occ_j > last_occ_i) & (last_occ_i > -1)
|
|
to_suppress = suppress_i.any(dim=1) | suppress_j.any(dim=0)
|
|
# Update last_occluded for occluded/suppressed objects
|
|
is_empty = ~binary.any(dim=(-1, -2))
|
|
newly_occluded = is_empty | to_suppress
|
|
last_occluded[newly_occluded] = frame_idx
|
|
# Suppress masks
|
|
low_res_masks[to_suppress] = -10.0
|
|
return low_res_masks
|
|
|
|
def _deferred_memory_encode(self, current_out, N_obj, vision_feats, feat_sizes, mux_state, device,
|
|
cond_obj_mask=None):
|
|
"""Deferred memory encoding for propagation frames. cond_obj_mask: per-object bool for conditioning."""
|
|
low_res_masks = current_out["pred_masks"] # [N_obj, 1, H_low, W_low]
|
|
|
|
if N_obj > 1:
|
|
lr = low_res_masks.squeeze(1) # [N_obj, H, W]
|
|
max_obj = torch.argmax(lr, dim=0, keepdim=True)
|
|
batch_inds = torch.arange(N_obj, device=device)[:, None, None]
|
|
pixel_nol = torch.where(max_obj == batch_inds, lr, torch.clamp(lr, max=-10.0))
|
|
area_before = (lr > 0).sum(dim=(-1, -2)).float().clamp(min=1)
|
|
area_after = (pixel_nol > 0).sum(dim=(-1, -2)).float()
|
|
shrink_ok = (area_after / area_before) >= 0.3
|
|
low_res_masks = torch.where(
|
|
shrink_ok[:, None, None, None].expand_as(low_res_masks),
|
|
low_res_masks, torch.clamp(low_res_masks, max=-10.0))
|
|
|
|
interpol_size = self.maskmem_backbone.mask_downsampler.interpol_size
|
|
mem_masks = F.interpolate(low_res_masks, size=interpol_size,
|
|
mode="bilinear", align_corners=False)
|
|
|
|
obj_scores = torch.where(
|
|
(mem_masks > 0).any(dim=(-1, -2)), 10.0, -10.0)
|
|
|
|
pix_feat = to_spatial(vision_feats[-1], feat_sizes[-1][0], feat_sizes[-1][1])
|
|
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
|
pix_feat=pix_feat, pred_masks_high_res=mem_masks,
|
|
object_score_logits=obj_scores,
|
|
multiplex_state=mux_state, cond_obj_mask=cond_obj_mask)
|
|
current_out["maskmem_features"] = maskmem_features
|
|
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
|
def _add_detected_objects(self, new_masks, mux_state, vision_feats, feat_sizes, current_out):
|
|
"""Grow MultiplexState with new detections, merge masks, re-encode memory. Modifies current_out."""
|
|
n_old = mux_state.total_valid_entries
|
|
mux_state.add_objects(new_masks.shape[0])
|
|
N_obj = mux_state.total_valid_entries
|
|
# Stored memory with old bucket counts is padded at read time by _pad_to_buckets
|
|
for k in ("pred_masks", "pred_masks_high_res"):
|
|
det = F.interpolate(new_masks.unsqueeze(1), size=current_out[k].shape[-2:],
|
|
mode="bilinear", align_corners=False)
|
|
current_out[k] = torch.cat([current_out[k], det], dim=0)
|
|
if self.num_maskmem > 0:
|
|
# Mark new objects as conditioning (clean detection masks) so model trusts them
|
|
cond_mask = torch.zeros(N_obj, dtype=torch.bool, device=new_masks.device)
|
|
cond_mask[n_old:] = True
|
|
self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes,
|
|
mux_state, new_masks.device, cond_obj_mask=cond_mask)
|
|
|
|
def _condition_with_masks(self, masks, frame_idx, vision_feats, vision_pos, feat_sizes,
|
|
high_res_prop, output_dict, N, mux_state, backbone_obj, frame,
|
|
trunk_out, threshold=0.5):
|
|
"""Condition tracker with masks on a frame."""
|
|
mask_input = F.interpolate(masks if masks.dim() == 4 else masks.unsqueeze(1),
|
|
size=(self.image_size, self.image_size), mode="bilinear", align_corners=False)
|
|
mask_input = (mask_input > threshold).to(masks.dtype)
|
|
hi_res = lo_feat = None
|
|
if backbone_obj is not None and backbone_obj.multiplex:
|
|
_, _, itf, _ = backbone_obj(frame, tracker_mode="interactive", cached_trunk=trunk_out, tracker_only=True)
|
|
hi_res, lo_feat = itf[:-1], itf[-1]
|
|
current_out = self.track_step(
|
|
frame_idx=frame_idx, is_init_cond_frame=True, current_vision_feats=vision_feats,
|
|
current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=mask_input,
|
|
output_dict=output_dict, num_frames=N, interactive_high_res=hi_res,
|
|
interactive_backbone=lo_feat, propagation_high_res=high_res_prop,
|
|
multiplex_state=mux_state, run_mem_encoder=True)
|
|
output_dict["cond_frame_outputs"][frame_idx] = current_out
|
|
return current_out
|
|
|
|
def _match_and_add_detections(self, det_masks, det_scores, current_out, mux_state,
|
|
vision_feats, feat_sizes, device, max_objects=0,
|
|
keep_alive=None):
|
|
"""Match detections against tracked masks, add new objects, recondition degraded tracks.
|
|
Updates keep_alive counters: +1 for matched tracks, -1 for unmatched."""
|
|
N_obj = mux_state.total_valid_entries
|
|
if det_masks.shape[0] == 0:
|
|
if keep_alive is not None:
|
|
for i in range(N_obj):
|
|
keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1)
|
|
return []
|
|
|
|
# Match at low-res (like reference)
|
|
trk_masks = current_out["pred_masks"][:, 0] # [N_obj, H_low, W_low]
|
|
det_resized = F.interpolate(det_masks.unsqueeze(1), size=trk_masks.shape[-2:],
|
|
mode="bilinear", align_corners=False)[:, 0]
|
|
overlap = _compute_mask_overlap(det_resized, trk_masks)
|
|
|
|
# Update keep_alive and find matched tracks
|
|
matched = set()
|
|
if overlap.shape[1] > 0:
|
|
matched = set((overlap >= 0.5).any(dim=0).nonzero(as_tuple=True)[0].tolist())
|
|
if keep_alive is not None:
|
|
for i in range(N_obj):
|
|
if i in matched:
|
|
keep_alive[i] = min(8, keep_alive.get(i, 0) + 1)
|
|
else:
|
|
keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1)
|
|
|
|
# Recondition: high-confidence detections (>=0.8) with high overlap refresh tracked masks
|
|
reconditioned = False
|
|
if det_scores is not None and overlap.shape[1] > 0:
|
|
HIGH_CONF = 0.8
|
|
for det_idx in range(overlap.shape[0]):
|
|
if det_scores[det_idx] < HIGH_CONF:
|
|
continue
|
|
best_trk = overlap[det_idx].argmax().item()
|
|
if overlap[det_idx, best_trk] >= 0.5:
|
|
# Replace tracked mask with fresh detection mask
|
|
current_out["pred_masks"][best_trk] = det_resized[det_idx].unsqueeze(0)
|
|
det_hr = F.interpolate(det_masks[det_idx:det_idx+1].unsqueeze(1),
|
|
size=current_out["pred_masks_high_res"].shape[-2:],
|
|
mode="bilinear", align_corners=False)
|
|
current_out["pred_masks_high_res"][best_trk] = det_hr[0]
|
|
reconditioned = True
|
|
|
|
# Re-encode memory if any tracks were reconditioned
|
|
if reconditioned and self.num_maskmem > 0:
|
|
self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device)
|
|
|
|
# Add new detections (not matching any track)
|
|
if max_objects > 0 and N_obj >= max_objects:
|
|
return []
|
|
max_overlap = overlap.max(dim=1)[0] if overlap.shape[1] > 0 else torch.zeros(overlap.shape[0], device=device)
|
|
new_dets = max_overlap < 0.5
|
|
if new_dets.any():
|
|
if max_objects > 0:
|
|
slots = max_objects - N_obj
|
|
new_dets = new_dets & (torch.cumsum(new_dets.int(), 0) <= slots)
|
|
self._add_detected_objects(det_masks[new_dets], mux_state,
|
|
vision_feats, feat_sizes, current_out)
|
|
if keep_alive is not None:
|
|
for i in range(N_obj, mux_state.total_valid_entries):
|
|
keep_alive[i] = 1
|
|
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
|
|
return []
|
|
|
|
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
|
|
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
|
backbone_obj=None, pbar=None):
|
|
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
|
|
N, device, dt = images.shape[0], images.device, images.dtype
|
|
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
|
all_masks = []
|
|
idev = comfy.model_management.intermediate_device()
|
|
empty = lambda: torch.zeros(0, self.image_size, self.image_size, device=idev, dtype=dt)
|
|
mux_state = None
|
|
if initial_masks is not None:
|
|
mux_state = MultiplexState(initial_masks.shape[0], self.num_multiplex, device, dt)
|
|
obj_scores = [] # per-object detection score (1.0 for initial masks)
|
|
keep_alive = {} if detect_fn is not None else None
|
|
last_occluded = torch.empty(0, device=device, dtype=torch.long) # per-object last occluded frame
|
|
|
|
# Prefetch next frame's backbone on a separate CUDA stream
|
|
prefetch = False
|
|
backbone_stream = None
|
|
if comfy.model_management.is_device_cuda(device):
|
|
try:
|
|
backbone_stream = torch.cuda.Stream(device=device)
|
|
prefetch = True
|
|
except RuntimeError:
|
|
pass
|
|
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
|
|
|
|
for frame_idx in tqdm(range(N), desc="tracking"):
|
|
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
|
|
|
|
# Start next frame's backbone on separate stream (overlaps with current frame's work)
|
|
if prefetch and frame_idx + 1 < N:
|
|
backbone_stream.wait_stream(torch.cuda.current_stream(device))
|
|
with torch.cuda.stream(backbone_stream):
|
|
next_bb = self._compute_backbone_frame(
|
|
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
|
|
|
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
|
|
det_masks = torch.empty(0, device=device)
|
|
det_scores = None
|
|
run_det = (detect_fn is not None
|
|
and frame_idx % max(detect_interval, 1) == 0
|
|
and not (max_objects > 0 and mux_state is not None
|
|
and mux_state.total_valid_entries >= max_objects))
|
|
if run_det:
|
|
det_out = detect_fn(trunk_out)
|
|
scores = det_out["scores"][0].sigmoid()
|
|
keep = scores > new_det_thresh
|
|
det_masks, det_scores = det_out["masks"][0][keep], scores[keep]
|
|
if det_masks.shape[0] > 1:
|
|
det_masks, det_scores = _nms_masks(det_masks, det_scores)
|
|
|
|
if frame_idx == 0 and initial_masks is not None:
|
|
current_out = self._condition_with_masks(
|
|
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
|
|
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
|
|
images[frame_idx:frame_idx + 1], trunk_out)
|
|
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
|
obj_scores = [1.0] * mux_state.total_valid_entries
|
|
if keep_alive is not None:
|
|
for i in range(mux_state.total_valid_entries):
|
|
keep_alive[i] = 8
|
|
elif mux_state is None or mux_state.total_valid_entries == 0:
|
|
if det_masks.shape[0] > 0:
|
|
if max_objects > 0:
|
|
det_scores = det_scores[:max_objects]
|
|
det_masks = det_masks[:max_objects]
|
|
mux_state = MultiplexState(det_masks.shape[0], self.num_multiplex, device, dt)
|
|
current_out = self._condition_with_masks(
|
|
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
|
|
output_dict, N, mux_state, backbone_obj,
|
|
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
|
|
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
|
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
|
|
if keep_alive is not None:
|
|
for i in range(mux_state.total_valid_entries):
|
|
keep_alive[i] = 1
|
|
else:
|
|
all_masks.append(empty())
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
# Skip to backbone advance at end of loop
|
|
if frame_idx + 1 < N:
|
|
if prefetch:
|
|
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
|
cur_bb = next_bb
|
|
else:
|
|
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
|
continue
|
|
else:
|
|
N_obj = mux_state.total_valid_entries
|
|
current_out = self.track_step(
|
|
frame_idx=frame_idx, is_init_cond_frame=False, current_vision_feats=vision_feats,
|
|
current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=None,
|
|
output_dict=output_dict, num_frames=N, propagation_high_res=high_res_prop,
|
|
multiplex_state=mux_state, run_mem_encoder=False)
|
|
current_out["pred_masks"] = fill_holes_in_mask_scores(
|
|
current_out["pred_masks"], max_area=16)
|
|
if last_occluded.shape[0] == N_obj and N_obj > 1:
|
|
self._suppress_recently_occluded(
|
|
current_out["pred_masks"], last_occluded, frame_idx)
|
|
if self.num_maskmem > 0:
|
|
self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device)
|
|
output_dict["non_cond_frame_outputs"][frame_idx] = current_out
|
|
lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder)
|
|
for old_idx in list(output_dict["non_cond_frame_outputs"]):
|
|
if old_idx < frame_idx - lookback:
|
|
del output_dict["non_cond_frame_outputs"][old_idx]
|
|
n_before = mux_state.total_valid_entries
|
|
new_obj_scores = self._match_and_add_detections(det_masks, det_scores, current_out, mux_state,
|
|
vision_feats, feat_sizes, device, max_objects,
|
|
keep_alive if run_det else None)
|
|
n_added = mux_state.total_valid_entries - n_before
|
|
if n_added > 0:
|
|
last_occluded = torch.cat([last_occluded,
|
|
torch.full((n_added,), -1, device=device, dtype=torch.long)])
|
|
obj_scores.extend(new_obj_scores)
|
|
|
|
masks_out = current_out["pred_masks_high_res"][:, 0]
|
|
if keep_alive is not None:
|
|
for i in range(masks_out.shape[0]):
|
|
if keep_alive.get(i, 0) <= 0:
|
|
masks_out[i] = NO_OBJ_SCORE
|
|
N_obj_now = mux_state.total_valid_entries if mux_state is not None else 0
|
|
if N_obj_now > 0:
|
|
all_masks.append(pack_masks(masks_out).to(idev))
|
|
else:
|
|
all_masks.append(None)
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
|
|
# Next frame's backbone
|
|
if frame_idx + 1 < N:
|
|
if prefetch:
|
|
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
|
cur_bb = next_bb
|
|
else:
|
|
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
|
|
|
if not all_masks or all(m is None for m in all_masks):
|
|
return {"packed_masks": None, "n_frames": N, "scores": []}
|
|
|
|
max_obj = max(m.shape[0] for m in all_masks if m is not None)
|
|
sample = next(m for m in all_masks if m is not None)
|
|
empty_packed = torch.zeros(max_obj, *sample.shape[1:], dtype=torch.uint8, device=sample.device)
|
|
for i, m in enumerate(all_masks):
|
|
if m is None:
|
|
all_masks[i] = empty_packed
|
|
elif m.shape[0] < max_obj:
|
|
pad = torch.zeros(max_obj - m.shape[0], *m.shape[1:], dtype=torch.uint8, device=m.device)
|
|
all_masks[i] = torch.cat([m, pad], dim=0)
|
|
return {"packed_masks": torch.stack(all_masks, dim=0), "n_frames": N, "scores": obj_scores}
|