From 4ba28caa8c0954205b6e1714bf28ca2c3a6d6079 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 14 Apr 2026 23:57:38 +0300 Subject: [PATCH] Initial SAM3.1 support --- comfy/ldm/sam3/detector.py | 596 ++++++++++ comfy/ldm/sam3/sam.py | 425 +++++++ comfy/ldm/sam3/tracker.py | 1786 ++++++++++++++++++++++++++++++ comfy/model_base.py | 5 + comfy/model_detection.py | 12 + comfy/supported_models.py | 53 +- comfy/text_encoders/sam3_clip.py | 97 ++ comfy_extras/nodes_sam3.py | 514 +++++++++ nodes.py | 3 +- 9 files changed, 3489 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/sam3/detector.py create mode 100644 comfy/ldm/sam3/sam.py create mode 100644 comfy/ldm/sam3/tracker.py create mode 100644 comfy/text_encoders/sam3_clip.py create mode 100644 comfy_extras/nodes_sam3.py diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py new file mode 100644 index 000000000..6ae919a79 --- /dev/null +++ b/comfy/ldm/sam3/detector.py @@ -0,0 +1,596 @@ +# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import roi_align + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker +from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__ +from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine + +TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker} +from comfy.ops import cast_to_input + + +def box_cxcywh_to_xyxy(x): + cx, cy, w, h = x.unbind(-1) + return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1) + + +def gen_sineembed_for_position(pos_tensor, num_feats=256): + """Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats).""" + assert num_feats % 2 == 0 + hdim = num_feats // 2 + freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim) + embeds = [] + for c in range(pos_tensor.shape[-1]): + raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs + embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2)) + return torch.cat(embeds, dim=-1).to(pos_tensor.dtype) + + +class SplitMHA(nn.Module): + """Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight).""" + def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + + def forward(self, q_input, k_input=None, v_input=None, mask=None): + q = self.q_proj(q_input) + if k_input is None: + k = self.k_proj(q_input) + v = self.v_proj(q_input) + else: + k = self.k_proj(k_input) + v = self.v_proj(v_input if v_input is not None else k_input) + if mask is not None and mask.ndim == 2: + mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast + dtype = q.dtype # manual_cast may produce mixed dtypes + out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask) + return self.out_proj(out) + + +class MLPWithNorm(nn.Module): + """MLP with residual connection and output LayerNorm.""" + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList([ + operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) + for i in range(num_layers) + ]) + self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype) + self.residual = residual and (input_dim == output_dim) + + def forward(self, x): + orig = x + for i, layer in enumerate(self.layers): + x = layer(x) + if i < len(self.layers) - 1: + x = F.relu(x) + if self.residual: + x = x + orig + return self.out_norm(x) + + +class EncoderLayer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, x, pos, text_memory=None, text_mask=None): + normed = self.norm1(x) + q_k = normed + pos + x = x + self.self_attn(q_k, q_k, normed) + if text_memory is not None: + normed = self.norm2(x) + x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask) + normed = self.norm3(x) + x = x + self.linear2(F.relu(self.linear1(normed))) + return x + + +class TransformerEncoder(nn.Module): + """Checkpoint: transformer.encoder.layers.N.*""" + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + + def forward(self, x, pos, text_memory=None, text_mask=None): + for layer in self.layers: + x = layer(x, pos, text_memory, text_mask) + return x + + +class DecoderLayer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + + def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None): + q_k = x + x_pos + x = self.norm2(x + self.self_attn(q_k, q_k, x)) + if text_memory is not None: + x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask)) + x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias)) + x = self.norm3(x + self.linear2(F.relu(self.linear1(x)))) + return x + + +class TransformerDecoder(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, + num_queries=200, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.num_queries = num_queries + + self.layers = nn.ModuleList([ + DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype) + self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes + self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each) + self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations) + + self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations) + self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations) + + self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + + @staticmethod + def _inverse_sigmoid(x): + return torch.log(x / (1 - x + 1e-6) + 1e-6) + + def _compute_box_rpb(self, ref_points, H, W): + """Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias.""" + boxes_xyxy = box_cxcywh_to_xyxy(ref_points) + B, Q, _ = boxes_xyxy.shape + coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H + coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W + deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2] + deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2] + + log2_8 = float(math.log2(8)) + def log_scale(d): + return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8 + + rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype)) + rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype)) + + bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2) + pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype) + return torch.cat([pres_bias, bias], dim=2) + + def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72): + B = memory.shape[0] + tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1) + presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1) + ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid() + + for layer_idx, layer in enumerate(self.layers): + query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model)) + tgt_with_pres = torch.cat([presence_out, tgt], dim=1) + pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1) + tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos, + text_memory, text_mask, self._compute_box_rpb(ref_points, H, W)) + presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:] + if layer_idx < len(self.layers) - 1: + ref_inv = self._inverse_sigmoid(ref_points) + ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach() + + query_out = self.norm(tgt) + ref_inv = self._inverse_sigmoid(ref_points) + boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid() + presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1) + return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence} + + +class Transformer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6, + num_queries=200, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations) + self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations) + + +class GeometryEncoder(nn.Module): + def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.roi_size = roi_size + self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True) + self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype) + self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype) + self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype) + self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype) + self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype) + self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.encode = nn.ModuleList([ + EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + + def _encode_points(self, coords, labels, img_feat_2d): + """Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized.""" + B, N, _ = coords.shape + embed = self.points_direct_project(coords) + # Pool features from backbone at point locations via grid_sample + grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1] + sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1] + embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C] + # Positional encoding of coordinates + x, y = coords[:, :, 0], coords[:, :, 1] # [B, N] + pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten()) + enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1) + embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed)) + embed = embed + cast_to_input(self.label_embed(labels.long()), embed) + return embed + + def _encode_boxes(self, boxes, labels, img_feat_2d): + """Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh.""" + B, N, _ = boxes.shape + embed = self.boxes_direct_project(boxes) + # ROI align from backbone at box regions + H, W = img_feat_2d.shape[-2:] + boxes_xyxy = box_cxcywh_to_xyxy(boxes) + scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device) + boxes_scaled = boxes_xyxy * scale + sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size) + proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C] + embed = embed + proj + # Positional encoding of box center + size + cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3] + enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten()) + enc = enc.view(B, N, -1) + embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed)) + embed = embed + cast_to_input(self.label_embed(labels.long()), embed) + return embed + + def forward(self, points=None, boxes=None, image_features=None): + """Encode geometry prompts. image_features: [B, HW, C] flattened backbone features.""" + # Prepare 2D image features for pooling + img_feat_2d = None + if image_features is not None: + B = image_features.shape[0] + HW, C = image_features.shape[1], image_features.shape[2] + hw = int(math.sqrt(HW)) + img_normed = self.img_pre_norm(image_features) + img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw) + + embeddings = [] + if points is not None: + coords, labels = points + embeddings.append(self._encode_points(coords, labels, img_feat_2d)) + if boxes is not None: + B = boxes.shape[0] + box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device) + embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d)) + if not embeddings: + return None + geo = torch.cat(embeddings, dim=1) + geo = self.norm(geo) + if image_features is not None: + for layer in self.encode: + geo = layer(geo, torch.zeros_like(geo), image_features) + geo = self.encode_norm(geo) + return self.final_proj(geo) + + +class PixelDecoder(nn.Module): + """Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation.""" + def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None): + super().__init__() + self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)]) + self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)]) + + def forward(self, backbone_features): + prev = backbone_features[-1] + for i, feat in enumerate(backbone_features[:-1][::-1]): + prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest")))) + return prev + + +class MaskPredictor(nn.Module): + def __init__(self, d_model=256, device=None, dtype=None, operations=None): + super().__init__() + self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, query_embeddings, pixel_features): + mask_embed = self.mask_embed(query_embeddings) + return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features) + + +class SegmentationHead(nn.Module): + def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations) + self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations) + self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype) + self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype) + + def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None): + if encoder_hidden_states is not None and prompt is not None: + enc_normed = self.cross_attn_norm(encoder_hidden_states) + enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask) + encoder_hidden_states = enc_cross + encoder_hidden_states + + if encoder_hidden_states is not None: + B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1] + encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W) + backbone_features = list(backbone_features) + backbone_features[-1] = encoder_visual + + pixel_features = self.pixel_decoder(backbone_features) + instance_features = self.instance_seg_head(pixel_features) + masks = self.mask_predictor(query_embeddings, instance_features) + return masks + + +class DotProductScoring(nn.Module): + def __init__(self, d_model=256, device=None, dtype=None, operations=None): + super().__init__() + self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations) + self.scale = 1.0 / (d_model ** 0.5) + + def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None): + prompt = self.prompt_mlp(prompt_embeddings) + if prompt_mask is not None: + weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype) + pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1) + else: + pooled = prompt.mean(dim=1) + hs = self.hs_proj(query_embeddings) + pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype) + scores = torch.matmul(hs, pp) + return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1) + + +class SAM3Detector(nn.Module): + def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + image_model = kwargs.pop("image_model", "SAM3") + for k in ("num_heads", "num_head_channels"): + kwargs.pop(k, None) + multiplex = image_model == "SAM31" + # SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0) + self.scalp = 0 if multiplex else 1 + self.backbone = nn.ModuleDict({ + "vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs), + "language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}), + }) + self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations) + self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations) + self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations) + self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations) + + def _get_backbone_features(self, images): + """Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions).""" + bb = self.backbone["vision_backbone"] + if bb.multiplex: + all_f, all_p, tf, tp = bb(images, tracker_mode="propagation") + else: + all_f, all_p, tf, tp = bb(images, need_tracker=True) + return all_f, all_p, tf, tp + + @staticmethod + def _run_geo_layer(layer, x, memory, memory_pos): + x = x + layer.self_attn(layer.norm1(x)) + x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory) + x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x)))) + return x + + def _detect(self, features, positions, text_embeddings=None, text_mask=None, + points=None, boxes=None): + """Shared detection: geometry encoding, transformer, scoring, segmentation.""" + B = features[0].shape[0] + # Scalp for encoder (use top-level feature), but keep all levels for segmentation head + seg_features = features + if self.scalp > 0: + features = features[:-self.scalp] + positions = positions[:-self.scalp] + enc_feat, enc_pos = features[-1], positions[-1] + _, _, H, W = enc_feat.shape + img_flat = enc_feat.flatten(2).permute(0, 2, 1) + pos_flat = enc_pos.flatten(2).permute(0, 2, 1) + + has_prompts = text_embeddings is not None or points is not None or boxes is not None + if has_prompts: + geo_enc = self.geometry_encoder + geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat) + geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1))) + for layer in geo_enc.encode: + geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat) + geo_cls = geo_enc.encode_norm(geo_cls) + if text_embeddings is not None and text_embeddings.shape[0] != B: + text_embeddings = text_embeddings.expand(B, -1, -1) + if text_mask is not None and text_mask.shape[0] != B: + text_mask = text_mask.expand(B, -1) + parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None] + text_embeddings = torch.cat(parts, dim=1) + n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0) + if text_mask is not None: + text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1) + else: + text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device) + + memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask) + dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W) + query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"] + + if text_embeddings is not None: + scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask) + else: + scores = torch.zeros(B, query_out.shape[1], device=query_out.device) + + masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask) + return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out + + def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None): + features, positions, _, _ = self._get_backbone_features(images) + + if text_embeddings is not None: + text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings) + if text_mask is not None: + text_mask = text_mask.bool() + + boxes_xyxy, scores, masks, dec_out = self._detect( + features, positions, text_embeddings, text_mask, points, boxes) + + if orig_size is not None: + oh, ow = orig_size + boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype) + masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False) + + return { + "boxes": boxes_xyxy, + "scores": scores, + "masks": masks, + "presence": dec_out.get("presence"), + } + + def forward_from_trunk(self, trunk_out, text_embeddings, text_mask): + """Run detection using a pre-computed ViTDet trunk output. + + text_embeddings must already be resized through language_backbone.resizer. + Returns dict with boxes (normalized xyxy), scores, masks at detector resolution. + """ + bb = self.backbone["vision_backbone"] + features = [conv(trunk_out) for conv in bb.convs] + positions = [cast_to_input(bb.position_encoding(f), f) for f in features] + + if text_mask is not None: + text_mask = text_mask.bool() + + boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask) + return {"boxes": boxes_xyxy, "scores": scores, "masks": masks} + + +class SAM3Model(nn.Module): + def __init__(self, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + image_model = kwargs.get("image_model", "SAM3") + tracker_cls = TRACKER_CLASSES[image_model] + self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs) + self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs) + + def forward(self, images, **kwargs): + return self.detector(images, **kwargs) + + def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None): + """Interactive segmentation using SAM decoder with point/box/mask prompts. + + Args: + images: [B, 3, 1008, 1008] preprocessed images + point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space + box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space + mask_inputs: [B, 1, H, W] coarse mask logits to refine + Returns: + [B, 1, image_size, image_size] high-res mask logits + """ + bb = self.detector.backbone["vision_backbone"] + if bb.multiplex: + _, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive") + else: + _, _, tracker_features, tracker_positions = bb(images, need_tracker=True) + if self.detector.scalp > 0: + tracker_features = tracker_features[:-self.detector.scalp] + tracker_positions = tracker_positions[:-self.detector.scalp] + + high_res = list(tracker_features[:-1]) + backbone_feat = tracker_features[-1] + B, C, H, W = backbone_feat.shape + # Add no-memory embedding (init frame path) + no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None) + if no_mem is None: + no_mem = getattr(self.tracker, 'no_mem_embed', None) + if no_mem is not None: + feat_flat = backbone_feat.flatten(2).permute(0, 2, 1) + feat_flat = feat_flat + cast_to_input(no_mem, feat_flat) + backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2) + + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + _, high_res_masks, _, _ = self.tracker._forward_sam_heads( + backbone_features=backbone_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + box_inputs=box_inputs, + high_res_features=high_res, + multimask_output=(0 < num_pts <= 1), + ) + return high_res_masks + + def forward_video(self, images, initial_masks, pbar=None, text_prompts=None, + new_det_thresh=0.5, max_objects=0, detect_interval=1): + """Track video with optional per-frame text-prompted detection.""" + bb = self.detector.backbone["vision_backbone"] + + def backbone_fn(frame, frame_idx=None): + trunk_out = bb.trunk(frame) + if bb.multiplex: + _, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True) + else: + _, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True) + return tf, tp, trunk_out + + detect_fn = None + if text_prompts: + resizer = self.detector.backbone["language_backbone"]["resizer"] + resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts] + def detect_fn(trunk_out): + all_scores, all_masks = [], [] + for emb, mask in resized: + det = self.detector.forward_from_trunk(trunk_out, emb, mask) + all_scores.append(det["scores"]) + all_masks.append(det["masks"]) + return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)} + + if hasattr(self.tracker, 'track_video_with_detection'): + return self.tracker.track_video_with_detection( + backbone_fn, images, initial_masks, detect_fn, + new_det_thresh=new_det_thresh, max_objects=max_objects, + detect_interval=detect_interval, backbone_obj=bb, pbar=pbar) + # SAM3 (non-multiplex) — no detection support, requires initial masks + if initial_masks is None: + raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking") + return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb) diff --git a/comfy/ldm/sam3/sam.py b/comfy/ldm/sam3/sam.py new file mode 100644 index 000000000..272781d45 --- /dev/null +++ b/comfy/ldm/sam3/sam.py @@ -0,0 +1,425 @@ +# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.layers import EmbedND +from comfy.ops import cast_to_input + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)]) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x) + return torch.sigmoid(x) if self.sigmoid_output else x + + +class SAMAttention(nn.Module): + def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + internal_dim = embedding_dim // downsample_rate + kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype) + + def forward(self, q, k, v): + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + return self.out_proj(optimized_attention(q, k, v, self.num_heads)) + + +class TwoWayAttentionBlock(nn.Module): + def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None): + super().__init__() + self.skip_first_layer_pe = skip_first_layer_pe + self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype)) + self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + + def forward(self, queries, keys, query_pe, key_pe): + if self.skip_first_layer_pe: + queries = self.norm1(self.self_attn(queries, queries, queries)) + else: + q = queries + query_pe + queries = self.norm1(queries + self.self_attn(q, q, queries)) + q, k = queries + query_pe, keys + key_pe + queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys)) + queries = self.norm3(queries + self.mlp(queries)) + q, k = queries + query_pe, keys + key_pe + keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries)) + return queries, keys + + +class TwoWayTransformer(nn.Module): + def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate, + skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations) + for i in range(depth) + ]) + self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + + def forward(self, image_embedding, image_pe, point_embedding): + queries, keys = point_embedding, image_embedding + for layer in self.layers: + queries, keys = layer(queries, keys, point_embedding, image_pe) + q, k = queries + point_embedding, keys + image_pe + queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys)) + return queries, keys + + +class PositionEmbeddingRandom(nn.Module): + """Fourier feature positional encoding with random gaussian projection.""" + def __init__(self, num_pos_feats=64, scale=None): + super().__init__() + self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats)) + + def _encode(self, normalized_coords): + """Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32.""" + orig_dtype = normalized_coords.dtype + proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32) + projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix + return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype) + + def forward(self, size, device=None): + h, w = size + dev = device if device is not None else self.positional_encoding_gaussian_matrix.device + ones = torch.ones((h, w), device=dev, dtype=torch.float32) + norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1) + return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0) + + def forward_with_coords(self, pixel_coords, image_size): + norm = pixel_coords.clone() + norm[:, :, 0] /= image_size[1] + norm[:, :, 1] /= image_size[0] + return self._encode(norm) + + +# ViTDet backbone + FPN neck + +def window_partition(x: torch.Tensor, window_size: int): + B, H, W, C = x.shape + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw): + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0): + """Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2].""" + t = torch.arange(end_x * end_y, dtype=torch.float32) + ids = torch.stack([(t % end_x) * scale_pos, + torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1) + return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0)) + + +class _ViTMLP(nn.Module): + def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None): + super().__init__() + hidden = int(dim * mlp_ratio) + self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype) + self.act = nn.GELU() + self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class Attention(nn.Module): + """ViTDet multi-head attention with fused QKV projection.""" + + def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.use_rope = use_rope + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + def forward(self, x, freqs_cis=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) + if self.use_rope and freqs_cis is not None: + q, k = apply_rope(q, k, freqs_cis) + return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True)) + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None): + super().__init__() + self.window_size = window_size + self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations) + self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations) + + def forward(self, x, freqs_cis=None): + shortcut = x + x = self.norm1(x) + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + x = x.view(x.shape[0], self.window_size * self.window_size, -1) + x = self.attn(x, freqs_cis=freqs_cis) + x = x.view(-1, self.window_size, self.window_size, x.shape[-1]) + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + else: + B, H, W, C = x.shape + x = x.view(B, H * W, C) + x = self.attn(x, freqs_cis=freqs_cis) + x = x.view(B, H, W, C) + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None): + super().__init__() + self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype) + + def forward(self, x): + return self.proj(x) + + +class ViTDet(nn.Module): + def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24, + global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.num_heads = num_heads + self.global_att_blocks = set(global_att_blocks) + + self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations) + + num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype)) + + self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + + grid_size = img_size // patch_size + pretrain_grid = pretrain_img_size // patch_size + + self.blocks = nn.ModuleList() + for i in range(depth): + is_global = i in self.global_att_blocks + self.blocks.append(Block( + embed_dim, num_heads, mlp_ratio, qkv_bias, + window_size=0 if is_global else window_size, + use_rope=use_rope, + device=device, dtype=dtype, operations=operations, + )) + + if use_rope: + rope_scale = pretrain_grid / grid_size + self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False) + self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False) + else: + self.freqs_cis = None + self.freqs_cis_window = None + + def _get_pos_embed(self, num_tokens): + pos = self.pos_embed + if pos.shape[1] == num_tokens: + return pos + cls_pos = pos[:, :1] + spatial_pos = pos[:, 1:] + old_size = int(math.sqrt(spatial_pos.shape[1])) + new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size + spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2) + tiles_h = new_size // old_size + 1 + tiles_w = new_size // old_size + 1 + tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size] + tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1) + return torch.cat([cls_pos, tiled], dim=1) + + def forward(self, x): + x = self.patch_embed(x) + B, C, Hp, Wp = x.shape + x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C) + + pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x) + x = x + pos[:, 1:Hp * Wp + 1] + + x = x.view(B, Hp, Wp, C) + x = self.ln_pre(x) + + freqs_cis_global = self.freqs_cis + freqs_cis_win = self.freqs_cis_window + if freqs_cis_global is not None: + freqs_cis_global = cast_to_input(freqs_cis_global, x) + if freqs_cis_win is not None: + freqs_cis_win = cast_to_input(freqs_cis_win, x) + + for block in self.blocks: + fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global + x = block(x, freqs_cis=fc) + + return x.permute(0, 3, 1, 2) + + +class FPNScaleConv(nn.Module): + def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None): + super().__init__() + if scale == 4.0: + self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype) + self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype) + proj_in = in_dim // 4 + elif scale == 2.0: + self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype) + proj_in = in_dim // 2 + elif scale == 1.0: + proj_in = in_dim + elif scale == 0.5: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + proj_in = in_dim + self.scale = scale + self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype) + self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype) + + def forward(self, x): + if self.scale == 4.0: + x = F.gelu(self.dconv_2x2_0(x)) + x = self.dconv_2x2_1(x) + elif self.scale == 2.0: + x = self.dconv_2x2(x) + elif self.scale == 0.5: + x = self.pool(x) + x = self.conv_1x1(x) + x = self.conv_3x3(x) + return x + + +class PositionEmbeddingSine(nn.Module): + """2D sinusoidal position encoding (DETR-style) with result caching.""" + def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None): + super().__init__() + assert num_pos_feats % 2 == 0 + self.half_dim = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + self.scale = scale if scale is not None else 2 * math.pi + self._cache = {} + + def _sincos(self, vals): + """Encode 1D values to interleaved sin/cos features.""" + freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim) + raw = vals[..., None] * self.scale / freqs + return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2) + + def _encode_xy(self, x, y): + """Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim].""" + dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim) + pos_x = x[:, None] * self.scale / dim_t + pos_y = y[:, None] * self.scale / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + def encode_boxes(self, cx, cy, w, h): + """Encode box center + size to [N, d_model+2] features.""" + pos_x, pos_y = self._encode_xy(cx, cy) + return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + + def forward(self, x): + B, C, H, W = x.shape + key = (H, W, x.device) + if key not in self._cache: + gy = torch.arange(H, dtype=torch.float32, device=x.device) + gx = torch.arange(W, dtype=torch.float32, device=x.device) + if self.normalize: + gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6) + yy, xx = torch.meshgrid(gy, gx, indexing="ij") + self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0) + return self._cache[key].expand(B, -1, -1, -1) + + +class SAM3VisionBackbone(nn.Module): + def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs) + self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True) + self.multiplex = multiplex + + fpn_args = dict(device=device, dtype=dtype, operations=operations) + if multiplex: + scales = [4.0, 2.0, 1.0] + self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + else: + scales = [4.0, 2.0, 1.0, 0.5] + self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + + def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False): + backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images) + + if tracker_only: + # Skip detector FPN when only tracker features are needed (video tracking) + if self.multiplex: + tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs + else: + tracker_convs = self.sam2_convs + tracker_features = [conv(backbone_out) for conv in tracker_convs] + tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features] + return None, None, tracker_features, tracker_positions + + features = [conv(backbone_out) for conv in self.convs] + positions = [cast_to_input(self.position_encoding(f), f) for f in features] + + if self.multiplex: + if tracker_mode == "propagation": + tracker_convs = self.propagation_convs + elif tracker_mode == "interactive": + tracker_convs = self.interactive_convs + else: + return features, positions, None, None + elif need_tracker: + tracker_convs = self.sam2_convs + else: + return features, positions, None, None + + tracker_features = [conv(backbone_out) for conv in tracker_convs] + tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features] + return features, positions, tracker_features, tracker_positions diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py new file mode 100644 index 000000000..84f2a3e4c --- /dev/null +++ b/comfy/ldm/sam3/tracker.py @@ -0,0 +1,1786 @@ +# SAM3 video tracker: memory encoder, memory attention, SAM mask decoder/prompt encoder. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +try: + import cv2 + _HAS_CV2 = True +except ImportError: + from scipy import ndimage + _HAS_CV2 = False + +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.sam3.sam import rope_2d, PositionEmbeddingSine +from comfy.ops import cast_to_input +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.cascade.common import LayerNorm2d_op +from comfy.ldm.sam3.sam import MLP, PositionEmbeddingRandom +from comfy.ldm.sam3.sam import TwoWayTransformer as SAMTwoWayTransformer + +NO_OBJ_SCORE = -1024.0 + + +def to_spatial(x, H, W): + """Reshape (B, H*W, C) → (B, C, H, W).""" + return x.view(x.shape[0], H, W, -1).permute(0, 3, 1, 2) + +class MultiplexState: + """Tracks object-to-slot assignments for multiplex tracking. Provides mux/demux operations.""" + + def __init__(self, num_objects, multiplex_count, device, dtype): + self.multiplex_count = multiplex_count + self.device = device + self.dtype = dtype + self._build(num_objects) + + def mux(self, x): + """[N_obj, ...] -> [num_buckets, multiplex_count, ...]""" + out_shape = (self.num_buckets, self.multiplex_count) + x.shape[1:] + return (self.mux_matrix.to(device=x.device, dtype=x.dtype) @ x.reshape(self.total_valid_entries, -1)).view(out_shape) + + def demux(self, x): + """[num_buckets, multiplex_count, ...] -> [N_obj, ...]""" + out_shape = (self.total_valid_entries,) + x.shape[2:] + flat = x.reshape(self.num_buckets * self.multiplex_count, -1) + return (self.demux_matrix.to(device=x.device, dtype=x.dtype) @ flat).view(out_shape) + + def get_valid_object_mask(self): + """[num_buckets, multiplex_count] bool tensor, True for valid slots.""" + return (self.mux_matrix.sum(dim=1) > 0).reshape(self.num_buckets, self.multiplex_count) + + def _build(self, num_objects): + M = self.multiplex_count + self.num_buckets = (num_objects + M - 1) // M + self.total_valid_entries = num_objects + total_slots = self.num_buckets * M + self.mux_matrix = torch.zeros(total_slots, num_objects, device=self.device, dtype=self.dtype) + self.demux_matrix = torch.zeros(num_objects, total_slots, device=self.device, dtype=self.dtype) + oids = torch.arange(num_objects, device=self.device) + slots = (oids // M) * M + (oids % M) + self.mux_matrix[slots, oids] = 1.0 + self.demux_matrix[oids, slots] = 1.0 + + def add_objects(self, n_new): + """Grow multiplex state for n_new additional objects.""" + self._build(self.total_valid_entries + n_new) + +def _compute_mask_overlap(masks_a, masks_b): + """Max of IoU and IoM (intersection over minimum area). More robust to size differences.""" + a_flat = (masks_a > 0).float().flatten(1) + b_flat = (masks_b > 0).float().flatten(1) + intersection = a_flat @ b_flat.T + area_a = a_flat.sum(1, keepdim=True) + area_b = b_flat.sum(1, keepdim=True).T + iou = intersection / (area_a + area_b - intersection).clamp(min=1) + iom = intersection / torch.min(area_a.expand_as(iou), area_b.expand_as(iou)).clamp(min=1) + return torch.max(iou, iom) + + +def _nms_masks(masks, scores, thresh=0.5): + """Mask-based NMS using IoU+IoM overlap. Returns (filtered_masks, filtered_scores).""" + order = scores.argsort(descending=True) + masks, scores = masks[order], scores[order] + keep = [] + for i in range(masks.shape[0]): + if keep: + if _compute_mask_overlap(masks[i:i+1], masks[torch.tensor(keep, device=masks.device)]).max() >= thresh: + continue + keep.append(i) + return masks[keep], scores[keep] + + +def _get_connected_components(mask_bin): + """Get connected component labels and areas. mask_bin: [B, 1, H, W] uint8.""" + labels_list, areas_list = [], [] + for i in range(mask_bin.shape[0]): + m = mask_bin[i, 0].cpu().numpy() + if _HAS_CV2: + _, labeled, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8) + areas = stats[labeled, cv2.CC_STAT_AREA].astype('int32') + else: + labeled, num_features = ndimage.label(m) + areas = np.zeros_like(m, dtype=np.int32) + for c in range(1, num_features + 1): + component = labeled == c + areas[component] = component.sum() + labels_list.append(torch.from_numpy(labeled).to(mask_bin.device)) + areas_list.append(torch.from_numpy(areas).to(device=mask_bin.device, dtype=torch.int32)) + return torch.stack(labels_list).unsqueeze(1), torch.stack(areas_list).unsqueeze(1) + + +def fill_holes_in_mask_scores(mask, max_area=0): + """Remove small foreground sprinkles and fill small background holes using connected components.""" + if max_area <= 0: + return mask + + # Fill holes: small connected components in background → foreground + mask_bg = (mask <= 0).to(torch.uint8) + _, areas_bg = _get_connected_components(mask_bg) + small_bg = mask_bg.bool() & (areas_bg <= max_area) + mask = torch.where(small_bg, 0.1, mask) + + # Remove sprinkles: small connected components in foreground → background + # Only remove if area < min(max_area, half of total foreground area) + mask_fg = (mask > 0).to(torch.uint8) + fg_area_thresh = mask_fg.sum(dim=(2, 3), keepdim=True, dtype=torch.int32) + fg_area_thresh.floor_divide_(2).clamp_(max=max_area) + _, areas_fg = _get_connected_components(mask_fg) + small_fg = mask_fg.bool() & (areas_fg <= fg_area_thresh) + mask = torch.where(small_fg, -0.1, mask) + + return mask + + +def apply_rope_memory(q, k, freqs, num_heads, num_k_exclude_rope=0): + """Apply 2D axial RoPE to memory attention using flux rope format. + + Args: + q: [B, Nq, C] projected queries (current frame features) + k: [B, Nk, C] projected keys (memory tokens) + freqs: [1, Nq, dim//2, 2, 2] flux-format rotation matrices for one frame + num_heads: number of attention heads + num_k_exclude_rope: number of trailing k tokens to skip RoPE (object pointers) + """ + B, Nq, C = q.shape + head_dim = C // num_heads + + # freqs shape: [1, 1, Nq, dim//2, 2, 2] (heads broadcast dim already included) + q_h = q.view(B, Nq, num_heads, head_dim).transpose(1, 2) + q_h = apply_rope1(q_h, freqs) + q = q_h.transpose(1, 2).reshape(B, Nq, C) + + # Apply RoPE to k (excluding last num_k_exclude_rope tokens) + Nk = k.shape[1] + num_k_rope = Nk - num_k_exclude_rope + if num_k_rope > 0: + # Repeat freqs for multiple frames of spatial memory + Nf = freqs.shape[2] # spatial positions in one frame + if num_k_rope > Nf: + r = (num_k_rope + Nf - 1) // Nf + pe_k = freqs.repeat(1, 1, r, 1, 1, 1)[:, :, :num_k_rope] + else: + pe_k = freqs[:, :, :num_k_rope] + + k_h = k[:, :num_k_rope].view(B, num_k_rope, num_heads, head_dim).transpose(1, 2) + k_h = apply_rope1(k_h, pe_k) + k = k.clone() + k[:, :num_k_rope] = k_h.transpose(1, 2).reshape(B, num_k_rope, C) + + return q, k + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """1D sinusoidal positional encoding for temporal positions.""" + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + pos_embed = pos_inds.unsqueeze(-1) / dim_t + return torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + + +def _pad_to_buckets(tensor, target_buckets): + """Pad a [num_buckets, ...] tensor to target_buckets along dim 0 if needed.""" + if tensor.shape[0] >= target_buckets: + return tensor + pad_shape = (target_buckets - tensor.shape[0],) + tensor.shape[1:] + return torch.cat([tensor, torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)], dim=0) + + +def pack_masks(masks): + """Pack binary masks [*, H, W] to bit-packed [*, H, W//8] uint8. W must be divisible by 8.""" + binary = masks > 0 + shifts = torch.arange(8, device=masks.device) + return (binary.view(*masks.shape[:-1], -1, 8) * (1 << shifts)).sum(-1).byte() + + +def unpack_masks(packed): + """Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8].""" + shifts = torch.arange(8, device=packed.device) + return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool() + + +def _compute_backbone(backbone_fn, frame, frame_idx=None): + """Compute backbone features for a single frame. Returns (vision_feats, vision_pos, feat_sizes, features, trunk_out).""" + features, positions, trunk_out = backbone_fn(frame, frame_idx=frame_idx) + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in features] + vision_feats = [x.flatten(2).permute(0, 2, 1) for x in features] + vision_pos = [x.flatten(2).permute(0, 2, 1) for x in positions] + return vision_feats, vision_pos, feat_sizes, features, trunk_out + + +def collect_memory_tokens(output_dict, frame_idx, num_maskmem, maskmem_tpos_enc, device, + collect_image_feats=False, tpos_v2=False, num_buckets=None): + """Collect spatial memory, position encodings, and optionally image features from past frames.""" + to_cat_memory, to_cat_memory_pos = [], [] + to_cat_image_feat, to_cat_image_pos = [], [] + + def _append(out, tpos_idx): + feats = out["maskmem_features"].to(device) + if num_buckets is not None: + feats = _pad_to_buckets(feats, num_buckets) + to_cat_memory.append(feats.flatten(2).permute(0, 2, 1)) + enc = out["maskmem_pos_enc"][-1].to(device).flatten(2).permute(0, 2, 1) + if num_buckets is not None: + enc = _pad_to_buckets(enc, num_buckets) + tpos = cast_to_input(maskmem_tpos_enc[tpos_idx], enc) + to_cat_memory_pos.append(enc + tpos) + if collect_image_feats and "image_features" in out: + to_cat_image_feat.append(out["image_features"].to(device)) + to_cat_image_pos.append(out["image_pos_enc"].to(device) + tpos) + + cond_outputs = output_dict["cond_frame_outputs"] + for t, out in cond_outputs.items(): + if tpos_v2: + t_pos = frame_idx - t + tpos_idx = num_maskmem - t_pos - 1 if 0 < t_pos < num_maskmem else num_maskmem - 1 + else: + tpos_idx = num_maskmem - 1 + _append(out, tpos_idx) + + for t_pos in range(1, num_maskmem): + out = output_dict["non_cond_frame_outputs"].get(frame_idx - (num_maskmem - t_pos), None) + if out is None or out.get("maskmem_features") is None: + continue + _append(out, num_maskmem - t_pos - 1) + + return to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs + + +def compute_tpos_enc(rel_pos_list, device, d_model, proj_layer, dtype=None, max_abs_pos=None): + """Temporal position encoding for object pointers.""" + pos_enc = torch.tensor(rel_pos_list, dtype=torch.float32, device=device) / max((max_abs_pos or 2) - 1, 1) + pos_enc = get_1d_sine_pe(pos_enc, dim=d_model) + if dtype is not None: + pos_enc = pos_enc.to(dtype) + return proj_layer(pos_enc) + + +def forward_sam_heads(backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, + image_size, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + """Shared SAM prompt encoder + mask decoder forward for both SAM3 and SAM3.1 trackers.""" + device = backbone_features.device + # Batch size from inputs (mask_inputs may have N_obj > 1 while backbone is batch 1) + if mask_inputs is not None: + B = mask_inputs.shape[0] + elif box_inputs is not None: + B = box_inputs.shape[0] + elif point_inputs is not None: + B = point_inputs["point_coords"].shape[0] + else: + B = backbone_features.shape[0] + + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + else: + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + if mask_inputs is not None: + prompt_size = (prompt_encoder.image_embedding_size[0] * 4, prompt_encoder.image_embedding_size[1] * 4) + if mask_inputs.shape[-2:] != prompt_size: + sam_mask_prompt = F.interpolate(mask_inputs, size=prompt_size, mode="bilinear", align_corners=False, antialias=True) + else: + sam_mask_prompt = mask_inputs + else: + sam_mask_prompt = None + + sparse, dense = prompt_encoder(points=(sam_point_coords, sam_point_labels), boxes=box_inputs, masks=sam_mask_prompt) + sparse = cast_to_input(sparse, backbone_features) + dense = cast_to_input(dense, backbone_features) + image_pe = cast_to_input(prompt_encoder.get_dense_pe(), backbone_features) + + low_res_multimasks, ious, sam_output_tokens, object_score_logits = mask_decoder( + image_embeddings=backbone_features, image_pe=image_pe, + sparse_prompt_embeddings=sparse, dense_prompt_embeddings=dense, + high_res_features=high_res_features, multimask_output=multimask_output, return_all=True, + ) + + is_obj_appearing = object_score_logits > 0 + low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, + torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_multimasks.dtype)) + high_res_multimasks = F.interpolate(low_res_multimasks, size=(image_size, image_size), mode="bilinear", align_corners=False) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + obj_ptr = obj_ptr_proj(sam_output_token) + obj_ptr = no_obj_fn(obj_ptr, is_obj_appearing) + + return low_res_masks, high_res_masks, obj_ptr, object_score_logits + + +def use_mask_as_output(backbone_features, high_res_features, mask_inputs, mask_downsample, + prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, image_size, backbone_stride): + """Shared mask-as-output for both SAM3 and SAM3.1 trackers.""" + out_scale, out_bias = 20.0, -10.0 + mask_inputs_float = cast_to_input(mask_inputs, backbone_features) + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate(high_res_masks, size=(image_size // backbone_stride * 4,) * 2, + mode="bilinear", align_corners=False, antialias=True) + _, _, obj_ptr, _ = forward_sam_heads( + backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, + image_size, mask_inputs=mask_downsample(mask_inputs_float), high_res_features=high_res_features, + ) + is_obj_appearing = torch.any(mask_inputs.flatten(1) > 0.0, dim=1)[..., None] + alpha = is_obj_appearing.to(obj_ptr.dtype) + object_score_logits = out_scale * alpha + out_bias + return low_res_masks, high_res_masks, obj_ptr, object_score_logits + + +# Split attention with configurable input dims (for asymmetric cross-attention) +class SplitAttn(nn.Module): + def __init__(self, embed_dim, num_heads=1, kv_dim=None, internal_dim=None, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + kv_dim = kv_dim or embed_dim + internal_dim = internal_dim or embed_dim + self.q_proj = operations.Linear(embed_dim, internal_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(internal_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, q, k=None, v=None, rope=None, num_k_exclude_rope=0): + if k is None: + k = q + if v is None: + v = k + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) + out = optimized_attention(q, k, v, self.num_heads) + return self.out_proj(out) + + +class MemoryAttnLayer(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.self_attn = SplitAttn(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_image = SplitAttn(d_model, num_heads, kv_dim=kv_dim, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, x, memory, memory_pos=None, rope=None, num_k_exclude_rope=0): + x = x + self.self_attn(self.norm1(x), rope=rope) + mem_k = memory + memory_pos if memory_pos is not None else memory + x = x + self.cross_attn_image(self.norm2(x), mem_k, memory, rope=rope, num_k_exclude_rope=num_k_exclude_rope) + normed = self.norm3(x) + x = x + self.linear2(F.relu(self.linear1(normed))) + return x + + +class MemoryAttnEncoder(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14, + device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + MemoryAttnLayer(d_model, num_heads, kv_dim, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + hw = image_size // patch_size + self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False) + + def forward(self, x, memory, src_pos=None, memory_pos=None, num_k_exclude_rope=0): + if src_pos is not None: + x = x + 0.1 * src_pos + + rope = self._rope.to(device=x.device) + for layer in self.layers: + x = layer(x, memory, memory_pos=memory_pos, rope=rope, num_k_exclude_rope=num_k_exclude_rope) + return self.norm(x) + + +class MemoryTransformer(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = MemoryAttnEncoder(d_model, num_heads, kv_dim, dim_ff, num_layers, device=device, dtype=dtype, operations=operations) + + +def _upscale_masks(output_upscaling, conv_s0, conv_s1, src_out, high_res_features): + """Shared upscaling for SAM mask decoders: deconv + high-res feature integration.""" + dc1, ln1, act1, dc2, act2 = output_upscaling + if high_res_features is not None: + upscaled = act1(ln1(dc1(src_out) + conv_s1(high_res_features[1]))) + upscaled = act2(dc2(upscaled) + conv_s0(high_res_features[0])) + else: + upscaled = act2(dc2(act1(ln1(dc1(src_out))))) + return upscaled + + +class SAMMaskDecoder(nn.Module): + def __init__(self, d_model=256, num_multimask_outputs=3, device=None, dtype=None, operations=None): + super().__init__() + self.num_mask_tokens = num_multimask_outputs + 1 + + self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations) + + self.iou_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.mask_tokens = operations.Embedding(self.num_mask_tokens, d_model, device=device, dtype=dtype) + self.obj_score_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + + # Output upscaling: d_model -> d_model//4 -> d_model//8 at 4x resolution + LN2d = LayerNorm2d_op(operations) + self.output_upscaling = nn.Sequential( + operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(), + operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(), + ) + + # High-res feature integration + self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype) + self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype) + + # Per-mask hypernetwork MLPs + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations) + for _ in range(self.num_mask_tokens) + ]) + + self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_tokens, 3, device=device, dtype=dtype, operations=operations) + self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, + high_res_features=None, multimask_output=False, return_all=False): + B = sparse_prompt_embeddings.shape[0] + ref = sparse_prompt_embeddings + # Token order: [obj_score(1), iou(1), mask(num_mask_tokens)] + tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref), + cast_to_input(self.mask_tokens.weight, ref)], dim=0) + tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1) + + src = image_embeddings + if src.shape[0] != B: + src = src.expand(B, -1, -1, -1) + src = src + dense_prompt_embeddings + pos_src = image_pe.expand(B, -1, -1, -1) + + b, c, h, w = src.shape + src_flat = src.flatten(2).permute(0, 2, 1) + pos_flat = pos_src.flatten(2).permute(0, 2, 1) + + hs, src_out = self.transformer(src_flat, pos_flat, tokens) + + obj_score_token_out = hs[:, 0, :] + iou_token_out = hs[:, 1, :] + mask_tokens_out = hs[:, 2:2 + self.num_mask_tokens, :] + + src_out = src_out.permute(0, 2, 1).view(b, c, h, w) + upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features) + + hyper_in = torch.stack([ + mlp(mask_tokens_out[:, i, :]) for i, mlp in enumerate(self.output_hypernetworks_mlps) + ], dim=1) + + masks = (hyper_in @ upscaled.flatten(2)).view(B, self.num_mask_tokens, upscaled.shape[2], upscaled.shape[3]) + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(obj_score_token_out) + + if multimask_output: + out_masks = masks[:, 1:] + out_iou = iou_pred[:, 1:] + out_tokens = mask_tokens_out[:, 1:] + else: + out_masks = masks[:, 0:1] + out_iou = iou_pred[:, 0:1] + out_tokens = mask_tokens_out[:, 0:1] + + if return_all: + return out_masks, out_iou, out_tokens, object_score_logits + return out_masks, out_iou + + +class SAMPromptEncoder(nn.Module): + def __init__(self, d_model=256, image_embedding_size=(72, 72), input_image_size=(1008, 1008), device=None, dtype=None, operations=None): + super().__init__() + self.embed_dim = d_model + self.image_embedding_size = image_embedding_size + self.input_image_size = input_image_size + + self.pe_layer = PositionEmbeddingRandom(d_model // 2) + self.point_embeddings = nn.ModuleList([ + operations.Embedding(1, d_model, device=device, dtype=dtype) for _ in range(4) + ]) + self.not_a_point_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + + LN2d = LayerNorm2d_op(operations) + self.mask_downscaling = nn.Sequential( + operations.Conv2d(1, 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(4, device=device, dtype=dtype), nn.GELU(), + operations.Conv2d(4, 16, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(16, device=device, dtype=dtype), nn.GELU(), + operations.Conv2d(16, d_model, kernel_size=1, device=device, dtype=dtype), + ) + self.no_mask_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + + def get_dense_pe(self): + return self.pe_layer(self.image_embedding_size) + + def forward(self, points=None, boxes=None, masks=None): + ref = points[0] if points is not None else boxes if boxes is not None else masks + B = 1 + sparse = torch.empty((B, 0, self.embed_dim), device=ref.device, dtype=ref.dtype) + + if points is not None: + coords, labels = points + B = coords.shape[0] + # Pad with an extra point (label=-1) when no boxes are provided (matching reference) + if boxes is None: + coords = torch.cat([coords, torch.zeros(B, 1, 2, device=coords.device, dtype=coords.dtype)], dim=1) + labels = torch.cat([labels, -torch.ones(B, 1, device=labels.device, dtype=labels.dtype)], dim=1) + pe = self.pe_layer.forward_with_coords(coords + 0.5, self.input_image_size) + for i in range(4): + pe[labels == i] += cast_to_input(self.point_embeddings[i].weight, ref) + invalid = (labels == -1) + pe[invalid] = 0.0 + pe[invalid] += cast_to_input(self.not_a_point_embed.weight, ref) + sparse = torch.cat([sparse.expand(B, -1, -1), pe], dim=1) + + if boxes is not None: + B = boxes.shape[0] + corners = self.pe_layer.forward_with_coords((boxes.reshape(-1, 2, 2) + 0.5), self.input_image_size) + corners[:, 0] += cast_to_input(self.point_embeddings[2].weight, ref) + corners[:, 1] += cast_to_input(self.point_embeddings[3].weight, ref) + sparse = torch.cat([sparse.expand(B, -1, -1), corners], dim=1) + + if masks is not None: + dense = self.mask_downscaling(masks) + else: + dense = cast_to_input(self.no_mask_embed.weight, ref).reshape(1, -1, 1, 1).expand( + B, -1, self.image_embedding_size[0], self.image_embedding_size[1]) + + return sparse, dense + + +class CXBlock(nn.Module): + def __init__(self, dim=256, kernel_size=7, device=None, dtype=None, operations=None): + super().__init__() + self.dwconv = operations.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, device=device, dtype=dtype) + self.pwconv1 = operations.Linear(dim, 4 * dim, device=device, dtype=dtype) + self.pwconv2 = operations.Linear(4 * dim, dim, device=device, dtype=dtype) + self.gamma = nn.Parameter(torch.ones(dim, device=device, dtype=dtype)) + + def forward(self, x): + residual = x + x = self.dwconv(x).permute(0, 2, 3, 1) + x = self.pwconv2(F.gelu(self.pwconv1(self.norm(x)))) + x.mul_(cast_to_input(self.gamma, x)) + return residual + x.permute(0, 3, 1, 2) + + +class MaskDownSampler(nn.Module): + def __init__(self, out_dim=256, in_chans=1, channels=None, interpol_size=(1152, 1152), device=None, dtype=None, operations=None): + super().__init__() + self.interpol_size = list(interpol_size) if interpol_size else None + if channels is None: + channels = [4, 16, 64, out_dim] # SAM3 default + LN2d = LayerNorm2d_op(operations) + layers = [] + prev = in_chans + for ch in channels: + layers += [operations.Conv2d(prev, ch, kernel_size=3, stride=2, padding=1, device=device, dtype=dtype), + LN2d(ch, device=device, dtype=dtype), nn.GELU()] + prev = ch + layers.append(operations.Conv2d(prev, out_dim, kernel_size=1, device=device, dtype=dtype)) + self.encoder = nn.Sequential(*layers) + + def forward(self, x): + if self.interpol_size is not None and list(x.shape[-2:]) != self.interpol_size: + x = F.interpolate(x, size=self.interpol_size, mode="bilinear", align_corners=False, antialias=True) + return self.encoder(x) + + +class Fuser(nn.Module): + def __init__(self, dim=256, num_layers=2, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.Sequential(*[CXBlock(dim, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)]) + + def forward(self, x): + return self.layers(x) + + +# --- SAM3.1 Multiplex components --- + +class DecoupledMemoryAttnLayer(nn.Module): + """Decoupled cross-attention layer for SAM3.1: fuses image and memory projections.""" + + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + # Self-attention projections (flat, not nested) + self.self_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # Cross-attention projections + self.cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # Image cross-attention (q/k only, fused with cross_attn) + self.image_cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.image_cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # FFN + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, image, x, memory_image, memory, memory_image_pos=None, + rope=None, num_k_exclude_rope=0): + # Self-attention with RoPE + normed = self.norm1(x) + q = self.self_attn_q_proj(normed) + k = self.self_attn_k_proj(normed) + v = self.self_attn_v_proj(normed) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, 0) + x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + + # Decoupled cross-attention: fuse image and memory projections + normed = self.norm2(x) + q = self.image_cross_attn_q_proj(image) + self.cross_attn_q_proj(normed) + k = self.image_cross_attn_k_proj(memory_image) + self.cross_attn_k_proj(memory) + if memory_image_pos is not None: + k = k + memory_image_pos + v = self.cross_attn_v_proj(memory) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) + x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + + # FFN + x = x + self.linear2(F.gelu(self.linear1(self.norm3(x)))) + return image, x + + +class DecoupledMemoryEncoder(nn.Module): + """Memory attention encoder for SAM3.1 with decoupled cross-attention.""" + + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14, + device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + DecoupledMemoryAttnLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + hw = image_size // patch_size + self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False) + + def forward(self, x, memory, memory_pos=None, src_pos=None, num_k_exclude_rope=0, + memory_image=None, memory_image_pos=None): + image = x # constant residual for decoupled cross-attention + output = x + if src_pos is not None: + output = output + 0.1 * src_pos + + B, _, C = x.shape + rope = self._rope.to(device=x.device) + + # memory_image: raw backbone features from past frames for decoupled cross-attention + if memory_image is None: + # Fallback: use spatial portion of memory (without obj pointers) + num_spatial = memory.shape[1] - num_k_exclude_rope + memory_image = memory[:, :num_spatial] + memory_image_pos = memory_pos[:, :num_spatial] if memory_pos is not None else None + # Pad memory_image to match memory length (zeros for obj pointer tokens) + if memory_image.shape[1] < memory.shape[1]: + pad_len = memory.shape[1] - memory_image.shape[1] + pad = torch.zeros(B, pad_len, C, device=memory.device, dtype=memory.dtype) + memory_image = torch.cat([memory_image, pad], dim=1) + if memory_image_pos is not None: + ptr_pos = memory_pos[:, -pad_len:] if memory_pos is not None else torch.zeros_like(pad) + memory_image_pos = torch.cat([memory_image_pos, ptr_pos], dim=1) + + for layer in self.layers: + image, output = layer(image, output, memory_image, memory, + memory_image_pos=memory_image_pos, rope=rope, + num_k_exclude_rope=num_k_exclude_rope) + + return self.norm(output) + + +class DecoupledMemoryTransformer(nn.Module): + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = DecoupledMemoryEncoder(d_model, num_heads, dim_ff, num_layers, + device=device, dtype=dtype, operations=operations) + + +class MemoryBackbone(nn.Module): + """Memory encoder: downsamples mask, fuses with pixel features, optionally compresses.""" + + def __init__(self, d_model=256, out_dim=None, in_chans=1, channels=None, device=None, dtype=None, operations=None): + super().__init__() + self.mask_downsampler = MaskDownSampler(d_model, in_chans=in_chans, channels=channels, device=device, dtype=dtype, operations=operations) + self.pix_feat_proj = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype) + self.fuser = Fuser(d_model, num_layers=2, device=device, dtype=dtype, operations=operations) + self.has_out_proj = out_dim is not None and out_dim != d_model + if self.has_out_proj: + self.out_proj = operations.Conv2d(d_model, out_dim, kernel_size=1, device=device, dtype=dtype) + feat_dim = out_dim + else: + feat_dim = d_model + self.position_encoding = PositionEmbeddingSine(num_pos_feats=feat_dim, normalize=True) + + def forward(self, image_features, mask_for_mem, skip_mask_sigmoid=False): + if not skip_mask_sigmoid: + mask_for_mem = mask_for_mem.sigmoid() + mask_features = self.mask_downsampler(cast_to_input(mask_for_mem, image_features)) + if mask_features.shape[-2:] != image_features.shape[-2:]: + mask_features = F.interpolate(mask_features, size=image_features.shape[-2:], mode="bilinear", align_corners=False) + features = self.pix_feat_proj(image_features) + mask_features + features = self.fuser(features) + if self.has_out_proj: + features = self.out_proj(features) + pos = cast_to_input(self.position_encoding(features), features) + return {"vision_features": features, "vision_pos_enc": [pos]} + + +class MultiplexMaskDecoder(nn.Module): + """SAM mask decoder for SAM3.1 multiplex: predicts masks for num_multiplex objects simultaneously. + + Uses multimask_outputs_only=True: num_mask_output_per_object = num_multimask_outputs (no +1). + Hypernetwork MLPs are shared across multiplex objects. + Token order: [obj_score_token(M), iou_token(M), mask_tokens(M*T)]. + """ + + def __init__(self, d_model=256, num_multiplex=16, num_multimask_outputs=3, device=None, dtype=None, operations=None): + super().__init__() + self.num_multiplex = num_multiplex + self.num_mask_output_per_object = num_multimask_outputs # 3 (multimask_outputs_only) + total_mask_tokens = num_multiplex * self.num_mask_output_per_object # 48 + + self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations) + + self.obj_score_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype) + self.iou_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype) + self.mask_tokens = operations.Embedding(total_mask_tokens, d_model, device=device, dtype=dtype) + + LN2d = LayerNorm2d_op(operations) + self.output_upscaling = nn.Sequential( + operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(), + operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(), + ) + self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype) + self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype) + + # Shared across all multiplex objects (one per mask output) + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations) + for _ in range(self.num_mask_output_per_object) + ]) + self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_output_per_object, 3, device=device, dtype=dtype, operations=operations) + self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, + high_res_features=None, multimask_output=False, return_all=False, extra_per_object_embeddings=None): + B = sparse_prompt_embeddings.shape[0] + M = self.num_multiplex + T = self.num_mask_output_per_object + + # Token order: [obj_score(M), iou(M), mask(M*T)] + ref = sparse_prompt_embeddings + mask_tokens = cast_to_input(self.mask_tokens.weight, ref) + if extra_per_object_embeddings is not None: + mask_tokens = mask_tokens.view(1, M, T, -1).expand(B, -1, -1, -1) + extra_per_object_embeddings.unsqueeze(2) + mask_tokens = mask_tokens.flatten(1, 2) # [B, M*T, C] + other_tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref)], dim=0).unsqueeze(0).expand(B, -1, -1) + tokens = torch.cat([other_tokens, mask_tokens, sparse_prompt_embeddings], dim=1) + else: + tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref), mask_tokens], dim=0) + tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1) + + src = image_embeddings + if src.shape[0] != B: + src = src.expand(B, -1, -1, -1) + src = src + dense_prompt_embeddings + pos_src = image_pe.expand(B, -1, -1, -1) + + b, c, h, w = src.shape + hs, src_out = self.transformer(src.flatten(2).permute(0, 2, 1), pos_src.flatten(2).permute(0, 2, 1), tokens) + + # Parse output tokens + obj_score_token_out = hs[:, :M] + iou_token_out = hs[:, M:2 * M] + mask_tokens_out = hs[:, 2 * M:2 * M + M * T] + + src_out = src_out.permute(0, 2, 1).view(b, c, h, w) + upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features) + + # Reshape mask tokens to [B, M, T, C] and apply shared hypernetwork MLPs per mask output index + mask_tokens_2d = mask_tokens_out.view(B, M, T, -1) + hyper_in = torch.stack([ + self.output_hypernetworks_mlps[i](mask_tokens_2d[:, :, i, :]) # [B, M, C//8] + for i in range(T) + ], dim=2) # [B, M, T, C//8] + + # Generate masks: [B, M*T, H*W] -> [B, M, T, H, W] + masks = torch.bmm(hyper_in.flatten(1, 2), upscaled.flatten(2)).view(b, M, T, upscaled.shape[2], upscaled.shape[3]) + + # IoU and object scores + iou_pred = self.iou_prediction_head(iou_token_out).view(b, M, T) + object_score_logits = self.pred_obj_score_head(obj_score_token_out) # [B, M, 1] + + # multimask_outputs_only: always output all T masks (no singlemask token) + sam_tokens_out = mask_tokens_2d[:, :, 0:1] # [B, M, 1, C] + + if return_all: + return masks, iou_pred, sam_tokens_out, object_score_logits + return masks, iou_pred + + +class SAM3Tracker(nn.Module): + def __init__(self, d_model=256, mem_dim=64, num_maskmem=7, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + + # Memory attention transformer + self.transformer = MemoryTransformer(d_model, num_heads=1, kv_dim=mem_dim, dim_ff=2048, num_layers=4, + device=device, dtype=dtype, operations=operations) + # SAM components + self.sam_mask_decoder = SAMMaskDecoder(d_model, device=device, dtype=dtype, operations=operations) + self.sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations) + + # Memory backbone + self.maskmem_backbone = MemoryBackbone(d_model, out_dim=mem_dim, device=device, dtype=dtype, operations=operations) + + # Standalone parameters + self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype)) + self.no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype)) + self.register_buffer("no_mem_pos_enc", torch.zeros(1, 1, d_model, device=device, dtype=dtype)) # checkpoint key, unused in forward + self.no_obj_embed_spatial = nn.Parameter(torch.zeros(1, mem_dim, device=device, dtype=dtype)) + self.no_obj_ptr = nn.Parameter(torch.zeros(1, d_model, device=device, dtype=dtype)) + + # Object pointer projection + self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype) + + # Mask downsample: Conv2d stride 4 to reduce GT mask to SAM logit scale + self.mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype) + + # Config + self.d_model = d_model + self.mem_dim = mem_dim + self.num_maskmem = num_maskmem + self.image_size = 1008 + self.backbone_stride = 14 + self.max_obj_ptrs_in_encoder = 16 + self.sigmoid_scale_for_mem_enc = 20.0 + self.sigmoid_bias_for_mem_enc = -10.0 + + def _no_obj_blend(self, obj_ptr, is_obj): + alpha = is_obj.to(obj_ptr.dtype) + return torch.lerp(cast_to_input(self.no_obj_ptr, obj_ptr), obj_ptr, alpha) + + def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + return forward_sam_heads(backbone_features, self.sam_prompt_encoder, self.sam_mask_decoder, + self.obj_ptr_proj, self._no_obj_blend, self.image_size, + point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + return use_mask_as_output(backbone_features, high_res_features, mask_inputs, + self.mask_downsample, self.sam_prompt_encoder, self.sam_mask_decoder, + self.obj_ptr_proj, self._no_obj_blend, self.image_size, self.backbone_stride) + + def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, output_dict, num_frames): + """Fuse current frame features with memory from previous frames.""" + B = current_vision_feats[-1].shape[0] + C = self.d_model + H, W = feat_sizes[-1] + device = current_vision_feats[-1].device + + if self.num_maskmem == 0: + return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W) + + if is_init_cond_frame: + # First conditioning frame: no memory yet, add no_mem_embed + pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + to_cat_memory, to_cat_memory_pos, _, _, cond_outputs = collect_memory_tokens( + output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device) + + max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder) + pos_and_ptrs = [] + for t, out in cond_outputs.items(): + if t <= frame_idx: + pos_and_ptrs.append(((frame_idx - t), out["obj_ptr"].to(device))) + for t_diff in range(1, max_obj_ptrs): + t = frame_idx - t_diff + if t < 0: + break + out = output_dict["non_cond_frame_outputs"].get(t, None) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device))) + + num_obj_ptr_tokens = 0 + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + obj_ptrs = torch.stack(ptrs_list, dim=1) # [B, N, C=256] + + # Temporal position encoding for pointers + obj_pos = compute_tpos_enc( + list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj, + max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype + ) # [N, mem_dim=64] + obj_pos = obj_pos.unsqueeze(0).expand(B, -1, -1) # [B, N, 64] + + # Split each 256-dim pointer into 4 x 64-dim tokens + if self.mem_dim < C: + N = obj_ptrs.shape[1] + obj_ptrs = obj_ptrs.view(B, N, C // self.mem_dim, self.mem_dim) # [B, N, 4, 64] + obj_ptrs = obj_ptrs.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64] + obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, C // self.mem_dim, -1) + obj_pos = obj_pos.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64] + + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[1] + + if len(to_cat_memory) == 0: + # No memory available yet, add no_mem_embed + pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + # Concatenate all memory and position encodings [B, total_mem, mem_dim=64] + memory = torch.cat(to_cat_memory, dim=1) + memory_pos = torch.cat(to_cat_memory_pos, dim=1) + + # Run memory attention encoder + pix_feat = current_vision_feats[-1] # [B, HW, C] + src_pos = current_vision_pos_embeds[-1] # [B, HW, C] + + pix_feat_with_mem = self.transformer.encoder( + x=pix_feat, + memory=memory, + src_pos=src_pos, + memory_pos=memory_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + ) + return to_spatial(pix_feat_with_mem, H, W) + + def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False): + """Encode predicted mask into memory features.""" + if is_mask_from_pts: + mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype) + else: + mask_for_mem = torch.sigmoid(pred_masks_high_res) + + mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc) + + maskmem_out = self.maskmem_backbone(pix_feat, mask_for_mem, skip_mask_sigmoid=True) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + # Add no_obj_embed for occluded objects + alpha = (object_score_logits > 0).to(maskmem_features.dtype)[..., None, None] + no_obj = cast_to_input(self.no_obj_embed_spatial, maskmem_features)[..., None, None].expand_as(maskmem_features) + return maskmem_features + (1 - alpha) * no_obj, maskmem_pos_enc + + def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, mask_inputs, output_dict, + num_frames, point_inputs=None): + """Track one frame: fuse with memory, predict mask, encode memory.""" + current_out = {} + + # High-res features for SAM head [stride-8, stride-4] + if len(current_vision_feats) > 1: + high_res_features = [ + x.view(x.shape[0], feat_sizes[i][0], feat_sizes[i][1], -1).permute(0, 3, 1, 2) + for i, x in enumerate(current_vision_feats[:-1]) + ] + else: + high_res_features = None + + # Top-level feature for memory + H, W = feat_sizes[-1] + + if mask_inputs is not None: + # Conditioning frame: use mask directly + pix_feat = to_spatial(current_vision_feats[-1], H, W) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # Track frame: fuse with memory, then SAM decoder + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + output_dict=output_dict, + num_frames=num_frames, + ) + # Use multimask for point prompts on init frames (picks best of 3 candidates) + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = is_init_cond_frame and 0 < num_pts <= 1 + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + (low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs + + # Clean low-res masks: remove sprinkles and fill holes + low_res_masks = fill_holes_in_mask_scores(low_res_masks, max_area=200) + high_res_masks = F.interpolate(low_res_masks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + current_out["object_score_logits"] = object_score_logits + + # Encode memory + if self.num_maskmem > 0: + pix_feat = to_spatial(current_vision_feats[-1], H, W) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None): + vision_feats, vision_pos, feat_sizes, _, _ = _compute_backbone(backbone_fn, frame, frame_idx) + # SAM3: drop last FPN level + return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1] + + def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None): + """Track one object, computing backbone per frame to save VRAM.""" + N = images.shape[0] + device, dt = images.device, images.dtype + output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} + all_masks = [] + + for frame_idx in tqdm(range(N), desc="tracking"): + vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame( + backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx) + mask_input = None + if frame_idx == 0: + mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt), + size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + mask_input = (mask_input > 0.5).to(dt) + + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=(frame_idx == 0), + current_vision_feats=vision_feats, current_vision_pos_embeds=vision_pos, + feat_sizes=feat_sizes, mask_inputs=mask_input, output_dict=output_dict, num_frames=N) + + if frame_idx == 0: + output_dict["cond_frame_outputs"][frame_idx] = current_out + else: + output_dict["non_cond_frame_outputs"][frame_idx] = current_out + lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder) + for old_idx in list(output_dict["non_cond_frame_outputs"]): + if old_idx < frame_idx - lookback: + del output_dict["non_cond_frame_outputs"][old_idx] + # Move masks to CPU immediately to free VRAM + all_masks.append(current_out["pred_masks_high_res"].to(comfy.model_management.intermediate_device())) + if pbar is not None: + pbar.update(1) + + return torch.cat(all_masks, dim=0) # [N, 1, H, W] + + def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs): + """Track one or more objects across video frames. + + Args: + backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame + images: [N, 3, 1008, 1008] video frames + initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object) + pbar: optional progress bar + + Returns: + [N, N_obj, image_size, image_size] predicted mask logits per frame per object + """ + N_obj = initial_masks.shape[0] + per_object = [] + for obj_idx in range(N_obj): + obj_masks = self._track_single_object( + backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar) + per_object.append(obj_masks) + + return torch.cat(per_object, dim=1) # [N, N_obj, H, W] + + +class SAM31Tracker(nn.Module): + """SAM3.1 multiplex tracker: decoupled memory attention, dual decoder, 16-object multiplex.""" + + def __init__(self, d_model=256, mem_dim=256, num_maskmem=7, num_multiplex=16, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.d_model = d_model + self.mem_dim = mem_dim + self.num_maskmem = num_maskmem + self.num_multiplex = num_multiplex + self.image_size = 1008 + self.backbone_stride = 14 + self.max_obj_ptrs_in_encoder = 16 + self.sigmoid_scale_for_mem_enc = 2.0 + self.sigmoid_bias_for_mem_enc = -1.0 + + # Memory attention (decoupled cross-attention, 8 heads matching reference) + self.transformer = DecoupledMemoryTransformer(d_model, num_heads=8, dim_ff=2048, num_layers=4, + device=device, dtype=dtype, operations=operations) + + # Propagation decoder (multiplex: 16 objects, multimask_outputs_only) + self.sam_mask_decoder = MultiplexMaskDecoder(d_model, num_multiplex, num_multimask_outputs=3, + device=device, dtype=dtype, operations=operations) + # Interactive decoder (single object, same as SAM3) + self.interactive_sam_mask_decoder = SAMMaskDecoder(d_model, num_multimask_outputs=3, + device=device, dtype=dtype, operations=operations) + self.interactive_sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations) + + # Memory backbone (mem_dim=256, no out_proj compression) + self.maskmem_backbone = MemoryBackbone(d_model, in_chans=num_multiplex * 2, channels=[16, 64, 256, 1024], + device=device, dtype=dtype, operations=operations) + + # Standalone parameters + self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype)) + self.no_obj_embed_spatial = nn.Parameter(torch.zeros(num_multiplex, mem_dim, device=device, dtype=dtype)) + self.interactivity_no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype)) + + # Object pointer projection + self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype) + self.no_obj_ptr_linear = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.interactive_obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + + # Interactive mask downsample + self.interactive_mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype) + + # Multiplex validity embeddings + self.output_valid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype)) + self.output_invalid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype)) + + # Position encoding for image (used by multiplex decoder) + self.image_pe_layer = PositionEmbeddingRandom(d_model // 2) + + def _no_obj_blend(self, obj_ptr, is_obj): + alpha = is_obj.to(obj_ptr.dtype) + return torch.lerp(self.no_obj_ptr_linear(obj_ptr), obj_ptr, alpha) + + def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + return forward_sam_heads(backbone_features, self.interactive_sam_prompt_encoder, self.interactive_sam_mask_decoder, + self.interactive_obj_ptr_proj, self._no_obj_blend, self.image_size, + point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + return use_mask_as_output(backbone_features, high_res_features, mask_inputs, + self.interactive_mask_downsample, self.interactive_sam_prompt_encoder, + self.interactive_sam_mask_decoder, self.interactive_obj_ptr_proj, + self._no_obj_blend, self.image_size, self.backbone_stride) + + def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, + current_vision_pos_embeds, feat_sizes, output_dict, num_frames, + multiplex_state=None): + B = current_vision_feats[-1].shape[0] + C = self.d_model + H, W = feat_sizes[-1] + device = current_vision_feats[-1].device + num_buc = multiplex_state.num_buckets if multiplex_state is not None else None + + if self.num_maskmem == 0: + return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W) + + if is_init_cond_frame: + pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs = collect_memory_tokens( + output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device, + collect_image_feats=True, tpos_v2=True, num_buckets=num_buc) + + max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder) + pos_and_ptrs = [] + for t, out in cond_outputs.items(): + if t <= frame_idx and "obj_ptr" in out: + ptr = out["obj_ptr"].to(device) + if num_buc is not None: + ptr = _pad_to_buckets(ptr, num_buc) + pos_and_ptrs.append(((frame_idx - t), ptr)) + for t_diff in range(1, max_obj_ptrs): + t = frame_idx - t_diff + if t < 0: + break + out = output_dict["non_cond_frame_outputs"].get(t, None) + if out is not None and "obj_ptr" in out: + ptr = out["obj_ptr"].to(device) + if num_buc is not None: + ptr = _pad_to_buckets(ptr, num_buc) + pos_and_ptrs.append((t_diff, ptr)) + + num_obj_ptr_tokens = 0 + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + obj_ptrs = torch.stack(ptrs_list, dim=1) # [num_buckets, N, M, C] + B_ptr = obj_ptrs.shape[0] + N_ptrs = obj_ptrs.shape[1] + M = obj_ptrs.shape[2] + obj_ptrs = obj_ptrs.reshape(B_ptr, N_ptrs * M, -1) + obj_pos = compute_tpos_enc(list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj, + max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype) + obj_pos = obj_pos.unsqueeze(0).expand(B_ptr, -1, -1) + obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, M, -1).reshape(B_ptr, N_ptrs * M, -1) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[1] + + if len(to_cat_memory) == 0: + pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + memory = torch.cat(to_cat_memory, dim=1) + memory_pos = torch.cat(to_cat_memory_pos, dim=1) + + # Expand vision features to num_buckets if memory has more buckets than B + mem_B = memory.shape[0] + x = current_vision_feats[-1] + x_pos = current_vision_pos_embeds[-1] + if x.shape[0] < mem_B: + x = x.expand(mem_B, -1, -1) + x_pos = x_pos.expand(mem_B, -1, -1) + + if len(to_cat_image_feat) > 0: + # Decoupled cross-attention: separate image features from memory + memory_image = cast_to_input(torch.cat(to_cat_image_feat, dim=1), x) + memory_image_pos = cast_to_input(torch.cat(to_cat_image_pos, dim=1), x) + if memory_image.shape[0] < mem_B: + memory_image = memory_image.expand(mem_B, -1, -1) + memory_image_pos = memory_image_pos.expand(mem_B, -1, -1) + pix_feat_with_mem = self.transformer.encoder( + x=x, + memory=cast_to_input(memory, x), + memory_pos=cast_to_input(memory_pos, x), + src_pos=cast_to_input(x_pos, x), + num_k_exclude_rope=num_obj_ptr_tokens, + memory_image=memory_image, + memory_image_pos=memory_image_pos, + ) + else: + pix_feat_with_mem = self.transformer.encoder( + x=x, + memory=memory, + memory_pos=memory_pos, + src_pos=x_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + ) + return to_spatial(pix_feat_with_mem, H, W) + + def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False, + multiplex_state=None, is_conditioning=False, cond_obj_mask=None): + if is_mask_from_pts: + mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype) + else: + mask_for_mem = torch.sigmoid(pred_masks_high_res) + mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc) + + # Mux masks: [N_obj, 1, H, W] -> [num_buckets, M, H, W] + mux_masks = multiplex_state.mux(mask_for_mem[:, 0]) + + # Conditioning channel: 1.0 = clean mask (trust it), 0.0 = propagation (noisy) + N_obj = mask_for_mem.shape[0] + cond_values = torch.full((N_obj,), 0.0, device=mask_for_mem.device, dtype=mask_for_mem.dtype) + if is_conditioning: + cond_values[:] = 1.0 + elif cond_obj_mask is not None: + cond_values[cond_obj_mask] = 1.0 + cond_spatial = cond_values.view(-1, 1, 1, 1).expand_as(mask_for_mem[:, 0:1, :, :]).squeeze(1) + mux_cond = multiplex_state.mux(cond_spatial) # [num_buckets, M, H, W] + mux_input = torch.cat([mux_masks, mux_cond], dim=1) # [num_buckets, 2*M, H, W] + + maskmem_out = self.maskmem_backbone(pix_feat, mux_input, skip_mask_sigmoid=True) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + # Add no_obj_embed_spatial for occluded objects + is_obj = (object_score_logits > 0).float() # [N_obj, 1] + mux_is_obj = multiplex_state.mux(is_obj) # [num_buckets, M, 1] + no_obj_embed = cast_to_input(self.no_obj_embed_spatial, maskmem_features) # [M, C] + no_obj_spatial = no_obj_embed.unsqueeze(0)[..., None, None] # [1, M, C, 1, 1] + # Expand and sum across multiplex slots weighted by (1 - is_obj) + alpha = mux_is_obj[..., None, None] # [num_buckets, M, 1, 1, 1] + per_slot_no_obj = ((1 - alpha) * no_obj_spatial).sum(dim=1) # [num_buckets, C, 1, 1] + maskmem_features = maskmem_features + per_slot_no_obj.expand_as(maskmem_features) + + return maskmem_features, maskmem_pos_enc + + def _forward_propagation(self, backbone_features, high_res_features=None, multiplex_state=None): + """Propagation path using the multiplex SAM decoder (no prompts).""" + B = backbone_features.shape[0] + device = backbone_features.device + + # Suppression embeddings from valid object mask + valid_mask = cast_to_input(multiplex_state.get_valid_object_mask().unsqueeze(-1).float(), backbone_features) + output_valid = cast_to_input(self.output_valid_embed, backbone_features).unsqueeze(0) + output_invalid = cast_to_input(self.output_invalid_embed, backbone_features).unsqueeze(0) + extra_embed = valid_mask * output_valid + (1 - valid_mask) * output_invalid + + image_pe = self.image_pe_layer((backbone_features.shape[-2], backbone_features.shape[-1]), device=backbone_features.device) + image_pe = cast_to_input(image_pe, backbone_features) + + masks, iou_pred, sam_tokens_out, object_score_logits = self.sam_mask_decoder( + image_embeddings=backbone_features, image_pe=image_pe, + sparse_prompt_embeddings=torch.empty(B, 0, self.d_model, device=device, dtype=backbone_features.dtype), + dense_prompt_embeddings=torch.zeros(B, self.d_model, *backbone_features.shape[-2:], device=device, dtype=backbone_features.dtype), + high_res_features=high_res_features, multimask_output=True, return_all=True, + extra_per_object_embeddings=extra_embed.expand(B, -1, -1), + ) + # masks: [B=num_buckets, M, T, H, W] + # Demux to per-object: [N_obj, T, H, W] + masks_obj = multiplex_state.demux(masks) + iou_obj = multiplex_state.demux(iou_pred) + score_obj = multiplex_state.demux(object_score_logits) + tokens_obj = multiplex_state.demux(sam_tokens_out) + + # Select best mask by IoU for each object + best_idx = torch.argmax(iou_obj, dim=-1) # [N_obj] + N_obj = masks_obj.shape[0] + obj_range = torch.arange(N_obj, device=device) + low_res_masks = masks_obj[obj_range, best_idx].unsqueeze(1) # [N_obj, 1, H, W] + # Suppress masks for objects with low confidence + is_obj = score_obj > 0 + low_res_masks = torch.where(is_obj[:, :, None, None], low_res_masks, + torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_masks.dtype)) + high_res_masks = F.interpolate(low_res_masks.float(), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + + # Object pointer: compute per-object, mux for storage as [num_buckets, M, C] + sam_token = tokens_obj[:, 0] # [N_obj, C] + obj_ptr = self.obj_ptr_proj(sam_token) + is_obj = (score_obj > 0).float() + no_obj = self.no_obj_ptr_linear(obj_ptr) + obj_ptr = is_obj * obj_ptr + (1 - is_obj) * no_obj + obj_ptr_muxed = multiplex_state.mux(obj_ptr) # [num_buckets, M, C] + + return low_res_masks, high_res_masks, obj_ptr_muxed, score_obj + + def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, + feat_sizes, mask_inputs, output_dict, num_frames, point_inputs=None, + interactive_high_res=None, interactive_backbone=None, propagation_high_res=None, + multiplex_state=None, run_mem_encoder=True): + current_out = {} + H, W = feat_sizes[-1] + + if mask_inputs is not None: + # Conditioning frame: use interactive features if available, else propagation + if interactive_backbone is not None: + pix_feat = interactive_backbone + # Add no_mem_embed for interactive path + pix_flat = pix_feat.flatten(2) + bf = pix_flat.permute(0, 2, 1) + cast_to_input(self.interactivity_no_mem_embed, pix_flat) + pix_feat = to_spatial(bf, H, W) + hi_res = interactive_high_res + else: + # Fallback: interactive backbone not available (e.g. called outside track_video). + # Propagation features work but may produce lower-quality conditioning. + pix_feat = to_spatial(current_vision_feats[-1], H, W) + hi_res = propagation_high_res + sam_outputs = self._use_mask_as_output(pix_feat, hi_res, mask_inputs) + elif point_inputs is not None: + # Interactive path: use interactive SAM decoder + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames, + multiplex_state=multiplex_state, + ) + hi_res = interactive_high_res if interactive_high_res is not None else propagation_high_res + num_pts = point_inputs["point_labels"].size(1) + multimask_output = is_init_cond_frame and 0 < num_pts <= 1 + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, point_inputs=point_inputs, + high_res_features=hi_res, multimask_output=multimask_output, + ) + else: + # Propagation path: use multiplex SAM decoder with propagation features + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames, + multiplex_state=multiplex_state, + ) + sam_outputs = self._forward_propagation(pix_feat_with_mem, propagation_high_res, + multiplex_state=multiplex_state) + + (low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs + + # Mux obj_ptr if it came from interactive path (shape [B, C]) vs propagation ([num_buckets, M, C]) + if multiplex_state is not None and obj_ptr.dim() == 2: + obj_ptr = multiplex_state.mux(obj_ptr) # [N_obj, C] -> [num_buckets, M, C] + + # Encode memory (can be deferred with run_mem_encoder=False) + if run_mem_encoder and self.num_maskmem > 0: + pix_feat = to_spatial(current_vision_feats[-1], H, W) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + multiplex_state=multiplex_state, + is_conditioning=(mask_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + # Store propagation image features for decoupled memory attention + current_out["image_features"] = current_vision_feats[-1] # [B, HW, C] + current_out["image_pos_enc"] = current_vision_pos_embeds[-1] # [B, HW, C] + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + current_out["object_score_logits"] = object_score_logits + + return current_out + + def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None): + vision_feats, vision_pos, feat_sizes, features, trunk_out = _compute_backbone(backbone_fn, frame, frame_idx) + return vision_feats, vision_pos, feat_sizes, list(features[:-1]), trunk_out + + @staticmethod + def _suppress_recently_occluded(low_res_masks, last_occluded, frame_idx, threshold=0.3): + """Suppress overlapping masks for objects that were most recently occluded. + Prevents corrupted masks from occluded objects from contaminating other objects.""" + N_obj = low_res_masks.shape[0] + if N_obj <= 1: + return low_res_masks + binary = low_res_masks[:, 0] > 0 # [N_obj, H, W] + iou = _compute_mask_overlap(low_res_masks[:, 0], low_res_masks[:, 0]) + overlapping = torch.triu(iou >= threshold, diagonal=1) # [N, N] upper triangle + last_occ_i = last_occluded.unsqueeze(1) # [N, 1] + last_occ_j = last_occluded.unsqueeze(0) # [1, N] + # Suppress the more recently occluded object in each overlapping pair + suppress_i = overlapping & (last_occ_i > last_occ_j) & (last_occ_j > -1) + suppress_j = overlapping & (last_occ_j > last_occ_i) & (last_occ_i > -1) + to_suppress = suppress_i.any(dim=1) | suppress_j.any(dim=0) + # Update last_occluded for occluded/suppressed objects + is_empty = ~binary.any(dim=(-1, -2)) + newly_occluded = is_empty | to_suppress + last_occluded[newly_occluded] = frame_idx + # Suppress masks + low_res_masks[to_suppress] = -10.0 + return low_res_masks + + def _deferred_memory_encode(self, current_out, N_obj, vision_feats, feat_sizes, mux_state, device, + cond_obj_mask=None): + """Deferred memory encoding for propagation frames. cond_obj_mask: per-object bool for conditioning.""" + low_res_masks = current_out["pred_masks"] # [N_obj, 1, H_low, W_low] + + if N_obj > 1: + lr = low_res_masks.squeeze(1) # [N_obj, H, W] + max_obj = torch.argmax(lr, dim=0, keepdim=True) + batch_inds = torch.arange(N_obj, device=device)[:, None, None] + pixel_nol = torch.where(max_obj == batch_inds, lr, torch.clamp(lr, max=-10.0)) + area_before = (lr > 0).sum(dim=(-1, -2)).float().clamp(min=1) + area_after = (pixel_nol > 0).sum(dim=(-1, -2)).float() + shrink_ok = (area_after / area_before) >= 0.3 + low_res_masks = torch.where( + shrink_ok[:, None, None, None].expand_as(low_res_masks), + low_res_masks, torch.clamp(low_res_masks, max=-10.0)) + + interpol_size = self.maskmem_backbone.mask_downsampler.interpol_size + mem_masks = F.interpolate(low_res_masks, size=interpol_size, + mode="bilinear", align_corners=False) + + obj_scores = torch.where( + (mem_masks > 0).any(dim=(-1, -2)), 10.0, -10.0) + + pix_feat = to_spatial(vision_feats[-1], feat_sizes[-1][0], feat_sizes[-1][1]) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, pred_masks_high_res=mem_masks, + object_score_logits=obj_scores, + multiplex_state=mux_state, cond_obj_mask=cond_obj_mask) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + + def _add_detected_objects(self, new_masks, mux_state, vision_feats, feat_sizes, current_out): + """Grow MultiplexState with new detections, merge masks, re-encode memory. Modifies current_out.""" + n_old = mux_state.total_valid_entries + mux_state.add_objects(new_masks.shape[0]) + N_obj = mux_state.total_valid_entries + # Stored memory with old bucket counts is padded at read time by _pad_to_buckets + for k in ("pred_masks", "pred_masks_high_res"): + det = F.interpolate(new_masks.unsqueeze(1), size=current_out[k].shape[-2:], + mode="bilinear", align_corners=False) + current_out[k] = torch.cat([current_out[k], det], dim=0) + if self.num_maskmem > 0: + # Mark new objects as conditioning (clean detection masks) so model trusts them + cond_mask = torch.zeros(N_obj, dtype=torch.bool, device=new_masks.device) + cond_mask[n_old:] = True + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, + mux_state, new_masks.device, cond_obj_mask=cond_mask) + + def _condition_with_masks(self, masks, frame_idx, vision_feats, vision_pos, feat_sizes, + high_res_prop, output_dict, N, mux_state, backbone_obj, frame, + trunk_out, threshold=0.5): + """Condition tracker with masks on a frame.""" + mask_input = F.interpolate(masks if masks.dim() == 4 else masks.unsqueeze(1), + size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + mask_input = (mask_input > threshold).to(masks.dtype) + hi_res = lo_feat = None + if backbone_obj is not None and backbone_obj.multiplex: + _, _, itf, _ = backbone_obj(frame, tracker_mode="interactive", cached_trunk=trunk_out, tracker_only=True) + hi_res, lo_feat = itf[:-1], itf[-1] + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=True, current_vision_feats=vision_feats, + current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=mask_input, + output_dict=output_dict, num_frames=N, interactive_high_res=hi_res, + interactive_backbone=lo_feat, propagation_high_res=high_res_prop, + multiplex_state=mux_state, run_mem_encoder=True) + output_dict["cond_frame_outputs"][frame_idx] = current_out + return current_out + + def _match_and_add_detections(self, det_masks, det_scores, current_out, mux_state, + vision_feats, feat_sizes, device, max_objects=0, + keep_alive=None): + """Match detections against tracked masks, add new objects, recondition degraded tracks. + Updates keep_alive counters: +1 for matched tracks, -1 for unmatched.""" + N_obj = mux_state.total_valid_entries + if det_masks.shape[0] == 0: + if keep_alive is not None: + for i in range(N_obj): + keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1) + return [] + + # Match at low-res (like reference) + trk_masks = current_out["pred_masks"][:, 0] # [N_obj, H_low, W_low] + det_resized = F.interpolate(det_masks.unsqueeze(1), size=trk_masks.shape[-2:], + mode="bilinear", align_corners=False)[:, 0] + overlap = _compute_mask_overlap(det_resized, trk_masks) + + # Update keep_alive and find matched tracks + matched = set() + if overlap.shape[1] > 0: + matched = set((overlap >= 0.5).any(dim=0).nonzero(as_tuple=True)[0].tolist()) + if keep_alive is not None: + for i in range(N_obj): + if i in matched: + keep_alive[i] = min(8, keep_alive.get(i, 0) + 1) + else: + keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1) + + # Recondition: high-confidence detections (>=0.8) with high overlap refresh tracked masks + reconditioned = False + if det_scores is not None and overlap.shape[1] > 0: + HIGH_CONF = 0.8 + for det_idx in range(overlap.shape[0]): + if det_scores[det_idx] < HIGH_CONF: + continue + best_trk = overlap[det_idx].argmax().item() + if overlap[det_idx, best_trk] >= 0.5: + # Replace tracked mask with fresh detection mask + current_out["pred_masks"][best_trk] = det_resized[det_idx].unsqueeze(0) + det_hr = F.interpolate(det_masks[det_idx:det_idx+1].unsqueeze(1), + size=current_out["pred_masks_high_res"].shape[-2:], + mode="bilinear", align_corners=False) + current_out["pred_masks_high_res"][best_trk] = det_hr[0] + reconditioned = True + + # Re-encode memory if any tracks were reconditioned + if reconditioned and self.num_maskmem > 0: + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device) + + # Add new detections (not matching any track) + if max_objects > 0 and N_obj >= max_objects: + return [] + max_overlap = overlap.max(dim=1)[0] if overlap.shape[1] > 0 else torch.zeros(overlap.shape[0], device=device) + new_dets = max_overlap < 0.5 + if new_dets.any(): + if max_objects > 0: + slots = max_objects - N_obj + new_dets = new_dets & (torch.cumsum(new_dets.int(), 0) <= slots) + self._add_detected_objects(det_masks[new_dets], mux_state, + vision_feats, feat_sizes, current_out) + if keep_alive is not None: + for i in range(N_obj, mux_state.total_valid_entries): + keep_alive[i] = 1 + return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item() + return [] + + def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None, + new_det_thresh=0.5, max_objects=0, detect_interval=1, + backbone_obj=None, pbar=None): + """Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits.""" + N, device, dt = images.shape[0], images.device, images.dtype + output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} + all_masks = [] + idev = comfy.model_management.intermediate_device() + empty = lambda: torch.zeros(0, self.image_size, self.image_size, device=idev, dtype=dt) + mux_state = None + if initial_masks is not None: + mux_state = MultiplexState(initial_masks.shape[0], self.num_multiplex, device, dt) + obj_scores = [] # per-object detection score (1.0 for initial masks) + keep_alive = {} if detect_fn is not None else None + last_occluded = torch.empty(0, device=device, dtype=torch.long) # per-object last occluded frame + + # Prefetch next frame's backbone on a separate CUDA stream + prefetch = False + backbone_stream = None + if comfy.model_management.is_device_cuda(device): + try: + backbone_stream = torch.cuda.Stream(device=device) + prefetch = True + except RuntimeError: + pass + cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0) + + for frame_idx in tqdm(range(N), desc="tracking"): + vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb + + # Start next frame's backbone on separate stream (overlaps with current frame's work) + if prefetch and frame_idx + 1 < N: + backbone_stream.wait_stream(torch.cuda.current_stream(device)) + with torch.cuda.stream(backbone_stream): + next_bb = self._compute_backbone_frame( + backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + + # Per-frame detection with NMS (skip if no detect_fn, or interval/max not met) + det_masks = torch.empty(0, device=device) + det_scores = None + run_det = (detect_fn is not None + and frame_idx % max(detect_interval, 1) == 0 + and not (max_objects > 0 and mux_state is not None + and mux_state.total_valid_entries >= max_objects)) + if run_det: + det_out = detect_fn(trunk_out) + scores = det_out["scores"][0].sigmoid() + keep = scores > new_det_thresh + det_masks, det_scores = det_out["masks"][0][keep], scores[keep] + if det_masks.shape[0] > 1: + det_masks, det_scores = _nms_masks(det_masks, det_scores) + + if frame_idx == 0 and initial_masks is not None: + current_out = self._condition_with_masks( + initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos, + feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj, + images[frame_idx:frame_idx + 1], trunk_out) + last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) + obj_scores = [1.0] * mux_state.total_valid_entries + if keep_alive is not None: + for i in range(mux_state.total_valid_entries): + keep_alive[i] = 8 + elif mux_state is None or mux_state.total_valid_entries == 0: + if det_masks.shape[0] > 0: + if max_objects > 0: + det_scores = det_scores[:max_objects] + det_masks = det_masks[:max_objects] + mux_state = MultiplexState(det_masks.shape[0], self.num_multiplex, device, dt) + current_out = self._condition_with_masks( + det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop, + output_dict, N, mux_state, backbone_obj, + images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0) + last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) + obj_scores = det_scores[:mux_state.total_valid_entries].tolist() + if keep_alive is not None: + for i in range(mux_state.total_valid_entries): + keep_alive[i] = 1 + else: + all_masks.append(empty()) + if pbar is not None: + pbar.update(1) + # Skip to backbone advance at end of loop + if frame_idx + 1 < N: + if prefetch: + torch.cuda.current_stream(device).wait_stream(backbone_stream) + cur_bb = next_bb + else: + cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + continue + else: + N_obj = mux_state.total_valid_entries + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=False, current_vision_feats=vision_feats, + current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=None, + output_dict=output_dict, num_frames=N, propagation_high_res=high_res_prop, + multiplex_state=mux_state, run_mem_encoder=False) + current_out["pred_masks"] = fill_holes_in_mask_scores( + current_out["pred_masks"], max_area=16) + if last_occluded.shape[0] == N_obj and N_obj > 1: + self._suppress_recently_occluded( + current_out["pred_masks"], last_occluded, frame_idx) + if self.num_maskmem > 0: + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device) + output_dict["non_cond_frame_outputs"][frame_idx] = current_out + lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder) + for old_idx in list(output_dict["non_cond_frame_outputs"]): + if old_idx < frame_idx - lookback: + del output_dict["non_cond_frame_outputs"][old_idx] + n_before = mux_state.total_valid_entries + new_obj_scores = self._match_and_add_detections(det_masks, det_scores, current_out, mux_state, + vision_feats, feat_sizes, device, max_objects, + keep_alive if run_det else None) + n_added = mux_state.total_valid_entries - n_before + if n_added > 0: + last_occluded = torch.cat([last_occluded, + torch.full((n_added,), -1, device=device, dtype=torch.long)]) + obj_scores.extend(new_obj_scores) + + masks_out = current_out["pred_masks_high_res"][:, 0] + if keep_alive is not None: + for i in range(masks_out.shape[0]): + if keep_alive.get(i, 0) <= 0: + masks_out[i] = NO_OBJ_SCORE + N_obj_now = mux_state.total_valid_entries if mux_state is not None else 0 + if N_obj_now > 0: + all_masks.append(pack_masks(masks_out).to(idev)) + else: + all_masks.append(None) + if pbar is not None: + pbar.update(1) + + # Next frame's backbone + if frame_idx + 1 < N: + if prefetch: + torch.cuda.current_stream(device).wait_stream(backbone_stream) + cur_bb = next_bb + else: + cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + + if not all_masks or all(m is None for m in all_masks): + return {"packed_masks": None, "n_frames": N, "scores": []} + + max_obj = max(m.shape[0] for m in all_masks if m is not None) + sample = next(m for m in all_masks if m is not None) + empty_packed = torch.zeros(max_obj, *sample.shape[1:], dtype=torch.uint8, device=sample.device) + for i, m in enumerate(all_masks): + if m is None: + all_masks[i] = empty_packed + elif m.shape[0] < max_obj: + pad = torch.zeros(max_obj - m.shape[0], *m.shape[1:], dtype=torch.uint8, device=m.device) + all_masks[i] = torch.cat([m, pad], dim=0) + return {"packed_masks": torch.stack(all_masks, dim=0), "n_frames": N, "scores": obj_scores} diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..49bf45c32 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,6 +54,7 @@ import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model +import comfy.ldm.sam3.detector import comfy.model_management import comfy.patcher_extension @@ -1974,3 +1975,7 @@ class ErnieImage(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class SAM3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ca06cdd1e..724a241bf 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "ernie" return dit_config + if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1 + if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "SAM3" + if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys: + dit_config["image_model"] = "SAM31" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal return model_config def unet_prefix_from_state_dict(state_dict): + # SAM3: detector.* and tracker.* at top level, no common prefix + if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict): + return "" + candidates = ["model.diffusion_model.", #ldm/sgm models "model.model.", #audio models "net.", #cosmos diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 58d4ce731..8886f32d5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1781,6 +1781,57 @@ class ErnieImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] +class SAM3(supported_models_base.BASE): + unet_config = {"image_model": "SAM3"} + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + text_encoder_key_prefix = ["detector.backbone.language_backbone."] + unet_extra_prefix = "" + + def process_clip_state_dict(self, state_dict): + clip_keys = getattr(self, "_clip_stash", {}) + clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True) + clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.") + return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")} + + def process_unet_state_dict(self, state_dict): + self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k} + # SAM3.1: remap tracker.model.* -> tracker.* + for k in list(state_dict.keys()): + if k.startswith("tracker.model."): + state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k) + # SAM3.1: remove per-block freqs_cis buffers (computed dynamically) + for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]: + state_dict.pop(k) + # Split fused QKV projections + for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]: + t = state_dict.pop(k) + base, suffix = k.rsplit(".in_proj_", 1) + s = ".weight" if suffix == "weight" else ".bias" + d = t.shape[0] // 3 + state_dict[base + ".q_proj" + s] = t[:d] + state_dict[base + ".k_proj" + s] = t[d:2*d] + state_dict[base + ".v_proj" + s] = t[2*d:] + # Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer + for k in list(state_dict.keys()): + if "sam_mask_decoder.transformer." not in k: + continue + new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.") + if new_k != k: + state_dict[new_k] = state_dict.pop(k) + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SAM3(self, device=device) + + def clip_target(self, state_dict={}): + import comfy.text_encoders.sam3_clip + return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper) + + +class SAM31(SAM3): + unet_config = {"image_model": "SAM31"} + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31] models += [SVD_img2vid] diff --git a/comfy/text_encoders/sam3_clip.py b/comfy/text_encoders/sam3_clip.py new file mode 100644 index 000000000..0e4249ece --- /dev/null +++ b/comfy/text_encoders/sam3_clip.py @@ -0,0 +1,97 @@ +import re +from comfy import sd1_clip + +SAM3_CLIP_CONFIG = { + "architectures": ["CLIPTextModel"], + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "max_position_embeddings": 32, + "projection_dim": 512, + "vocab_size": 49408, + "layer_norm_eps": 1e-5, + "eos_token_id": 49407, +} + + +class SAM3ClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options) + + +class SAM3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data) + self.disable_weights = True + + +def _parse_prompts(text): + """Split comma-separated prompts with optional :N max detections per category""" + text = text.replace("(", "").replace(")", "") + parts = [p.strip() for p in text.split(",") if p.strip()] + result = [] + for part in parts: + m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part) + if m: + text_part = m.group(1).strip() + val = m.group(2) + max_det = max(1, round(float(val))) + result.append((text_part, max_det)) + else: + result.append((part, 1)) + return result + + +class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip") + + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): + parsed = _parse_prompts(text) + if len(parsed) <= 1 and parsed[0][1] == 1: + return super().tokenize_with_weights(text, return_word_ids, **kwargs) + # Tokenize each prompt part separately, store per-part batches and metadata + inner = getattr(self, self.clip) + per_prompt = [] + for prompt_text, max_det in parsed: + batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs) + per_prompt.append((batches, max_det)) + # Main output uses first prompt's tokens (for compatibility) + out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt} + return out + + +class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip") + + def encode_token_weights(self, token_weight_pairs): + per_prompt = token_weight_pairs.pop("sam3_per_prompt", None) + if per_prompt is None: + return super().encode_token_weights(token_weight_pairs) + + # Encode each prompt separately, pack into extra dict + inner = getattr(self, self.clip) + multi_cond = [] + first_pooled = None + for batches, max_det in per_prompt: + out = inner.encode_token_weights(batches) + cond, pooled = out[0], out[1] + extra = out[2] if len(out) > 2 else {} + if first_pooled is None: + first_pooled = pooled + multi_cond.append({ + "cond": cond, + "attention_mask": extra.get("attention_mask"), + "max_detections": max_det, + }) + + # Return first prompt as main (for non-SAM3 consumers), all prompts in metadata + main = multi_cond[0] + main_extra = {} + if main["attention_mask"] is not None: + main_extra["attention_mask"] = main["attention_mask"] + main_extra["sam3_multi_cond"] = multi_cond + return (main["cond"], first_pooled, main_extra) diff --git a/comfy_extras/nodes_sam3.py b/comfy_extras/nodes_sam3.py new file mode 100644 index 000000000..7eee2c66e --- /dev/null +++ b/comfy_extras/nodes_sam3.py @@ -0,0 +1,514 @@ +""" +SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking. +""" + +from typing_extensions import override + +import json +import os +import torch +import torch.nn.functional as F +import comfy.model_management +import comfy.utils +import folder_paths +from comfy_api.latest import ComfyExtension, io, ui +import av +from fractions import Fraction + + +def _extract_text_prompts(conditioning, device, dtype): + """Extract list of (text_embeddings, text_mask) from conditioning.""" + cond_meta = conditioning[0][1] + multi = cond_meta.get("sam3_multi_cond") + prompts = [] + if multi is not None: + for entry in multi: + emb = entry["cond"].to(device=device, dtype=dtype) + mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None + if mask is None: + mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) + prompts.append((emb, mask, entry.get("max_detections", 1))) + else: + emb = conditioning[0][0].to(device=device, dtype=dtype) + mask = cond_meta.get("attention_mask") + if mask is not None: + mask = mask.to(device) + else: + mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) + prompts.append((emb, mask, 1)) + return prompts + + +def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations): + """Refine a coarse detector mask via SAM decoder, cropping to the detection box. + + Returns: [1, H, W] binary mask + """ + def _coarse_fallback(): + return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), + mode="bilinear", align_corners=False)[0] > 0).float() + + if iterations <= 0: + return _coarse_fallback() + + pad_frac = 0.1 + x1, y1, x2, y2 = box_xyxy.tolist() + bw, bh = x2 - x1, y2 - y1 + cx1 = max(0, int(x1 - bw * pad_frac)) + cy1 = max(0, int(y1 - bh * pad_frac)) + cx2 = min(W, int(x2 + bw * pad_frac)) + cy2 = min(H, int(y2 + bh * pad_frac)) + if cx2 <= cx1 or cy2 <= cy1: + return _coarse_fallback() + + crop = orig_image_hwc[cy1:cy2, cx1:cx2] + crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + crop_frame = crop_1008.to(device=device, dtype=dtype) + crop_h, crop_w = cy2 - cy1, cx2 - cx1 + + # Crop coarse mask and refine via SAM on the cropped image + mask_h, mask_w = coarse_mask.shape[-2:] + mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h) + mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h) + mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0) + for _ in range(iterations): + coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False) + mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input) + + refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False) + full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype) + full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop + coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False) + return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float() + + + +class SAM3_Detect(io.ComfyNode): + """Open-vocabulary detection and segmentation using text, box, or point prompts.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_Detect", + display_name="SAM3 Detect", + category="detection/", + search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"], + inputs=[ + io.Model.Input("model", display_name="model"), + io.Image.Input("image", display_name="image"), + io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"), + io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"), + io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), + io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), + io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01), + io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"), + io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"), + ], + outputs=[ + io.Mask.Output("masks"), + io.BoundingBox.Output("bboxes"), + ], + ) + + @classmethod + def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput: + B, H, W, C = image.shape + + image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + + # Convert bboxes to normalized cxcywh format [1, N, 4] + # BoundingBox type can be: single dict, list of dicts, or list of lists of dicts (per-frame) + boxes_tensor = None + if bboxes is not None: + # Flatten to list of dicts + if isinstance(bboxes, dict): + flat_boxes = [bboxes] + elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list): + flat_boxes = [d for frame in bboxes for d in frame] # per-frame list of lists + elif isinstance(bboxes, list): + flat_boxes = bboxes + else: + flat_boxes = [] + if flat_boxes: + coords = [] + for d in flat_boxes: + cx = (d["x"] + d["width"] / 2) / W + cy = (d["y"] + d["height"] / 2) / H + coords.append([cx, cy, d["width"] / W, d["height"] / H]) + boxes_tensor = torch.tensor([coords], dtype=torch.float32) # [1, N, 4] + + # Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...]) + pos_pts = json.loads(positive_coords) if positive_coords else [] + neg_pts = json.loads(negative_coords) if negative_coords else [] + has_points = len(pos_pts) > 0 or len(neg_pts) > 0 + + comfy.model_management.load_model_gpu(model) + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype() + sam3_model = model.model.diffusion_model + + # Build point inputs for tracker SAM decoder path + point_inputs = None + if has_points: + all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \ + [[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts] + all_labels = [1] * len(pos_pts) + [0] * len(neg_pts) + point_inputs = { + "point_coords": torch.tensor([all_coords], dtype=dtype, device=device), + "point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device), + } + + cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else [] + has_text = len(cond_list) > 0 + + # Run per-image through detector (text/boxes) and/or tracker (points) + all_bbox_dicts = [] + all_masks = [] + pbar = comfy.utils.ProgressBar(B) + b_boxes_tensor = boxes_tensor.to(device=device, dtype=dtype) if boxes_tensor is not None else None + + for b in range(B): + frame = image_in[b:b+1].to(device=device, dtype=dtype) + + frame_bbox_dicts = [] + frame_masks = [] + + # Point prompts: tracker SAM decoder path with iterative refinement + if point_inputs is not None: + mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs) + for _ in range(max(0, refine_iterations - 1)): + mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) + mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) + frame_masks.append((mask[0] > 0).float()) + + # Box prompts: SAM decoder path (segment inside each box) + if b_boxes_tensor is not None and not has_text: + for box_cxcywh in b_boxes_tensor[0]: + cx, cy, bw, bh = box_cxcywh.tolist() + # Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners + sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008], + [(cx + bw/2) * 1008, (cy + bh/2) * 1008]]], + device=device, dtype=dtype) + mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box) + for _ in range(max(0, refine_iterations - 1)): + mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) + mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) + frame_masks.append((mask[0] > 0).float()) + + # Text prompts: run detector per text prompt (each detects one category) + for text_embeddings, text_mask, max_det in cond_list: + results = sam3_model( + frame, text_embeddings=text_embeddings, text_mask=text_mask, + boxes=b_boxes_tensor, threshold=threshold, orig_size=(H, W)) + + pred_boxes = results["boxes"][0] + scores = results["scores"][0] + masks = results["masks"][0] + + probs = scores.sigmoid() + keep = probs > threshold + kept_boxes = pred_boxes[keep].cpu() + kept_scores = probs[keep].cpu() + kept_masks = masks[keep] + + order = kept_scores.argsort(descending=True)[:max_det] + kept_boxes = kept_boxes[order] + kept_scores = kept_scores[order] + kept_masks = kept_masks[order] + + for box, score in zip(kept_boxes, kept_scores): + frame_bbox_dicts.append({ + "x": float(box[0]), "y": float(box[1]), + "width": float(box[2] - box[0]), "height": float(box[3] - box[1]), + "score": float(score), + }) + for m, box in zip(kept_masks, kept_boxes): + frame_masks.append(_refine_mask( + sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations)) + + all_bbox_dicts.append(frame_bbox_dicts) + if len(frame_masks) > 0: + combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W] + if individual_masks: + all_masks.append(combined) + else: + all_masks.append((combined > 0).any(dim=0).float()) + else: + all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device())) + pbar.update(1) + + mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks) + return io.NodeOutput(mask_out, all_bbox_dicts) + + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") + +class SAM3_VideoTrack(io.ComfyNode): + """Track objects across video frames using SAM3's memory-based tracker.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_VideoTrack", + display_name="SAM3 Video Track", + category="detection/", + search_aliases=["sam3", "video", "track", "propagate"], + inputs=[ + io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"), + io.Model.Input("model", display_name="model"), + io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"), + io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"), + io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"), + io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."), + io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."), + ], + outputs=[ + SAM3TrackData.Output("track_data", display_name="track_data"), + ], + ) + + @classmethod + def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput: + N, H, W, C = images.shape + + comfy.model_management.load_model_gpu(model) + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype() + sam3_model = model.model.diffusion_model + + frames = images.movedim(-1, 1) + frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype) + + init_masks = None + if initial_mask is not None: + init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype) + + pbar = comfy.utils.ProgressBar(N) + + text_prompts = None + if conditioning is not None: + text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)] + elif initial_mask is None: + raise ValueError("Either initial_mask or conditioning must be provided") + + result = sam3_model.forward_video( + images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts, + new_det_thresh=detection_threshold, max_objects=max_objects, + detect_interval=detect_interval) + result["orig_size"] = (H, W) + return io.NodeOutput(result) + + +class SAM3_TrackPreview(io.ComfyNode): + """Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_TrackPreview", + display_name="SAM3 Track Preview", + category="detection/", + inputs=[ + SAM3TrackData.Input("track_data", display_name="track_data"), + io.Image.Input("images", display_name="images", optional=True), + io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05), + io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0), + ], + is_output_node=True, + ) + + COLORS = [ + (0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16), + (0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5), + (0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84), + ] + + # 5x3 bitmap font atlas for digits 0-9 [10, 5, 3] + _glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow) + + @staticmethod + def _get_glyphs(device, scale=3): + key = (device, scale) + if key in SAM3_TrackPreview._glyph_cache: + return SAM3_TrackPreview._glyph_cache[key] + atlas = torch.tensor([ + [[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]], + [[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]], + [[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]], + [[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]], + [[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]], + [[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]], + [[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]], + [[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]], + [[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]], + [[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]], + ], dtype=torch.bool) + glyphs, outlines = [], [] + for d in range(10): + g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1) + padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1)) + o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0) + glyphs.append(g.to(device)) + outlines.append(o.to(device)) + gh, gw = glyphs[0].shape + oh, ow = outlines[0].shape + SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow) + return SAM3_TrackPreview._glyph_cache[key] + + @staticmethod + def _draw_number_gpu(frame, number, cx, cy, color, scale=3): + """Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline.""" + H, W = frame.shape[:2] + device = frame.device + glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale) + color_t = torch.tensor(color, device=device, dtype=frame.dtype) + digs = [int(d) for d in str(number)] + total_w = len(digs) * (gw + scale) - scale + x0 = cx - total_w // 2 + y0 = cy - gh // 2 + for i, d in enumerate(digs): + dx = x0 + i * (gw + scale) + # Black outline + oy0, ox0 = y0 - 1, dx - 1 + osy1, osx1 = max(0, -oy0), max(0, -ox0) + osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0) + if osy2 > osy1 and osx2 > osx1: + fy1, fx1 = oy0 + osy1, ox0 + osx1 + frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0 + # Colored fill + sy1, sx1 = max(0, -y0), max(0, -dx) + sy2, sx2 = min(gh, H - y0), min(gw, W - dx) + if sy2 > sy1 and sx2 > sx1: + fy1, fx1 = y0 + sy1, dx + sx1 + frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t + + @classmethod + def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput: + + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + if images is not None: + H, W = images.shape[1], images.shape[2] + if packed is None: + N, N_obj = track_data["n_frames"], 0 + else: + N, N_obj = packed.shape[0], packed.shape[1] + + gpu = comfy.model_management.get_torch_device() + temp_dir = folder_paths.get_temp_directory() + filename = "sam3_track_preview.mp4" + filepath = os.path.join(temp_dir, filename) + with av.open(filepath, mode='w') as output: + stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000)) + stream.width = W + stream.height = H + stream.pix_fmt = 'yuv420p' + + frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8) + frame_np = frame_cpu.numpy() + if N_obj > 0: + colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)], + device=gpu, dtype=torch.float32) + grid_y = torch.arange(H, device=gpu).view(1, H, 1) + grid_x = torch.arange(W, device=gpu).view(1, 1, W) + for t in range(N): + if images is not None: + frame = images[t].clone() + else: + frame = torch.zeros(H, W, 3) + + if N_obj > 0: + frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool + frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0] + frame_gpu = frame.to(gpu) + bool_masks = frame_masks > 0.5 + any_mask = bool_masks.any(dim=0) + if any_mask.any(): + obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0) + color_overlay = colors_t[obj_idx_map] + mask_3d = any_mask.unsqueeze(-1) + frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu) + area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1) + cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area + cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area + has = area > 1 + scores = track_data.get("scores", []) + for obj_idx in range(N_obj): + if has[obj_idx]: + _cx, _cy = int(cx[obj_idx]), int(cy[obj_idx]) + color = cls.COLORS[obj_idx % len(cls.COLORS)] + SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color) + if obj_idx < len(scores) and scores[obj_idx] < 1.0: + SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100), + _cx, _cy + 5 * 3 + 3, color, scale=2) + frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte()) + else: + frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte()) + + vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24') + output.mux(stream.encode(vframe.reformat(format='yuv420p'))) + output.mux(stream.encode(None)) + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)])) + + +class SAM3_TrackToMask(io.ComfyNode): + """Select tracked objects by index and output as mask.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_TrackToMask", + display_name="SAM3 Track to Mask", + category="detection/", + inputs=[ + SAM3TrackData.Input("track_data", display_name="track_data"), + io.String.Input("object_indices", display_name="object_indices", default="", + tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."), + ], + outputs=[ + io.Mask.Output("masks", display_name="masks"), + ], + ) + + @classmethod + def execute(cls, track_data, object_indices="") -> io.NodeOutput: + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + + if packed is None: + N = track_data["n_frames"] + return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) + + N, N_obj = packed.shape[0], packed.shape[1] + + if object_indices.strip(): + indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] + indices = [i for i in indices if 0 <= i < N_obj] + else: + indices = list(range(N_obj)) + + if not indices: + return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) + + selected = packed[:, indices] + binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool + union = binary.any(dim=1, keepdim=True).float() + mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0] + return io.NodeOutput(mask_out) + + +class SAM3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SAM3_Detect, + SAM3_VideoTrack, + SAM3_TrackPreview, + SAM3_TrackToMask, + ] + + +async def comfy_entrypoint() -> SAM3Extension: + return SAM3Extension() diff --git a/nodes.py b/nodes.py index 299b3d758..9eaa1ae18 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_sam3.py" ] import_failed = []