ComfyUI/comfy/ldm/sam3d_body/model/prompt.py
2026-05-26 02:15:15 +03:00

273 lines
11 KiB
Python

"""SAM 3D Body prompt pipeline: encode (keypoint, mask) prompts and run them
through a cross-attention transformer decoder over (token, image) pairs.
Both adapted from the SAM-style prompt path (Meta, Apache 2.0):
https://github.com/facebookresearch/segment-anything
"""
from typing import Optional, Tuple
import torch
import torch.nn as nn
from comfy.ldm.cascade.common import LayerNorm2d_op
from comfy.ldm.sam3.sam import PositionEmbeddingRandom
from .transformer import TransformerDecoderLayer
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
num_body_joints: int,
device=None,
dtype=None,
operations=None,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
"""
super().__init__()
ops = operations if operations is not None else nn
self.embed_dim = embed_dim
self.num_body_joints = num_body_joints
# Keypoint prompts
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.point_embeddings = nn.ModuleList(
[ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
)
self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
LN2d = LayerNorm2d_op(ops)
mask_in_chans = 256
self.mask_downscaling = nn.Sequential(
ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 64, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 16, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 4, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
)
# Trained values for the gating conv and no_mask_embed are loaded from the state dict
self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
"""Positional encoding over the image-embedding grid; (1, C, H, W)."""
return self.pe_layer(size)
def _embed_keypoints(self, points: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Embeds point prompts.
Assuming points have been normalized to [0, 1].
Output shape [B, N, C], mask shape [B, N]
"""
assert points.min() >= 0 and points.max() <= 1
# PE compute in fp32 for precision (sin/cos of large coords), then cast back to the embedding weight dtype
weight_dtype = self.invalid_point_embed.weight.dtype
point_embedding = self.pe_layer._encode(points.to(torch.float)).to(weight_dtype)
point_embedding[labels == -2] = 0.0 # invalid points
point_embedding[labels == -2] += self.invalid_point_embed.weight.to(point_embedding)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding)
for i in range(self.num_body_joints):
point_embedding[labels == i] += self.point_embeddings[i].weight.to(point_embedding)
point_mask = labels > -2
return point_embedding, point_mask
def _get_batch_size(self, keypoints: Optional[torch.Tensor], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor]) -> int:
if keypoints is not None:
return keypoints.shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def forward(
self,
keypoints: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
keypoints (torchTensor or none): point coordinates and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(keypoints, boxes, masks)
# Anchor device on the input prompts so we don't pull the offloaded
# CPU embedding device under dynamic loading.
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
device = ref.device if ref is not None else self.point_embeddings[0].weight.device
weight_dtype = self.invalid_point_embed.weight.dtype
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=device, dtype=weight_dtype)
sparse_masks = torch.empty((bs, 0), device=device)
if keypoints is not None:
coords = keypoints[:, :, :2]
labels = keypoints[:, :, -1]
point_embeddings, point_mask = self._embed_keypoints(coords, labels)
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
sparse_masks = torch.cat([sparse_masks, point_mask], dim=1)
return sparse_embeddings, sparse_masks
def get_mask_embeddings(
self,
masks: Optional[torch.Tensor] = None,
bs: int = 1,
size: Tuple[int, int] = (16, 16), # [H, W]
) -> torch.Tensor:
"""Embeds mask inputs."""
# masks is always on the active device when present; fall back to the
# downscaling Conv's weight device when it isn't (rare callers).
ref = masks if masks is not None else next(self.mask_downscaling.parameters())
no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand(
bs, -1, size[0], size[1]
)
if masks is not None:
mask_embeddings = self.mask_downscaling(masks)
else:
mask_embeddings = no_mask_embeddings
return mask_embeddings, no_mask_embeddings
class PromptableDecoder(nn.Module):
"""Cross-attention transformer decoder over (token, image) pairs."""
def __init__(
self,
dims: int,
context_dims: int,
depth: int,
num_heads: int = 8,
head_dims: int = 64,
mlp_dims: int = 1024,
repeat_pe: bool = False,
do_interm_preds: bool = False,
keypoint_token_update: bool = False,
device=None,
dtype=None,
operations=None,
):
super().__init__()
ops = operations if operations is not None else nn
self.layers = nn.ModuleList(
TransformerDecoderLayer(
token_dims=dims,
context_dims=context_dims,
num_heads=num_heads,
head_dims=head_dims,
mlp_dims=mlp_dims,
repeat_pe=repeat_pe,
skip_first_pe=(i == 0),
device=device,
dtype=dtype,
operations=operations,
)
for i in range(depth)
)
self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
self.do_interm_preds = do_interm_preds
self.keypoint_token_update = keypoint_token_update
def forward(
self,
token_embedding: torch.Tensor,
image_embedding: torch.Tensor,
token_augment: Optional[torch.Tensor] = None,
image_augment: Optional[torch.Tensor] = None,
token_mask: Optional[torch.Tensor] = None,
token_to_pose_output_fn=None,
keypoint_token_update_fn=None,
hand_embeddings=None,
hand_augment=None,
):
"""
Args:
token_embedding: [B, N, C]
image_embedding: [B, C, H, W] -- flattened to [B, HW, C] inline
"""
# Channels-last for the transformer.
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
if image_augment is not None:
image_augment = image_augment.flatten(2).permute(0, 2, 1)
if hand_embeddings is not None:
hand_embeddings = hand_embeddings.flatten(2).permute(0, 2, 1)
hand_augment = hand_augment.flatten(2).permute(0, 2, 1)
if len(hand_augment) == 1:
# inflate batch dimension
assert len(hand_augment.shape) == 3
hand_augment = hand_augment.repeat(len(hand_embeddings), 1, 1)
all_pose_outputs = [] if self.do_interm_preds else None
if self.do_interm_preds:
assert token_to_pose_output_fn is not None
layer_idx = 0
for layer_idx, layer in enumerate(self.layers):
if hand_embeddings is None:
token_embedding, image_embedding = layer(
token_embedding, image_embedding,
token_augment, image_augment, token_mask,
)
else:
token_embedding, image_embedding = layer(
token_embedding,
torch.cat([image_embedding, hand_embeddings], dim=1),
token_augment,
torch.cat([image_augment, hand_augment], dim=1),
token_mask,
)
image_embedding = image_embedding[:, : image_augment.shape[1]]
if self.do_interm_preds and layer_idx < len(self.layers) - 1:
curr = token_to_pose_output_fn(
self.norm_final(token_embedding),
prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None,
layer_idx=layer_idx,
)
all_pose_outputs.append(curr)
if self.keypoint_token_update:
assert keypoint_token_update_fn is not None
token_embedding, token_augment, _, _ = keypoint_token_update_fn(
token_embedding, token_augment, curr, layer_idx,
)
out = self.norm_final(token_embedding)
if self.do_interm_preds:
curr = token_to_pose_output_fn(
out,
prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None,
layer_idx=layer_idx,
)
all_pose_outputs.append(curr)
return out, all_pose_outputs
return out