ComfyUI/comfy/ldm/sam3/sam.py
Jedrzej Kosinski 1b96430c60
Some checks are pending
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Merge master into worksplit-multigpu (#13546)
* fix: pin SQLAlchemy>=2.0 in requirements.txt (fixes #13036) (#13316)

* Refactor io to IO in nodes_ace.py (#13485)

* Bump comfyui-frontend-package to 1.42.12 (#13489)

* Make the ltx audio vae more native. (#13486)

* feat(api-nodes): add automatic downscaling of videos for ByteDance 2 nodes (#13465)

* Support standalone LTXV audio VAEs (#13499)

* [Partner Nodes]  added 4K resolution for Veo models; added Veo 3 Lite model (#13330)

* feat(api nodes): added 4K resolution for Veo models; added Veo 3 Lite model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* increase poll_interval from 5 to 9

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>

* Bump comfyui-frontend-package to 1.42.14 (#13493)

* Add gpt-image-2 as version option (#13501)

* Allow logging in comfy app files. (#13505)

* chore: update workflow templates to v0.9.59 (#13507)

* fix(veo): reject 4K resolution for veo-3.0 models in Veo3VideoGenerationNode (#13504)

The tooltip on the resolution input states that 4K is not available for
veo-3.1-lite or veo-3.0 models, but the execute guard only rejected the
lite combination. Selecting 4K with veo-3.0-generate-001 or
veo-3.0-fast-generate-001 would fall through and hit the upstream API
with an invalid request.

Broaden the guard to match the documented behavior and update the error
message accordingly.

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>

* feat: RIFE and FILM frame interpolation model support (CORE-29) (#13258)

* initial RIFE support

* Also support FILM

* Better RAM usage, reduce FILM VRAM peak

* Add model folder placeholder

* Fix oom fallback frame loss

* Remove torch.compile for now

* Rename model input

* Shorter input type name

---------

* fix: use Parameter assignment for Stable_Zero123 cc_projection weights (fixes #13492) (#13518)

On Windows with aimdo enabled, disable_weight_init.Linear uses lazy
initialization that sets weight and bias to None to avoid unnecessary
memory allocation. This caused a crash when copy_() was called on the
None weight attribute in Stable_Zero123.__init__.

Replace copy_() with direct torch.nn.Parameter assignment, which works
correctly on both Windows (aimdo enabled) and other platforms.

* Derive InterruptProcessingException from BaseException (#13523)

* bump manager version to 4.2.1 (#13516)

* ModelPatcherDynamic: force cast stray weights on comfy layers (#13487)

the mixed_precision ops can have input_scale parameters that are used
in tensor math but arent a weight or bias so dont get proper VRAM
management. Treat these as force-castable parameters like the non comfy
weight, random params are buffers already are.

* Update logging level for invalid version format (#13526)

* [Partner Nodes] add SD2 real human support (#13509)

* feat(api-nodes): add SD2 real human support

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* fix: add validation before uploading Assets

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* Add asset_id and group_id displaying on the node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* extend poll_op to use instead of custom async cycle

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* added the polling for the "Active" status after asset creation

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* updated tooltip for group_id

* allow usage of real human in the ByteDance2FirstLastFrame node

* add reference count limits

* corrected price in status when input assets contain video

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat: SAM (segment anything) 3.1 support (CORE-34) (#13408)

* [Partner Nodes] GPTImage: fix price badges, add new resolutions (#13519)

* fix(api-nodes): fixed price badges, add new resolutions

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* proper calculate the total run cost when "n > 1"

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* chore: update workflow templates to v0.9.61 (#13533)

* chore: update embedded docs to v0.4.4 (#13535)

* add 4K resolution to Kling nodes (#13536)

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* Fix LTXV Reference Audio node (#13531)

* comfy-aimdo 0.2.14: Hotfix async allocator estimations (#13534)

This was doing an over-estimate of VRAM used by the async allocator when lots
of little small tensors were in play.

Also change the versioning scheme to == so we can roll forward aimdo without
worrying about stable regressions downstream in comfyUI core.

* Disable sageattention for SAM3 (#13529)

Causes Nans

* execution: Add anti-cycle validation (#13169)

Currently if the graph contains a cycle, the just inifitiate recursions,
hits a catch all then throws a generic error against the output node
that seeded the validation. Instead, fail the offending cycling mode
chain and handlng it as an error in its own right.

Co-authored-by: guill <jacob.e.segal@gmail.com>

* chore: update workflow templates to v0.9.62 (#13539)

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Octopus <liyuan851277048@icloud.com>
Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Co-authored-by: Comfy Org PR Bot <snomiao+comfy-pr@gmail.com>
Co-authored-by: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Co-authored-by: Jukka Seppänen <40791699+kijai@users.noreply.github.com>
Co-authored-by: AustinMroz <austin@comfy.org>
Co-authored-by: Daxiong (Lin) <contact@comfyui-wiki.com>
Co-authored-by: Matt Miller <matt@miller-media.com>
Co-authored-by: blepping <157360029+blepping@users.noreply.github.com>
Co-authored-by: Dr.Lt.Data <128333288+ltdrdata@users.noreply.github.com>
Co-authored-by: rattus <46076784+rattus128@users.noreply.github.com>
Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-04-23 19:20:14 -07:00

426 lines
20 KiB
Python

# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.flux.layers import EmbedND
from comfy.ops import cast_to_input
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)])
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x)
return torch.sigmoid(x) if self.sigmoid_output else x
class SAMAttention(nn.Module):
def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
internal_dim = embedding_dim // downsample_rate
kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype)
def forward(self, q, k, v):
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
class TwoWayAttentionBlock(nn.Module):
def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None):
super().__init__()
self.skip_first_layer_pe = skip_first_layer_pe
self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype))
self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, queries, keys, query_pe, key_pe):
if self.skip_first_layer_pe:
queries = self.norm1(self.self_attn(queries, queries, queries))
else:
q = queries + query_pe
queries = self.norm1(queries + self.self_attn(q, q, queries))
q, k = queries + query_pe, keys + key_pe
queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys))
queries = self.norm3(queries + self.mlp(queries))
q, k = queries + query_pe, keys + key_pe
keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries))
return queries, keys
class TwoWayTransformer(nn.Module):
def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate,
skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations)
for i in range(depth)
])
self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, image_embedding, image_pe, point_embedding):
queries, keys = point_embedding, image_embedding
for layer in self.layers:
queries, keys = layer(queries, keys, point_embedding, image_pe)
q, k = queries + point_embedding, keys + image_pe
queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys))
return queries, keys
class PositionEmbeddingRandom(nn.Module):
"""Fourier feature positional encoding with random gaussian projection."""
def __init__(self, num_pos_feats=64, scale=None):
super().__init__()
self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats))
def _encode(self, normalized_coords):
"""Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32."""
orig_dtype = normalized_coords.dtype
proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32)
projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix
return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype)
def forward(self, size, device=None):
h, w = size
dev = device if device is not None else self.positional_encoding_gaussian_matrix.device
ones = torch.ones((h, w), device=dev, dtype=torch.float32)
norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1)
return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0)
def forward_with_coords(self, pixel_coords, image_size):
norm = pixel_coords.clone()
norm[:, :, 0] /= image_size[1]
norm[:, :, 1] /= image_size[0]
return self._encode(norm)
# ViTDet backbone + FPN neck
def window_partition(x: torch.Tensor, window_size: int):
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0):
"""Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2]."""
t = torch.arange(end_x * end_y, dtype=torch.float32)
ids = torch.stack([(t % end_x) * scale_pos,
torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1)
return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0))
class _ViTMLP(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None):
super().__init__()
hidden = int(dim * mlp_ratio)
self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class Attention(nn.Module):
"""ViTDet multi-head attention with fused QKV projection."""
def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.use_rope = use_rope
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x, freqs_cis=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
if self.use_rope and freqs_cis is not None:
q, k = apply_rope(q, k, freqs_cis)
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations)
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
def forward(self, x, freqs_cis=None):
shortcut = x
x = self.norm1(x)
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = x.view(x.shape[0], self.window_size * self.window_size, -1)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(-1, self.window_size, self.window_size, x.shape[-1])
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
else:
B, H, W, C = x.shape
x = x.view(B, H * W, C)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(B, H, W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchEmbed(nn.Module):
def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None):
super().__init__()
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.proj(x)
class ViTDet(nn.Module):
def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24,
global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.global_att_blocks = set(global_att_blocks)
self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations)
num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype))
self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
grid_size = img_size // patch_size
pretrain_grid = pretrain_img_size // patch_size
self.blocks = nn.ModuleList()
for i in range(depth):
is_global = i in self.global_att_blocks
self.blocks.append(Block(
embed_dim, num_heads, mlp_ratio, qkv_bias,
window_size=0 if is_global else window_size,
use_rope=use_rope,
device=device, dtype=dtype, operations=operations,
))
if use_rope:
rope_scale = pretrain_grid / grid_size
self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False)
self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False)
else:
self.freqs_cis = None
self.freqs_cis_window = None
def _get_pos_embed(self, num_tokens):
pos = self.pos_embed
if pos.shape[1] == num_tokens:
return pos
cls_pos = pos[:, :1]
spatial_pos = pos[:, 1:]
old_size = int(math.sqrt(spatial_pos.shape[1]))
new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size
spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2)
tiles_h = new_size // old_size + 1
tiles_w = new_size // old_size + 1
tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size]
tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
return torch.cat([cls_pos, tiled], dim=1)
def forward(self, x):
x = self.patch_embed(x)
B, C, Hp, Wp = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C)
pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x)
x = x + pos[:, 1:Hp * Wp + 1]
x = x.view(B, Hp, Wp, C)
x = self.ln_pre(x)
freqs_cis_global = self.freqs_cis
freqs_cis_win = self.freqs_cis_window
if freqs_cis_global is not None:
freqs_cis_global = cast_to_input(freqs_cis_global, x)
if freqs_cis_win is not None:
freqs_cis_win = cast_to_input(freqs_cis_win, x)
for block in self.blocks:
fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global
x = block(x, freqs_cis=fc)
return x.permute(0, 3, 1, 2)
class FPNScaleConv(nn.Module):
def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None):
super().__init__()
if scale == 4.0:
self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 4
elif scale == 2.0:
self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 2
elif scale == 1.0:
proj_in = in_dim
elif scale == 0.5:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
proj_in = in_dim
self.scale = scale
self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype)
self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype)
def forward(self, x):
if self.scale == 4.0:
x = F.gelu(self.dconv_2x2_0(x))
x = self.dconv_2x2_1(x)
elif self.scale == 2.0:
x = self.dconv_2x2(x)
elif self.scale == 0.5:
x = self.pool(x)
x = self.conv_1x1(x)
x = self.conv_3x3(x)
return x
class PositionEmbeddingSine(nn.Module):
"""2D sinusoidal position encoding (DETR-style) with result caching."""
def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None):
super().__init__()
assert num_pos_feats % 2 == 0
self.half_dim = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
self.scale = scale if scale is not None else 2 * math.pi
self._cache = {}
def _sincos(self, vals):
"""Encode 1D values to interleaved sin/cos features."""
freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim)
raw = vals[..., None] * self.scale / freqs
return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2)
def _encode_xy(self, x, y):
"""Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim]."""
dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim)
pos_x = x[:, None] * self.scale / dim_t
pos_y = y[:, None] * self.scale / dim_t
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y
def encode_boxes(self, cx, cy, w, h):
"""Encode box center + size to [N, d_model+2] features."""
pos_x, pos_y = self._encode_xy(cx, cy)
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
def forward(self, x):
B, C, H, W = x.shape
key = (H, W, x.device)
if key not in self._cache:
gy = torch.arange(H, dtype=torch.float32, device=x.device)
gx = torch.arange(W, dtype=torch.float32, device=x.device)
if self.normalize:
gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6)
yy, xx = torch.meshgrid(gy, gx, indexing="ij")
self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0)
return self._cache[key].expand(B, -1, -1, -1)
class SAM3VisionBackbone(nn.Module):
def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs)
self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.multiplex = multiplex
fpn_args = dict(device=device, dtype=dtype, operations=operations)
if multiplex:
scales = [4.0, 2.0, 1.0]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
else:
scales = [4.0, 2.0, 1.0, 0.5]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False):
backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images)
if tracker_only:
# Skip detector FPN when only tracker features are needed (video tracking)
if self.multiplex:
tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs
else:
tracker_convs = self.sam2_convs
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return None, None, tracker_features, tracker_positions
features = [conv(backbone_out) for conv in self.convs]
positions = [cast_to_input(self.position_encoding(f), f) for f in features]
if self.multiplex:
if tracker_mode == "propagation":
tracker_convs = self.propagation_convs
elif tracker_mode == "interactive":
tracker_convs = self.interactive_convs
else:
return features, positions, None, None
elif need_tracker:
tracker_convs = self.sam2_convs
else:
return features, positions, None, None
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return features, positions, tracker_features, tracker_positions