ComfyUI/comfy/ldm/sam3d_body/mhr/mhr_rig.py
2026-05-26 02:15:15 +03:00

236 lines
9.6 KiB
Python

# Adapted from facebookresearch/MHR (Apache 2.0):
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
# verbatim from the TorchScript source bundled in the upstream mhr_model.pt
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.sparse.check_sparse_tensor_invariants.disable() # silence the "implicitly disabled" UserWarning.
from .mhr_utils import batch6DFromXYZ
_LN2 = 0.6931471824645996
def _euler_xyz_to_quat(angles):
"""(roll, pitch, yaw) -> quaternion (x, y, z, w). Matches pymomentum.quaternion.euler_xyz_to_quaternion."""
roll, pitch, yaw = angles.unbind(-1)
cy, sy = torch.cos(yaw * 0.5), torch.sin(yaw * 0.5)
cp, sp = torch.cos(pitch * 0.5), torch.sin(pitch * 0.5)
cr, sr = torch.cos(roll * 0.5), torch.sin(roll * 0.5)
x = sr * cp * cy - cr * sp * sy
y = cr * sp * cy + sr * cp * sy
z = cr * cp * sy - sr * sp * cy
w = cr * cp * cy + sr * sp * sy
return torch.stack([x, y, z, w], dim=-1)
def _quat_multiply(q1, q2):
x1, y1, z1, w1 = q1.unbind(-1)
x2, y2, z2, w2 = q2.unbind(-1)
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
return torch.stack([x, y, z, w], dim=-1)
def _quat_rotate(q, v):
"""Rotate v by unit quaternion q (xyzw). v + 2 * (axis x v * w + axis x (axis x v))."""
axis = q[..., :3]
r = q[..., 3:4]
av = torch.cross(axis, v, dim=-1)
aav = torch.cross(axis, av, dim=-1)
return v + 2.0 * (av * r + aav)
def _skel_multiply(s1, s2):
"""Compose two skel states (..., 8). Returns parent ∘ child.
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
before composition. With many FK levels the previously-normalized quats
drift in ULPs; the JIT renormalizes defensively, so we do too to stay
bit-close to its outputs.
"""
t1, sc1 = s1[..., :3], s1[..., 7:8]
t2, sc2 = s2[..., :3], s2[..., 7:8]
q1 = F.normalize(s1[..., 3:7], p=2, dim=-1, eps=1e-12)
q2 = F.normalize(s2[..., 3:7], p=2, dim=-1, eps=1e-12)
t_res = t1 + sc1 * _quat_rotate(q1, t2)
q_res = _quat_multiply(q1, q2)
s_res = sc1 * sc2
return torch.cat([t_res, q_res, s_res], dim=-1)
def _skel_transform_points(skel_state, points):
"""Apply skel_state (..., 8) to points (..., 3): t + q * (s * points).
Assumes the quaternion in skel_state is already unit-norm. Callers that
can't guarantee that should normalize first.
"""
t = skel_state[..., :3]
q = skel_state[..., 3:7]
s = skel_state[..., 7:8]
return t + _quat_rotate(q, s * points)
def _global_skel_state_from_local(local, pmi_levels):
"""FK walk in fp64 (matches the JIT's use_double_precision=True path).
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
one per BFS level. Avoids per-call torch.split + tolist() sync.
"""
orig_dtype = local.dtype
g = local.to(torch.float64).clone()
for source, target in pmi_levels:
parent = g.index_select(-2, target)
child = g.index_select(-2, source)
g.index_copy_(-2, source, _skel_multiply(parent, child))
return g.to(orig_dtype)
class MHRRig(nn.Module):
"""Plain-PyTorch reimplementation of Meta's MHR rig.
All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's
use_double_precision=True backend) regardless of the host model's dtype.
"""
NUM_VERTS = 18439
NUM_JOINTS = 127
NUM_LBS_TRIPLETS = 51337
NUM_IDENTITY = 45
NUM_EXPR = 72
PARAM_TRANSFORM_IN = 249 # = model_parameters(204) + identity_coeffs(45)
PARAM_TRANSFORM_OUT = 889 # = NUM_JOINTS * 7
POSE_CORR_IN = 750 # = (NUM_JOINTS - 2) * 6
POSE_CORR_HIDDEN = 3000
POSE_CORR_SPARSE_NNZ = 53136
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
del dtype, operations
f32 = torch.float32
# All buffers are populated by load_state_dict from the `mhr.*` keys
def _p(*shape, dtype=f32):
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
def _b(name, *shape, dtype):
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
self.base_shape = _p(self.NUM_VERTS, 3)
self.identity_basis = _p(self.NUM_IDENTITY, self.NUM_VERTS, 3)
self.expr_basis = _p(self.NUM_EXPR, self.NUM_VERTS, 3)
self.param_transform = _p(self.PARAM_TRANSFORM_OUT, self.PARAM_TRANSFORM_IN)
self.skel_joint_translation_offsets = _p(self.NUM_JOINTS, 3)
self.skel_joint_prerotations = _p(self.NUM_JOINTS, 4)
_b("skel_joint_parents", self.NUM_JOINTS, dtype=torch.int32)
_b("skel_pmi", 2, 266, dtype=torch.int64)
_b("skel_pmi_buffer_sizes", 4, dtype=torch.int64)
self.lbs_inverse_bind_pose = _p(self.NUM_JOINTS, 8)
self.lbs_skin_weights = _p(self.NUM_LBS_TRIPLETS)
_b("lbs_skin_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int32)
_b("lbs_vert_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int64)
_b("pose_corr_sparse_indices", 2, self.POSE_CORR_SPARSE_NNZ, dtype=torch.int64)
self.pose_corr_sparse_weight = _p(self.POSE_CORR_SPARSE_NNZ)
self.register_buffer("pose_corr_sparse_shape", torch.tensor([self.POSE_CORR_HIDDEN, self.POSE_CORR_IN], dtype=torch.int64, device=device))
self.pose_corr_weight = _p(self.NUM_VERTS * 3, self.POSE_CORR_HIDDEN)
self.pose_corr_bias = None
self._pose_corr_sparse_cache = None
self._pmi_levels_cache = None
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
f32 = self.base_shape.dtype
identity_coeffs = identity_coeffs.to(f32)
model_parameters = model_parameters.to(f32)
expr_coeffs = expr_coeffs.to(f32)
B = identity_coeffs.shape[0]
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)
cat_in = torch.cat([model_parameters, torch.zeros_like(identity_coeffs)], dim=1)
joint_parameters = torch.einsum("dn,bn->bd", self.param_transform, cat_in)
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
local_t = jp[..., :3] + self.skel_joint_translation_offsets.unsqueeze(0)
local_q = _euler_xyz_to_quat(jp[..., 3:6])
local_q = _quat_multiply(self.skel_joint_prerotations.unsqueeze(0), local_q)
local_s = torch.exp(jp[..., 6:7] * _LN2)
local_state = torch.cat([local_t, local_q, local_s], dim=-1)
skel_state = _global_skel_state_from_local(local_state, self._pmi_levels())
face_expr = torch.einsum("nvd,bn->bvd", self.expr_basis, expr_coeffs)
unposed = identity_rest + face_expr
if apply_correctives:
unposed = unposed + self._pose_correctives(joint_parameters)
verts = self._skin(skel_state, unposed)
return verts, skel_state
def _pose_correctives(self, joint_parameters):
B = joint_parameters.shape[0]
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
# Joints [2:] only — root and one more skipped. Take Euler XYZ (cols 3:6).
feat = batch6DFromXYZ(jp[:, 2:, 3:6], return_9D=False) # (B, 125, 6)
feat[..., 0] -= 1.0
feat[..., 4] -= 1.0
feat = feat.flatten(1, 2) # (B, 750)
h = (self._sparse_w() @ feat.T).T # (B, 3000)
h = F.relu(h)
out = F.linear(h, self.pose_corr_weight, self.pose_corr_bias) # (B, 55317)
return out.view(B, self.NUM_VERTS, 3)
def _pmi_levels(self):
cached = self._pmi_levels_cache
pmi = self.skel_pmi
if cached is not None and cached[0][0].device == pmi.device:
return cached
sizes = self.skel_pmi_buffer_sizes.tolist()
parts = torch.split(pmi, sizes, dim=1)
levels = [(p[0], p[1]) for p in parts]
self._pmi_levels_cache = levels
return levels
def _sparse_w(self):
cached = self._pose_corr_sparse_cache
w = self.pose_corr_sparse_weight
if cached is not None and cached.device == w.device and cached.dtype == w.dtype:
return cached
sparse = torch.sparse_coo_tensor(
self.pose_corr_sparse_indices,
w,
tuple(self.pose_corr_sparse_shape.tolist()),
check_invariants=False,
).coalesce()
self._pose_corr_sparse_cache = sparse
return sparse
def _skin(self, skel_state, rest_verts):
B = skel_state.shape[0]
ibp = self.lbs_inverse_bind_pose.unsqueeze(0).expand(B, self.NUM_JOINTS, 8)
joint_xform = _skel_multiply(skel_state, ibp)
norm_q = F.normalize(joint_xform[..., 3:7], p=2, dim=-1, eps=1e-12)
joint_xform = torch.cat([joint_xform[..., :3], norm_q, joint_xform[..., 7:8]], dim=-1)
sk_idx = self.lbs_skin_indices.long()
v_idx = self.lbs_vert_indices
w = self.lbs_skin_weights
per_triplet_xform = joint_xform.index_select(-2, sk_idx) # (B, 51337, 8)
per_triplet_rest = rest_verts.index_select(-2, v_idx) # (B, 51337, 3)
contrib = _skel_transform_points(per_triplet_xform, per_triplet_rest) * w.unsqueeze(0).unsqueeze(-1)
out = torch.zeros(B, self.NUM_VERTS, 3, dtype=rest_verts.dtype, device=rest_verts.device)
out.index_add_(-2, v_idx, contrib)
return out