Big cleanup

This commit is contained in:
kijai 2026-06-16 20:47:15 +03:00
parent f1be65f914
commit ecbaefd8fc
13 changed files with 376 additions and 877 deletions

View File

@ -4,25 +4,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, mhr_param_hand_mask
from ..model.transformer import MLP from ..model.transformer import MLP
class MHRHead(nn.Module): class MHRHead(nn.Module):
def __init__( def __init__(self, input_dim: int, mhr_rig, mlp_depth: int = 1, mlp_channel_div_factor: int = 8, enable_hand_model=False,
self, device=None, dtype=None, operations=None):
input_dim: int,
mhr_rig,
mlp_depth: int = 1,
extra_joint_regressor: str = "",
mlp_channel_div_factor: int = 8,
enable_hand_model=False,
device=None,
dtype=None,
operations=None,
):
super().__init__() super().__init__()
# Store the shared MHRRig as a non-registered Python attribute # Store the shared MHRRig as a non-registered Python attribute
object.__setattr__(self, "mhr", mhr_rig) object.__setattr__(self, "mhr", mhr_rig)
@ -48,9 +38,7 @@ class MHRHead(nn.Module):
hidden_dim=input_dim // mlp_channel_div_factor, hidden_dim=input_dim // mlp_channel_div_factor,
output_dim=self.npose, output_dim=self.npose,
num_layers=mlp_depth, num_layers=mlp_depth,
device=device, device=device, dtype=dtype, operations=operations,
dtype=dtype,
operations=operations,
) )
# MHR Parameters # MHR Parameters
@ -75,28 +63,25 @@ class MHRHead(nn.Module):
self.local_to_world_wrist = _p(3, 3) self.local_to_world_wrist = _p(3, 3)
self.nonhand_param_idxs = _p(145, dtype=torch.int64) self.nonhand_param_idxs = _p(145, dtype=torch.int64)
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader). # Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
# Optional — loaded from the .safetensors if present, otherwise the self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32))
# render path falls back to a coarse geometric approximation.
self.register_buffer(
"face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32),
)
def canonical_vertices(self, device=None): def canonical_vertices(self):
"""Return the T-pose vertices for the mean shape (scaled to meters). """Return the T-pose vertices for the mean shape (scaled to meters).
Runs MHR with zero pose / shape / scale / expression so the returned Runs MHR with zero pose / shape / scale / expression so the returned
mesh is the canonical rest pose fixed per-model mesh is the canonical rest pose fixed per-model
""" """
dev = device or self.scale_mean.device device = self.scale_mean.device
dt = self.scale_mean.dtype dtype = self.scale_mean.dtype
B = 1 B = 1
global_trans = torch.zeros(B, 3, device=dev, dtype=dt) global_trans = torch.zeros(B, 3, device=device, dtype=dtype)
global_rot = torch.zeros(B, 3, device=dev, dtype=dt) global_rot = torch.zeros(B, 3, device=device, dtype=dtype)
body_pose = torch.zeros(B, 130, device=dev, dtype=dt) body_pose = torch.zeros(B, 130, device=device, dtype=dtype)
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt) hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=device, dtype=dtype)
scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt) scale = torch.zeros(B, self.num_scale_comps, device=device, dtype=dtype)
shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt) shape = torch.zeros(B, self.num_shape_comps, device=device, dtype=dtype)
expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt) expr = torch.zeros(B, self.num_face_comps, device=device, dtype=dtype)
verts = self.mhr_forward( verts = self.mhr_forward(
global_trans=global_trans, global_trans=global_trans,
global_rot=global_rot, global_rot=global_rot,
@ -108,20 +93,6 @@ class MHRHead(nn.Module):
) # single-tensor shape (1, N_v, 3) in meters ) # single-tensor shape (1, N_v, 3) in meters
return verts[0] return verts[0]
def get_zero_pose_init(self, factor=1.0):
# Initialize pose token with zero-initialized learnable params
# Note: bias/initial value should be zero-pose in cont, not all-zeros
weights = torch.zeros(1, self.npose)
weights[:, : 6 + self.body_cont_dim] = torch.cat(
[
torch.FloatTensor([1, 0, 0, 0, 1, 0]),
compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
* factor,
],
dim=0,
)
return weights
def replace_hands_in_pose(self, full_pose_params, hand_pose_params): def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
assert full_pose_params.shape[1] == 136 assert full_pose_params.shape[1] == 136
@ -159,12 +130,9 @@ class MHRHead(nn.Module):
shape_params, shape_params,
expr_params=None, expr_params=None,
return_keypoints=False, return_keypoints=False,
do_pcblend=True,
return_joint_coords=False, return_joint_coords=False,
return_model_params=False, return_model_params=False,
return_joint_rotations=False, return_joint_rotations=False,
scale_offsets=None,
vertex_offsets=None,
): ):
# Align everything to the static buffers # Align everything to the static buffers
dt = self.scale_mean.dtype dt = self.scale_mean.dtype
@ -206,14 +174,10 @@ class MHRHead(nn.Module):
shape_params = shape_params[None] shape_params = shape_params[None]
# Convert scale... # Convert scale...
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
if scale_offsets is not None:
scales = scales + scale_offsets
# Now, figure out the pose. # Now, figure out the pose.
## 10 here is because it's more stable to optimize global translation in meters. ## 10 here is because it's more stable to optimize global translation in meters.
full_pose_params = torch.cat( full_pose_params = torch.cat([global_trans * 10, global_rot, body_pose_params], dim=1) # B x 127
[global_trans * 10, global_rot, body_pose_params], dim=1
) # B x 127
## Put in hands ## Put in hands
if hand_pose_params is not None: if hand_pose_params is not None:
full_pose_params = self.replace_hands_in_pose( full_pose_params = self.replace_hands_in_pose(
@ -268,14 +232,7 @@ class MHRHead(nn.Module):
else: else:
return tuple(to_return) return tuple(to_return)
def forward( def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None, intermediate: bool = False):
self,
x: torch.Tensor,
init_estimate: Optional[torch.Tensor] = None,
do_pcblend=True,
slim_keypoints=False,
intermediate: bool = False,
):
""" """
Args: Args:
x: pose token with shape [B, C], usually C=DECODER.DIM x: pose token with shape [B, C], usually C=DECODER.DIM
@ -331,7 +288,6 @@ class MHRHead(nn.Module):
scale_params=pred_scale, scale_params=pred_scale,
shape_params=pred_shape, shape_params=pred_shape,
expr_params=pred_face, expr_params=pred_face,
do_pcblend=do_pcblend,
return_keypoints=True, return_keypoints=True,
return_joint_coords=True, return_joint_coords=True,
return_model_params=True, return_model_params=True,
@ -356,7 +312,7 @@ class MHRHead(nn.Module):
# Head-MLP outputs are promoted to fp32 here so the external # Head-MLP outputs are promoted to fp32 here so the external
# pose_output["mhr"] contract has a stable dtype regardless of what # pose_output["mhr"] contract has a stable dtype regardless of what
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are # the head ran at (fp16/bf16 for speed). MHR-derived outputs are
# already fp32 from MHR's math layers; the cast on them is a no-op. # already fp32 from MHR's math layers.
output = { output = {
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(), "pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
"pred_pose_rotmat": None, "pred_pose_rotmat": None,

View File

@ -1,7 +1,7 @@
# Adapted from facebookresearch/MHR (Apache 2.0): # Adapted from facebookresearch/MHR (Apache 2.0):
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py # https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas # Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
# verbatim from the TorchScript source bundled in the upstream mhr_model.pt # verbatim from the upstream mhr_model.pt
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}). # (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
# Original Copyright (c) Meta Platforms, Inc. and affiliates. # Original Copyright (c) Meta Platforms, Inc. and affiliates.
@ -52,7 +52,7 @@ def _skel_multiply(s1, s2):
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
before composition. With many FK levels the previously-normalized quats before composition. With many FK levels the previously-normalized quats
drift in ULPs; the JIT renormalizes defensively, so we do too to stay drift in ULPs; upstream renormalizes defensively, so we do too to stay
bit-close to its outputs. bit-close to its outputs.
""" """
t1, sc1 = s1[..., :3], s1[..., 7:8] t1, sc1 = s1[..., :3], s1[..., 7:8]
@ -78,7 +78,7 @@ def _skel_transform_points(skel_state, points):
def _global_skel_state_from_local(local, pmi_levels): def _global_skel_state_from_local(local, pmi_levels):
"""FK walk in fp64 (matches the JIT's use_double_precision=True path). """FK walk in fp64 (matches upstream's use_double_precision=True path).
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs, `pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
one per BFS level. Avoids per-call torch.split + tolist() sync. one per BFS level. Avoids per-call torch.split + tolist() sync.
@ -95,7 +95,7 @@ def _global_skel_state_from_local(local, pmi_levels):
class MHRRig(nn.Module): class MHRRig(nn.Module):
"""Plain-PyTorch reimplementation of Meta's MHR rig. """Plain-PyTorch reimplementation of Meta's MHR rig.
All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's All math runs in fp32 (FK upcast to fp64 internally, matching upstream's
use_double_precision=True backend) regardless of the host model's dtype. use_double_precision=True backend) regardless of the host model's dtype.
""" """
@ -110,13 +110,11 @@ class MHRRig(nn.Module):
POSE_CORR_HIDDEN = 3000 POSE_CORR_HIDDEN = 3000
POSE_CORR_SPARSE_NNZ = 53136 POSE_CORR_SPARSE_NNZ = 53136
def __init__(self, device=None, dtype=None, operations=None): def __init__(self, device=None):
super().__init__() super().__init__()
del dtype, operations
f32 = torch.float32
# All buffers are populated by load_state_dict from the `mhr.*` keys # All buffers are populated by load_state_dict from the `mhr.*` keys
def _p(*shape, dtype=f32): def _p(*shape, dtype=torch.float32):
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False) return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
def _b(name, *shape, dtype): def _b(name, *shape, dtype):
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device)) self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
@ -147,10 +145,10 @@ class MHRRig(nn.Module):
self._pmi_levels_cache = None self._pmi_levels_cache = None
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True): def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
f32 = self.base_shape.dtype dtype = self.base_shape.dtype
identity_coeffs = identity_coeffs.to(f32) identity_coeffs = identity_coeffs.to(dtype)
model_parameters = model_parameters.to(f32) model_parameters = model_parameters.to(dtype)
expr_coeffs = expr_coeffs.to(f32) expr_coeffs = expr_coeffs.to(dtype)
B = identity_coeffs.shape[0] B = identity_coeffs.shape[0]
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs) identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)

View File

@ -1,5 +1,5 @@
# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers # MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity # (batch6DFromXYZ, batchXYZfrom6D) are the continuity
# representation from Zhou et al., "On the Continuity of Rotation # representation from Zhou et al., "On the Continuity of Rotation
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035), # Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
# implementations from papagina/RotationContinuity: # implementations from papagina/RotationContinuity:
@ -158,18 +158,10 @@ def _hand_masks(device):
m = _HAND_MASK_CACHE.get(device) m = _HAND_MASK_CACHE.get(device)
if m is not None: if m is not None:
return m return m
mask_cont_threedofs = torch.cat( mask_cont_threedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
[torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS] mask_cont_onedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
).to(device) mask_model_params_threedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
mask_cont_onedofs = torch.cat( mask_model_params_onedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
[torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
).to(device)
mask_model_params_threedofs = torch.cat(
[torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
).to(device)
mask_model_params_onedofs = torch.cat(
[torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
).to(device)
m = dict( m = dict(
mask_cont_threedofs=mask_cont_threedofs, mask_cont_threedofs=mask_cont_threedofs,
mask_cont_onedofs=mask_cont_onedofs, mask_cont_onedofs=mask_cont_onedofs,
@ -182,7 +174,6 @@ def _hand_masks(device):
def compact_cont_to_model_params_hand(hand_cont): def compact_cont_to_model_params_hand(hand_cont):
# These are ordered by joint, not model params ^^ # These are ordered by joint, not model params ^^
assert hand_cont.shape[-1] == 54
m = _hand_masks(hand_cont.device) m = _hand_masks(hand_cont.device)
mask_cont_threedofs = m["mask_cont_threedofs"] mask_cont_threedofs = m["mask_cont_threedofs"]
mask_cont_onedofs = m["mask_cont_onedofs"] mask_cont_onedofs = m["mask_cont_onedofs"]
@ -209,120 +200,6 @@ def compact_cont_to_model_params_hand(hand_cont):
return hand_model_params return hand_model_params
def compact_model_params_to_cont_hand(hand_model_params):
# These are ordered by joint, not model params ^^
assert hand_model_params.shape[-1] == 27
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
assert sum(hand_dofs_in_order) == 27
# Mask of 3DoFs into hand_cont
mask_cont_threedofs = torch.cat(
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
)
# Mask of 1DoFs (including 2DoF) into hand_cont
mask_cont_onedofs = torch.cat(
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
)
# Mask of 3DoFs into hand_model_params
mask_model_params_threedofs = torch.cat(
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
)
# Mask of 1DoFs (including 2DoF) into hand_model_params
mask_model_params_onedofs = torch.cat(
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
)
# Convert eulers to hand_cont hand_cont
## First for 3DoFs
hand_model_params_threedofs = hand_model_params[
..., mask_model_params_threedofs
].unflatten(-1, (-1, 3))
hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
## Next for 1DoFs
hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
hand_cont_onedofs = torch.stack(
[hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
).flatten(-2, -1)
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
return hand_cont
def batch9Dfrom6D(poses):
# Args: poses: ... x 6, where "6" is the combined first and second columns
# First, get the rotaiton matrix
x_raw = poses[..., :3]
y_raw = poses[..., 3:]
x = F.normalize(x_raw, dim=-1)
z = torch.cross(x, y_raw, dim=-1)
z = F.normalize(z, dim=-1)
y = torch.cross(z, x, dim=-1)
matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
return matrix
def batch4Dfrom2D(poses):
# Args: poses: ... x 2, where "2" is sincos
poses_norm = F.normalize(poses, dim=-1)
poses_4d = torch.stack(
[
poses_norm[..., 1],
poses_norm[..., 0],
-poses_norm[..., 0],
poses_norm[..., 1],
],
dim=-1,
) # Flattened SO2.
return poses_4d # .... x 4
def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
# fmt: on
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_cont.shape[-1] == (
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
)
# Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
]
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
# Convert conts to model params
## First for 3dofs
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
## Next for 1dofs
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
if inflate_trans:
assert (
False
), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
else:
## Nothing to do for trans
body_rotmat_trans = body_cont_trans
# Put them together
body_rotmat_params = torch.cat(
[body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
)
return body_rotmat_params
_BODY_IDX_CACHE: dict = {} _BODY_IDX_CACHE: dict = {}
@ -349,8 +226,6 @@ def compact_cont_to_model_params_body(body_pose_cont):
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device) (all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3 num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs) num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
# Get subsets # Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles] body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles] body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
@ -372,42 +247,10 @@ def compact_cont_to_model_params_body(body_pose_cont):
return body_pose_params return body_pose_params
def compact_model_params_to_cont_body(body_pose_params): # Hand indices into the 133-dim param and 260-dim cont body-pose vectors.
# fmt: off mhr_param_hand_idxs = list(range(62, 116))
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]) mhr_cont_hand_idxs = list(range(72, 132)) + list(range(190, 238))
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
# fmt: on
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_params.shape[-1] == (
num_3dof_angles + num_1dof_angles + num_1dof_trans
)
# Take out params
body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
# params to cont
body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
-2, -1
)
body_cont_1dofs = torch.stack(
[body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
).flatten(-2, -1)
body_cont_trans = body_params_trans
# Put them together
body_pose_cont = torch.cat(
[body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
)
return body_pose_cont
# fmt: off
mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
mhr_param_hand_mask = torch.zeros(133).bool() mhr_param_hand_mask = torch.zeros(133).bool()
mhr_param_hand_mask[mhr_param_hand_idxs] = True mhr_param_hand_mask[mhr_param_hand_idxs] = True
mhr_cont_hand_mask = torch.zeros(260).bool() mhr_cont_hand_mask = torch.zeros(260).bool()
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
# fmt: on

View File

@ -43,15 +43,6 @@ class FourierPositionEncoding(nn.Module):
self.num_bands = num_bands self.num_bands = num_bands
self.max_resolution = [max_resolution] * n self.max_resolution = [max_resolution] * n
@property
def channels(self):
num_dims = len(self.max_resolution)
encoding_size = self.num_bands * num_dims
encoding_size *= 2 # sin-cos
encoding_size += num_dims # concat
return encoding_size
def forward(self, pos: torch.Tensor): def forward(self, pos: torch.Tensor):
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution) fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
return fourier_pos_enc return fourier_pos_enc
@ -118,9 +109,7 @@ class PerspectiveHead(nn.Module):
pred_cam: torch.Tensor, pred_cam: torch.Tensor,
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h) bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
bbox_size: torch.Tensor, # [N,], in original image space bbox_size: torch.Tensor, # [N,], in original image space
img_size: torch.Tensor,
cam_int: torch.Tensor, # [B, 3, 3] cam_int: torch.Tensor, # [B, 3, 3]
use_intrin_center: bool = False,
): ):
batch_size = points_3d.shape[0] batch_size = points_3d.shape[0]
pred_cam = pred_cam.clone() pred_cam = pred_cam.clone()
@ -133,12 +122,8 @@ class PerspectiveHead(nn.Module):
focal_length = cam_int[:, 0, 0] focal_length = cam_int[:, 0, 0]
tz = 2 * focal_length / bs tz = 2 * focal_length / bs
if not use_intrin_center: cx = 2 * (bbox_center[:, 0] - cam_int[:, 0, 2]) / bs
cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs cy = 2 * (bbox_center[:, 1] - cam_int[:, 1, 2]) / bs
cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
else:
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)

View File

@ -37,20 +37,15 @@ class SAM3DBody(nn.Module):
def __init__(self, device=None, dtype=None, operations=None): def __init__(self, device=None, dtype=None, operations=None):
super().__init__() super().__init__()
# `operations` falls back to torch.nn so the model is constructible
# without comfy.ops; matches the pattern in comfy/ldm/sam3/.
ops = operations if operations is not None else nn
# Per-batch state populated by `_initialize_batch`. # Per-batch state populated by `_initialize_batch`.
self._max_num_person = None self._max_num_person = None
self._person_valid = None
self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False) self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False)
self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False) self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False)
self.image_size = IMAGE_SIZE self.image_size = IMAGE_SIZE
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=ops) self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations)
embed_dims = self.backbone.embed_dims embed_dims = self.backbone.embed_dims
# MHR rig shared between body + hand pose heads via a non-registered # MHR rig shared between body + hand pose heads via a non-registered
@ -72,7 +67,7 @@ class SAM3DBody(nn.Module):
self.head_pose.hand_pose_comps.data = ( self.head_pose.hand_pose_comps.data = (
torch.eye(54).to(self.head_pose.hand_pose_comps.data).float() torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
) )
self.init_pose = ops.Embedding(1, self.head_pose.npose, device=device, dtype=dtype) self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs) self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs)
self.head_pose_hand.hand_pose_comps_ori = nn.Parameter( self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
@ -81,7 +76,7 @@ class SAM3DBody(nn.Module):
self.head_pose_hand.hand_pose_comps.data = ( self.head_pose_hand.hand_pose_comps.data = (
torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float() torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
) )
self.init_pose_hand = ops.Embedding( self.init_pose_hand = operations.Embedding(
1, self.head_pose_hand.npose, device=device, dtype=dtype 1, self.head_pose_hand.npose, device=device, dtype=dtype
) )
@ -93,25 +88,25 @@ class SAM3DBody(nn.Module):
device=device, dtype=dtype, operations=operations, device=device, dtype=dtype, operations=operations,
) )
self.head_camera = PerspectiveHead(**camera_kwargs) self.head_camera = PerspectiveHead(**camera_kwargs)
self.init_camera = ops.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype) self.init_camera = operations.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype)
self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs) self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs)
self.init_camera_hand = ops.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype) self.init_camera_hand = operations.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype)
cond_dim = 3 cond_dim = 3
init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim
linear_kwargs = dict(device=device, dtype=dtype) linear_kwargs = dict(device=device, dtype=dtype)
self.init_to_token_mhr = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs) self.init_to_token_mhr = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs)
self.prev_to_token_mhr = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) self.prev_to_token_mhr = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
self.init_to_token_mhr_hand = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs) self.init_to_token_mhr_hand = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs)
self.prev_to_token_mhr_hand = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) self.prev_to_token_mhr_hand = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
self.prompt_encoder = PromptEncoder( self.prompt_encoder = PromptEncoder(
embed_dim=embed_dims, # match backbone dims so PE adds directly embed_dim=embed_dims, # match backbone dims so PE adds directly
num_body_joints=N_KEYPOINTS, num_body_joints=N_KEYPOINTS,
device=device, dtype=dtype, operations=operations, device=device, dtype=dtype, operations=operations,
) )
self.prompt_to_token = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) self.prompt_to_token = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
decoder_kwargs = dict( decoder_kwargs = dict(
dims=DECODER_DIM, dims=DECODER_DIM,
@ -141,11 +136,10 @@ class SAM3DBody(nn.Module):
self.keypoint_embedding_idxs = list(range(N_KEYPOINTS)) self.keypoint_embedding_idxs = list(range(N_KEYPOINTS))
self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS)) self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS))
self.keypoint_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.keypoint_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.keypoint_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.hand_box_embedding = ops.Embedding(2, DECODER_DIM, **linear_kwargs) self.hand_box_embedding = operations.Embedding(2, DECODER_DIM, **linear_kwargs)
self.hand_cls_embed = ops.Linear(DECODER_DIM, 2, **linear_kwargs)
self.bbox_embed = MLP( self.bbox_embed = MLP(
input_dim=DECODER_DIM, hidden_dim=DECODER_DIM, input_dim=DECODER_DIM, hidden_dim=DECODER_DIM,
output_dim=4, num_layers=3, output_dim=4, num_layers=3,
@ -158,13 +152,13 @@ class SAM3DBody(nn.Module):
) )
self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs) self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs)
self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs) self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs)
self.keypoint_feat_linear = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) self.keypoint_feat_linear = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
self.keypoint_feat_linear_hand = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) self.keypoint_feat_linear_hand = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS)) self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS))
self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS)) self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS))
self.keypoint3d_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.keypoint3d_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint3d_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.keypoint3d_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs) self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs)
self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs) self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs)
@ -183,11 +177,9 @@ class SAM3DBody(nn.Module):
def _initialize_batch(self, batch: Dict) -> None: def _initialize_batch(self, batch: Dict) -> None:
if batch["img"].dim() == 5: if batch["img"].dim() == 5:
self._batch_size, self._max_num_person = batch["img"].shape[:2] self._batch_size, self._max_num_person = batch["img"].shape[:2]
self._person_valid = self._flatten_person(batch["person_valid"]) > 0
else: else:
self._batch_size = batch["img"].shape[0] self._batch_size = batch["img"].shape[0]
self._max_num_person = 0 self._max_num_person = 0
self._person_valid = None
def _flatten_person(self, x: torch.Tensor) -> torch.Tensor: def _flatten_person(self, x: torch.Tensor) -> torch.Tensor:
assert self._max_num_person is not None, "No max_num_person initialized" assert self._max_num_person is not None, "No max_num_person initialized"
@ -258,11 +250,9 @@ class SAM3DBody(nn.Module):
if is_multi_image: if is_multi_image:
assert isinstance(img, list) assert isinstance(img, list)
n = len(img) n = len(img)
H_src, W_src = img[0].shape[:2]
src_t = torch.stack(list(img), dim=0) src_t = torch.stack(list(img), dim=0)
else: else:
n = int(left_xyxy.shape[0]) n = int(left_xyxy.shape[0])
H_src, W_src = img.shape[:2]
src_t = img.unsqueeze(0).expand(n, -1, -1, -1) src_t = img.unsqueeze(0).expand(n, -1, -1, -1)
H_out, W_out = int(self.image_size[0]), int(self.image_size[1]) H_out, W_out = int(self.image_size[0]), int(self.image_size[1])
@ -292,14 +282,12 @@ class SAM3DBody(nn.Module):
zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device) zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device)
person_valid = torch.ones((1, n), dtype=torch.float32, device=device) person_valid = torch.ones((1, n), dtype=torch.float32, device=device)
img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous() img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous()
ori_img_size = torch.tensor([W_src, H_src], dtype=torch.float32, device=device).expand(n, 2).contiguous()
cam_int_dev = cam_int.to(device).to(dtype=torch.float32) cam_int_dev = cam_int.to(device).to(dtype=torch.float32)
def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy): def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy):
return { return {
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out) "img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
"img_size": img_size.unsqueeze(0), "img_size": img_size.unsqueeze(0),
"ori_img_size": ori_img_size.unsqueeze(0),
"bbox_center": centers_t.to(device).unsqueeze(0), "bbox_center": centers_t.to(device).unsqueeze(0),
"bbox_scale": scales_t.to(device).unsqueeze(0), "bbox_scale": scales_t.to(device).unsqueeze(0),
"bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0), "bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0),
@ -349,7 +337,6 @@ class SAM3DBody(nn.Module):
self, self,
branch: str, branch: str,
image_embeddings: torch.Tensor, image_embeddings: torch.Tensor,
init_estimate: Optional[torch.Tensor] = None,
keypoints: Optional[torch.Tensor] = None, keypoints: Optional[torch.Tensor] = None,
prev_estimate: Optional[torch.Tensor] = None, prev_estimate: Optional[torch.Tensor] = None,
condition_info: Optional[torch.Tensor] = None, condition_info: Optional[torch.Tensor] = None,
@ -359,7 +346,6 @@ class SAM3DBody(nn.Module):
of the pipeline is shared. of the pipeline is shared.
image_embeddings: (B, C, H, W) backbone features. image_embeddings: (B, C, H, W) backbone features.
init_estimate: (B, 1, C) initial pose+cam estimate to refine.
keypoints: (B, N, 3) prompts as (x, y in [0, 1], label). keypoints: (B, N, 3) prompts as (x, y in [0, 1], label).
label: 0..K = joint, -1 = incorrect, -2 = invalid. label: 0..K = joint, -1 = incorrect, -2 = invalid.
prev_estimate: (B, 1, C) previous estimate for pose refinement. prev_estimate: (B, 1, C) previous estimate for pose refinement.
@ -402,15 +388,11 @@ class SAM3DBody(nn.Module):
# .to(image_embeddings) moves weights CPU→GPU under dynamic loading # .to(image_embeddings) moves weights CPU→GPU under dynamic loading
# (they stay on CPU until first use). # (they stay on CPU until first use).
if init_estimate is None:
init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3) init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3)
init_input = ( init_input = torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
if condition_info is not None else init_estimate
)
token_embeddings = init_to_token(init_input).view(batch_size, 1, -1) token_embeddings = init_to_token(init_input).view(batch_size, 1, -1)
num_pose_token = token_embeddings.shape[1] # always 1 num_pose_token = token_embeddings.shape[1] # always 1
@ -495,9 +477,8 @@ class SAM3DBody(nn.Module):
def _get_mask_prompt(self, batch, image_embeddings): def _get_mask_prompt(self, batch, image_embeddings):
x_mask = self._flatten_person(batch["mask"]) x_mask = self._flatten_person(batch["mask"])
# batch tensors are fp32 from prepare_batch; mask_downscaling is in the
# Loader's dtype — cast once so the conv input matches.
x_mask = x_mask.to(image_embeddings.dtype) x_mask = x_mask.to(image_embeddings.dtype)
mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings( mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings(
x_mask, image_embeddings.shape[0], image_embeddings.shape[2:] x_mask, image_embeddings.shape[0], image_embeddings.shape[2:]
) )
@ -546,7 +527,6 @@ class SAM3DBody(nn.Module):
# expand+contiguous for the vertices branch. # expand+contiguous for the vertices branch.
bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx] bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx]
bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0] bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0]
ori_img_size = self._flatten_person(batch["ori_img_size"])[batch_idx]
cam_int = self._flatten_person( cam_int = self._flatten_person(
batch["cam_int"] batch["cam_int"]
.unsqueeze(1) .unsqueeze(1)
@ -556,8 +536,7 @@ class SAM3DBody(nn.Module):
def _project(points_3d): def _project(points_3d):
return head_camera.perspective_projection( return head_camera.perspective_projection(
points_3d, pred_cam, bbox_center, bbox_scale, ori_img_size, cam_int, points_3d, pred_cam, bbox_center, bbox_scale, cam_int,
use_intrin_center=True,
) )
cam_out = _project(pose_output["pred_keypoints_3d"]) cam_out = _project(pose_output["pred_keypoints_3d"])
@ -632,7 +611,6 @@ class SAM3DBody(nn.Module):
tokens_output, pose_output = self.forward_decoder( tokens_output, pose_output = self.forward_decoder(
"body", "body",
image_embeddings[self.body_batch_idx], image_embeddings[self.body_batch_idx],
init_estimate=None,
keypoints=keypoints_prompt[self.body_batch_idx], keypoints=keypoints_prompt[self.body_batch_idx],
prev_estimate=None, prev_estimate=None,
condition_info=condition_info[self.body_batch_idx], condition_info=condition_info[self.body_batch_idx],
@ -643,7 +621,6 @@ class SAM3DBody(nn.Module):
tokens_output_hand, pose_output_hand = self.forward_decoder( tokens_output_hand, pose_output_hand = self.forward_decoder(
"hand", "hand",
image_embeddings[self.hand_batch_idx], image_embeddings[self.hand_batch_idx],
init_estimate=None,
keypoints=keypoints_prompt[self.hand_batch_idx], keypoints=keypoints_prompt[self.hand_batch_idx],
prev_estimate=None, prev_estimate=None,
condition_info=condition_info[self.hand_batch_idx], condition_info=condition_info[self.hand_batch_idx],
@ -661,10 +638,8 @@ class SAM3DBody(nn.Module):
# match the head-MLP external contract (_get_hand_box would .float() anyway). # match the head-MLP external contract (_get_hand_box would .float() anyway).
if len(self.body_batch_idx): if len(self.body_batch_idx):
output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float() output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float()
output["mhr"]["hand_logits"] = self.hand_cls_embed(tokens_output).float()
if len(self.hand_batch_idx): if len(self.hand_batch_idx):
output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid() output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid()
output["mhr_hand"]["hand_logits"] = self.hand_cls_embed(tokens_output_hand)
return output return output
@ -715,10 +690,10 @@ class SAM3DBody(nn.Module):
# Concat lhand+rhand along dim 0 so backbone+decoder run once on # Concat lhand+rhand along dim 0 so backbone+decoder run once on
# (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass. # (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass.
batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand) batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand)
saved_batch_state = (self._batch_size, self._max_num_person, self._person_valid) saved_batch_state = (self._batch_size, self._max_num_person)
self._initialize_batch(batch_hands) self._initialize_batch(batch_hands)
hands_output = self.forward_step(batch_hands, decoder_type="hand") hands_output = self.forward_step(batch_hands, decoder_type="hand")
self._batch_size, self._max_num_person, self._person_valid = saved_batch_state self._batch_size, self._max_num_person = saved_batch_state
n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1] n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1]
lhand_output, rhand_output = self._split_hand_output(hands_output, n_left) lhand_output, rhand_output = self._split_hand_output(hands_output, n_left)
# Free the batched image_embeddings/condition_info (unused downstream); # Free the batched image_embeddings/condition_info (unused downstream);
@ -808,9 +783,7 @@ class SAM3DBody(nn.Module):
# to get an updated body pose estimation. # to get an updated body pose estimation.
self._set_active_branch("body") self._set_active_branch("body")
right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() # right_kps_full / left_kps_full already computed above (unchanged since).
left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1
right_kps_crop = self._full_to_crop(batch, right_kps_full) right_kps_crop = self._full_to_crop(batch, right_kps_full)
left_kps_crop = self._full_to_crop(batch, left_kps_full) left_kps_crop = self._full_to_crop(batch, left_kps_full)
@ -1030,7 +1003,6 @@ class SAM3DBody(nn.Module):
_, pose_output = self.forward_decoder( _, pose_output = self.forward_decoder(
"body", "body",
image_embeddings, image_embeddings,
init_estimate=None, # use the default init, not the prev estimate
keypoints=keypoint_prompt, keypoints=keypoint_prompt,
prev_estimate=prev_estimate, prev_estimate=prev_estimate,
condition_info=condition_info, condition_info=condition_info,

View File

@ -29,38 +29,37 @@ class PromptEncoder(nn.Module):
Encodes prompts for input to SAM's mask decoder. Encodes prompts for input to SAM's mask decoder.
""" """
super().__init__() super().__init__()
ops = operations if operations is not None else nn
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_body_joints = num_body_joints self.num_body_joints = num_body_joints
# Keypoint prompts # Keypoint prompts
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.point_embeddings = nn.ModuleList( self.point_embeddings = nn.ModuleList(
[ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)] [operations.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
) )
self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) self.not_a_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) self.invalid_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim. # Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
LN2d = LayerNorm2d_op(ops) LN2d = LayerNorm2d_op(operations)
mask_in_chans = 256 mask_in_chans = 256
self.mask_downscaling = nn.Sequential( self.mask_downscaling = nn.Sequential(
ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype), operations.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 64, device=device, dtype=dtype), LN2d(mask_in_chans // 64, device=device, dtype=dtype),
nn.GELU(), nn.GELU(),
ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype), operations.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 16, device=device, dtype=dtype), LN2d(mask_in_chans // 16, device=device, dtype=dtype),
nn.GELU(), nn.GELU(),
ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype), operations.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 4, device=device, dtype=dtype), LN2d(mask_in_chans // 4, device=device, dtype=dtype),
nn.GELU(), nn.GELU(),
ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype), operations.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans, device=device, dtype=dtype), LN2d(mask_in_chans, device=device, dtype=dtype),
nn.GELU(), nn.GELU(),
ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype), operations.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
) )
# Trained values for the gating conv and no_mask_embed are loaded from the state dict # Trained values for the gating conv and no_mask_embed are loaded from the state dict
self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) self.no_mask_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor: def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
"""Positional encoding over the image-embedding grid; (1, C, H, W).""" """Positional encoding over the image-embedding grid; (1, C, H, W)."""
@ -120,8 +119,7 @@ class PromptEncoder(nn.Module):
Bx(embed_dim)x(embed_H)x(embed_W) Bx(embed_dim)x(embed_H)x(embed_W)
""" """
bs = self._get_batch_size(keypoints, boxes, masks) bs = self._get_batch_size(keypoints, boxes, masks)
# Anchor device on the input prompts so we don't pull the offloaded
# CPU embedding device under dynamic loading.
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
device = ref.device if ref is not None else self.point_embeddings[0].weight.device device = ref.device if ref is not None else self.point_embeddings[0].weight.device
weight_dtype = self.invalid_point_embed.weight.dtype weight_dtype = self.invalid_point_embed.weight.dtype
@ -136,23 +134,10 @@ class PromptEncoder(nn.Module):
return sparse_embeddings, sparse_masks return sparse_embeddings, sparse_masks
def get_mask_embeddings( def get_mask_embeddings(self, masks: torch.Tensor, bs: int = 1, size: Tuple[int, int] = (16, 16)) -> torch.Tensor:
self, """Embeds mask inputs. Caller casts both outputs to its working dtype."""
masks: Optional[torch.Tensor] = None, no_mask_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, size[0], size[1])
bs: int = 1,
size: Tuple[int, int] = (16, 16), # [H, W]
) -> torch.Tensor:
"""Embeds mask inputs."""
# masks is always on the active device when present; fall back to the
# downscaling Conv's weight device when it isn't (rare callers).
ref = masks if masks is not None else next(self.mask_downscaling.parameters())
no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand(
bs, -1, size[0], size[1]
)
if masks is not None:
mask_embeddings = self.mask_downscaling(masks) mask_embeddings = self.mask_downscaling(masks)
else:
mask_embeddings = no_mask_embeddings
return mask_embeddings, no_mask_embeddings return mask_embeddings, no_mask_embeddings
@ -170,12 +155,9 @@ class PromptableDecoder(nn.Module):
repeat_pe: bool = False, repeat_pe: bool = False,
do_interm_preds: bool = False, do_interm_preds: bool = False,
keypoint_token_update: bool = False, keypoint_token_update: bool = False,
device=None, device=None, dtype=None, operations=None,
dtype=None,
operations=None,
): ):
super().__init__() super().__init__()
ops = operations if operations is not None else nn
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
TransformerDecoderLayer( TransformerDecoderLayer(
@ -193,7 +175,7 @@ class PromptableDecoder(nn.Module):
for i in range(depth) for i in range(depth)
) )
self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype) self.norm_final = operations.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
self.do_interm_preds = do_interm_preds self.do_interm_preds = do_interm_preds
self.keypoint_token_update = keypoint_token_update self.keypoint_token_update = keypoint_token_update

View File

@ -166,12 +166,10 @@ def prepare_batch(
mask_score_t = torch.ones((n,), dtype=torch.float32) mask_score_t = torch.ones((n,), dtype=torch.float32)
img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous() img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous()
ori_img_size_t = torch.tensor([width, height], dtype=torch.float32).expand(n, 2).contiguous()
batch = { batch = {
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out) "img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
"img_size": img_size_t.unsqueeze(0), # (1, N, 2) "img_size": img_size_t.unsqueeze(0), # (1, N, 2)
"ori_img_size": ori_img_size_t.unsqueeze(0),# (1, N, 2)
"bbox_center": centers.unsqueeze(0), # (1, N, 2) "bbox_center": centers.unsqueeze(0), # (1, N, 2)
"bbox_scale": scales.unsqueeze(0), # (1, N, 2) "bbox_scale": scales.unsqueeze(0), # (1, N, 2)
"bbox": boxes_t.unsqueeze(0), # (1, N, 4) "bbox": boxes_t.unsqueeze(0), # (1, N, 4)

View File

@ -1,11 +1,9 @@
"""BVH export for SAM 3D Body pose_data. """BVH export for SAM 3D Body pose_data.
BVH stores explicit bone OFFSETs per joint, so any standard importer BVH stores explicit bone OFFSETs per joint, so standard importers reconstruct
(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations anatomical bone orientations directly (unlike glTF). We skip the rig's joint 0
directly no heuristic guessing as needed for glTF. We skip the rig's joint 0 (static world anchor) and use joint 1 as the ROOT (6 channels: XYZ pos + ZXY
(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos + rot); other joints get 3 channels. Rotations are intrinsic Z-X-Y Euler degrees.
ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are
intrinsic Z-X-Y Euler degrees.
""" """
from __future__ import annotations from __future__ import annotations
@ -49,13 +47,10 @@ def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray:
def _find_bvh_root(parents: np.ndarray, is_external: bool = False) -> int: def _find_bvh_root(parents: np.ndarray, is_external: bool = False) -> int:
"""First child of the rig's world anchor so the static origin→body stick """First child of the rig's world anchor, dropping the origin→body stick.
bone gets left out. Falls back to the first root joint. Falls back to the first root joint. External rigs whose root is already the
articulated body root with multiple child chains keep the root descending
MHR's joint 0 is a static world anchor whose single child is the pelvis, so into one child would drop the sibling limbs."""
skipping it is correct. External rigs (e.g. SOMA-77) whose root is already
the articulated body root with multiple child chains must keep the root
descending into one child would drop the sibling limbs from the BVH."""
NJ = parents.shape[0] NJ = parents.shape[0]
world_anchors = [j for j in range(NJ) world_anchors = [j for j in range(NJ)
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)] if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
@ -93,14 +88,11 @@ def build_bvh(
track_index: int = -1, track_index: int = -1,
units: str = "cm", units: str = "cm",
) -> bytes: ) -> bytes:
"""Build a BVH file from pose_data. Returns UTF-8 encoded text bytes. """Build a BVH file from pose_data. Returns UTF-8 text bytes.
`model` may be None when pose_data carries a `_skeleton_override` (external `model` may be None when pose_data carries a `_skeleton_override` (external
rigs, e.g. Kimodo); the rig hierarchy/offsets/bind are read from the rigs); the rig hierarchy/offsets/bind come from the override. `units` is
override instead of the MHR model. "cm" (default) or "m" affects OFFSET/root-position, not rotations.
`units` is "cm" (default, standard mocap convention) or "m". Affects the
OFFSET and root-position values; rotations are independent of units.
""" """
if units not in ("cm", "m"): if units not in ("cm", "m"):
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}") raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
@ -123,10 +115,8 @@ def build_bvh(
body_root = _find_bvh_root(parents, is_external) body_root = _find_bvh_root(parents, is_external)
children_map = _build_children_map(parents) children_map = _build_children_map(parents)
# Bone OFFSETs come from MHR's translation_offsets (joint position # Bone OFFSETs = translation_offsets (joint position relative to parent).
# relative to parent in parent's local-bind frame). For the BVH root, # The BVH root uses its bind world position so the skeleton imports in place.
# we use its bind world position so the skeleton sits at the right
# spot when imported.
bind_global = rig.bind_global_cm # (NJ, 8) cm bind_global = rig.bind_global_cm # (NJ, 8) cm
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
offset_m = rig.joint_offsets_cm.astype(np.float64) * 0.01 offset_m = rig.joint_offsets_cm.astype(np.float64) * 0.01
@ -139,9 +129,8 @@ def build_bvh(
_visit(c) _visit(c)
_visit(body_root) _visit(body_root)
# Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative) # Stored pred_global_rots/pred_joint_coords (authoritative); derive locals
# rather than re-running rig.forward, then derive locals with body_root # with body_root as the BVH-space hierarchy root.
# treated as the hierarchy root in BVH-space.
rig_global_m = global_skel_state_from_pose_data( rig_global_m = global_skel_state_from_pose_data(
pose_data, frame_indices, person_k, NJ, pose_data, frame_indices, person_k, NJ,
joint_coords_y_down=rig.per_frame_y_down, joint_coords_y_down=rig.per_frame_y_down,
@ -203,9 +192,8 @@ def build_bvh(
lines.append(f"Frames: {n_frames}") lines.append(f"Frames: {n_frames}")
lines.append(f"Frame Time: {1.0 / float(fps):.6f}") lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
# Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per # Channel matrix per frame: root pos (3) + root rot (3) + non-root rots
# frame, columns in `bvh_order` order. Vectorized — savetxt's C-side # (3 each), columns in `bvh_order`. savetxt is far faster than f-strings.
# formatting beats Python f-strings by ~10× on long clips.
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64) non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
motion = np.concatenate([ motion = np.concatenate([
root_pos_m * unit_scale, # (N, 3) root_pos_m * unit_scale, # (N, 3)

View File

@ -1,12 +1,9 @@
"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent """3D capsule rendering for OpenPose-style skeletons (SCAIL-Pose look).
torch ray-marching SDF renderer adapted to SAM3DBody pose_data.
Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps), Each limb is a true 3D capsule (cylinder + hemispherical caps), projected
projected through the per-person camera (`pred_cam_t` + `focal_length` + through the per-person camera (`pred_cam_t` + `focal_length` + image_size) so
image_size) so closer limbs appear thicker/brighter the SCAIL-Pose closer limbs appear thicker/brighter. Self-contained analytic ray-capsule
visual style. Self-contained: no dependency on the SCAIL-Pose package. renderer. Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
""" """
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -41,14 +38,12 @@ def _build_specs_from_pose(
palette: str, palette: str,
person_brightness_falloff: float = 0.0, person_brightness_falloff: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Flatten body + optional hand limbs for one frame into """Flatten body + optional hand limbs for one frame into (starts, ends,
(starts, ends, colors_rgba, is_hand) in camera coords (Y-down, +Z forward). colors_rgba, is_hand) in camera coords (Y-down, +Z forward). Drops non-finite
Drops endpoints that are non-finite or behind the camera. `is_hand` flags or behind-camera endpoints; `is_hand` lets the renderer draw hands thinner.
the hand limbs so the renderer can draw them thinner.
`person_brightness_falloff` mixes each per-person limb color toward white `person_brightness_falloff` mixes each per-person color toward white by
by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the `1 - falloff^k` for track k (track 0 stays vivid)."""
mesh rasterizer and GLB exporters."""
starts: List[np.ndarray] = [] starts: List[np.ndarray] = []
ends: List[np.ndarray] = [] ends: List[np.ndarray] = []
colors: List[np.ndarray] = [] colors: List[np.ndarray] = []
@ -65,8 +60,7 @@ def _build_specs_from_pose(
if body_op is None or cam_t is None: if body_op is None or cam_t is None:
continue continue
cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3) cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3)
# op-keypoints are camera frame (Y-down); add cam_t to place the # op-keypoints are camera frame; add cam_t to place the subject in front.
# subject in front of the camera.
body_kp = body_op + cam_t_np[None, :] body_kp = body_op + cam_t_np[None, :]
pastel = 0.0 if k == 0 else (1.0 - falloff ** k) pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
@ -148,10 +142,9 @@ def _ray_capsule_t(
ba_len: torch.Tensor, # (M,) segment length ba_len: torch.Tensor, # (M,) segment length
radius: torch.Tensor, # (M,) per-capsule radius radius: torch.Tensor, # (M,) per-capsule radius
) -> torch.Tensor: ) -> torch.Tensor:
"""Closed-form ray-capsule intersection. Returns (K, M) tensor of ray """Closed-form ray-capsule intersection -> (K, M) ray params t to the nearest
parameters t to the nearest valid hit per capsule, +inf where the ray valid hit per capsule, +inf on miss. Capsule = union of (cylinder, hemisphere
misses. A capsule is the union of (cylinder body, hemisphere at A, at A, hemisphere at B), each a quadratic root-find."""
hemisphere at B); each component is a quadratic root-find."""
INF = float("inf") INF = float("inf")
r_sq = radius * radius # (M,) r_sq = radius * radius # (M,)
@ -238,9 +231,8 @@ def _render_capsules_torch(
z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item())) z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item()))
z_near = max(0.05, z_min - float(radius.max().item())) z_near = max(0.05, z_min - float(radius.max().item()))
# Union of per-capsule screen-space bboxes. Pixels outside this mask # Union of per-capsule screen-space bboxes — pixels outside can't hit any
# provably can't hit any capsule, so the analytic intersection only runs # capsule, so intersection only runs on the relevant subset of the canvas.
# on the relevant subset of the canvas (~5-15% at 1080p for typical poses).
sz = starts[:, 2].clamp(min=z_near) sz = starts[:, 2].clamp(min=z_near)
ez = ends[:, 2].clamp(min=z_near) ez = ends[:, 2].clamp(min=z_near)
sx_p = starts[:, 0] * fx / sz + cx sx_p = starts[:, 0] * fx / sz + cx
@ -261,16 +253,13 @@ def _render_capsules_torch(
if xmax_i > xmin_i and ymax_i > ymin_i: if xmax_i > xmin_i and ymax_i > ymin_i:
coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True
# Analytic ray-capsule intersection. One pass over the masked pixels — # Analytic ray-capsule intersection, one pass over the masked pixels.
# the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel
# plus 6 SDF evaluations per hit pixel for finite-difference normals.
INF = float("inf") INF = float("inf")
flat_t = torch.full((N,), INF, device=device, dtype=torch.float32) flat_t = torch.full((N,), INF, device=device, dtype=torch.float32)
flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long) flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long)
active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1) active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1)
if active_idx.numel() > 0: if active_idx.numel() > 0:
# Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory # Cap per-chunk (K, M) tensors to ~4M elements to bound peak memory.
# manageable when both K (image pixels) and M (capsules) are large.
chunk_max = max(1, int(4_000_000 / max(M, 1))) chunk_max = max(1, int(4_000_000 / max(M, 1)))
for i0 in range(0, active_idx.numel(), chunk_max): for i0 in range(0, active_idx.numel(), chunk_max):
sub = active_idx[i0 : i0 + chunk_max] sub = active_idx[i0 : i0 + chunk_max]
@ -284,7 +273,7 @@ def _render_capsules_torch(
flat_t[winners] = t_min[hit] flat_t[winners] = t_min[hit]
flat_m_idx[winners] = m_idx[hit] flat_m_idx[winners] = m_idx[hit]
# Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade. # Shade via analytic normal (P - closest point on segment).
out = torch.zeros((N, 3), dtype=torch.float32, device=device) out = torch.zeros((N, 3), dtype=torch.float32, device=device)
if background_rgb is not None: if background_rgb is not None:
out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone() out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone()
@ -306,10 +295,10 @@ def _render_capsules_torch(
col = colors[m_h, :3] col = colors[m_h, :3]
if flat_shade: if flat_shade:
# Solid per-limb color (OpenPose look) — no lighting/depth modulation. # Solid per-limb color (OpenPose look) — no lighting/depth.
out[hit_idx] = col out[hit_idx] = col
return out.view(H, W, 3).clamp(0.0, 1.0) return out.view(H, W, 3).clamp(0.0, 1.0)
# SCAIL Blinn-Phong (render_torch.py:290-331). Headlight: light = +Z. # SCAIL Blinn-Phong, headlight along +Z.
diff = torch.clamp(-(normals[:, 2]), min=0.0) diff = torch.clamp(-(normals[:, 2]), min=0.0)
diffuse = 0.45 + 0.55 * diff diffuse = 0.45 + 0.55 * diff
@ -319,7 +308,7 @@ def _render_capsules_torch(
half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8) half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8)
spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32) spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32)
# Mild depth fade matches SCAIL's mm-scale ramp in our meter units. # Mild depth fade.
z_vals = p_hit[:, 2] z_vals = p_hit[:, 2]
z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item()) z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item())
if z_hi - z_lo > 1e-6: if z_hi - z_lo > 1e-6:
@ -351,21 +340,18 @@ def render_pose_data_capsules(
hand_radius_scale: float = 0.4, hand_radius_scale: float = 0.4,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Render a frame's pose_data as 3D capsules projected through the per- """Render a frame's pose_data as 3D capsules through the per-person camera.
person camera. Returns (H, W, 3) fp32 in [0, 1]. Returns (H, W, 3) fp32 in [0, 1].
`composite='over'` paints over `background` (black if None); `composite='over'` paints over `background` (black if None); 'mesh_only'
`composite='mesh_only'` always uses a black canvas. uses a black canvas. `radius_m` is in meters; hand limbs use
`radius_m * hand_radius_scale`. fx/fy come from each person's `focal_length`.
`radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`).
Hand limbs use `radius_m * hand_radius_scale` (their bones are far shorter
than body limbs). Camera fx/fy come from each person's `focal_length`.
""" """
persons = pose_data["frames"][frame_idx] persons = pose_data["frames"][frame_idx]
if device is None: if device is None:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
# SAM3DBody shares one camera across the clip — pick from the first valid person. # SAM3DBody shares one camera across the clip — use the first valid person.
fx = fy = float(min(H, W)) fx = fy = float(min(H, W))
for person in persons: for person in persons:
f = person.get("focal_length") f = person.get("focal_length")

View File

@ -1,16 +1,10 @@
"""GLB export — OpenPose 18-keypoint visualization mode. """GLB export — OpenPose 18-keypoint visualization mode.
Independent of the MHR rig sourced from pose_data's `pred_keypoints_3d` Sourced from pose_data's `pred_keypoints_3d`, independent of the MHR rig. Each
(the model's regressed surface keypoints). Each track becomes an armature track becomes an armature with a joint per keypoint; sphere markers and limbs
with a sibling joint per keypoint; sphere markers + stick/capsule limbs are are skinned to those joints. Optional hands (`pred_keypoints_3d` 21..62) and
skinned to those joints. face landmarks (`pred_vertices` at fixed vertex IDs) extend the same armature.
Shared tables/palettes/mappings live in `glb_shared.py`.
Optional hand keypoints (also from `pred_keypoints_3d`, indices 21..62) and
face landmarks (sampled from `pred_vertices` at fixed head-mesh vertex IDs)
extend the same armature.
OpenPose-shared tables / palettes / mappings live in `glb_shared.py` and are
imported below they're also used by the 2D and 3D renderers in this package.
""" """
from __future__ import annotations from __future__ import annotations
@ -55,9 +49,8 @@ def _finalize_skinned_mesh(
joints: np.ndarray, weights: np.ndarray, vert_colors: np.ndarray, joints: np.ndarray, weights: np.ndarray, vert_colors: np.ndarray,
smooth_shade: bool, smooth_shade: bool,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Apply smooth or flat shading to an indexed sphere/stick group mesh and """Shade a skinned group mesh and pack per-vertex colors. Smooth keeps the
pack per-vertex colors. Smooth keeps the indexed mesh + per-vertex colors; indexed mesh; flat duplicates verts per face and gathers face-corner colors."""
flat duplicates verts per face and gathers face-corner colors."""
if smooth_shade: if smooth_shade:
v_f, n_f, f_f, j_f, w_f = smooth_shade_mesh(verts, faces, joints, weights) v_f, n_f, f_f, j_f, w_f = smooth_shade_mesh(verts, faces, joints, weights)
return v_f, n_f, f_f, j_f, w_f, vert_colors.astype(np.float32) return v_f, n_f, f_f, j_f, w_f, vert_colors.astype(np.float32)
@ -73,10 +66,8 @@ def _finalize_skinned_mesh(
def _pair_colors_from_kp( def _pair_colors_from_kp(
pairs: Tuple[Tuple[int, int], ...], kp_colors: np.ndarray, endpoint: int = 1, pairs: Tuple[Tuple[int, int], ...], kp_colors: np.ndarray, endpoint: int = 1,
) -> np.ndarray: ) -> np.ndarray:
"""Per-limb color = endpoint-vertex color from `kp_colors`. Default """Per-limb color from `kp_colors`. `endpoint=1` (default) picks the distal
`endpoint=1` picks the second (distal) vertex of each pair, which is vertex of each pair the OpenPose per-finger gradient for basetip fingers."""
the OpenPose-canonical per-finger gradient when fingers go basetip
(wrist=0 thumb1=1 thumb2=2 )."""
n = len(pairs) n = len(pairs)
out = np.zeros((n, 3), dtype=np.float32) out = np.zeros((n, 3), dtype=np.float32)
for i, (a, b) in enumerate(pairs): for i, (a, b) in enumerate(pairs):
@ -88,19 +79,13 @@ def _openpose_bind_at_rig_rest(
pose_data: Dict[str, Any], *, pose_data: Dict[str, Any], *,
include_hands: bool, face_vert_ids: Optional[np.ndarray], include_hands: bool, face_vert_ids: Optional[np.ndarray],
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
"""OpenPose keypoint positions at the rig's REST pose (T-pose at authoring """OpenPose keypoint positions at the rig's REST pose, from the override's
origin), built from the `_skeleton_override`'s `bind_global_m` (joint rest `bind_global_m` (joint rest TRS) and `rest_verts_m` (face landmarks).
TRS) and `rest_verts_m` (mesh rest verts for face landmarks).
Used as the static-bind for openpose-mode geometry so the GLB's static Used as the static-bind so the GLB's static POSITION sits at rig origin,
POSITION attribute sits at rig origin matching skeletal mode's bind and matching skeletal mode and producing the same restscene-frame-0 transition.
producing the same 'snap from rest to scene-frame-0' transition at the Returns None when the override lacks the needed mappings caller then falls
start of playback. Without this, the static geometry is at scene-frame-0 back to per-frame extraction (kp_seq[0])."""
(kp_seq[0]) and viewers that auto-fit on static POSITION will center on
the scene location, hiding the per-frame motion.
Returns None when the override is missing or doesn't carry all the needed
mappings caller falls back to per-frame extraction (kp_seq[0])."""
override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None
if override is None or "bind_global_m" not in override: if override is None or "bind_global_m" not in override:
return None return None
@ -141,19 +126,12 @@ def _openpose_bind_at_rig_rest(
def _extract_openpose_keypoints( def _extract_openpose_keypoints(
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
) -> np.ndarray: ) -> np.ndarray:
"""(N, 18, 3) OpenPose keypoint positions in rig-native Y-up metres. """(N, 18, 3) OpenPose keypoints in rig-native Y-up metres.
Two sources, in priority order: External-skeleton path: when the override carries `openpose18_joint_indices`
((18, 2) int32), synthesize from each person's `pred_joint_coords` (already
1. **External-skeleton path** when pose_data has `_skeleton_override` Y-up, no flip). MHR70 path (default): re-index `pred_keypoints_3d` to COCO-18
with `openpose18_joint_indices` ((18, 2) int32, see and un-flip y/z (stored y-down by sam3d_body).
`_resolve_openpose_keypoints_from_joints`), synthesize from each
person's `pred_joint_coords` directly. The override frame is already
rig-native Y-up, so no axis flip.
2. **MHR70 path** (default for SAM3DBody_Predict output) re-index the
first 70 of 308 MHR keypoints (`pred_keypoints_3d`) to COCO-18.
Stored y-down (post `j3d[..., [1,2]] *= -1` in sam3d_body), so we
un-flip y/z to match rig-native Y-up.
""" """
frames = pose_data["frames"] frames = pose_data["frames"]
N = len(frame_indices) N = len(frame_indices)
@ -195,10 +173,8 @@ def _extract_openpose_keypoints(
for t_idx, t in enumerate(frame_indices): for t_idx, t in enumerate(frame_indices):
person = frames[t][person_k] person = frames[t][person_k]
if "pred_keypoints_3d" not in person: if "pred_keypoints_3d" not in person:
# Diagnose the source: external-skeleton producers ship # External-skeleton producer without `openpose18_joint_indices`:
# `_skeleton_override` instead of MHR70 keypoints. If the # can't synthesize the 18-keypoint set.
# producer didn't populate `openpose18_joint_indices` either,
# we can't synthesize the 18-keypoint set.
if override is not None: if override is not None:
raise ValueError( raise ValueError(
"build_glb_openpose: this pose_data carries " "build_glb_openpose: this pose_data carries "
@ -229,15 +205,11 @@ def _extract_openpose_keypoints(
def _extract_openpose_hand_keypoints( def _extract_openpose_hand_keypoints(
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
) -> np.ndarray: ) -> np.ndarray:
"""(N, 42, 3) right+left OpenPose hand keypoints (21 + 21) in rig-native """(N, 42, 3) right+left OpenPose hand keypoints (21+21) in rig-native Y-up.
Y-up frame.
External-skeleton path: requires `openpose_hand21_r_joint_indices` AND External-skeleton path: needs `openpose_hand21_{r,l}_joint_indices` ((21, 2)
`openpose_hand21_l_joint_indices` ((21, 2) int32 each) in the override. int32) in the override, resolved against `pred_joint_coords`. MHR70 path:
Resolved against per-frame `pred_joint_coords` like the body path. re-orders `pred_keypoints_3d` 21..62 to OpenPose-21 (wrist + 5 fingers)."""
MHR70 path: re-orders `pred_keypoints_3d` indices 21..62 to OpenPose-21
(wrist + 5 fingers, thumbpinky, basetip)."""
frames = pose_data["frames"] frames = pose_data["frames"]
N = len(frame_indices) N = len(frame_indices)
out = np.zeros((N, 42, 3), dtype=np.float32) out = np.zeros((N, 42, 3), dtype=np.float32)
@ -307,10 +279,8 @@ def _extract_face_landmarks_from_verts(
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
vert_ids: np.ndarray, vert_ids: np.ndarray,
) -> np.ndarray: ) -> np.ndarray:
"""(N, K_face, 3) face landmarks sampled from per-frame `pred_vertices` """(N, K_face, 3) face landmarks sampled from `pred_vertices` at the given
at the supplied head-mesh vertex IDs, unflipped to MHR-native Y-up. vertex IDs, unflipped to Y-up. Per-frame deformation is already baked in."""
Each landmark inherits per-frame shape/expr/pose deformation for free
since `pred_vertices` already has it baked in."""
frames = pose_data["frames"] frames = pose_data["frames"]
N = len(frame_indices) N = len(frame_indices)
K = int(vert_ids.shape[0]) K = int(vert_ids.shape[0])
@ -335,18 +305,11 @@ def _build_openpose_spheres(
smooth_shade: bool = False, smooth_shade: bool = False,
joint_indices: Optional[np.ndarray] = None, joint_indices: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""UV sphere per OpenPose keypoint, rigidly skinned to that keypoint's """UV sphere per keypoint, rigidly skinned to that keypoint's joint and
joint, vertex-colored from kp_colors. `base_joint_idx` is added to the vertex-colored from kp_colors. `base_joint_idx` offsets the emitted JOINTS_0
emitted JOINTS_0 indices so callers can place this group at any offset indices (body=0, right hand=18, ); `joint_indices`, if given, sets explicit
in the shared skin (body=0, right hand=18, etc.). `joint_indices` (when per-sphere indices so callers can skip keypoints (e.g. SCAIL head dots).
given) overrides that with explicit per-sphere joint indices, so callers Returns (verts, normals, faces, joints4, weights4, vert_colors)."""
can skip keypoints (e.g. SCAIL head dots).
`smooth_shade=True` keeps the indexed mesh and writes per-vertex
normals via face-normal averaging round shading on the spheres.
`smooth_shade=False` (default) flat-shades by duplicating verts per
face, matching the existing OpenPose-mode look. Returns
(verts, normals, faces, joints4, weights4, vert_colors)."""
sv, sf = uv_sphere_unit() sv, sf = uv_sphere_unit()
K = bind_kp_m.shape[0] K = bind_kp_m.shape[0]
Nv = sv.shape[0] Nv = sv.shape[0]
@ -376,43 +339,23 @@ def _capsule_mesh_local(
end_width_frac: float = 0.3, end_width_frac: float = 0.3,
shape: str = "ellipsoid", shape: str = "ellipsoid",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Build a per-limb mesh in limb-local frame along +Y from y=0 (head """Per-limb mesh in limb-local frame along +Y from y=0 (head) to y=L (tail).
pole) to y=L (tail pole).
`shape` selects the silhouette: `shape`:
- 'ellipsoid' (default): tips are small hemispheres of radius - 'ellipsoid' (default): hemisphere tips of radius `W * end_width_frac`,
`W * end_width_frac`; body has ellipsoidal radius profile ellipsoidal sin(π·u) body profile (fat middle, narrow ends).
sin(π*u) from w_end at the junctions to W at the middle. Gives - 'capsule': SCAIL "rig" limb an OPEN cylinder of constant radius W,
a fat-middle / narrow-end stretched-ellipse look. no caps. Pair with same-radius sphere markers so they cap the ends
- 'capsule': SCAIL-style "rig" limb an OPEN cylinder of constant seamlessly (caps would bump out when sphere radius cap radius).
radius W with no hemisphere caps. Pair with sphere joint markers
of the same radius so the spheres seamlessly cap the open
cylinder ends (the cylinder cross-section ring at the endpoint
lies exactly on the sphere surface). Drawing hemisphere caps
inside the joint sphere creates a visible bump where the cap
pokes out unevenly when sphere radius cap radius open
cylinders avoid that.
Per-limb mesh is required because the cap height (w_end) depends on A per-limb mesh is needed because cap height depends on width one
the limb width a single canonical mesh can't produce true canonical mesh can't give true hemispheres for arbitrary L:W in ellipsoid.
hemispheres for arbitrary L:W ratios in ellipsoid mode.
Returns: Returns (verts (Nv,3), faces (Nf,3), weights (Nv,2) head/tail, sums to 1).
verts: (Nv, 3) float32 limb-local positions in meters.
faces: (Nf, 3) uint32 triangle indices.
weights: (Nv, 2) float32 (head, tail) skinning weights, linearly
interpolated by axial position (sums to 1).
""" """
W = max(1e-6, min(float(W), float(L) * 0.5 - 1e-6)) W = max(1e-6, min(float(W), float(L) * 0.5 - 1e-6))
if str(shape) == "capsule": if str(shape) == "capsule":
# SCAIL-style "rig" limb: an OPEN cylinder of constant radius W, # Open cylinder, no caps — sphere markers cap the ends (see docstring).
# no hemisphere caps. The sphere joint markers at each endpoint
# provide the rounded ends of the bone — when sphere_radius ==
# cylinder_radius, the cylinder cross-section ring at the bone
# endpoint lies exactly on the sphere surface, so silhouette is
# seamless. Hemisphere caps would create a visible bump where
# the cap pokes out of the sphere if cap_r ≠ marker_r, so we
# omit them entirely.
cap_r = 0.0 cap_r = 0.0
body_r = W body_r = W
if n_cap_lat is None: if n_cap_lat is None:
@ -425,7 +368,7 @@ def _capsule_mesh_local(
end_frac = float(min(0.95, max(0.05, end_width_frac))) end_frac = float(min(0.95, max(0.05, end_width_frac)))
cap_r = max(1e-7, W * end_frac) cap_r = max(1e-7, W * end_frac)
body_r = W body_r = W
# Ellipsoid defaults: more body rings to sample the sin(π·u) curve. # More body rings to sample the sin(π·u) curve.
if n_cap_lat is None: if n_cap_lat is None:
n_cap_lat = 3 n_cap_lat = 3
if n_body is None: if n_body is None:
@ -473,10 +416,7 @@ def _capsule_mesh_local(
phi = 2.0 * np.pi * k / n_lon phi = 2.0 * np.pi * k / n_lon
verts.append([body_r * float(np.cos(phi)), 0.0, body_r * float(np.sin(phi))]) verts.append([body_r * float(np.cos(phi)), 0.0, body_r * float(np.sin(phi))])
# Body intermediate rings (between the cap junctions for capped meshes, # Body intermediate rings (none for 'capsule', n_body=0 by default).
# between the two end rings for open cylinders). For 'capsule' mode
# n_body=0 by default — no intermediate rings needed for a constant-
# radius cylinder.
body_rings: List[int] = [] body_rings: List[int] = []
is_ellipsoid = str(shape) == "ellipsoid" is_ellipsoid = str(shape) == "ellipsoid"
for j in range(1, n_body + 1): for j in range(1, n_body + 1):
@ -572,11 +512,8 @@ def _scail_redirect_neck_stub(body_kp: np.ndarray) -> np.ndarray:
def _openpose_limb_rest_trs( def _openpose_limb_rest_trs(
bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...], bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...],
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Per-limb rest TRS: """Per-limb rest TRS: midpoints (K_pairs, 3) and unit a→b axes (or +Y if
midpoints (K_pairs, 3): rest midpoint between bind_kp_m[a] and bind_kp_m[b]. degenerate). Caller uses midpoints as rest translation, axes for alignment."""
rest_axes (K_pairs, 3): unit direction ab at rest (or +Y if degenerate).
Caller uses `midpoints` as each limb joint's rest translation (rotation =
identity), and `rest_axes` to compute per-frame alignment rotations."""
K_pairs = len(pairs) K_pairs = len(pairs)
mid = np.zeros((K_pairs, 3), dtype=np.float32) mid = np.zeros((K_pairs, 3), dtype=np.float32)
axis = np.zeros((K_pairs, 3), dtype=np.float32) axis = np.zeros((K_pairs, 3), dtype=np.float32)
@ -595,13 +532,10 @@ def _openpose_limb_rest_trs(
def _openpose_limb_anim_trs( def _openpose_limb_anim_trs(
kp_seq: np.ndarray, pairs: Tuple[Tuple[int, int], ...], rest_axes: np.ndarray, kp_seq: np.ndarray, pairs: Tuple[Tuple[int, int], ...], rest_axes: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
"""Per-frame limb TRS: """Per-frame limb TRS: anim_mid (N, K_pairs, 3) midpoints and anim_quat
anim_mid (N, K_pairs, 3): midpoint of (kp_seq[t][a], kp_seq[t][b]). (N, K_pairs, 4 xyzw) aligning each limb's rest axis to its frame-t axis.
anim_quat (N, K_pairs, 4): rotation (xyzw) that aligns each limb's rest Drives skin_matrix(t) = T(mid_t)·R_t·T(-mid_rest) rigid rotation about
axis to its frame-t axis. the rest midpoint, no LBS cross-section thinning."""
Together with rest TRS, this drives `skin_matrix(t) = T(mid_t) * R_t *
T(-mid_rest)` so each capsule rigidly rotates about its rest midpoint to
track the limb's current direction — no LBS cross-section thinning."""
N = kp_seq.shape[0] N = kp_seq.shape[0]
K_pairs = len(pairs) K_pairs = len(pairs)
anim_mid = np.zeros((N, K_pairs, 3), dtype=np.float32) anim_mid = np.zeros((N, K_pairs, 3), dtype=np.float32)
@ -616,7 +550,7 @@ def _openpose_limb_anim_trs(
n = float(np.linalg.norm(d)) n = float(np.linalg.norm(d))
if n > 1e-9: if n > 1e-9:
R[t, k] = rotation_align(ax_rest, d / n) R[t, k] = rotation_align(ax_rest, d / n)
quat = rotmat_to_quat_np(R).astype(np.float32) # (N, K_pairs, 4) xyzw quat = rotmat_to_quat_np(R).astype(np.float32)
return anim_mid, quat return anim_mid, quat
@ -628,20 +562,14 @@ def _build_openpose_sticks(
smooth_shade: bool = False, smooth_shade: bool = False,
end_width_frac: float = 0.3, end_width_frac: float = 0.3,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Capsule (cylinder + hemispherical caps) per limb pair (a, b). """Capsule per limb pair (a, b), each sized to its own length/width so caps
are true hemispheres regardless of L:W. Ellipsoid mode auto-clamps width to
`length * 0.1` so short limbs don't look chunky.
Each limb gets its own mesh sized to that limb's length and width so Rigid (weight=1) binding to a per-limb joint at `limb_joint_base_idx +
the caps are TRUE hemispheres of radius `half_width_eff` the limb limb_idx`, which the caller animates with midpoint translation + rotation
silhouette is rounded-rectangle-like, regardless of L:W ratio. Width (avoids LBS thinning). Returns (verts, normals, faces, joints4, weights4,
auto-clamped to `length * 0.1` so short limbs (face/ear) don't look vert_colors)."""
chunky next to long ones.
Skinning: rigid (weight=1) binding to a per-limb joint at
`limb_joint_base_idx + limb_idx` the caller animates that joint with
midpoint translation + rest-to-current rotation so each capsule rotates
rigidly with its limb (avoids translation-only LBS cross-section
thinning). Returns flat-shaded (verts, normals, faces, joints4,
weights4, vert_colors)."""
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32) canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
out_v_chunks: List[np.ndarray] = [] out_v_chunks: List[np.ndarray] = []
@ -663,13 +591,10 @@ def _build_openpose_sticks(
unit_dir = direction / length unit_dir = direction / length
R = rotation_align(canonical, unit_dir) R = rotation_align(canonical, unit_dir)
if is_capsule: if is_capsule:
# SCAIL-style uniform radius — every bone gets the same width. # Uniform radius — every bone the same width (clamped internally).
# `_capsule_mesh_local` clamps internally to L/2-eps so very
# short bones don't go degenerate.
half_width_eff = max(MIN_WIDTH, half_width_m) half_width_eff = max(MIN_WIDTH, half_width_m)
else: else:
# Ellipsoid mode: original auto-thinning so short face/ear # Auto-thin so short face/ear limbs aren't chunky next to body limbs.
# limbs don't look chunky next to long body limbs.
half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m)) half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m))
v_local, f_local, _weights_unused = _capsule_mesh_local( v_local, f_local, _weights_unused = _capsule_mesh_local(
@ -678,10 +603,8 @@ def _build_openpose_sticks(
v_world = v_local @ R.T + head v_world = v_local @ R.T + head
Nv = v_local.shape[0] Nv = v_local.shape[0]
# Rigid binding to the per-limb joint. The 2-bone (head, tail) weights # Rigid binding to the per-limb joint; the 2-bone weights are discarded
# from `_capsule_mesh_local` are discarded — they're translation-only # (translation-only under LBS, would thin the cross-section).
# under glTF LBS and don't rotate the cross-section, causing visible
# thinning when the limb axis changes between rest and animated pose.
j_arr = np.zeros((Nv, 4), dtype=np.uint16) j_arr = np.zeros((Nv, 4), dtype=np.uint16)
j_arr[:, 0] = limb_idx + limb_joint_base_idx j_arr[:, 0] = limb_idx + limb_joint_base_idx
w_arr = np.zeros((Nv, 4), dtype=np.float32) w_arr = np.zeros((Nv, 4), dtype=np.float32)
@ -730,40 +653,24 @@ def build_glb_openpose(
stick_end_width_frac: float = 0.6, stick_end_width_frac: float = 0.6,
bone_smooth_window: int = 0, bone_smooth_window: int = 0,
) -> bytes: ) -> bytes:
"""Build a GLB containing an OpenPose-style 3D skeleton — sphere markers """Build a GLB of an OpenPose-style 3D skeleton — sphere markers per keypoint
per keypoint plus rainbow-colored sticks between standard limb pairs. plus colored sticks between limb pairs, one armature per track. Body from
Body keypoints are sourced from pose_data's `pred_keypoints_3d` (no rig `pred_keypoints_3d`; optional hands (same source) and face landmarks
forward needed). Optional hand keypoints (also from `pred_keypoints_3d`) (`pred_vertices`) extend each armature.
and face landmarks (sampled from `pred_vertices` at fixed head-mesh
vertex IDs) extend the same per-track armature.
Args: Args:
include_hands: append the standard 21+21 OpenPose hand keypoints to include_hands: append the 21+21 OpenPose hand keypoints per track.
each track's armature (right hand at MHR70 indices 21..41, hand_marker_radius_m: hand sphere radius. 0 = auto = 0.4 × marker_radius_m.
left at 42..62). hand_stick_radius_m: hand limb half-width. 0 = auto = 0.5 × stick_radius_m.
hand_marker_radius_m: per-hand sphere radius. 0 = auto = 0.4 × hand_color_style: 'dwpose' (default) = solid-blue dots + rainbow sticks;
`marker_radius_m` (hand keypoints are anatomically smaller than 'openpose' = rainbow dots AND sticks.
body joints; matches DWPose's smaller hand dots). face_style: 'disabled' (default) | 'full' (~30 contour pts) | 'eyes_mouth'
hand_stick_radius_m: per-hand limb half-width. 0 = auto = 0.5 × (eyes + outer-lip subset); sampled at vertex IDs from
`stick_radius_m`. `canonical_colors["positions"]`.
hand_color_style: 'dwpose' (default) = solid-blue hand dots, face_marker_radius_m: face landmark sphere radius. 0 = auto = 0.3 ×
rainbow per-finger sticks (controlnet_aux/dwpose convention); marker_radius_m. Rendered as dots only, no contour lines.
'openpose' = rainbow per-finger dots AND sticks (matches palette: 'openpose' = rainbow gradient per keypoint; 'scail' = warm right
poseParameters.cpp::HAND_COLORS_RENDER). / cool left, grey centerline, distinct per-limb colors.
face_style: 'disabled' (default) | 'full' | 'eyes_mouth' face
landmarks sampled from `pred_vertices` at vertex IDs picked from
`pose_data["canonical_colors"]["positions"]`. 'full' = all ~30
contour points; 'eyes_mouth' = the eyes + outer-lip subset.
face_marker_radius_m: per-face landmark sphere radius. 0 = auto =
0.3 × `marker_radius_m` face landmarks are densely packed
around the eyes/mouth/jaw and need to be much smaller than
body keypoints to keep the layout legible. Face landmarks are
rendered as standalone dots (no contour lines), matching
DWPose's face_pose draw style.
palette: body color scheme. 'openpose' = standard rainbow gradient
per keypoint (canonical OpenPose convention); 'scail' =
SCAIL-Pose style warm hues right side, cool hues left side,
grey neck-to-nose centerline, distinct per-limb colors.
""" """
is_scail = str(palette) == "scail" is_scail = str(palette) == "scail"
# SCAIL drops the face bones (13..16) and eye/ear spheres; keeps nose (idx 0, # SCAIL drops the face bones (13..16) and eye/ear spheres; keeps nose (idx 0,
@ -771,13 +678,11 @@ def build_glb_openpose(
body_pairs = OPENPOSE_18_PAIRS[:13] if is_scail else OPENPOSE_18_PAIRS body_pairs = OPENPOSE_18_PAIRS[:13] if is_scail else OPENPOSE_18_PAIRS
body_sphere_kp = (np.arange(14, dtype=np.int64) body_sphere_kp = (np.arange(14, dtype=np.int64)
if is_scail else np.arange(18, dtype=np.int64)) if is_scail else np.arange(18, dtype=np.int64))
if str(palette) == "scail": if is_scail:
body_sphere_colors = SCAIL_KEYPOINT_COLORS_18 body_sphere_colors = SCAIL_KEYPOINT_COLORS_18
body_stick_colors = SCAIL_LIMB_COLORS_17 body_stick_colors = SCAIL_LIMB_COLORS_17
elif str(palette) == "openpose": elif str(palette) == "openpose":
# Existing OpenPose behavior: same rainbow array used for both # Same rainbow array drives both spheres and sticks.
# spheres (per-keypoint) and sticks (per-limb, indexed 0..16 of
# the 18-element rainbow — yields a legible per-limb gradient).
body_sphere_colors = OPENPOSE_RAINBOW_18 body_sphere_colors = OPENPOSE_RAINBOW_18
body_stick_colors = OPENPOSE_RAINBOW_18 body_stick_colors = OPENPOSE_RAINBOW_18
else: else:
@ -892,13 +797,9 @@ def build_glb_openpose(
if bone_smooth_window and bone_smooth_window > 1: if bone_smooth_window and bone_smooth_window > 1:
kp_seq = gaussian_smooth_positions(kp_seq, int(bone_smooth_window)) kp_seq = gaussian_smooth_positions(kp_seq, int(bone_smooth_window))
# Static-bind = rig's REST pose when available (override path); else # Static-bind = rig REST pose when available, else frame 0. The rest
# fall back to frame 0 of the motion. The rest-pose bind makes the # bind keeps static POSITION at rig origin so viewers auto-center there
# GLB's static POSITION attribute sit at rig origin, so viewers # and the motion is visible (see _openpose_bind_at_rig_rest).
# auto-fit/center on rig origin and the animation visibly snaps from
# rest to scene-frame-0 — matching skeletal mode's behavior. Without
# this, openpose's static geometry is at scene-frame-0 and viewers
# mis-center on the scene location, masking the motion entirely.
bind_kp_m_rest = _openpose_bind_at_rig_rest( bind_kp_m_rest = _openpose_bind_at_rig_rest(
pose_data, include_hands=include_hands, face_vert_ids=face_vert_ids, pose_data, include_hands=include_hands, face_vert_ids=face_vert_ids,
) )
@ -914,7 +815,7 @@ def build_glb_openpose(
person_root_idx = len(nodes) - 1 person_root_idx = len(nodes) - 1
scene_root_indices.append(person_root_idx) scene_root_indices.append(person_root_idx)
# K keypoint joint nodes (spheres bind here, rigid translation only). # K keypoint joint nodes (spheres bind here, translation only).
joint_node_indices: List[int] = [] joint_node_indices: List[int] = []
for j in range(K): for j in range(K):
nodes.append({ nodes.append({
@ -926,9 +827,7 @@ def build_glb_openpose(
joint_node_indices.append(len(nodes) - 1) joint_node_indices.append(len(nodes) - 1)
person_root["children"].extend(joint_node_indices) person_root["children"].extend(joint_node_indices)
# Per-limb REST TRS (midpoint + axis) and per-frame TRS (midpoint + # Per-limb rest + per-frame TRS; sticks bind rigidly to these joints.
# quaternion that aligns rest-axis → frame-t-axis). Sticks bind
# rigidly to these joints so each capsule rotates with its limb.
limb_rest_mids_list: List[np.ndarray] = [] limb_rest_mids_list: List[np.ndarray] = []
limb_rest_axes_list: List[np.ndarray] = [] limb_rest_axes_list: List[np.ndarray] = []
limb_anim_mids_list: List[np.ndarray] = [] limb_anim_mids_list: List[np.ndarray] = []
@ -951,12 +850,10 @@ def build_glb_openpose(
limb_rest_axes_list.append(raxis_h) limb_rest_axes_list.append(raxis_h)
limb_anim_mids_list.append(amid_h) limb_anim_mids_list.append(amid_h)
limb_anim_quats_list.append(aquat_h) limb_anim_quats_list.append(aquat_h)
limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0) # (K_limbs, 3) limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0)
limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1) # (N, K_limbs, 3) limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1)
limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1) # (N, K_limbs, 4) limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1)
# Hemisphere-align consecutive quats per limb so LINEAR interpolation # Hemisphere-align consecutive quats so LINEAR interp takes the short path.
# takes the short path (otherwise large per-frame rotations can flip
# signs and produce visible "twist back" artifacts mid-playback).
limb_anim_quats = quat_sign_fix_per_joint(limb_anim_quats).astype(np.float32) limb_anim_quats = quat_sign_fix_per_joint(limb_anim_quats).astype(np.float32)
limb_joint_indices: List[int] = [] limb_joint_indices: List[int] = []
@ -970,8 +867,8 @@ def build_glb_openpose(
limb_joint_indices.append(len(nodes) - 1) limb_joint_indices.append(len(nodes) - 1)
person_root["children"].extend(limb_joint_indices) person_root["children"].extend(limb_joint_indices)
# Combined skin: keypoint joints (IBM = T(-bind_kp_m)) then limb joints # Combined skin: keypoint joints then limb joints; IBM = T(-rest) for
# (IBM = T(-limb_rest_mid)). Both yield identity skin_matrix at rest. # both, yielding identity skin_matrix at rest.
all_joint_indices = joint_node_indices + limb_joint_indices all_joint_indices = joint_node_indices + limb_joint_indices
ibm = np.tile(np.eye(4, dtype=np.float32), (K + K_limbs, 1, 1)) ibm = np.tile(np.eye(4, dtype=np.float32), (K + K_limbs, 1, 1))
ibm[:K, :3, 3] = -bind_kp_m ibm[:K, :3, 3] = -bind_kp_m
@ -985,10 +882,8 @@ def build_glb_openpose(
}) })
skin_idx = len(skins) - 1 skin_idx = len(skins) - 1
# Per-group geometry. Spheres bind to keypoint joints (base_joint_idx # Per-group geometry. Spheres bind to keypoint joints [0, K); sticks to
# ∈ [0, K)); sticks bind to limb joints (limb_joint_base_idx ∈ # limb joints [K, K+K_limbs). Stacked body → R-hand → L-hand → face.
# [K, K + K_limbs)). Groups stack body → right hand → left hand →
# face for keypoint joints, and body → R-hand → L-hand for limbs.
group_meshes: List[Tuple[np.ndarray, np.ndarray, np.ndarray, group_meshes: List[Tuple[np.ndarray, np.ndarray, np.ndarray,
np.ndarray, np.ndarray, np.ndarray]] = [] np.ndarray, np.ndarray, np.ndarray]] = []
sp = _build_openpose_spheres( sp = _build_openpose_spheres(
@ -1008,9 +903,7 @@ def build_glb_openpose(
group_meshes.append(st) group_meshes.append(st)
if include_hands: if include_hands:
# Hand stick colors stay rainbow per-finger regardless of # Hand sticks stay rainbow per-finger; only dots switch under 'dwpose'.
# `hand_color_style` — only the sphere dots switch to solid
# blue under 'dwpose'. Matches controlnet_aux/dwpose/util.py.
hand_pair_colors = _pair_colors_from_kp( hand_pair_colors = _pair_colors_from_kp(
OPENPOSE_HAND_PAIRS, OPENPOSE_HAND_COLORS_21, endpoint=1, OPENPOSE_HAND_PAIRS, OPENPOSE_HAND_COLORS_21, endpoint=1,
) )
@ -1033,9 +926,7 @@ def build_glb_openpose(
if K_face > 0: if K_face > 0:
f_off = K_body + K_hands f_off = K_body + K_hands
f_bind = bind_kp_m[f_off:f_off + K_face] f_bind = bind_kp_m[f_off:f_off + K_face]
# DWPose face = dots only, no contour lines # DWPose face = dots only, no contour lines.
# (controlnet_aux/dwpose/util.py::draw_facepose draws white
# circles per landmark and never connects them).
group_meshes.append(_build_openpose_spheres( group_meshes.append(_build_openpose_spheres(
f_bind, float(face_marker_radius_m), f_bind, float(face_marker_radius_m),
FACE_LANDMARK_COLORS, base_joint_idx=f_off, FACE_LANDMARK_COLORS, base_joint_idx=f_off,
@ -1087,9 +978,8 @@ def build_glb_openpose(
"target": {"node": joint_node_indices[j], "path": "translation"}, "target": {"node": joint_node_indices[j], "path": "translation"},
}) })
# Per-limb-joint translation + rotation channels. Stationary limbs # Per-limb-joint translation + rotation; stationary limbs bake their
# have their constant TRS baked into the node so they don't bloat the # constant TRS into the node instead of an animation channel.
# animation buffer.
for k in range(K_limbs): for k in range(K_limbs):
t_k = limb_anim_mids[:, k, :].astype(np.float32) t_k = limb_anim_mids[:, k, :].astype(np.float32)
if (np.ptp(t_k, axis=0) < 1e-6).all(): if (np.ptp(t_k, axis=0) < 1e-6).all():
@ -1103,9 +993,7 @@ def build_glb_openpose(
"target": {"node": limb_joint_indices[k], "path": "translation"}, "target": {"node": limb_joint_indices[k], "path": "translation"},
}) })
q_k = limb_anim_quats[:, k, :].astype(np.float32) q_k = limb_anim_quats[:, k, :].astype(np.float32)
# ptp on the absolute value handles the +q == -q ambiguity, but # Plain ptp is fine — signs already aligned by quat_sign_fix_per_joint.
# `quat_sign_fix_per_joint` already aligned signs so a plain ptp
# is fine here.
if (np.ptp(q_k, axis=0) < 1e-6).all(): if (np.ptp(q_k, axis=0) < 1e-6).all():
nodes[limb_joint_indices[k]]["rotation"] = q_k[0].tolist() nodes[limb_joint_indices[k]]["rotation"] = q_k[0].tolist()
else: else:

View File

@ -1,15 +1,11 @@
"""GLB export for SAM 3D Body pose_data. """Shared GLB export helpers for SAM 3D Body pose_data.
Mode: skeletal rebuilds the MHR 127-bone rig. Per-frame local TRS comes from Skeletal mode rebuilds the MHR 127-bone rig: per-frame local TRS from
re-running param_transform on saved mhr_model_params; rest verts from a param_transform on mhr_model_params, rest verts from a zero-pose forward,
zero-pose forward with the person's shape_params; sparse triplet skinning is sparse skinning compacted to glTF's 4-influence form, expression re-exposed as
compacted to glTF's max-4-influences form; facial expression is re-exposed as 72 morph targets. Camera-y-down data is un-flipped to glTF Y-up. Pose
72 morph targets driven by expr_params. correctives are dropped (glTF skinning can't represent them), so extreme joint
angles differ from the SAM3DBody renderer by the corrective amount.
pred_vertices/pred_cam_t are camera-y-down un-flipped here so the GLB lives
in glTF-spec Y-up. Pose correctives are dropped (glTF skinning can't represent
them); deformation at extreme joint angles will differ from the SAM3DBody
renderer by the corrective amount.
""" """
from __future__ import annotations from __future__ import annotations
@ -24,12 +20,11 @@ import torch
from comfy_extras.sam3d_body.rasterizer import rainbow_colors_from_canonical from comfy_extras.sam3d_body.rasterizer import rainbow_colors_from_canonical
# fp32-rounded ln(2). Used as `exp(x * _LN2)` to compute 2**x bit-identically # fp32-rounded ln(2); exp(x * _LN2) matches the rig's own 2**x bit-for-bit.
# to the rig's own `torch.exp(jp[..., 6:7] * _LN2)`
_LN2 = 0.6931471824645996 _LN2 = 0.6931471824645996
# Quaternion / rotation helpers (xyzw convention, matching MHR rig) # Quaternion / rotation helpers (xyzw, matching MHR rig)
def _euler_xyz_to_quat_np(angles: np.ndarray) -> np.ndarray: def _euler_xyz_to_quat_np(angles: np.ndarray) -> np.ndarray:
"""(roll, pitch, yaw) -> (x, y, z, w). Mirrors mhr_rig._euler_xyz_to_quat.""" """(roll, pitch, yaw) -> (x, y, z, w). Mirrors mhr_rig._euler_xyz_to_quat."""
@ -96,8 +91,7 @@ def _skel_state_compose_np(s1: np.ndarray, s2: np.ndarray) -> np.ndarray:
def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray: def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
"""Edge-replicate Gaussian smoothing along axis 0 (time); sigma = window/4. """Edge-replicate Gaussian smoothing along time (sigma = window/4). float64."""
Endpoints replicate so they aren't pulled toward zero. Returns float64."""
a = np.asarray(arr, dtype=np.float64) a = np.asarray(arr, dtype=np.float64)
n = a.shape[0] n = a.shape[0]
half = window // 2 half = window // 2
@ -117,9 +111,8 @@ def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray: def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
"""Gaussian-smooth a (N, NJ, 4) quaternion sequence along time. Sign-aligns """Smooth a (N, NJ, 4) quaternion sequence along time: sign-align per joint,
per joint first, convolves per-component, renormalizes. Suppresses multi- convolve per-component, renormalize. Calms bone spikes at extreme poses."""
frame bone spikes at extreme poses without needing the upstream Smooth node."""
if window <= 1 or q_seq.shape[0] < 2: if window <= 1 or q_seq.shape[0] < 2:
return q_seq return q_seq
out = _gaussian_smooth_time(quat_sign_fix_per_joint(q_seq), window) out = _gaussian_smooth_time(quat_sign_fix_per_joint(q_seq), window)
@ -128,18 +121,16 @@ def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray: def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray:
"""Gaussian-smooth a (N, K, 3) position sequence along time (edge-replicate """Smooth a (N, K, 3) position sequence along time. Calms jittery keypoint
padding). Used to calm jittery keypoint tracks before the openpose rig tracks before the openpose rig derives sphere translations + limb TRS."""
derives sphere translations + limb TRS from them."""
if window <= 1 or seq.shape[0] < 2: if window <= 1 or seq.shape[0] < 2:
return seq return seq
return _gaussian_smooth_time(seq, window).astype(np.float32) return _gaussian_smooth_time(seq, window).astype(np.float32)
def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray: def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
"""Walk (N, NJ, 4) along time, flip sign whenever consecutive frames sit """Walk (N, NJ, 4) along time, flipping sign when consecutive frames sit on
on opposite hemispheres. Eliminates long-path slerp glitches (mid-anim opposite hemispheres. Avoids long-path slerp glitches. fp64 internally."""
cartwheel flip). fp64 to avoid drift; normalizes input defensively."""
out = np.array(q_seq, dtype=np.float64, copy=True) out = np.array(q_seq, dtype=np.float64, copy=True)
norms = np.linalg.norm(out, axis=-1, keepdims=True) norms = np.linalg.norm(out, axis=-1, keepdims=True)
out = out / np.maximum(norms, 1e-12) out = out / np.maximum(norms, 1e-12)
@ -151,11 +142,9 @@ def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
def bone_locals_from_globals(rig_global: np.ndarray, parents: np.ndarray) -> np.ndarray: def bone_locals_from_globals(rig_global: np.ndarray, parents: np.ndarray) -> np.ndarray:
"""Globals (N, NJ, 8) + parents -> per-bone local TRS (N, NJ, 8) such that """Globals (N, NJ, 8) + parents -> per-bone local TRS so FK reproduces
FK over (parents, bone_local) reproduces rig_global. local = rig_global. local = inverse(parent_global) child_global, robust to
inverse(parent_global) child_global makes this robust to hierarchy- hierarchy-convention mismatches in `parents`."""
convention mismatches: glTF FK gives back exactly rig_global even if
`parents` doesn't match the rig's pmi-walk."""
N, NJ, _ = rig_global.shape N, NJ, _ = rig_global.shape
bone_local = np.zeros_like(rig_global) bone_local = np.zeros_like(rig_global)
for j in range(NJ): for j in range(NJ):
@ -188,8 +177,7 @@ def _quat_to_mat3_np(q: np.ndarray) -> np.ndarray:
def collect_tracks(pose_data: Dict[str, Any], track_index: int) -> List[Tuple[int, List[int]]]: def collect_tracks(pose_data: Dict[str, Any], track_index: int) -> List[Tuple[int, List[int]]]:
"""List of (person_index, frame_indices). track_index == -1 means every """List of (person_index, frame_indices). track_index == -1 means every
present track; empty tracks are dropped. Same person index across frames present track; empty tracks dropped. Same person index = same subject."""
is assumed same subject (Smooth/Predict enforce this on tracked bboxes)."""
frames = pose_data["frames"] frames = pose_data["frames"]
max_p = max((len(f) for f in frames), default=0) max_p = max((len(f) for f in frames), default=0)
if max_p == 0: if max_p == 0:
@ -257,8 +245,7 @@ class GLBWriter:
return len(self.accessors) - 1 return len(self.accessors) - 1
def add_vec3_f32_no_minmax(self, arr: np.ndarray) -> int: def add_vec3_f32_no_minmax(self, arr: np.ndarray) -> int:
"""Morph-target POSITIONs: spec lets us skip min/max, avoiding a """Morph-target POSITIONs: spec lets us skip min/max."""
per-frame delta bbox."""
a = np.ascontiguousarray(arr, dtype=np.float32) a = np.ascontiguousarray(arr, dtype=np.float32)
view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY) view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY)
self.accessors.append({ self.accessors.append({
@ -288,9 +275,8 @@ class GLBWriter:
return len(self.accessors) - 1 return len(self.accessors) - 1
def add_scalar_f32_flat(self, arr: np.ndarray, count: int) -> int: def add_scalar_f32_flat(self, arr: np.ndarray, count: int) -> int:
"""Animation-output scalars: `count` is keyframes, not floats. Morph- """Animation-output scalars: `count` is keyframes, not floats (morph
target weight tracks store N_morph weights per keyframe as flat float32 weight tracks store N_morph weights per keyframe)."""
with count=N_keyframes."""
a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1) a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1)
view_idx = self._add_view(a.tobytes()) view_idx = self._add_view(a.tobytes())
self.accessors.append({ self.accessors.append({
@ -382,9 +368,8 @@ def bake_vertex_colors(
rainbow_tilt_z_deg: float, rainbow_tilt_z_deg: float,
pastel_mix: float, pastel_mix: float,
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
"""Per-vertex RGB matching the renderer's shader preset, on the canonical """Per-vertex RGB matching the renderer's shader preset. Returns (N_v, 3)
mesh. Returns (N_v, 3) float32 in [0, 1], or None for `default` (let the float32 in [0, 1], or None for `default` (use the viewer's material)."""
viewer's default material handle shading)."""
if shader == "default" or canonical_colors is None: if shader == "default" or canonical_colors is None:
return None return None
@ -432,8 +417,8 @@ def compute_normals(verts: np.ndarray, faces: np.ndarray) -> np.ndarray:
def _parents_from_pmi(rig: Any) -> np.ndarray: def _parents_from_pmi(rig: Any) -> np.ndarray:
"""Parent index per joint from skel_pmi. pmi is (2, 266): row 0 = child, """Parent index per joint from skel_pmi ((2, 266): row 0 child, row 1
row 1 = parent, split into BFS levels by skel_pmi_buffer_sizes. Roots = -1.""" parent, split into BFS levels by skel_pmi_buffer_sizes). Roots = -1."""
NJ = int(rig.NUM_JOINTS) NJ = int(rig.NUM_JOINTS)
pmi = rig.skel_pmi.cpu().numpy() pmi = rig.skel_pmi.cpu().numpy()
sizes = rig.skel_pmi_buffer_sizes.cpu().numpy().tolist() sizes = rig.skel_pmi_buffer_sizes.cpu().numpy().tolist()
@ -450,7 +435,9 @@ def _parents_from_pmi(rig: Any) -> np.ndarray:
def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Return ``_skeleton_override`` dict if present. Non-MHR skeletons supply """Return ``_skeleton_override`` dict if present. Non-MHR skeletons supply
this to bypass MHR rig extraction (see ComfyUI-Kimodo). Required keys: this to bypass MHR rig extraction (see ComfyUI-Kimodo).
Required keys:
parents: (NJ,) int32, -1 = root parents: (NJ,) int32, -1 = root
bind_global_m: (NJ, 8) f32 [t.xyz | q.xyzw | scale], meters bind_global_m: (NJ, 8) f32 [t.xyz | q.xyzw | scale], meters
lbs_compact_joints: (V, 8) uint16 pre-compacted skin influences lbs_compact_joints: (V, 8) uint16 pre-compacted skin influences
@ -458,39 +445,19 @@ def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict
lbs_compact_max_inf: int actual max influences ( 8) lbs_compact_max_inf: int actual max influences ( 8)
rest_verts_m: (V, 3) f32 rest_verts_m: (V, 3) f32
faces: (F, 3) uint32 faces: (F, 3) uint32
Optional:
per_frame_y_down: bool set False if pred_joint_coords are already Optional (enable openpose mode on external rigs):
rig-native Y-up (kimodo). Default True (MHR). per_frame_y_down: bool False if pred_joint_coords are already Y-up
openpose18_joint_indices: (18, 2) int32 body OpenPose-18 joint (kimodo). Default True (MHR).
index pair, resolved against per-frame openpose18_joint_indices: (18, 2) int32 body keypoint (a, b)
`pred_joint_coords`. Each row is joints, resolved against `pred_joint_coords`.
(joint_a, joint_b); b == -1 = single b == -1 = single joint, else midpoint of (a, b).
joint, else default midpoint of the two openpose18_joint_weights: (18,) f32 blend w: w*a + (1-w)*b
(lets producers approximate keypoints (default 0.5; outside [0,1] extrapolates; ignored
with no matching joint, e.g. Nose when b == -1).
midpoint(LeftEye, RightEye)). Enables openpose_hand21_{r,l}_joint_indices: (21, 2) int32 per-hand keypoint
`SAM3DBody_ToGLB(mode="openpose")` on maps; both required for include_hands=True.
external rigs. openpose_hand21_{r,l}_joint_weights: (21,) f32 optional, same as above.
openpose18_joint_weights: (18,) f32 optional per-keypoint blend
weight for the (a, b) mapping above.
Position = w*joints[a] + (1-w)*joints[b]
when b 0 (default w=0.5 midpoint).
Values outside [0, 1] EXTRAPOLATE past
the line segment used to approximate
landmarks with no nearby joint pair
(e.g. ears: w=2.0 along the eyeeye
axis puts each ear one eye-distance
outside the corresponding eye). Ignored
for single-joint rows (b = -1).
openpose_hand21_r_joint_indices: (21, 2) int32 right-hand OpenPose-21
(wrist + 5 fingers × 4 joints, basetip)
joint index pair. Required (alongside
the L counterpart) for openpose mode
with include_hands=True.
openpose_hand21_l_joint_indices: (21, 2) int32 left-hand counterpart.
openpose_hand21_r_joint_weights: (21,) f32 optional, same semantics as
`openpose18_joint_weights`.
openpose_hand21_l_joint_weights: (21,) f32 optional, same as above.
""" """
if pose_data is None: if pose_data is None:
return None return None
@ -502,12 +469,10 @@ def extract_rig_static(model: Any, pose_data: Optional[Dict[str, Any]] = None) -
use that instead of MHR-specific `model.head_pose.mhr` buffers.""" use that instead of MHR-specific `model.head_pose.mhr` buffers."""
override = _get_skeleton_override(pose_data) override = _get_skeleton_override(pose_data)
if override is not None: if override is not None:
# External rig: caller pre-compacts skin and supplies bind global directly, # External rig: skin pre-compacted, bind global supplied directly.
# so we don't need MHR's PCA pose / expression bases.
parents = np.asarray(override["parents"], dtype=np.int32) parents = np.asarray(override["parents"], dtype=np.int32)
rest_v = np.asarray(override["rest_verts_m"], dtype=np.float32) rest_v = np.asarray(override["rest_verts_m"], dtype=np.float32)
# BVH needs parent-relative bone OFFSETs (cm). MHR ships these directly; # BVH needs parent-relative bone offsets (cm); derive from bind globals.
# external rigs only give bind globals, so derive locals from them.
bind_global_m = np.asarray(override["bind_global_m"], dtype=np.float32) bind_global_m = np.asarray(override["bind_global_m"], dtype=np.float32)
local_bind = bone_locals_from_globals(bind_global_m[None], parents)[0] local_bind = bone_locals_from_globals(bind_global_m[None], parents)[0]
joint_translation_offsets = (local_bind[:, :3] * 100.0).astype(np.float32) joint_translation_offsets = (local_bind[:, :3] * 100.0).astype(np.float32)
@ -560,29 +525,26 @@ def compact_skin_to_n(
skin_indices: np.ndarray, vert_indices: np.ndarray, weights: np.ndarray, skin_indices: np.ndarray, vert_indices: np.ndarray, weights: np.ndarray,
num_verts: int, max_inf: int = 8, num_verts: int, max_inf: int = 8,
) -> Tuple[np.ndarray, np.ndarray, int]: ) -> Tuple[np.ndarray, np.ndarray, int]:
"""Sparse (joint, vert, weight) triplets -> dense (joints[V, max_inf], """Sparse (joint, vert, weight) triplets -> dense (joints, weights) of shape
weights[V, max_inf]). Keeps `max_inf` largest-magnitude influences, (V, max_inf), keeping the largest influences and renormalizing. `actual_max`
renormalizes. `actual_max` lets the caller skip JOINTS_1/WEIGHTS_1 when lets the caller skip JOINTS_1/WEIGHTS_1 when nothing exceeds 4 influences."""
nothing exceeds 4 influences."""
joints = np.zeros((num_verts, max_inf), dtype=np.uint16) joints = np.zeros((num_verts, max_inf), dtype=np.uint16)
out_w = np.zeros((num_verts, max_inf), dtype=np.float32) out_w = np.zeros((num_verts, max_inf), dtype=np.float32)
counts = np.zeros(num_verts, dtype=np.int32) counts = np.zeros(num_verts, dtype=np.int32)
if vert_indices.size: if vert_indices.size:
# lexsort secondary key first: groups by vert, weights descending within group. # Group by vert, weights descending within each group.
order = np.lexsort((-weights, vert_indices)) order = np.lexsort((-weights, vert_indices))
vi_sorted = vert_indices[order] vi_sorted = vert_indices[order]
sk_sorted = skin_indices[order] sk_sorted = skin_indices[order]
w_sorted = weights[order] w_sorted = weights[order]
# Per-row rank within its vertex group: 0 at each group start, +1 elsewhere. # Per-row rank within its vertex group (0 at each group start).
# group_start[i] is True when vi_sorted[i] starts a new vertex.
n = vi_sorted.size n = vi_sorted.size
group_start = np.empty(n, dtype=bool) group_start = np.empty(n, dtype=bool)
group_start[0] = True group_start[0] = True
np.not_equal(vi_sorted[1:], vi_sorted[:-1], out=group_start[1:]) np.not_equal(vi_sorted[1:], vi_sorted[:-1], out=group_start[1:])
pos = np.arange(n, dtype=np.int64) pos = np.arange(n, dtype=np.int64)
# Position of each row's group start, broadcast forward.
group_start_pos = np.maximum.accumulate(np.where(group_start, pos, 0)) group_start_pos = np.maximum.accumulate(np.where(group_start, pos, 0))
rank = pos - group_start_pos rank = pos - group_start_pos
@ -609,9 +571,8 @@ def zero_pose_rest_verts(
model: Any, shape_params: np.ndarray, expr_zero: bool = True, model: Any, shape_params: np.ndarray, expr_zero: bool = True,
pose_data: Optional[Dict[str, Any]] = None, pose_data: Optional[Dict[str, Any]] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Rig with zero pose + this subject's shape -> rest verts (V, 3) in """Zero pose + this subject's shape -> rest verts (V, 3) in rig-native Y-up
rig-native Y-up meters. External-skeleton path returns `rest_verts_m` meters. External path returns `rest_verts_m` directly."""
directly (no PCA shape space to expand)."""
override = _get_skeleton_override(pose_data) override = _get_skeleton_override(pose_data)
if override is not None: if override is not None:
return np.asarray(override["rest_verts_m"], dtype=np.float32) return np.asarray(override["rest_verts_m"], dtype=np.float32)
@ -624,14 +585,11 @@ def zero_pose_rest_verts(
sp = torch.from_numpy(np.ascontiguousarray(shape_params, dtype=np.float32)).to(device) sp = torch.from_numpy(np.ascontiguousarray(shape_params, dtype=np.float32)).to(device)
if sp.ndim == 1: if sp.ndim == 1:
sp = sp.unsqueeze(0) sp = sp.unsqueeze(0)
# mhr.forward(identity_coeffs, model_parameters, expr_coeffs): # rig.forward(shape, model_params, expr); zero pose + zero expr.
# identity_rest = base_shape + identity_basis @ shape;
# cat([model_params, zeros]) through param_transform; expr added.
model_params = torch.zeros(1, 204, device=device, dtype=dtype) model_params = torch.zeros(1, 204, device=device, dtype=dtype)
expr = torch.zeros(1, 72, device=device, dtype=dtype) expr = torch.zeros(1, 72, device=device, dtype=dtype)
verts, _ = rig(sp.to(dtype), model_params, expr, apply_correctives=False) verts, _ = rig(sp.to(dtype), model_params, expr, apply_correctives=False)
# Rig outputs cm; mhr_head divides by 100 for meters. Match that. verts_m = verts[0].cpu().float().numpy() / 100.0 # cm -> m
verts_m = verts[0].cpu().float().numpy() / 100.0
return verts_m.astype(np.float32) return verts_m.astype(np.float32)
@ -639,7 +597,7 @@ def global_skel_state_per_frame(
model: Any, mhr_model_params: np.ndarray, model: Any, mhr_model_params: np.ndarray,
) -> np.ndarray: ) -> np.ndarray:
"""Rig FK over a batch of mhr_model_params -> (N, NJ, 8) = (t cm, q xyzw, """Rig FK over a batch of mhr_model_params -> (N, NJ, 8) = (t cm, q xyzw,
scale). Bones are shape- and expression-independent so we pass zeros.""" scale). Bones are shape/expression-independent, so pass zeros."""
inner = model.model if hasattr(model, "model") else model inner = model.model if hasattr(model, "model") else model
rig = inner.head_pose.mhr rig = inner.head_pose.mhr
device = next(rig.parameters()).device device = next(rig.parameters()).device
@ -655,8 +613,8 @@ def global_skel_state_per_frame(
def rotmat_to_quat_np(R: np.ndarray) -> np.ndarray: def rotmat_to_quat_np(R: np.ndarray) -> np.ndarray:
"""(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978 branched, largest-component """(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978, largest-component pick.
pick for stability. Cross-frame sign-fixing is the caller's job.""" Cross-frame sign-fixing is the caller's job."""
shape = R.shape[:-2] shape = R.shape[:-2]
Rf = R.reshape(-1, 3, 3).astype(np.float64) Rf = R.reshape(-1, 3, 3).astype(np.float64)
M = Rf.shape[0] M = Rf.shape[0]
@ -703,14 +661,12 @@ def global_skel_state_from_pose_data(
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
NJ: int, *, joint_coords_y_down: bool = True, NJ: int, *, joint_coords_y_down: bool = True,
) -> np.ndarray: ) -> np.ndarray:
"""Build per-frame skel_state from stored pred_global_rots + pred_joint_coords, """Per-frame skel_state from stored pred_global_rots + pred_joint_coords,
bypassing rig.forward. Returns (N, NJ, 8) in METERS, MHR-native frame. bypassing rig.forward. Returns (N, NJ, 8) in meters, MHR-native frame.
pred_global_rots are MHR-native (no y/z flip). For MHR, pred_joint_coords pred_global_rots are MHR-native. pred_joint_coords are y-down for MHR
are stored y-down (post-flip), so un-flip when `joint_coords_y_down=True`. (un-flipped when `joint_coords_y_down=True`); external rigs store y-up
External skeletons (Kimodo) store y-up already pass False. Scale (pass False). Scale defaults to 1 (not preserved in pose_data)."""
defaults to 1 (rig scale isn't preserved in pose_data; close to 1 for
typical body poses)."""
frames = pose_data["frames"] frames = pose_data["frames"]
N = len(frame_indices) N = len(frame_indices)
rotmat = np.zeros((N, NJ, 3, 3), dtype=np.float32) rotmat = np.zeros((N, NJ, 3, 3), dtype=np.float32)
@ -731,10 +687,8 @@ def global_skel_state_from_pose_data(
def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> np.ndarray: def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
"""Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm. """Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm,
Inverse of `lbs_inverse_bind_pose` modulo precision; used as bones' static used as bones' static TRS. External rig: convert `bind_global_m` m -> cm."""
TRS so the rest mesh looks correct with no animation playing. External
rig: convert override's `bind_global_m` from m → cm to match this contract."""
override = _get_skeleton_override(pose_data) override = _get_skeleton_override(pose_data)
if override is not None: if override is not None:
bind_m = np.asarray(override["bind_global_m"], dtype=np.float32).copy() bind_m = np.asarray(override["bind_global_m"], dtype=np.float32).copy()
@ -746,13 +700,10 @@ def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> n
@dataclass @dataclass
class Rig: class Rig:
"""Normalized static rig for the GLB/BVH exporters, independent of where it """Normalized static rig for the GLB/BVH exporters, source-independent: MHR
came from: an MHR model (`Rig.from_pose_data(pose_data, model)`) or an inline model or inline `pose_data["_skeleton_override"]` (external rigs). Consumers
`pose_data["_skeleton_override"]` (external rigs, e.g. ComfyUI-Kimodo). never branch on the source. Only `rest_verts_m` is source-dependent MHR
expands it from `shape_params`; external rigs ship it fixed.
Consumers read these fields and never branch on the source. The only
source-dependent operation is `rest_verts_m` MHR rest verts depend on the
subject's `shape_params`; external rigs ship fixed rest verts.
""" """
parents: np.ndarray # (NJ,) int32, -1 = root parents: np.ndarray # (NJ,) int32, -1 = root
joint_offsets_cm: np.ndarray # (NJ, 3) parent-relative bind offsets, cm joint_offsets_cm: np.ndarray # (NJ, 3) parent-relative bind offsets, cm
@ -816,9 +767,8 @@ class Rig:
def ibp_from_bind_global(bind_skel_state_m: np.ndarray) -> np.ndarray: def ibp_from_bind_global(bind_skel_state_m: np.ndarray) -> np.ndarray:
"""Inverse-bind MAT4 by inverting the rig's bind global (meters). Guarantees """Inverse-bind MAT4 from the rig's bind global (meters). IBP[j] =
IBP[j] = inverse(FK over bind local TRS) exactly what glTF skinning inverse(FK over bind local TRS), as glTF skinning needs. Returns (NJ, 4, 4)
needs given bones default to the bind local TRS. Returns (NJ, 4, 4)
column-major.""" column-major."""
NJ = bind_skel_state_m.shape[0] NJ = bind_skel_state_m.shape[0]
t = bind_skel_state_m[:, :3].astype(np.float32) t = bind_skel_state_m[:, :3].astype(np.float32)
@ -877,10 +827,8 @@ def _ibp_to_mat4(ibp_skel: np.ndarray) -> np.ndarray:
def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndarray]: def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndarray]:
"""Unit UV sphere, poles ±Y. `n_lat` kept ODD by default so one ring """Unit UV sphere, poles ±Y. `n_lat` odd so a ring lands at the equator;
lands at the equator. Default (9, 16) gives 146 verts / 288 faces n_lon n_lon=16 matches the capsule cylinder so end rings meet flush."""
matches the 16-segment cylinder used by capsule limbs AND the equator
ring aligns 1-to-1 with the cylinder end ring, so silhouettes meet flush."""
verts: List[List[float]] = [[0.0, -1.0, 0.0]] # south pole at index 0 verts: List[List[float]] = [[0.0, -1.0, 0.0]] # south pole at index 0
for i in range(1, n_lat + 1): for i in range(1, n_lat + 1):
lat = -0.5 * np.pi + np.pi * i / (n_lat + 1) lat = -0.5 * np.pi + np.pi * i / (n_lat + 1)
@ -924,8 +872,8 @@ def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndar
def flat_shade_mesh( def flat_shade_mesh(
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray, verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Smooth -> flat by duplicating verts per face; each triangle gets 3 """Flat-shade by duplicating verts per face; each triangle gets 3 unique
unique verts sharing its face normal. Skinning attrs duplicated alongside.""" verts sharing its face normal. Skinning attrs duplicated alongside."""
F = faces.shape[0] F = faces.shape[0]
new_v = np.zeros((F * 3, 3), dtype=np.float32) new_v = np.zeros((F * 3, 3), dtype=np.float32)
new_n = np.zeros((F * 3, 3), dtype=np.float32) new_n = np.zeros((F * 3, 3), dtype=np.float32)
@ -949,9 +897,8 @@ def flat_shade_mesh(
def smooth_shade_mesh( def smooth_shade_mesh(
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray, verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Area-weighted per-vertex normals (smooth shading). Geometry, skinning, """Area-weighted per-vertex normals. Geometry/skinning/indexing pass through
indexing pass through unchanged so vertex colors stay aligned. Orphan unchanged so vertex colors stay aligned. Orphan verts get +Y fallback."""
verts get +Y fallback."""
Nv = int(verts.shape[0]) Nv = int(verts.shape[0])
v0 = verts[faces[:, 0]] v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]] v1 = verts[faces[:, 1]]
@ -994,11 +941,9 @@ def rotation_align(from_vec: np.ndarray, to_vec: np.ndarray) -> np.ndarray:
def make_lit_material( def make_lit_material(
roughness: float = 0.85, double_sided: bool = False, opacity: float = 1.0, roughness: float = 0.85, double_sided: bool = False, opacity: float = 1.0,
) -> dict: ) -> dict:
"""Lit PBR material using vertex COLOR_0 multiplicatively. KHR_materials_unlit """Lit PBR material using vertex COLOR_0. Dielectric (metallic=0) so colors
is intentionally off so viewer lighting reveals surface form. metallic=0 stay readable; roughness 0.85 suits rainbow body meshes, 0.3 the glossy
keeps the surface dielectric so vertex colors stay readable. roughness=0.85 SCAIL rig. opacity < 1 switches to alpha-blend."""
suits dense rainbow body meshes; 0.3 matches SCAIL-Pose's glossy rig look.
opacity < 1 switches to alpha-blend (e.g. see-through body mesh over bones)."""
a = float(max(0.0, min(1.0, opacity))) a = float(max(0.0, min(1.0, opacity)))
mat = { mat = {
"pbrMetallicRoughness": { "pbrMetallicRoughness": {
@ -1182,14 +1127,12 @@ def openpose_render_keypoints(
person: Dict[str, Any], pose_data: Optional[Dict[str, Any]], part: str, person: Dict[str, Any], pose_data: Optional[Dict[str, Any]], part: str,
*, dim: int, H: int = 0, W: int = 0, *, dim: int, H: int = 0, W: int = 0,
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
"""OpenPose keypoints for one person, in op-layout, CAMERA frame (Y-down). """OpenPose keypoints for one person, op-layout, camera frame (Y-down).
`part` in {'body','hand_r','hand_l'}. dim=3 -> (K, 3) metres pre-cam_t-add; `part` in {'body','hand_r','hand_l'}. dim=3 -> (K, 3) metres pre-cam_t-add;
dim=2 -> (K, 2) image pixels. Returns None when the source data is missing. dim=2 -> (K, 2) pixels. Returns None when source data is missing.
External rigs (override carries the joint-index map) resolve from per-frame External rigs resolve from `pred_joint_coords` (Y-up -> flipped to Y-down);
`pred_joint_coords` (rig-native Y-up -> flipped to camera Y-down, matching MHR reindexes stored `pred_keypoints_{3d,2d}` via the MHR70 map."""
the pred_vertices convention). MHR reindexes the stored
`pred_keypoints_{3d,2d}` via the MHR70 map."""
map_key, w_key, mhr_map = _OPENPOSE_RENDER_MAPS[part] map_key, w_key, mhr_map = _OPENPOSE_RENDER_MAPS[part]
override = _get_skeleton_override(pose_data) override = _get_skeleton_override(pose_data)
ext_map = override.get(map_key) if override is not None else None ext_map = override.get(map_key) if override is not None else None
@ -1228,11 +1171,9 @@ def openpose_render_keypoints(
return kp_full[mhr_map] return kp_full[mhr_map]
# Face landmarks from the MHR rig (option `face_source="rig"`). # Face landmarks (face_source="rig"). MHR has no face bones, so landmarks are
# MHR has no face bones — face deforms via expr_params morphs — so landmarks # sourced from `pred_vertices` at vertex IDs picked by NN against the target xyz
# are sourced from `pred_vertices` at fixed vertex IDs picked by NN against # below. Tweak targets if landmarks land off-surface.
# anatomically-plausible target xyz in canonical Y-up. Iterate visually in
# Blender and tweak targets if landmarks land off-surface.
# (name, target_xyz) in MHR canonical Y-up meters. # (name, target_xyz) in MHR canonical Y-up meters.
FACE_LANDMARK_TARGETS: Tuple[Tuple[str, Tuple[float, float, float]], ...] = ( FACE_LANDMARK_TARGETS: Tuple[Tuple[str, Tuple[float, float, float]], ...] = (
@ -1290,10 +1231,8 @@ def select_face_landmark_vert_ids(
face_mask: Optional[np.ndarray] = None, face_mask: Optional[np.ndarray] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Pick MHR head vertex IDs for each `FACE_LANDMARK_TARGETS` by NN in """Pick MHR head vertex IDs for each `FACE_LANDMARK_TARGETS` by NN in
canonical positions. Filter: `face_mask` (verts that deform with any of canonical positions, restricted to `face_mask` verts (expression-deforming)
the 72 expression axes) if available keeps chin/jaw search off the when available, else a position bbox (less reliable around the chin/jaw)."""
neck. Otherwise a position bbox (less reliable; throat verts sometimes
pull chin targets)."""
P = np.asarray(canonical_positions, dtype=np.float32).reshape(-1, 3) P = np.asarray(canonical_positions, dtype=np.float32).reshape(-1, 3)
if face_mask is not None and np.asarray(face_mask).any(): if face_mask is not None and np.asarray(face_mask).any():
valid = np.where(np.asarray(face_mask).reshape(-1))[0] valid = np.where(np.asarray(face_mask).reshape(-1))[0]

View File

@ -1,19 +1,11 @@
"""GLB export — skeletal (real armature) mode. """GLB export — skeletal (real armature) mode.
Rebuilds an Armature with the MHR 127-bone rig: Rebuilds an Armature with the MHR 127-bone rig: per-frame local TRS from
- per-frame local TRS comes from re-running param_transform on the saved param_transform on `mhr_model_params`, rest verts from a zero-pose forward,
`mhr_model_params`; sparse skinning compacted to glTF's 4-influence form, and facial expression as
- rest verts come from a zero-pose forward with each person's `shape_params`; 72 morph targets driven by `expr_params`. Optional octahedron bone-vis is
- sparse triplet skinning is compacted to glTF's max-4-influences-per-vertex form; rigidly skinned alongside for viewers that don't draw bones. Shared infra lives
- facial expression is re-exposed as 72 morph targets driven by `expr_params` in `glb_shared.py`.
so face animation survives plain glTF skinning.
Optional bone visualization (octahedrons) is rigidly
skinned alongside the body mesh used to preview the armature in glTF
viewers that don't draw bones.
Shared GLB infra (writer, math, rig static extraction, shaders, normals)
stays in `glb_shared.py`; only this mode's geometry + assembly live here.
""" """
from __future__ import annotations from __future__ import annotations
@ -44,8 +36,7 @@ from .glb_shared import (
from comfy_extras.sam3d_body.utils import jet_colormap from comfy_extras.sam3d_body.utils import jet_colormap
def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]: def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]:
"""Per-bone RGB color (NJ, 3) float32 in [0, 1]. Returns None for 'white' """Per-bone RGB (NJ, 3) float32 in [0, 1]. None for 'white' (default material)."""
(no per-bone color bone-vis mesh uses default unlit material)."""
if scheme == "rainbow_y": if scheme == "rainbow_y":
y = bind_pos_m[:, 1].astype(np.float32) y = bind_pos_m[:, 1].astype(np.float32)
y_min, y_max = float(y.min()), float(y.max()) y_min, y_max = float(y.min()), float(y.max())
@ -55,9 +46,8 @@ def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray
def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]: def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
"""Canonical Blender-style bone octahedron. Head at origin, tail at +Y, """Canonical Blender-style bone octahedron: head at origin, tail at +Y, unit
unit length, ridge at 1/10 height. 6 verts, 8 triangles. Faces wound length, ridge at 1/10 height. 6 verts, 8 triangles, faces wound outward."""
so cross(v1-v0, v2-v0) points OUTWARD from the bone axis."""
v = np.array([ v = np.array([
[0.0, 0.0, 0.0], # 0: head [0.0, 0.0, 0.0], # 0: head
[0.0, 1.0, 0.0], # 1: tail [0.0, 1.0, 0.0], # 1: tail
@ -78,18 +68,16 @@ def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
def _bone_edges( def _bone_edges(
joint_pos_m: np.ndarray, parents: np.ndarray, joint_pos_m: np.ndarray, parents: np.ndarray,
) -> List[Tuple[int, int, np.ndarray, np.ndarray]]: ) -> List[Tuple[int, int, np.ndarray, np.ndarray]]:
"""Return one (parent_idx, child_idx, head_pos, tail_pos) tuple per """One (parent_idx, child_idx, head_pos, tail_pos) per parent→child edge.
parentchild edge in the hierarchy, skipping edges whose PARENT is a Skips edges whose parent is a root (world-anchor sticks) and zero-length
root joint (those typically anchor the skeleton at world origin and edges."""
just look like a stray stick from origin to the body). Zero-length
edges are skipped too."""
NJ = joint_pos_m.shape[0] NJ = joint_pos_m.shape[0]
out: List[Tuple[int, int, np.ndarray, np.ndarray]] = [] out: List[Tuple[int, int, np.ndarray, np.ndarray]] = []
for c in range(NJ): for c in range(NJ):
p = int(parents[c]) p = int(parents[c])
if not (0 <= p < NJ and p != c): if not (0 <= p < NJ and p != c):
continue continue
# Skip if parent itself is a root — that bone is a world-anchor stick. # Skip world-anchor sticks: parent itself is a root.
gp = int(parents[p]) gp = int(parents[p])
if not (0 <= gp < NJ and gp != p): if not (0 <= gp < NJ and gp != p):
continue continue
@ -104,9 +92,8 @@ def _bone_edges(
def _build_bone_octahedrons_mesh( def _build_bone_octahedrons_mesh(
bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02, bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""One Blender-style octahedron per parent→child edge. Returns """One octahedron per parent→child edge. Returns (verts, normals, faces,
(verts, normals, faces, joints, weights, child_idx_per_vert); joints, weights, child_idx_per_vert); child_idx feeds per-bone color."""
child_idx feeds per-bone color lookup at the call site."""
base_v, base_f = _octahedron_unit() base_v, base_f = _octahedron_unit()
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32) canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
@ -117,8 +104,7 @@ def _build_bone_octahedrons_mesh(
out_w: List[List[float]] = [] out_w: List[List[float]] = []
child_per_vert: List[int] = [] child_per_vert: List[int] = []
# Width scales with length so short bones (fingers, face) don't look chunky # Width scales with length (capped by half_width_m) so short bones aren't chunky.
# next to long ones (limbs, spine). `half_width_m` caps long bones.
WIDTH_RATIO = 0.1 WIDTH_RATIO = 0.1
MIN_WIDTH = 0.001 MIN_WIDTH = 0.001
for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents): for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents):
@ -151,8 +137,8 @@ def _build_bone_octahedrons_mesh(
out_n.extend(n_world.tolist()) out_n.extend(n_world.tolist())
for face in base_f: for face in base_f:
out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off]) out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off])
# Dual skin head→parent, tail→child, ridges blend by canonical Y so the # Dual skin (head→parent, tail→child); ridges blend by canonical Y so
# bone stretches between joints instead of going rigid with one. # the bone stretches between joints instead of going rigid with one.
for k in range(base_v.shape[0]): for k in range(base_v.shape[0]):
y_canon = float(base_v[k, 1]) y_canon = float(base_v[k, 1])
w_parent = max(0.0, 1.0 - y_canon) w_parent = max(0.0, 1.0 - y_canon)
@ -196,22 +182,17 @@ def build_glb_skeletal(
bone_vis_color: str = "white", bone_vis_color: str = "white",
include_body_mesh: bool = True, include_body_mesh: bool = True,
) -> bytes: ) -> bytes:
"""Build pose_data as a real Armature GLB blob with per-bone TRS keyframes. """Build pose_data as a real Armature GLB with per-bone TRS keyframes. For
MHR, facial expression is exposed as 72 morph targets when
include_face_morphs=True.
For MHR (default) facial expression is exposed as 72 morph targets driven External skeletons (e.g. ComfyUI-Kimodo) can supply
by expr_params per frame when include_face_morphs=True. ``pose_data["_skeleton_override"]`` to bypass MHR rig extraction (``model``
may be None then); per-frame state still reads ``pred_global_rots`` /
External skeletons (e.g. ComfyUI-Kimodo) can supply a ``pred_joint_coords``. See ``glb_shared._get_skeleton_override`` for the schema.
``pose_data["_skeleton_override"]`` dict to bypass the MHR rig extraction
entirely. When present, ``model`` may be None and the rig data, bind pose,
skin weights, and rest verts come from the override. Per-frame skeletal
state still reads ``pred_global_rots`` / ``pred_joint_coords`` from each
person dict (kimodo populates these from its own FK output). See
``glb.shared._get_skeleton_override`` for the override schema.
""" """
frames = pose_data["frames"] frames = pose_data["frames"]
# Only `pred_cam_t` is camera-y-down; mhr_model_params, lbs_*, expr_basis, # Only `pred_cam_t` is camera-y-down; everything else is rig-native Y-up.
# faces are all rig-native (Y-up).
faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32) faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32)
tracks = collect_tracks(pose_data, track_index) tracks = collect_tracks(pose_data, track_index)
if not tracks: if not tracks:
@ -219,17 +200,14 @@ def build_glb_skeletal(
rig = Rig.from_pose_data(pose_data, model) rig = Rig.from_pose_data(pose_data, model)
NJ = rig.num_joints NJ = rig.num_joints
# NV = rig.num_verts
NEXPR = rig.num_expr NEXPR = rig.num_expr
parents = rig.parents parents = rig.parents
if not rig.can_rerun_fk: if not rig.can_rerun_fk:
# External rigs have no PCA pose params to re-run; only stored globals # External rigs have no PCA pose params to re-run; use stored globals.
# are available, and they store joint coords already Y-up.
use_stored_global_rots = True use_stored_global_rots = True
joint_coords_y_down = rig.per_frame_y_down joint_coords_y_down = rig.per_frame_y_down
# Skinning is already compacted to ≤8 influences per vertex (MHR averages # Skin already compacted to ≤8 influences/vertex (some shoulder/hip verts
# ~2.8 but some shoulder/hip verts hit 5-8; keeping only 4 there leaks # need >4, else per-bone rotation noise leaks into the mesh).
# per-bone rotation noise into the rendered mesh).
joints_8 = rig.lbs_joints joints_8 = rig.lbs_joints
weights_8 = rig.lbs_weights weights_8 = rig.lbs_weights
actual_max_inf = rig.lbs_max_inf actual_max_inf = rig.lbs_max_inf
@ -238,14 +216,12 @@ def build_glb_skeletal(
use_set1 = actual_max_inf > 4 use_set1 = actual_max_inf > 4
joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None
weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None
# Derive bone locals from the rig's bind globals rather than recomputing # Derive bone locals from bind globals so any `parents`/FK mismatch is
# FK ourselves, so any mismatch between `parents` and the rig's actual FK # absorbed into the local TRS instead of producing wrong globals.
# is absorbed into the local TRS instead of producing wrong globals.
bind_global_m = rig.bind_global_m bind_global_m = rig.bind_global_m
bind_local = bone_locals_from_globals(bind_global_m[None], parents)[0] bind_local = bone_locals_from_globals(bind_global_m[None], parents)[0]
# IBP = inverse of bind global. With bone defaults set to bind_local and # IBP = inverse of bind global → skin_matrix at rest is identity.
# FK composed via `parents`, skin_matrix at rest = identity.
ibp_mat4 = ibp_from_bind_global(bind_global_m) ibp_mat4 = ibp_from_bind_global(bind_global_m)
w = GLBWriter() w = GLBWriter()
@ -316,9 +292,7 @@ def build_glb_skeletal(
body_mesh_node_idx: Optional[int] = None body_mesh_node_idx: Optional[int] = None
if include_body: if include_body:
# MHR rest verts depend on the subject's shape_params; external rigs # MHR rest verts depend on shape_params; external rigs ignore the arg.
# ship fixed rest verts and ignore the arg (so the empty external
# `shape_params` is harmless).
shape_params_arr = np.asarray( shape_params_arr = np.asarray(
frames[frame_indices[0]][person_k].get("shape_params", []), frames[frame_indices[0]][person_k].get("shape_params", []),
dtype=np.float32, dtype=np.float32,
@ -349,8 +323,8 @@ def build_glb_skeletal(
"indices": indices_acc, "indices": indices_acc,
"mode": 4, "mode": 4,
} }
# See-through body when bones are shown, else opaque (only when a # See-through body when bones are shown, else opaque (only if a
# vertex-color shader baked COLOR_0 — otherwise default material). # shader baked COLOR_0; otherwise default material).
if color_acc is not None or include_bones: if color_acc is not None or include_bones:
materials.append(make_lit_material(opacity=0.35 if include_bones else 1.0)) materials.append(make_lit_material(opacity=0.35 if include_bones else 1.0))
primitive["material"] = len(materials) - 1 primitive["material"] = len(materials) - 1
@ -373,8 +347,7 @@ def build_glb_skeletal(
if include_bones: if include_bones:
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color) bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
# Indexes `bone_palette`: octahedrons use the bone's child joint so # Color by child joint so every bone has its own color.
# every bone has its own color regardless of skin target.
color_idx_per_vert: Optional[np.ndarray] = None color_idx_per_vert: Optional[np.ndarray] = None
hw = float(bone_vis_radius_m) hw = float(bone_vis_radius_m)
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh( bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
@ -422,8 +395,8 @@ def build_glb_skeletal(
nodes.append(bv_mesh_node) nodes.append(bv_mesh_node)
person_root["children"].append(len(nodes) - 1) person_root["children"].append(len(nodes) - 1)
# Per-frame GLOBAL skel state → bone locals via parent-inverse. # Per-frame global skel state → bone locals via parent-inverse. Stored
# Default uses the rig's stored output; the fallback re-runs FK. # output by default; fallback re-runs FK.
if use_stored_global_rots: if use_stored_global_rots:
rig_global_m = global_skel_state_from_pose_data( rig_global_m = global_skel_state_from_pose_data(
pose_data, frame_indices, person_k, NJ, pose_data, frame_indices, person_k, NJ,
@ -437,11 +410,9 @@ def build_glb_skeletal(
rig_global_cm = global_skel_state_per_frame(model, mp_per_frame) rig_global_cm = global_skel_state_per_frame(model, mp_per_frame)
rig_global_m = rig_global_cm.copy().astype(np.float32) rig_global_m = rig_global_cm.copy().astype(np.float32)
rig_global_m[..., :3] *= 0.01 rig_global_m[..., :3] *= 0.01
# Sign-fix on the GLOBAL quats BEFORE deriving locals. The rig's # Sign-fix global quats BEFORE deriving locals: a parent's ±180° flip
# Euler-XYZ parametrization wraps at ±180° for spinning joints; if we # would otherwise propagate into the child's local translation and cause
# only fix locals, the parent's flip propagates into the child's # visible "axis resets" mid-animation.
# local translation (t_local inherits parent sign via q_parent_inv)
# and produces visible "axis resets" mid-animation.
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7]) rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
bone_local_anim = bone_locals_from_globals(rig_global_m, parents) bone_local_anim = bone_locals_from_globals(rig_global_m, parents)
local_t = bone_local_anim[..., :3].astype(np.float32) local_t = bone_local_anim[..., :3].astype(np.float32)
@ -449,20 +420,17 @@ def build_glb_skeletal(
local_s = bone_local_anim[..., 7].astype(np.float32) local_s = bone_local_anim[..., 7].astype(np.float32)
# Second pass on locals catches residual drift from the parent-inverse. # Second pass on locals catches residual drift from the parent-inverse.
local_q = quat_sign_fix_per_joint(local_q) local_q = quat_sign_fix_per_joint(local_q)
# Hemisphere-align frame 0 with the bind quat so pause/play takes the # Align frame 0 with the bind quat so pause/play takes the short path.
# short path; then re-propagate.
bind_q = bind_local[:, 3:7].astype(np.float32) bind_q = bind_local[:, 3:7].astype(np.float32)
if local_q.shape[0] > 0: if local_q.shape[0] > 0:
d0 = (bind_q * local_q[0]).sum(axis=-1) d0 = (bind_q * local_q[0]).sum(axis=-1)
sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None] sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None]
local_q[0] = local_q[0] * sign0 local_q[0] = local_q[0] * sign0
local_q = quat_sign_fix_per_joint(local_q) local_q = quat_sign_fix_per_joint(local_q)
# Optional smoothing for multi-frame rig spikes (e.g. q.w discontinuity # Optional smoothing for multi-frame rig spikes (e.g. q.w at handstand).
# at handstand) that the upstream Smooth node may not catch.
if bone_smooth_window and bone_smooth_window > 1: if bone_smooth_window and bone_smooth_window > 1:
local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window)) local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window))
# fp64 renormalize → fp32 keyframes. Viewers' nlerp amplifies non-unit # fp64 renormalize → fp32; viewers' nlerp amplifies non-unit drift.
# drift into visible flips otherwise.
lq64 = local_q.astype(np.float64) lq64 = local_q.astype(np.float64)
lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12) lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12)
local_q = lq64.astype(np.float32) local_q = lq64.astype(np.float32)
@ -527,7 +495,7 @@ def build_glb_skeletal(
"target": {"node": person_root_idx, "path": "translation"}, "target": {"node": person_root_idx, "path": "translation"},
}) })
# Body-mesh-only: bone-vis primitives have no morph targets. # Body mesh only — bone-vis primitives have no morph targets.
if expr_morph_accs and body_mesh_node_idx is not None: if expr_morph_accs and body_mesh_node_idx is not None:
expr_per_frame = np.stack([ expr_per_frame = np.stack([
np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32) np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32)

View File

@ -34,8 +34,7 @@ def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]:
def inputs_from_sam3_track(track_data, B: int, H: int, W: int): def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image """Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
resolution. Returns (per_frame_bboxes, per_frame_masks) or resolution. Returns (None, None) on empty track / frame-count mismatch."""
(None, None) when the track is empty / frame count doesn't match"""
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
if packed is None: if packed is None:
@ -100,7 +99,7 @@ def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[to
def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any], def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any],
H: int, W: int) -> Dict[str, Any]: H: int, W: int) -> Dict[str, Any]:
"""Re-project every frame's pose through a Load3D 6DOF camera (position/ """Re-project every frame's pose through a Load3D 6DOF camera (position/
target/zoom + optional FOV). Returns a new mhr_pose_data; unchanged on target/zoom + optional FOV). Returns a new mhr_pose_data, unchanged on
empty/invalid input.""" empty/invalid input."""
first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else [] first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else []
if not first_frame: if not first_frame:
@ -158,16 +157,16 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str,
y_axis = np.cross(z_axis, x_axis) y_axis = np.cross(z_axis, x_axis)
R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32) R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32)
# Eye: dolly along the given offset; for a rotation-only camera (position == # Eye: dolly along the offset; rotation-only camera keeps the predicted
# target) keep the predicted viewing distance so only orientation/roll changes. # viewing distance so only orientation/roll changes.
if has_offset: if has_offset:
eye = target + offset / max(0.01, zoom) eye = target + offset / max(0.01, zoom)
else: else:
d = max(0.1, float(target[2])) d = max(0.1, float(target[2]))
eye = target - z_axis * (d / max(0.01, zoom)) eye = target - z_axis * (d / max(0.01, zoom))
# Lens: use the camera's own FoV; else the SAM3D predicted focal (viewpoint- # Lens: camera FoV if given, else the SAM3D predicted focal. Three.js fov
# only change). Three.js fov is vertical → focal from image height. # is vertical → focal from image height.
cam_fov = float(camera_info.get("fov", 0.0) or 0.0) cam_fov = float(camera_info.get("fov", 0.0) or 0.0)
if cam_fov > 0: if cam_fov > 0:
new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0))) new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0)))
@ -178,10 +177,8 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str,
center = np.array([W * 0.5, H * 0.5], dtype=np.float32) center = np.array([W * 0.5, H * 0.5], dtype=np.float32)
reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"} reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"}
# External rigs (e.g. Kimodo) store pred_joint_coords rig-native Y-up; the # External rigs store pred_joint_coords Y-up; transform them through the
# render openpose/scail keypoint provider resolves from them and flips Y/Z. # camera too (in camera space, then back to Y-up) so they follow the override.
# Transform them through the camera too (in camera space, then back to Y-up)
# so those keypoints follow the override instead of staying in the old frame.
override = mhr_pose_data.get("_skeleton_override") override = mhr_pose_data.get("_skeleton_override")
joints_y_up = override is not None and not bool(override.get("per_frame_y_down", False)) joints_y_up = override is not None and not bool(override.get("per_frame_y_down", False))
new_frames: List[List[Dict[str, Any]]] = [] new_frames: List[List[Dict[str, Any]]] = []
@ -242,8 +239,7 @@ def run_batched_single_chunk(inner: SAM3DBody, frames_rgb: List[torch.Tensor], p
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)] img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
if per_frame_masks is not None: if per_frame_masks is not None:
# Broadcast a single-mask bundle to per-bbox: when the user supplied one # One mask but multiple bboxes per frame → each bbox gets the same mask.
# mask but multiple bboxes per frame, each bbox gets the same mask.
flat_masks = [] flat_masks = []
for f in range(N): for f in range(N):
mf = per_frame_masks[f] mf = per_frame_masks[f]