mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Big cleanup
This commit is contained in:
parent
f1be65f914
commit
ecbaefd8fc
@ -4,25 +4,15 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
|
||||
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask
|
||||
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, mhr_param_hand_mask
|
||||
|
||||
from ..model.transformer import MLP
|
||||
|
||||
|
||||
class MHRHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
mhr_rig,
|
||||
mlp_depth: int = 1,
|
||||
extra_joint_regressor: str = "",
|
||||
mlp_channel_div_factor: int = 8,
|
||||
enable_hand_model=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
def __init__(self, input_dim: int, mhr_rig, mlp_depth: int = 1, mlp_channel_div_factor: int = 8, enable_hand_model=False,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
# Store the shared MHRRig as a non-registered Python attribute
|
||||
object.__setattr__(self, "mhr", mhr_rig)
|
||||
@ -48,9 +38,7 @@ class MHRHead(nn.Module):
|
||||
hidden_dim=input_dim // mlp_channel_div_factor,
|
||||
output_dim=self.npose,
|
||||
num_layers=mlp_depth,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
# MHR Parameters
|
||||
@ -75,28 +63,25 @@ class MHRHead(nn.Module):
|
||||
self.local_to_world_wrist = _p(3, 3)
|
||||
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
|
||||
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
|
||||
# Optional — loaded from the .safetensors if present, otherwise the
|
||||
# render path falls back to a coarse geometric approximation.
|
||||
self.register_buffer(
|
||||
"face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32),
|
||||
)
|
||||
self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32))
|
||||
|
||||
def canonical_vertices(self, device=None):
|
||||
def canonical_vertices(self):
|
||||
"""Return the T-pose vertices for the mean shape (scaled to meters).
|
||||
|
||||
Runs MHR with zero pose / shape / scale / expression so the returned
|
||||
mesh is the canonical rest pose — fixed per-model
|
||||
"""
|
||||
dev = device or self.scale_mean.device
|
||||
dt = self.scale_mean.dtype
|
||||
device = self.scale_mean.device
|
||||
dtype = self.scale_mean.dtype
|
||||
B = 1
|
||||
global_trans = torch.zeros(B, 3, device=dev, dtype=dt)
|
||||
global_rot = torch.zeros(B, 3, device=dev, dtype=dt)
|
||||
body_pose = torch.zeros(B, 130, device=dev, dtype=dt)
|
||||
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt)
|
||||
scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt)
|
||||
shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt)
|
||||
expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt)
|
||||
global_trans = torch.zeros(B, 3, device=device, dtype=dtype)
|
||||
global_rot = torch.zeros(B, 3, device=device, dtype=dtype)
|
||||
body_pose = torch.zeros(B, 130, device=device, dtype=dtype)
|
||||
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=device, dtype=dtype)
|
||||
scale = torch.zeros(B, self.num_scale_comps, device=device, dtype=dtype)
|
||||
shape = torch.zeros(B, self.num_shape_comps, device=device, dtype=dtype)
|
||||
expr = torch.zeros(B, self.num_face_comps, device=device, dtype=dtype)
|
||||
|
||||
verts = self.mhr_forward(
|
||||
global_trans=global_trans,
|
||||
global_rot=global_rot,
|
||||
@ -108,20 +93,6 @@ class MHRHead(nn.Module):
|
||||
) # single-tensor shape (1, N_v, 3) in meters
|
||||
return verts[0]
|
||||
|
||||
def get_zero_pose_init(self, factor=1.0):
|
||||
# Initialize pose token with zero-initialized learnable params
|
||||
# Note: bias/initial value should be zero-pose in cont, not all-zeros
|
||||
weights = torch.zeros(1, self.npose)
|
||||
weights[:, : 6 + self.body_cont_dim] = torch.cat(
|
||||
[
|
||||
torch.FloatTensor([1, 0, 0, 0, 1, 0]),
|
||||
compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
|
||||
* factor,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
return weights
|
||||
|
||||
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
|
||||
assert full_pose_params.shape[1] == 136
|
||||
|
||||
@ -159,12 +130,9 @@ class MHRHead(nn.Module):
|
||||
shape_params,
|
||||
expr_params=None,
|
||||
return_keypoints=False,
|
||||
do_pcblend=True,
|
||||
return_joint_coords=False,
|
||||
return_model_params=False,
|
||||
return_joint_rotations=False,
|
||||
scale_offsets=None,
|
||||
vertex_offsets=None,
|
||||
):
|
||||
# Align everything to the static buffers
|
||||
dt = self.scale_mean.dtype
|
||||
@ -206,14 +174,10 @@ class MHRHead(nn.Module):
|
||||
shape_params = shape_params[None]
|
||||
# Convert scale...
|
||||
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
|
||||
if scale_offsets is not None:
|
||||
scales = scales + scale_offsets
|
||||
|
||||
# Now, figure out the pose.
|
||||
## 10 here is because it's more stable to optimize global translation in meters.
|
||||
full_pose_params = torch.cat(
|
||||
[global_trans * 10, global_rot, body_pose_params], dim=1
|
||||
) # B x 127
|
||||
full_pose_params = torch.cat([global_trans * 10, global_rot, body_pose_params], dim=1) # B x 127
|
||||
## Put in hands
|
||||
if hand_pose_params is not None:
|
||||
full_pose_params = self.replace_hands_in_pose(
|
||||
@ -268,14 +232,7 @@ class MHRHead(nn.Module):
|
||||
else:
|
||||
return tuple(to_return)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
init_estimate: Optional[torch.Tensor] = None,
|
||||
do_pcblend=True,
|
||||
slim_keypoints=False,
|
||||
intermediate: bool = False,
|
||||
):
|
||||
def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None, intermediate: bool = False):
|
||||
"""
|
||||
Args:
|
||||
x: pose token with shape [B, C], usually C=DECODER.DIM
|
||||
@ -331,7 +288,6 @@ class MHRHead(nn.Module):
|
||||
scale_params=pred_scale,
|
||||
shape_params=pred_shape,
|
||||
expr_params=pred_face,
|
||||
do_pcblend=do_pcblend,
|
||||
return_keypoints=True,
|
||||
return_joint_coords=True,
|
||||
return_model_params=True,
|
||||
@ -356,7 +312,7 @@ class MHRHead(nn.Module):
|
||||
# Head-MLP outputs are promoted to fp32 here so the external
|
||||
# pose_output["mhr"] contract has a stable dtype regardless of what
|
||||
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are
|
||||
# already fp32 from MHR's math layers; the cast on them is a no-op.
|
||||
# already fp32 from MHR's math layers.
|
||||
output = {
|
||||
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
|
||||
"pred_pose_rotmat": None,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Adapted from facebookresearch/MHR (Apache 2.0):
|
||||
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
|
||||
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
|
||||
# verbatim from the TorchScript source bundled in the upstream mhr_model.pt
|
||||
# verbatim from the upstream mhr_model.pt
|
||||
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
|
||||
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
|
||||
@ -52,7 +52,7 @@ def _skel_multiply(s1, s2):
|
||||
|
||||
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
|
||||
before composition. With many FK levels the previously-normalized quats
|
||||
drift in ULPs; the JIT renormalizes defensively, so we do too to stay
|
||||
drift in ULPs; upstream renormalizes defensively, so we do too to stay
|
||||
bit-close to its outputs.
|
||||
"""
|
||||
t1, sc1 = s1[..., :3], s1[..., 7:8]
|
||||
@ -78,7 +78,7 @@ def _skel_transform_points(skel_state, points):
|
||||
|
||||
|
||||
def _global_skel_state_from_local(local, pmi_levels):
|
||||
"""FK walk in fp64 (matches the JIT's use_double_precision=True path).
|
||||
"""FK walk in fp64 (matches upstream's use_double_precision=True path).
|
||||
|
||||
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
|
||||
one per BFS level. Avoids per-call torch.split + tolist() sync.
|
||||
@ -95,7 +95,7 @@ def _global_skel_state_from_local(local, pmi_levels):
|
||||
class MHRRig(nn.Module):
|
||||
"""Plain-PyTorch reimplementation of Meta's MHR rig.
|
||||
|
||||
All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's
|
||||
All math runs in fp32 (FK upcast to fp64 internally, matching upstream's
|
||||
use_double_precision=True backend) regardless of the host model's dtype.
|
||||
"""
|
||||
|
||||
@ -110,13 +110,11 @@ class MHRRig(nn.Module):
|
||||
POSE_CORR_HIDDEN = 3000
|
||||
POSE_CORR_SPARSE_NNZ = 53136
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
def __init__(self, device=None):
|
||||
super().__init__()
|
||||
del dtype, operations
|
||||
f32 = torch.float32
|
||||
|
||||
# All buffers are populated by load_state_dict from the `mhr.*` keys
|
||||
def _p(*shape, dtype=f32):
|
||||
def _p(*shape, dtype=torch.float32):
|
||||
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
|
||||
def _b(name, *shape, dtype):
|
||||
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
|
||||
@ -147,10 +145,10 @@ class MHRRig(nn.Module):
|
||||
self._pmi_levels_cache = None
|
||||
|
||||
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
|
||||
f32 = self.base_shape.dtype
|
||||
identity_coeffs = identity_coeffs.to(f32)
|
||||
model_parameters = model_parameters.to(f32)
|
||||
expr_coeffs = expr_coeffs.to(f32)
|
||||
dtype = self.base_shape.dtype
|
||||
identity_coeffs = identity_coeffs.to(dtype)
|
||||
model_parameters = model_parameters.to(dtype)
|
||||
expr_coeffs = expr_coeffs.to(dtype)
|
||||
B = identity_coeffs.shape[0]
|
||||
|
||||
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
|
||||
# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity
|
||||
# (batch6DFromXYZ, batchXYZfrom6D) are the continuity
|
||||
# representation from Zhou et al., "On the Continuity of Rotation
|
||||
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
|
||||
# implementations from papagina/RotationContinuity:
|
||||
@ -158,18 +158,10 @@ def _hand_masks(device):
|
||||
m = _HAND_MASK_CACHE.get(device)
|
||||
if m is not None:
|
||||
return m
|
||||
mask_cont_threedofs = torch.cat(
|
||||
[torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
|
||||
).to(device)
|
||||
mask_cont_onedofs = torch.cat(
|
||||
[torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
|
||||
).to(device)
|
||||
mask_model_params_threedofs = torch.cat(
|
||||
[torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
|
||||
).to(device)
|
||||
mask_model_params_onedofs = torch.cat(
|
||||
[torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
|
||||
).to(device)
|
||||
mask_cont_threedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
|
||||
mask_cont_onedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
|
||||
mask_model_params_threedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
|
||||
mask_model_params_onedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
|
||||
m = dict(
|
||||
mask_cont_threedofs=mask_cont_threedofs,
|
||||
mask_cont_onedofs=mask_cont_onedofs,
|
||||
@ -182,7 +174,6 @@ def _hand_masks(device):
|
||||
|
||||
def compact_cont_to_model_params_hand(hand_cont):
|
||||
# These are ordered by joint, not model params ^^
|
||||
assert hand_cont.shape[-1] == 54
|
||||
m = _hand_masks(hand_cont.device)
|
||||
mask_cont_threedofs = m["mask_cont_threedofs"]
|
||||
mask_cont_onedofs = m["mask_cont_onedofs"]
|
||||
@ -209,120 +200,6 @@ def compact_cont_to_model_params_hand(hand_cont):
|
||||
return hand_model_params
|
||||
|
||||
|
||||
def compact_model_params_to_cont_hand(hand_model_params):
|
||||
# These are ordered by joint, not model params ^^
|
||||
assert hand_model_params.shape[-1] == 27
|
||||
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
|
||||
assert sum(hand_dofs_in_order) == 27
|
||||
# Mask of 3DoFs into hand_cont
|
||||
mask_cont_threedofs = torch.cat(
|
||||
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
||||
)
|
||||
# Mask of 1DoFs (including 2DoF) into hand_cont
|
||||
mask_cont_onedofs = torch.cat(
|
||||
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
||||
)
|
||||
# Mask of 3DoFs into hand_model_params
|
||||
mask_model_params_threedofs = torch.cat(
|
||||
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
||||
)
|
||||
# Mask of 1DoFs (including 2DoF) into hand_model_params
|
||||
mask_model_params_onedofs = torch.cat(
|
||||
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
||||
)
|
||||
|
||||
# Convert eulers to hand_cont hand_cont
|
||||
## First for 3DoFs
|
||||
hand_model_params_threedofs = hand_model_params[
|
||||
..., mask_model_params_threedofs
|
||||
].unflatten(-1, (-1, 3))
|
||||
hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
|
||||
## Next for 1DoFs
|
||||
hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
|
||||
hand_cont_onedofs = torch.stack(
|
||||
[hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
|
||||
).flatten(-2, -1)
|
||||
|
||||
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
|
||||
hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
|
||||
hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
|
||||
hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
|
||||
|
||||
return hand_cont
|
||||
|
||||
|
||||
def batch9Dfrom6D(poses):
|
||||
# Args: poses: ... x 6, where "6" is the combined first and second columns
|
||||
# First, get the rotaiton matrix
|
||||
x_raw = poses[..., :3]
|
||||
y_raw = poses[..., 3:]
|
||||
|
||||
x = F.normalize(x_raw, dim=-1)
|
||||
z = torch.cross(x, y_raw, dim=-1)
|
||||
z = F.normalize(z, dim=-1)
|
||||
y = torch.cross(z, x, dim=-1)
|
||||
|
||||
matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def batch4Dfrom2D(poses):
|
||||
# Args: poses: ... x 2, where "2" is sincos
|
||||
poses_norm = F.normalize(poses, dim=-1)
|
||||
|
||||
poses_4d = torch.stack(
|
||||
[
|
||||
poses_norm[..., 1],
|
||||
poses_norm[..., 0],
|
||||
-poses_norm[..., 0],
|
||||
poses_norm[..., 1],
|
||||
],
|
||||
dim=-1,
|
||||
) # Flattened SO2.
|
||||
|
||||
return poses_4d # .... x 4
|
||||
|
||||
|
||||
def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
|
||||
# fmt: off
|
||||
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
||||
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
||||
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
||||
# fmt: on
|
||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
||||
assert body_pose_cont.shape[-1] == (
|
||||
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
|
||||
)
|
||||
# Get subsets
|
||||
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
||||
body_cont_1dofs = body_pose_cont[
|
||||
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
|
||||
]
|
||||
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
|
||||
# Convert conts to model params
|
||||
## First for 3dofs
|
||||
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
|
||||
body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
|
||||
## Next for 1dofs
|
||||
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
|
||||
body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
|
||||
if inflate_trans:
|
||||
assert (
|
||||
False
|
||||
), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
|
||||
else:
|
||||
## Nothing to do for trans
|
||||
body_rotmat_trans = body_cont_trans
|
||||
# Put them together
|
||||
body_rotmat_params = torch.cat(
|
||||
[body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
|
||||
)
|
||||
return body_rotmat_params
|
||||
|
||||
|
||||
_BODY_IDX_CACHE: dict = {}
|
||||
|
||||
|
||||
@ -349,8 +226,6 @@ def compact_cont_to_model_params_body(body_pose_cont):
|
||||
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
|
||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
||||
assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
|
||||
# Get subsets
|
||||
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
||||
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
|
||||
@ -372,42 +247,10 @@ def compact_cont_to_model_params_body(body_pose_cont):
|
||||
return body_pose_params
|
||||
|
||||
|
||||
def compact_model_params_to_cont_body(body_pose_params):
|
||||
# fmt: off
|
||||
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
||||
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
||||
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
||||
# fmt: on
|
||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
||||
assert body_pose_params.shape[-1] == (
|
||||
num_3dof_angles + num_1dof_angles + num_1dof_trans
|
||||
)
|
||||
# Take out params
|
||||
body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
|
||||
body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
|
||||
body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
|
||||
# params to cont
|
||||
body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
|
||||
-2, -1
|
||||
)
|
||||
body_cont_1dofs = torch.stack(
|
||||
[body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
|
||||
).flatten(-2, -1)
|
||||
body_cont_trans = body_params_trans
|
||||
# Put them together
|
||||
body_pose_cont = torch.cat(
|
||||
[body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
|
||||
)
|
||||
return body_pose_cont
|
||||
|
||||
|
||||
# fmt: off
|
||||
mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
|
||||
mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
|
||||
# Hand indices into the 133-dim param and 260-dim cont body-pose vectors.
|
||||
mhr_param_hand_idxs = list(range(62, 116))
|
||||
mhr_cont_hand_idxs = list(range(72, 132)) + list(range(190, 238))
|
||||
mhr_param_hand_mask = torch.zeros(133).bool()
|
||||
mhr_param_hand_mask[mhr_param_hand_idxs] = True
|
||||
mhr_cont_hand_mask = torch.zeros(260).bool()
|
||||
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
|
||||
# fmt: on
|
||||
|
||||
@ -43,15 +43,6 @@ class FourierPositionEncoding(nn.Module):
|
||||
self.num_bands = num_bands
|
||||
self.max_resolution = [max_resolution] * n
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
num_dims = len(self.max_resolution)
|
||||
encoding_size = self.num_bands * num_dims
|
||||
encoding_size *= 2 # sin-cos
|
||||
encoding_size += num_dims # concat
|
||||
|
||||
return encoding_size
|
||||
|
||||
def forward(self, pos: torch.Tensor):
|
||||
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
|
||||
return fourier_pos_enc
|
||||
@ -118,9 +109,7 @@ class PerspectiveHead(nn.Module):
|
||||
pred_cam: torch.Tensor,
|
||||
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
|
||||
bbox_size: torch.Tensor, # [N,], in original image space
|
||||
img_size: torch.Tensor,
|
||||
cam_int: torch.Tensor, # [B, 3, 3]
|
||||
use_intrin_center: bool = False,
|
||||
):
|
||||
batch_size = points_3d.shape[0]
|
||||
pred_cam = pred_cam.clone()
|
||||
@ -133,12 +122,8 @@ class PerspectiveHead(nn.Module):
|
||||
focal_length = cam_int[:, 0, 0]
|
||||
tz = 2 * focal_length / bs
|
||||
|
||||
if not use_intrin_center:
|
||||
cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs
|
||||
cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
|
||||
else:
|
||||
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
|
||||
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
|
||||
cx = 2 * (bbox_center[:, 0] - cam_int[:, 0, 2]) / bs
|
||||
cy = 2 * (bbox_center[:, 1] - cam_int[:, 1, 2]) / bs
|
||||
|
||||
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
||||
|
||||
|
||||
@ -37,20 +37,15 @@ class SAM3DBody(nn.Module):
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
# `operations` falls back to torch.nn so the model is constructible
|
||||
# without comfy.ops; matches the pattern in comfy/ldm/sam3/.
|
||||
ops = operations if operations is not None else nn
|
||||
|
||||
# Per-batch state populated by `_initialize_batch`.
|
||||
self._max_num_person = None
|
||||
self._person_valid = None
|
||||
|
||||
self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False)
|
||||
self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False)
|
||||
|
||||
self.image_size = IMAGE_SIZE
|
||||
|
||||
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=ops)
|
||||
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations)
|
||||
embed_dims = self.backbone.embed_dims
|
||||
|
||||
# MHR rig shared between body + hand pose heads via a non-registered
|
||||
@ -72,7 +67,7 @@ class SAM3DBody(nn.Module):
|
||||
self.head_pose.hand_pose_comps.data = (
|
||||
torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
|
||||
)
|
||||
self.init_pose = ops.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
|
||||
self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
|
||||
|
||||
self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs)
|
||||
self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
|
||||
@ -81,7 +76,7 @@ class SAM3DBody(nn.Module):
|
||||
self.head_pose_hand.hand_pose_comps.data = (
|
||||
torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
|
||||
)
|
||||
self.init_pose_hand = ops.Embedding(
|
||||
self.init_pose_hand = operations.Embedding(
|
||||
1, self.head_pose_hand.npose, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
@ -93,25 +88,25 @@ class SAM3DBody(nn.Module):
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.head_camera = PerspectiveHead(**camera_kwargs)
|
||||
self.init_camera = ops.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype)
|
||||
self.init_camera = operations.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype)
|
||||
|
||||
self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs)
|
||||
self.init_camera_hand = ops.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype)
|
||||
self.init_camera_hand = operations.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype)
|
||||
|
||||
cond_dim = 3
|
||||
init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim
|
||||
linear_kwargs = dict(device=device, dtype=dtype)
|
||||
self.init_to_token_mhr = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs)
|
||||
self.prev_to_token_mhr = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
|
||||
self.init_to_token_mhr_hand = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs)
|
||||
self.prev_to_token_mhr_hand = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
|
||||
self.init_to_token_mhr = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs)
|
||||
self.prev_to_token_mhr = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
|
||||
self.init_to_token_mhr_hand = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs)
|
||||
self.prev_to_token_mhr_hand = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
|
||||
|
||||
self.prompt_encoder = PromptEncoder(
|
||||
embed_dim=embed_dims, # match backbone dims so PE adds directly
|
||||
num_body_joints=N_KEYPOINTS,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.prompt_to_token = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||
self.prompt_to_token = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||
|
||||
decoder_kwargs = dict(
|
||||
dims=DECODER_DIM,
|
||||
@ -141,11 +136,10 @@ class SAM3DBody(nn.Module):
|
||||
|
||||
self.keypoint_embedding_idxs = list(range(N_KEYPOINTS))
|
||||
self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS))
|
||||
self.keypoint_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||
|
||||
self.hand_box_embedding = ops.Embedding(2, DECODER_DIM, **linear_kwargs)
|
||||
self.hand_cls_embed = ops.Linear(DECODER_DIM, 2, **linear_kwargs)
|
||||
self.hand_box_embedding = operations.Embedding(2, DECODER_DIM, **linear_kwargs)
|
||||
self.bbox_embed = MLP(
|
||||
input_dim=DECODER_DIM, hidden_dim=DECODER_DIM,
|
||||
output_dim=4, num_layers=3,
|
||||
@ -158,13 +152,13 @@ class SAM3DBody(nn.Module):
|
||||
)
|
||||
self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs)
|
||||
self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs)
|
||||
self.keypoint_feat_linear = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint_feat_linear_hand = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint_feat_linear = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint_feat_linear_hand = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||
|
||||
self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS))
|
||||
self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS))
|
||||
self.keypoint3d_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint3d_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint3d_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint3d_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||
self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs)
|
||||
self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs)
|
||||
|
||||
@ -183,11 +177,9 @@ class SAM3DBody(nn.Module):
|
||||
def _initialize_batch(self, batch: Dict) -> None:
|
||||
if batch["img"].dim() == 5:
|
||||
self._batch_size, self._max_num_person = batch["img"].shape[:2]
|
||||
self._person_valid = self._flatten_person(batch["person_valid"]) > 0
|
||||
else:
|
||||
self._batch_size = batch["img"].shape[0]
|
||||
self._max_num_person = 0
|
||||
self._person_valid = None
|
||||
|
||||
def _flatten_person(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert self._max_num_person is not None, "No max_num_person initialized"
|
||||
@ -258,11 +250,9 @@ class SAM3DBody(nn.Module):
|
||||
if is_multi_image:
|
||||
assert isinstance(img, list)
|
||||
n = len(img)
|
||||
H_src, W_src = img[0].shape[:2]
|
||||
src_t = torch.stack(list(img), dim=0)
|
||||
else:
|
||||
n = int(left_xyxy.shape[0])
|
||||
H_src, W_src = img.shape[:2]
|
||||
src_t = img.unsqueeze(0).expand(n, -1, -1, -1)
|
||||
|
||||
H_out, W_out = int(self.image_size[0]), int(self.image_size[1])
|
||||
@ -292,14 +282,12 @@ class SAM3DBody(nn.Module):
|
||||
zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device)
|
||||
person_valid = torch.ones((1, n), dtype=torch.float32, device=device)
|
||||
img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous()
|
||||
ori_img_size = torch.tensor([W_src, H_src], dtype=torch.float32, device=device).expand(n, 2).contiguous()
|
||||
cam_int_dev = cam_int.to(device).to(dtype=torch.float32)
|
||||
|
||||
def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy):
|
||||
return {
|
||||
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
|
||||
"img_size": img_size.unsqueeze(0),
|
||||
"ori_img_size": ori_img_size.unsqueeze(0),
|
||||
"bbox_center": centers_t.to(device).unsqueeze(0),
|
||||
"bbox_scale": scales_t.to(device).unsqueeze(0),
|
||||
"bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0),
|
||||
@ -349,7 +337,6 @@ class SAM3DBody(nn.Module):
|
||||
self,
|
||||
branch: str,
|
||||
image_embeddings: torch.Tensor,
|
||||
init_estimate: Optional[torch.Tensor] = None,
|
||||
keypoints: Optional[torch.Tensor] = None,
|
||||
prev_estimate: Optional[torch.Tensor] = None,
|
||||
condition_info: Optional[torch.Tensor] = None,
|
||||
@ -359,7 +346,6 @@ class SAM3DBody(nn.Module):
|
||||
of the pipeline is shared.
|
||||
|
||||
image_embeddings: (B, C, H, W) backbone features.
|
||||
init_estimate: (B, 1, C) initial pose+cam estimate to refine.
|
||||
keypoints: (B, N, 3) prompts as (x, y in [0, 1], label).
|
||||
label: 0..K = joint, -1 = incorrect, -2 = invalid.
|
||||
prev_estimate: (B, 1, C) previous estimate for pose refinement.
|
||||
@ -402,15 +388,11 @@ class SAM3DBody(nn.Module):
|
||||
|
||||
# .to(image_embeddings) moves weights CPU→GPU under dynamic loading
|
||||
# (they stay on CPU until first use).
|
||||
if init_estimate is None:
|
||||
init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
|
||||
init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
|
||||
init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3)
|
||||
init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
|
||||
init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
|
||||
init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3)
|
||||
|
||||
init_input = (
|
||||
torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
|
||||
if condition_info is not None else init_estimate
|
||||
)
|
||||
init_input = torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
|
||||
token_embeddings = init_to_token(init_input).view(batch_size, 1, -1)
|
||||
num_pose_token = token_embeddings.shape[1] # always 1
|
||||
|
||||
@ -495,9 +477,8 @@ class SAM3DBody(nn.Module):
|
||||
|
||||
def _get_mask_prompt(self, batch, image_embeddings):
|
||||
x_mask = self._flatten_person(batch["mask"])
|
||||
# batch tensors are fp32 from prepare_batch; mask_downscaling is in the
|
||||
# Loader's dtype — cast once so the conv input matches.
|
||||
x_mask = x_mask.to(image_embeddings.dtype)
|
||||
|
||||
mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings(
|
||||
x_mask, image_embeddings.shape[0], image_embeddings.shape[2:]
|
||||
)
|
||||
@ -546,7 +527,6 @@ class SAM3DBody(nn.Module):
|
||||
# expand+contiguous for the vertices branch.
|
||||
bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx]
|
||||
bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0]
|
||||
ori_img_size = self._flatten_person(batch["ori_img_size"])[batch_idx]
|
||||
cam_int = self._flatten_person(
|
||||
batch["cam_int"]
|
||||
.unsqueeze(1)
|
||||
@ -556,8 +536,7 @@ class SAM3DBody(nn.Module):
|
||||
|
||||
def _project(points_3d):
|
||||
return head_camera.perspective_projection(
|
||||
points_3d, pred_cam, bbox_center, bbox_scale, ori_img_size, cam_int,
|
||||
use_intrin_center=True,
|
||||
points_3d, pred_cam, bbox_center, bbox_scale, cam_int,
|
||||
)
|
||||
|
||||
cam_out = _project(pose_output["pred_keypoints_3d"])
|
||||
@ -632,7 +611,6 @@ class SAM3DBody(nn.Module):
|
||||
tokens_output, pose_output = self.forward_decoder(
|
||||
"body",
|
||||
image_embeddings[self.body_batch_idx],
|
||||
init_estimate=None,
|
||||
keypoints=keypoints_prompt[self.body_batch_idx],
|
||||
prev_estimate=None,
|
||||
condition_info=condition_info[self.body_batch_idx],
|
||||
@ -643,7 +621,6 @@ class SAM3DBody(nn.Module):
|
||||
tokens_output_hand, pose_output_hand = self.forward_decoder(
|
||||
"hand",
|
||||
image_embeddings[self.hand_batch_idx],
|
||||
init_estimate=None,
|
||||
keypoints=keypoints_prompt[self.hand_batch_idx],
|
||||
prev_estimate=None,
|
||||
condition_info=condition_info[self.hand_batch_idx],
|
||||
@ -661,10 +638,8 @@ class SAM3DBody(nn.Module):
|
||||
# match the head-MLP external contract (_get_hand_box would .float() anyway).
|
||||
if len(self.body_batch_idx):
|
||||
output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float()
|
||||
output["mhr"]["hand_logits"] = self.hand_cls_embed(tokens_output).float()
|
||||
if len(self.hand_batch_idx):
|
||||
output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid()
|
||||
output["mhr_hand"]["hand_logits"] = self.hand_cls_embed(tokens_output_hand)
|
||||
|
||||
return output
|
||||
|
||||
@ -715,10 +690,10 @@ class SAM3DBody(nn.Module):
|
||||
# Concat lhand+rhand along dim 0 so backbone+decoder run once on
|
||||
# (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass.
|
||||
batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand)
|
||||
saved_batch_state = (self._batch_size, self._max_num_person, self._person_valid)
|
||||
saved_batch_state = (self._batch_size, self._max_num_person)
|
||||
self._initialize_batch(batch_hands)
|
||||
hands_output = self.forward_step(batch_hands, decoder_type="hand")
|
||||
self._batch_size, self._max_num_person, self._person_valid = saved_batch_state
|
||||
self._batch_size, self._max_num_person = saved_batch_state
|
||||
n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1]
|
||||
lhand_output, rhand_output = self._split_hand_output(hands_output, n_left)
|
||||
# Free the batched image_embeddings/condition_info (unused downstream);
|
||||
@ -808,9 +783,7 @@ class SAM3DBody(nn.Module):
|
||||
# to get an updated body pose estimation.
|
||||
self._set_active_branch("body")
|
||||
|
||||
right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
||||
left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
||||
left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1
|
||||
# right_kps_full / left_kps_full already computed above (unchanged since).
|
||||
right_kps_crop = self._full_to_crop(batch, right_kps_full)
|
||||
left_kps_crop = self._full_to_crop(batch, left_kps_full)
|
||||
|
||||
@ -1030,7 +1003,6 @@ class SAM3DBody(nn.Module):
|
||||
_, pose_output = self.forward_decoder(
|
||||
"body",
|
||||
image_embeddings,
|
||||
init_estimate=None, # use the default init, not the prev estimate
|
||||
keypoints=keypoint_prompt,
|
||||
prev_estimate=prev_estimate,
|
||||
condition_info=condition_info,
|
||||
|
||||
@ -29,38 +29,37 @@ class PromptEncoder(nn.Module):
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
"""
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
self.embed_dim = embed_dim
|
||||
self.num_body_joints = num_body_joints
|
||||
|
||||
# Keypoint prompts
|
||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||
self.point_embeddings = nn.ModuleList(
|
||||
[ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
|
||||
[operations.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
|
||||
)
|
||||
self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
self.not_a_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
self.invalid_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
|
||||
LN2d = LayerNorm2d_op(ops)
|
||||
LN2d = LayerNorm2d_op(operations)
|
||||
mask_in_chans = 256
|
||||
self.mask_downscaling = nn.Sequential(
|
||||
ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
operations.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
LN2d(mask_in_chans // 64, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
operations.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
LN2d(mask_in_chans // 16, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
operations.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
LN2d(mask_in_chans // 4, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
operations.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
LN2d(mask_in_chans, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
|
||||
operations.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
|
||||
)
|
||||
# Trained values for the gating conv and no_mask_embed are loaded from the state dict
|
||||
self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
self.no_mask_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Positional encoding over the image-embedding grid; (1, C, H, W)."""
|
||||
@ -120,8 +119,7 @@ class PromptEncoder(nn.Module):
|
||||
Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
"""
|
||||
bs = self._get_batch_size(keypoints, boxes, masks)
|
||||
# Anchor device on the input prompts so we don't pull the offloaded
|
||||
# CPU embedding device under dynamic loading.
|
||||
|
||||
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
|
||||
device = ref.device if ref is not None else self.point_embeddings[0].weight.device
|
||||
weight_dtype = self.invalid_point_embed.weight.dtype
|
||||
@ -136,23 +134,10 @@ class PromptEncoder(nn.Module):
|
||||
|
||||
return sparse_embeddings, sparse_masks
|
||||
|
||||
def get_mask_embeddings(
|
||||
self,
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
bs: int = 1,
|
||||
size: Tuple[int, int] = (16, 16), # [H, W]
|
||||
) -> torch.Tensor:
|
||||
"""Embeds mask inputs."""
|
||||
# masks is always on the active device when present; fall back to the
|
||||
# downscaling Conv's weight device when it isn't (rare callers).
|
||||
ref = masks if masks is not None else next(self.mask_downscaling.parameters())
|
||||
no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand(
|
||||
bs, -1, size[0], size[1]
|
||||
)
|
||||
if masks is not None:
|
||||
mask_embeddings = self.mask_downscaling(masks)
|
||||
else:
|
||||
mask_embeddings = no_mask_embeddings
|
||||
def get_mask_embeddings(self, masks: torch.Tensor, bs: int = 1, size: Tuple[int, int] = (16, 16)) -> torch.Tensor:
|
||||
"""Embeds mask inputs. Caller casts both outputs to its working dtype."""
|
||||
no_mask_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, size[0], size[1])
|
||||
mask_embeddings = self.mask_downscaling(masks)
|
||||
return mask_embeddings, no_mask_embeddings
|
||||
|
||||
|
||||
@ -170,12 +155,9 @@ class PromptableDecoder(nn.Module):
|
||||
repeat_pe: bool = False,
|
||||
do_interm_preds: bool = False,
|
||||
keypoint_token_update: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
TransformerDecoderLayer(
|
||||
@ -193,7 +175,7 @@ class PromptableDecoder(nn.Module):
|
||||
for i in range(depth)
|
||||
)
|
||||
|
||||
self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm_final = operations.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
|
||||
self.do_interm_preds = do_interm_preds
|
||||
self.keypoint_token_update = keypoint_token_update
|
||||
|
||||
|
||||
@ -166,12 +166,10 @@ def prepare_batch(
|
||||
mask_score_t = torch.ones((n,), dtype=torch.float32)
|
||||
|
||||
img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous()
|
||||
ori_img_size_t = torch.tensor([width, height], dtype=torch.float32).expand(n, 2).contiguous()
|
||||
|
||||
batch = {
|
||||
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
|
||||
"img_size": img_size_t.unsqueeze(0), # (1, N, 2)
|
||||
"ori_img_size": ori_img_size_t.unsqueeze(0),# (1, N, 2)
|
||||
"bbox_center": centers.unsqueeze(0), # (1, N, 2)
|
||||
"bbox_scale": scales.unsqueeze(0), # (1, N, 2)
|
||||
"bbox": boxes_t.unsqueeze(0), # (1, N, 4)
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
"""BVH export for SAM 3D Body pose_data.
|
||||
|
||||
BVH stores explicit bone OFFSETs per joint, so any standard importer
|
||||
(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations
|
||||
directly — no heuristic guessing as needed for glTF. We skip the rig's joint 0
|
||||
(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos +
|
||||
ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are
|
||||
intrinsic Z-X-Y Euler degrees.
|
||||
BVH stores explicit bone OFFSETs per joint, so standard importers reconstruct
|
||||
anatomical bone orientations directly (unlike glTF). We skip the rig's joint 0
|
||||
(static world anchor) and use joint 1 as the ROOT (6 channels: XYZ pos + ZXY
|
||||
rot); other joints get 3 channels. Rotations are intrinsic Z-X-Y Euler degrees.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -49,13 +47,10 @@ def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
def _find_bvh_root(parents: np.ndarray, is_external: bool = False) -> int:
|
||||
"""First child of the rig's world anchor so the static origin→body stick
|
||||
bone gets left out. Falls back to the first root joint.
|
||||
|
||||
MHR's joint 0 is a static world anchor whose single child is the pelvis, so
|
||||
skipping it is correct. External rigs (e.g. SOMA-77) whose root is already
|
||||
the articulated body root with multiple child chains must keep the root —
|
||||
descending into one child would drop the sibling limbs from the BVH."""
|
||||
"""First child of the rig's world anchor, dropping the origin→body stick.
|
||||
Falls back to the first root joint. External rigs whose root is already the
|
||||
articulated body root with multiple child chains keep the root — descending
|
||||
into one child would drop the sibling limbs."""
|
||||
NJ = parents.shape[0]
|
||||
world_anchors = [j for j in range(NJ)
|
||||
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
|
||||
@ -93,14 +88,11 @@ def build_bvh(
|
||||
track_index: int = -1,
|
||||
units: str = "cm",
|
||||
) -> bytes:
|
||||
"""Build a BVH file from pose_data. Returns UTF-8 encoded text bytes.
|
||||
"""Build a BVH file from pose_data. Returns UTF-8 text bytes.
|
||||
|
||||
`model` may be None when pose_data carries a `_skeleton_override` (external
|
||||
rigs, e.g. Kimodo); the rig hierarchy/offsets/bind are read from the
|
||||
override instead of the MHR model.
|
||||
|
||||
`units` is "cm" (default, standard mocap convention) or "m". Affects the
|
||||
OFFSET and root-position values; rotations are independent of units.
|
||||
rigs); the rig hierarchy/offsets/bind come from the override. `units` is
|
||||
"cm" (default) or "m" — affects OFFSET/root-position, not rotations.
|
||||
"""
|
||||
if units not in ("cm", "m"):
|
||||
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
|
||||
@ -123,10 +115,8 @@ def build_bvh(
|
||||
body_root = _find_bvh_root(parents, is_external)
|
||||
children_map = _build_children_map(parents)
|
||||
|
||||
# Bone OFFSETs come from MHR's translation_offsets (joint position
|
||||
# relative to parent in parent's local-bind frame). For the BVH root,
|
||||
# we use its bind world position so the skeleton sits at the right
|
||||
# spot when imported.
|
||||
# Bone OFFSETs = translation_offsets (joint position relative to parent).
|
||||
# The BVH root uses its bind world position so the skeleton imports in place.
|
||||
bind_global = rig.bind_global_cm # (NJ, 8) cm
|
||||
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
|
||||
offset_m = rig.joint_offsets_cm.astype(np.float64) * 0.01
|
||||
@ -139,9 +129,8 @@ def build_bvh(
|
||||
_visit(c)
|
||||
_visit(body_root)
|
||||
|
||||
# Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative)
|
||||
# rather than re-running rig.forward, then derive locals with body_root
|
||||
# treated as the hierarchy root in BVH-space.
|
||||
# Stored pred_global_rots/pred_joint_coords (authoritative); derive locals
|
||||
# with body_root as the BVH-space hierarchy root.
|
||||
rig_global_m = global_skel_state_from_pose_data(
|
||||
pose_data, frame_indices, person_k, NJ,
|
||||
joint_coords_y_down=rig.per_frame_y_down,
|
||||
@ -203,9 +192,8 @@ def build_bvh(
|
||||
lines.append(f"Frames: {n_frames}")
|
||||
lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
|
||||
|
||||
# Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per
|
||||
# frame, columns in `bvh_order` order. Vectorized — savetxt's C-side
|
||||
# formatting beats Python f-strings by ~10× on long clips.
|
||||
# Channel matrix per frame: root pos (3) + root rot (3) + non-root rots
|
||||
# (3 each), columns in `bvh_order`. savetxt is far faster than f-strings.
|
||||
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
|
||||
motion = np.concatenate([
|
||||
root_pos_m * unit_scale, # (N, 3)
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent
|
||||
torch ray-marching SDF renderer adapted to SAM3DBody pose_data.
|
||||
"""3D capsule rendering for OpenPose-style skeletons (SCAIL-Pose look).
|
||||
|
||||
Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps),
|
||||
projected through the per-person camera (`pred_cam_t` + `focal_length` +
|
||||
image_size) so closer limbs appear thicker/brighter — the SCAIL-Pose
|
||||
visual style. Self-contained: no dependency on the SCAIL-Pose package.
|
||||
|
||||
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
|
||||
Each limb is a true 3D capsule (cylinder + hemispherical caps), projected
|
||||
through the per-person camera (`pred_cam_t` + `focal_length` + image_size) so
|
||||
closer limbs appear thicker/brighter. Self-contained analytic ray-capsule
|
||||
renderer. Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@ -41,14 +38,12 @@ def _build_specs_from_pose(
|
||||
palette: str,
|
||||
person_brightness_falloff: float = 0.0,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Flatten body + optional hand limbs for one frame into
|
||||
(starts, ends, colors_rgba, is_hand) in camera coords (Y-down, +Z forward).
|
||||
Drops endpoints that are non-finite or behind the camera. `is_hand` flags
|
||||
the hand limbs so the renderer can draw them thinner.
|
||||
"""Flatten body + optional hand limbs for one frame into (starts, ends,
|
||||
colors_rgba, is_hand) in camera coords (Y-down, +Z forward). Drops non-finite
|
||||
or behind-camera endpoints; `is_hand` lets the renderer draw hands thinner.
|
||||
|
||||
`person_brightness_falloff` mixes each per-person limb color toward white
|
||||
by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the
|
||||
mesh rasterizer and GLB exporters."""
|
||||
`person_brightness_falloff` mixes each per-person color toward white by
|
||||
`1 - falloff^k` for track k (track 0 stays vivid)."""
|
||||
starts: List[np.ndarray] = []
|
||||
ends: List[np.ndarray] = []
|
||||
colors: List[np.ndarray] = []
|
||||
@ -65,8 +60,7 @@ def _build_specs_from_pose(
|
||||
if body_op is None or cam_t is None:
|
||||
continue
|
||||
cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3)
|
||||
# op-keypoints are camera frame (Y-down); add cam_t to place the
|
||||
# subject in front of the camera.
|
||||
# op-keypoints are camera frame; add cam_t to place the subject in front.
|
||||
body_kp = body_op + cam_t_np[None, :]
|
||||
|
||||
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
|
||||
@ -148,10 +142,9 @@ def _ray_capsule_t(
|
||||
ba_len: torch.Tensor, # (M,) segment length
|
||||
radius: torch.Tensor, # (M,) per-capsule radius
|
||||
) -> torch.Tensor:
|
||||
"""Closed-form ray-capsule intersection. Returns (K, M) tensor of ray
|
||||
parameters t to the nearest valid hit per capsule, +inf where the ray
|
||||
misses. A capsule is the union of (cylinder body, hemisphere at A,
|
||||
hemisphere at B); each component is a quadratic root-find."""
|
||||
"""Closed-form ray-capsule intersection -> (K, M) ray params t to the nearest
|
||||
valid hit per capsule, +inf on miss. Capsule = union of (cylinder, hemisphere
|
||||
at A, hemisphere at B), each a quadratic root-find."""
|
||||
INF = float("inf")
|
||||
r_sq = radius * radius # (M,)
|
||||
|
||||
@ -238,9 +231,8 @@ def _render_capsules_torch(
|
||||
z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item()))
|
||||
z_near = max(0.05, z_min - float(radius.max().item()))
|
||||
|
||||
# Union of per-capsule screen-space bboxes. Pixels outside this mask
|
||||
# provably can't hit any capsule, so the analytic intersection only runs
|
||||
# on the relevant subset of the canvas (~5-15% at 1080p for typical poses).
|
||||
# Union of per-capsule screen-space bboxes — pixels outside can't hit any
|
||||
# capsule, so intersection only runs on the relevant subset of the canvas.
|
||||
sz = starts[:, 2].clamp(min=z_near)
|
||||
ez = ends[:, 2].clamp(min=z_near)
|
||||
sx_p = starts[:, 0] * fx / sz + cx
|
||||
@ -261,16 +253,13 @@ def _render_capsules_torch(
|
||||
if xmax_i > xmin_i and ymax_i > ymin_i:
|
||||
coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True
|
||||
|
||||
# Analytic ray-capsule intersection. One pass over the masked pixels —
|
||||
# the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel
|
||||
# plus 6 SDF evaluations per hit pixel for finite-difference normals.
|
||||
# Analytic ray-capsule intersection, one pass over the masked pixels.
|
||||
INF = float("inf")
|
||||
flat_t = torch.full((N,), INF, device=device, dtype=torch.float32)
|
||||
flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long)
|
||||
active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1)
|
||||
if active_idx.numel() > 0:
|
||||
# Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory
|
||||
# manageable when both K (image pixels) and M (capsules) are large.
|
||||
# Cap per-chunk (K, M) tensors to ~4M elements to bound peak memory.
|
||||
chunk_max = max(1, int(4_000_000 / max(M, 1)))
|
||||
for i0 in range(0, active_idx.numel(), chunk_max):
|
||||
sub = active_idx[i0 : i0 + chunk_max]
|
||||
@ -284,7 +273,7 @@ def _render_capsules_torch(
|
||||
flat_t[winners] = t_min[hit]
|
||||
flat_m_idx[winners] = m_idx[hit]
|
||||
|
||||
# Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade.
|
||||
# Shade via analytic normal (P - closest point on segment).
|
||||
out = torch.zeros((N, 3), dtype=torch.float32, device=device)
|
||||
if background_rgb is not None:
|
||||
out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone()
|
||||
@ -306,10 +295,10 @@ def _render_capsules_torch(
|
||||
|
||||
col = colors[m_h, :3]
|
||||
if flat_shade:
|
||||
# Solid per-limb color (OpenPose look) — no lighting/depth modulation.
|
||||
# Solid per-limb color (OpenPose look) — no lighting/depth.
|
||||
out[hit_idx] = col
|
||||
return out.view(H, W, 3).clamp(0.0, 1.0)
|
||||
# SCAIL Blinn-Phong (render_torch.py:290-331). Headlight: light = +Z.
|
||||
# SCAIL Blinn-Phong, headlight along +Z.
|
||||
diff = torch.clamp(-(normals[:, 2]), min=0.0)
|
||||
diffuse = 0.45 + 0.55 * diff
|
||||
|
||||
@ -319,7 +308,7 @@ def _render_capsules_torch(
|
||||
half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32)
|
||||
|
||||
# Mild depth fade matches SCAIL's mm-scale ramp in our meter units.
|
||||
# Mild depth fade.
|
||||
z_vals = p_hit[:, 2]
|
||||
z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item())
|
||||
if z_hi - z_lo > 1e-6:
|
||||
@ -351,21 +340,18 @@ def render_pose_data_capsules(
|
||||
hand_radius_scale: float = 0.4,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Render a frame's pose_data as 3D capsules projected through the per-
|
||||
person camera. Returns (H, W, 3) fp32 in [0, 1].
|
||||
"""Render a frame's pose_data as 3D capsules through the per-person camera.
|
||||
Returns (H, W, 3) fp32 in [0, 1].
|
||||
|
||||
`composite='over'` paints over `background` (black if None);
|
||||
`composite='mesh_only'` always uses a black canvas.
|
||||
|
||||
`radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`).
|
||||
Hand limbs use `radius_m * hand_radius_scale` (their bones are far shorter
|
||||
than body limbs). Camera fx/fy come from each person's `focal_length`.
|
||||
`composite='over'` paints over `background` (black if None); 'mesh_only'
|
||||
uses a black canvas. `radius_m` is in meters; hand limbs use
|
||||
`radius_m * hand_radius_scale`. fx/fy come from each person's `focal_length`.
|
||||
"""
|
||||
persons = pose_data["frames"][frame_idx]
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# SAM3DBody shares one camera across the clip — pick from the first valid person.
|
||||
# SAM3DBody shares one camera across the clip — use the first valid person.
|
||||
fx = fy = float(min(H, W))
|
||||
for person in persons:
|
||||
f = person.get("focal_length")
|
||||
|
||||
@ -1,16 +1,10 @@
|
||||
"""GLB export — OpenPose 18-keypoint visualization mode.
|
||||
|
||||
Independent of the MHR rig — sourced from pose_data's `pred_keypoints_3d`
|
||||
(the model's regressed surface keypoints). Each track becomes an armature
|
||||
with a sibling joint per keypoint; sphere markers + stick/capsule limbs are
|
||||
skinned to those joints.
|
||||
|
||||
Optional hand keypoints (also from `pred_keypoints_3d`, indices 21..62) and
|
||||
face landmarks (sampled from `pred_vertices` at fixed head-mesh vertex IDs)
|
||||
extend the same armature.
|
||||
|
||||
OpenPose-shared tables / palettes / mappings live in `glb_shared.py` and are
|
||||
imported below — they're also used by the 2D and 3D renderers in this package.
|
||||
Sourced from pose_data's `pred_keypoints_3d`, independent of the MHR rig. Each
|
||||
track becomes an armature with a joint per keypoint; sphere markers and limbs
|
||||
are skinned to those joints. Optional hands (`pred_keypoints_3d` 21..62) and
|
||||
face landmarks (`pred_vertices` at fixed vertex IDs) extend the same armature.
|
||||
Shared tables/palettes/mappings live in `glb_shared.py`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -55,9 +49,8 @@ def _finalize_skinned_mesh(
|
||||
joints: np.ndarray, weights: np.ndarray, vert_colors: np.ndarray,
|
||||
smooth_shade: bool,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Apply smooth or flat shading to an indexed sphere/stick group mesh and
|
||||
pack per-vertex colors. Smooth keeps the indexed mesh + per-vertex colors;
|
||||
flat duplicates verts per face and gathers face-corner colors."""
|
||||
"""Shade a skinned group mesh and pack per-vertex colors. Smooth keeps the
|
||||
indexed mesh; flat duplicates verts per face and gathers face-corner colors."""
|
||||
if smooth_shade:
|
||||
v_f, n_f, f_f, j_f, w_f = smooth_shade_mesh(verts, faces, joints, weights)
|
||||
return v_f, n_f, f_f, j_f, w_f, vert_colors.astype(np.float32)
|
||||
@ -73,10 +66,8 @@ def _finalize_skinned_mesh(
|
||||
def _pair_colors_from_kp(
|
||||
pairs: Tuple[Tuple[int, int], ...], kp_colors: np.ndarray, endpoint: int = 1,
|
||||
) -> np.ndarray:
|
||||
"""Per-limb color = endpoint-vertex color from `kp_colors`. Default
|
||||
`endpoint=1` picks the second (distal) vertex of each pair, which is
|
||||
the OpenPose-canonical per-finger gradient when fingers go base→tip
|
||||
(wrist=0 → thumb1=1 → thumb2=2 …)."""
|
||||
"""Per-limb color from `kp_colors`. `endpoint=1` (default) picks the distal
|
||||
vertex of each pair — the OpenPose per-finger gradient for base→tip fingers."""
|
||||
n = len(pairs)
|
||||
out = np.zeros((n, 3), dtype=np.float32)
|
||||
for i, (a, b) in enumerate(pairs):
|
||||
@ -88,19 +79,13 @@ def _openpose_bind_at_rig_rest(
|
||||
pose_data: Dict[str, Any], *,
|
||||
include_hands: bool, face_vert_ids: Optional[np.ndarray],
|
||||
) -> Optional[np.ndarray]:
|
||||
"""OpenPose keypoint positions at the rig's REST pose (T-pose at authoring
|
||||
origin), built from the `_skeleton_override`'s `bind_global_m` (joint rest
|
||||
TRS) and `rest_verts_m` (mesh rest verts for face landmarks).
|
||||
"""OpenPose keypoint positions at the rig's REST pose, from the override's
|
||||
`bind_global_m` (joint rest TRS) and `rest_verts_m` (face landmarks).
|
||||
|
||||
Used as the static-bind for openpose-mode geometry so the GLB's static
|
||||
POSITION attribute sits at rig origin — matching skeletal mode's bind and
|
||||
producing the same 'snap from rest to scene-frame-0' transition at the
|
||||
start of playback. Without this, the static geometry is at scene-frame-0
|
||||
(kp_seq[0]) and viewers that auto-fit on static POSITION will center on
|
||||
the scene location, hiding the per-frame motion.
|
||||
|
||||
Returns None when the override is missing or doesn't carry all the needed
|
||||
mappings — caller falls back to per-frame extraction (kp_seq[0])."""
|
||||
Used as the static-bind so the GLB's static POSITION sits at rig origin,
|
||||
matching skeletal mode and producing the same rest→scene-frame-0 transition.
|
||||
Returns None when the override lacks the needed mappings — caller then falls
|
||||
back to per-frame extraction (kp_seq[0])."""
|
||||
override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None
|
||||
if override is None or "bind_global_m" not in override:
|
||||
return None
|
||||
@ -141,19 +126,12 @@ def _openpose_bind_at_rig_rest(
|
||||
def _extract_openpose_keypoints(
|
||||
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
||||
) -> np.ndarray:
|
||||
"""(N, 18, 3) OpenPose keypoint positions in rig-native Y-up metres.
|
||||
"""(N, 18, 3) OpenPose keypoints in rig-native Y-up metres.
|
||||
|
||||
Two sources, in priority order:
|
||||
|
||||
1. **External-skeleton path** — when pose_data has `_skeleton_override`
|
||||
with `openpose18_joint_indices` ((18, 2) int32, see
|
||||
`_resolve_openpose_keypoints_from_joints`), synthesize from each
|
||||
person's `pred_joint_coords` directly. The override frame is already
|
||||
rig-native Y-up, so no axis flip.
|
||||
2. **MHR70 path** (default for SAM3DBody_Predict output) — re-index the
|
||||
first 70 of 308 MHR keypoints (`pred_keypoints_3d`) to COCO-18.
|
||||
Stored y-down (post `j3d[..., [1,2]] *= -1` in sam3d_body), so we
|
||||
un-flip y/z to match rig-native Y-up.
|
||||
External-skeleton path: when the override carries `openpose18_joint_indices`
|
||||
((18, 2) int32), synthesize from each person's `pred_joint_coords` (already
|
||||
Y-up, no flip). MHR70 path (default): re-index `pred_keypoints_3d` to COCO-18
|
||||
and un-flip y/z (stored y-down by sam3d_body).
|
||||
"""
|
||||
frames = pose_data["frames"]
|
||||
N = len(frame_indices)
|
||||
@ -195,10 +173,8 @@ def _extract_openpose_keypoints(
|
||||
for t_idx, t in enumerate(frame_indices):
|
||||
person = frames[t][person_k]
|
||||
if "pred_keypoints_3d" not in person:
|
||||
# Diagnose the source: external-skeleton producers ship
|
||||
# `_skeleton_override` instead of MHR70 keypoints. If the
|
||||
# producer didn't populate `openpose18_joint_indices` either,
|
||||
# we can't synthesize the 18-keypoint set.
|
||||
# External-skeleton producer without `openpose18_joint_indices`:
|
||||
# can't synthesize the 18-keypoint set.
|
||||
if override is not None:
|
||||
raise ValueError(
|
||||
"build_glb_openpose: this pose_data carries "
|
||||
@ -229,15 +205,11 @@ def _extract_openpose_keypoints(
|
||||
def _extract_openpose_hand_keypoints(
|
||||
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
||||
) -> np.ndarray:
|
||||
"""(N, 42, 3) right+left OpenPose hand keypoints (21 + 21) in rig-native
|
||||
Y-up frame.
|
||||
"""(N, 42, 3) right+left OpenPose hand keypoints (21+21) in rig-native Y-up.
|
||||
|
||||
External-skeleton path: requires `openpose_hand21_r_joint_indices` AND
|
||||
`openpose_hand21_l_joint_indices` ((21, 2) int32 each) in the override.
|
||||
Resolved against per-frame `pred_joint_coords` like the body path.
|
||||
|
||||
MHR70 path: re-orders `pred_keypoints_3d` indices 21..62 to OpenPose-21
|
||||
(wrist + 5 fingers, thumb→pinky, base→tip)."""
|
||||
External-skeleton path: needs `openpose_hand21_{r,l}_joint_indices` ((21, 2)
|
||||
int32) in the override, resolved against `pred_joint_coords`. MHR70 path:
|
||||
re-orders `pred_keypoints_3d` 21..62 to OpenPose-21 (wrist + 5 fingers)."""
|
||||
frames = pose_data["frames"]
|
||||
N = len(frame_indices)
|
||||
out = np.zeros((N, 42, 3), dtype=np.float32)
|
||||
@ -307,10 +279,8 @@ def _extract_face_landmarks_from_verts(
|
||||
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
||||
vert_ids: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""(N, K_face, 3) face landmarks sampled from per-frame `pred_vertices`
|
||||
at the supplied head-mesh vertex IDs, unflipped to MHR-native Y-up.
|
||||
Each landmark inherits per-frame shape/expr/pose deformation for free
|
||||
since `pred_vertices` already has it baked in."""
|
||||
"""(N, K_face, 3) face landmarks sampled from `pred_vertices` at the given
|
||||
vertex IDs, unflipped to Y-up. Per-frame deformation is already baked in."""
|
||||
frames = pose_data["frames"]
|
||||
N = len(frame_indices)
|
||||
K = int(vert_ids.shape[0])
|
||||
@ -335,18 +305,11 @@ def _build_openpose_spheres(
|
||||
smooth_shade: bool = False,
|
||||
joint_indices: Optional[np.ndarray] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""UV sphere per OpenPose keypoint, rigidly skinned to that keypoint's
|
||||
joint, vertex-colored from kp_colors. `base_joint_idx` is added to the
|
||||
emitted JOINTS_0 indices so callers can place this group at any offset
|
||||
in the shared skin (body=0, right hand=18, etc.). `joint_indices` (when
|
||||
given) overrides that with explicit per-sphere joint indices, so callers
|
||||
can skip keypoints (e.g. SCAIL head dots).
|
||||
|
||||
`smooth_shade=True` keeps the indexed mesh and writes per-vertex
|
||||
normals via face-normal averaging — round shading on the spheres.
|
||||
`smooth_shade=False` (default) flat-shades by duplicating verts per
|
||||
face, matching the existing OpenPose-mode look. Returns
|
||||
(verts, normals, faces, joints4, weights4, vert_colors)."""
|
||||
"""UV sphere per keypoint, rigidly skinned to that keypoint's joint and
|
||||
vertex-colored from kp_colors. `base_joint_idx` offsets the emitted JOINTS_0
|
||||
indices (body=0, right hand=18, …); `joint_indices`, if given, sets explicit
|
||||
per-sphere indices so callers can skip keypoints (e.g. SCAIL head dots).
|
||||
Returns (verts, normals, faces, joints4, weights4, vert_colors)."""
|
||||
sv, sf = uv_sphere_unit()
|
||||
K = bind_kp_m.shape[0]
|
||||
Nv = sv.shape[0]
|
||||
@ -376,43 +339,23 @@ def _capsule_mesh_local(
|
||||
end_width_frac: float = 0.3,
|
||||
shape: str = "ellipsoid",
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Build a per-limb mesh in limb-local frame along +Y from y=0 (head
|
||||
pole) to y=L (tail pole).
|
||||
"""Per-limb mesh in limb-local frame along +Y from y=0 (head) to y=L (tail).
|
||||
|
||||
`shape` selects the silhouette:
|
||||
- 'ellipsoid' (default): tips are small hemispheres of radius
|
||||
`W * end_width_frac`; body has ellipsoidal radius profile
|
||||
sin(π*u) from w_end at the junctions to W at the middle. Gives
|
||||
a fat-middle / narrow-end stretched-ellipse look.
|
||||
- 'capsule': SCAIL-style "rig" limb — an OPEN cylinder of constant
|
||||
radius W with no hemisphere caps. Pair with sphere joint markers
|
||||
of the same radius so the spheres seamlessly cap the open
|
||||
cylinder ends (the cylinder cross-section ring at the endpoint
|
||||
lies exactly on the sphere surface). Drawing hemisphere caps
|
||||
inside the joint sphere creates a visible bump where the cap
|
||||
pokes out unevenly when sphere radius ≠ cap radius — open
|
||||
cylinders avoid that.
|
||||
`shape`:
|
||||
- 'ellipsoid' (default): hemisphere tips of radius `W * end_width_frac`,
|
||||
ellipsoidal sin(π·u) body profile (fat middle, narrow ends).
|
||||
- 'capsule': SCAIL "rig" limb — an OPEN cylinder of constant radius W,
|
||||
no caps. Pair with same-radius sphere markers so they cap the ends
|
||||
seamlessly (caps would bump out when sphere radius ≠ cap radius).
|
||||
|
||||
Per-limb mesh is required because the cap height (w_end) depends on
|
||||
the limb width — a single canonical mesh can't produce true
|
||||
hemispheres for arbitrary L:W ratios in ellipsoid mode.
|
||||
A per-limb mesh is needed because cap height depends on width — one
|
||||
canonical mesh can't give true hemispheres for arbitrary L:W in ellipsoid.
|
||||
|
||||
Returns:
|
||||
verts: (Nv, 3) float32 — limb-local positions in meters.
|
||||
faces: (Nf, 3) uint32 — triangle indices.
|
||||
weights: (Nv, 2) float32 — (head, tail) skinning weights, linearly
|
||||
interpolated by axial position (sums to 1).
|
||||
Returns (verts (Nv,3), faces (Nf,3), weights (Nv,2) head/tail, sums to 1).
|
||||
"""
|
||||
W = max(1e-6, min(float(W), float(L) * 0.5 - 1e-6))
|
||||
if str(shape) == "capsule":
|
||||
# SCAIL-style "rig" limb: an OPEN cylinder of constant radius W,
|
||||
# no hemisphere caps. The sphere joint markers at each endpoint
|
||||
# provide the rounded ends of the bone — when sphere_radius ==
|
||||
# cylinder_radius, the cylinder cross-section ring at the bone
|
||||
# endpoint lies exactly on the sphere surface, so silhouette is
|
||||
# seamless. Hemisphere caps would create a visible bump where
|
||||
# the cap pokes out of the sphere if cap_r ≠ marker_r, so we
|
||||
# omit them entirely.
|
||||
# Open cylinder, no caps — sphere markers cap the ends (see docstring).
|
||||
cap_r = 0.0
|
||||
body_r = W
|
||||
if n_cap_lat is None:
|
||||
@ -425,7 +368,7 @@ def _capsule_mesh_local(
|
||||
end_frac = float(min(0.95, max(0.05, end_width_frac)))
|
||||
cap_r = max(1e-7, W * end_frac)
|
||||
body_r = W
|
||||
# Ellipsoid defaults: more body rings to sample the sin(π·u) curve.
|
||||
# More body rings to sample the sin(π·u) curve.
|
||||
if n_cap_lat is None:
|
||||
n_cap_lat = 3
|
||||
if n_body is None:
|
||||
@ -473,10 +416,7 @@ def _capsule_mesh_local(
|
||||
phi = 2.0 * np.pi * k / n_lon
|
||||
verts.append([body_r * float(np.cos(phi)), 0.0, body_r * float(np.sin(phi))])
|
||||
|
||||
# Body intermediate rings (between the cap junctions for capped meshes,
|
||||
# between the two end rings for open cylinders). For 'capsule' mode
|
||||
# n_body=0 by default — no intermediate rings needed for a constant-
|
||||
# radius cylinder.
|
||||
# Body intermediate rings (none for 'capsule', n_body=0 by default).
|
||||
body_rings: List[int] = []
|
||||
is_ellipsoid = str(shape) == "ellipsoid"
|
||||
for j in range(1, n_body + 1):
|
||||
@ -572,11 +512,8 @@ def _scail_redirect_neck_stub(body_kp: np.ndarray) -> np.ndarray:
|
||||
def _openpose_limb_rest_trs(
|
||||
bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...],
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Per-limb rest TRS:
|
||||
midpoints (K_pairs, 3): rest midpoint between bind_kp_m[a] and bind_kp_m[b].
|
||||
rest_axes (K_pairs, 3): unit direction a→b at rest (or +Y if degenerate).
|
||||
Caller uses `midpoints` as each limb joint's rest translation (rotation =
|
||||
identity), and `rest_axes` to compute per-frame alignment rotations."""
|
||||
"""Per-limb rest TRS: midpoints (K_pairs, 3) and unit a→b axes (or +Y if
|
||||
degenerate). Caller uses midpoints as rest translation, axes for alignment."""
|
||||
K_pairs = len(pairs)
|
||||
mid = np.zeros((K_pairs, 3), dtype=np.float32)
|
||||
axis = np.zeros((K_pairs, 3), dtype=np.float32)
|
||||
@ -595,13 +532,10 @@ def _openpose_limb_rest_trs(
|
||||
def _openpose_limb_anim_trs(
|
||||
kp_seq: np.ndarray, pairs: Tuple[Tuple[int, int], ...], rest_axes: np.ndarray,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Per-frame limb TRS:
|
||||
anim_mid (N, K_pairs, 3): midpoint of (kp_seq[t][a], kp_seq[t][b]).
|
||||
anim_quat (N, K_pairs, 4): rotation (xyzw) that aligns each limb's rest
|
||||
axis to its frame-t axis.
|
||||
Together with rest TRS, this drives `skin_matrix(t) = T(mid_t) * R_t *
|
||||
T(-mid_rest)` so each capsule rigidly rotates about its rest midpoint to
|
||||
track the limb's current direction — no LBS cross-section thinning."""
|
||||
"""Per-frame limb TRS: anim_mid (N, K_pairs, 3) midpoints and anim_quat
|
||||
(N, K_pairs, 4 xyzw) aligning each limb's rest axis to its frame-t axis.
|
||||
Drives skin_matrix(t) = T(mid_t)·R_t·T(-mid_rest) — rigid rotation about
|
||||
the rest midpoint, no LBS cross-section thinning."""
|
||||
N = kp_seq.shape[0]
|
||||
K_pairs = len(pairs)
|
||||
anim_mid = np.zeros((N, K_pairs, 3), dtype=np.float32)
|
||||
@ -616,7 +550,7 @@ def _openpose_limb_anim_trs(
|
||||
n = float(np.linalg.norm(d))
|
||||
if n > 1e-9:
|
||||
R[t, k] = rotation_align(ax_rest, d / n)
|
||||
quat = rotmat_to_quat_np(R).astype(np.float32) # (N, K_pairs, 4) xyzw
|
||||
quat = rotmat_to_quat_np(R).astype(np.float32)
|
||||
return anim_mid, quat
|
||||
|
||||
|
||||
@ -628,20 +562,14 @@ def _build_openpose_sticks(
|
||||
smooth_shade: bool = False,
|
||||
end_width_frac: float = 0.3,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Capsule (cylinder + hemispherical caps) per limb pair (a, b).
|
||||
"""Capsule per limb pair (a, b), each sized to its own length/width so caps
|
||||
are true hemispheres regardless of L:W. Ellipsoid mode auto-clamps width to
|
||||
`length * 0.1` so short limbs don't look chunky.
|
||||
|
||||
Each limb gets its own mesh sized to that limb's length and width so
|
||||
the caps are TRUE hemispheres of radius `half_width_eff` — the limb
|
||||
silhouette is rounded-rectangle-like, regardless of L:W ratio. Width
|
||||
auto-clamped to `length * 0.1` so short limbs (face/ear) don't look
|
||||
chunky next to long ones.
|
||||
|
||||
Skinning: rigid (weight=1) binding to a per-limb joint at
|
||||
`limb_joint_base_idx + limb_idx` — the caller animates that joint with
|
||||
midpoint translation + rest-to-current rotation so each capsule rotates
|
||||
rigidly with its limb (avoids translation-only LBS cross-section
|
||||
thinning). Returns flat-shaded (verts, normals, faces, joints4,
|
||||
weights4, vert_colors)."""
|
||||
Rigid (weight=1) binding to a per-limb joint at `limb_joint_base_idx +
|
||||
limb_idx`, which the caller animates with midpoint translation + rotation
|
||||
(avoids LBS thinning). Returns (verts, normals, faces, joints4, weights4,
|
||||
vert_colors)."""
|
||||
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
||||
|
||||
out_v_chunks: List[np.ndarray] = []
|
||||
@ -663,13 +591,10 @@ def _build_openpose_sticks(
|
||||
unit_dir = direction / length
|
||||
R = rotation_align(canonical, unit_dir)
|
||||
if is_capsule:
|
||||
# SCAIL-style uniform radius — every bone gets the same width.
|
||||
# `_capsule_mesh_local` clamps internally to L/2-eps so very
|
||||
# short bones don't go degenerate.
|
||||
# Uniform radius — every bone the same width (clamped internally).
|
||||
half_width_eff = max(MIN_WIDTH, half_width_m)
|
||||
else:
|
||||
# Ellipsoid mode: original auto-thinning so short face/ear
|
||||
# limbs don't look chunky next to long body limbs.
|
||||
# Auto-thin so short face/ear limbs aren't chunky next to body limbs.
|
||||
half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m))
|
||||
|
||||
v_local, f_local, _weights_unused = _capsule_mesh_local(
|
||||
@ -678,10 +603,8 @@ def _build_openpose_sticks(
|
||||
v_world = v_local @ R.T + head
|
||||
Nv = v_local.shape[0]
|
||||
|
||||
# Rigid binding to the per-limb joint. The 2-bone (head, tail) weights
|
||||
# from `_capsule_mesh_local` are discarded — they're translation-only
|
||||
# under glTF LBS and don't rotate the cross-section, causing visible
|
||||
# thinning when the limb axis changes between rest and animated pose.
|
||||
# Rigid binding to the per-limb joint; the 2-bone weights are discarded
|
||||
# (translation-only under LBS, would thin the cross-section).
|
||||
j_arr = np.zeros((Nv, 4), dtype=np.uint16)
|
||||
j_arr[:, 0] = limb_idx + limb_joint_base_idx
|
||||
w_arr = np.zeros((Nv, 4), dtype=np.float32)
|
||||
@ -730,40 +653,24 @@ def build_glb_openpose(
|
||||
stick_end_width_frac: float = 0.6,
|
||||
bone_smooth_window: int = 0,
|
||||
) -> bytes:
|
||||
"""Build a GLB containing an OpenPose-style 3D skeleton — sphere markers
|
||||
per keypoint plus rainbow-colored sticks between standard limb pairs.
|
||||
Body keypoints are sourced from pose_data's `pred_keypoints_3d` (no rig
|
||||
forward needed). Optional hand keypoints (also from `pred_keypoints_3d`)
|
||||
and face landmarks (sampled from `pred_vertices` at fixed head-mesh
|
||||
vertex IDs) extend the same per-track armature.
|
||||
"""Build a GLB of an OpenPose-style 3D skeleton — sphere markers per keypoint
|
||||
plus colored sticks between limb pairs, one armature per track. Body from
|
||||
`pred_keypoints_3d`; optional hands (same source) and face landmarks
|
||||
(`pred_vertices`) extend each armature.
|
||||
|
||||
Args:
|
||||
include_hands: append the standard 21+21 OpenPose hand keypoints to
|
||||
each track's armature (right hand at MHR70 indices 21..41,
|
||||
left at 42..62).
|
||||
hand_marker_radius_m: per-hand sphere radius. 0 = auto = 0.4 ×
|
||||
`marker_radius_m` (hand keypoints are anatomically smaller than
|
||||
body joints; matches DWPose's smaller hand dots).
|
||||
hand_stick_radius_m: per-hand limb half-width. 0 = auto = 0.5 ×
|
||||
`stick_radius_m`.
|
||||
hand_color_style: 'dwpose' (default) = solid-blue hand dots,
|
||||
rainbow per-finger sticks (controlnet_aux/dwpose convention);
|
||||
'openpose' = rainbow per-finger dots AND sticks (matches
|
||||
poseParameters.cpp::HAND_COLORS_RENDER).
|
||||
face_style: 'disabled' (default) | 'full' | 'eyes_mouth' — face
|
||||
landmarks sampled from `pred_vertices` at vertex IDs picked from
|
||||
`pose_data["canonical_colors"]["positions"]`. 'full' = all ~30
|
||||
contour points; 'eyes_mouth' = the eyes + outer-lip subset.
|
||||
face_marker_radius_m: per-face landmark sphere radius. 0 = auto =
|
||||
0.3 × `marker_radius_m` — face landmarks are densely packed
|
||||
around the eyes/mouth/jaw and need to be much smaller than
|
||||
body keypoints to keep the layout legible. Face landmarks are
|
||||
rendered as standalone dots (no contour lines), matching
|
||||
DWPose's face_pose draw style.
|
||||
palette: body color scheme. 'openpose' = standard rainbow gradient
|
||||
per keypoint (canonical OpenPose convention); 'scail' =
|
||||
SCAIL-Pose style — warm hues right side, cool hues left side,
|
||||
grey neck-to-nose centerline, distinct per-limb colors.
|
||||
include_hands: append the 21+21 OpenPose hand keypoints per track.
|
||||
hand_marker_radius_m: hand sphere radius. 0 = auto = 0.4 × marker_radius_m.
|
||||
hand_stick_radius_m: hand limb half-width. 0 = auto = 0.5 × stick_radius_m.
|
||||
hand_color_style: 'dwpose' (default) = solid-blue dots + rainbow sticks;
|
||||
'openpose' = rainbow dots AND sticks.
|
||||
face_style: 'disabled' (default) | 'full' (~30 contour pts) | 'eyes_mouth'
|
||||
(eyes + outer-lip subset); sampled at vertex IDs from
|
||||
`canonical_colors["positions"]`.
|
||||
face_marker_radius_m: face landmark sphere radius. 0 = auto = 0.3 ×
|
||||
marker_radius_m. Rendered as dots only, no contour lines.
|
||||
palette: 'openpose' = rainbow gradient per keypoint; 'scail' = warm right
|
||||
/ cool left, grey centerline, distinct per-limb colors.
|
||||
"""
|
||||
is_scail = str(palette) == "scail"
|
||||
# SCAIL drops the face bones (13..16) and eye/ear spheres; keeps nose (idx 0,
|
||||
@ -771,13 +678,11 @@ def build_glb_openpose(
|
||||
body_pairs = OPENPOSE_18_PAIRS[:13] if is_scail else OPENPOSE_18_PAIRS
|
||||
body_sphere_kp = (np.arange(14, dtype=np.int64)
|
||||
if is_scail else np.arange(18, dtype=np.int64))
|
||||
if str(palette) == "scail":
|
||||
if is_scail:
|
||||
body_sphere_colors = SCAIL_KEYPOINT_COLORS_18
|
||||
body_stick_colors = SCAIL_LIMB_COLORS_17
|
||||
elif str(palette) == "openpose":
|
||||
# Existing OpenPose behavior: same rainbow array used for both
|
||||
# spheres (per-keypoint) and sticks (per-limb, indexed 0..16 of
|
||||
# the 18-element rainbow — yields a legible per-limb gradient).
|
||||
# Same rainbow array drives both spheres and sticks.
|
||||
body_sphere_colors = OPENPOSE_RAINBOW_18
|
||||
body_stick_colors = OPENPOSE_RAINBOW_18
|
||||
else:
|
||||
@ -892,13 +797,9 @@ def build_glb_openpose(
|
||||
if bone_smooth_window and bone_smooth_window > 1:
|
||||
kp_seq = gaussian_smooth_positions(kp_seq, int(bone_smooth_window))
|
||||
|
||||
# Static-bind = rig's REST pose when available (override path); else
|
||||
# fall back to frame 0 of the motion. The rest-pose bind makes the
|
||||
# GLB's static POSITION attribute sit at rig origin, so viewers
|
||||
# auto-fit/center on rig origin and the animation visibly snaps from
|
||||
# rest to scene-frame-0 — matching skeletal mode's behavior. Without
|
||||
# this, openpose's static geometry is at scene-frame-0 and viewers
|
||||
# mis-center on the scene location, masking the motion entirely.
|
||||
# Static-bind = rig REST pose when available, else frame 0. The rest
|
||||
# bind keeps static POSITION at rig origin so viewers auto-center there
|
||||
# and the motion is visible (see _openpose_bind_at_rig_rest).
|
||||
bind_kp_m_rest = _openpose_bind_at_rig_rest(
|
||||
pose_data, include_hands=include_hands, face_vert_ids=face_vert_ids,
|
||||
)
|
||||
@ -914,7 +815,7 @@ def build_glb_openpose(
|
||||
person_root_idx = len(nodes) - 1
|
||||
scene_root_indices.append(person_root_idx)
|
||||
|
||||
# K keypoint joint nodes (spheres bind here, rigid translation only).
|
||||
# K keypoint joint nodes (spheres bind here, translation only).
|
||||
joint_node_indices: List[int] = []
|
||||
for j in range(K):
|
||||
nodes.append({
|
||||
@ -926,9 +827,7 @@ def build_glb_openpose(
|
||||
joint_node_indices.append(len(nodes) - 1)
|
||||
person_root["children"].extend(joint_node_indices)
|
||||
|
||||
# Per-limb REST TRS (midpoint + axis) and per-frame TRS (midpoint +
|
||||
# quaternion that aligns rest-axis → frame-t-axis). Sticks bind
|
||||
# rigidly to these joints so each capsule rotates with its limb.
|
||||
# Per-limb rest + per-frame TRS; sticks bind rigidly to these joints.
|
||||
limb_rest_mids_list: List[np.ndarray] = []
|
||||
limb_rest_axes_list: List[np.ndarray] = []
|
||||
limb_anim_mids_list: List[np.ndarray] = []
|
||||
@ -951,12 +850,10 @@ def build_glb_openpose(
|
||||
limb_rest_axes_list.append(raxis_h)
|
||||
limb_anim_mids_list.append(amid_h)
|
||||
limb_anim_quats_list.append(aquat_h)
|
||||
limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0) # (K_limbs, 3)
|
||||
limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1) # (N, K_limbs, 3)
|
||||
limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1) # (N, K_limbs, 4)
|
||||
# Hemisphere-align consecutive quats per limb so LINEAR interpolation
|
||||
# takes the short path (otherwise large per-frame rotations can flip
|
||||
# signs and produce visible "twist back" artifacts mid-playback).
|
||||
limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0)
|
||||
limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1)
|
||||
limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1)
|
||||
# Hemisphere-align consecutive quats so LINEAR interp takes the short path.
|
||||
limb_anim_quats = quat_sign_fix_per_joint(limb_anim_quats).astype(np.float32)
|
||||
|
||||
limb_joint_indices: List[int] = []
|
||||
@ -970,8 +867,8 @@ def build_glb_openpose(
|
||||
limb_joint_indices.append(len(nodes) - 1)
|
||||
person_root["children"].extend(limb_joint_indices)
|
||||
|
||||
# Combined skin: keypoint joints (IBM = T(-bind_kp_m)) then limb joints
|
||||
# (IBM = T(-limb_rest_mid)). Both yield identity skin_matrix at rest.
|
||||
# Combined skin: keypoint joints then limb joints; IBM = T(-rest) for
|
||||
# both, yielding identity skin_matrix at rest.
|
||||
all_joint_indices = joint_node_indices + limb_joint_indices
|
||||
ibm = np.tile(np.eye(4, dtype=np.float32), (K + K_limbs, 1, 1))
|
||||
ibm[:K, :3, 3] = -bind_kp_m
|
||||
@ -985,10 +882,8 @@ def build_glb_openpose(
|
||||
})
|
||||
skin_idx = len(skins) - 1
|
||||
|
||||
# Per-group geometry. Spheres bind to keypoint joints (base_joint_idx
|
||||
# ∈ [0, K)); sticks bind to limb joints (limb_joint_base_idx ∈
|
||||
# [K, K + K_limbs)). Groups stack body → right hand → left hand →
|
||||
# face for keypoint joints, and body → R-hand → L-hand for limbs.
|
||||
# Per-group geometry. Spheres bind to keypoint joints [0, K); sticks to
|
||||
# limb joints [K, K+K_limbs). Stacked body → R-hand → L-hand → face.
|
||||
group_meshes: List[Tuple[np.ndarray, np.ndarray, np.ndarray,
|
||||
np.ndarray, np.ndarray, np.ndarray]] = []
|
||||
sp = _build_openpose_spheres(
|
||||
@ -1008,9 +903,7 @@ def build_glb_openpose(
|
||||
group_meshes.append(st)
|
||||
|
||||
if include_hands:
|
||||
# Hand stick colors stay rainbow per-finger regardless of
|
||||
# `hand_color_style` — only the sphere dots switch to solid
|
||||
# blue under 'dwpose'. Matches controlnet_aux/dwpose/util.py.
|
||||
# Hand sticks stay rainbow per-finger; only dots switch under 'dwpose'.
|
||||
hand_pair_colors = _pair_colors_from_kp(
|
||||
OPENPOSE_HAND_PAIRS, OPENPOSE_HAND_COLORS_21, endpoint=1,
|
||||
)
|
||||
@ -1033,9 +926,7 @@ def build_glb_openpose(
|
||||
if K_face > 0:
|
||||
f_off = K_body + K_hands
|
||||
f_bind = bind_kp_m[f_off:f_off + K_face]
|
||||
# DWPose face = dots only, no contour lines
|
||||
# (controlnet_aux/dwpose/util.py::draw_facepose draws white
|
||||
# circles per landmark and never connects them).
|
||||
# DWPose face = dots only, no contour lines.
|
||||
group_meshes.append(_build_openpose_spheres(
|
||||
f_bind, float(face_marker_radius_m),
|
||||
FACE_LANDMARK_COLORS, base_joint_idx=f_off,
|
||||
@ -1087,9 +978,8 @@ def build_glb_openpose(
|
||||
"target": {"node": joint_node_indices[j], "path": "translation"},
|
||||
})
|
||||
|
||||
# Per-limb-joint translation + rotation channels. Stationary limbs
|
||||
# have their constant TRS baked into the node so they don't bloat the
|
||||
# animation buffer.
|
||||
# Per-limb-joint translation + rotation; stationary limbs bake their
|
||||
# constant TRS into the node instead of an animation channel.
|
||||
for k in range(K_limbs):
|
||||
t_k = limb_anim_mids[:, k, :].astype(np.float32)
|
||||
if (np.ptp(t_k, axis=0) < 1e-6).all():
|
||||
@ -1103,9 +993,7 @@ def build_glb_openpose(
|
||||
"target": {"node": limb_joint_indices[k], "path": "translation"},
|
||||
})
|
||||
q_k = limb_anim_quats[:, k, :].astype(np.float32)
|
||||
# ptp on the absolute value handles the +q == -q ambiguity, but
|
||||
# `quat_sign_fix_per_joint` already aligned signs so a plain ptp
|
||||
# is fine here.
|
||||
# Plain ptp is fine — signs already aligned by quat_sign_fix_per_joint.
|
||||
if (np.ptp(q_k, axis=0) < 1e-6).all():
|
||||
nodes[limb_joint_indices[k]]["rotation"] = q_k[0].tolist()
|
||||
else:
|
||||
|
||||
@ -1,15 +1,11 @@
|
||||
"""GLB export for SAM 3D Body pose_data.
|
||||
"""Shared GLB export helpers for SAM 3D Body pose_data.
|
||||
|
||||
Mode: skeletal — rebuilds the MHR 127-bone rig. Per-frame local TRS comes from
|
||||
re-running param_transform on saved mhr_model_params; rest verts from a
|
||||
zero-pose forward with the person's shape_params; sparse triplet skinning is
|
||||
compacted to glTF's max-4-influences form; facial expression is re-exposed as
|
||||
72 morph targets driven by expr_params.
|
||||
|
||||
pred_vertices/pred_cam_t are camera-y-down — un-flipped here so the GLB lives
|
||||
in glTF-spec Y-up. Pose correctives are dropped (glTF skinning can't represent
|
||||
them); deformation at extreme joint angles will differ from the SAM3DBody
|
||||
renderer by the corrective amount.
|
||||
Skeletal mode rebuilds the MHR 127-bone rig: per-frame local TRS from
|
||||
param_transform on mhr_model_params, rest verts from a zero-pose forward,
|
||||
sparse skinning compacted to glTF's 4-influence form, expression re-exposed as
|
||||
72 morph targets. Camera-y-down data is un-flipped to glTF Y-up. Pose
|
||||
correctives are dropped (glTF skinning can't represent them), so extreme joint
|
||||
angles differ from the SAM3DBody renderer by the corrective amount.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -24,12 +20,11 @@ import torch
|
||||
|
||||
from comfy_extras.sam3d_body.rasterizer import rainbow_colors_from_canonical
|
||||
|
||||
# fp32-rounded ln(2). Used as `exp(x * _LN2)` to compute 2**x bit-identically
|
||||
# to the rig's own `torch.exp(jp[..., 6:7] * _LN2)`
|
||||
# fp32-rounded ln(2); exp(x * _LN2) matches the rig's own 2**x bit-for-bit.
|
||||
_LN2 = 0.6931471824645996
|
||||
|
||||
|
||||
# Quaternion / rotation helpers (xyzw convention, matching MHR rig)
|
||||
# Quaternion / rotation helpers (xyzw, matching MHR rig)
|
||||
|
||||
def _euler_xyz_to_quat_np(angles: np.ndarray) -> np.ndarray:
|
||||
"""(roll, pitch, yaw) -> (x, y, z, w). Mirrors mhr_rig._euler_xyz_to_quat."""
|
||||
@ -96,8 +91,7 @@ def _skel_state_compose_np(s1: np.ndarray, s2: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
|
||||
"""Edge-replicate Gaussian smoothing along axis 0 (time); sigma = window/4.
|
||||
Endpoints replicate so they aren't pulled toward zero. Returns float64."""
|
||||
"""Edge-replicate Gaussian smoothing along time (sigma = window/4). float64."""
|
||||
a = np.asarray(arr, dtype=np.float64)
|
||||
n = a.shape[0]
|
||||
half = window // 2
|
||||
@ -117,9 +111,8 @@ def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
|
||||
|
||||
|
||||
def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
|
||||
"""Gaussian-smooth a (N, NJ, 4) quaternion sequence along time. Sign-aligns
|
||||
per joint first, convolves per-component, renormalizes. Suppresses multi-
|
||||
frame bone spikes at extreme poses without needing the upstream Smooth node."""
|
||||
"""Smooth a (N, NJ, 4) quaternion sequence along time: sign-align per joint,
|
||||
convolve per-component, renormalize. Calms bone spikes at extreme poses."""
|
||||
if window <= 1 or q_seq.shape[0] < 2:
|
||||
return q_seq
|
||||
out = _gaussian_smooth_time(quat_sign_fix_per_joint(q_seq), window)
|
||||
@ -128,18 +121,16 @@ def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
|
||||
|
||||
|
||||
def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray:
|
||||
"""Gaussian-smooth a (N, K, 3) position sequence along time (edge-replicate
|
||||
padding). Used to calm jittery keypoint tracks before the openpose rig
|
||||
derives sphere translations + limb TRS from them."""
|
||||
"""Smooth a (N, K, 3) position sequence along time. Calms jittery keypoint
|
||||
tracks before the openpose rig derives sphere translations + limb TRS."""
|
||||
if window <= 1 or seq.shape[0] < 2:
|
||||
return seq
|
||||
return _gaussian_smooth_time(seq, window).astype(np.float32)
|
||||
|
||||
|
||||
def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
|
||||
"""Walk (N, NJ, 4) along time, flip sign whenever consecutive frames sit
|
||||
on opposite hemispheres. Eliminates long-path slerp glitches (mid-anim
|
||||
cartwheel flip). fp64 to avoid drift; normalizes input defensively."""
|
||||
"""Walk (N, NJ, 4) along time, flipping sign when consecutive frames sit on
|
||||
opposite hemispheres. Avoids long-path slerp glitches. fp64 internally."""
|
||||
out = np.array(q_seq, dtype=np.float64, copy=True)
|
||||
norms = np.linalg.norm(out, axis=-1, keepdims=True)
|
||||
out = out / np.maximum(norms, 1e-12)
|
||||
@ -151,11 +142,9 @@ def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
def bone_locals_from_globals(rig_global: np.ndarray, parents: np.ndarray) -> np.ndarray:
|
||||
"""Globals (N, NJ, 8) + parents -> per-bone local TRS (N, NJ, 8) such that
|
||||
FK over (parents, bone_local) reproduces rig_global. local =
|
||||
inverse(parent_global) ∘ child_global makes this robust to hierarchy-
|
||||
convention mismatches: glTF FK gives back exactly rig_global even if
|
||||
`parents` doesn't match the rig's pmi-walk."""
|
||||
"""Globals (N, NJ, 8) + parents -> per-bone local TRS so FK reproduces
|
||||
rig_global. local = inverse(parent_global) ∘ child_global, robust to
|
||||
hierarchy-convention mismatches in `parents`."""
|
||||
N, NJ, _ = rig_global.shape
|
||||
bone_local = np.zeros_like(rig_global)
|
||||
for j in range(NJ):
|
||||
@ -188,8 +177,7 @@ def _quat_to_mat3_np(q: np.ndarray) -> np.ndarray:
|
||||
|
||||
def collect_tracks(pose_data: Dict[str, Any], track_index: int) -> List[Tuple[int, List[int]]]:
|
||||
"""List of (person_index, frame_indices). track_index == -1 means every
|
||||
present track; empty tracks are dropped. Same person index across frames
|
||||
is assumed same subject (Smooth/Predict enforce this on tracked bboxes)."""
|
||||
present track; empty tracks dropped. Same person index = same subject."""
|
||||
frames = pose_data["frames"]
|
||||
max_p = max((len(f) for f in frames), default=0)
|
||||
if max_p == 0:
|
||||
@ -257,8 +245,7 @@ class GLBWriter:
|
||||
return len(self.accessors) - 1
|
||||
|
||||
def add_vec3_f32_no_minmax(self, arr: np.ndarray) -> int:
|
||||
"""Morph-target POSITIONs: spec lets us skip min/max, avoiding a
|
||||
per-frame delta bbox."""
|
||||
"""Morph-target POSITIONs: spec lets us skip min/max."""
|
||||
a = np.ascontiguousarray(arr, dtype=np.float32)
|
||||
view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY)
|
||||
self.accessors.append({
|
||||
@ -288,9 +275,8 @@ class GLBWriter:
|
||||
return len(self.accessors) - 1
|
||||
|
||||
def add_scalar_f32_flat(self, arr: np.ndarray, count: int) -> int:
|
||||
"""Animation-output scalars: `count` is keyframes, not floats. Morph-
|
||||
target weight tracks store N_morph weights per keyframe as flat float32
|
||||
with count=N_keyframes."""
|
||||
"""Animation-output scalars: `count` is keyframes, not floats (morph
|
||||
weight tracks store N_morph weights per keyframe)."""
|
||||
a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1)
|
||||
view_idx = self._add_view(a.tobytes())
|
||||
self.accessors.append({
|
||||
@ -382,9 +368,8 @@ def bake_vertex_colors(
|
||||
rainbow_tilt_z_deg: float,
|
||||
pastel_mix: float,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Per-vertex RGB matching the renderer's shader preset, on the canonical
|
||||
mesh. Returns (N_v, 3) float32 in [0, 1], or None for `default` (let the
|
||||
viewer's default material handle shading)."""
|
||||
"""Per-vertex RGB matching the renderer's shader preset. Returns (N_v, 3)
|
||||
float32 in [0, 1], or None for `default` (use the viewer's material)."""
|
||||
if shader == "default" or canonical_colors is None:
|
||||
return None
|
||||
|
||||
@ -432,8 +417,8 @@ def compute_normals(verts: np.ndarray, faces: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
def _parents_from_pmi(rig: Any) -> np.ndarray:
|
||||
"""Parent index per joint from skel_pmi. pmi is (2, 266): row 0 = child,
|
||||
row 1 = parent, split into BFS levels by skel_pmi_buffer_sizes. Roots = -1."""
|
||||
"""Parent index per joint from skel_pmi ((2, 266): row 0 child, row 1
|
||||
parent, split into BFS levels by skel_pmi_buffer_sizes). Roots = -1."""
|
||||
NJ = int(rig.NUM_JOINTS)
|
||||
pmi = rig.skel_pmi.cpu().numpy()
|
||||
sizes = rig.skel_pmi_buffer_sizes.cpu().numpy().tolist()
|
||||
@ -450,47 +435,29 @@ def _parents_from_pmi(rig: Any) -> np.ndarray:
|
||||
|
||||
def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Return ``_skeleton_override`` dict if present. Non-MHR skeletons supply
|
||||
this to bypass MHR rig extraction (see ComfyUI-Kimodo). Required keys:
|
||||
parents: (NJ,) int32, -1 = root
|
||||
bind_global_m: (NJ, 8) f32 — [t.xyz | q.xyzw | scale], meters
|
||||
lbs_compact_joints: (V, 8) uint16 — pre-compacted skin influences
|
||||
lbs_compact_weights: (V, 8) f32
|
||||
lbs_compact_max_inf: int — actual max influences (≤ 8)
|
||||
rest_verts_m: (V, 3) f32
|
||||
faces: (F, 3) uint32
|
||||
Optional:
|
||||
per_frame_y_down: bool — set False if pred_joint_coords are already
|
||||
rig-native Y-up (kimodo). Default True (MHR).
|
||||
openpose18_joint_indices: (18, 2) int32 — body OpenPose-18 → joint
|
||||
index pair, resolved against per-frame
|
||||
`pred_joint_coords`. Each row is
|
||||
(joint_a, joint_b); b == -1 = single
|
||||
joint, else default midpoint of the two
|
||||
(lets producers approximate keypoints
|
||||
with no matching joint, e.g. Nose ≈
|
||||
midpoint(LeftEye, RightEye)). Enables
|
||||
`SAM3DBody_ToGLB(mode="openpose")` on
|
||||
external rigs.
|
||||
openpose18_joint_weights: (18,) f32 — optional per-keypoint blend
|
||||
weight for the (a, b) mapping above.
|
||||
Position = w*joints[a] + (1-w)*joints[b]
|
||||
when b ≥ 0 (default w=0.5 → midpoint).
|
||||
Values outside [0, 1] EXTRAPOLATE past
|
||||
the line segment — used to approximate
|
||||
landmarks with no nearby joint pair
|
||||
(e.g. ears: w=2.0 along the eye→eye
|
||||
axis puts each ear one eye-distance
|
||||
outside the corresponding eye). Ignored
|
||||
for single-joint rows (b = -1).
|
||||
openpose_hand21_r_joint_indices: (21, 2) int32 — right-hand OpenPose-21
|
||||
(wrist + 5 fingers × 4 joints, base→tip)
|
||||
→ joint index pair. Required (alongside
|
||||
the L counterpart) for openpose mode
|
||||
with include_hands=True.
|
||||
openpose_hand21_l_joint_indices: (21, 2) int32 — left-hand counterpart.
|
||||
openpose_hand21_r_joint_weights: (21,) f32 — optional, same semantics as
|
||||
`openpose18_joint_weights`.
|
||||
openpose_hand21_l_joint_weights: (21,) f32 — optional, same as above.
|
||||
this to bypass MHR rig extraction (see ComfyUI-Kimodo).
|
||||
|
||||
Required keys:
|
||||
parents: (NJ,) int32, -1 = root
|
||||
bind_global_m: (NJ, 8) f32 — [t.xyz | q.xyzw | scale], meters
|
||||
lbs_compact_joints: (V, 8) uint16 — pre-compacted skin influences
|
||||
lbs_compact_weights: (V, 8) f32
|
||||
lbs_compact_max_inf: int — actual max influences (≤ 8)
|
||||
rest_verts_m: (V, 3) f32
|
||||
faces: (F, 3) uint32
|
||||
|
||||
Optional (enable openpose mode on external rigs):
|
||||
per_frame_y_down: bool — False if pred_joint_coords are already Y-up
|
||||
(kimodo). Default True (MHR).
|
||||
openpose18_joint_indices: (18, 2) int32 — body keypoint → (a, b)
|
||||
joints, resolved against `pred_joint_coords`.
|
||||
b == -1 = single joint, else midpoint of (a, b).
|
||||
openpose18_joint_weights: (18,) f32 — blend w: w*a + (1-w)*b
|
||||
(default 0.5; outside [0,1] extrapolates; ignored
|
||||
when b == -1).
|
||||
openpose_hand21_{r,l}_joint_indices: (21, 2) int32 — per-hand keypoint
|
||||
maps; both required for include_hands=True.
|
||||
openpose_hand21_{r,l}_joint_weights: (21,) f32 — optional, same as above.
|
||||
"""
|
||||
if pose_data is None:
|
||||
return None
|
||||
@ -502,12 +469,10 @@ def extract_rig_static(model: Any, pose_data: Optional[Dict[str, Any]] = None) -
|
||||
use that instead of MHR-specific `model.head_pose.mhr` buffers."""
|
||||
override = _get_skeleton_override(pose_data)
|
||||
if override is not None:
|
||||
# External rig: caller pre-compacts skin and supplies bind global directly,
|
||||
# so we don't need MHR's PCA pose / expression bases.
|
||||
# External rig: skin pre-compacted, bind global supplied directly.
|
||||
parents = np.asarray(override["parents"], dtype=np.int32)
|
||||
rest_v = np.asarray(override["rest_verts_m"], dtype=np.float32)
|
||||
# BVH needs parent-relative bone OFFSETs (cm). MHR ships these directly;
|
||||
# external rigs only give bind globals, so derive locals from them.
|
||||
# BVH needs parent-relative bone offsets (cm); derive from bind globals.
|
||||
bind_global_m = np.asarray(override["bind_global_m"], dtype=np.float32)
|
||||
local_bind = bone_locals_from_globals(bind_global_m[None], parents)[0]
|
||||
joint_translation_offsets = (local_bind[:, :3] * 100.0).astype(np.float32)
|
||||
@ -560,29 +525,26 @@ def compact_skin_to_n(
|
||||
skin_indices: np.ndarray, vert_indices: np.ndarray, weights: np.ndarray,
|
||||
num_verts: int, max_inf: int = 8,
|
||||
) -> Tuple[np.ndarray, np.ndarray, int]:
|
||||
"""Sparse (joint, vert, weight) triplets -> dense (joints[V, max_inf],
|
||||
weights[V, max_inf]). Keeps `max_inf` largest-magnitude influences,
|
||||
renormalizes. `actual_max` lets the caller skip JOINTS_1/WEIGHTS_1 when
|
||||
nothing exceeds 4 influences."""
|
||||
"""Sparse (joint, vert, weight) triplets -> dense (joints, weights) of shape
|
||||
(V, max_inf), keeping the largest influences and renormalizing. `actual_max`
|
||||
lets the caller skip JOINTS_1/WEIGHTS_1 when nothing exceeds 4 influences."""
|
||||
joints = np.zeros((num_verts, max_inf), dtype=np.uint16)
|
||||
out_w = np.zeros((num_verts, max_inf), dtype=np.float32)
|
||||
counts = np.zeros(num_verts, dtype=np.int32)
|
||||
|
||||
if vert_indices.size:
|
||||
# lexsort secondary key first: groups by vert, weights descending within group.
|
||||
# Group by vert, weights descending within each group.
|
||||
order = np.lexsort((-weights, vert_indices))
|
||||
vi_sorted = vert_indices[order]
|
||||
sk_sorted = skin_indices[order]
|
||||
w_sorted = weights[order]
|
||||
|
||||
# Per-row rank within its vertex group: 0 at each group start, +1 elsewhere.
|
||||
# group_start[i] is True when vi_sorted[i] starts a new vertex.
|
||||
# Per-row rank within its vertex group (0 at each group start).
|
||||
n = vi_sorted.size
|
||||
group_start = np.empty(n, dtype=bool)
|
||||
group_start[0] = True
|
||||
np.not_equal(vi_sorted[1:], vi_sorted[:-1], out=group_start[1:])
|
||||
pos = np.arange(n, dtype=np.int64)
|
||||
# Position of each row's group start, broadcast forward.
|
||||
group_start_pos = np.maximum.accumulate(np.where(group_start, pos, 0))
|
||||
rank = pos - group_start_pos
|
||||
|
||||
@ -609,9 +571,8 @@ def zero_pose_rest_verts(
|
||||
model: Any, shape_params: np.ndarray, expr_zero: bool = True,
|
||||
pose_data: Optional[Dict[str, Any]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Rig with zero pose + this subject's shape -> rest verts (V, 3) in
|
||||
rig-native Y-up meters. External-skeleton path returns `rest_verts_m`
|
||||
directly (no PCA shape space to expand)."""
|
||||
"""Zero pose + this subject's shape -> rest verts (V, 3) in rig-native Y-up
|
||||
meters. External path returns `rest_verts_m` directly."""
|
||||
override = _get_skeleton_override(pose_data)
|
||||
if override is not None:
|
||||
return np.asarray(override["rest_verts_m"], dtype=np.float32)
|
||||
@ -624,14 +585,11 @@ def zero_pose_rest_verts(
|
||||
sp = torch.from_numpy(np.ascontiguousarray(shape_params, dtype=np.float32)).to(device)
|
||||
if sp.ndim == 1:
|
||||
sp = sp.unsqueeze(0)
|
||||
# mhr.forward(identity_coeffs, model_parameters, expr_coeffs):
|
||||
# identity_rest = base_shape + identity_basis @ shape;
|
||||
# cat([model_params, zeros]) through param_transform; expr added.
|
||||
# rig.forward(shape, model_params, expr); zero pose + zero expr.
|
||||
model_params = torch.zeros(1, 204, device=device, dtype=dtype)
|
||||
expr = torch.zeros(1, 72, device=device, dtype=dtype)
|
||||
verts, _ = rig(sp.to(dtype), model_params, expr, apply_correctives=False)
|
||||
# Rig outputs cm; mhr_head divides by 100 for meters. Match that.
|
||||
verts_m = verts[0].cpu().float().numpy() / 100.0
|
||||
verts_m = verts[0].cpu().float().numpy() / 100.0 # cm -> m
|
||||
return verts_m.astype(np.float32)
|
||||
|
||||
|
||||
@ -639,7 +597,7 @@ def global_skel_state_per_frame(
|
||||
model: Any, mhr_model_params: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Rig FK over a batch of mhr_model_params -> (N, NJ, 8) = (t cm, q xyzw,
|
||||
scale). Bones are shape- and expression-independent so we pass zeros."""
|
||||
scale). Bones are shape/expression-independent, so pass zeros."""
|
||||
inner = model.model if hasattr(model, "model") else model
|
||||
rig = inner.head_pose.mhr
|
||||
device = next(rig.parameters()).device
|
||||
@ -655,8 +613,8 @@ def global_skel_state_per_frame(
|
||||
|
||||
|
||||
def rotmat_to_quat_np(R: np.ndarray) -> np.ndarray:
|
||||
"""(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978 branched, largest-component
|
||||
pick for stability. Cross-frame sign-fixing is the caller's job."""
|
||||
"""(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978, largest-component pick.
|
||||
Cross-frame sign-fixing is the caller's job."""
|
||||
shape = R.shape[:-2]
|
||||
Rf = R.reshape(-1, 3, 3).astype(np.float64)
|
||||
M = Rf.shape[0]
|
||||
@ -703,14 +661,12 @@ def global_skel_state_from_pose_data(
|
||||
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
||||
NJ: int, *, joint_coords_y_down: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""Build per-frame skel_state from stored pred_global_rots + pred_joint_coords,
|
||||
bypassing rig.forward. Returns (N, NJ, 8) in METERS, MHR-native frame.
|
||||
"""Per-frame skel_state from stored pred_global_rots + pred_joint_coords,
|
||||
bypassing rig.forward. Returns (N, NJ, 8) in meters, MHR-native frame.
|
||||
|
||||
pred_global_rots are MHR-native (no y/z flip). For MHR, pred_joint_coords
|
||||
are stored y-down (post-flip), so un-flip when `joint_coords_y_down=True`.
|
||||
External skeletons (Kimodo) store y-up already → pass False. Scale
|
||||
defaults to 1 (rig scale isn't preserved in pose_data; close to 1 for
|
||||
typical body poses)."""
|
||||
pred_global_rots are MHR-native. pred_joint_coords are y-down for MHR
|
||||
(un-flipped when `joint_coords_y_down=True`); external rigs store y-up
|
||||
(pass False). Scale defaults to 1 (not preserved in pose_data)."""
|
||||
frames = pose_data["frames"]
|
||||
N = len(frame_indices)
|
||||
rotmat = np.zeros((N, NJ, 3, 3), dtype=np.float32)
|
||||
@ -731,10 +687,8 @@ def global_skel_state_from_pose_data(
|
||||
|
||||
|
||||
def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
||||
"""Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm.
|
||||
Inverse of `lbs_inverse_bind_pose` modulo precision; used as bones' static
|
||||
TRS so the rest mesh looks correct with no animation playing. External
|
||||
rig: convert override's `bind_global_m` from m → cm to match this contract."""
|
||||
"""Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm,
|
||||
used as bones' static TRS. External rig: convert `bind_global_m` m -> cm."""
|
||||
override = _get_skeleton_override(pose_data)
|
||||
if override is not None:
|
||||
bind_m = np.asarray(override["bind_global_m"], dtype=np.float32).copy()
|
||||
@ -746,13 +700,10 @@ def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> n
|
||||
|
||||
@dataclass
|
||||
class Rig:
|
||||
"""Normalized static rig for the GLB/BVH exporters, independent of where it
|
||||
came from: an MHR model (`Rig.from_pose_data(pose_data, model)`) or an inline
|
||||
`pose_data["_skeleton_override"]` (external rigs, e.g. ComfyUI-Kimodo).
|
||||
|
||||
Consumers read these fields and never branch on the source. The only
|
||||
source-dependent operation is `rest_verts_m` — MHR rest verts depend on the
|
||||
subject's `shape_params`; external rigs ship fixed rest verts.
|
||||
"""Normalized static rig for the GLB/BVH exporters, source-independent: MHR
|
||||
model or inline `pose_data["_skeleton_override"]` (external rigs). Consumers
|
||||
never branch on the source. Only `rest_verts_m` is source-dependent — MHR
|
||||
expands it from `shape_params`; external rigs ship it fixed.
|
||||
"""
|
||||
parents: np.ndarray # (NJ,) int32, -1 = root
|
||||
joint_offsets_cm: np.ndarray # (NJ, 3) parent-relative bind offsets, cm
|
||||
@ -816,9 +767,8 @@ class Rig:
|
||||
|
||||
|
||||
def ibp_from_bind_global(bind_skel_state_m: np.ndarray) -> np.ndarray:
|
||||
"""Inverse-bind MAT4 by inverting the rig's bind global (meters). Guarantees
|
||||
IBP[j] = inverse(FK over bind local TRS) — exactly what glTF skinning
|
||||
needs given bones default to the bind local TRS. Returns (NJ, 4, 4)
|
||||
"""Inverse-bind MAT4 from the rig's bind global (meters). IBP[j] =
|
||||
inverse(FK over bind local TRS), as glTF skinning needs. Returns (NJ, 4, 4)
|
||||
column-major."""
|
||||
NJ = bind_skel_state_m.shape[0]
|
||||
t = bind_skel_state_m[:, :3].astype(np.float32)
|
||||
@ -877,10 +827,8 @@ def _ibp_to_mat4(ibp_skel: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Unit UV sphere, poles ±Y. `n_lat` kept ODD by default so one ring
|
||||
lands at the equator. Default (9, 16) gives 146 verts / 288 faces — n_lon
|
||||
matches the 16-segment cylinder used by capsule limbs AND the equator
|
||||
ring aligns 1-to-1 with the cylinder end ring, so silhouettes meet flush."""
|
||||
"""Unit UV sphere, poles ±Y. `n_lat` odd so a ring lands at the equator;
|
||||
n_lon=16 matches the capsule cylinder so end rings meet flush."""
|
||||
verts: List[List[float]] = [[0.0, -1.0, 0.0]] # south pole at index 0
|
||||
for i in range(1, n_lat + 1):
|
||||
lat = -0.5 * np.pi + np.pi * i / (n_lat + 1)
|
||||
@ -924,8 +872,8 @@ def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndar
|
||||
def flat_shade_mesh(
|
||||
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Smooth -> flat by duplicating verts per face; each triangle gets 3
|
||||
unique verts sharing its face normal. Skinning attrs duplicated alongside."""
|
||||
"""Flat-shade by duplicating verts per face; each triangle gets 3 unique
|
||||
verts sharing its face normal. Skinning attrs duplicated alongside."""
|
||||
F = faces.shape[0]
|
||||
new_v = np.zeros((F * 3, 3), dtype=np.float32)
|
||||
new_n = np.zeros((F * 3, 3), dtype=np.float32)
|
||||
@ -949,9 +897,8 @@ def flat_shade_mesh(
|
||||
def smooth_shade_mesh(
|
||||
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Area-weighted per-vertex normals (smooth shading). Geometry, skinning,
|
||||
indexing pass through unchanged so vertex colors stay aligned. Orphan
|
||||
verts get +Y fallback."""
|
||||
"""Area-weighted per-vertex normals. Geometry/skinning/indexing pass through
|
||||
unchanged so vertex colors stay aligned. Orphan verts get +Y fallback."""
|
||||
Nv = int(verts.shape[0])
|
||||
v0 = verts[faces[:, 0]]
|
||||
v1 = verts[faces[:, 1]]
|
||||
@ -994,11 +941,9 @@ def rotation_align(from_vec: np.ndarray, to_vec: np.ndarray) -> np.ndarray:
|
||||
def make_lit_material(
|
||||
roughness: float = 0.85, double_sided: bool = False, opacity: float = 1.0,
|
||||
) -> dict:
|
||||
"""Lit PBR material using vertex COLOR_0 multiplicatively. KHR_materials_unlit
|
||||
is intentionally off so viewer lighting reveals surface form. metallic=0
|
||||
keeps the surface dielectric so vertex colors stay readable. roughness=0.85
|
||||
suits dense rainbow body meshes; 0.3 matches SCAIL-Pose's glossy rig look.
|
||||
opacity < 1 switches to alpha-blend (e.g. see-through body mesh over bones)."""
|
||||
"""Lit PBR material using vertex COLOR_0. Dielectric (metallic=0) so colors
|
||||
stay readable; roughness 0.85 suits rainbow body meshes, 0.3 the glossy
|
||||
SCAIL rig. opacity < 1 switches to alpha-blend."""
|
||||
a = float(max(0.0, min(1.0, opacity)))
|
||||
mat = {
|
||||
"pbrMetallicRoughness": {
|
||||
@ -1182,14 +1127,12 @@ def openpose_render_keypoints(
|
||||
person: Dict[str, Any], pose_data: Optional[Dict[str, Any]], part: str,
|
||||
*, dim: int, H: int = 0, W: int = 0,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""OpenPose keypoints for one person, in op-layout, CAMERA frame (Y-down).
|
||||
"""OpenPose keypoints for one person, op-layout, camera frame (Y-down).
|
||||
`part` in {'body','hand_r','hand_l'}. dim=3 -> (K, 3) metres pre-cam_t-add;
|
||||
dim=2 -> (K, 2) image pixels. Returns None when the source data is missing.
|
||||
dim=2 -> (K, 2) pixels. Returns None when source data is missing.
|
||||
|
||||
External rigs (override carries the joint-index map) resolve from per-frame
|
||||
`pred_joint_coords` (rig-native Y-up -> flipped to camera Y-down, matching
|
||||
the pred_vertices convention). MHR reindexes the stored
|
||||
`pred_keypoints_{3d,2d}` via the MHR70 map."""
|
||||
External rigs resolve from `pred_joint_coords` (Y-up -> flipped to Y-down);
|
||||
MHR reindexes stored `pred_keypoints_{3d,2d}` via the MHR70 map."""
|
||||
map_key, w_key, mhr_map = _OPENPOSE_RENDER_MAPS[part]
|
||||
override = _get_skeleton_override(pose_data)
|
||||
ext_map = override.get(map_key) if override is not None else None
|
||||
@ -1228,11 +1171,9 @@ def openpose_render_keypoints(
|
||||
return kp_full[mhr_map]
|
||||
|
||||
|
||||
# Face landmarks from the MHR rig (option `face_source="rig"`).
|
||||
# MHR has no face bones — face deforms via expr_params morphs — so landmarks
|
||||
# are sourced from `pred_vertices` at fixed vertex IDs picked by NN against
|
||||
# anatomically-plausible target xyz in canonical Y-up. Iterate visually in
|
||||
# Blender and tweak targets if landmarks land off-surface.
|
||||
# Face landmarks (face_source="rig"). MHR has no face bones, so landmarks are
|
||||
# sourced from `pred_vertices` at vertex IDs picked by NN against the target xyz
|
||||
# below. Tweak targets if landmarks land off-surface.
|
||||
|
||||
# (name, target_xyz) in MHR canonical Y-up meters.
|
||||
FACE_LANDMARK_TARGETS: Tuple[Tuple[str, Tuple[float, float, float]], ...] = (
|
||||
@ -1290,10 +1231,8 @@ def select_face_landmark_vert_ids(
|
||||
face_mask: Optional[np.ndarray] = None,
|
||||
) -> np.ndarray:
|
||||
"""Pick MHR head vertex IDs for each `FACE_LANDMARK_TARGETS` by NN in
|
||||
canonical positions. Filter: `face_mask` (verts that deform with any of
|
||||
the 72 expression axes) if available — keeps chin/jaw search off the
|
||||
neck. Otherwise a position bbox (less reliable; throat verts sometimes
|
||||
pull chin targets)."""
|
||||
canonical positions, restricted to `face_mask` verts (expression-deforming)
|
||||
when available, else a position bbox (less reliable around the chin/jaw)."""
|
||||
P = np.asarray(canonical_positions, dtype=np.float32).reshape(-1, 3)
|
||||
if face_mask is not None and np.asarray(face_mask).any():
|
||||
valid = np.where(np.asarray(face_mask).reshape(-1))[0]
|
||||
|
||||
@ -1,19 +1,11 @@
|
||||
"""GLB export — skeletal (real armature) mode.
|
||||
|
||||
Rebuilds an Armature with the MHR 127-bone rig:
|
||||
- per-frame local TRS comes from re-running param_transform on the saved
|
||||
`mhr_model_params`;
|
||||
- rest verts come from a zero-pose forward with each person's `shape_params`;
|
||||
- sparse triplet skinning is compacted to glTF's max-4-influences-per-vertex form;
|
||||
- facial expression is re-exposed as 72 morph targets driven by `expr_params`
|
||||
so face animation survives plain glTF skinning.
|
||||
|
||||
Optional bone visualization (octahedrons) is rigidly
|
||||
skinned alongside the body mesh — used to preview the armature in glTF
|
||||
viewers that don't draw bones.
|
||||
|
||||
Shared GLB infra (writer, math, rig static extraction, shaders, normals)
|
||||
stays in `glb_shared.py`; only this mode's geometry + assembly live here.
|
||||
Rebuilds an Armature with the MHR 127-bone rig: per-frame local TRS from
|
||||
param_transform on `mhr_model_params`, rest verts from a zero-pose forward,
|
||||
sparse skinning compacted to glTF's 4-influence form, and facial expression as
|
||||
72 morph targets driven by `expr_params`. Optional octahedron bone-vis is
|
||||
rigidly skinned alongside for viewers that don't draw bones. Shared infra lives
|
||||
in `glb_shared.py`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -44,8 +36,7 @@ from .glb_shared import (
|
||||
from comfy_extras.sam3d_body.utils import jet_colormap
|
||||
|
||||
def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]:
|
||||
"""Per-bone RGB color (NJ, 3) float32 in [0, 1]. Returns None for 'white'
|
||||
(no per-bone color → bone-vis mesh uses default unlit material)."""
|
||||
"""Per-bone RGB (NJ, 3) float32 in [0, 1]. None for 'white' (default material)."""
|
||||
if scheme == "rainbow_y":
|
||||
y = bind_pos_m[:, 1].astype(np.float32)
|
||||
y_min, y_max = float(y.min()), float(y.max())
|
||||
@ -55,9 +46,8 @@ def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray
|
||||
|
||||
|
||||
def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Canonical Blender-style bone octahedron. Head at origin, tail at +Y,
|
||||
unit length, ridge at 1/10 height. 6 verts, 8 triangles. Faces wound
|
||||
so cross(v1-v0, v2-v0) points OUTWARD from the bone axis."""
|
||||
"""Canonical Blender-style bone octahedron: head at origin, tail at +Y, unit
|
||||
length, ridge at 1/10 height. 6 verts, 8 triangles, faces wound outward."""
|
||||
v = np.array([
|
||||
[0.0, 0.0, 0.0], # 0: head
|
||||
[0.0, 1.0, 0.0], # 1: tail
|
||||
@ -78,18 +68,16 @@ def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
|
||||
def _bone_edges(
|
||||
joint_pos_m: np.ndarray, parents: np.ndarray,
|
||||
) -> List[Tuple[int, int, np.ndarray, np.ndarray]]:
|
||||
"""Return one (parent_idx, child_idx, head_pos, tail_pos) tuple per
|
||||
parent→child edge in the hierarchy, skipping edges whose PARENT is a
|
||||
root joint (those typically anchor the skeleton at world origin and
|
||||
just look like a stray stick from origin to the body). Zero-length
|
||||
edges are skipped too."""
|
||||
"""One (parent_idx, child_idx, head_pos, tail_pos) per parent→child edge.
|
||||
Skips edges whose parent is a root (world-anchor sticks) and zero-length
|
||||
edges."""
|
||||
NJ = joint_pos_m.shape[0]
|
||||
out: List[Tuple[int, int, np.ndarray, np.ndarray]] = []
|
||||
for c in range(NJ):
|
||||
p = int(parents[c])
|
||||
if not (0 <= p < NJ and p != c):
|
||||
continue
|
||||
# Skip if parent itself is a root — that bone is a world-anchor stick.
|
||||
# Skip world-anchor sticks: parent itself is a root.
|
||||
gp = int(parents[p])
|
||||
if not (0 <= gp < NJ and gp != p):
|
||||
continue
|
||||
@ -104,9 +92,8 @@ def _bone_edges(
|
||||
def _build_bone_octahedrons_mesh(
|
||||
bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""One Blender-style octahedron per parent→child edge. Returns
|
||||
(verts, normals, faces, joints, weights, child_idx_per_vert);
|
||||
child_idx feeds per-bone color lookup at the call site."""
|
||||
"""One octahedron per parent→child edge. Returns (verts, normals, faces,
|
||||
joints, weights, child_idx_per_vert); child_idx feeds per-bone color."""
|
||||
base_v, base_f = _octahedron_unit()
|
||||
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
||||
|
||||
@ -117,8 +104,7 @@ def _build_bone_octahedrons_mesh(
|
||||
out_w: List[List[float]] = []
|
||||
child_per_vert: List[int] = []
|
||||
|
||||
# Width scales with length so short bones (fingers, face) don't look chunky
|
||||
# next to long ones (limbs, spine). `half_width_m` caps long bones.
|
||||
# Width scales with length (capped by half_width_m) so short bones aren't chunky.
|
||||
WIDTH_RATIO = 0.1
|
||||
MIN_WIDTH = 0.001
|
||||
for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents):
|
||||
@ -151,8 +137,8 @@ def _build_bone_octahedrons_mesh(
|
||||
out_n.extend(n_world.tolist())
|
||||
for face in base_f:
|
||||
out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off])
|
||||
# Dual skin head→parent, tail→child, ridges blend by canonical Y so the
|
||||
# bone stretches between joints instead of going rigid with one.
|
||||
# Dual skin (head→parent, tail→child); ridges blend by canonical Y so
|
||||
# the bone stretches between joints instead of going rigid with one.
|
||||
for k in range(base_v.shape[0]):
|
||||
y_canon = float(base_v[k, 1])
|
||||
w_parent = max(0.0, 1.0 - y_canon)
|
||||
@ -196,22 +182,17 @@ def build_glb_skeletal(
|
||||
bone_vis_color: str = "white",
|
||||
include_body_mesh: bool = True,
|
||||
) -> bytes:
|
||||
"""Build pose_data as a real Armature GLB blob with per-bone TRS keyframes.
|
||||
"""Build pose_data as a real Armature GLB with per-bone TRS keyframes. For
|
||||
MHR, facial expression is exposed as 72 morph targets when
|
||||
include_face_morphs=True.
|
||||
|
||||
For MHR (default) facial expression is exposed as 72 morph targets driven
|
||||
by expr_params per frame when include_face_morphs=True.
|
||||
|
||||
External skeletons (e.g. ComfyUI-Kimodo) can supply a
|
||||
``pose_data["_skeleton_override"]`` dict to bypass the MHR rig extraction
|
||||
entirely. When present, ``model`` may be None and the rig data, bind pose,
|
||||
skin weights, and rest verts come from the override. Per-frame skeletal
|
||||
state still reads ``pred_global_rots`` / ``pred_joint_coords`` from each
|
||||
person dict (kimodo populates these from its own FK output). See
|
||||
``glb.shared._get_skeleton_override`` for the override schema.
|
||||
External skeletons (e.g. ComfyUI-Kimodo) can supply
|
||||
``pose_data["_skeleton_override"]`` to bypass MHR rig extraction (``model``
|
||||
may be None then); per-frame state still reads ``pred_global_rots`` /
|
||||
``pred_joint_coords``. See ``glb_shared._get_skeleton_override`` for the schema.
|
||||
"""
|
||||
frames = pose_data["frames"]
|
||||
# Only `pred_cam_t` is camera-y-down; mhr_model_params, lbs_*, expr_basis,
|
||||
# faces are all rig-native (Y-up).
|
||||
# Only `pred_cam_t` is camera-y-down; everything else is rig-native Y-up.
|
||||
faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32)
|
||||
tracks = collect_tracks(pose_data, track_index)
|
||||
if not tracks:
|
||||
@ -219,17 +200,14 @@ def build_glb_skeletal(
|
||||
|
||||
rig = Rig.from_pose_data(pose_data, model)
|
||||
NJ = rig.num_joints
|
||||
# NV = rig.num_verts
|
||||
NEXPR = rig.num_expr
|
||||
parents = rig.parents
|
||||
if not rig.can_rerun_fk:
|
||||
# External rigs have no PCA pose params to re-run; only stored globals
|
||||
# are available, and they store joint coords already Y-up.
|
||||
# External rigs have no PCA pose params to re-run; use stored globals.
|
||||
use_stored_global_rots = True
|
||||
joint_coords_y_down = rig.per_frame_y_down
|
||||
# Skinning is already compacted to ≤8 influences per vertex (MHR averages
|
||||
# ~2.8 but some shoulder/hip verts hit 5-8; keeping only 4 there leaks
|
||||
# per-bone rotation noise into the rendered mesh).
|
||||
# Skin already compacted to ≤8 influences/vertex (some shoulder/hip verts
|
||||
# need >4, else per-bone rotation noise leaks into the mesh).
|
||||
joints_8 = rig.lbs_joints
|
||||
weights_8 = rig.lbs_weights
|
||||
actual_max_inf = rig.lbs_max_inf
|
||||
@ -238,14 +216,12 @@ def build_glb_skeletal(
|
||||
use_set1 = actual_max_inf > 4
|
||||
joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None
|
||||
weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None
|
||||
# Derive bone locals from the rig's bind globals rather than recomputing
|
||||
# FK ourselves, so any mismatch between `parents` and the rig's actual FK
|
||||
# is absorbed into the local TRS instead of producing wrong globals.
|
||||
# Derive bone locals from bind globals so any `parents`/FK mismatch is
|
||||
# absorbed into the local TRS instead of producing wrong globals.
|
||||
bind_global_m = rig.bind_global_m
|
||||
bind_local = bone_locals_from_globals(bind_global_m[None], parents)[0]
|
||||
|
||||
# IBP = inverse of bind global. With bone defaults set to bind_local and
|
||||
# FK composed via `parents`, skin_matrix at rest = identity.
|
||||
# IBP = inverse of bind global → skin_matrix at rest is identity.
|
||||
ibp_mat4 = ibp_from_bind_global(bind_global_m)
|
||||
|
||||
w = GLBWriter()
|
||||
@ -316,9 +292,7 @@ def build_glb_skeletal(
|
||||
body_mesh_node_idx: Optional[int] = None
|
||||
|
||||
if include_body:
|
||||
# MHR rest verts depend on the subject's shape_params; external rigs
|
||||
# ship fixed rest verts and ignore the arg (so the empty external
|
||||
# `shape_params` is harmless).
|
||||
# MHR rest verts depend on shape_params; external rigs ignore the arg.
|
||||
shape_params_arr = np.asarray(
|
||||
frames[frame_indices[0]][person_k].get("shape_params", []),
|
||||
dtype=np.float32,
|
||||
@ -349,8 +323,8 @@ def build_glb_skeletal(
|
||||
"indices": indices_acc,
|
||||
"mode": 4,
|
||||
}
|
||||
# See-through body when bones are shown, else opaque (only when a
|
||||
# vertex-color shader baked COLOR_0 — otherwise default material).
|
||||
# See-through body when bones are shown, else opaque (only if a
|
||||
# shader baked COLOR_0; otherwise default material).
|
||||
if color_acc is not None or include_bones:
|
||||
materials.append(make_lit_material(opacity=0.35 if include_bones else 1.0))
|
||||
primitive["material"] = len(materials) - 1
|
||||
@ -373,8 +347,7 @@ def build_glb_skeletal(
|
||||
if include_bones:
|
||||
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
|
||||
|
||||
# Indexes `bone_palette`: octahedrons use the bone's child joint so
|
||||
# every bone has its own color regardless of skin target.
|
||||
# Color by child joint so every bone has its own color.
|
||||
color_idx_per_vert: Optional[np.ndarray] = None
|
||||
hw = float(bone_vis_radius_m)
|
||||
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
|
||||
@ -422,8 +395,8 @@ def build_glb_skeletal(
|
||||
nodes.append(bv_mesh_node)
|
||||
person_root["children"].append(len(nodes) - 1)
|
||||
|
||||
# Per-frame GLOBAL skel state → bone locals via parent-inverse.
|
||||
# Default uses the rig's stored output; the fallback re-runs FK.
|
||||
# Per-frame global skel state → bone locals via parent-inverse. Stored
|
||||
# output by default; fallback re-runs FK.
|
||||
if use_stored_global_rots:
|
||||
rig_global_m = global_skel_state_from_pose_data(
|
||||
pose_data, frame_indices, person_k, NJ,
|
||||
@ -437,11 +410,9 @@ def build_glb_skeletal(
|
||||
rig_global_cm = global_skel_state_per_frame(model, mp_per_frame)
|
||||
rig_global_m = rig_global_cm.copy().astype(np.float32)
|
||||
rig_global_m[..., :3] *= 0.01
|
||||
# Sign-fix on the GLOBAL quats BEFORE deriving locals. The rig's
|
||||
# Euler-XYZ parametrization wraps at ±180° for spinning joints; if we
|
||||
# only fix locals, the parent's flip propagates into the child's
|
||||
# local translation (t_local inherits parent sign via q_parent_inv)
|
||||
# and produces visible "axis resets" mid-animation.
|
||||
# Sign-fix global quats BEFORE deriving locals: a parent's ±180° flip
|
||||
# would otherwise propagate into the child's local translation and cause
|
||||
# visible "axis resets" mid-animation.
|
||||
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
|
||||
bone_local_anim = bone_locals_from_globals(rig_global_m, parents)
|
||||
local_t = bone_local_anim[..., :3].astype(np.float32)
|
||||
@ -449,20 +420,17 @@ def build_glb_skeletal(
|
||||
local_s = bone_local_anim[..., 7].astype(np.float32)
|
||||
# Second pass on locals catches residual drift from the parent-inverse.
|
||||
local_q = quat_sign_fix_per_joint(local_q)
|
||||
# Hemisphere-align frame 0 with the bind quat so pause/play takes the
|
||||
# short path; then re-propagate.
|
||||
# Align frame 0 with the bind quat so pause/play takes the short path.
|
||||
bind_q = bind_local[:, 3:7].astype(np.float32)
|
||||
if local_q.shape[0] > 0:
|
||||
d0 = (bind_q * local_q[0]).sum(axis=-1)
|
||||
sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None]
|
||||
local_q[0] = local_q[0] * sign0
|
||||
local_q = quat_sign_fix_per_joint(local_q)
|
||||
# Optional smoothing for multi-frame rig spikes (e.g. q.w discontinuity
|
||||
# at handstand) that the upstream Smooth node may not catch.
|
||||
# Optional smoothing for multi-frame rig spikes (e.g. q.w at handstand).
|
||||
if bone_smooth_window and bone_smooth_window > 1:
|
||||
local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window))
|
||||
# fp64 renormalize → fp32 keyframes. Viewers' nlerp amplifies non-unit
|
||||
# drift into visible flips otherwise.
|
||||
# fp64 renormalize → fp32; viewers' nlerp amplifies non-unit drift.
|
||||
lq64 = local_q.astype(np.float64)
|
||||
lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12)
|
||||
local_q = lq64.astype(np.float32)
|
||||
@ -527,7 +495,7 @@ def build_glb_skeletal(
|
||||
"target": {"node": person_root_idx, "path": "translation"},
|
||||
})
|
||||
|
||||
# Body-mesh-only: bone-vis primitives have no morph targets.
|
||||
# Body mesh only — bone-vis primitives have no morph targets.
|
||||
if expr_morph_accs and body_mesh_node_idx is not None:
|
||||
expr_per_frame = np.stack([
|
||||
np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32)
|
||||
|
||||
@ -34,8 +34,7 @@ def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
|
||||
def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
|
||||
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
|
||||
resolution. Returns (per_frame_bboxes, per_frame_masks) or
|
||||
(None, None) when the track is empty / frame count doesn't match"""
|
||||
resolution. Returns (None, None) on empty track / frame-count mismatch."""
|
||||
|
||||
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
|
||||
if packed is None:
|
||||
@ -100,7 +99,7 @@ def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[to
|
||||
def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any],
|
||||
H: int, W: int) -> Dict[str, Any]:
|
||||
"""Re-project every frame's pose through a Load3D 6DOF camera (position/
|
||||
target/zoom + optional FOV). Returns a new mhr_pose_data; unchanged on
|
||||
target/zoom + optional FOV). Returns a new mhr_pose_data, unchanged on
|
||||
empty/invalid input."""
|
||||
first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else []
|
||||
if not first_frame:
|
||||
@ -158,16 +157,16 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str,
|
||||
y_axis = np.cross(z_axis, x_axis)
|
||||
R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32)
|
||||
|
||||
# Eye: dolly along the given offset; for a rotation-only camera (position ==
|
||||
# target) keep the predicted viewing distance so only orientation/roll changes.
|
||||
# Eye: dolly along the offset; rotation-only camera keeps the predicted
|
||||
# viewing distance so only orientation/roll changes.
|
||||
if has_offset:
|
||||
eye = target + offset / max(0.01, zoom)
|
||||
else:
|
||||
d = max(0.1, float(target[2]))
|
||||
eye = target - z_axis * (d / max(0.01, zoom))
|
||||
|
||||
# Lens: use the camera's own FoV; else the SAM3D predicted focal (viewpoint-
|
||||
# only change). Three.js fov is vertical → focal from image height.
|
||||
# Lens: camera FoV if given, else the SAM3D predicted focal. Three.js fov
|
||||
# is vertical → focal from image height.
|
||||
cam_fov = float(camera_info.get("fov", 0.0) or 0.0)
|
||||
if cam_fov > 0:
|
||||
new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0)))
|
||||
@ -178,10 +177,8 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str,
|
||||
|
||||
center = np.array([W * 0.5, H * 0.5], dtype=np.float32)
|
||||
reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"}
|
||||
# External rigs (e.g. Kimodo) store pred_joint_coords rig-native Y-up; the
|
||||
# render openpose/scail keypoint provider resolves from them and flips Y/Z.
|
||||
# Transform them through the camera too (in camera space, then back to Y-up)
|
||||
# so those keypoints follow the override instead of staying in the old frame.
|
||||
# External rigs store pred_joint_coords Y-up; transform them through the
|
||||
# camera too (in camera space, then back to Y-up) so they follow the override.
|
||||
override = mhr_pose_data.get("_skeleton_override")
|
||||
joints_y_up = override is not None and not bool(override.get("per_frame_y_down", False))
|
||||
new_frames: List[List[Dict[str, Any]]] = []
|
||||
@ -242,8 +239,7 @@ def run_batched_single_chunk(inner: SAM3DBody, frames_rgb: List[torch.Tensor], p
|
||||
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
|
||||
|
||||
if per_frame_masks is not None:
|
||||
# Broadcast a single-mask bundle to per-bbox: when the user supplied one
|
||||
# mask but multiple bboxes per frame, each bbox gets the same mask.
|
||||
# One mask but multiple bboxes per frame → each bbox gets the same mask.
|
||||
flat_masks = []
|
||||
for f in range(N):
|
||||
mf = per_frame_masks[f]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user