Initial SAM3.1 support

This commit is contained in:
kijai 2026-04-14 23:57:38 +03:00
parent c5569e8627
commit 4ba28caa8c
9 changed files with 3489 additions and 2 deletions

596
comfy/ldm/sam3/detector.py Normal file
View File

@ -0,0 +1,596 @@
# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import roi_align
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker
from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__
from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine
TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker}
from comfy.ops import cast_to_input
def box_cxcywh_to_xyxy(x):
cx, cy, w, h = x.unbind(-1)
return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
def gen_sineembed_for_position(pos_tensor, num_feats=256):
"""Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats)."""
assert num_feats % 2 == 0
hdim = num_feats // 2
freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim)
embeds = []
for c in range(pos_tensor.shape[-1]):
raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs
embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2))
return torch.cat(embeds, dim=-1).to(pos_tensor.dtype)
class SplitMHA(nn.Module):
"""Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight)."""
def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
def forward(self, q_input, k_input=None, v_input=None, mask=None):
q = self.q_proj(q_input)
if k_input is None:
k = self.k_proj(q_input)
v = self.v_proj(q_input)
else:
k = self.k_proj(k_input)
v = self.v_proj(v_input if v_input is not None else k_input)
if mask is not None and mask.ndim == 2:
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
dtype = q.dtype # manual_cast may produce mixed dtypes
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask)
return self.out_proj(out)
class MLPWithNorm(nn.Module):
"""MLP with residual connection and output LayerNorm."""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
for i in range(num_layers)
])
self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype)
self.residual = residual and (input_dim == output_dim)
def forward(self, x):
orig = x
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = F.relu(x)
if self.residual:
x = x + orig
return self.out_norm(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, x, pos, text_memory=None, text_mask=None):
normed = self.norm1(x)
q_k = normed + pos
x = x + self.self_attn(q_k, q_k, normed)
if text_memory is not None:
normed = self.norm2(x)
x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask)
normed = self.norm3(x)
x = x + self.linear2(F.relu(self.linear1(normed)))
return x
class TransformerEncoder(nn.Module):
"""Checkpoint: transformer.encoder.layers.N.*"""
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
def forward(self, x, pos, text_memory=None, text_mask=None):
for layer in self.layers:
x = layer(x, pos, text_memory, text_mask)
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None):
q_k = x + x_pos
x = self.norm2(x + self.self_attn(q_k, q_k, x))
if text_memory is not None:
x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask))
x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias))
x = self.norm3(x + self.linear2(F.relu(self.linear1(x))))
return x
class TransformerDecoder(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6,
num_queries=200, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.num_queries = num_queries
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype)
self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes
self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each)
self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations)
self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype)
self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations)
self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
@staticmethod
def _inverse_sigmoid(x):
return torch.log(x / (1 - x + 1e-6) + 1e-6)
def _compute_box_rpb(self, ref_points, H, W):
"""Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias."""
boxes_xyxy = box_cxcywh_to_xyxy(ref_points)
B, Q, _ = boxes_xyxy.shape
coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H
coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W
deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2]
deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2]
log2_8 = float(math.log2(8))
def log_scale(d):
return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8
rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype))
rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype))
bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2)
pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype)
return torch.cat([pres_bias, bias], dim=2)
def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72):
B = memory.shape[0]
tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1)
presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1)
ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid()
for layer_idx, layer in enumerate(self.layers):
query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model))
tgt_with_pres = torch.cat([presence_out, tgt], dim=1)
pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1)
tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos,
text_memory, text_mask, self._compute_box_rpb(ref_points, H, W))
presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:]
if layer_idx < len(self.layers) - 1:
ref_inv = self._inverse_sigmoid(ref_points)
ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach()
query_out = self.norm(tgt)
ref_inv = self._inverse_sigmoid(ref_points)
boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid()
presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1)
return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence}
class Transformer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6,
num_queries=200, device=None, dtype=None, operations=None):
super().__init__()
self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations)
self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations)
class GeometryEncoder(nn.Module):
def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.roi_size = roi_size
self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype)
self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype)
self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype)
self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype)
self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype)
self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype)
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.encode = nn.ModuleList([
EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
def _encode_points(self, coords, labels, img_feat_2d):
"""Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized."""
B, N, _ = coords.shape
embed = self.points_direct_project(coords)
# Pool features from backbone at point locations via grid_sample
grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1]
sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1]
embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C]
# Positional encoding of coordinates
x, y = coords[:, :, 0], coords[:, :, 1] # [B, N]
pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1)
embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed))
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
return embed
def _encode_boxes(self, boxes, labels, img_feat_2d):
"""Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh."""
B, N, _ = boxes.shape
embed = self.boxes_direct_project(boxes)
# ROI align from backbone at box regions
H, W = img_feat_2d.shape[-2:]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
boxes_scaled = boxes_xyxy * scale
sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size)
proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C]
embed = embed + proj
# Positional encoding of box center + size
cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3]
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
enc = enc.view(B, N, -1)
embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed))
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
return embed
def forward(self, points=None, boxes=None, image_features=None):
"""Encode geometry prompts. image_features: [B, HW, C] flattened backbone features."""
# Prepare 2D image features for pooling
img_feat_2d = None
if image_features is not None:
B = image_features.shape[0]
HW, C = image_features.shape[1], image_features.shape[2]
hw = int(math.sqrt(HW))
img_normed = self.img_pre_norm(image_features)
img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw)
embeddings = []
if points is not None:
coords, labels = points
embeddings.append(self._encode_points(coords, labels, img_feat_2d))
if boxes is not None:
B = boxes.shape[0]
box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device)
embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d))
if not embeddings:
return None
geo = torch.cat(embeddings, dim=1)
geo = self.norm(geo)
if image_features is not None:
for layer in self.encode:
geo = layer(geo, torch.zeros_like(geo), image_features)
geo = self.encode_norm(geo)
return self.final_proj(geo)
class PixelDecoder(nn.Module):
"""Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation."""
def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None):
super().__init__()
self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)])
self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)])
def forward(self, backbone_features):
prev = backbone_features[-1]
for i, feat in enumerate(backbone_features[:-1][::-1]):
prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest"))))
return prev
class MaskPredictor(nn.Module):
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
super().__init__()
self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
def forward(self, query_embeddings, pixel_features):
mask_embed = self.mask_embed(query_embeddings)
return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features)
class SegmentationHead(nn.Module):
def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations)
self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations)
self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype)
self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype)
def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None):
if encoder_hidden_states is not None and prompt is not None:
enc_normed = self.cross_attn_norm(encoder_hidden_states)
enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask)
encoder_hidden_states = enc_cross + encoder_hidden_states
if encoder_hidden_states is not None:
B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1]
encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W)
backbone_features = list(backbone_features)
backbone_features[-1] = encoder_visual
pixel_features = self.pixel_decoder(backbone_features)
instance_features = self.instance_seg_head(pixel_features)
masks = self.mask_predictor(query_embeddings, instance_features)
return masks
class DotProductScoring(nn.Module):
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
super().__init__()
self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations)
self.scale = 1.0 / (d_model ** 0.5)
def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None):
prompt = self.prompt_mlp(prompt_embeddings)
if prompt_mask is not None:
weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype)
pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1)
else:
pooled = prompt.mean(dim=1)
hs = self.hs_proj(query_embeddings)
pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype)
scores = torch.matmul(hs, pp)
return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1)
class SAM3Detector(nn.Module):
def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
image_model = kwargs.pop("image_model", "SAM3")
for k in ("num_heads", "num_head_channels"):
kwargs.pop(k, None)
multiplex = image_model == "SAM31"
# SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0)
self.scalp = 0 if multiplex else 1
self.backbone = nn.ModuleDict({
"vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs),
"language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}),
})
self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations)
self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations)
self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations)
self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations)
def _get_backbone_features(self, images):
"""Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions)."""
bb = self.backbone["vision_backbone"]
if bb.multiplex:
all_f, all_p, tf, tp = bb(images, tracker_mode="propagation")
else:
all_f, all_p, tf, tp = bb(images, need_tracker=True)
return all_f, all_p, tf, tp
@staticmethod
def _run_geo_layer(layer, x, memory, memory_pos):
x = x + layer.self_attn(layer.norm1(x))
x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory)
x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x))))
return x
def _detect(self, features, positions, text_embeddings=None, text_mask=None,
points=None, boxes=None):
"""Shared detection: geometry encoding, transformer, scoring, segmentation."""
B = features[0].shape[0]
# Scalp for encoder (use top-level feature), but keep all levels for segmentation head
seg_features = features
if self.scalp > 0:
features = features[:-self.scalp]
positions = positions[:-self.scalp]
enc_feat, enc_pos = features[-1], positions[-1]
_, _, H, W = enc_feat.shape
img_flat = enc_feat.flatten(2).permute(0, 2, 1)
pos_flat = enc_pos.flatten(2).permute(0, 2, 1)
has_prompts = text_embeddings is not None or points is not None or boxes is not None
if has_prompts:
geo_enc = self.geometry_encoder
geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat)
geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1)))
for layer in geo_enc.encode:
geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat)
geo_cls = geo_enc.encode_norm(geo_cls)
if text_embeddings is not None and text_embeddings.shape[0] != B:
text_embeddings = text_embeddings.expand(B, -1, -1)
if text_mask is not None and text_mask.shape[0] != B:
text_mask = text_mask.expand(B, -1)
parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None]
text_embeddings = torch.cat(parts, dim=1)
n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0)
if text_mask is not None:
text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1)
else:
text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device)
memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask)
dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W)
query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"]
if text_embeddings is not None:
scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask)
else:
scores = torch.zeros(B, query_out.shape[1], device=query_out.device)
masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask)
return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out
def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None):
features, positions, _, _ = self._get_backbone_features(images)
if text_embeddings is not None:
text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings)
if text_mask is not None:
text_mask = text_mask.bool()
boxes_xyxy, scores, masks, dec_out = self._detect(
features, positions, text_embeddings, text_mask, points, boxes)
if orig_size is not None:
oh, ow = orig_size
boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype)
masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False)
return {
"boxes": boxes_xyxy,
"scores": scores,
"masks": masks,
"presence": dec_out.get("presence"),
}
def forward_from_trunk(self, trunk_out, text_embeddings, text_mask):
"""Run detection using a pre-computed ViTDet trunk output.
text_embeddings must already be resized through language_backbone.resizer.
Returns dict with boxes (normalized xyxy), scores, masks at detector resolution.
"""
bb = self.backbone["vision_backbone"]
features = [conv(trunk_out) for conv in bb.convs]
positions = [cast_to_input(bb.position_encoding(f), f) for f in features]
if text_mask is not None:
text_mask = text_mask.bool()
boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask)
return {"boxes": boxes_xyxy, "scores": scores, "masks": masks}
class SAM3Model(nn.Module):
def __init__(self, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
image_model = kwargs.get("image_model", "SAM3")
tracker_cls = TRACKER_CLASSES[image_model]
self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs)
self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs)
def forward(self, images, **kwargs):
return self.detector(images, **kwargs)
def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None):
"""Interactive segmentation using SAM decoder with point/box/mask prompts.
Args:
images: [B, 3, 1008, 1008] preprocessed images
point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space
box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space
mask_inputs: [B, 1, H, W] coarse mask logits to refine
Returns:
[B, 1, image_size, image_size] high-res mask logits
"""
bb = self.detector.backbone["vision_backbone"]
if bb.multiplex:
_, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive")
else:
_, _, tracker_features, tracker_positions = bb(images, need_tracker=True)
if self.detector.scalp > 0:
tracker_features = tracker_features[:-self.detector.scalp]
tracker_positions = tracker_positions[:-self.detector.scalp]
high_res = list(tracker_features[:-1])
backbone_feat = tracker_features[-1]
B, C, H, W = backbone_feat.shape
# Add no-memory embedding (init frame path)
no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None)
if no_mem is None:
no_mem = getattr(self.tracker, 'no_mem_embed', None)
if no_mem is not None:
feat_flat = backbone_feat.flatten(2).permute(0, 2, 1)
feat_flat = feat_flat + cast_to_input(no_mem, feat_flat)
backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2)
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
_, high_res_masks, _, _ = self.tracker._forward_sam_heads(
backbone_features=backbone_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
box_inputs=box_inputs,
high_res_features=high_res,
multimask_output=(0 < num_pts <= 1),
)
return high_res_masks
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1):
"""Track video with optional per-frame text-prompted detection."""
bb = self.detector.backbone["vision_backbone"]
def backbone_fn(frame, frame_idx=None):
trunk_out = bb.trunk(frame)
if bb.multiplex:
_, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True)
else:
_, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True)
return tf, tp, trunk_out
detect_fn = None
if text_prompts:
resizer = self.detector.backbone["language_backbone"]["resizer"]
resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts]
def detect_fn(trunk_out):
all_scores, all_masks = [], []
for emb, mask in resized:
det = self.detector.forward_from_trunk(trunk_out, emb, mask)
all_scores.append(det["scores"])
all_masks.append(det["masks"])
return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)}
if hasattr(self.tracker, 'track_video_with_detection'):
return self.tracker.track_video_with_detection(
backbone_fn, images, initial_masks, detect_fn,
new_det_thresh=new_det_thresh, max_objects=max_objects,
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
# SAM3 (non-multiplex) — no detection support, requires initial masks
if initial_masks is None:
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)

425
comfy/ldm/sam3/sam.py Normal file
View File

@ -0,0 +1,425 @@
# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.flux.layers import EmbedND
from comfy.ops import cast_to_input
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)])
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x)
return torch.sigmoid(x) if self.sigmoid_output else x
class SAMAttention(nn.Module):
def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
internal_dim = embedding_dim // downsample_rate
kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype)
def forward(self, q, k, v):
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
return self.out_proj(optimized_attention(q, k, v, self.num_heads))
class TwoWayAttentionBlock(nn.Module):
def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None):
super().__init__()
self.skip_first_layer_pe = skip_first_layer_pe
self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype))
self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, queries, keys, query_pe, key_pe):
if self.skip_first_layer_pe:
queries = self.norm1(self.self_attn(queries, queries, queries))
else:
q = queries + query_pe
queries = self.norm1(queries + self.self_attn(q, q, queries))
q, k = queries + query_pe, keys + key_pe
queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys))
queries = self.norm3(queries + self.mlp(queries))
q, k = queries + query_pe, keys + key_pe
keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries))
return queries, keys
class TwoWayTransformer(nn.Module):
def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate,
skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations)
for i in range(depth)
])
self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, image_embedding, image_pe, point_embedding):
queries, keys = point_embedding, image_embedding
for layer in self.layers:
queries, keys = layer(queries, keys, point_embedding, image_pe)
q, k = queries + point_embedding, keys + image_pe
queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys))
return queries, keys
class PositionEmbeddingRandom(nn.Module):
"""Fourier feature positional encoding with random gaussian projection."""
def __init__(self, num_pos_feats=64, scale=None):
super().__init__()
self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats))
def _encode(self, normalized_coords):
"""Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32."""
orig_dtype = normalized_coords.dtype
proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32)
projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix
return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype)
def forward(self, size, device=None):
h, w = size
dev = device if device is not None else self.positional_encoding_gaussian_matrix.device
ones = torch.ones((h, w), device=dev, dtype=torch.float32)
norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1)
return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0)
def forward_with_coords(self, pixel_coords, image_size):
norm = pixel_coords.clone()
norm[:, :, 0] /= image_size[1]
norm[:, :, 1] /= image_size[0]
return self._encode(norm)
# ViTDet backbone + FPN neck
def window_partition(x: torch.Tensor, window_size: int):
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0):
"""Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2]."""
t = torch.arange(end_x * end_y, dtype=torch.float32)
ids = torch.stack([(t % end_x) * scale_pos,
torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1)
return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0))
class _ViTMLP(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None):
super().__init__()
hidden = int(dim * mlp_ratio)
self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class Attention(nn.Module):
"""ViTDet multi-head attention with fused QKV projection."""
def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.use_rope = use_rope
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x, freqs_cis=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
if self.use_rope and freqs_cis is not None:
q, k = apply_rope(q, k, freqs_cis)
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations)
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
def forward(self, x, freqs_cis=None):
shortcut = x
x = self.norm1(x)
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = x.view(x.shape[0], self.window_size * self.window_size, -1)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(-1, self.window_size, self.window_size, x.shape[-1])
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
else:
B, H, W, C = x.shape
x = x.view(B, H * W, C)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(B, H, W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchEmbed(nn.Module):
def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None):
super().__init__()
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.proj(x)
class ViTDet(nn.Module):
def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24,
global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.global_att_blocks = set(global_att_blocks)
self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations)
num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype))
self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
grid_size = img_size // patch_size
pretrain_grid = pretrain_img_size // patch_size
self.blocks = nn.ModuleList()
for i in range(depth):
is_global = i in self.global_att_blocks
self.blocks.append(Block(
embed_dim, num_heads, mlp_ratio, qkv_bias,
window_size=0 if is_global else window_size,
use_rope=use_rope,
device=device, dtype=dtype, operations=operations,
))
if use_rope:
rope_scale = pretrain_grid / grid_size
self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False)
self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False)
else:
self.freqs_cis = None
self.freqs_cis_window = None
def _get_pos_embed(self, num_tokens):
pos = self.pos_embed
if pos.shape[1] == num_tokens:
return pos
cls_pos = pos[:, :1]
spatial_pos = pos[:, 1:]
old_size = int(math.sqrt(spatial_pos.shape[1]))
new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size
spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2)
tiles_h = new_size // old_size + 1
tiles_w = new_size // old_size + 1
tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size]
tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
return torch.cat([cls_pos, tiled], dim=1)
def forward(self, x):
x = self.patch_embed(x)
B, C, Hp, Wp = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C)
pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x)
x = x + pos[:, 1:Hp * Wp + 1]
x = x.view(B, Hp, Wp, C)
x = self.ln_pre(x)
freqs_cis_global = self.freqs_cis
freqs_cis_win = self.freqs_cis_window
if freqs_cis_global is not None:
freqs_cis_global = cast_to_input(freqs_cis_global, x)
if freqs_cis_win is not None:
freqs_cis_win = cast_to_input(freqs_cis_win, x)
for block in self.blocks:
fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global
x = block(x, freqs_cis=fc)
return x.permute(0, 3, 1, 2)
class FPNScaleConv(nn.Module):
def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None):
super().__init__()
if scale == 4.0:
self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 4
elif scale == 2.0:
self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 2
elif scale == 1.0:
proj_in = in_dim
elif scale == 0.5:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
proj_in = in_dim
self.scale = scale
self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype)
self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype)
def forward(self, x):
if self.scale == 4.0:
x = F.gelu(self.dconv_2x2_0(x))
x = self.dconv_2x2_1(x)
elif self.scale == 2.0:
x = self.dconv_2x2(x)
elif self.scale == 0.5:
x = self.pool(x)
x = self.conv_1x1(x)
x = self.conv_3x3(x)
return x
class PositionEmbeddingSine(nn.Module):
"""2D sinusoidal position encoding (DETR-style) with result caching."""
def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None):
super().__init__()
assert num_pos_feats % 2 == 0
self.half_dim = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
self.scale = scale if scale is not None else 2 * math.pi
self._cache = {}
def _sincos(self, vals):
"""Encode 1D values to interleaved sin/cos features."""
freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim)
raw = vals[..., None] * self.scale / freqs
return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2)
def _encode_xy(self, x, y):
"""Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim]."""
dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim)
pos_x = x[:, None] * self.scale / dim_t
pos_y = y[:, None] * self.scale / dim_t
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y
def encode_boxes(self, cx, cy, w, h):
"""Encode box center + size to [N, d_model+2] features."""
pos_x, pos_y = self._encode_xy(cx, cy)
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
def forward(self, x):
B, C, H, W = x.shape
key = (H, W, x.device)
if key not in self._cache:
gy = torch.arange(H, dtype=torch.float32, device=x.device)
gx = torch.arange(W, dtype=torch.float32, device=x.device)
if self.normalize:
gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6)
yy, xx = torch.meshgrid(gy, gx, indexing="ij")
self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0)
return self._cache[key].expand(B, -1, -1, -1)
class SAM3VisionBackbone(nn.Module):
def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs)
self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.multiplex = multiplex
fpn_args = dict(device=device, dtype=dtype, operations=operations)
if multiplex:
scales = [4.0, 2.0, 1.0]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
else:
scales = [4.0, 2.0, 1.0, 0.5]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False):
backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images)
if tracker_only:
# Skip detector FPN when only tracker features are needed (video tracking)
if self.multiplex:
tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs
else:
tracker_convs = self.sam2_convs
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return None, None, tracker_features, tracker_positions
features = [conv(backbone_out) for conv in self.convs]
positions = [cast_to_input(self.position_encoding(f), f) for f in features]
if self.multiplex:
if tracker_mode == "propagation":
tracker_convs = self.propagation_convs
elif tracker_mode == "interactive":
tracker_convs = self.interactive_convs
else:
return features, positions, None, None
elif need_tracker:
tracker_convs = self.sam2_convs
else:
return features, positions, None, None
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return features, positions, tracker_features, tracker_positions

1786
comfy/ldm/sam3/tracker.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -54,6 +54,7 @@ import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.model_management
import comfy.patcher_extension
@ -1974,3 +1975,7 @@ class ErnieImage(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class SAM3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)

View File

@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "ernie"
return dit_config
if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1
if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "SAM3"
if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys:
dit_config["image_model"] = "SAM31"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return model_config
def unet_prefix_from_state_dict(state_dict):
# SAM3: detector.* and tracker.* at top level, no common prefix
if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict):
return ""
candidates = ["model.diffusion_model.", #ldm/sgm models
"model.model.", #audio models
"net.", #cosmos

View File

@ -1781,6 +1781,57 @@ class ErnieImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
class SAM3(supported_models_base.BASE):
unet_config = {"image_model": "SAM3"}
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
text_encoder_key_prefix = ["detector.backbone.language_backbone."]
unet_extra_prefix = ""
def process_clip_state_dict(self, state_dict):
clip_keys = getattr(self, "_clip_stash", {})
clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True)
clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.")
return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")}
def process_unet_state_dict(self, state_dict):
self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k}
# SAM3.1: remap tracker.model.* -> tracker.*
for k in list(state_dict.keys()):
if k.startswith("tracker.model."):
state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k)
# SAM3.1: remove per-block freqs_cis buffers (computed dynamically)
for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]:
state_dict.pop(k)
# Split fused QKV projections
for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]:
t = state_dict.pop(k)
base, suffix = k.rsplit(".in_proj_", 1)
s = ".weight" if suffix == "weight" else ".bias"
d = t.shape[0] // 3
state_dict[base + ".q_proj" + s] = t[:d]
state_dict[base + ".k_proj" + s] = t[d:2*d]
state_dict[base + ".v_proj" + s] = t[2*d:]
# Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer
for k in list(state_dict.keys()):
if "sam_mask_decoder.transformer." not in k:
continue
new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.")
if new_k != k:
state_dict[new_k] = state_dict.pop(k)
return state_dict
def get_model(self, state_dict, prefix="", device=None):
return model_base.SAM3(self, device=device)
def clip_target(self, state_dict={}):
import comfy.text_encoders.sam3_clip
return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper)
class SAM31(SAM3):
unet_config = {"image_model": "SAM31"}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
models += [SVD_img2vid]

View File

@ -0,0 +1,97 @@
import re
from comfy import sd1_clip
SAM3_CLIP_CONFIG = {
"architectures": ["CLIPTextModel"],
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"max_position_embeddings": 32,
"projection_dim": 512,
"vocab_size": 49408,
"layer_norm_eps": 1e-5,
"eos_token_id": 49407,
}
class SAM3ClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options)
class SAM3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data)
self.disable_weights = True
def _parse_prompts(text):
"""Split comma-separated prompts with optional :N max detections per category"""
text = text.replace("(", "").replace(")", "")
parts = [p.strip() for p in text.split(",") if p.strip()]
result = []
for part in parts:
m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part)
if m:
text_part = m.group(1).strip()
val = m.group(2)
max_det = max(1, round(float(val)))
result.append((text_part, max_det))
else:
result.append((part, 1))
return result
class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip")
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
parsed = _parse_prompts(text)
if len(parsed) <= 1 and parsed[0][1] == 1:
return super().tokenize_with_weights(text, return_word_ids, **kwargs)
# Tokenize each prompt part separately, store per-part batches and metadata
inner = getattr(self, self.clip)
per_prompt = []
for prompt_text, max_det in parsed:
batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs)
per_prompt.append((batches, max_det))
# Main output uses first prompt's tokens (for compatibility)
out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt}
return out
class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip")
def encode_token_weights(self, token_weight_pairs):
per_prompt = token_weight_pairs.pop("sam3_per_prompt", None)
if per_prompt is None:
return super().encode_token_weights(token_weight_pairs)
# Encode each prompt separately, pack into extra dict
inner = getattr(self, self.clip)
multi_cond = []
first_pooled = None
for batches, max_det in per_prompt:
out = inner.encode_token_weights(batches)
cond, pooled = out[0], out[1]
extra = out[2] if len(out) > 2 else {}
if first_pooled is None:
first_pooled = pooled
multi_cond.append({
"cond": cond,
"attention_mask": extra.get("attention_mask"),
"max_detections": max_det,
})
# Return first prompt as main (for non-SAM3 consumers), all prompts in metadata
main = multi_cond[0]
main_extra = {}
if main["attention_mask"] is not None:
main_extra["attention_mask"] = main["attention_mask"]
main_extra["sam3_multi_cond"] = multi_cond
return (main["cond"], first_pooled, main_extra)

514
comfy_extras/nodes_sam3.py Normal file
View File

@ -0,0 +1,514 @@
"""
SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking.
"""
from typing_extensions import override
import json
import os
import torch
import torch.nn.functional as F
import comfy.model_management
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, io, ui
import av
from fractions import Fraction
def _extract_text_prompts(conditioning, device, dtype):
"""Extract list of (text_embeddings, text_mask) from conditioning."""
cond_meta = conditioning[0][1]
multi = cond_meta.get("sam3_multi_cond")
prompts = []
if multi is not None:
for entry in multi:
emb = entry["cond"].to(device=device, dtype=dtype)
mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None
if mask is None:
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
prompts.append((emb, mask, entry.get("max_detections", 1)))
else:
emb = conditioning[0][0].to(device=device, dtype=dtype)
mask = cond_meta.get("attention_mask")
if mask is not None:
mask = mask.to(device)
else:
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
prompts.append((emb, mask, 1))
return prompts
def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations):
"""Refine a coarse detector mask via SAM decoder, cropping to the detection box.
Returns: [1, H, W] binary mask
"""
def _coarse_fallback():
return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W),
mode="bilinear", align_corners=False)[0] > 0).float()
if iterations <= 0:
return _coarse_fallback()
pad_frac = 0.1
x1, y1, x2, y2 = box_xyxy.tolist()
bw, bh = x2 - x1, y2 - y1
cx1 = max(0, int(x1 - bw * pad_frac))
cy1 = max(0, int(y1 - bh * pad_frac))
cx2 = min(W, int(x2 + bw * pad_frac))
cy2 = min(H, int(y2 + bh * pad_frac))
if cx2 <= cx1 or cy2 <= cy1:
return _coarse_fallback()
crop = orig_image_hwc[cy1:cy2, cx1:cx2]
crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
crop_frame = crop_1008.to(device=device, dtype=dtype)
crop_h, crop_w = cy2 - cy1, cx2 - cx1
# Crop coarse mask and refine via SAM on the cropped image
mask_h, mask_w = coarse_mask.shape[-2:]
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
for _ in range(iterations):
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input)
refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False)
full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype)
full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop
coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False)
return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float()
class SAM3_Detect(io.ComfyNode):
"""Open-vocabulary detection and segmentation using text, box, or point prompts."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_Detect",
display_name="SAM3 Detect",
category="detection/",
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
inputs=[
io.Model.Input("model", display_name="model"),
io.Image.Input("image", display_name="image"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"),
io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"),
io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01),
io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"),
io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"),
],
outputs=[
io.Mask.Output("masks"),
io.BoundingBox.Output("bboxes"),
],
)
@classmethod
def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput:
B, H, W, C = image.shape
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
# Convert bboxes to normalized cxcywh format [1, N, 4]
# BoundingBox type can be: single dict, list of dicts, or list of lists of dicts (per-frame)
boxes_tensor = None
if bboxes is not None:
# Flatten to list of dicts
if isinstance(bboxes, dict):
flat_boxes = [bboxes]
elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list):
flat_boxes = [d for frame in bboxes for d in frame] # per-frame list of lists
elif isinstance(bboxes, list):
flat_boxes = bboxes
else:
flat_boxes = []
if flat_boxes:
coords = []
for d in flat_boxes:
cx = (d["x"] + d["width"] / 2) / W
cy = (d["y"] + d["height"] / 2) / H
coords.append([cx, cy, d["width"] / W, d["height"] / H])
boxes_tensor = torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
# Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...])
pos_pts = json.loads(positive_coords) if positive_coords else []
neg_pts = json.loads(negative_coords) if negative_coords else []
has_points = len(pos_pts) > 0 or len(neg_pts) > 0
comfy.model_management.load_model_gpu(model)
device = comfy.model_management.get_torch_device()
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
# Build point inputs for tracker SAM decoder path
point_inputs = None
if has_points:
all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \
[[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts]
all_labels = [1] * len(pos_pts) + [0] * len(neg_pts)
point_inputs = {
"point_coords": torch.tensor([all_coords], dtype=dtype, device=device),
"point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device),
}
cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else []
has_text = len(cond_list) > 0
# Run per-image through detector (text/boxes) and/or tracker (points)
all_bbox_dicts = []
all_masks = []
pbar = comfy.utils.ProgressBar(B)
b_boxes_tensor = boxes_tensor.to(device=device, dtype=dtype) if boxes_tensor is not None else None
for b in range(B):
frame = image_in[b:b+1].to(device=device, dtype=dtype)
frame_bbox_dicts = []
frame_masks = []
# Point prompts: tracker SAM decoder path with iterative refinement
if point_inputs is not None:
mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs)
for _ in range(max(0, refine_iterations - 1)):
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
frame_masks.append((mask[0] > 0).float())
# Box prompts: SAM decoder path (segment inside each box)
if b_boxes_tensor is not None and not has_text:
for box_cxcywh in b_boxes_tensor[0]:
cx, cy, bw, bh = box_cxcywh.tolist()
# Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners
sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008],
[(cx + bw/2) * 1008, (cy + bh/2) * 1008]]],
device=device, dtype=dtype)
mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box)
for _ in range(max(0, refine_iterations - 1)):
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
frame_masks.append((mask[0] > 0).float())
# Text prompts: run detector per text prompt (each detects one category)
for text_embeddings, text_mask, max_det in cond_list:
results = sam3_model(
frame, text_embeddings=text_embeddings, text_mask=text_mask,
boxes=b_boxes_tensor, threshold=threshold, orig_size=(H, W))
pred_boxes = results["boxes"][0]
scores = results["scores"][0]
masks = results["masks"][0]
probs = scores.sigmoid()
keep = probs > threshold
kept_boxes = pred_boxes[keep].cpu()
kept_scores = probs[keep].cpu()
kept_masks = masks[keep]
order = kept_scores.argsort(descending=True)[:max_det]
kept_boxes = kept_boxes[order]
kept_scores = kept_scores[order]
kept_masks = kept_masks[order]
for box, score in zip(kept_boxes, kept_scores):
frame_bbox_dicts.append({
"x": float(box[0]), "y": float(box[1]),
"width": float(box[2] - box[0]), "height": float(box[3] - box[1]),
"score": float(score),
})
for m, box in zip(kept_masks, kept_boxes):
frame_masks.append(_refine_mask(
sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations))
all_bbox_dicts.append(frame_bbox_dicts)
if len(frame_masks) > 0:
combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W]
if individual_masks:
all_masks.append(combined)
else:
all_masks.append((combined > 0).any(dim=0).float())
else:
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
pbar.update(1)
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
return io.NodeOutput(mask_out, all_bbox_dicts)
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
class SAM3_VideoTrack(io.ComfyNode):
"""Track objects across video frames using SAM3's memory-based tracker."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_VideoTrack",
display_name="SAM3 Video Track",
category="detection/",
search_aliases=["sam3", "video", "track", "propagate"],
inputs=[
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
io.Model.Input("model", display_name="model"),
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
],
outputs=[
SAM3TrackData.Output("track_data", display_name="track_data"),
],
)
@classmethod
def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput:
N, H, W, C = images.shape
comfy.model_management.load_model_gpu(model)
device = comfy.model_management.get_torch_device()
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
frames = images.movedim(-1, 1)
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
init_masks = None
if initial_mask is not None:
init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype)
pbar = comfy.utils.ProgressBar(N)
text_prompts = None
if conditioning is not None:
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
elif initial_mask is None:
raise ValueError("Either initial_mask or conditioning must be provided")
result = sam3_model.forward_video(
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
new_det_thresh=detection_threshold, max_objects=max_objects,
detect_interval=detect_interval)
result["orig_size"] = (H, W)
return io.NodeOutput(result)
class SAM3_TrackPreview(io.ComfyNode):
"""Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_TrackPreview",
display_name="SAM3 Track Preview",
category="detection/",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.Image.Input("images", display_name="images", optional=True),
io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05),
io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0),
],
is_output_node=True,
)
COLORS = [
(0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16),
(0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5),
(0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84),
]
# 5x3 bitmap font atlas for digits 0-9 [10, 5, 3]
_glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow)
@staticmethod
def _get_glyphs(device, scale=3):
key = (device, scale)
if key in SAM3_TrackPreview._glyph_cache:
return SAM3_TrackPreview._glyph_cache[key]
atlas = torch.tensor([
[[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]],
[[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]],
[[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]],
[[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]],
[[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]],
[[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]],
[[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]],
[[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]],
[[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]],
[[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]],
], dtype=torch.bool)
glyphs, outlines = [], []
for d in range(10):
g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1)
padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1))
o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0)
glyphs.append(g.to(device))
outlines.append(o.to(device))
gh, gw = glyphs[0].shape
oh, ow = outlines[0].shape
SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow)
return SAM3_TrackPreview._glyph_cache[key]
@staticmethod
def _draw_number_gpu(frame, number, cx, cy, color, scale=3):
"""Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline."""
H, W = frame.shape[:2]
device = frame.device
glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale)
color_t = torch.tensor(color, device=device, dtype=frame.dtype)
digs = [int(d) for d in str(number)]
total_w = len(digs) * (gw + scale) - scale
x0 = cx - total_w // 2
y0 = cy - gh // 2
for i, d in enumerate(digs):
dx = x0 + i * (gw + scale)
# Black outline
oy0, ox0 = y0 - 1, dx - 1
osy1, osx1 = max(0, -oy0), max(0, -ox0)
osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0)
if osy2 > osy1 and osx2 > osx1:
fy1, fx1 = oy0 + osy1, ox0 + osx1
frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0
# Colored fill
sy1, sx1 = max(0, -y0), max(0, -dx)
sy2, sx2 = min(gh, H - y0), min(gw, W - dx)
if sy2 > sy1 and sx2 > sx1:
fy1, fx1 = y0 + sy1, dx + sx1
frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t
@classmethod
def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput:
from comfy.ldm.sam3.tracker import unpack_masks
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
if images is not None:
H, W = images.shape[1], images.shape[2]
if packed is None:
N, N_obj = track_data["n_frames"], 0
else:
N, N_obj = packed.shape[0], packed.shape[1]
gpu = comfy.model_management.get_torch_device()
temp_dir = folder_paths.get_temp_directory()
filename = "sam3_track_preview.mp4"
filepath = os.path.join(temp_dir, filename)
with av.open(filepath, mode='w') as output:
stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000))
stream.width = W
stream.height = H
stream.pix_fmt = 'yuv420p'
frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8)
frame_np = frame_cpu.numpy()
if N_obj > 0:
colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)],
device=gpu, dtype=torch.float32)
grid_y = torch.arange(H, device=gpu).view(1, H, 1)
grid_x = torch.arange(W, device=gpu).view(1, 1, W)
for t in range(N):
if images is not None:
frame = images[t].clone()
else:
frame = torch.zeros(H, W, 3)
if N_obj > 0:
frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool
frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0]
frame_gpu = frame.to(gpu)
bool_masks = frame_masks > 0.5
any_mask = bool_masks.any(dim=0)
if any_mask.any():
obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0)
color_overlay = colors_t[obj_idx_map]
mask_3d = any_mask.unsqueeze(-1)
frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu)
area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1)
cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
has = area > 1
scores = track_data.get("scores", [])
for obj_idx in range(N_obj):
if has[obj_idx]:
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
color = cls.COLORS[obj_idx % len(cls.COLORS)]
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
_cx, _cy + 5 * 3 + 3, color, scale=2)
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
else:
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
output.mux(stream.encode(vframe.reformat(format='yuv420p')))
output.mux(stream.encode(None))
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)]))
class SAM3_TrackToMask(io.ComfyNode):
"""Select tracked objects by index and output as mask."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SAM3_TrackToMask",
display_name="SAM3 Track to Mask",
category="detection/",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.String.Input("object_indices", display_name="object_indices", default="",
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."),
],
outputs=[
io.Mask.Output("masks", display_name="masks"),
],
)
@classmethod
def execute(cls, track_data, object_indices="") -> io.NodeOutput:
from comfy.ldm.sam3.tracker import unpack_masks
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
if packed is None:
N = track_data["n_frames"]
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
N, N_obj = packed.shape[0], packed.shape[1]
if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
indices = [i for i in indices if 0 <= i < N_obj]
else:
indices = list(range(N_obj))
if not indices:
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
selected = packed[:, indices]
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
union = binary.any(dim=1, keepdim=True).float()
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
return io.NodeOutput(mask_out)
class SAM3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SAM3_Detect,
SAM3_VideoTrack,
SAM3_TrackPreview,
SAM3_TrackToMask,
]
async def comfy_entrypoint() -> SAM3Extension:
return SAM3Extension()

View File

@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py",
"nodes_painter.py",
"nodes_curve.py",
"nodes_rtdetr.py"
"nodes_rtdetr.py",
"nodes_sam3.py"
]
import_failed = []