ComfyUI/comfy/ldm/sam3d_body/mhr/mhr_utils.py
2026-05-26 02:15:15 +03:00

414 lines
17 KiB
Python

# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity
# representation from Zhou et al., "On the Continuity of Rotation
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
# implementations from papagina/RotationContinuity:
# https://github.com/papagina/RotationContinuity/blob/758b0ce5/shapenet/code/tools.py
# The compact_cont_to_model_params_* functions are MHR-rig-specific glue.
import torch
import torch.nn.functional as F
def rotation_angle_difference(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Compute the angle difference (magnitude) between two batches of SO(3) rotation matrices.
Args:
A: Tensor of shape (*, 3, 3), batch of rotation matrices.
B: Tensor of shape (*, 3, 3), batch of rotation matrices.
Returns:
Tensor of shape (*,), angle differences in radians.
"""
# Compute relative rotation matrix
R_rel = torch.matmul(A, B.transpose(-2, -1)) # (B, 3, 3)
# Compute trace of relative rotation
trace = R_rel[..., 0, 0] + R_rel[..., 1, 1] + R_rel[..., 2, 2] # (B,)
# Compute angle using the trace formula
cos_theta = (trace - 1) / 2
# Clamp for numerical stability
cos_theta_clamped = torch.clamp(cos_theta, -1.0, 1.0)
# Compute angle difference
angle = torch.acos(cos_theta_clamped)
return angle
def fix_wrist_euler(
wrist_xzy, limits_x=(-2.2, 1.0), limits_z=(-2.2, 1.5), limits_y=(-1.2, 1.5)
):
"""
wrist_xzy: B x 2 x 3 (X, Z, Y angles)
Returns: Fixed angles within joint limits
"""
x, z, y = wrist_xzy[..., 0], wrist_xzy[..., 1], wrist_xzy[..., 2]
x_alt = torch.atan2(torch.sin(x + torch.pi), torch.cos(x + torch.pi))
z_alt = torch.atan2(torch.sin(-(z + torch.pi)), torch.cos(-(z + torch.pi)))
y_alt = torch.atan2(torch.sin(y + torch.pi), torch.cos(y + torch.pi))
# Calculate L2 violation distance
def calc_violation(val, limits):
below = torch.clamp(limits[0] - val, min=0.0)
above = torch.clamp(val - limits[1], min=0.0)
return below**2 + above**2
violation_orig = (
calc_violation(x, limits_x)
+ calc_violation(z, limits_z)
+ calc_violation(y, limits_y)
)
violation_alt = (
calc_violation(x_alt, limits_x)
+ calc_violation(z_alt, limits_z)
+ calc_violation(y_alt, limits_y)
)
# Use alternative where it has lower L2 violation
use_alt = violation_alt < violation_orig
# Stack alternative and apply mask
wrist_xzy_alt = torch.stack([x_alt, z_alt, y_alt], dim=-1)
result = torch.where(use_alt.unsqueeze(-1), wrist_xzy_alt, wrist_xzy)
return result
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py
def batch6DFromXYZ(r, return_9D=False):
"""
Generate a matrix representing a rotation defined by a XYZ-Euler
rotation.
Args:
r: ... x 3 rotation vectors
Returns:
... x 6
"""
rc = torch.cos(r)
rs = torch.sin(r)
cx = rc[..., 0]
cy = rc[..., 1]
cz = rc[..., 2]
sx = rs[..., 0]
sy = rs[..., 1]
sz = rs[..., 2]
result = torch.empty(list(r.shape[:-1]) + [3, 3], dtype=r.dtype).to(r.device)
result[..., 0, 0] = cy * cz
result[..., 0, 1] = -cx * sz + sx * sy * cz
result[..., 0, 2] = sx * sz + cx * sy * cz
result[..., 1, 0] = cy * sz
result[..., 1, 1] = cx * cz + sx * sy * sz
result[..., 1, 2] = -sx * cz + cx * sy * sz
result[..., 2, 0] = -sy
result[..., 2, 1] = sx * cy
result[..., 2, 2] = cx * cy
if not return_9D:
return torch.cat([result[..., :, 0], result[..., :, 1]], dim=-1)
else:
return result
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L82
def batchXYZfrom6D(poses):
# Args: poses: ... x 6, where "6" is the combined first and second columns
# First, get the rotaiton matrix
x_raw = poses[..., :3]
y_raw = poses[..., 3:]
x = F.normalize(x_raw, dim=-1)
z = torch.cross(x, y_raw, dim=-1)
z = F.normalize(z, dim=-1)
y = torch.cross(z, x, dim=-1)
matrix = torch.stack([x, y, z], dim=-1) # ... x 3 x 3
# Now get it into euler
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L412
sy = torch.sqrt(
matrix[..., 0, 0] * matrix[..., 0, 0] + matrix[..., 1, 0] * matrix[..., 1, 0]
)
singular = sy < 1e-6
singular = singular.float()
x = torch.atan2(matrix[..., 2, 1], matrix[..., 2, 2])
y = torch.atan2(-matrix[..., 2, 0], sy)
z = torch.atan2(matrix[..., 1, 0], matrix[..., 0, 0])
xs = torch.atan2(-matrix[..., 1, 2], matrix[..., 1, 1])
ys = torch.atan2(-matrix[..., 2, 0], sy)
zs = matrix[..., 1, 0] * 0
out_euler = torch.zeros_like(matrix[..., 0])
out_euler[..., 0] = x * (1 - singular) + xs * singular
out_euler[..., 1] = y * (1 - singular) + ys * singular
out_euler[..., 2] = z * (1 - singular) + zs * singular
return out_euler
_HAND_DOFS = [3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1]
_HAND_MASK_CACHE: dict = {} # device -> dict of masks
def _hand_masks(device):
m = _HAND_MASK_CACHE.get(device)
if m is not None:
return m
mask_cont_threedofs = torch.cat(
[torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
).to(device)
mask_cont_onedofs = torch.cat(
[torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
).to(device)
mask_model_params_threedofs = torch.cat(
[torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
).to(device)
mask_model_params_onedofs = torch.cat(
[torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
).to(device)
m = dict(
mask_cont_threedofs=mask_cont_threedofs,
mask_cont_onedofs=mask_cont_onedofs,
mask_model_params_threedofs=mask_model_params_threedofs,
mask_model_params_onedofs=mask_model_params_onedofs,
)
_HAND_MASK_CACHE[device] = m
return m
def compact_cont_to_model_params_hand(hand_cont):
# These are ordered by joint, not model params ^^
assert hand_cont.shape[-1] == 54
m = _hand_masks(hand_cont.device)
mask_cont_threedofs = m["mask_cont_threedofs"]
mask_cont_onedofs = m["mask_cont_onedofs"]
mask_model_params_threedofs = m["mask_model_params_threedofs"]
mask_model_params_onedofs = m["mask_model_params_onedofs"]
# Convert hand_cont to eulers
## First for 3DoFs
hand_cont_threedofs = hand_cont[..., mask_cont_threedofs].unflatten(-1, (-1, 6))
hand_model_params_threedofs = batchXYZfrom6D(hand_cont_threedofs).flatten(-2, -1)
## Next for 1DoFs
hand_cont_onedofs = hand_cont[..., mask_cont_onedofs].unflatten(
-1, (-1, 2)
) # (sincos)
hand_model_params_onedofs = torch.atan2(
hand_cont_onedofs[..., -2], hand_cont_onedofs[..., -1]
)
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
hand_model_params = torch.zeros(*hand_cont.shape[:-1], 27).to(hand_cont)
hand_model_params[..., mask_model_params_threedofs] = hand_model_params_threedofs
hand_model_params[..., mask_model_params_onedofs] = hand_model_params_onedofs
return hand_model_params
def compact_model_params_to_cont_hand(hand_model_params):
# These are ordered by joint, not model params ^^
assert hand_model_params.shape[-1] == 27
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
assert sum(hand_dofs_in_order) == 27
# Mask of 3DoFs into hand_cont
mask_cont_threedofs = torch.cat(
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
)
# Mask of 1DoFs (including 2DoF) into hand_cont
mask_cont_onedofs = torch.cat(
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
)
# Mask of 3DoFs into hand_model_params
mask_model_params_threedofs = torch.cat(
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
)
# Mask of 1DoFs (including 2DoF) into hand_model_params
mask_model_params_onedofs = torch.cat(
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
)
# Convert eulers to hand_cont hand_cont
## First for 3DoFs
hand_model_params_threedofs = hand_model_params[
..., mask_model_params_threedofs
].unflatten(-1, (-1, 3))
hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
## Next for 1DoFs
hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
hand_cont_onedofs = torch.stack(
[hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
).flatten(-2, -1)
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
return hand_cont
def batch9Dfrom6D(poses):
# Args: poses: ... x 6, where "6" is the combined first and second columns
# First, get the rotaiton matrix
x_raw = poses[..., :3]
y_raw = poses[..., 3:]
x = F.normalize(x_raw, dim=-1)
z = torch.cross(x, y_raw, dim=-1)
z = F.normalize(z, dim=-1)
y = torch.cross(z, x, dim=-1)
matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
return matrix
def batch4Dfrom2D(poses):
# Args: poses: ... x 2, where "2" is sincos
poses_norm = F.normalize(poses, dim=-1)
poses_4d = torch.stack(
[
poses_norm[..., 1],
poses_norm[..., 0],
-poses_norm[..., 0],
poses_norm[..., 1],
],
dim=-1,
) # Flattened SO2.
return poses_4d # .... x 4
def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
# fmt: on
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_cont.shape[-1] == (
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
)
# Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
]
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
# Convert conts to model params
## First for 3dofs
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
## Next for 1dofs
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
if inflate_trans:
assert (
False
), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
else:
## Nothing to do for trans
body_rotmat_trans = body_cont_trans
# Put them together
body_rotmat_params = torch.cat(
[body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
)
return body_rotmat_params
_BODY_IDX_CACHE: dict = {}
def _body_idxs(device):
cached = _BODY_IDX_CACHE.get(device)
if cached is not None:
return cached
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]).to(device)
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]).to(device)
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]).to(device)
# fmt: on
cached = (
all_param_3dof_rot_idxs,
all_param_1dof_rot_idxs,
all_param_1dof_trans_idxs,
all_param_3dof_rot_idxs.flatten(),
)
_BODY_IDX_CACHE[device] = cached
return cached
def compact_cont_to_model_params_body(body_pose_cont):
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
# Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
# Convert conts to model params
## First for 3dofs
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
body_params_3dofs = batchXYZfrom6D(body_cont_3dofs).flatten(-2, -1)
## Next for 1dofs
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
body_params_1dofs = torch.atan2(body_cont_1dofs[..., -2], body_cont_1dofs[..., -1])
## Nothing to do for trans
body_params_trans = body_cont_trans
# Put them together
body_pose_params = torch.zeros(*body_pose_cont.shape[:-1], 133, dtype=body_pose_cont.dtype, device=body_pose_cont.device)
body_pose_params[..., idxs_3dof_flat] = body_params_3dofs
body_pose_params[..., all_param_1dof_rot_idxs] = body_params_1dofs
body_pose_params[..., all_param_1dof_trans_idxs] = body_params_trans
return body_pose_params
def compact_model_params_to_cont_body(body_pose_params):
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
# fmt: on
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_params.shape[-1] == (
num_3dof_angles + num_1dof_angles + num_1dof_trans
)
# Take out params
body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
# params to cont
body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
-2, -1
)
body_cont_1dofs = torch.stack(
[body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
).flatten(-2, -1)
body_cont_trans = body_params_trans
# Put them together
body_pose_cont = torch.cat(
[body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
)
return body_pose_cont
# fmt: off
mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
mhr_param_hand_mask = torch.zeros(133).bool()
mhr_param_hand_mask[mhr_param_hand_idxs] = True
mhr_cont_hand_mask = torch.zeros(260).bool()
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
# fmt: on