"""SAM 3D Body prompt pipeline: encode (keypoint, mask) prompts and run them through a cross-attention transformer decoder over (token, image) pairs. Both adapted from the SAM-style prompt path (Meta, Apache 2.0): https://github.com/facebookresearch/segment-anything """ from typing import Optional, Tuple import torch import torch.nn as nn from comfy.ldm.cascade.common import LayerNorm2d_op from comfy.ldm.sam3.sam import PositionEmbeddingRandom from .transformer import TransformerDecoderLayer class PromptEncoder(nn.Module): def __init__( self, embed_dim: int, num_body_joints: int, device=None, dtype=None, operations=None, ) -> None: """ Encodes prompts for input to SAM's mask decoder. """ super().__init__() ops = operations if operations is not None else nn self.embed_dim = embed_dim self.num_body_joints = num_body_joints # Keypoint prompts self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.point_embeddings = nn.ModuleList( [ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)] ) self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) # Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim. LN2d = LayerNorm2d_op(ops) mask_in_chans = 256 self.mask_downscaling = nn.Sequential( ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(mask_in_chans // 64, device=device, dtype=dtype), nn.GELU(), ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(mask_in_chans // 16, device=device, dtype=dtype), nn.GELU(), ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(mask_in_chans // 4, device=device, dtype=dtype), nn.GELU(), ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(mask_in_chans, device=device, dtype=dtype), nn.GELU(), ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype), ) # Trained values for the gating conv and no_mask_embed are loaded from the state dict self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor: """Positional encoding over the image-embedding grid; (1, C, H, W).""" return self.pe_layer(size) def _embed_keypoints(self, points: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Embeds point prompts. Assuming points have been normalized to [0, 1]. Output shape [B, N, C], mask shape [B, N] """ assert points.min() >= 0 and points.max() <= 1 # PE compute in fp32 for precision (sin/cos of large coords), then cast back to the embedding weight dtype weight_dtype = self.invalid_point_embed.weight.dtype point_embedding = self.pe_layer._encode(points.to(torch.float)).to(weight_dtype) point_embedding[labels == -2] = 0.0 # invalid points point_embedding[labels == -2] += self.invalid_point_embed.weight.to(point_embedding) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding) for i in range(self.num_body_joints): point_embedding[labels == i] += self.point_embeddings[i].weight.to(point_embedding) point_mask = labels > -2 return point_embedding, point_mask def _get_batch_size(self, keypoints: Optional[torch.Tensor], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor]) -> int: if keypoints is not None: return keypoints.shape[0] elif boxes is not None: return boxes.shape[0] elif masks is not None: return masks.shape[0] else: return 1 def forward( self, keypoints: Optional[torch.Tensor], boxes: Optional[torch.Tensor] = None, masks: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense embeddings. Arguments: keypoints (torchTensor or none): point coordinates and labels to embed. boxes (torch.Tensor or none): boxes to embed masks (torch.Tensor or none): masks to embed Returns: torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined by the number of input points and boxes. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(keypoints, boxes, masks) # Anchor device on the input prompts so we don't pull the offloaded # CPU embedding device under dynamic loading. ref = keypoints if keypoints is not None else boxes if boxes is not None else masks device = ref.device if ref is not None else self.point_embeddings[0].weight.device weight_dtype = self.invalid_point_embed.weight.dtype sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=device, dtype=weight_dtype) sparse_masks = torch.empty((bs, 0), device=device) if keypoints is not None: coords = keypoints[:, :, :2] labels = keypoints[:, :, -1] point_embeddings, point_mask = self._embed_keypoints(coords, labels) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) sparse_masks = torch.cat([sparse_masks, point_mask], dim=1) return sparse_embeddings, sparse_masks def get_mask_embeddings( self, masks: Optional[torch.Tensor] = None, bs: int = 1, size: Tuple[int, int] = (16, 16), # [H, W] ) -> torch.Tensor: """Embeds mask inputs.""" # masks is always on the active device when present; fall back to the # downscaling Conv's weight device when it isn't (rare callers). ref = masks if masks is not None else next(self.mask_downscaling.parameters()) no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand( bs, -1, size[0], size[1] ) if masks is not None: mask_embeddings = self.mask_downscaling(masks) else: mask_embeddings = no_mask_embeddings return mask_embeddings, no_mask_embeddings class PromptableDecoder(nn.Module): """Cross-attention transformer decoder over (token, image) pairs.""" def __init__( self, dims: int, context_dims: int, depth: int, num_heads: int = 8, head_dims: int = 64, mlp_dims: int = 1024, repeat_pe: bool = False, do_interm_preds: bool = False, keypoint_token_update: bool = False, device=None, dtype=None, operations=None, ): super().__init__() ops = operations if operations is not None else nn self.layers = nn.ModuleList( TransformerDecoderLayer( token_dims=dims, context_dims=context_dims, num_heads=num_heads, head_dims=head_dims, mlp_dims=mlp_dims, repeat_pe=repeat_pe, skip_first_pe=(i == 0), device=device, dtype=dtype, operations=operations, ) for i in range(depth) ) self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype) self.do_interm_preds = do_interm_preds self.keypoint_token_update = keypoint_token_update def forward( self, token_embedding: torch.Tensor, image_embedding: torch.Tensor, token_augment: Optional[torch.Tensor] = None, image_augment: Optional[torch.Tensor] = None, token_mask: Optional[torch.Tensor] = None, token_to_pose_output_fn=None, keypoint_token_update_fn=None, hand_embeddings=None, hand_augment=None, ): """ Args: token_embedding: [B, N, C] image_embedding: [B, C, H, W] -- flattened to [B, HW, C] inline """ # Channels-last for the transformer. image_embedding = image_embedding.flatten(2).permute(0, 2, 1) if image_augment is not None: image_augment = image_augment.flatten(2).permute(0, 2, 1) if hand_embeddings is not None: hand_embeddings = hand_embeddings.flatten(2).permute(0, 2, 1) hand_augment = hand_augment.flatten(2).permute(0, 2, 1) if len(hand_augment) == 1: # inflate batch dimension assert len(hand_augment.shape) == 3 hand_augment = hand_augment.repeat(len(hand_embeddings), 1, 1) all_pose_outputs = [] if self.do_interm_preds else None if self.do_interm_preds: assert token_to_pose_output_fn is not None layer_idx = 0 for layer_idx, layer in enumerate(self.layers): if hand_embeddings is None: token_embedding, image_embedding = layer( token_embedding, image_embedding, token_augment, image_augment, token_mask, ) else: token_embedding, image_embedding = layer( token_embedding, torch.cat([image_embedding, hand_embeddings], dim=1), token_augment, torch.cat([image_augment, hand_augment], dim=1), token_mask, ) image_embedding = image_embedding[:, : image_augment.shape[1]] if self.do_interm_preds and layer_idx < len(self.layers) - 1: curr = token_to_pose_output_fn( self.norm_final(token_embedding), prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None, layer_idx=layer_idx, ) all_pose_outputs.append(curr) if self.keypoint_token_update: assert keypoint_token_update_fn is not None token_embedding, token_augment, _, _ = keypoint_token_update_fn( token_embedding, token_augment, curr, layer_idx, ) out = self.norm_final(token_embedding) if self.do_interm_preds: curr = token_to_pose_output_fn( out, prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None, layer_idx=layer_idx, ) all_pose_outputs.append(curr) return out, all_pose_outputs return out