Big cleanup

This commit is contained in:
kijai 2026-06-16 20:47:15 +03:00
parent f1be65f914
commit ecbaefd8fc
13 changed files with 376 additions and 877 deletions

View File

@ -4,25 +4,15 @@ import torch
import torch.nn as nn
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, mhr_param_hand_mask
from ..model.transformer import MLP
class MHRHead(nn.Module):
def __init__(
self,
input_dim: int,
mhr_rig,
mlp_depth: int = 1,
extra_joint_regressor: str = "",
mlp_channel_div_factor: int = 8,
enable_hand_model=False,
device=None,
dtype=None,
operations=None,
):
def __init__(self, input_dim: int, mhr_rig, mlp_depth: int = 1, mlp_channel_div_factor: int = 8, enable_hand_model=False,
device=None, dtype=None, operations=None):
super().__init__()
# Store the shared MHRRig as a non-registered Python attribute
object.__setattr__(self, "mhr", mhr_rig)
@ -48,9 +38,7 @@ class MHRHead(nn.Module):
hidden_dim=input_dim // mlp_channel_div_factor,
output_dim=self.npose,
num_layers=mlp_depth,
device=device,
dtype=dtype,
operations=operations,
device=device, dtype=dtype, operations=operations,
)
# MHR Parameters
@ -75,28 +63,25 @@ class MHRHead(nn.Module):
self.local_to_world_wrist = _p(3, 3)
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
# Optional — loaded from the .safetensors if present, otherwise the
# render path falls back to a coarse geometric approximation.
self.register_buffer(
"face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32),
)
self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32))
def canonical_vertices(self, device=None):
def canonical_vertices(self):
"""Return the T-pose vertices for the mean shape (scaled to meters).
Runs MHR with zero pose / shape / scale / expression so the returned
mesh is the canonical rest pose fixed per-model
"""
dev = device or self.scale_mean.device
dt = self.scale_mean.dtype
device = self.scale_mean.device
dtype = self.scale_mean.dtype
B = 1
global_trans = torch.zeros(B, 3, device=dev, dtype=dt)
global_rot = torch.zeros(B, 3, device=dev, dtype=dt)
body_pose = torch.zeros(B, 130, device=dev, dtype=dt)
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt)
scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt)
shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt)
expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt)
global_trans = torch.zeros(B, 3, device=device, dtype=dtype)
global_rot = torch.zeros(B, 3, device=device, dtype=dtype)
body_pose = torch.zeros(B, 130, device=device, dtype=dtype)
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=device, dtype=dtype)
scale = torch.zeros(B, self.num_scale_comps, device=device, dtype=dtype)
shape = torch.zeros(B, self.num_shape_comps, device=device, dtype=dtype)
expr = torch.zeros(B, self.num_face_comps, device=device, dtype=dtype)
verts = self.mhr_forward(
global_trans=global_trans,
global_rot=global_rot,
@ -108,20 +93,6 @@ class MHRHead(nn.Module):
) # single-tensor shape (1, N_v, 3) in meters
return verts[0]
def get_zero_pose_init(self, factor=1.0):
# Initialize pose token with zero-initialized learnable params
# Note: bias/initial value should be zero-pose in cont, not all-zeros
weights = torch.zeros(1, self.npose)
weights[:, : 6 + self.body_cont_dim] = torch.cat(
[
torch.FloatTensor([1, 0, 0, 0, 1, 0]),
compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
* factor,
],
dim=0,
)
return weights
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
assert full_pose_params.shape[1] == 136
@ -159,12 +130,9 @@ class MHRHead(nn.Module):
shape_params,
expr_params=None,
return_keypoints=False,
do_pcblend=True,
return_joint_coords=False,
return_model_params=False,
return_joint_rotations=False,
scale_offsets=None,
vertex_offsets=None,
):
# Align everything to the static buffers
dt = self.scale_mean.dtype
@ -206,14 +174,10 @@ class MHRHead(nn.Module):
shape_params = shape_params[None]
# Convert scale...
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
if scale_offsets is not None:
scales = scales + scale_offsets
# Now, figure out the pose.
## 10 here is because it's more stable to optimize global translation in meters.
full_pose_params = torch.cat(
[global_trans * 10, global_rot, body_pose_params], dim=1
) # B x 127
full_pose_params = torch.cat([global_trans * 10, global_rot, body_pose_params], dim=1) # B x 127
## Put in hands
if hand_pose_params is not None:
full_pose_params = self.replace_hands_in_pose(
@ -268,14 +232,7 @@ class MHRHead(nn.Module):
else:
return tuple(to_return)
def forward(
self,
x: torch.Tensor,
init_estimate: Optional[torch.Tensor] = None,
do_pcblend=True,
slim_keypoints=False,
intermediate: bool = False,
):
def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None, intermediate: bool = False):
"""
Args:
x: pose token with shape [B, C], usually C=DECODER.DIM
@ -331,7 +288,6 @@ class MHRHead(nn.Module):
scale_params=pred_scale,
shape_params=pred_shape,
expr_params=pred_face,
do_pcblend=do_pcblend,
return_keypoints=True,
return_joint_coords=True,
return_model_params=True,
@ -356,7 +312,7 @@ class MHRHead(nn.Module):
# Head-MLP outputs are promoted to fp32 here so the external
# pose_output["mhr"] contract has a stable dtype regardless of what
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are
# already fp32 from MHR's math layers; the cast on them is a no-op.
# already fp32 from MHR's math layers.
output = {
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
"pred_pose_rotmat": None,

View File

@ -1,7 +1,7 @@
# Adapted from facebookresearch/MHR (Apache 2.0):
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
# verbatim from the TorchScript source bundled in the upstream mhr_model.pt
# verbatim from the upstream mhr_model.pt
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
@ -52,7 +52,7 @@ def _skel_multiply(s1, s2):
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
before composition. With many FK levels the previously-normalized quats
drift in ULPs; the JIT renormalizes defensively, so we do too to stay
drift in ULPs; upstream renormalizes defensively, so we do too to stay
bit-close to its outputs.
"""
t1, sc1 = s1[..., :3], s1[..., 7:8]
@ -78,7 +78,7 @@ def _skel_transform_points(skel_state, points):
def _global_skel_state_from_local(local, pmi_levels):
"""FK walk in fp64 (matches the JIT's use_double_precision=True path).
"""FK walk in fp64 (matches upstream's use_double_precision=True path).
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
one per BFS level. Avoids per-call torch.split + tolist() sync.
@ -95,7 +95,7 @@ def _global_skel_state_from_local(local, pmi_levels):
class MHRRig(nn.Module):
"""Plain-PyTorch reimplementation of Meta's MHR rig.
All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's
All math runs in fp32 (FK upcast to fp64 internally, matching upstream's
use_double_precision=True backend) regardless of the host model's dtype.
"""
@ -110,13 +110,11 @@ class MHRRig(nn.Module):
POSE_CORR_HIDDEN = 3000
POSE_CORR_SPARSE_NNZ = 53136
def __init__(self, device=None, dtype=None, operations=None):
def __init__(self, device=None):
super().__init__()
del dtype, operations
f32 = torch.float32
# All buffers are populated by load_state_dict from the `mhr.*` keys
def _p(*shape, dtype=f32):
def _p(*shape, dtype=torch.float32):
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
def _b(name, *shape, dtype):
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
@ -147,10 +145,10 @@ class MHRRig(nn.Module):
self._pmi_levels_cache = None
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
f32 = self.base_shape.dtype
identity_coeffs = identity_coeffs.to(f32)
model_parameters = model_parameters.to(f32)
expr_coeffs = expr_coeffs.to(f32)
dtype = self.base_shape.dtype
identity_coeffs = identity_coeffs.to(dtype)
model_parameters = model_parameters.to(dtype)
expr_coeffs = expr_coeffs.to(dtype)
B = identity_coeffs.shape[0]
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)

View File

@ -1,5 +1,5 @@
# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity
# (batch6DFromXYZ, batchXYZfrom6D) are the continuity
# representation from Zhou et al., "On the Continuity of Rotation
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
# implementations from papagina/RotationContinuity:
@ -158,18 +158,10 @@ def _hand_masks(device):
m = _HAND_MASK_CACHE.get(device)
if m is not None:
return m
mask_cont_threedofs = torch.cat(
[torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
).to(device)
mask_cont_onedofs = torch.cat(
[torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
).to(device)
mask_model_params_threedofs = torch.cat(
[torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
).to(device)
mask_model_params_onedofs = torch.cat(
[torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
).to(device)
mask_cont_threedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
mask_cont_onedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
mask_model_params_threedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
mask_model_params_onedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
m = dict(
mask_cont_threedofs=mask_cont_threedofs,
mask_cont_onedofs=mask_cont_onedofs,
@ -182,7 +174,6 @@ def _hand_masks(device):
def compact_cont_to_model_params_hand(hand_cont):
# These are ordered by joint, not model params ^^
assert hand_cont.shape[-1] == 54
m = _hand_masks(hand_cont.device)
mask_cont_threedofs = m["mask_cont_threedofs"]
mask_cont_onedofs = m["mask_cont_onedofs"]
@ -209,120 +200,6 @@ def compact_cont_to_model_params_hand(hand_cont):
return hand_model_params
def compact_model_params_to_cont_hand(hand_model_params):
# These are ordered by joint, not model params ^^
assert hand_model_params.shape[-1] == 27
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
assert sum(hand_dofs_in_order) == 27
# Mask of 3DoFs into hand_cont
mask_cont_threedofs = torch.cat(
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
)
# Mask of 1DoFs (including 2DoF) into hand_cont
mask_cont_onedofs = torch.cat(
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
)
# Mask of 3DoFs into hand_model_params
mask_model_params_threedofs = torch.cat(
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
)
# Mask of 1DoFs (including 2DoF) into hand_model_params
mask_model_params_onedofs = torch.cat(
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
)
# Convert eulers to hand_cont hand_cont
## First for 3DoFs
hand_model_params_threedofs = hand_model_params[
..., mask_model_params_threedofs
].unflatten(-1, (-1, 3))
hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
## Next for 1DoFs
hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
hand_cont_onedofs = torch.stack(
[hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
).flatten(-2, -1)
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
return hand_cont
def batch9Dfrom6D(poses):
# Args: poses: ... x 6, where "6" is the combined first and second columns
# First, get the rotaiton matrix
x_raw = poses[..., :3]
y_raw = poses[..., 3:]
x = F.normalize(x_raw, dim=-1)
z = torch.cross(x, y_raw, dim=-1)
z = F.normalize(z, dim=-1)
y = torch.cross(z, x, dim=-1)
matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
return matrix
def batch4Dfrom2D(poses):
# Args: poses: ... x 2, where "2" is sincos
poses_norm = F.normalize(poses, dim=-1)
poses_4d = torch.stack(
[
poses_norm[..., 1],
poses_norm[..., 0],
-poses_norm[..., 0],
poses_norm[..., 1],
],
dim=-1,
) # Flattened SO2.
return poses_4d # .... x 4
def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
# fmt: on
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_cont.shape[-1] == (
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
)
# Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
]
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
# Convert conts to model params
## First for 3dofs
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
## Next for 1dofs
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
if inflate_trans:
assert (
False
), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
else:
## Nothing to do for trans
body_rotmat_trans = body_cont_trans
# Put them together
body_rotmat_params = torch.cat(
[body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
)
return body_rotmat_params
_BODY_IDX_CACHE: dict = {}
@ -349,8 +226,6 @@ def compact_cont_to_model_params_body(body_pose_cont):
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
# Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
@ -372,42 +247,10 @@ def compact_cont_to_model_params_body(body_pose_cont):
return body_pose_params
def compact_model_params_to_cont_body(body_pose_params):
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
# fmt: on
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_params.shape[-1] == (
num_3dof_angles + num_1dof_angles + num_1dof_trans
)
# Take out params
body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
# params to cont
body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
-2, -1
)
body_cont_1dofs = torch.stack(
[body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
).flatten(-2, -1)
body_cont_trans = body_params_trans
# Put them together
body_pose_cont = torch.cat(
[body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
)
return body_pose_cont
# fmt: off
mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
# Hand indices into the 133-dim param and 260-dim cont body-pose vectors.
mhr_param_hand_idxs = list(range(62, 116))
mhr_cont_hand_idxs = list(range(72, 132)) + list(range(190, 238))
mhr_param_hand_mask = torch.zeros(133).bool()
mhr_param_hand_mask[mhr_param_hand_idxs] = True
mhr_cont_hand_mask = torch.zeros(260).bool()
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
# fmt: on

View File

@ -43,15 +43,6 @@ class FourierPositionEncoding(nn.Module):
self.num_bands = num_bands
self.max_resolution = [max_resolution] * n
@property
def channels(self):
num_dims = len(self.max_resolution)
encoding_size = self.num_bands * num_dims
encoding_size *= 2 # sin-cos
encoding_size += num_dims # concat
return encoding_size
def forward(self, pos: torch.Tensor):
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
return fourier_pos_enc
@ -118,9 +109,7 @@ class PerspectiveHead(nn.Module):
pred_cam: torch.Tensor,
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
bbox_size: torch.Tensor, # [N,], in original image space
img_size: torch.Tensor,
cam_int: torch.Tensor, # [B, 3, 3]
use_intrin_center: bool = False,
):
batch_size = points_3d.shape[0]
pred_cam = pred_cam.clone()
@ -133,12 +122,8 @@ class PerspectiveHead(nn.Module):
focal_length = cam_int[:, 0, 0]
tz = 2 * focal_length / bs
if not use_intrin_center:
cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs
cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
else:
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
cx = 2 * (bbox_center[:, 0] - cam_int[:, 0, 2]) / bs
cy = 2 * (bbox_center[:, 1] - cam_int[:, 1, 2]) / bs
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)

View File

@ -37,20 +37,15 @@ class SAM3DBody(nn.Module):
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
# `operations` falls back to torch.nn so the model is constructible
# without comfy.ops; matches the pattern in comfy/ldm/sam3/.
ops = operations if operations is not None else nn
# Per-batch state populated by `_initialize_batch`.
self._max_num_person = None
self._person_valid = None
self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False)
self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False)
self.image_size = IMAGE_SIZE
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=ops)
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations)
embed_dims = self.backbone.embed_dims
# MHR rig shared between body + hand pose heads via a non-registered
@ -72,7 +67,7 @@ class SAM3DBody(nn.Module):
self.head_pose.hand_pose_comps.data = (
torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
)
self.init_pose = ops.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs)
self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
@ -81,7 +76,7 @@ class SAM3DBody(nn.Module):
self.head_pose_hand.hand_pose_comps.data = (
torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
)
self.init_pose_hand = ops.Embedding(
self.init_pose_hand = operations.Embedding(
1, self.head_pose_hand.npose, device=device, dtype=dtype
)
@ -93,25 +88,25 @@ class SAM3DBody(nn.Module):
device=device, dtype=dtype, operations=operations,
)
self.head_camera = PerspectiveHead(**camera_kwargs)
self.init_camera = ops.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype)
self.init_camera = operations.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype)
self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs)
self.init_camera_hand = ops.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype)
self.init_camera_hand = operations.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype)
cond_dim = 3
init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim
linear_kwargs = dict(device=device, dtype=dtype)
self.init_to_token_mhr = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs)
self.prev_to_token_mhr = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
self.init_to_token_mhr_hand = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs)
self.prev_to_token_mhr_hand = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
self.init_to_token_mhr = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs)
self.prev_to_token_mhr = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
self.init_to_token_mhr_hand = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs)
self.prev_to_token_mhr_hand = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
self.prompt_encoder = PromptEncoder(
embed_dim=embed_dims, # match backbone dims so PE adds directly
num_body_joints=N_KEYPOINTS,
device=device, dtype=dtype, operations=operations,
)
self.prompt_to_token = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
self.prompt_to_token = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
decoder_kwargs = dict(
dims=DECODER_DIM,
@ -141,11 +136,10 @@ class SAM3DBody(nn.Module):
self.keypoint_embedding_idxs = list(range(N_KEYPOINTS))
self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS))
self.keypoint_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.hand_box_embedding = ops.Embedding(2, DECODER_DIM, **linear_kwargs)
self.hand_cls_embed = ops.Linear(DECODER_DIM, 2, **linear_kwargs)
self.hand_box_embedding = operations.Embedding(2, DECODER_DIM, **linear_kwargs)
self.bbox_embed = MLP(
input_dim=DECODER_DIM, hidden_dim=DECODER_DIM,
output_dim=4, num_layers=3,
@ -158,13 +152,13 @@ class SAM3DBody(nn.Module):
)
self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs)
self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs)
self.keypoint_feat_linear = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
self.keypoint_feat_linear_hand = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
self.keypoint_feat_linear = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
self.keypoint_feat_linear_hand = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS))
self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS))
self.keypoint3d_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint3d_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint3d_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint3d_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs)
self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs)
@ -183,11 +177,9 @@ class SAM3DBody(nn.Module):
def _initialize_batch(self, batch: Dict) -> None:
if batch["img"].dim() == 5:
self._batch_size, self._max_num_person = batch["img"].shape[:2]
self._person_valid = self._flatten_person(batch["person_valid"]) > 0
else:
self._batch_size = batch["img"].shape[0]
self._max_num_person = 0
self._person_valid = None
def _flatten_person(self, x: torch.Tensor) -> torch.Tensor:
assert self._max_num_person is not None, "No max_num_person initialized"
@ -258,11 +250,9 @@ class SAM3DBody(nn.Module):
if is_multi_image:
assert isinstance(img, list)
n = len(img)
H_src, W_src = img[0].shape[:2]
src_t = torch.stack(list(img), dim=0)
else:
n = int(left_xyxy.shape[0])
H_src, W_src = img.shape[:2]
src_t = img.unsqueeze(0).expand(n, -1, -1, -1)
H_out, W_out = int(self.image_size[0]), int(self.image_size[1])
@ -292,14 +282,12 @@ class SAM3DBody(nn.Module):
zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device)
person_valid = torch.ones((1, n), dtype=torch.float32, device=device)
img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous()
ori_img_size = torch.tensor([W_src, H_src], dtype=torch.float32, device=device).expand(n, 2).contiguous()
cam_int_dev = cam_int.to(device).to(dtype=torch.float32)
def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy):
return {
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
"img_size": img_size.unsqueeze(0),
"ori_img_size": ori_img_size.unsqueeze(0),
"bbox_center": centers_t.to(device).unsqueeze(0),
"bbox_scale": scales_t.to(device).unsqueeze(0),
"bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0),
@ -349,7 +337,6 @@ class SAM3DBody(nn.Module):
self,
branch: str,
image_embeddings: torch.Tensor,
init_estimate: Optional[torch.Tensor] = None,
keypoints: Optional[torch.Tensor] = None,
prev_estimate: Optional[torch.Tensor] = None,
condition_info: Optional[torch.Tensor] = None,
@ -359,7 +346,6 @@ class SAM3DBody(nn.Module):
of the pipeline is shared.
image_embeddings: (B, C, H, W) backbone features.
init_estimate: (B, 1, C) initial pose+cam estimate to refine.
keypoints: (B, N, 3) prompts as (x, y in [0, 1], label).
label: 0..K = joint, -1 = incorrect, -2 = invalid.
prev_estimate: (B, 1, C) previous estimate for pose refinement.
@ -402,15 +388,11 @@ class SAM3DBody(nn.Module):
# .to(image_embeddings) moves weights CPU→GPU under dynamic loading
# (they stay on CPU until first use).
if init_estimate is None:
init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3)
init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3)
init_input = (
torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
if condition_info is not None else init_estimate
)
init_input = torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
token_embeddings = init_to_token(init_input).view(batch_size, 1, -1)
num_pose_token = token_embeddings.shape[1] # always 1
@ -495,9 +477,8 @@ class SAM3DBody(nn.Module):
def _get_mask_prompt(self, batch, image_embeddings):
x_mask = self._flatten_person(batch["mask"])
# batch tensors are fp32 from prepare_batch; mask_downscaling is in the
# Loader's dtype — cast once so the conv input matches.
x_mask = x_mask.to(image_embeddings.dtype)
mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings(
x_mask, image_embeddings.shape[0], image_embeddings.shape[2:]
)
@ -546,7 +527,6 @@ class SAM3DBody(nn.Module):
# expand+contiguous for the vertices branch.
bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx]
bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0]
ori_img_size = self._flatten_person(batch["ori_img_size"])[batch_idx]
cam_int = self._flatten_person(
batch["cam_int"]
.unsqueeze(1)
@ -556,8 +536,7 @@ class SAM3DBody(nn.Module):
def _project(points_3d):
return head_camera.perspective_projection(
points_3d, pred_cam, bbox_center, bbox_scale, ori_img_size, cam_int,
use_intrin_center=True,
points_3d, pred_cam, bbox_center, bbox_scale, cam_int,
)
cam_out = _project(pose_output["pred_keypoints_3d"])
@ -632,7 +611,6 @@ class SAM3DBody(nn.Module):
tokens_output, pose_output = self.forward_decoder(
"body",
image_embeddings[self.body_batch_idx],
init_estimate=None,
keypoints=keypoints_prompt[self.body_batch_idx],
prev_estimate=None,
condition_info=condition_info[self.body_batch_idx],
@ -643,7 +621,6 @@ class SAM3DBody(nn.Module):
tokens_output_hand, pose_output_hand = self.forward_decoder(
"hand",
image_embeddings[self.hand_batch_idx],
init_estimate=None,
keypoints=keypoints_prompt[self.hand_batch_idx],
prev_estimate=None,
condition_info=condition_info[self.hand_batch_idx],
@ -661,10 +638,8 @@ class SAM3DBody(nn.Module):
# match the head-MLP external contract (_get_hand_box would .float() anyway).
if len(self.body_batch_idx):
output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float()
output["mhr"]["hand_logits"] = self.hand_cls_embed(tokens_output).float()
if len(self.hand_batch_idx):
output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid()
output["mhr_hand"]["hand_logits"] = self.hand_cls_embed(tokens_output_hand)
return output
@ -715,10 +690,10 @@ class SAM3DBody(nn.Module):
# Concat lhand+rhand along dim 0 so backbone+decoder run once on
# (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass.
batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand)
saved_batch_state = (self._batch_size, self._max_num_person, self._person_valid)
saved_batch_state = (self._batch_size, self._max_num_person)
self._initialize_batch(batch_hands)
hands_output = self.forward_step(batch_hands, decoder_type="hand")
self._batch_size, self._max_num_person, self._person_valid = saved_batch_state
self._batch_size, self._max_num_person = saved_batch_state
n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1]
lhand_output, rhand_output = self._split_hand_output(hands_output, n_left)
# Free the batched image_embeddings/condition_info (unused downstream);
@ -808,9 +783,7 @@ class SAM3DBody(nn.Module):
# to get an updated body pose estimation.
self._set_active_branch("body")
right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1
# right_kps_full / left_kps_full already computed above (unchanged since).
right_kps_crop = self._full_to_crop(batch, right_kps_full)
left_kps_crop = self._full_to_crop(batch, left_kps_full)
@ -1030,7 +1003,6 @@ class SAM3DBody(nn.Module):
_, pose_output = self.forward_decoder(
"body",
image_embeddings,
init_estimate=None, # use the default init, not the prev estimate
keypoints=keypoint_prompt,
prev_estimate=prev_estimate,
condition_info=condition_info,

View File

@ -29,38 +29,37 @@ class PromptEncoder(nn.Module):
Encodes prompts for input to SAM's mask decoder.
"""
super().__init__()
ops = operations if operations is not None else nn
self.embed_dim = embed_dim
self.num_body_joints = num_body_joints
# Keypoint prompts
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.point_embeddings = nn.ModuleList(
[ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
[operations.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
)
self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
self.not_a_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
self.invalid_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
LN2d = LayerNorm2d_op(ops)
LN2d = LayerNorm2d_op(operations)
mask_in_chans = 256
self.mask_downscaling = nn.Sequential(
ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
operations.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 64, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
operations.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 16, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
operations.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 4, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
operations.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
operations.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
)
# Trained values for the gating conv and no_mask_embed are loaded from the state dict
self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
self.no_mask_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
"""Positional encoding over the image-embedding grid; (1, C, H, W)."""
@ -120,8 +119,7 @@ class PromptEncoder(nn.Module):
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(keypoints, boxes, masks)
# Anchor device on the input prompts so we don't pull the offloaded
# CPU embedding device under dynamic loading.
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
device = ref.device if ref is not None else self.point_embeddings[0].weight.device
weight_dtype = self.invalid_point_embed.weight.dtype
@ -136,23 +134,10 @@ class PromptEncoder(nn.Module):
return sparse_embeddings, sparse_masks
def get_mask_embeddings(
self,
masks: Optional[torch.Tensor] = None,
bs: int = 1,
size: Tuple[int, int] = (16, 16), # [H, W]
) -> torch.Tensor:
"""Embeds mask inputs."""
# masks is always on the active device when present; fall back to the
# downscaling Conv's weight device when it isn't (rare callers).
ref = masks if masks is not None else next(self.mask_downscaling.parameters())
no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand(
bs, -1, size[0], size[1]
)
if masks is not None:
mask_embeddings = self.mask_downscaling(masks)
else:
mask_embeddings = no_mask_embeddings
def get_mask_embeddings(self, masks: torch.Tensor, bs: int = 1, size: Tuple[int, int] = (16, 16)) -> torch.Tensor:
"""Embeds mask inputs. Caller casts both outputs to its working dtype."""
no_mask_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, size[0], size[1])
mask_embeddings = self.mask_downscaling(masks)
return mask_embeddings, no_mask_embeddings
@ -170,12 +155,9 @@ class PromptableDecoder(nn.Module):
repeat_pe: bool = False,
do_interm_preds: bool = False,
keypoint_token_update: bool = False,
device=None,
dtype=None,
operations=None,
device=None, dtype=None, operations=None,
):
super().__init__()
ops = operations if operations is not None else nn
self.layers = nn.ModuleList(
TransformerDecoderLayer(
@ -193,7 +175,7 @@ class PromptableDecoder(nn.Module):
for i in range(depth)
)
self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
self.norm_final = operations.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
self.do_interm_preds = do_interm_preds
self.keypoint_token_update = keypoint_token_update

View File

@ -166,12 +166,10 @@ def prepare_batch(
mask_score_t = torch.ones((n,), dtype=torch.float32)
img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous()
ori_img_size_t = torch.tensor([width, height], dtype=torch.float32).expand(n, 2).contiguous()
batch = {
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
"img_size": img_size_t.unsqueeze(0), # (1, N, 2)
"ori_img_size": ori_img_size_t.unsqueeze(0),# (1, N, 2)
"bbox_center": centers.unsqueeze(0), # (1, N, 2)
"bbox_scale": scales.unsqueeze(0), # (1, N, 2)
"bbox": boxes_t.unsqueeze(0), # (1, N, 4)

View File

@ -1,11 +1,9 @@
"""BVH export for SAM 3D Body pose_data.
BVH stores explicit bone OFFSETs per joint, so any standard importer
(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations
directly no heuristic guessing as needed for glTF. We skip the rig's joint 0
(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos +
ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are
intrinsic Z-X-Y Euler degrees.
BVH stores explicit bone OFFSETs per joint, so standard importers reconstruct
anatomical bone orientations directly (unlike glTF). We skip the rig's joint 0
(static world anchor) and use joint 1 as the ROOT (6 channels: XYZ pos + ZXY
rot); other joints get 3 channels. Rotations are intrinsic Z-X-Y Euler degrees.
"""
from __future__ import annotations
@ -49,13 +47,10 @@ def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray:
def _find_bvh_root(parents: np.ndarray, is_external: bool = False) -> int:
"""First child of the rig's world anchor so the static origin→body stick
bone gets left out. Falls back to the first root joint.
MHR's joint 0 is a static world anchor whose single child is the pelvis, so
skipping it is correct. External rigs (e.g. SOMA-77) whose root is already
the articulated body root with multiple child chains must keep the root
descending into one child would drop the sibling limbs from the BVH."""
"""First child of the rig's world anchor, dropping the origin→body stick.
Falls back to the first root joint. External rigs whose root is already the
articulated body root with multiple child chains keep the root descending
into one child would drop the sibling limbs."""
NJ = parents.shape[0]
world_anchors = [j for j in range(NJ)
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
@ -93,14 +88,11 @@ def build_bvh(
track_index: int = -1,
units: str = "cm",
) -> bytes:
"""Build a BVH file from pose_data. Returns UTF-8 encoded text bytes.
"""Build a BVH file from pose_data. Returns UTF-8 text bytes.
`model` may be None when pose_data carries a `_skeleton_override` (external
rigs, e.g. Kimodo); the rig hierarchy/offsets/bind are read from the
override instead of the MHR model.
`units` is "cm" (default, standard mocap convention) or "m". Affects the
OFFSET and root-position values; rotations are independent of units.
rigs); the rig hierarchy/offsets/bind come from the override. `units` is
"cm" (default) or "m" affects OFFSET/root-position, not rotations.
"""
if units not in ("cm", "m"):
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
@ -123,10 +115,8 @@ def build_bvh(
body_root = _find_bvh_root(parents, is_external)
children_map = _build_children_map(parents)
# Bone OFFSETs come from MHR's translation_offsets (joint position
# relative to parent in parent's local-bind frame). For the BVH root,
# we use its bind world position so the skeleton sits at the right
# spot when imported.
# Bone OFFSETs = translation_offsets (joint position relative to parent).
# The BVH root uses its bind world position so the skeleton imports in place.
bind_global = rig.bind_global_cm # (NJ, 8) cm
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
offset_m = rig.joint_offsets_cm.astype(np.float64) * 0.01
@ -139,9 +129,8 @@ def build_bvh(
_visit(c)
_visit(body_root)
# Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative)
# rather than re-running rig.forward, then derive locals with body_root
# treated as the hierarchy root in BVH-space.
# Stored pred_global_rots/pred_joint_coords (authoritative); derive locals
# with body_root as the BVH-space hierarchy root.
rig_global_m = global_skel_state_from_pose_data(
pose_data, frame_indices, person_k, NJ,
joint_coords_y_down=rig.per_frame_y_down,
@ -203,9 +192,8 @@ def build_bvh(
lines.append(f"Frames: {n_frames}")
lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
# Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per
# frame, columns in `bvh_order` order. Vectorized — savetxt's C-side
# formatting beats Python f-strings by ~10× on long clips.
# Channel matrix per frame: root pos (3) + root rot (3) + non-root rots
# (3 each), columns in `bvh_order`. savetxt is far faster than f-strings.
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
motion = np.concatenate([
root_pos_m * unit_scale, # (N, 3)

View File

@ -1,12 +1,9 @@
"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent
torch ray-marching SDF renderer adapted to SAM3DBody pose_data.
"""3D capsule rendering for OpenPose-style skeletons (SCAIL-Pose look).
Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps),
projected through the per-person camera (`pred_cam_t` + `focal_length` +
image_size) so closer limbs appear thicker/brighter the SCAIL-Pose
visual style. Self-contained: no dependency on the SCAIL-Pose package.
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
Each limb is a true 3D capsule (cylinder + hemispherical caps), projected
through the per-person camera (`pred_cam_t` + `focal_length` + image_size) so
closer limbs appear thicker/brighter. Self-contained analytic ray-capsule
renderer. Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
"""
from typing import Any, Dict, List, Optional, Tuple
@ -41,14 +38,12 @@ def _build_specs_from_pose(
palette: str,
person_brightness_falloff: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Flatten body + optional hand limbs for one frame into
(starts, ends, colors_rgba, is_hand) in camera coords (Y-down, +Z forward).
Drops endpoints that are non-finite or behind the camera. `is_hand` flags
the hand limbs so the renderer can draw them thinner.
"""Flatten body + optional hand limbs for one frame into (starts, ends,
colors_rgba, is_hand) in camera coords (Y-down, +Z forward). Drops non-finite
or behind-camera endpoints; `is_hand` lets the renderer draw hands thinner.
`person_brightness_falloff` mixes each per-person limb color toward white
by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the
mesh rasterizer and GLB exporters."""
`person_brightness_falloff` mixes each per-person color toward white by
`1 - falloff^k` for track k (track 0 stays vivid)."""
starts: List[np.ndarray] = []
ends: List[np.ndarray] = []
colors: List[np.ndarray] = []
@ -65,8 +60,7 @@ def _build_specs_from_pose(
if body_op is None or cam_t is None:
continue
cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3)
# op-keypoints are camera frame (Y-down); add cam_t to place the
# subject in front of the camera.
# op-keypoints are camera frame; add cam_t to place the subject in front.
body_kp = body_op + cam_t_np[None, :]
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
@ -148,10 +142,9 @@ def _ray_capsule_t(
ba_len: torch.Tensor, # (M,) segment length
radius: torch.Tensor, # (M,) per-capsule radius
) -> torch.Tensor:
"""Closed-form ray-capsule intersection. Returns (K, M) tensor of ray
parameters t to the nearest valid hit per capsule, +inf where the ray
misses. A capsule is the union of (cylinder body, hemisphere at A,
hemisphere at B); each component is a quadratic root-find."""
"""Closed-form ray-capsule intersection -> (K, M) ray params t to the nearest
valid hit per capsule, +inf on miss. Capsule = union of (cylinder, hemisphere
at A, hemisphere at B), each a quadratic root-find."""
INF = float("inf")
r_sq = radius * radius # (M,)
@ -238,9 +231,8 @@ def _render_capsules_torch(
z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item()))
z_near = max(0.05, z_min - float(radius.max().item()))
# Union of per-capsule screen-space bboxes. Pixels outside this mask
# provably can't hit any capsule, so the analytic intersection only runs
# on the relevant subset of the canvas (~5-15% at 1080p for typical poses).
# Union of per-capsule screen-space bboxes — pixels outside can't hit any
# capsule, so intersection only runs on the relevant subset of the canvas.
sz = starts[:, 2].clamp(min=z_near)
ez = ends[:, 2].clamp(min=z_near)
sx_p = starts[:, 0] * fx / sz + cx
@ -261,16 +253,13 @@ def _render_capsules_torch(
if xmax_i > xmin_i and ymax_i > ymin_i:
coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True
# Analytic ray-capsule intersection. One pass over the masked pixels —
# the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel
# plus 6 SDF evaluations per hit pixel for finite-difference normals.
# Analytic ray-capsule intersection, one pass over the masked pixels.
INF = float("inf")
flat_t = torch.full((N,), INF, device=device, dtype=torch.float32)
flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long)
active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1)
if active_idx.numel() > 0:
# Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory
# manageable when both K (image pixels) and M (capsules) are large.
# Cap per-chunk (K, M) tensors to ~4M elements to bound peak memory.
chunk_max = max(1, int(4_000_000 / max(M, 1)))
for i0 in range(0, active_idx.numel(), chunk_max):
sub = active_idx[i0 : i0 + chunk_max]
@ -284,7 +273,7 @@ def _render_capsules_torch(
flat_t[winners] = t_min[hit]
flat_m_idx[winners] = m_idx[hit]
# Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade.
# Shade via analytic normal (P - closest point on segment).
out = torch.zeros((N, 3), dtype=torch.float32, device=device)
if background_rgb is not None:
out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone()
@ -306,10 +295,10 @@ def _render_capsules_torch(
col = colors[m_h, :3]
if flat_shade:
# Solid per-limb color (OpenPose look) — no lighting/depth modulation.
# Solid per-limb color (OpenPose look) — no lighting/depth.
out[hit_idx] = col
return out.view(H, W, 3).clamp(0.0, 1.0)
# SCAIL Blinn-Phong (render_torch.py:290-331). Headlight: light = +Z.
# SCAIL Blinn-Phong, headlight along +Z.
diff = torch.clamp(-(normals[:, 2]), min=0.0)
diffuse = 0.45 + 0.55 * diff
@ -319,7 +308,7 @@ def _render_capsules_torch(
half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8)
spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32)
# Mild depth fade matches SCAIL's mm-scale ramp in our meter units.
# Mild depth fade.
z_vals = p_hit[:, 2]
z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item())
if z_hi - z_lo > 1e-6:
@ -351,21 +340,18 @@ def render_pose_data_capsules(
hand_radius_scale: float = 0.4,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Render a frame's pose_data as 3D capsules projected through the per-
person camera. Returns (H, W, 3) fp32 in [0, 1].
"""Render a frame's pose_data as 3D capsules through the per-person camera.
Returns (H, W, 3) fp32 in [0, 1].
`composite='over'` paints over `background` (black if None);
`composite='mesh_only'` always uses a black canvas.
`radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`).
Hand limbs use `radius_m * hand_radius_scale` (their bones are far shorter
than body limbs). Camera fx/fy come from each person's `focal_length`.
`composite='over'` paints over `background` (black if None); 'mesh_only'
uses a black canvas. `radius_m` is in meters; hand limbs use
`radius_m * hand_radius_scale`. fx/fy come from each person's `focal_length`.
"""
persons = pose_data["frames"][frame_idx]
if device is None:
device = comfy.model_management.get_torch_device()
# SAM3DBody shares one camera across the clip — pick from the first valid person.
# SAM3DBody shares one camera across the clip — use the first valid person.
fx = fy = float(min(H, W))
for person in persons:
f = person.get("focal_length")

View File

@ -1,16 +1,10 @@
"""GLB export — OpenPose 18-keypoint visualization mode.
Independent of the MHR rig sourced from pose_data's `pred_keypoints_3d`
(the model's regressed surface keypoints). Each track becomes an armature
with a sibling joint per keypoint; sphere markers + stick/capsule limbs are
skinned to those joints.
Optional hand keypoints (also from `pred_keypoints_3d`, indices 21..62) and
face landmarks (sampled from `pred_vertices` at fixed head-mesh vertex IDs)
extend the same armature.
OpenPose-shared tables / palettes / mappings live in `glb_shared.py` and are
imported below they're also used by the 2D and 3D renderers in this package.
Sourced from pose_data's `pred_keypoints_3d`, independent of the MHR rig. Each
track becomes an armature with a joint per keypoint; sphere markers and limbs
are skinned to those joints. Optional hands (`pred_keypoints_3d` 21..62) and
face landmarks (`pred_vertices` at fixed vertex IDs) extend the same armature.
Shared tables/palettes/mappings live in `glb_shared.py`.
"""
from __future__ import annotations
@ -55,9 +49,8 @@ def _finalize_skinned_mesh(
joints: np.ndarray, weights: np.ndarray, vert_colors: np.ndarray,
smooth_shade: bool,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Apply smooth or flat shading to an indexed sphere/stick group mesh and
pack per-vertex colors. Smooth keeps the indexed mesh + per-vertex colors;
flat duplicates verts per face and gathers face-corner colors."""
"""Shade a skinned group mesh and pack per-vertex colors. Smooth keeps the
indexed mesh; flat duplicates verts per face and gathers face-corner colors."""
if smooth_shade:
v_f, n_f, f_f, j_f, w_f = smooth_shade_mesh(verts, faces, joints, weights)
return v_f, n_f, f_f, j_f, w_f, vert_colors.astype(np.float32)
@ -73,10 +66,8 @@ def _finalize_skinned_mesh(
def _pair_colors_from_kp(
pairs: Tuple[Tuple[int, int], ...], kp_colors: np.ndarray, endpoint: int = 1,
) -> np.ndarray:
"""Per-limb color = endpoint-vertex color from `kp_colors`. Default
`endpoint=1` picks the second (distal) vertex of each pair, which is
the OpenPose-canonical per-finger gradient when fingers go basetip
(wrist=0 thumb1=1 thumb2=2 )."""
"""Per-limb color from `kp_colors`. `endpoint=1` (default) picks the distal
vertex of each pair the OpenPose per-finger gradient for basetip fingers."""
n = len(pairs)
out = np.zeros((n, 3), dtype=np.float32)
for i, (a, b) in enumerate(pairs):
@ -88,19 +79,13 @@ def _openpose_bind_at_rig_rest(
pose_data: Dict[str, Any], *,
include_hands: bool, face_vert_ids: Optional[np.ndarray],
) -> Optional[np.ndarray]:
"""OpenPose keypoint positions at the rig's REST pose (T-pose at authoring
origin), built from the `_skeleton_override`'s `bind_global_m` (joint rest
TRS) and `rest_verts_m` (mesh rest verts for face landmarks).
"""OpenPose keypoint positions at the rig's REST pose, from the override's
`bind_global_m` (joint rest TRS) and `rest_verts_m` (face landmarks).
Used as the static-bind for openpose-mode geometry so the GLB's static
POSITION attribute sits at rig origin matching skeletal mode's bind and
producing the same 'snap from rest to scene-frame-0' transition at the
start of playback. Without this, the static geometry is at scene-frame-0
(kp_seq[0]) and viewers that auto-fit on static POSITION will center on
the scene location, hiding the per-frame motion.
Returns None when the override is missing or doesn't carry all the needed
mappings caller falls back to per-frame extraction (kp_seq[0])."""
Used as the static-bind so the GLB's static POSITION sits at rig origin,
matching skeletal mode and producing the same restscene-frame-0 transition.
Returns None when the override lacks the needed mappings caller then falls
back to per-frame extraction (kp_seq[0])."""
override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None
if override is None or "bind_global_m" not in override:
return None
@ -141,19 +126,12 @@ def _openpose_bind_at_rig_rest(
def _extract_openpose_keypoints(
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
) -> np.ndarray:
"""(N, 18, 3) OpenPose keypoint positions in rig-native Y-up metres.
"""(N, 18, 3) OpenPose keypoints in rig-native Y-up metres.
Two sources, in priority order:
1. **External-skeleton path** when pose_data has `_skeleton_override`
with `openpose18_joint_indices` ((18, 2) int32, see
`_resolve_openpose_keypoints_from_joints`), synthesize from each
person's `pred_joint_coords` directly. The override frame is already
rig-native Y-up, so no axis flip.
2. **MHR70 path** (default for SAM3DBody_Predict output) re-index the
first 70 of 308 MHR keypoints (`pred_keypoints_3d`) to COCO-18.
Stored y-down (post `j3d[..., [1,2]] *= -1` in sam3d_body), so we
un-flip y/z to match rig-native Y-up.
External-skeleton path: when the override carries `openpose18_joint_indices`
((18, 2) int32), synthesize from each person's `pred_joint_coords` (already
Y-up, no flip). MHR70 path (default): re-index `pred_keypoints_3d` to COCO-18
and un-flip y/z (stored y-down by sam3d_body).
"""
frames = pose_data["frames"]
N = len(frame_indices)
@ -195,10 +173,8 @@ def _extract_openpose_keypoints(
for t_idx, t in enumerate(frame_indices):
person = frames[t][person_k]
if "pred_keypoints_3d" not in person:
# Diagnose the source: external-skeleton producers ship
# `_skeleton_override` instead of MHR70 keypoints. If the
# producer didn't populate `openpose18_joint_indices` either,
# we can't synthesize the 18-keypoint set.
# External-skeleton producer without `openpose18_joint_indices`:
# can't synthesize the 18-keypoint set.
if override is not None:
raise ValueError(
"build_glb_openpose: this pose_data carries "
@ -229,15 +205,11 @@ def _extract_openpose_keypoints(
def _extract_openpose_hand_keypoints(
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
) -> np.ndarray:
"""(N, 42, 3) right+left OpenPose hand keypoints (21 + 21) in rig-native
Y-up frame.
"""(N, 42, 3) right+left OpenPose hand keypoints (21+21) in rig-native Y-up.
External-skeleton path: requires `openpose_hand21_r_joint_indices` AND
`openpose_hand21_l_joint_indices` ((21, 2) int32 each) in the override.
Resolved against per-frame `pred_joint_coords` like the body path.
MHR70 path: re-orders `pred_keypoints_3d` indices 21..62 to OpenPose-21
(wrist + 5 fingers, thumbpinky, basetip)."""
External-skeleton path: needs `openpose_hand21_{r,l}_joint_indices` ((21, 2)
int32) in the override, resolved against `pred_joint_coords`. MHR70 path:
re-orders `pred_keypoints_3d` 21..62 to OpenPose-21 (wrist + 5 fingers)."""
frames = pose_data["frames"]
N = len(frame_indices)
out = np.zeros((N, 42, 3), dtype=np.float32)
@ -307,10 +279,8 @@ def _extract_face_landmarks_from_verts(
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
vert_ids: np.ndarray,
) -> np.ndarray:
"""(N, K_face, 3) face landmarks sampled from per-frame `pred_vertices`
at the supplied head-mesh vertex IDs, unflipped to MHR-native Y-up.
Each landmark inherits per-frame shape/expr/pose deformation for free
since `pred_vertices` already has it baked in."""
"""(N, K_face, 3) face landmarks sampled from `pred_vertices` at the given
vertex IDs, unflipped to Y-up. Per-frame deformation is already baked in."""
frames = pose_data["frames"]
N = len(frame_indices)
K = int(vert_ids.shape[0])
@ -335,18 +305,11 @@ def _build_openpose_spheres(
smooth_shade: bool = False,
joint_indices: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""UV sphere per OpenPose keypoint, rigidly skinned to that keypoint's
joint, vertex-colored from kp_colors. `base_joint_idx` is added to the
emitted JOINTS_0 indices so callers can place this group at any offset
in the shared skin (body=0, right hand=18, etc.). `joint_indices` (when
given) overrides that with explicit per-sphere joint indices, so callers
can skip keypoints (e.g. SCAIL head dots).
`smooth_shade=True` keeps the indexed mesh and writes per-vertex
normals via face-normal averaging round shading on the spheres.
`smooth_shade=False` (default) flat-shades by duplicating verts per
face, matching the existing OpenPose-mode look. Returns
(verts, normals, faces, joints4, weights4, vert_colors)."""
"""UV sphere per keypoint, rigidly skinned to that keypoint's joint and
vertex-colored from kp_colors. `base_joint_idx` offsets the emitted JOINTS_0
indices (body=0, right hand=18, ); `joint_indices`, if given, sets explicit
per-sphere indices so callers can skip keypoints (e.g. SCAIL head dots).
Returns (verts, normals, faces, joints4, weights4, vert_colors)."""
sv, sf = uv_sphere_unit()
K = bind_kp_m.shape[0]
Nv = sv.shape[0]
@ -376,43 +339,23 @@ def _capsule_mesh_local(
end_width_frac: float = 0.3,
shape: str = "ellipsoid",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Build a per-limb mesh in limb-local frame along +Y from y=0 (head
pole) to y=L (tail pole).
"""Per-limb mesh in limb-local frame along +Y from y=0 (head) to y=L (tail).
`shape` selects the silhouette:
- 'ellipsoid' (default): tips are small hemispheres of radius
`W * end_width_frac`; body has ellipsoidal radius profile
sin(π*u) from w_end at the junctions to W at the middle. Gives
a fat-middle / narrow-end stretched-ellipse look.
- 'capsule': SCAIL-style "rig" limb an OPEN cylinder of constant
radius W with no hemisphere caps. Pair with sphere joint markers
of the same radius so the spheres seamlessly cap the open
cylinder ends (the cylinder cross-section ring at the endpoint
lies exactly on the sphere surface). Drawing hemisphere caps
inside the joint sphere creates a visible bump where the cap
pokes out unevenly when sphere radius cap radius open
cylinders avoid that.
`shape`:
- 'ellipsoid' (default): hemisphere tips of radius `W * end_width_frac`,
ellipsoidal sin(π·u) body profile (fat middle, narrow ends).
- 'capsule': SCAIL "rig" limb an OPEN cylinder of constant radius W,
no caps. Pair with same-radius sphere markers so they cap the ends
seamlessly (caps would bump out when sphere radius cap radius).
Per-limb mesh is required because the cap height (w_end) depends on
the limb width a single canonical mesh can't produce true
hemispheres for arbitrary L:W ratios in ellipsoid mode.
A per-limb mesh is needed because cap height depends on width one
canonical mesh can't give true hemispheres for arbitrary L:W in ellipsoid.
Returns:
verts: (Nv, 3) float32 limb-local positions in meters.
faces: (Nf, 3) uint32 triangle indices.
weights: (Nv, 2) float32 (head, tail) skinning weights, linearly
interpolated by axial position (sums to 1).
Returns (verts (Nv,3), faces (Nf,3), weights (Nv,2) head/tail, sums to 1).
"""
W = max(1e-6, min(float(W), float(L) * 0.5 - 1e-6))
if str(shape) == "capsule":
# SCAIL-style "rig" limb: an OPEN cylinder of constant radius W,
# no hemisphere caps. The sphere joint markers at each endpoint
# provide the rounded ends of the bone — when sphere_radius ==
# cylinder_radius, the cylinder cross-section ring at the bone
# endpoint lies exactly on the sphere surface, so silhouette is
# seamless. Hemisphere caps would create a visible bump where
# the cap pokes out of the sphere if cap_r ≠ marker_r, so we
# omit them entirely.
# Open cylinder, no caps — sphere markers cap the ends (see docstring).
cap_r = 0.0
body_r = W
if n_cap_lat is None:
@ -425,7 +368,7 @@ def _capsule_mesh_local(
end_frac = float(min(0.95, max(0.05, end_width_frac)))
cap_r = max(1e-7, W * end_frac)
body_r = W
# Ellipsoid defaults: more body rings to sample the sin(π·u) curve.
# More body rings to sample the sin(π·u) curve.
if n_cap_lat is None:
n_cap_lat = 3
if n_body is None:
@ -473,10 +416,7 @@ def _capsule_mesh_local(
phi = 2.0 * np.pi * k / n_lon
verts.append([body_r * float(np.cos(phi)), 0.0, body_r * float(np.sin(phi))])
# Body intermediate rings (between the cap junctions for capped meshes,
# between the two end rings for open cylinders). For 'capsule' mode
# n_body=0 by default — no intermediate rings needed for a constant-
# radius cylinder.
# Body intermediate rings (none for 'capsule', n_body=0 by default).
body_rings: List[int] = []
is_ellipsoid = str(shape) == "ellipsoid"
for j in range(1, n_body + 1):
@ -572,11 +512,8 @@ def _scail_redirect_neck_stub(body_kp: np.ndarray) -> np.ndarray:
def _openpose_limb_rest_trs(
bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...],
) -> Tuple[np.ndarray, np.ndarray]:
"""Per-limb rest TRS:
midpoints (K_pairs, 3): rest midpoint between bind_kp_m[a] and bind_kp_m[b].
rest_axes (K_pairs, 3): unit direction ab at rest (or +Y if degenerate).
Caller uses `midpoints` as each limb joint's rest translation (rotation =
identity), and `rest_axes` to compute per-frame alignment rotations."""
"""Per-limb rest TRS: midpoints (K_pairs, 3) and unit a→b axes (or +Y if
degenerate). Caller uses midpoints as rest translation, axes for alignment."""
K_pairs = len(pairs)
mid = np.zeros((K_pairs, 3), dtype=np.float32)
axis = np.zeros((K_pairs, 3), dtype=np.float32)
@ -595,13 +532,10 @@ def _openpose_limb_rest_trs(
def _openpose_limb_anim_trs(
kp_seq: np.ndarray, pairs: Tuple[Tuple[int, int], ...], rest_axes: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Per-frame limb TRS:
anim_mid (N, K_pairs, 3): midpoint of (kp_seq[t][a], kp_seq[t][b]).
anim_quat (N, K_pairs, 4): rotation (xyzw) that aligns each limb's rest
axis to its frame-t axis.
Together with rest TRS, this drives `skin_matrix(t) = T(mid_t) * R_t *
T(-mid_rest)` so each capsule rigidly rotates about its rest midpoint to
track the limb's current direction — no LBS cross-section thinning."""
"""Per-frame limb TRS: anim_mid (N, K_pairs, 3) midpoints and anim_quat
(N, K_pairs, 4 xyzw) aligning each limb's rest axis to its frame-t axis.
Drives skin_matrix(t) = T(mid_t)·R_t·T(-mid_rest) rigid rotation about
the rest midpoint, no LBS cross-section thinning."""
N = kp_seq.shape[0]
K_pairs = len(pairs)
anim_mid = np.zeros((N, K_pairs, 3), dtype=np.float32)
@ -616,7 +550,7 @@ def _openpose_limb_anim_trs(
n = float(np.linalg.norm(d))
if n > 1e-9:
R[t, k] = rotation_align(ax_rest, d / n)
quat = rotmat_to_quat_np(R).astype(np.float32) # (N, K_pairs, 4) xyzw
quat = rotmat_to_quat_np(R).astype(np.float32)
return anim_mid, quat
@ -628,20 +562,14 @@ def _build_openpose_sticks(
smooth_shade: bool = False,
end_width_frac: float = 0.3,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Capsule (cylinder + hemispherical caps) per limb pair (a, b).
"""Capsule per limb pair (a, b), each sized to its own length/width so caps
are true hemispheres regardless of L:W. Ellipsoid mode auto-clamps width to
`length * 0.1` so short limbs don't look chunky.
Each limb gets its own mesh sized to that limb's length and width so
the caps are TRUE hemispheres of radius `half_width_eff` the limb
silhouette is rounded-rectangle-like, regardless of L:W ratio. Width
auto-clamped to `length * 0.1` so short limbs (face/ear) don't look
chunky next to long ones.
Skinning: rigid (weight=1) binding to a per-limb joint at
`limb_joint_base_idx + limb_idx` the caller animates that joint with
midpoint translation + rest-to-current rotation so each capsule rotates
rigidly with its limb (avoids translation-only LBS cross-section
thinning). Returns flat-shaded (verts, normals, faces, joints4,
weights4, vert_colors)."""
Rigid (weight=1) binding to a per-limb joint at `limb_joint_base_idx +
limb_idx`, which the caller animates with midpoint translation + rotation
(avoids LBS thinning). Returns (verts, normals, faces, joints4, weights4,
vert_colors)."""
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
out_v_chunks: List[np.ndarray] = []
@ -663,13 +591,10 @@ def _build_openpose_sticks(
unit_dir = direction / length
R = rotation_align(canonical, unit_dir)
if is_capsule:
# SCAIL-style uniform radius — every bone gets the same width.
# `_capsule_mesh_local` clamps internally to L/2-eps so very
# short bones don't go degenerate.
# Uniform radius — every bone the same width (clamped internally).
half_width_eff = max(MIN_WIDTH, half_width_m)
else:
# Ellipsoid mode: original auto-thinning so short face/ear
# limbs don't look chunky next to long body limbs.
# Auto-thin so short face/ear limbs aren't chunky next to body limbs.
half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m))
v_local, f_local, _weights_unused = _capsule_mesh_local(
@ -678,10 +603,8 @@ def _build_openpose_sticks(
v_world = v_local @ R.T + head
Nv = v_local.shape[0]
# Rigid binding to the per-limb joint. The 2-bone (head, tail) weights
# from `_capsule_mesh_local` are discarded — they're translation-only
# under glTF LBS and don't rotate the cross-section, causing visible
# thinning when the limb axis changes between rest and animated pose.
# Rigid binding to the per-limb joint; the 2-bone weights are discarded
# (translation-only under LBS, would thin the cross-section).
j_arr = np.zeros((Nv, 4), dtype=np.uint16)
j_arr[:, 0] = limb_idx + limb_joint_base_idx
w_arr = np.zeros((Nv, 4), dtype=np.float32)
@ -730,40 +653,24 @@ def build_glb_openpose(
stick_end_width_frac: float = 0.6,
bone_smooth_window: int = 0,
) -> bytes:
"""Build a GLB containing an OpenPose-style 3D skeleton — sphere markers
per keypoint plus rainbow-colored sticks between standard limb pairs.
Body keypoints are sourced from pose_data's `pred_keypoints_3d` (no rig
forward needed). Optional hand keypoints (also from `pred_keypoints_3d`)
and face landmarks (sampled from `pred_vertices` at fixed head-mesh
vertex IDs) extend the same per-track armature.
"""Build a GLB of an OpenPose-style 3D skeleton — sphere markers per keypoint
plus colored sticks between limb pairs, one armature per track. Body from
`pred_keypoints_3d`; optional hands (same source) and face landmarks
(`pred_vertices`) extend each armature.
Args:
include_hands: append the standard 21+21 OpenPose hand keypoints to
each track's armature (right hand at MHR70 indices 21..41,
left at 42..62).
hand_marker_radius_m: per-hand sphere radius. 0 = auto = 0.4 ×
`marker_radius_m` (hand keypoints are anatomically smaller than
body joints; matches DWPose's smaller hand dots).
hand_stick_radius_m: per-hand limb half-width. 0 = auto = 0.5 ×
`stick_radius_m`.
hand_color_style: 'dwpose' (default) = solid-blue hand dots,
rainbow per-finger sticks (controlnet_aux/dwpose convention);
'openpose' = rainbow per-finger dots AND sticks (matches
poseParameters.cpp::HAND_COLORS_RENDER).
face_style: 'disabled' (default) | 'full' | 'eyes_mouth' face
landmarks sampled from `pred_vertices` at vertex IDs picked from
`pose_data["canonical_colors"]["positions"]`. 'full' = all ~30
contour points; 'eyes_mouth' = the eyes + outer-lip subset.
face_marker_radius_m: per-face landmark sphere radius. 0 = auto =
0.3 × `marker_radius_m` face landmarks are densely packed
around the eyes/mouth/jaw and need to be much smaller than
body keypoints to keep the layout legible. Face landmarks are
rendered as standalone dots (no contour lines), matching
DWPose's face_pose draw style.
palette: body color scheme. 'openpose' = standard rainbow gradient
per keypoint (canonical OpenPose convention); 'scail' =
SCAIL-Pose style warm hues right side, cool hues left side,
grey neck-to-nose centerline, distinct per-limb colors.
include_hands: append the 21+21 OpenPose hand keypoints per track.
hand_marker_radius_m: hand sphere radius. 0 = auto = 0.4 × marker_radius_m.
hand_stick_radius_m: hand limb half-width. 0 = auto = 0.5 × stick_radius_m.
hand_color_style: 'dwpose' (default) = solid-blue dots + rainbow sticks;
'openpose' = rainbow dots AND sticks.
face_style: 'disabled' (default) | 'full' (~30 contour pts) | 'eyes_mouth'
(eyes + outer-lip subset); sampled at vertex IDs from
`canonical_colors["positions"]`.
face_marker_radius_m: face landmark sphere radius. 0 = auto = 0.3 ×
marker_radius_m. Rendered as dots only, no contour lines.
palette: 'openpose' = rainbow gradient per keypoint; 'scail' = warm right
/ cool left, grey centerline, distinct per-limb colors.
"""
is_scail = str(palette) == "scail"
# SCAIL drops the face bones (13..16) and eye/ear spheres; keeps nose (idx 0,
@ -771,13 +678,11 @@ def build_glb_openpose(
body_pairs = OPENPOSE_18_PAIRS[:13] if is_scail else OPENPOSE_18_PAIRS
body_sphere_kp = (np.arange(14, dtype=np.int64)
if is_scail else np.arange(18, dtype=np.int64))
if str(palette) == "scail":
if is_scail:
body_sphere_colors = SCAIL_KEYPOINT_COLORS_18
body_stick_colors = SCAIL_LIMB_COLORS_17
elif str(palette) == "openpose":
# Existing OpenPose behavior: same rainbow array used for both
# spheres (per-keypoint) and sticks (per-limb, indexed 0..16 of
# the 18-element rainbow — yields a legible per-limb gradient).
# Same rainbow array drives both spheres and sticks.
body_sphere_colors = OPENPOSE_RAINBOW_18
body_stick_colors = OPENPOSE_RAINBOW_18
else:
@ -892,13 +797,9 @@ def build_glb_openpose(
if bone_smooth_window and bone_smooth_window > 1:
kp_seq = gaussian_smooth_positions(kp_seq, int(bone_smooth_window))
# Static-bind = rig's REST pose when available (override path); else
# fall back to frame 0 of the motion. The rest-pose bind makes the
# GLB's static POSITION attribute sit at rig origin, so viewers
# auto-fit/center on rig origin and the animation visibly snaps from
# rest to scene-frame-0 — matching skeletal mode's behavior. Without
# this, openpose's static geometry is at scene-frame-0 and viewers
# mis-center on the scene location, masking the motion entirely.
# Static-bind = rig REST pose when available, else frame 0. The rest
# bind keeps static POSITION at rig origin so viewers auto-center there
# and the motion is visible (see _openpose_bind_at_rig_rest).
bind_kp_m_rest = _openpose_bind_at_rig_rest(
pose_data, include_hands=include_hands, face_vert_ids=face_vert_ids,
)
@ -914,7 +815,7 @@ def build_glb_openpose(
person_root_idx = len(nodes) - 1
scene_root_indices.append(person_root_idx)
# K keypoint joint nodes (spheres bind here, rigid translation only).
# K keypoint joint nodes (spheres bind here, translation only).
joint_node_indices: List[int] = []
for j in range(K):
nodes.append({
@ -926,9 +827,7 @@ def build_glb_openpose(
joint_node_indices.append(len(nodes) - 1)
person_root["children"].extend(joint_node_indices)
# Per-limb REST TRS (midpoint + axis) and per-frame TRS (midpoint +
# quaternion that aligns rest-axis → frame-t-axis). Sticks bind
# rigidly to these joints so each capsule rotates with its limb.
# Per-limb rest + per-frame TRS; sticks bind rigidly to these joints.
limb_rest_mids_list: List[np.ndarray] = []
limb_rest_axes_list: List[np.ndarray] = []
limb_anim_mids_list: List[np.ndarray] = []
@ -951,12 +850,10 @@ def build_glb_openpose(
limb_rest_axes_list.append(raxis_h)
limb_anim_mids_list.append(amid_h)
limb_anim_quats_list.append(aquat_h)
limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0) # (K_limbs, 3)
limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1) # (N, K_limbs, 3)
limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1) # (N, K_limbs, 4)
# Hemisphere-align consecutive quats per limb so LINEAR interpolation
# takes the short path (otherwise large per-frame rotations can flip
# signs and produce visible "twist back" artifacts mid-playback).
limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0)
limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1)
limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1)
# Hemisphere-align consecutive quats so LINEAR interp takes the short path.
limb_anim_quats = quat_sign_fix_per_joint(limb_anim_quats).astype(np.float32)
limb_joint_indices: List[int] = []
@ -970,8 +867,8 @@ def build_glb_openpose(
limb_joint_indices.append(len(nodes) - 1)
person_root["children"].extend(limb_joint_indices)
# Combined skin: keypoint joints (IBM = T(-bind_kp_m)) then limb joints
# (IBM = T(-limb_rest_mid)). Both yield identity skin_matrix at rest.
# Combined skin: keypoint joints then limb joints; IBM = T(-rest) for
# both, yielding identity skin_matrix at rest.
all_joint_indices = joint_node_indices + limb_joint_indices
ibm = np.tile(np.eye(4, dtype=np.float32), (K + K_limbs, 1, 1))
ibm[:K, :3, 3] = -bind_kp_m
@ -985,10 +882,8 @@ def build_glb_openpose(
})
skin_idx = len(skins) - 1
# Per-group geometry. Spheres bind to keypoint joints (base_joint_idx
# ∈ [0, K)); sticks bind to limb joints (limb_joint_base_idx ∈
# [K, K + K_limbs)). Groups stack body → right hand → left hand →
# face for keypoint joints, and body → R-hand → L-hand for limbs.
# Per-group geometry. Spheres bind to keypoint joints [0, K); sticks to
# limb joints [K, K+K_limbs). Stacked body → R-hand → L-hand → face.
group_meshes: List[Tuple[np.ndarray, np.ndarray, np.ndarray,
np.ndarray, np.ndarray, np.ndarray]] = []
sp = _build_openpose_spheres(
@ -1008,9 +903,7 @@ def build_glb_openpose(
group_meshes.append(st)
if include_hands:
# Hand stick colors stay rainbow per-finger regardless of
# `hand_color_style` — only the sphere dots switch to solid
# blue under 'dwpose'. Matches controlnet_aux/dwpose/util.py.
# Hand sticks stay rainbow per-finger; only dots switch under 'dwpose'.
hand_pair_colors = _pair_colors_from_kp(
OPENPOSE_HAND_PAIRS, OPENPOSE_HAND_COLORS_21, endpoint=1,
)
@ -1033,9 +926,7 @@ def build_glb_openpose(
if K_face > 0:
f_off = K_body + K_hands
f_bind = bind_kp_m[f_off:f_off + K_face]
# DWPose face = dots only, no contour lines
# (controlnet_aux/dwpose/util.py::draw_facepose draws white
# circles per landmark and never connects them).
# DWPose face = dots only, no contour lines.
group_meshes.append(_build_openpose_spheres(
f_bind, float(face_marker_radius_m),
FACE_LANDMARK_COLORS, base_joint_idx=f_off,
@ -1087,9 +978,8 @@ def build_glb_openpose(
"target": {"node": joint_node_indices[j], "path": "translation"},
})
# Per-limb-joint translation + rotation channels. Stationary limbs
# have their constant TRS baked into the node so they don't bloat the
# animation buffer.
# Per-limb-joint translation + rotation; stationary limbs bake their
# constant TRS into the node instead of an animation channel.
for k in range(K_limbs):
t_k = limb_anim_mids[:, k, :].astype(np.float32)
if (np.ptp(t_k, axis=0) < 1e-6).all():
@ -1103,9 +993,7 @@ def build_glb_openpose(
"target": {"node": limb_joint_indices[k], "path": "translation"},
})
q_k = limb_anim_quats[:, k, :].astype(np.float32)
# ptp on the absolute value handles the +q == -q ambiguity, but
# `quat_sign_fix_per_joint` already aligned signs so a plain ptp
# is fine here.
# Plain ptp is fine — signs already aligned by quat_sign_fix_per_joint.
if (np.ptp(q_k, axis=0) < 1e-6).all():
nodes[limb_joint_indices[k]]["rotation"] = q_k[0].tolist()
else:

View File

@ -1,15 +1,11 @@
"""GLB export for SAM 3D Body pose_data.
"""Shared GLB export helpers for SAM 3D Body pose_data.
Mode: skeletal rebuilds the MHR 127-bone rig. Per-frame local TRS comes from
re-running param_transform on saved mhr_model_params; rest verts from a
zero-pose forward with the person's shape_params; sparse triplet skinning is
compacted to glTF's max-4-influences form; facial expression is re-exposed as
72 morph targets driven by expr_params.
pred_vertices/pred_cam_t are camera-y-down un-flipped here so the GLB lives
in glTF-spec Y-up. Pose correctives are dropped (glTF skinning can't represent
them); deformation at extreme joint angles will differ from the SAM3DBody
renderer by the corrective amount.
Skeletal mode rebuilds the MHR 127-bone rig: per-frame local TRS from
param_transform on mhr_model_params, rest verts from a zero-pose forward,
sparse skinning compacted to glTF's 4-influence form, expression re-exposed as
72 morph targets. Camera-y-down data is un-flipped to glTF Y-up. Pose
correctives are dropped (glTF skinning can't represent them), so extreme joint
angles differ from the SAM3DBody renderer by the corrective amount.
"""
from __future__ import annotations
@ -24,12 +20,11 @@ import torch
from comfy_extras.sam3d_body.rasterizer import rainbow_colors_from_canonical
# fp32-rounded ln(2). Used as `exp(x * _LN2)` to compute 2**x bit-identically
# to the rig's own `torch.exp(jp[..., 6:7] * _LN2)`
# fp32-rounded ln(2); exp(x * _LN2) matches the rig's own 2**x bit-for-bit.
_LN2 = 0.6931471824645996
# Quaternion / rotation helpers (xyzw convention, matching MHR rig)
# Quaternion / rotation helpers (xyzw, matching MHR rig)
def _euler_xyz_to_quat_np(angles: np.ndarray) -> np.ndarray:
"""(roll, pitch, yaw) -> (x, y, z, w). Mirrors mhr_rig._euler_xyz_to_quat."""
@ -96,8 +91,7 @@ def _skel_state_compose_np(s1: np.ndarray, s2: np.ndarray) -> np.ndarray:
def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
"""Edge-replicate Gaussian smoothing along axis 0 (time); sigma = window/4.
Endpoints replicate so they aren't pulled toward zero. Returns float64."""
"""Edge-replicate Gaussian smoothing along time (sigma = window/4). float64."""
a = np.asarray(arr, dtype=np.float64)
n = a.shape[0]
half = window // 2
@ -117,9 +111,8 @@ def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
"""Gaussian-smooth a (N, NJ, 4) quaternion sequence along time. Sign-aligns
per joint first, convolves per-component, renormalizes. Suppresses multi-
frame bone spikes at extreme poses without needing the upstream Smooth node."""
"""Smooth a (N, NJ, 4) quaternion sequence along time: sign-align per joint,
convolve per-component, renormalize. Calms bone spikes at extreme poses."""
if window <= 1 or q_seq.shape[0] < 2:
return q_seq
out = _gaussian_smooth_time(quat_sign_fix_per_joint(q_seq), window)
@ -128,18 +121,16 @@ def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray:
"""Gaussian-smooth a (N, K, 3) position sequence along time (edge-replicate
padding). Used to calm jittery keypoint tracks before the openpose rig
derives sphere translations + limb TRS from them."""
"""Smooth a (N, K, 3) position sequence along time. Calms jittery keypoint
tracks before the openpose rig derives sphere translations + limb TRS."""
if window <= 1 or seq.shape[0] < 2:
return seq
return _gaussian_smooth_time(seq, window).astype(np.float32)
def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
"""Walk (N, NJ, 4) along time, flip sign whenever consecutive frames sit
on opposite hemispheres. Eliminates long-path slerp glitches (mid-anim
cartwheel flip). fp64 to avoid drift; normalizes input defensively."""
"""Walk (N, NJ, 4) along time, flipping sign when consecutive frames sit on
opposite hemispheres. Avoids long-path slerp glitches. fp64 internally."""
out = np.array(q_seq, dtype=np.float64, copy=True)
norms = np.linalg.norm(out, axis=-1, keepdims=True)
out = out / np.maximum(norms, 1e-12)
@ -151,11 +142,9 @@ def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
def bone_locals_from_globals(rig_global: np.ndarray, parents: np.ndarray) -> np.ndarray:
"""Globals (N, NJ, 8) + parents -> per-bone local TRS (N, NJ, 8) such that
FK over (parents, bone_local) reproduces rig_global. local =
inverse(parent_global) child_global makes this robust to hierarchy-
convention mismatches: glTF FK gives back exactly rig_global even if
`parents` doesn't match the rig's pmi-walk."""
"""Globals (N, NJ, 8) + parents -> per-bone local TRS so FK reproduces
rig_global. local = inverse(parent_global) child_global, robust to
hierarchy-convention mismatches in `parents`."""
N, NJ, _ = rig_global.shape
bone_local = np.zeros_like(rig_global)
for j in range(NJ):
@ -188,8 +177,7 @@ def _quat_to_mat3_np(q: np.ndarray) -> np.ndarray:
def collect_tracks(pose_data: Dict[str, Any], track_index: int) -> List[Tuple[int, List[int]]]:
"""List of (person_index, frame_indices). track_index == -1 means every
present track; empty tracks are dropped. Same person index across frames
is assumed same subject (Smooth/Predict enforce this on tracked bboxes)."""
present track; empty tracks dropped. Same person index = same subject."""
frames = pose_data["frames"]
max_p = max((len(f) for f in frames), default=0)
if max_p == 0:
@ -257,8 +245,7 @@ class GLBWriter:
return len(self.accessors) - 1
def add_vec3_f32_no_minmax(self, arr: np.ndarray) -> int:
"""Morph-target POSITIONs: spec lets us skip min/max, avoiding a
per-frame delta bbox."""
"""Morph-target POSITIONs: spec lets us skip min/max."""
a = np.ascontiguousarray(arr, dtype=np.float32)
view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY)
self.accessors.append({
@ -288,9 +275,8 @@ class GLBWriter:
return len(self.accessors) - 1
def add_scalar_f32_flat(self, arr: np.ndarray, count: int) -> int:
"""Animation-output scalars: `count` is keyframes, not floats. Morph-
target weight tracks store N_morph weights per keyframe as flat float32
with count=N_keyframes."""
"""Animation-output scalars: `count` is keyframes, not floats (morph
weight tracks store N_morph weights per keyframe)."""
a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1)
view_idx = self._add_view(a.tobytes())
self.accessors.append({
@ -382,9 +368,8 @@ def bake_vertex_colors(
rainbow_tilt_z_deg: float,
pastel_mix: float,
) -> Optional[np.ndarray]:
"""Per-vertex RGB matching the renderer's shader preset, on the canonical
mesh. Returns (N_v, 3) float32 in [0, 1], or None for `default` (let the
viewer's default material handle shading)."""
"""Per-vertex RGB matching the renderer's shader preset. Returns (N_v, 3)
float32 in [0, 1], or None for `default` (use the viewer's material)."""
if shader == "default" or canonical_colors is None:
return None
@ -432,8 +417,8 @@ def compute_normals(verts: np.ndarray, faces: np.ndarray) -> np.ndarray:
def _parents_from_pmi(rig: Any) -> np.ndarray:
"""Parent index per joint from skel_pmi. pmi is (2, 266): row 0 = child,
row 1 = parent, split into BFS levels by skel_pmi_buffer_sizes. Roots = -1."""
"""Parent index per joint from skel_pmi ((2, 266): row 0 child, row 1
parent, split into BFS levels by skel_pmi_buffer_sizes). Roots = -1."""
NJ = int(rig.NUM_JOINTS)
pmi = rig.skel_pmi.cpu().numpy()
sizes = rig.skel_pmi_buffer_sizes.cpu().numpy().tolist()
@ -450,47 +435,29 @@ def _parents_from_pmi(rig: Any) -> np.ndarray:
def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Return ``_skeleton_override`` dict if present. Non-MHR skeletons supply
this to bypass MHR rig extraction (see ComfyUI-Kimodo). Required keys:
parents: (NJ,) int32, -1 = root
bind_global_m: (NJ, 8) f32 [t.xyz | q.xyzw | scale], meters
lbs_compact_joints: (V, 8) uint16 pre-compacted skin influences
lbs_compact_weights: (V, 8) f32
lbs_compact_max_inf: int actual max influences ( 8)
rest_verts_m: (V, 3) f32
faces: (F, 3) uint32
Optional:
per_frame_y_down: bool set False if pred_joint_coords are already
rig-native Y-up (kimodo). Default True (MHR).
openpose18_joint_indices: (18, 2) int32 body OpenPose-18 joint
index pair, resolved against per-frame
`pred_joint_coords`. Each row is
(joint_a, joint_b); b == -1 = single
joint, else default midpoint of the two
(lets producers approximate keypoints
with no matching joint, e.g. Nose
midpoint(LeftEye, RightEye)). Enables
`SAM3DBody_ToGLB(mode="openpose")` on
external rigs.
openpose18_joint_weights: (18,) f32 optional per-keypoint blend
weight for the (a, b) mapping above.
Position = w*joints[a] + (1-w)*joints[b]
when b 0 (default w=0.5 midpoint).
Values outside [0, 1] EXTRAPOLATE past
the line segment used to approximate
landmarks with no nearby joint pair
(e.g. ears: w=2.0 along the eyeeye
axis puts each ear one eye-distance
outside the corresponding eye). Ignored
for single-joint rows (b = -1).
openpose_hand21_r_joint_indices: (21, 2) int32 right-hand OpenPose-21
(wrist + 5 fingers × 4 joints, basetip)
joint index pair. Required (alongside
the L counterpart) for openpose mode
with include_hands=True.
openpose_hand21_l_joint_indices: (21, 2) int32 left-hand counterpart.
openpose_hand21_r_joint_weights: (21,) f32 optional, same semantics as
`openpose18_joint_weights`.
openpose_hand21_l_joint_weights: (21,) f32 optional, same as above.
this to bypass MHR rig extraction (see ComfyUI-Kimodo).
Required keys:
parents: (NJ,) int32, -1 = root
bind_global_m: (NJ, 8) f32 [t.xyz | q.xyzw | scale], meters
lbs_compact_joints: (V, 8) uint16 pre-compacted skin influences
lbs_compact_weights: (V, 8) f32
lbs_compact_max_inf: int actual max influences ( 8)
rest_verts_m: (V, 3) f32
faces: (F, 3) uint32
Optional (enable openpose mode on external rigs):
per_frame_y_down: bool False if pred_joint_coords are already Y-up
(kimodo). Default True (MHR).
openpose18_joint_indices: (18, 2) int32 body keypoint (a, b)
joints, resolved against `pred_joint_coords`.
b == -1 = single joint, else midpoint of (a, b).
openpose18_joint_weights: (18,) f32 blend w: w*a + (1-w)*b
(default 0.5; outside [0,1] extrapolates; ignored
when b == -1).
openpose_hand21_{r,l}_joint_indices: (21, 2) int32 per-hand keypoint
maps; both required for include_hands=True.
openpose_hand21_{r,l}_joint_weights: (21,) f32 optional, same as above.
"""
if pose_data is None:
return None
@ -502,12 +469,10 @@ def extract_rig_static(model: Any, pose_data: Optional[Dict[str, Any]] = None) -
use that instead of MHR-specific `model.head_pose.mhr` buffers."""
override = _get_skeleton_override(pose_data)
if override is not None:
# External rig: caller pre-compacts skin and supplies bind global directly,
# so we don't need MHR's PCA pose / expression bases.
# External rig: skin pre-compacted, bind global supplied directly.
parents = np.asarray(override["parents"], dtype=np.int32)
rest_v = np.asarray(override["rest_verts_m"], dtype=np.float32)
# BVH needs parent-relative bone OFFSETs (cm). MHR ships these directly;
# external rigs only give bind globals, so derive locals from them.
# BVH needs parent-relative bone offsets (cm); derive from bind globals.
bind_global_m = np.asarray(override["bind_global_m"], dtype=np.float32)
local_bind = bone_locals_from_globals(bind_global_m[None], parents)[0]
joint_translation_offsets = (local_bind[:, :3] * 100.0).astype(np.float32)
@ -560,29 +525,26 @@ def compact_skin_to_n(
skin_indices: np.ndarray, vert_indices: np.ndarray, weights: np.ndarray,
num_verts: int, max_inf: int = 8,
) -> Tuple[np.ndarray, np.ndarray, int]:
"""Sparse (joint, vert, weight) triplets -> dense (joints[V, max_inf],
weights[V, max_inf]). Keeps `max_inf` largest-magnitude influences,
renormalizes. `actual_max` lets the caller skip JOINTS_1/WEIGHTS_1 when
nothing exceeds 4 influences."""
"""Sparse (joint, vert, weight) triplets -> dense (joints, weights) of shape
(V, max_inf), keeping the largest influences and renormalizing. `actual_max`
lets the caller skip JOINTS_1/WEIGHTS_1 when nothing exceeds 4 influences."""
joints = np.zeros((num_verts, max_inf), dtype=np.uint16)
out_w = np.zeros((num_verts, max_inf), dtype=np.float32)
counts = np.zeros(num_verts, dtype=np.int32)
if vert_indices.size:
# lexsort secondary key first: groups by vert, weights descending within group.
# Group by vert, weights descending within each group.
order = np.lexsort((-weights, vert_indices))
vi_sorted = vert_indices[order]
sk_sorted = skin_indices[order]
w_sorted = weights[order]
# Per-row rank within its vertex group: 0 at each group start, +1 elsewhere.
# group_start[i] is True when vi_sorted[i] starts a new vertex.
# Per-row rank within its vertex group (0 at each group start).
n = vi_sorted.size
group_start = np.empty(n, dtype=bool)
group_start[0] = True
np.not_equal(vi_sorted[1:], vi_sorted[:-1], out=group_start[1:])
pos = np.arange(n, dtype=np.int64)
# Position of each row's group start, broadcast forward.
group_start_pos = np.maximum.accumulate(np.where(group_start, pos, 0))
rank = pos - group_start_pos
@ -609,9 +571,8 @@ def zero_pose_rest_verts(
model: Any, shape_params: np.ndarray, expr_zero: bool = True,
pose_data: Optional[Dict[str, Any]] = None,
) -> np.ndarray:
"""Rig with zero pose + this subject's shape -> rest verts (V, 3) in
rig-native Y-up meters. External-skeleton path returns `rest_verts_m`
directly (no PCA shape space to expand)."""
"""Zero pose + this subject's shape -> rest verts (V, 3) in rig-native Y-up
meters. External path returns `rest_verts_m` directly."""
override = _get_skeleton_override(pose_data)
if override is not None:
return np.asarray(override["rest_verts_m"], dtype=np.float32)
@ -624,14 +585,11 @@ def zero_pose_rest_verts(
sp = torch.from_numpy(np.ascontiguousarray(shape_params, dtype=np.float32)).to(device)
if sp.ndim == 1:
sp = sp.unsqueeze(0)
# mhr.forward(identity_coeffs, model_parameters, expr_coeffs):
# identity_rest = base_shape + identity_basis @ shape;
# cat([model_params, zeros]) through param_transform; expr added.
# rig.forward(shape, model_params, expr); zero pose + zero expr.
model_params = torch.zeros(1, 204, device=device, dtype=dtype)
expr = torch.zeros(1, 72, device=device, dtype=dtype)
verts, _ = rig(sp.to(dtype), model_params, expr, apply_correctives=False)
# Rig outputs cm; mhr_head divides by 100 for meters. Match that.
verts_m = verts[0].cpu().float().numpy() / 100.0
verts_m = verts[0].cpu().float().numpy() / 100.0 # cm -> m
return verts_m.astype(np.float32)
@ -639,7 +597,7 @@ def global_skel_state_per_frame(
model: Any, mhr_model_params: np.ndarray,
) -> np.ndarray:
"""Rig FK over a batch of mhr_model_params -> (N, NJ, 8) = (t cm, q xyzw,
scale). Bones are shape- and expression-independent so we pass zeros."""
scale). Bones are shape/expression-independent, so pass zeros."""
inner = model.model if hasattr(model, "model") else model
rig = inner.head_pose.mhr
device = next(rig.parameters()).device
@ -655,8 +613,8 @@ def global_skel_state_per_frame(
def rotmat_to_quat_np(R: np.ndarray) -> np.ndarray:
"""(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978 branched, largest-component
pick for stability. Cross-frame sign-fixing is the caller's job."""
"""(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978, largest-component pick.
Cross-frame sign-fixing is the caller's job."""
shape = R.shape[:-2]
Rf = R.reshape(-1, 3, 3).astype(np.float64)
M = Rf.shape[0]
@ -703,14 +661,12 @@ def global_skel_state_from_pose_data(
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
NJ: int, *, joint_coords_y_down: bool = True,
) -> np.ndarray:
"""Build per-frame skel_state from stored pred_global_rots + pred_joint_coords,
bypassing rig.forward. Returns (N, NJ, 8) in METERS, MHR-native frame.
"""Per-frame skel_state from stored pred_global_rots + pred_joint_coords,
bypassing rig.forward. Returns (N, NJ, 8) in meters, MHR-native frame.
pred_global_rots are MHR-native (no y/z flip). For MHR, pred_joint_coords
are stored y-down (post-flip), so un-flip when `joint_coords_y_down=True`.
External skeletons (Kimodo) store y-up already pass False. Scale
defaults to 1 (rig scale isn't preserved in pose_data; close to 1 for
typical body poses)."""
pred_global_rots are MHR-native. pred_joint_coords are y-down for MHR
(un-flipped when `joint_coords_y_down=True`); external rigs store y-up
(pass False). Scale defaults to 1 (not preserved in pose_data)."""
frames = pose_data["frames"]
N = len(frame_indices)
rotmat = np.zeros((N, NJ, 3, 3), dtype=np.float32)
@ -731,10 +687,8 @@ def global_skel_state_from_pose_data(
def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
"""Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm.
Inverse of `lbs_inverse_bind_pose` modulo precision; used as bones' static
TRS so the rest mesh looks correct with no animation playing. External
rig: convert override's `bind_global_m` from m → cm to match this contract."""
"""Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm,
used as bones' static TRS. External rig: convert `bind_global_m` m -> cm."""
override = _get_skeleton_override(pose_data)
if override is not None:
bind_m = np.asarray(override["bind_global_m"], dtype=np.float32).copy()
@ -746,13 +700,10 @@ def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> n
@dataclass
class Rig:
"""Normalized static rig for the GLB/BVH exporters, independent of where it
came from: an MHR model (`Rig.from_pose_data(pose_data, model)`) or an inline
`pose_data["_skeleton_override"]` (external rigs, e.g. ComfyUI-Kimodo).
Consumers read these fields and never branch on the source. The only
source-dependent operation is `rest_verts_m` MHR rest verts depend on the
subject's `shape_params`; external rigs ship fixed rest verts.
"""Normalized static rig for the GLB/BVH exporters, source-independent: MHR
model or inline `pose_data["_skeleton_override"]` (external rigs). Consumers
never branch on the source. Only `rest_verts_m` is source-dependent MHR
expands it from `shape_params`; external rigs ship it fixed.
"""
parents: np.ndarray # (NJ,) int32, -1 = root
joint_offsets_cm: np.ndarray # (NJ, 3) parent-relative bind offsets, cm
@ -816,9 +767,8 @@ class Rig:
def ibp_from_bind_global(bind_skel_state_m: np.ndarray) -> np.ndarray:
"""Inverse-bind MAT4 by inverting the rig's bind global (meters). Guarantees
IBP[j] = inverse(FK over bind local TRS) exactly what glTF skinning
needs given bones default to the bind local TRS. Returns (NJ, 4, 4)
"""Inverse-bind MAT4 from the rig's bind global (meters). IBP[j] =
inverse(FK over bind local TRS), as glTF skinning needs. Returns (NJ, 4, 4)
column-major."""
NJ = bind_skel_state_m.shape[0]
t = bind_skel_state_m[:, :3].astype(np.float32)
@ -877,10 +827,8 @@ def _ibp_to_mat4(ibp_skel: np.ndarray) -> np.ndarray:
def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndarray]:
"""Unit UV sphere, poles ±Y. `n_lat` kept ODD by default so one ring
lands at the equator. Default (9, 16) gives 146 verts / 288 faces n_lon
matches the 16-segment cylinder used by capsule limbs AND the equator
ring aligns 1-to-1 with the cylinder end ring, so silhouettes meet flush."""
"""Unit UV sphere, poles ±Y. `n_lat` odd so a ring lands at the equator;
n_lon=16 matches the capsule cylinder so end rings meet flush."""
verts: List[List[float]] = [[0.0, -1.0, 0.0]] # south pole at index 0
for i in range(1, n_lat + 1):
lat = -0.5 * np.pi + np.pi * i / (n_lat + 1)
@ -924,8 +872,8 @@ def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndar
def flat_shade_mesh(
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Smooth -> flat by duplicating verts per face; each triangle gets 3
unique verts sharing its face normal. Skinning attrs duplicated alongside."""
"""Flat-shade by duplicating verts per face; each triangle gets 3 unique
verts sharing its face normal. Skinning attrs duplicated alongside."""
F = faces.shape[0]
new_v = np.zeros((F * 3, 3), dtype=np.float32)
new_n = np.zeros((F * 3, 3), dtype=np.float32)
@ -949,9 +897,8 @@ def flat_shade_mesh(
def smooth_shade_mesh(
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Area-weighted per-vertex normals (smooth shading). Geometry, skinning,
indexing pass through unchanged so vertex colors stay aligned. Orphan
verts get +Y fallback."""
"""Area-weighted per-vertex normals. Geometry/skinning/indexing pass through
unchanged so vertex colors stay aligned. Orphan verts get +Y fallback."""
Nv = int(verts.shape[0])
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
@ -994,11 +941,9 @@ def rotation_align(from_vec: np.ndarray, to_vec: np.ndarray) -> np.ndarray:
def make_lit_material(
roughness: float = 0.85, double_sided: bool = False, opacity: float = 1.0,
) -> dict:
"""Lit PBR material using vertex COLOR_0 multiplicatively. KHR_materials_unlit
is intentionally off so viewer lighting reveals surface form. metallic=0
keeps the surface dielectric so vertex colors stay readable. roughness=0.85
suits dense rainbow body meshes; 0.3 matches SCAIL-Pose's glossy rig look.
opacity < 1 switches to alpha-blend (e.g. see-through body mesh over bones)."""
"""Lit PBR material using vertex COLOR_0. Dielectric (metallic=0) so colors
stay readable; roughness 0.85 suits rainbow body meshes, 0.3 the glossy
SCAIL rig. opacity < 1 switches to alpha-blend."""
a = float(max(0.0, min(1.0, opacity)))
mat = {
"pbrMetallicRoughness": {
@ -1182,14 +1127,12 @@ def openpose_render_keypoints(
person: Dict[str, Any], pose_data: Optional[Dict[str, Any]], part: str,
*, dim: int, H: int = 0, W: int = 0,
) -> Optional[np.ndarray]:
"""OpenPose keypoints for one person, in op-layout, CAMERA frame (Y-down).
"""OpenPose keypoints for one person, op-layout, camera frame (Y-down).
`part` in {'body','hand_r','hand_l'}. dim=3 -> (K, 3) metres pre-cam_t-add;
dim=2 -> (K, 2) image pixels. Returns None when the source data is missing.
dim=2 -> (K, 2) pixels. Returns None when source data is missing.
External rigs (override carries the joint-index map) resolve from per-frame
`pred_joint_coords` (rig-native Y-up -> flipped to camera Y-down, matching
the pred_vertices convention). MHR reindexes the stored
`pred_keypoints_{3d,2d}` via the MHR70 map."""
External rigs resolve from `pred_joint_coords` (Y-up -> flipped to Y-down);
MHR reindexes stored `pred_keypoints_{3d,2d}` via the MHR70 map."""
map_key, w_key, mhr_map = _OPENPOSE_RENDER_MAPS[part]
override = _get_skeleton_override(pose_data)
ext_map = override.get(map_key) if override is not None else None
@ -1228,11 +1171,9 @@ def openpose_render_keypoints(
return kp_full[mhr_map]
# Face landmarks from the MHR rig (option `face_source="rig"`).
# MHR has no face bones — face deforms via expr_params morphs — so landmarks
# are sourced from `pred_vertices` at fixed vertex IDs picked by NN against
# anatomically-plausible target xyz in canonical Y-up. Iterate visually in
# Blender and tweak targets if landmarks land off-surface.
# Face landmarks (face_source="rig"). MHR has no face bones, so landmarks are
# sourced from `pred_vertices` at vertex IDs picked by NN against the target xyz
# below. Tweak targets if landmarks land off-surface.
# (name, target_xyz) in MHR canonical Y-up meters.
FACE_LANDMARK_TARGETS: Tuple[Tuple[str, Tuple[float, float, float]], ...] = (
@ -1290,10 +1231,8 @@ def select_face_landmark_vert_ids(
face_mask: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Pick MHR head vertex IDs for each `FACE_LANDMARK_TARGETS` by NN in
canonical positions. Filter: `face_mask` (verts that deform with any of
the 72 expression axes) if available keeps chin/jaw search off the
neck. Otherwise a position bbox (less reliable; throat verts sometimes
pull chin targets)."""
canonical positions, restricted to `face_mask` verts (expression-deforming)
when available, else a position bbox (less reliable around the chin/jaw)."""
P = np.asarray(canonical_positions, dtype=np.float32).reshape(-1, 3)
if face_mask is not None and np.asarray(face_mask).any():
valid = np.where(np.asarray(face_mask).reshape(-1))[0]

View File

@ -1,19 +1,11 @@
"""GLB export — skeletal (real armature) mode.
Rebuilds an Armature with the MHR 127-bone rig:
- per-frame local TRS comes from re-running param_transform on the saved
`mhr_model_params`;
- rest verts come from a zero-pose forward with each person's `shape_params`;
- sparse triplet skinning is compacted to glTF's max-4-influences-per-vertex form;
- facial expression is re-exposed as 72 morph targets driven by `expr_params`
so face animation survives plain glTF skinning.
Optional bone visualization (octahedrons) is rigidly
skinned alongside the body mesh used to preview the armature in glTF
viewers that don't draw bones.
Shared GLB infra (writer, math, rig static extraction, shaders, normals)
stays in `glb_shared.py`; only this mode's geometry + assembly live here.
Rebuilds an Armature with the MHR 127-bone rig: per-frame local TRS from
param_transform on `mhr_model_params`, rest verts from a zero-pose forward,
sparse skinning compacted to glTF's 4-influence form, and facial expression as
72 morph targets driven by `expr_params`. Optional octahedron bone-vis is
rigidly skinned alongside for viewers that don't draw bones. Shared infra lives
in `glb_shared.py`.
"""
from __future__ import annotations
@ -44,8 +36,7 @@ from .glb_shared import (
from comfy_extras.sam3d_body.utils import jet_colormap
def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]:
"""Per-bone RGB color (NJ, 3) float32 in [0, 1]. Returns None for 'white'
(no per-bone color bone-vis mesh uses default unlit material)."""
"""Per-bone RGB (NJ, 3) float32 in [0, 1]. None for 'white' (default material)."""
if scheme == "rainbow_y":
y = bind_pos_m[:, 1].astype(np.float32)
y_min, y_max = float(y.min()), float(y.max())
@ -55,9 +46,8 @@ def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray
def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
"""Canonical Blender-style bone octahedron. Head at origin, tail at +Y,
unit length, ridge at 1/10 height. 6 verts, 8 triangles. Faces wound
so cross(v1-v0, v2-v0) points OUTWARD from the bone axis."""
"""Canonical Blender-style bone octahedron: head at origin, tail at +Y, unit
length, ridge at 1/10 height. 6 verts, 8 triangles, faces wound outward."""
v = np.array([
[0.0, 0.0, 0.0], # 0: head
[0.0, 1.0, 0.0], # 1: tail
@ -78,18 +68,16 @@ def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
def _bone_edges(
joint_pos_m: np.ndarray, parents: np.ndarray,
) -> List[Tuple[int, int, np.ndarray, np.ndarray]]:
"""Return one (parent_idx, child_idx, head_pos, tail_pos) tuple per
parentchild edge in the hierarchy, skipping edges whose PARENT is a
root joint (those typically anchor the skeleton at world origin and
just look like a stray stick from origin to the body). Zero-length
edges are skipped too."""
"""One (parent_idx, child_idx, head_pos, tail_pos) per parent→child edge.
Skips edges whose parent is a root (world-anchor sticks) and zero-length
edges."""
NJ = joint_pos_m.shape[0]
out: List[Tuple[int, int, np.ndarray, np.ndarray]] = []
for c in range(NJ):
p = int(parents[c])
if not (0 <= p < NJ and p != c):
continue
# Skip if parent itself is a root — that bone is a world-anchor stick.
# Skip world-anchor sticks: parent itself is a root.
gp = int(parents[p])
if not (0 <= gp < NJ and gp != p):
continue
@ -104,9 +92,8 @@ def _bone_edges(
def _build_bone_octahedrons_mesh(
bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""One Blender-style octahedron per parent→child edge. Returns
(verts, normals, faces, joints, weights, child_idx_per_vert);
child_idx feeds per-bone color lookup at the call site."""
"""One octahedron per parent→child edge. Returns (verts, normals, faces,
joints, weights, child_idx_per_vert); child_idx feeds per-bone color."""
base_v, base_f = _octahedron_unit()
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
@ -117,8 +104,7 @@ def _build_bone_octahedrons_mesh(
out_w: List[List[float]] = []
child_per_vert: List[int] = []
# Width scales with length so short bones (fingers, face) don't look chunky
# next to long ones (limbs, spine). `half_width_m` caps long bones.
# Width scales with length (capped by half_width_m) so short bones aren't chunky.
WIDTH_RATIO = 0.1
MIN_WIDTH = 0.001
for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents):
@ -151,8 +137,8 @@ def _build_bone_octahedrons_mesh(
out_n.extend(n_world.tolist())
for face in base_f:
out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off])
# Dual skin head→parent, tail→child, ridges blend by canonical Y so the
# bone stretches between joints instead of going rigid with one.
# Dual skin (head→parent, tail→child); ridges blend by canonical Y so
# the bone stretches between joints instead of going rigid with one.
for k in range(base_v.shape[0]):
y_canon = float(base_v[k, 1])
w_parent = max(0.0, 1.0 - y_canon)
@ -196,22 +182,17 @@ def build_glb_skeletal(
bone_vis_color: str = "white",
include_body_mesh: bool = True,
) -> bytes:
"""Build pose_data as a real Armature GLB blob with per-bone TRS keyframes.
"""Build pose_data as a real Armature GLB with per-bone TRS keyframes. For
MHR, facial expression is exposed as 72 morph targets when
include_face_morphs=True.
For MHR (default) facial expression is exposed as 72 morph targets driven
by expr_params per frame when include_face_morphs=True.
External skeletons (e.g. ComfyUI-Kimodo) can supply a
``pose_data["_skeleton_override"]`` dict to bypass the MHR rig extraction
entirely. When present, ``model`` may be None and the rig data, bind pose,
skin weights, and rest verts come from the override. Per-frame skeletal
state still reads ``pred_global_rots`` / ``pred_joint_coords`` from each
person dict (kimodo populates these from its own FK output). See
``glb.shared._get_skeleton_override`` for the override schema.
External skeletons (e.g. ComfyUI-Kimodo) can supply
``pose_data["_skeleton_override"]`` to bypass MHR rig extraction (``model``
may be None then); per-frame state still reads ``pred_global_rots`` /
``pred_joint_coords``. See ``glb_shared._get_skeleton_override`` for the schema.
"""
frames = pose_data["frames"]
# Only `pred_cam_t` is camera-y-down; mhr_model_params, lbs_*, expr_basis,
# faces are all rig-native (Y-up).
# Only `pred_cam_t` is camera-y-down; everything else is rig-native Y-up.
faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32)
tracks = collect_tracks(pose_data, track_index)
if not tracks:
@ -219,17 +200,14 @@ def build_glb_skeletal(
rig = Rig.from_pose_data(pose_data, model)
NJ = rig.num_joints
# NV = rig.num_verts
NEXPR = rig.num_expr
parents = rig.parents
if not rig.can_rerun_fk:
# External rigs have no PCA pose params to re-run; only stored globals
# are available, and they store joint coords already Y-up.
# External rigs have no PCA pose params to re-run; use stored globals.
use_stored_global_rots = True
joint_coords_y_down = rig.per_frame_y_down
# Skinning is already compacted to ≤8 influences per vertex (MHR averages
# ~2.8 but some shoulder/hip verts hit 5-8; keeping only 4 there leaks
# per-bone rotation noise into the rendered mesh).
# Skin already compacted to ≤8 influences/vertex (some shoulder/hip verts
# need >4, else per-bone rotation noise leaks into the mesh).
joints_8 = rig.lbs_joints
weights_8 = rig.lbs_weights
actual_max_inf = rig.lbs_max_inf
@ -238,14 +216,12 @@ def build_glb_skeletal(
use_set1 = actual_max_inf > 4
joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None
weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None
# Derive bone locals from the rig's bind globals rather than recomputing
# FK ourselves, so any mismatch between `parents` and the rig's actual FK
# is absorbed into the local TRS instead of producing wrong globals.
# Derive bone locals from bind globals so any `parents`/FK mismatch is
# absorbed into the local TRS instead of producing wrong globals.
bind_global_m = rig.bind_global_m
bind_local = bone_locals_from_globals(bind_global_m[None], parents)[0]
# IBP = inverse of bind global. With bone defaults set to bind_local and
# FK composed via `parents`, skin_matrix at rest = identity.
# IBP = inverse of bind global → skin_matrix at rest is identity.
ibp_mat4 = ibp_from_bind_global(bind_global_m)
w = GLBWriter()
@ -316,9 +292,7 @@ def build_glb_skeletal(
body_mesh_node_idx: Optional[int] = None
if include_body:
# MHR rest verts depend on the subject's shape_params; external rigs
# ship fixed rest verts and ignore the arg (so the empty external
# `shape_params` is harmless).
# MHR rest verts depend on shape_params; external rigs ignore the arg.
shape_params_arr = np.asarray(
frames[frame_indices[0]][person_k].get("shape_params", []),
dtype=np.float32,
@ -349,8 +323,8 @@ def build_glb_skeletal(
"indices": indices_acc,
"mode": 4,
}
# See-through body when bones are shown, else opaque (only when a
# vertex-color shader baked COLOR_0 — otherwise default material).
# See-through body when bones are shown, else opaque (only if a
# shader baked COLOR_0; otherwise default material).
if color_acc is not None or include_bones:
materials.append(make_lit_material(opacity=0.35 if include_bones else 1.0))
primitive["material"] = len(materials) - 1
@ -373,8 +347,7 @@ def build_glb_skeletal(
if include_bones:
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
# Indexes `bone_palette`: octahedrons use the bone's child joint so
# every bone has its own color regardless of skin target.
# Color by child joint so every bone has its own color.
color_idx_per_vert: Optional[np.ndarray] = None
hw = float(bone_vis_radius_m)
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
@ -422,8 +395,8 @@ def build_glb_skeletal(
nodes.append(bv_mesh_node)
person_root["children"].append(len(nodes) - 1)
# Per-frame GLOBAL skel state → bone locals via parent-inverse.
# Default uses the rig's stored output; the fallback re-runs FK.
# Per-frame global skel state → bone locals via parent-inverse. Stored
# output by default; fallback re-runs FK.
if use_stored_global_rots:
rig_global_m = global_skel_state_from_pose_data(
pose_data, frame_indices, person_k, NJ,
@ -437,11 +410,9 @@ def build_glb_skeletal(
rig_global_cm = global_skel_state_per_frame(model, mp_per_frame)
rig_global_m = rig_global_cm.copy().astype(np.float32)
rig_global_m[..., :3] *= 0.01
# Sign-fix on the GLOBAL quats BEFORE deriving locals. The rig's
# Euler-XYZ parametrization wraps at ±180° for spinning joints; if we
# only fix locals, the parent's flip propagates into the child's
# local translation (t_local inherits parent sign via q_parent_inv)
# and produces visible "axis resets" mid-animation.
# Sign-fix global quats BEFORE deriving locals: a parent's ±180° flip
# would otherwise propagate into the child's local translation and cause
# visible "axis resets" mid-animation.
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
bone_local_anim = bone_locals_from_globals(rig_global_m, parents)
local_t = bone_local_anim[..., :3].astype(np.float32)
@ -449,20 +420,17 @@ def build_glb_skeletal(
local_s = bone_local_anim[..., 7].astype(np.float32)
# Second pass on locals catches residual drift from the parent-inverse.
local_q = quat_sign_fix_per_joint(local_q)
# Hemisphere-align frame 0 with the bind quat so pause/play takes the
# short path; then re-propagate.
# Align frame 0 with the bind quat so pause/play takes the short path.
bind_q = bind_local[:, 3:7].astype(np.float32)
if local_q.shape[0] > 0:
d0 = (bind_q * local_q[0]).sum(axis=-1)
sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None]
local_q[0] = local_q[0] * sign0
local_q = quat_sign_fix_per_joint(local_q)
# Optional smoothing for multi-frame rig spikes (e.g. q.w discontinuity
# at handstand) that the upstream Smooth node may not catch.
# Optional smoothing for multi-frame rig spikes (e.g. q.w at handstand).
if bone_smooth_window and bone_smooth_window > 1:
local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window))
# fp64 renormalize → fp32 keyframes. Viewers' nlerp amplifies non-unit
# drift into visible flips otherwise.
# fp64 renormalize → fp32; viewers' nlerp amplifies non-unit drift.
lq64 = local_q.astype(np.float64)
lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12)
local_q = lq64.astype(np.float32)
@ -527,7 +495,7 @@ def build_glb_skeletal(
"target": {"node": person_root_idx, "path": "translation"},
})
# Body-mesh-only: bone-vis primitives have no morph targets.
# Body mesh only — bone-vis primitives have no morph targets.
if expr_morph_accs and body_mesh_node_idx is not None:
expr_per_frame = np.stack([
np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32)

View File

@ -34,8 +34,7 @@ def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]:
def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
resolution. Returns (per_frame_bboxes, per_frame_masks) or
(None, None) when the track is empty / frame count doesn't match"""
resolution. Returns (None, None) on empty track / frame-count mismatch."""
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
if packed is None:
@ -100,7 +99,7 @@ def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[to
def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any],
H: int, W: int) -> Dict[str, Any]:
"""Re-project every frame's pose through a Load3D 6DOF camera (position/
target/zoom + optional FOV). Returns a new mhr_pose_data; unchanged on
target/zoom + optional FOV). Returns a new mhr_pose_data, unchanged on
empty/invalid input."""
first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else []
if not first_frame:
@ -158,16 +157,16 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str,
y_axis = np.cross(z_axis, x_axis)
R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32)
# Eye: dolly along the given offset; for a rotation-only camera (position ==
# target) keep the predicted viewing distance so only orientation/roll changes.
# Eye: dolly along the offset; rotation-only camera keeps the predicted
# viewing distance so only orientation/roll changes.
if has_offset:
eye = target + offset / max(0.01, zoom)
else:
d = max(0.1, float(target[2]))
eye = target - z_axis * (d / max(0.01, zoom))
# Lens: use the camera's own FoV; else the SAM3D predicted focal (viewpoint-
# only change). Three.js fov is vertical → focal from image height.
# Lens: camera FoV if given, else the SAM3D predicted focal. Three.js fov
# is vertical → focal from image height.
cam_fov = float(camera_info.get("fov", 0.0) or 0.0)
if cam_fov > 0:
new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0)))
@ -178,10 +177,8 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str,
center = np.array([W * 0.5, H * 0.5], dtype=np.float32)
reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"}
# External rigs (e.g. Kimodo) store pred_joint_coords rig-native Y-up; the
# render openpose/scail keypoint provider resolves from them and flips Y/Z.
# Transform them through the camera too (in camera space, then back to Y-up)
# so those keypoints follow the override instead of staying in the old frame.
# External rigs store pred_joint_coords Y-up; transform them through the
# camera too (in camera space, then back to Y-up) so they follow the override.
override = mhr_pose_data.get("_skeleton_override")
joints_y_up = override is not None and not bool(override.get("per_frame_y_down", False))
new_frames: List[List[Dict[str, Any]]] = []
@ -242,8 +239,7 @@ def run_batched_single_chunk(inner: SAM3DBody, frames_rgb: List[torch.Tensor], p
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
if per_frame_masks is not None:
# Broadcast a single-mask bundle to per-bbox: when the user supplied one
# mask but multiple bboxes per frame, each bbox gets the same mask.
# One mask but multiple bboxes per frame → each bbox gets the same mask.
flat_masks = []
for f in range(N):
mf = per_frame_masks[f]