from typing import Optional import torch import torch.nn as nn from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask from ..model.transformer import MLP class MHRHead(nn.Module): def __init__( self, input_dim: int, mhr_rig, mlp_depth: int = 1, extra_joint_regressor: str = "", mlp_channel_div_factor: int = 8, enable_hand_model=False, device=None, dtype=None, operations=None, ): super().__init__() # Store the shared MHRRig as a non-registered Python attribute object.__setattr__(self, "mhr", mhr_rig) self.num_shape_comps = 45 self.num_scale_comps = 28 self.num_hand_comps = 54 self.num_face_comps = 72 self.enable_hand_model = enable_hand_model self.body_cont_dim = 260 self.npose = ( 6 # Global Rotation + self.body_cont_dim # then body + self.num_shape_comps + self.num_scale_comps + self.num_hand_comps * 2 + self.num_face_comps ) self.proj = MLP( input_dim=input_dim, hidden_dim=input_dim // mlp_channel_div_factor, output_dim=self.npose, num_layers=mlp_depth, device=device, dtype=dtype, operations=operations, ) # MHR Parameters self.num_hand_scale_comps = self.num_scale_comps - 18 self.num_hand_pose_comps = self.num_hand_comps # Buffers populated by load_state_dict from the safetensors def _p(*shape, dtype=torch.float32): return nn.Parameter(torch.empty(*shape, dtype=dtype), requires_grad=False) self.joint_rotation = _p(127, 3, 3) self.scale_mean = _p(68) self.scale_comps = _p(28, 68) self.faces = _p(36874, 3, dtype=torch.int64) self.hand_pose_mean = _p(54) self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False) self.hand_joint_idxs_left = _p(27, dtype=torch.int64) self.hand_joint_idxs_right = _p(27, dtype=torch.int64) self.keypoint_mapping = _p(308, 18439 + 127) # Some special buffers for the hand-version self.right_wrist_coords = _p(3) self.root_coords = _p(3) self.local_to_world_wrist = _p(3, 3) self.nonhand_param_idxs = _p(145, dtype=torch.int64) # Hand-painted per-vertex face region RGB (rainbow_face_semantic shader). # Optional — loaded from the .safetensors if present, otherwise the # render path falls back to a coarse geometric approximation. self.register_buffer( "face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32), ) def canonical_vertices(self, device=None): """Return the T-pose vertices for the mean shape (scaled to meters). Runs MHR with zero pose / shape / scale / expression so the returned mesh is the canonical rest pose — fixed per-model """ dev = device or self.scale_mean.device dt = self.scale_mean.dtype B = 1 global_trans = torch.zeros(B, 3, device=dev, dtype=dt) global_rot = torch.zeros(B, 3, device=dev, dtype=dt) body_pose = torch.zeros(B, 130, device=dev, dtype=dt) hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt) scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt) shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt) expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt) verts = self.mhr_forward( global_trans=global_trans, global_rot=global_rot, body_pose_params=body_pose, hand_pose_params=hand_pose, scale_params=scale, shape_params=shape, expr_params=expr, ) # single-tensor shape (1, N_v, 3) in meters return verts[0] def get_zero_pose_init(self, factor=1.0): # Initialize pose token with zero-initialized learnable params # Note: bias/initial value should be zero-pose in cont, not all-zeros weights = torch.zeros(1, self.npose) weights[:, : 6 + self.body_cont_dim] = torch.cat( [ torch.FloatTensor([1, 0, 0, 0, 1, 0]), compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze() * factor, ], dim=0, ) return weights def replace_hands_in_pose(self, full_pose_params, hand_pose_params): assert full_pose_params.shape[1] == 136 # This drops in the hand poses from hand_pose_params (PCA 6D) into full_pose_params. # Split into left and right hands left_hand_params, right_hand_params = torch.split( hand_pose_params, [self.num_hand_pose_comps, self.num_hand_pose_comps], dim=1, ) # Change from cont to model params left_hand_params_model_params = compact_cont_to_model_params_hand( self.hand_pose_mean + torch.einsum("da,ab->db", left_hand_params, self.hand_pose_comps) ) right_hand_params_model_params = compact_cont_to_model_params_hand( self.hand_pose_mean + torch.einsum("da,ab->db", right_hand_params, self.hand_pose_comps) ) # Drop it in full_pose_params[:, self.hand_joint_idxs_left] = left_hand_params_model_params full_pose_params[:, self.hand_joint_idxs_right] = right_hand_params_model_params return full_pose_params # B x 207 def mhr_forward( self, global_trans, global_rot, body_pose_params, hand_pose_params, scale_params, shape_params, expr_params=None, return_keypoints=False, do_pcblend=True, return_joint_coords=False, return_model_params=False, return_joint_rotations=False, scale_offsets=None, vertex_offsets=None, ): # Align everything to the static buffers dt = self.scale_mean.dtype global_trans = global_trans.to(dt) global_rot = global_rot.to(dt) body_pose_params = body_pose_params.to(dt) if hand_pose_params is not None: hand_pose_params = hand_pose_params.to(dt) scale_params = scale_params.to(dt) shape_params = shape_params.to(dt) if expr_params is not None: expr_params = expr_params.to(dt) if self.enable_hand_model: # Transfer wrist-centric predictions to the body. global_rot_ori = global_rot.clone() global_trans_ori = global_trans.clone() global_rot = rotmat_to_euler( "xyz", euler_to_rotmat("xyz", global_rot_ori) @ self.local_to_world_wrist, ) global_trans = ( -( euler_to_rotmat("xyz", global_rot) @ (self.right_wrist_coords - self.root_coords) + self.root_coords ) + global_trans_ori ) body_pose_params = body_pose_params[..., :130] # Convert from scale and shape params to actual scales and vertices # Add singleton batches in case... if len(scale_params.shape) == 1: scale_params = scale_params[None] if len(shape_params.shape) == 1: shape_params = shape_params[None] # Convert scale... scales = self.scale_mean[None, :] + scale_params @ self.scale_comps if scale_offsets is not None: scales = scales + scale_offsets # Now, figure out the pose. ## 10 here is because it's more stable to optimize global translation in meters. full_pose_params = torch.cat( [global_trans * 10, global_rot, body_pose_params], dim=1 ) # B x 127 ## Put in hands if hand_pose_params is not None: full_pose_params = self.replace_hands_in_pose( full_pose_params, hand_pose_params ) model_params = torch.cat([full_pose_params, scales], dim=1) if self.enable_hand_model: # Zero out non-hand parameters model_params[:, self.nonhand_param_idxs] = 0 curr_skinned_verts, curr_skel_state = self.mhr( shape_params, model_params, expr_params ) curr_joint_coords, curr_joint_quats, _ = torch.split( curr_skel_state, [3, 4, 1], dim=2 ) curr_skinned_verts = curr_skinned_verts / 100 curr_joint_coords = curr_joint_coords / 100 curr_joint_rots = unitquat_to_rotmat(curr_joint_quats) # Prepare returns to_return = [curr_skinned_verts] if return_keypoints: # Get sapiens 308 keypoints model_vert_joints = torch.cat( [curr_skinned_verts, curr_joint_coords], dim=1 ) # B x (num_verts + 127) x 3 kp_map = self.keypoint_mapping.to(model_vert_joints.dtype) model_keypoints_pred = ( (kp_map @ model_vert_joints.permute(1, 0, 2).flatten(1, 2)) .reshape(-1, model_vert_joints.shape[0], 3) .permute(1, 0, 2) ) if self.enable_hand_model: # Zero out everything except for the right hand model_keypoints_pred[:, :21] = 0 model_keypoints_pred[:, 42:] = 0 to_return = to_return + [model_keypoints_pred] if return_joint_coords: to_return = to_return + [curr_joint_coords] if return_model_params: to_return = to_return + [model_params] if return_joint_rotations: to_return = to_return + [curr_joint_rots] if isinstance(to_return, list) and len(to_return) == 1: return to_return[0] else: return tuple(to_return) def forward( self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None, do_pcblend=True, slim_keypoints=False, intermediate: bool = False, ): """ Args: x: pose token with shape [B, C], usually C=DECODER.DIM init_estimate: [B, self.npose] intermediate: when True, the caller only needs the keypoints/pose outputs needed by the per-layer keypoint-token update path — vertex output is suppressed so `camera_project` skips the 18439-vertex perspective projection on intermediate decoder layers. The final layer must call with intermediate=False. """ batch_size = x.shape[0] pred = self.proj(x) if init_estimate is not None: pred = pred + init_estimate # From pred, we want to pull out individual predictions. ## First, get globals ### Global rotation is first 6. count = 6 global_rot_6d = pred[:, :count] global_rot_rotmat = rot6d_to_rotmat(global_rot_6d) # B x 3 x 3 global_rot_euler = rotmat_to_euler("ZYX", global_rot_rotmat) # B x 3 global_trans = torch.zeros_like(global_rot_euler) ## Next, get body pose. ### Hold onto raw, continuous version for iterative correction. pred_pose_cont = pred[:, count : count + self.body_cont_dim] count += self.body_cont_dim ### Convert to eulers (and trans) pred_pose_euler = compact_cont_to_model_params_body(pred_pose_cont) ### Zero-out hands pred_pose_euler[:, mhr_param_hand_mask] = 0 ### Zero-out jaw pred_pose_euler[:, -3:] = 0 ## Get remaining parameters pred_shape = pred[:, count : count + self.num_shape_comps] count += self.num_shape_comps pred_scale = pred[:, count : count + self.num_scale_comps] count += self.num_scale_comps pred_hand = pred[:, count : count + self.num_hand_comps * 2] count += self.num_hand_comps * 2 pred_face = pred[:, count : count + self.num_face_comps] * 0 count += self.num_face_comps # Run everything through mhr output = self.mhr_forward( global_trans=global_trans, global_rot=global_rot_euler, body_pose_params=pred_pose_euler, hand_pose_params=pred_hand, scale_params=pred_scale, shape_params=pred_shape, expr_params=pred_face, do_pcblend=do_pcblend, return_keypoints=True, return_joint_coords=True, return_model_params=True, return_joint_rotations=True, ) # Some existing code to get joints and fix camera system verts, j3d, jcoords, mhr_model_params, joint_global_rots = output j3d = j3d[:, :70] # 308 --> 70 keypoints # Intermediate decoder layers only consume pred_keypoints_3d via the # keypoint-token update path; suppress verts so camera_project skips # the 18439-vertex perspective projection. if intermediate: verts = None if verts is not None: verts[..., [1, 2]] *= -1 # Camera system difference j3d[..., [1, 2]] *= -1 # Camera system difference if jcoords is not None: jcoords[..., [1, 2]] *= -1 # Head-MLP outputs are promoted to fp32 here so the external # pose_output["mhr"] contract has a stable dtype regardless of what # the head ran at (fp16/bf16 for speed). MHR-derived outputs are # already fp32 from MHR's math layers; the cast on them is a no-op. output = { "pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(), "pred_pose_rotmat": None, "global_rot": global_rot_euler.float(), "body_pose": pred_pose_euler.float(), "shape": pred_shape.float(), "scale": pred_scale.float(), "hand": pred_hand.float(), "face": pred_face.float(), "pred_keypoints_3d": j3d.reshape(batch_size, -1, 3), "pred_vertices": verts.reshape(batch_size, -1, 3) if verts is not None else None, "pred_joint_coords": jcoords.reshape(batch_size, -1, 3) if jcoords is not None else None, "faces": self.faces.cpu().numpy(), "joint_global_rots": joint_global_rots, "mhr_model_params": mhr_model_params, } return output