ComfyUI/comfy/ldm/sam3d_body/mhr/mhr_head.py
2026-05-26 02:15:15 +03:00

378 lines
14 KiB
Python

from typing import Optional
import torch
import torch.nn as nn
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask
from ..model.transformer import MLP
class MHRHead(nn.Module):
def __init__(
self,
input_dim: int,
mhr_rig,
mlp_depth: int = 1,
extra_joint_regressor: str = "",
mlp_channel_div_factor: int = 8,
enable_hand_model=False,
device=None,
dtype=None,
operations=None,
):
super().__init__()
# Store the shared MHRRig as a non-registered Python attribute
object.__setattr__(self, "mhr", mhr_rig)
self.num_shape_comps = 45
self.num_scale_comps = 28
self.num_hand_comps = 54
self.num_face_comps = 72
self.enable_hand_model = enable_hand_model
self.body_cont_dim = 260
self.npose = (
6 # Global Rotation
+ self.body_cont_dim # then body
+ self.num_shape_comps
+ self.num_scale_comps
+ self.num_hand_comps * 2
+ self.num_face_comps
)
self.proj = MLP(
input_dim=input_dim,
hidden_dim=input_dim // mlp_channel_div_factor,
output_dim=self.npose,
num_layers=mlp_depth,
device=device,
dtype=dtype,
operations=operations,
)
# MHR Parameters
self.num_hand_scale_comps = self.num_scale_comps - 18
self.num_hand_pose_comps = self.num_hand_comps
# Buffers populated by load_state_dict from the safetensors
def _p(*shape, dtype=torch.float32):
return nn.Parameter(torch.empty(*shape, dtype=dtype), requires_grad=False)
self.joint_rotation = _p(127, 3, 3)
self.scale_mean = _p(68)
self.scale_comps = _p(28, 68)
self.faces = _p(36874, 3, dtype=torch.int64)
self.hand_pose_mean = _p(54)
self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
self.hand_joint_idxs_left = _p(27, dtype=torch.int64)
self.hand_joint_idxs_right = _p(27, dtype=torch.int64)
self.keypoint_mapping = _p(308, 18439 + 127)
# Some special buffers for the hand-version
self.right_wrist_coords = _p(3)
self.root_coords = _p(3)
self.local_to_world_wrist = _p(3, 3)
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
# Optional — loaded from the .safetensors if present, otherwise the
# render path falls back to a coarse geometric approximation.
self.register_buffer(
"face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32),
)
def canonical_vertices(self, device=None):
"""Return the T-pose vertices for the mean shape (scaled to meters).
Runs MHR with zero pose / shape / scale / expression so the returned
mesh is the canonical rest pose — fixed per-model
"""
dev = device or self.scale_mean.device
dt = self.scale_mean.dtype
B = 1
global_trans = torch.zeros(B, 3, device=dev, dtype=dt)
global_rot = torch.zeros(B, 3, device=dev, dtype=dt)
body_pose = torch.zeros(B, 130, device=dev, dtype=dt)
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt)
scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt)
shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt)
expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt)
verts = self.mhr_forward(
global_trans=global_trans,
global_rot=global_rot,
body_pose_params=body_pose,
hand_pose_params=hand_pose,
scale_params=scale,
shape_params=shape,
expr_params=expr,
) # single-tensor shape (1, N_v, 3) in meters
return verts[0]
def get_zero_pose_init(self, factor=1.0):
# Initialize pose token with zero-initialized learnable params
# Note: bias/initial value should be zero-pose in cont, not all-zeros
weights = torch.zeros(1, self.npose)
weights[:, : 6 + self.body_cont_dim] = torch.cat(
[
torch.FloatTensor([1, 0, 0, 0, 1, 0]),
compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
* factor,
],
dim=0,
)
return weights
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
assert full_pose_params.shape[1] == 136
# This drops in the hand poses from hand_pose_params (PCA 6D) into full_pose_params.
# Split into left and right hands
left_hand_params, right_hand_params = torch.split(
hand_pose_params,
[self.num_hand_pose_comps, self.num_hand_pose_comps],
dim=1,
)
# Change from cont to model params
left_hand_params_model_params = compact_cont_to_model_params_hand(
self.hand_pose_mean
+ torch.einsum("da,ab->db", left_hand_params, self.hand_pose_comps)
)
right_hand_params_model_params = compact_cont_to_model_params_hand(
self.hand_pose_mean
+ torch.einsum("da,ab->db", right_hand_params, self.hand_pose_comps)
)
# Drop it in
full_pose_params[:, self.hand_joint_idxs_left] = left_hand_params_model_params
full_pose_params[:, self.hand_joint_idxs_right] = right_hand_params_model_params
return full_pose_params # B x 207
def mhr_forward(
self,
global_trans,
global_rot,
body_pose_params,
hand_pose_params,
scale_params,
shape_params,
expr_params=None,
return_keypoints=False,
do_pcblend=True,
return_joint_coords=False,
return_model_params=False,
return_joint_rotations=False,
scale_offsets=None,
vertex_offsets=None,
):
# Align everything to the static buffers
dt = self.scale_mean.dtype
global_trans = global_trans.to(dt)
global_rot = global_rot.to(dt)
body_pose_params = body_pose_params.to(dt)
if hand_pose_params is not None:
hand_pose_params = hand_pose_params.to(dt)
scale_params = scale_params.to(dt)
shape_params = shape_params.to(dt)
if expr_params is not None:
expr_params = expr_params.to(dt)
if self.enable_hand_model:
# Transfer wrist-centric predictions to the body.
global_rot_ori = global_rot.clone()
global_trans_ori = global_trans.clone()
global_rot = rotmat_to_euler(
"xyz",
euler_to_rotmat("xyz", global_rot_ori) @ self.local_to_world_wrist,
)
global_trans = (
-(
euler_to_rotmat("xyz", global_rot)
@ (self.right_wrist_coords - self.root_coords)
+ self.root_coords
)
+ global_trans_ori
)
body_pose_params = body_pose_params[..., :130]
# Convert from scale and shape params to actual scales and vertices
# Add singleton batches in case...
if len(scale_params.shape) == 1:
scale_params = scale_params[None]
if len(shape_params.shape) == 1:
shape_params = shape_params[None]
# Convert scale...
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
if scale_offsets is not None:
scales = scales + scale_offsets
# Now, figure out the pose.
## 10 here is because it's more stable to optimize global translation in meters.
full_pose_params = torch.cat(
[global_trans * 10, global_rot, body_pose_params], dim=1
) # B x 127
## Put in hands
if hand_pose_params is not None:
full_pose_params = self.replace_hands_in_pose(
full_pose_params, hand_pose_params
)
model_params = torch.cat([full_pose_params, scales], dim=1)
if self.enable_hand_model:
# Zero out non-hand parameters
model_params[:, self.nonhand_param_idxs] = 0
curr_skinned_verts, curr_skel_state = self.mhr(
shape_params, model_params, expr_params
)
curr_joint_coords, curr_joint_quats, _ = torch.split(
curr_skel_state, [3, 4, 1], dim=2
)
curr_skinned_verts = curr_skinned_verts / 100
curr_joint_coords = curr_joint_coords / 100
curr_joint_rots = unitquat_to_rotmat(curr_joint_quats)
# Prepare returns
to_return = [curr_skinned_verts]
if return_keypoints:
# Get sapiens 308 keypoints
model_vert_joints = torch.cat(
[curr_skinned_verts, curr_joint_coords], dim=1
) # B x (num_verts + 127) x 3
kp_map = self.keypoint_mapping.to(model_vert_joints.dtype)
model_keypoints_pred = (
(kp_map @ model_vert_joints.permute(1, 0, 2).flatten(1, 2))
.reshape(-1, model_vert_joints.shape[0], 3)
.permute(1, 0, 2)
)
if self.enable_hand_model:
# Zero out everything except for the right hand
model_keypoints_pred[:, :21] = 0
model_keypoints_pred[:, 42:] = 0
to_return = to_return + [model_keypoints_pred]
if return_joint_coords:
to_return = to_return + [curr_joint_coords]
if return_model_params:
to_return = to_return + [model_params]
if return_joint_rotations:
to_return = to_return + [curr_joint_rots]
if isinstance(to_return, list) and len(to_return) == 1:
return to_return[0]
else:
return tuple(to_return)
def forward(
self,
x: torch.Tensor,
init_estimate: Optional[torch.Tensor] = None,
do_pcblend=True,
slim_keypoints=False,
intermediate: bool = False,
):
"""
Args:
x: pose token with shape [B, C], usually C=DECODER.DIM
init_estimate: [B, self.npose]
intermediate: when True, the caller only needs the keypoints/pose
outputs needed by the per-layer keypoint-token update path —
vertex output is suppressed so `camera_project` skips the
18439-vertex perspective projection on intermediate decoder
layers. The final layer must call with intermediate=False.
"""
batch_size = x.shape[0]
pred = self.proj(x)
if init_estimate is not None:
pred = pred + init_estimate
# From pred, we want to pull out individual predictions.
## First, get globals
### Global rotation is first 6.
count = 6
global_rot_6d = pred[:, :count]
global_rot_rotmat = rot6d_to_rotmat(global_rot_6d) # B x 3 x 3
global_rot_euler = rotmat_to_euler("ZYX", global_rot_rotmat) # B x 3
global_trans = torch.zeros_like(global_rot_euler)
## Next, get body pose.
### Hold onto raw, continuous version for iterative correction.
pred_pose_cont = pred[:, count : count + self.body_cont_dim]
count += self.body_cont_dim
### Convert to eulers (and trans)
pred_pose_euler = compact_cont_to_model_params_body(pred_pose_cont)
### Zero-out hands
pred_pose_euler[:, mhr_param_hand_mask] = 0
### Zero-out jaw
pred_pose_euler[:, -3:] = 0
## Get remaining parameters
pred_shape = pred[:, count : count + self.num_shape_comps]
count += self.num_shape_comps
pred_scale = pred[:, count : count + self.num_scale_comps]
count += self.num_scale_comps
pred_hand = pred[:, count : count + self.num_hand_comps * 2]
count += self.num_hand_comps * 2
pred_face = pred[:, count : count + self.num_face_comps] * 0
count += self.num_face_comps
# Run everything through mhr
output = self.mhr_forward(
global_trans=global_trans,
global_rot=global_rot_euler,
body_pose_params=pred_pose_euler,
hand_pose_params=pred_hand,
scale_params=pred_scale,
shape_params=pred_shape,
expr_params=pred_face,
do_pcblend=do_pcblend,
return_keypoints=True,
return_joint_coords=True,
return_model_params=True,
return_joint_rotations=True,
)
# Some existing code to get joints and fix camera system
verts, j3d, jcoords, mhr_model_params, joint_global_rots = output
j3d = j3d[:, :70] # 308 --> 70 keypoints
# Intermediate decoder layers only consume pred_keypoints_3d via the
# keypoint-token update path; suppress verts so camera_project skips
# the 18439-vertex perspective projection.
if intermediate:
verts = None
if verts is not None:
verts[..., [1, 2]] *= -1 # Camera system difference
j3d[..., [1, 2]] *= -1 # Camera system difference
if jcoords is not None:
jcoords[..., [1, 2]] *= -1
# Head-MLP outputs are promoted to fp32 here so the external
# pose_output["mhr"] contract has a stable dtype regardless of what
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are
# already fp32 from MHR's math layers; the cast on them is a no-op.
output = {
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
"pred_pose_rotmat": None,
"global_rot": global_rot_euler.float(),
"body_pose": pred_pose_euler.float(),
"shape": pred_shape.float(),
"scale": pred_scale.float(),
"hand": pred_hand.float(),
"face": pred_face.float(),
"pred_keypoints_3d": j3d.reshape(batch_size, -1, 3),
"pred_vertices": verts.reshape(batch_size, -1, 3) if verts is not None else None,
"pred_joint_coords": jcoords.reshape(batch_size, -1, 3) if jcoords is not None else None,
"faces": self.faces.cpu().numpy(),
"joint_global_rots": joint_global_rots,
"mhr_model_params": mhr_model_params,
}
return output