mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-25 01:09:24 +08:00
236 lines
9.6 KiB
Python
236 lines
9.6 KiB
Python
# Adapted from facebookresearch/MHR (Apache 2.0):
|
|
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
|
|
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
|
|
# verbatim from the TorchScript source bundled in the upstream mhr_model.pt
|
|
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
|
|
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
torch.sparse.check_sparse_tensor_invariants.disable() # silence the "implicitly disabled" UserWarning.
|
|
|
|
from .mhr_utils import batch6DFromXYZ
|
|
|
|
_LN2 = 0.6931471824645996
|
|
|
|
def _euler_xyz_to_quat(angles):
|
|
"""(roll, pitch, yaw) -> quaternion (x, y, z, w). Matches pymomentum.quaternion.euler_xyz_to_quaternion."""
|
|
roll, pitch, yaw = angles.unbind(-1)
|
|
cy, sy = torch.cos(yaw * 0.5), torch.sin(yaw * 0.5)
|
|
cp, sp = torch.cos(pitch * 0.5), torch.sin(pitch * 0.5)
|
|
cr, sr = torch.cos(roll * 0.5), torch.sin(roll * 0.5)
|
|
x = sr * cp * cy - cr * sp * sy
|
|
y = cr * sp * cy + sr * cp * sy
|
|
z = cr * cp * sy - sr * sp * cy
|
|
w = cr * cp * cy + sr * sp * sy
|
|
return torch.stack([x, y, z, w], dim=-1)
|
|
|
|
|
|
def _quat_multiply(q1, q2):
|
|
x1, y1, z1, w1 = q1.unbind(-1)
|
|
x2, y2, z2, w2 = q2.unbind(-1)
|
|
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
|
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
|
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
|
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
|
return torch.stack([x, y, z, w], dim=-1)
|
|
|
|
|
|
def _quat_rotate(q, v):
|
|
"""Rotate v by unit quaternion q (xyzw). v + 2 * (axis x v * w + axis x (axis x v))."""
|
|
axis = q[..., :3]
|
|
r = q[..., 3:4]
|
|
av = torch.cross(axis, v, dim=-1)
|
|
aav = torch.cross(axis, av, dim=-1)
|
|
return v + 2.0 * (av * r + aav)
|
|
|
|
|
|
def _skel_multiply(s1, s2):
|
|
"""Compose two skel states (..., 8). Returns parent ∘ child.
|
|
|
|
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
|
|
before composition. With many FK levels the previously-normalized quats
|
|
drift in ULPs; the JIT renormalizes defensively, so we do too to stay
|
|
bit-close to its outputs.
|
|
"""
|
|
t1, sc1 = s1[..., :3], s1[..., 7:8]
|
|
t2, sc2 = s2[..., :3], s2[..., 7:8]
|
|
q1 = F.normalize(s1[..., 3:7], p=2, dim=-1, eps=1e-12)
|
|
q2 = F.normalize(s2[..., 3:7], p=2, dim=-1, eps=1e-12)
|
|
t_res = t1 + sc1 * _quat_rotate(q1, t2)
|
|
q_res = _quat_multiply(q1, q2)
|
|
s_res = sc1 * sc2
|
|
return torch.cat([t_res, q_res, s_res], dim=-1)
|
|
|
|
|
|
def _skel_transform_points(skel_state, points):
|
|
"""Apply skel_state (..., 8) to points (..., 3): t + q * (s * points).
|
|
|
|
Assumes the quaternion in skel_state is already unit-norm. Callers that
|
|
can't guarantee that should normalize first.
|
|
"""
|
|
t = skel_state[..., :3]
|
|
q = skel_state[..., 3:7]
|
|
s = skel_state[..., 7:8]
|
|
return t + _quat_rotate(q, s * points)
|
|
|
|
|
|
def _global_skel_state_from_local(local, pmi_levels):
|
|
"""FK walk in fp64 (matches the JIT's use_double_precision=True path).
|
|
|
|
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
|
|
one per BFS level. Avoids per-call torch.split + tolist() sync.
|
|
"""
|
|
orig_dtype = local.dtype
|
|
g = local.to(torch.float64).clone()
|
|
for source, target in pmi_levels:
|
|
parent = g.index_select(-2, target)
|
|
child = g.index_select(-2, source)
|
|
g.index_copy_(-2, source, _skel_multiply(parent, child))
|
|
return g.to(orig_dtype)
|
|
|
|
|
|
class MHRRig(nn.Module):
|
|
"""Plain-PyTorch reimplementation of Meta's MHR rig.
|
|
|
|
All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's
|
|
use_double_precision=True backend) regardless of the host model's dtype.
|
|
"""
|
|
|
|
NUM_VERTS = 18439
|
|
NUM_JOINTS = 127
|
|
NUM_LBS_TRIPLETS = 51337
|
|
NUM_IDENTITY = 45
|
|
NUM_EXPR = 72
|
|
PARAM_TRANSFORM_IN = 249 # = model_parameters(204) + identity_coeffs(45)
|
|
PARAM_TRANSFORM_OUT = 889 # = NUM_JOINTS * 7
|
|
POSE_CORR_IN = 750 # = (NUM_JOINTS - 2) * 6
|
|
POSE_CORR_HIDDEN = 3000
|
|
POSE_CORR_SPARSE_NNZ = 53136
|
|
|
|
def __init__(self, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
del dtype, operations
|
|
f32 = torch.float32
|
|
|
|
# All buffers are populated by load_state_dict from the `mhr.*` keys
|
|
def _p(*shape, dtype=f32):
|
|
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
|
|
def _b(name, *shape, dtype):
|
|
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
|
|
|
|
self.base_shape = _p(self.NUM_VERTS, 3)
|
|
self.identity_basis = _p(self.NUM_IDENTITY, self.NUM_VERTS, 3)
|
|
self.expr_basis = _p(self.NUM_EXPR, self.NUM_VERTS, 3)
|
|
self.param_transform = _p(self.PARAM_TRANSFORM_OUT, self.PARAM_TRANSFORM_IN)
|
|
|
|
self.skel_joint_translation_offsets = _p(self.NUM_JOINTS, 3)
|
|
self.skel_joint_prerotations = _p(self.NUM_JOINTS, 4)
|
|
_b("skel_joint_parents", self.NUM_JOINTS, dtype=torch.int32)
|
|
_b("skel_pmi", 2, 266, dtype=torch.int64)
|
|
_b("skel_pmi_buffer_sizes", 4, dtype=torch.int64)
|
|
|
|
self.lbs_inverse_bind_pose = _p(self.NUM_JOINTS, 8)
|
|
self.lbs_skin_weights = _p(self.NUM_LBS_TRIPLETS)
|
|
_b("lbs_skin_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int32)
|
|
_b("lbs_vert_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int64)
|
|
|
|
_b("pose_corr_sparse_indices", 2, self.POSE_CORR_SPARSE_NNZ, dtype=torch.int64)
|
|
self.pose_corr_sparse_weight = _p(self.POSE_CORR_SPARSE_NNZ)
|
|
|
|
self.register_buffer("pose_corr_sparse_shape", torch.tensor([self.POSE_CORR_HIDDEN, self.POSE_CORR_IN], dtype=torch.int64, device=device))
|
|
self.pose_corr_weight = _p(self.NUM_VERTS * 3, self.POSE_CORR_HIDDEN)
|
|
self.pose_corr_bias = None
|
|
self._pose_corr_sparse_cache = None
|
|
self._pmi_levels_cache = None
|
|
|
|
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
|
|
f32 = self.base_shape.dtype
|
|
identity_coeffs = identity_coeffs.to(f32)
|
|
model_parameters = model_parameters.to(f32)
|
|
expr_coeffs = expr_coeffs.to(f32)
|
|
B = identity_coeffs.shape[0]
|
|
|
|
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)
|
|
|
|
cat_in = torch.cat([model_parameters, torch.zeros_like(identity_coeffs)], dim=1)
|
|
joint_parameters = torch.einsum("dn,bn->bd", self.param_transform, cat_in)
|
|
|
|
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
|
|
local_t = jp[..., :3] + self.skel_joint_translation_offsets.unsqueeze(0)
|
|
local_q = _euler_xyz_to_quat(jp[..., 3:6])
|
|
local_q = _quat_multiply(self.skel_joint_prerotations.unsqueeze(0), local_q)
|
|
local_s = torch.exp(jp[..., 6:7] * _LN2)
|
|
local_state = torch.cat([local_t, local_q, local_s], dim=-1)
|
|
|
|
skel_state = _global_skel_state_from_local(local_state, self._pmi_levels())
|
|
|
|
face_expr = torch.einsum("nvd,bn->bvd", self.expr_basis, expr_coeffs)
|
|
unposed = identity_rest + face_expr
|
|
if apply_correctives:
|
|
unposed = unposed + self._pose_correctives(joint_parameters)
|
|
|
|
verts = self._skin(skel_state, unposed)
|
|
return verts, skel_state
|
|
|
|
def _pose_correctives(self, joint_parameters):
|
|
B = joint_parameters.shape[0]
|
|
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
|
|
# Joints [2:] only — root and one more skipped. Take Euler XYZ (cols 3:6).
|
|
feat = batch6DFromXYZ(jp[:, 2:, 3:6], return_9D=False) # (B, 125, 6)
|
|
feat[..., 0] -= 1.0
|
|
feat[..., 4] -= 1.0
|
|
feat = feat.flatten(1, 2) # (B, 750)
|
|
|
|
h = (self._sparse_w() @ feat.T).T # (B, 3000)
|
|
h = F.relu(h)
|
|
out = F.linear(h, self.pose_corr_weight, self.pose_corr_bias) # (B, 55317)
|
|
return out.view(B, self.NUM_VERTS, 3)
|
|
|
|
def _pmi_levels(self):
|
|
cached = self._pmi_levels_cache
|
|
pmi = self.skel_pmi
|
|
if cached is not None and cached[0][0].device == pmi.device:
|
|
return cached
|
|
sizes = self.skel_pmi_buffer_sizes.tolist()
|
|
parts = torch.split(pmi, sizes, dim=1)
|
|
levels = [(p[0], p[1]) for p in parts]
|
|
self._pmi_levels_cache = levels
|
|
return levels
|
|
|
|
def _sparse_w(self):
|
|
cached = self._pose_corr_sparse_cache
|
|
w = self.pose_corr_sparse_weight
|
|
if cached is not None and cached.device == w.device and cached.dtype == w.dtype:
|
|
return cached
|
|
sparse = torch.sparse_coo_tensor(
|
|
self.pose_corr_sparse_indices,
|
|
w,
|
|
tuple(self.pose_corr_sparse_shape.tolist()),
|
|
check_invariants=False,
|
|
).coalesce()
|
|
self._pose_corr_sparse_cache = sparse
|
|
return sparse
|
|
|
|
def _skin(self, skel_state, rest_verts):
|
|
B = skel_state.shape[0]
|
|
ibp = self.lbs_inverse_bind_pose.unsqueeze(0).expand(B, self.NUM_JOINTS, 8)
|
|
joint_xform = _skel_multiply(skel_state, ibp)
|
|
|
|
norm_q = F.normalize(joint_xform[..., 3:7], p=2, dim=-1, eps=1e-12)
|
|
joint_xform = torch.cat([joint_xform[..., :3], norm_q, joint_xform[..., 7:8]], dim=-1)
|
|
|
|
sk_idx = self.lbs_skin_indices.long()
|
|
v_idx = self.lbs_vert_indices
|
|
w = self.lbs_skin_weights
|
|
|
|
per_triplet_xform = joint_xform.index_select(-2, sk_idx) # (B, 51337, 8)
|
|
per_triplet_rest = rest_verts.index_select(-2, v_idx) # (B, 51337, 3)
|
|
contrib = _skel_transform_points(per_triplet_xform, per_triplet_rest) * w.unsqueeze(0).unsqueeze(-1)
|
|
|
|
out = torch.zeros(B, self.NUM_VERTS, 3, dtype=rest_verts.dtype, device=rest_verts.device)
|
|
out.index_add_(-2, v_idx, contrib)
|
|
return out
|