from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import comfy.model_management from comfy.ldm.sam3.sam import PositionEmbeddingRandom from .dinov3 import Dinov3Backbone from .prompt import PromptEncoder, PromptableDecoder from ..mhr.mhr_head import MHRHead from ..mhr.mhr_rig import MHRRig from ..mhr.mhr_utils import fix_wrist_euler, rotation_angle_difference from .transformer import MLP from .camera_modules import CameraEncoder, PerspectiveHead from comfy_extras.mediapipe.face_landmarker import FaceLandmarker from ..utils import bbox_xyxy2cs, fix_aspect_ratio, get_warp_matrices, warp_affine_batched, euler_to_rotmat, rotmat_to_euler # Architecture constants for the released `dinov3-h+` SAM 3D Body checkpoint. IMAGE_SIZE = (512, 512) IMAGE_MEAN = (0.485, 0.456, 0.406) IMAGE_STD = (0.229, 0.224, 0.225) DECODER_DIM = 1024 DECODER_DEPTH = 6 DECODER_HEADS = 8 DECODER_DIM_HEAD = 64 DECODER_MLP_DIM = 1024 MHR_MLP_DEPTH = 2 CAMERA_MLP_DEPTH = 2 CAMERA_DEFAULT_SCALE_FACTOR_HAND = 10.0 N_KEYPOINTS = 70 # mhr70 class SAM3DBody(nn.Module): pelvis_idx = [9, 10] # left_hip, right_hip def __init__(self, device=None, dtype=None, operations=None): super().__init__() # `operations` falls back to torch.nn so the model is constructible # without comfy.ops; matches the pattern in comfy/ldm/sam3/. ops = operations if operations is not None else nn # Per-batch state populated by `_initialize_batch`. self._max_num_person = None self._person_valid = None self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False) self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False) self.image_size = IMAGE_SIZE self.backbone = Dinov3Backbone(device=device, dtype=dtype, operations=operations) embed_dims = self.backbone.embed_dims # MHR rig shared between body + hand pose heads via a non-registered # Python ref, so state_dict has one top-level `mhr.*` key tree (not # duplicated under `head_pose.mhr.*` AND `head_pose_hand.mhr.*`). self.mhr = MHRRig(device=device) head_kwargs = dict( input_dim=DECODER_DIM, mlp_depth=MHR_MLP_DEPTH, mhr_rig=self.mhr, mlp_channel_div_factor=1, device=device, dtype=dtype, operations=operations, ) self.head_pose = MHRHead(**head_kwargs) self.head_pose.hand_pose_comps_ori = nn.Parameter( self.head_pose.hand_pose_comps.clone(), requires_grad=False ) self.head_pose.hand_pose_comps.data = ( torch.eye(54).to(self.head_pose.hand_pose_comps.data).float() ) self.init_pose = ops.Embedding(1, self.head_pose.npose, device=device, dtype=dtype) self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs) self.head_pose_hand.hand_pose_comps_ori = nn.Parameter( self.head_pose_hand.hand_pose_comps.clone(), requires_grad=False ) self.head_pose_hand.hand_pose_comps.data = ( torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float() ) self.init_pose_hand = ops.Embedding( 1, self.head_pose_hand.npose, device=device, dtype=dtype ) camera_kwargs = dict( input_dim=DECODER_DIM, img_size=IMAGE_SIZE, mlp_depth=CAMERA_MLP_DEPTH, mlp_channel_div_factor=1, device=device, dtype=dtype, operations=operations, ) self.head_camera = PerspectiveHead(**camera_kwargs) self.init_camera = ops.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype) self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs) self.init_camera_hand = ops.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype) cond_dim = 3 init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim linear_kwargs = dict(device=device, dtype=dtype) self.init_to_token_mhr = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs) self.prev_to_token_mhr = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) self.init_to_token_mhr_hand = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs) self.prev_to_token_mhr_hand = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) self.prompt_encoder = PromptEncoder( embed_dim=embed_dims, # match backbone dims so PE adds directly num_body_joints=N_KEYPOINTS, device=device, dtype=dtype, operations=operations, ) self.prompt_to_token = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) decoder_kwargs = dict( dims=DECODER_DIM, context_dims=embed_dims, depth=DECODER_DEPTH, num_heads=DECODER_HEADS, head_dims=DECODER_DIM_HEAD, mlp_dims=DECODER_MLP_DIM, repeat_pe=True, do_interm_preds=True, keypoint_token_update="v2", device=device, dtype=dtype, operations=operations, ) self.decoder = PromptableDecoder(**decoder_kwargs) self.decoder_hand = PromptableDecoder(**decoder_kwargs) self.hand_pe_layer = PositionEmbeddingRandom(embed_dims // 2) # Inference-time dtype set by the Loader via model.backbone.to(dtype). self.backbone_dtype = torch.float32 ray_kwargs = dict( embed_dim=embed_dims, patch_size=self.backbone.patch_size, device=device, dtype=dtype, operations=operations, ) self.ray_cond_emb = CameraEncoder(**ray_kwargs) self.ray_cond_emb_hand = CameraEncoder(**ray_kwargs) self.keypoint_embedding_idxs = list(range(N_KEYPOINTS)) self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS)) self.keypoint_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.keypoint_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.hand_box_embedding = ops.Embedding(2, DECODER_DIM, **linear_kwargs) self.hand_cls_embed = ops.Linear(DECODER_DIM, 2, **linear_kwargs) self.bbox_embed = MLP( input_dim=DECODER_DIM, hidden_dim=DECODER_DIM, output_dim=4, num_layers=3, device=device, dtype=dtype, operations=operations, ) posemb_kwargs = dict( hidden_dim=DECODER_DIM, output_dim=DECODER_DIM, num_layers=2, device=device, dtype=dtype, operations=operations, ) self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs) self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs) self.keypoint_feat_linear = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) self.keypoint_feat_linear_hand = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS)) self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS)) self.keypoint3d_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.keypoint3d_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs) self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs) self.face_landmarker = FaceLandmarker( device=device, dtype=torch.float32, operations=None, detector_variant="both", # short+full, picks whichever found more faces per frame. ) def data_preprocess(self, inputs: torch.Tensor) -> torch.Tensor: if inputs.max() > 1 and self.image_mean.max() <= 1.0: inputs = inputs / 255.0 elif inputs.max() <= 1.0 and self.image_mean.max() > 1: inputs = inputs * 255.0 return (inputs - self.image_mean) / self.image_std def _initialize_batch(self, batch: Dict) -> None: if batch["img"].dim() == 5: self._batch_size, self._max_num_person = batch["img"].shape[:2] self._person_valid = self._flatten_person(batch["person_valid"]) > 0 else: self._batch_size = batch["img"].shape[0] self._max_num_person = 0 self._person_valid = None def _flatten_person(self, x: torch.Tensor) -> torch.Tensor: assert self._max_num_person is not None, "No max_num_person initialized" if self._max_num_person: x = x.view(self._batch_size * self._max_num_person, *x.shape[2:]) return x def _set_active_branch(self, kind: str) -> None: """Route subsequent calls through the body or hand decoder by switching which batch indices are active.""" n = self._batch_size * self._max_num_person all_idx = list(range(n)) if kind == "body": self.body_batch_idx, self.hand_batch_idx = all_idx, [] elif kind == "hand": self.body_batch_idx, self.hand_batch_idx = [], all_idx else: raise ValueError(f"Invalid branch kind: {kind!r}") @staticmethod def _concat_hand_batches(a: Dict, b: Dict) -> Dict: """Merge two prepare_batch dicts along dim 0 for a single hand pass. Tensors cat, lists extend, scalars/metadata taken from `a`.""" out = {} for k, va in a.items(): vb = b.get(k) if isinstance(va, torch.Tensor) and isinstance(vb, torch.Tensor): out[k] = torch.cat([va, vb], dim=0) elif isinstance(va, list) and isinstance(vb, list): out[k] = va + vb else: out[k] = va return out @staticmethod def _split_hand_output(batched: Dict, n_left: int) -> Tuple[Dict, Dict]: """Inverse of `_concat_hand_batches`. Only `mhr_hand` needs splitting; condition_info / image_embeddings aren't consumed downstream.""" batched_mhr = batched["mhr_hand"] lhand_mhr: Dict[str, Any] = {} rhand_mhr: Dict[str, Any] = {} for k, v in batched_mhr.items(): if isinstance(v, torch.Tensor): lhand_mhr[k] = v[:n_left] rhand_mhr[k] = v[n_left:] else: # numpy `faces`, `pred_pose_rotmat=None`, etc. -- shared. lhand_mhr[k] = v rhand_mhr[k] = v return ( {"mhr": None, "mhr_hand": lhand_mhr}, {"mhr": None, "mhr_hand": rhand_mhr}, ) def _prepare_hand_batches_gpu( self, img, left_xyxy: torch.Tensor, right_xyxy: torch.Tensor, cam_int: torch.Tensor, is_multi_image: bool, ) -> Tuple[Dict, Dict]: """Build batch_lhand + batch_rhand directly on GPU. Bit-exact match for the CPU `prepare_batch` × 2 path, with the source uploaded once and both warps issued through one batched grid_sample.""" device = comfy.model_management.get_torch_device() if is_multi_image: assert isinstance(img, list) n = len(img) H_src, W_src = img[0].shape[:2] src_t = torch.stack(list(img), dim=0) else: n = int(left_xyxy.shape[0]) H_src, W_src = img.shape[:2] src_t = img.unsqueeze(0).expand(n, -1, -1, -1) H_out, W_out = int(self.image_size[0]), int(self.image_size[1]) bbox_padding = 0.9 # matches transform_hand aspect = 0.75 def _meta(boxes_xyxy): centers, scales = bbox_xyxy2cs(boxes_xyxy, padding=bbox_padding) scales = fix_aspect_ratio(scales, aspect) scales = fix_aspect_ratio(scales, W_out / H_out) mats = get_warp_matrices(centers, scales, (H_out, W_out)) return centers, scales, mats l_centers, l_scales, l_mats = _meta(left_xyxy) r_centers, r_scales, r_mats = _meta(right_xyxy) src_t = src_t.to(device, non_blocking=True).permute(0, 3, 1, 2).float() warped_l = warp_affine_batched(torch.flip(src_t, dims=[3]), l_mats, (H_out, W_out)) warped_r = warp_affine_batched(src_t, r_mats, (H_out, W_out)) # floor -> /255 matches the per-item uint8 round-trip path. l_img = (torch.floor(warped_l).clamp_(0.0, 255.0) / 255.0).contiguous() r_img = (torch.floor(warped_r).clamp_(0.0, 255.0) / 255.0).contiguous() # All-zero mask + score 0 (matches prepare_batch's masks=None path). zero_mask = torch.zeros((n, 1, H_out, W_out), dtype=torch.float32, device=device) zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device) person_valid = torch.ones((1, n), dtype=torch.float32, device=device) img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous() ori_img_size = torch.tensor([W_src, H_src], dtype=torch.float32, device=device).expand(n, 2).contiguous() cam_int_dev = cam_int.to(device).to(dtype=torch.float32) def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy): return { "img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out) "img_size": img_size.unsqueeze(0), "ori_img_size": ori_img_size.unsqueeze(0), "bbox_center": centers_t.to(device).unsqueeze(0), "bbox_scale": scales_t.to(device).unsqueeze(0), "bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0), "affine_trans": mats_t.to(device).unsqueeze(0), "mask": zero_mask.unsqueeze(0), # (1, N, 1, H_out, W_out) "mask_score": zero_mask_score.unsqueeze(0), # (1, N) "person_valid": person_valid, # (1, N) shared OK "cam_int": cam_int_dev, } return ( _build(l_centers, l_scales, l_mats, l_img, left_xyxy), _build(r_centers, r_scales, r_mats, r_img, right_xyxy), ) # Forward path def _get_decoder_condition(self, batch: Dict) -> torch.Tensor: """CLIFF-style condition: ((cx-img_cx)/f, (cy-img_cy)/f, b/f), all in [-1, 1].""" num_person = batch["img"].shape[1] cx, cy = torch.chunk(self._flatten_person(batch["bbox_center"]), chunks=2, dim=-1) b = self._flatten_person(batch["bbox_scale"])[:, [0]] cam_int_per_person = self._flatten_person( batch["cam_int"].unsqueeze(1).expand(-1, num_person, -1, -1).contiguous() ) focal_length = cam_int_per_person[:, 0, 0] full_img_cxy = cam_int_per_person[:, [0, 1], [2, 2]] condition_info = torch.cat( [cx - full_img_cxy[:, [0]], cy - full_img_cxy[:, [1]], b], dim=-1, ) condition_info[:, :2] = condition_info[:, :2] / focal_length.unsqueeze(-1) condition_info[:, 2] = condition_info[:, 2] / focal_length return condition_info.type(batch["img"].dtype) @staticmethod def _append_token_block(token_embeddings, token_augment, embedding_weight, batch_size): """Append a token block from `embedding_weight` (+ zero-block in token_augment). Returns (token_embeddings, token_augment, start_idx).""" start_idx = token_embeddings.shape[1] block = embedding_weight.to(token_embeddings)[None, :, :].repeat(batch_size, 1, 1) token_embeddings = torch.cat([token_embeddings, block], dim=1) token_augment = torch.cat([token_augment, torch.zeros_like(block)], dim=1) return token_embeddings, token_augment, start_idx def forward_decoder( self, branch: str, image_embeddings: torch.Tensor, init_estimate: Optional[torch.Tensor] = None, keypoints: Optional[torch.Tensor] = None, prev_estimate: Optional[torch.Tensor] = None, condition_info: Optional[torch.Tensor] = None, batch=None, ): """`branch` selects body or hand decoder + paired attribute set; rest of the pipeline is shared. image_embeddings: (B, C, H, W) backbone features. init_estimate: (B, 1, C) initial pose+cam estimate to refine. keypoints: (B, N, 3) prompts as (x, y in [0, 1], label). label: 0..K = joint, -1 = incorrect, -2 = invalid. prev_estimate: (B, 1, C) previous estimate for pose refinement. condition_info: (B, c) extra condition concatenated to input tokens. """ if branch == "body": init_pose_emb = self.init_pose init_camera_emb = self.init_camera init_to_token = self.init_to_token_mhr prev_to_token = self.prev_to_token_mhr ray_cond_emb = self.ray_cond_emb ray_cond_key = "ray_cond" head_pose = self.head_pose head_camera = self.head_camera keypoint_embedding = self.keypoint_embedding keypoint3d_embedding = self.keypoint3d_embedding decoder = self.decoder batch_idx = self.body_batch_idx # Body shares the prompt encoder's PE. image_augment_fn = self.prompt_encoder.get_dense_pe elif branch == "hand": init_pose_emb = self.init_pose_hand init_camera_emb = self.init_camera_hand init_to_token = self.init_to_token_mhr_hand prev_to_token = self.prev_to_token_mhr_hand ray_cond_emb = self.ray_cond_emb_hand ray_cond_key = "ray_cond_hand" head_pose = self.head_pose_hand head_camera = self.head_camera_hand keypoint_embedding = self.keypoint_embedding_hand keypoint3d_embedding = self.keypoint3d_embedding_hand decoder = self.decoder_hand batch_idx = self.hand_batch_idx # Hand decoder has its own PE layer (not the prompt encoder's). image_augment_fn = self.hand_pe_layer else: raise ValueError(f"Invalid branch: {branch!r}") batch_size = image_embeddings.shape[0] # .to(image_embeddings) moves weights CPU→GPU under dynamic loading # (they stay on CPU until first use). if init_estimate is None: init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3) init_input = ( torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1) if condition_info is not None else init_estimate ) token_embeddings = init_to_token(init_input).view(batch_size, 1, -1) num_pose_token = token_embeddings.shape[1] # always 1 image_augment, token_augment, token_mask = None, None, None if keypoints is not None: if prev_estimate is None: prev_estimate = init_estimate prev_embeddings = prev_to_token(prev_estimate).view(batch_size, 1, -1) # PE generated in fp32; cast back to decoder dtype. image_augment = image_augment_fn(image_embeddings.shape[-2:]).to(image_embeddings.dtype) # ray_cond is fp32 from get_ray_condition; cast so CameraEncoder's # internal cat doesn't silently promote everything back to fp32. image_embeddings = ray_cond_emb( image_embeddings, batch[ray_cond_key].type(image_embeddings.dtype), ) # Keypoints start as [0, 0, -2]. Labels select the embedding # weight (special for -2, -1, then per joint). prompt_embeddings, _ = self.prompt_encoder(keypoints=keypoints) prompt_embeddings = self.prompt_to_token(prompt_embeddings) # Pin dtypes so a silent fp16→fp32 promotion in any branch # (init/prev/prompt) doesn't break the index_put assigns below. token_embeddings = torch.cat( [token_embeddings, prev_embeddings, prompt_embeddings], dim=1, ).to(image_embeddings.dtype) prev_embeddings = prev_embeddings.to(image_embeddings.dtype) prompt_embeddings = prompt_embeddings.to(image_embeddings.dtype) token_augment = torch.zeros_like(token_embeddings) token_augment[:, [num_pose_token]] = prev_embeddings token_augment[:, (num_pose_token + 1):] = prompt_embeddings token_embeddings, token_augment, hand_det_emb_start_idx = self._append_token_block( token_embeddings, token_augment, self.hand_box_embedding.weight, batch_size, ) token_embeddings, token_augment, kps_emb_start_idx = self._append_token_block( token_embeddings, token_augment, keypoint_embedding.weight, batch_size, ) token_embeddings, token_augment, kps3d_emb_start_idx = self._append_token_block( token_embeddings, token_augment, keypoint3d_embedding.weight, batch_size, ) last_layer_idx = len(decoder.layers) - 1 def token_to_pose_output_fn(tokens, prev_pose_output, layer_idx): pose_token = tokens[:, 0] prev_pose = init_pose.view(batch_size, -1) prev_camera = init_camera.view(batch_size, -1) # Suppress vertices on non-final layers — kp-token updates only # need keypoints, so we skip the 18439-vertex perspective projection. is_intermediate = layer_idx != last_layer_idx pose_output = head_pose(pose_token, prev_pose, intermediate=is_intermediate) pose_output["pred_cam"] = head_camera(pose_token, prev_camera) pose_output = self.camera_project(pose_output, batch, branch=branch) pose_output["pred_keypoints_2d_cropped"] = self._full_to_crop( batch, pose_output["pred_keypoints_2d"], batch_idx, ) return pose_output def keypoint_token_update_fn_comb(*args): args = self._keypoint_token_update(branch, kps_emb_start_idx, image_embeddings, *args) args = self._keypoint3d_token_update(branch, kps3d_emb_start_idx, *args) return args pose_token, pose_output = decoder( token_embeddings, image_embeddings, token_augment, image_augment, token_mask, token_to_pose_output_fn=token_to_pose_output_fn, keypoint_token_update_fn=keypoint_token_update_fn_comb, ) return ( pose_token[:, hand_det_emb_start_idx:hand_det_emb_start_idx + 2], pose_output, ) def _get_mask_prompt(self, batch, image_embeddings): x_mask = self._flatten_person(batch["mask"]) # batch tensors are fp32 from prepare_batch; mask_downscaling is in the # Loader's dtype — cast once so the conv input matches. x_mask = x_mask.to(image_embeddings.dtype) mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings( x_mask, image_embeddings.shape[0], image_embeddings.shape[2:] ) mask_score = self._flatten_person(batch["mask_score"]).view(-1, 1, 1, 1).to(image_embeddings.dtype) mask_embeddings = torch.where( mask_score > 0, mask_score * mask_embeddings.to(image_embeddings), no_mask_embeddings.to(image_embeddings), ) return mask_embeddings def _full_to_crop( self, batch: Dict, pred_keypoints_2d: torch.Tensor, batch_idx: torch.Tensor = None, ) -> torch.Tensor: """Full-image kp coords → crop-normalized [-0.5, 0.5].""" pred_keypoints_2d_cropped = torch.cat( [pred_keypoints_2d, torch.ones_like(pred_keypoints_2d[:, :, [-1]])], dim=-1 ) if batch_idx is not None: affine_trans = self._flatten_person(batch["affine_trans"])[batch_idx].to( pred_keypoints_2d_cropped ) img_size = self._flatten_person(batch["img_size"])[batch_idx].unsqueeze(1) else: affine_trans = self._flatten_person(batch["affine_trans"]).to( pred_keypoints_2d_cropped ) img_size = self._flatten_person(batch["img_size"]).unsqueeze(1) pred_keypoints_2d_cropped = pred_keypoints_2d_cropped @ affine_trans.mT pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[..., :2] / img_size - 0.5 return pred_keypoints_2d_cropped def camera_project(self, pose_output: Dict, batch: Dict, branch: str = "body") -> Dict: """Project 3D keypoints (+ optional vertices) to 2D. `branch` selects the body or hand attribute set + batch slice.""" head_camera = self.head_camera_hand if branch == "hand" else self.head_camera batch_idx = self.hand_batch_idx if branch == "hand" else self.body_batch_idx pred_cam = pose_output["pred_cam"] # Hoist the shared bbox/intrinsics slice so we don't recompute the # expand+contiguous for the vertices branch. bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx] bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0] ori_img_size = self._flatten_person(batch["ori_img_size"])[batch_idx] cam_int = self._flatten_person( batch["cam_int"] .unsqueeze(1) .expand(-1, batch["img"].shape[1], -1, -1) .contiguous() )[batch_idx] def _project(points_3d): return head_camera.perspective_projection( points_3d, pred_cam, bbox_center, bbox_scale, ori_img_size, cam_int, use_intrin_center=True, ) cam_out = _project(pose_output["pred_keypoints_3d"]) if pose_output.get("pred_vertices") is not None: pose_output["pred_keypoints_2d_verts"] = _project( pose_output["pred_vertices"] )["pred_keypoints_2d"] pose_output.update(cam_out) return pose_output def get_ray_condition(self, batch): B, N, _, H, W = batch["img"].shape meshgrid_xy = ( torch.stack( torch.meshgrid(torch.arange(H), torch.arange(W), indexing="xy"), dim=2 )[None, None, :, :, :] .repeat(B, N, 1, 1, 1) .to(batch["affine_trans"].device) ) # B x N x H x W x 2 meshgrid_xy = ( meshgrid_xy / batch["affine_trans"][:, :, None, None, [0, 1], [0, 1]] ) meshgrid_xy = ( meshgrid_xy - batch["affine_trans"][:, :, None, None, [0, 1], [2, 2]] / batch["affine_trans"][:, :, None, None, [0, 1], [0, 1]] ) # Subtract out center & normalize to be rays meshgrid_xy = ( meshgrid_xy - batch["cam_int"][:, None, None, None, [0, 1], [2, 2]] ) meshgrid_xy = ( meshgrid_xy / batch["cam_int"][:, None, None, None, [0, 1], [0, 1]] ) return meshgrid_xy.permute(0, 1, 4, 2, 3).to( batch["img"].dtype ) # This is B x num_person x 2 x H x W def forward_pose_branch(self, batch: Dict) -> Dict: """One pose-decoder pass over the crop batch (body and/or hand).""" batch_size, num_person = batch["img"].shape[:2] x = self.data_preprocess(self._flatten_person(batch["img"])) ray_cond = self._flatten_person(self.get_ray_condition(batch)) if len(self.body_batch_idx): batch["ray_cond"] = ray_cond[self.body_batch_idx].clone() if len(self.hand_batch_idx): batch["ray_cond_hand"] = ray_cond[self.hand_batch_idx].clone() ray_cond = None image_embeddings = self.backbone(x.type(self.backbone_dtype)) # bf16 mantissa too lossy for the heads — promote back. fp16 survives. if self.backbone_dtype != torch.float16: image_embeddings = image_embeddings.type(x.dtype) image_embeddings = image_embeddings + self._get_mask_prompt(batch, image_embeddings) # condition_info is fp32 from `_get_decoder_condition`; align to # decoder dtype so the downstream cat doesn't auto-promote. condition_info = self._get_decoder_condition(batch).type(image_embeddings.dtype) # Seed prompt: all-invalid keypoints (label = -2). keypoints_prompt = torch.zeros((batch_size * num_person, 1, 3)).to(batch["img"]) keypoints_prompt[:, :, -1] = -2 pose_output, pose_output_hand = None, None if len(self.body_batch_idx): tokens_output, pose_output = self.forward_decoder( "body", image_embeddings[self.body_batch_idx], init_estimate=None, keypoints=keypoints_prompt[self.body_batch_idx], prev_estimate=None, condition_info=condition_info[self.body_batch_idx], batch=batch, ) pose_output = pose_output[-1] if len(self.hand_batch_idx): tokens_output_hand, pose_output_hand = self.forward_decoder( "hand", image_embeddings[self.hand_batch_idx], init_estimate=None, keypoints=keypoints_prompt[self.hand_batch_idx], prev_estimate=None, condition_info=condition_info[self.hand_batch_idx], batch=batch, ) pose_output_hand = pose_output_hand[-1] output = { "mhr": pose_output, "mhr_hand": pose_output_hand, "condition_info": condition_info, "image_embeddings": image_embeddings, } # hand_box is (x1, y1, w, h) ∈ [0, 1]. Body path promotes to fp32 to # match the head-MLP external contract (_get_hand_box would .float() anyway). if len(self.body_batch_idx): output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float() output["mhr"]["hand_logits"] = self.hand_cls_embed(tokens_output).float() if len(self.hand_batch_idx): output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid() output["mhr_hand"]["hand_logits"] = self.hand_cls_embed(tokens_output_hand) return output def forward_step(self, batch: Dict, decoder_type: str = "body") -> Tuple[Dict, Dict]: self._set_active_branch(decoder_type) return self.forward_pose_branch(batch) def run_inference( self, img, batch: Dict, inference_type: str = "full", thresh_wrist_angle=1.4, ): """3DB inference. inference_type: 'full' (body + hand-refined), 'body' (body decoder only), 'hand' (hand decoder only).""" is_multi_image = isinstance(img, list) ref_img = img[0] if is_multi_image else img height, width = ref_img.shape[:2] cam_int = batch["cam_int"].clone() if inference_type == "body": return self.forward_step(batch, decoder_type="body") if inference_type == "hand": return self.forward_step(batch, decoder_type="hand") if inference_type != "full": raise ValueError(f"Invalid inference type: {inference_type!r}") # 1. Body decoder pass. pose_output = self.forward_step(batch, decoder_type="body") left_xyxy, right_xyxy = self._get_hand_box(pose_output, batch) ori_local_wrist_rotmat = euler_to_rotmat( "XZY", pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]].unflatten(1, (2, 3)), ) # 2. Hand re-run. Flip the left box's x so it indexes into the # to-be-flipped source image (the flip itself happens on GPU inside # _prepare_hand_batches_gpu — no CPU copy of the frames needed). tmp = left_xyxy.clone() left_xyxy[:, 0] = width - tmp[:, 2] - 1 left_xyxy[:, 2] = width - tmp[:, 0] - 1 batch_lhand, batch_rhand = self._prepare_hand_batches_gpu( img, left_xyxy, right_xyxy, cam_int.clone(), is_multi_image, ) # Concat lhand+rhand along dim 0 so backbone+decoder run once on # (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass. batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand) saved_batch_state = (self._batch_size, self._max_num_person, self._person_valid) self._initialize_batch(batch_hands) hands_output = self.forward_step(batch_hands, decoder_type="hand") self._batch_size, self._max_num_person, self._person_valid = saved_batch_state n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1] lhand_output, rhand_output = self._split_hand_output(hands_output, n_left) # Free the batched image_embeddings/condition_info (unused downstream); # mhr_hand views into the underlying tensors stay alive via l/rhand_output. del hands_output, batch_hands # Unflip left-hand output. Keep MHR consts as 0-d on-device tensors — # `.item()` would force four hard CPU<->GPU syncs in the hot path. _lhand_scale = lhand_output["mhr_hand"]["scale"] scale_r_hands_mean = self.head_pose.scale_mean[8].to(_lhand_scale) scale_l_hands_mean = self.head_pose.scale_mean[9].to(_lhand_scale) scale_r_hands_std = self.head_pose.scale_comps[8, 8].to(_lhand_scale) scale_l_hands_std = self.head_pose.scale_comps[9, 9].to(_lhand_scale) lhand_output["mhr_hand"]["scale"][:, 9] = ( (scale_r_hands_mean + scale_r_hands_std * lhand_output["mhr_hand"]["scale"][:, 8]) - scale_l_hands_mean ) / scale_l_hands_std # Right-hand global rotation flipped → used as left. lhand_output["mhr_hand"]["joint_global_rots"][:, 78] = \ lhand_output["mhr_hand"]["joint_global_rots"][:, 42].clone() lhand_output["mhr_hand"]["joint_global_rots"][:, 78, [1, 2], :] *= -1 lhand_output["mhr_hand"]["hand"][:, :54] = lhand_output["mhr_hand"]["hand"][:, 54:] batch_lhand["bbox_center"][:, :, 0] = width - batch_lhand["bbox_center"][:, :, 0] - 1 # 3. Validity criteria for replacing body-decoder hand pose. # (a) local wrist pose difference: hand vs body wrist rotations joint_rotations = pose_output["mhr"]["joint_global_rots"] _dev = joint_rotations.device lowarm_joint_idxs = torch.LongTensor([76, 40]).to(_dev) # left, right lowarm_joint_rotations = joint_rotations[:, lowarm_joint_idxs] wrist_twist_joint_idxs = torch.LongTensor([77, 41]).to(_dev) wrist_zero_rot_pose = lowarm_joint_rotations @ self.head_pose.joint_rotation[wrist_twist_joint_idxs] pred_global_wrist_rotmat = torch.stack( [lhand_output["mhr_hand"]["joint_global_rots"][:, 78], rhand_output["mhr_hand"]["joint_global_rots"][:, 42]], dim=1, ) fused_local_wrist_rotmat = torch.einsum( "kabc,kabd->kadc", pred_global_wrist_rotmat, wrist_zero_rot_pose, ) angle_difference_valid_mask = rotation_angle_difference( ori_local_wrist_rotmat, fused_local_wrist_rotmat, ) < thresh_wrist_angle # (b) hand box big enough to give the decoder useful pixels hand_box_size_thresh = 64 hand_box_size_valid_mask = torch.stack( [(batch_lhand["bbox_scale"].flatten(0, 1) > hand_box_size_thresh).all(dim=1), (batch_rhand["bbox_scale"].flatten(0, 1) > hand_box_size_thresh).all(dim=1)], dim=1, ) # (c) all hand 2D keypoints inside the crop box hand_kps2d_thresh = 0.5 hand_kps2d_valid_mask = torch.stack( [lhand_output["mhr_hand"]["pred_keypoints_2d_cropped"].abs().amax(dim=(1, 2)) < hand_kps2d_thresh, rhand_output["mhr_hand"]["pred_keypoints_2d_cropped"].abs().amax(dim=(1, 2)) < hand_kps2d_thresh], dim=1, ) # (d) hand-decoder wrist close to body-decoder wrist in 2D hand_wrist_kps2d_thresh = 0.25 kps_right_wrist_idx, kps_left_wrist_idx = 41, 62 right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 body_right_kps_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() body_left_kps_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_left_wrist_idx]].clone() right_kps_dist = (right_kps_full - body_right_kps_full).flatten(0, 1).norm(dim=-1) \ / batch_lhand["bbox_scale"].flatten(0, 1)[:, 0] left_kps_dist = (left_kps_full - body_left_kps_full).flatten(0, 1).norm(dim=-1) \ / batch_rhand["bbox_scale"].flatten(0, 1)[:, 0] hand_wrist_kps2d_valid_mask = torch.stack( [left_kps_dist < hand_wrist_kps2d_thresh, right_kps_dist < hand_wrist_kps2d_thresh], dim=1, ) hand_valid_mask = ( angle_difference_valid_mask & hand_box_size_valid_mask & hand_kps2d_valid_mask & hand_wrist_kps2d_valid_mask ) # Re-prompt body decoder with hand-decoder wrists + body-decoder elbows # to get an updated body pose estimation. self._set_active_branch("body") right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 right_kps_crop = self._full_to_crop(batch, right_kps_full) left_kps_crop = self._full_to_crop(batch, left_kps_full) kps_right_elbow_idx, kps_left_elbow_idx = 8, 7 right_kps_elbow_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_right_elbow_idx]].clone() left_kps_elbow_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_left_elbow_idx]].clone() right_kps_elbow_crop = self._full_to_crop(batch, right_kps_elbow_full) left_kps_elbow_crop = self._full_to_crop(batch, left_kps_elbow_full) keypoint_prompt = torch.cat( [right_kps_crop, left_kps_crop, right_kps_elbow_crop, left_kps_elbow_crop], dim=1, ) keypoint_prompt = torch.cat([keypoint_prompt, keypoint_prompt[..., [-1]]], dim=-1) keypoint_prompt[:, 0, -1] = kps_right_wrist_idx keypoint_prompt[:, 1, -1] = kps_left_wrist_idx keypoint_prompt[:, 2, -1] = kps_right_elbow_idx keypoint_prompt[:, 3, -1] = kps_left_elbow_idx if keypoint_prompt.shape[0] > 1: invalid_prompt = ( (keypoint_prompt[..., 0] < -0.5) | (keypoint_prompt[..., 0] > 0.5) | (keypoint_prompt[..., 1] < -0.5) | (keypoint_prompt[..., 1] > 0.5) | (~hand_valid_mask[..., [1, 0, 1, 0]]) ).unsqueeze(-1) dummy_prompt = torch.zeros((1, 1, 3)).to(keypoint_prompt) dummy_prompt[:, :, -1] = -2 # Shift [-0.5, 0.5] → [0, 1] for the prompt encoder. keypoint_prompt[:, :, :2] = torch.clamp(keypoint_prompt[:, :, :2] + 0.5, 0.0, 1.0) keypoint_prompt = torch.where(invalid_prompt, dummy_prompt, keypoint_prompt) else: valid_keypoint = ( torch.all( (keypoint_prompt[:, :, :2] > -0.5) & (keypoint_prompt[:, :, :2] < 0.5), dim=2, ) & hand_valid_mask[..., [1, 0, 1, 0]] ).squeeze() keypoint_prompt = keypoint_prompt[:, valid_keypoint] keypoint_prompt[:, :, :2] = torch.clamp(keypoint_prompt[:, :, :2] + 0.5, 0.0, 1.0) if keypoint_prompt.numel() != 0: pose_output, _ = self.run_keypoint_prompt(batch, pose_output, keypoint_prompt) # 4. Drop hand pose / scale / shape from the hand decoder into the body output. updated_hand_pose = torch.cat( [lhand_output["mhr_hand"]["hand"][:, :54], rhand_output["mhr_hand"]["hand"][:, 54:]], dim=1, ) updated_scale = pose_output["mhr"]["scale"].clone() updated_scale[:, 9] = lhand_output["mhr_hand"]["scale"][:, 9] updated_scale[:, 8] = rhand_output["mhr_hand"]["scale"][:, 8] updated_scale[:, 18:] = ( lhand_output["mhr_hand"]["scale"][:, 18:] + rhand_output["mhr_hand"]["scale"][:, 18:] ) / 2 updated_shape = pose_output["mhr"]["shape"].clone() updated_shape[:, 40:] = ( lhand_output["mhr_hand"]["shape"][:, 40:] + rhand_output["mhr_hand"]["shape"][:, 40:] ) / 2 # 5. IK: solve local wrist Euler from the (updated) global wrist rotmat. joint_rotations = self.head_pose.mhr_forward( global_trans=pose_output["mhr"]["global_rot"] * 0, global_rot=pose_output["mhr"]["global_rot"], body_pose_params=pose_output["mhr"]["body_pose"], hand_pose_params=updated_hand_pose, scale_params=updated_scale, shape_params=updated_shape, expr_params=pose_output["mhr"]["face"], return_joint_rotations=True, )[1] _dev = joint_rotations.device lowarm_joint_idxs = torch.LongTensor([76, 40]).to(_dev) lowarm_joint_rotations = joint_rotations[:, lowarm_joint_idxs] # joint_rotation is a static buffer at head dtype; cast to MHR's fp32 # to keep the rotation matmul in fp32. wrist_twist_joint_idxs = torch.LongTensor([77, 41]).to(_dev) wrist_zero_rot_pose = lowarm_joint_rotations @ \ self.head_pose.joint_rotation[wrist_twist_joint_idxs].to(joint_rotations.dtype) pred_global_wrist_rotmat = torch.stack( [lhand_output["mhr_hand"]["joint_global_rots"][:, 78], rhand_output["mhr_hand"]["joint_global_rots"][:, 42]], dim=1, ) fused_local_wrist_rotmat = torch.einsum( "kabc,kabd->kadc", pred_global_wrist_rotmat, wrist_zero_rot_pose, ) wrist_xzy = fix_wrist_euler(rotmat_to_euler("XZY", fused_local_wrist_rotmat)) valid_angle = ( (rotation_angle_difference(ori_local_wrist_rotmat, fused_local_wrist_rotmat) < thresh_wrist_angle) & hand_valid_mask ).unsqueeze(-1) body_pose = pose_output["mhr"]["body_pose"][ :, [41, 43, 42, 31, 33, 32] ].unflatten(1, (2, 3)) updated_body_pose = torch.where(valid_angle, wrist_xzy, body_pose) pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]] = ( updated_body_pose.flatten(1, 2) ) hand_pose = pose_output["mhr"]["hand"].unflatten(1, (2, 54)) pose_output["mhr"]["hand"] = torch.where( valid_angle, updated_hand_pose.unflatten(1, (2, 54)), hand_pose ).flatten(1, 2) hand_scale = torch.stack( [pose_output["mhr"]["scale"][:, 9], pose_output["mhr"]["scale"][:, 8]], dim=1, ) updated_hand_scale = torch.stack( [updated_scale[:, 9], updated_scale[:, 8]], dim=1 ) masked_hand_scale = torch.where( valid_angle.squeeze(-1), updated_hand_scale, hand_scale ) pose_output["mhr"]["scale"][:, 9] = masked_hand_scale[:, 0] pose_output["mhr"]["scale"][:, 8] = masked_hand_scale[:, 1] # Replace shared shape and scale pose_output["mhr"]["scale"][:, 18:] = torch.where( valid_angle.squeeze(-1).sum(dim=1, keepdim=True) > 0, ( lhand_output["mhr_hand"]["scale"][:, 18:] * valid_angle.squeeze(-1)[:, [0]] + rhand_output["mhr_hand"]["scale"][:, 18:] * valid_angle.squeeze(-1)[:, [1]] ) / (valid_angle.squeeze(-1).sum(dim=1, keepdim=True) + 1e-8), pose_output["mhr"]["scale"][:, 18:], ) pose_output["mhr"]["shape"][:, 40:] = torch.where( valid_angle.squeeze(-1).sum(dim=1, keepdim=True) > 0, ( lhand_output["mhr_hand"]["shape"][:, 40:] * valid_angle.squeeze(-1)[:, [0]] + rhand_output["mhr_hand"]["shape"][:, 40:] * valid_angle.squeeze(-1)[:, [1]] ) / (valid_angle.squeeze(-1).sum(dim=1, keepdim=True) + 1e-8), pose_output["mhr"]["shape"][:, 40:], ) # Re-run MHR forward with the updated parameters. verts, j3d, jcoords, mhr_model_params, joint_global_rots = self.head_pose.mhr_forward( global_trans=pose_output["mhr"]["global_rot"] * 0, global_rot=pose_output["mhr"]["global_rot"], body_pose_params=pose_output["mhr"]["body_pose"], hand_pose_params=pose_output["mhr"]["hand"], scale_params=pose_output["mhr"]["scale"], shape_params=pose_output["mhr"]["shape"], expr_params=pose_output["mhr"]["face"], return_keypoints=True, return_joint_coords=True, return_model_params=True, return_joint_rotations=True, ) # j3d: 308 → 70 body/hand kps + 238 face landmarks. All four buffers # need the same y/z flip so they share a coordinate system. j3d_face = j3d[:, 70:].clone() j3d = j3d[:, :70] verts[..., [1, 2]] *= -1 j3d[..., [1, 2]] *= -1 j3d_face[..., [1, 2]] *= -1 jcoords[..., [1, 2]] *= -1 pose_output["mhr"]["pred_keypoints_3d"] = j3d pose_output["mhr"]["pred_face_keypoints_3d"] = j3d_face pose_output["mhr"]["pred_vertices"] = verts pose_output["mhr"]["pred_joint_coords"] = jcoords pose_output["mhr"]["pred_pose_raw"][...] = 0 # invalidated by the IK update pose_output["mhr"]["mhr_model_params"] = mhr_model_params def _project_kp3d(kp3d: torch.Tensor) -> torch.Tensor: proj = kp3d + pose_output["mhr"]["pred_cam_t"][:, None, :] proj[:, :, [0, 1]] = proj[:, :, [0, 1]] * pose_output["mhr"]["focal_length"][:, None, None] proj[:, :, [0, 1]] = ( proj[:, :, [0, 1]] + torch.FloatTensor([width / 2, height / 2]).to(proj)[None, None, :] * proj[:, :, [2]] ) proj[:, :, :2] = proj[:, :, :2] / proj[:, :, [2]] return proj[:, :, :2] pose_output["mhr"]["pred_keypoints_2d"] = _project_kp3d( pose_output["mhr"]["pred_keypoints_3d"].clone() ) pose_output["mhr"]["pred_face_keypoints_2d"] = _project_kp3d( pose_output["mhr"]["pred_face_keypoints_3d"].clone() ) return pose_output, batch_lhand, batch_rhand, lhand_output, rhand_output def run_keypoint_prompt(self, batch, output, keypoint_prompt): image_embeddings = output["image_embeddings"] condition_info = output["condition_info"] pose_output = output["mhr"] prev_estimate = torch.cat( [ pose_output["pred_pose_raw"], pose_output["shape"], pose_output["scale"], pose_output["hand"], pose_output["face"], ], dim=1, ).unsqueeze(1) prev_estimate = torch.cat( [prev_estimate, pose_output["pred_cam"].unsqueeze(1)], dim=-1, ) _, pose_output = self.forward_decoder( "body", image_embeddings, init_estimate=None, # use the default init, not the prev estimate keypoints=keypoint_prompt, prev_estimate=prev_estimate, condition_info=condition_info, batch=batch, ) pose_output = pose_output[-1] output.update({"mhr": pose_output}) return output, keypoint_prompt def _get_hand_box(self, pose_output, batch): """Hand bbox from the detector → full-image coords (xyxy). Stays on device throughout.""" hand_box = pose_output["mhr"]["hand_box"] # (B, 2, 4) fp32 pred_left_hand_box = hand_box[:, 0] * self.image_size[0] pred_right_hand_box = hand_box[:, 1] * self.image_size[0] # Square the boxes (long side wins). batch["left_center"] = pred_left_hand_box[:, :2] batch["left_scale"] = pred_left_hand_box[:, 2:].amax(dim=1, keepdim=True).repeat(1, 2) batch["right_center"] = pred_right_hand_box[:, :2] batch["right_scale"] = pred_right_hand_box[:, 2:].amax(dim=1, keepdim=True).repeat(1, 2) # Invert the crop's full→crop affine. rot=0 makes it diagonal: # divide-by-scale and subtract translation offset. affine_trans = batch["affine_trans"][0] affine_scale = affine_trans[:, 0, 0] affine_offset = affine_trans[:, :2, 2] batch["left_scale"] = batch["left_scale"] / affine_scale[:, None] batch["right_scale"] = batch["right_scale"] / affine_scale[:, None] batch["left_center"] = (batch["left_center"] - affine_offset) / affine_scale[:, None] batch["right_center"] = (batch["right_center"] - affine_offset) / affine_scale[:, None] left_xyxy = torch.stack( [ batch["left_center"][:, 0] - batch["left_scale"][:, 0] / 2, batch["left_center"][:, 1] - batch["left_scale"][:, 1] / 2, batch["left_center"][:, 0] + batch["left_scale"][:, 0] / 2, batch["left_center"][:, 1] + batch["left_scale"][:, 1] / 2, ], dim=1, ) right_xyxy = torch.stack( [ batch["right_center"][:, 0] - batch["right_scale"][:, 0] / 2, batch["right_center"][:, 1] - batch["right_scale"][:, 1] / 2, batch["right_center"][:, 0] + batch["right_scale"][:, 0] / 2, batch["right_center"][:, 1] + batch["right_scale"][:, 1] / 2, ], dim=1, ) return left_xyxy, right_xyxy # Shared 2D-keypoint-driven token update. `branch` picks body/hand attrs; # rest is identical. Called via keypoint_token_update_fn_comb in # forward_decoder. def _keypoint_token_update( self, branch: str, kps_emb_start_idx, image_embeddings, token_embeddings, token_augment, pose_output, layer_idx, ): if branch == "body": decoder_layers = self.decoder.layers kp_emb_w = self.keypoint_embedding.weight kp_idxs = self.keypoint_embedding_idxs posemb_linear = self.keypoint_posemb_linear feat_linear = self.keypoint_feat_linear else: decoder_layers = self.decoder_hand.layers kp_emb_w = self.keypoint_embedding_hand.weight kp_idxs = self.keypoint_embedding_idxs_hand posemb_linear = self.keypoint_posemb_linear_hand feat_linear = self.keypoint_feat_linear_hand # Last layer's pose output is final — nothing to inject back. if layer_idx == len(decoder_layers) - 1: return token_embeddings, token_augment, pose_output, layer_idx token_augment = token_augment.clone() num_keypoints = kp_emb_w.shape[0] # kp comes from fp32 MHR/cam projection; cast once to decoder dtype # so posemb / grid_sample match. pred_keypoints_2d_cropped = pose_output["pred_keypoints_2d_cropped"].clone()[:, kp_idxs] pred_keypoints_2d_depth = pose_output["pred_keypoints_2d_depth"].clone()[:, kp_idxs] pred_keypoints_2d_cropped = pred_keypoints_2d_cropped.to(image_embeddings.dtype) # Mask out-of-frame OR behind-camera keypoints' contributions. pred_keypoints_2d_cropped_01 = pred_keypoints_2d_cropped + 0.5 invalid_mask = ( (pred_keypoints_2d_cropped_01[:, :, 0] < 0) | (pred_keypoints_2d_cropped_01[:, :, 0] > 1) | (pred_keypoints_2d_cropped_01[:, :, 1] < 0) | (pred_keypoints_2d_cropped_01[:, :, 1] > 1) | (pred_keypoints_2d_depth[:, :] < 1e-5) ) token_augment[:, kps_emb_start_idx : kps_emb_start_idx + num_keypoints, :] = ( posemb_linear(pred_keypoints_2d_cropped) * (~invalid_mask[:, :, None]) ) # Bilinear-sample image features at each kp's projected location. # grid_sample wants -1..1; cropped form is -0.5..0.5, so ×2. sample_points = pred_keypoints_2d_cropped * 2 feats = F.grid_sample( image_embeddings, sample_points[:, :, None, :], mode="bilinear", padding_mode="zeros", align_corners=False, ).squeeze(3).permute(0, 2, 1) feats = feats * (~invalid_mask[:, :, None]) token_embeddings = token_embeddings.clone() token_embeddings[:, kps_emb_start_idx : kps_emb_start_idx + num_keypoints, :] += ( feat_linear(feats) ) return token_embeddings, token_augment, pose_output, layer_idx def _keypoint3d_token_update( self, branch: str, kps3d_emb_start_idx, token_embeddings, token_augment, pose_output, layer_idx, ): if branch == "body": decoder_layers = self.decoder.layers kp3d_emb_w = self.keypoint3d_embedding.weight kp3d_idxs = self.keypoint3d_embedding_idxs posemb_linear = self.keypoint3d_posemb_linear else: decoder_layers = self.decoder_hand.layers kp3d_emb_w = self.keypoint3d_embedding_hand.weight kp3d_idxs = self.keypoint3d_embedding_idxs_hand posemb_linear = self.keypoint3d_posemb_linear_hand if layer_idx == len(decoder_layers) - 1: return token_embeddings, token_augment, pose_output, layer_idx num_keypoints3d = kp3d_emb_w.shape[0] # Pelvis-normalize so 3D kps live in subject-centric coords (don't # leak global cam translation into the token signal). Cast to decoder # dtype before posemb_linear writes back into token_augment. pred_keypoints_3d = pose_output["pred_keypoints_3d"].clone() pred_keypoints_3d = pred_keypoints_3d - ( pred_keypoints_3d[:, [self.pelvis_idx[0]], :] + pred_keypoints_3d[:, [self.pelvis_idx[1]], :] ) / 2 pred_keypoints_3d = pred_keypoints_3d[:, kp3d_idxs].to(token_augment.dtype) token_augment = token_augment.clone() token_augment[:, kps3d_emb_start_idx : kps3d_emb_start_idx + num_keypoints3d, :] = ( posemb_linear(pred_keypoints_3d) ) return token_embeddings, token_augment, pose_output, layer_idx