mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Big cleanup
This commit is contained in:
parent
f1be65f914
commit
ecbaefd8fc
@ -4,25 +4,15 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
|
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
|
||||||
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask
|
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, mhr_param_hand_mask
|
||||||
|
|
||||||
from ..model.transformer import MLP
|
from ..model.transformer import MLP
|
||||||
|
|
||||||
|
|
||||||
class MHRHead(nn.Module):
|
class MHRHead(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, input_dim: int, mhr_rig, mlp_depth: int = 1, mlp_channel_div_factor: int = 8, enable_hand_model=False,
|
||||||
self,
|
device=None, dtype=None, operations=None):
|
||||||
input_dim: int,
|
|
||||||
mhr_rig,
|
|
||||||
mlp_depth: int = 1,
|
|
||||||
extra_joint_regressor: str = "",
|
|
||||||
mlp_channel_div_factor: int = 8,
|
|
||||||
enable_hand_model=False,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
operations=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Store the shared MHRRig as a non-registered Python attribute
|
# Store the shared MHRRig as a non-registered Python attribute
|
||||||
object.__setattr__(self, "mhr", mhr_rig)
|
object.__setattr__(self, "mhr", mhr_rig)
|
||||||
@ -48,9 +38,7 @@ class MHRHead(nn.Module):
|
|||||||
hidden_dim=input_dim // mlp_channel_div_factor,
|
hidden_dim=input_dim // mlp_channel_div_factor,
|
||||||
output_dim=self.npose,
|
output_dim=self.npose,
|
||||||
num_layers=mlp_depth,
|
num_layers=mlp_depth,
|
||||||
device=device,
|
device=device, dtype=dtype, operations=operations,
|
||||||
dtype=dtype,
|
|
||||||
operations=operations,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# MHR Parameters
|
# MHR Parameters
|
||||||
@ -75,28 +63,25 @@ class MHRHead(nn.Module):
|
|||||||
self.local_to_world_wrist = _p(3, 3)
|
self.local_to_world_wrist = _p(3, 3)
|
||||||
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
|
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
|
||||||
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
|
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
|
||||||
# Optional — loaded from the .safetensors if present, otherwise the
|
self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32))
|
||||||
# render path falls back to a coarse geometric approximation.
|
|
||||||
self.register_buffer(
|
|
||||||
"face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32),
|
|
||||||
)
|
|
||||||
|
|
||||||
def canonical_vertices(self, device=None):
|
def canonical_vertices(self):
|
||||||
"""Return the T-pose vertices for the mean shape (scaled to meters).
|
"""Return the T-pose vertices for the mean shape (scaled to meters).
|
||||||
|
|
||||||
Runs MHR with zero pose / shape / scale / expression so the returned
|
Runs MHR with zero pose / shape / scale / expression so the returned
|
||||||
mesh is the canonical rest pose — fixed per-model
|
mesh is the canonical rest pose — fixed per-model
|
||||||
"""
|
"""
|
||||||
dev = device or self.scale_mean.device
|
device = self.scale_mean.device
|
||||||
dt = self.scale_mean.dtype
|
dtype = self.scale_mean.dtype
|
||||||
B = 1
|
B = 1
|
||||||
global_trans = torch.zeros(B, 3, device=dev, dtype=dt)
|
global_trans = torch.zeros(B, 3, device=device, dtype=dtype)
|
||||||
global_rot = torch.zeros(B, 3, device=dev, dtype=dt)
|
global_rot = torch.zeros(B, 3, device=device, dtype=dtype)
|
||||||
body_pose = torch.zeros(B, 130, device=dev, dtype=dt)
|
body_pose = torch.zeros(B, 130, device=device, dtype=dtype)
|
||||||
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt)
|
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=device, dtype=dtype)
|
||||||
scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt)
|
scale = torch.zeros(B, self.num_scale_comps, device=device, dtype=dtype)
|
||||||
shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt)
|
shape = torch.zeros(B, self.num_shape_comps, device=device, dtype=dtype)
|
||||||
expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt)
|
expr = torch.zeros(B, self.num_face_comps, device=device, dtype=dtype)
|
||||||
|
|
||||||
verts = self.mhr_forward(
|
verts = self.mhr_forward(
|
||||||
global_trans=global_trans,
|
global_trans=global_trans,
|
||||||
global_rot=global_rot,
|
global_rot=global_rot,
|
||||||
@ -108,20 +93,6 @@ class MHRHead(nn.Module):
|
|||||||
) # single-tensor shape (1, N_v, 3) in meters
|
) # single-tensor shape (1, N_v, 3) in meters
|
||||||
return verts[0]
|
return verts[0]
|
||||||
|
|
||||||
def get_zero_pose_init(self, factor=1.0):
|
|
||||||
# Initialize pose token with zero-initialized learnable params
|
|
||||||
# Note: bias/initial value should be zero-pose in cont, not all-zeros
|
|
||||||
weights = torch.zeros(1, self.npose)
|
|
||||||
weights[:, : 6 + self.body_cont_dim] = torch.cat(
|
|
||||||
[
|
|
||||||
torch.FloatTensor([1, 0, 0, 0, 1, 0]),
|
|
||||||
compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
|
|
||||||
* factor,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
|
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
|
||||||
assert full_pose_params.shape[1] == 136
|
assert full_pose_params.shape[1] == 136
|
||||||
|
|
||||||
@ -159,12 +130,9 @@ class MHRHead(nn.Module):
|
|||||||
shape_params,
|
shape_params,
|
||||||
expr_params=None,
|
expr_params=None,
|
||||||
return_keypoints=False,
|
return_keypoints=False,
|
||||||
do_pcblend=True,
|
|
||||||
return_joint_coords=False,
|
return_joint_coords=False,
|
||||||
return_model_params=False,
|
return_model_params=False,
|
||||||
return_joint_rotations=False,
|
return_joint_rotations=False,
|
||||||
scale_offsets=None,
|
|
||||||
vertex_offsets=None,
|
|
||||||
):
|
):
|
||||||
# Align everything to the static buffers
|
# Align everything to the static buffers
|
||||||
dt = self.scale_mean.dtype
|
dt = self.scale_mean.dtype
|
||||||
@ -206,14 +174,10 @@ class MHRHead(nn.Module):
|
|||||||
shape_params = shape_params[None]
|
shape_params = shape_params[None]
|
||||||
# Convert scale...
|
# Convert scale...
|
||||||
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
|
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
|
||||||
if scale_offsets is not None:
|
|
||||||
scales = scales + scale_offsets
|
|
||||||
|
|
||||||
# Now, figure out the pose.
|
# Now, figure out the pose.
|
||||||
## 10 here is because it's more stable to optimize global translation in meters.
|
## 10 here is because it's more stable to optimize global translation in meters.
|
||||||
full_pose_params = torch.cat(
|
full_pose_params = torch.cat([global_trans * 10, global_rot, body_pose_params], dim=1) # B x 127
|
||||||
[global_trans * 10, global_rot, body_pose_params], dim=1
|
|
||||||
) # B x 127
|
|
||||||
## Put in hands
|
## Put in hands
|
||||||
if hand_pose_params is not None:
|
if hand_pose_params is not None:
|
||||||
full_pose_params = self.replace_hands_in_pose(
|
full_pose_params = self.replace_hands_in_pose(
|
||||||
@ -268,14 +232,7 @@ class MHRHead(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return tuple(to_return)
|
return tuple(to_return)
|
||||||
|
|
||||||
def forward(
|
def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None, intermediate: bool = False):
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
init_estimate: Optional[torch.Tensor] = None,
|
|
||||||
do_pcblend=True,
|
|
||||||
slim_keypoints=False,
|
|
||||||
intermediate: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: pose token with shape [B, C], usually C=DECODER.DIM
|
x: pose token with shape [B, C], usually C=DECODER.DIM
|
||||||
@ -331,7 +288,6 @@ class MHRHead(nn.Module):
|
|||||||
scale_params=pred_scale,
|
scale_params=pred_scale,
|
||||||
shape_params=pred_shape,
|
shape_params=pred_shape,
|
||||||
expr_params=pred_face,
|
expr_params=pred_face,
|
||||||
do_pcblend=do_pcblend,
|
|
||||||
return_keypoints=True,
|
return_keypoints=True,
|
||||||
return_joint_coords=True,
|
return_joint_coords=True,
|
||||||
return_model_params=True,
|
return_model_params=True,
|
||||||
@ -356,7 +312,7 @@ class MHRHead(nn.Module):
|
|||||||
# Head-MLP outputs are promoted to fp32 here so the external
|
# Head-MLP outputs are promoted to fp32 here so the external
|
||||||
# pose_output["mhr"] contract has a stable dtype regardless of what
|
# pose_output["mhr"] contract has a stable dtype regardless of what
|
||||||
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are
|
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are
|
||||||
# already fp32 from MHR's math layers; the cast on them is a no-op.
|
# already fp32 from MHR's math layers.
|
||||||
output = {
|
output = {
|
||||||
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
|
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
|
||||||
"pred_pose_rotmat": None,
|
"pred_pose_rotmat": None,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# Adapted from facebookresearch/MHR (Apache 2.0):
|
# Adapted from facebookresearch/MHR (Apache 2.0):
|
||||||
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
|
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
|
||||||
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
|
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
|
||||||
# verbatim from the TorchScript source bundled in the upstream mhr_model.pt
|
# verbatim from the upstream mhr_model.pt
|
||||||
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
|
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
|
||||||
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ def _skel_multiply(s1, s2):
|
|||||||
|
|
||||||
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
|
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
|
||||||
before composition. With many FK levels the previously-normalized quats
|
before composition. With many FK levels the previously-normalized quats
|
||||||
drift in ULPs; the JIT renormalizes defensively, so we do too to stay
|
drift in ULPs; upstream renormalizes defensively, so we do too to stay
|
||||||
bit-close to its outputs.
|
bit-close to its outputs.
|
||||||
"""
|
"""
|
||||||
t1, sc1 = s1[..., :3], s1[..., 7:8]
|
t1, sc1 = s1[..., :3], s1[..., 7:8]
|
||||||
@ -78,7 +78,7 @@ def _skel_transform_points(skel_state, points):
|
|||||||
|
|
||||||
|
|
||||||
def _global_skel_state_from_local(local, pmi_levels):
|
def _global_skel_state_from_local(local, pmi_levels):
|
||||||
"""FK walk in fp64 (matches the JIT's use_double_precision=True path).
|
"""FK walk in fp64 (matches upstream's use_double_precision=True path).
|
||||||
|
|
||||||
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
|
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
|
||||||
one per BFS level. Avoids per-call torch.split + tolist() sync.
|
one per BFS level. Avoids per-call torch.split + tolist() sync.
|
||||||
@ -95,7 +95,7 @@ def _global_skel_state_from_local(local, pmi_levels):
|
|||||||
class MHRRig(nn.Module):
|
class MHRRig(nn.Module):
|
||||||
"""Plain-PyTorch reimplementation of Meta's MHR rig.
|
"""Plain-PyTorch reimplementation of Meta's MHR rig.
|
||||||
|
|
||||||
All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's
|
All math runs in fp32 (FK upcast to fp64 internally, matching upstream's
|
||||||
use_double_precision=True backend) regardless of the host model's dtype.
|
use_double_precision=True backend) regardless of the host model's dtype.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -110,13 +110,11 @@ class MHRRig(nn.Module):
|
|||||||
POSE_CORR_HIDDEN = 3000
|
POSE_CORR_HIDDEN = 3000
|
||||||
POSE_CORR_SPARSE_NNZ = 53136
|
POSE_CORR_SPARSE_NNZ = 53136
|
||||||
|
|
||||||
def __init__(self, device=None, dtype=None, operations=None):
|
def __init__(self, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
del dtype, operations
|
|
||||||
f32 = torch.float32
|
|
||||||
|
|
||||||
# All buffers are populated by load_state_dict from the `mhr.*` keys
|
# All buffers are populated by load_state_dict from the `mhr.*` keys
|
||||||
def _p(*shape, dtype=f32):
|
def _p(*shape, dtype=torch.float32):
|
||||||
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
|
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
|
||||||
def _b(name, *shape, dtype):
|
def _b(name, *shape, dtype):
|
||||||
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
|
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
|
||||||
@ -147,10 +145,10 @@ class MHRRig(nn.Module):
|
|||||||
self._pmi_levels_cache = None
|
self._pmi_levels_cache = None
|
||||||
|
|
||||||
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
|
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
|
||||||
f32 = self.base_shape.dtype
|
dtype = self.base_shape.dtype
|
||||||
identity_coeffs = identity_coeffs.to(f32)
|
identity_coeffs = identity_coeffs.to(dtype)
|
||||||
model_parameters = model_parameters.to(f32)
|
model_parameters = model_parameters.to(dtype)
|
||||||
expr_coeffs = expr_coeffs.to(f32)
|
expr_coeffs = expr_coeffs.to(dtype)
|
||||||
B = identity_coeffs.shape[0]
|
B = identity_coeffs.shape[0]
|
||||||
|
|
||||||
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)
|
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
|
# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
|
||||||
# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity
|
# (batch6DFromXYZ, batchXYZfrom6D) are the continuity
|
||||||
# representation from Zhou et al., "On the Continuity of Rotation
|
# representation from Zhou et al., "On the Continuity of Rotation
|
||||||
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
|
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
|
||||||
# implementations from papagina/RotationContinuity:
|
# implementations from papagina/RotationContinuity:
|
||||||
@ -158,18 +158,10 @@ def _hand_masks(device):
|
|||||||
m = _HAND_MASK_CACHE.get(device)
|
m = _HAND_MASK_CACHE.get(device)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
return m
|
return m
|
||||||
mask_cont_threedofs = torch.cat(
|
mask_cont_threedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
|
||||||
[torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
|
mask_cont_onedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
|
||||||
).to(device)
|
mask_model_params_threedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
|
||||||
mask_cont_onedofs = torch.cat(
|
mask_model_params_onedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
|
||||||
[torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
|
|
||||||
).to(device)
|
|
||||||
mask_model_params_threedofs = torch.cat(
|
|
||||||
[torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
|
|
||||||
).to(device)
|
|
||||||
mask_model_params_onedofs = torch.cat(
|
|
||||||
[torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
|
|
||||||
).to(device)
|
|
||||||
m = dict(
|
m = dict(
|
||||||
mask_cont_threedofs=mask_cont_threedofs,
|
mask_cont_threedofs=mask_cont_threedofs,
|
||||||
mask_cont_onedofs=mask_cont_onedofs,
|
mask_cont_onedofs=mask_cont_onedofs,
|
||||||
@ -182,7 +174,6 @@ def _hand_masks(device):
|
|||||||
|
|
||||||
def compact_cont_to_model_params_hand(hand_cont):
|
def compact_cont_to_model_params_hand(hand_cont):
|
||||||
# These are ordered by joint, not model params ^^
|
# These are ordered by joint, not model params ^^
|
||||||
assert hand_cont.shape[-1] == 54
|
|
||||||
m = _hand_masks(hand_cont.device)
|
m = _hand_masks(hand_cont.device)
|
||||||
mask_cont_threedofs = m["mask_cont_threedofs"]
|
mask_cont_threedofs = m["mask_cont_threedofs"]
|
||||||
mask_cont_onedofs = m["mask_cont_onedofs"]
|
mask_cont_onedofs = m["mask_cont_onedofs"]
|
||||||
@ -209,120 +200,6 @@ def compact_cont_to_model_params_hand(hand_cont):
|
|||||||
return hand_model_params
|
return hand_model_params
|
||||||
|
|
||||||
|
|
||||||
def compact_model_params_to_cont_hand(hand_model_params):
|
|
||||||
# These are ordered by joint, not model params ^^
|
|
||||||
assert hand_model_params.shape[-1] == 27
|
|
||||||
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
|
|
||||||
assert sum(hand_dofs_in_order) == 27
|
|
||||||
# Mask of 3DoFs into hand_cont
|
|
||||||
mask_cont_threedofs = torch.cat(
|
|
||||||
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
|
||||||
)
|
|
||||||
# Mask of 1DoFs (including 2DoF) into hand_cont
|
|
||||||
mask_cont_onedofs = torch.cat(
|
|
||||||
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
|
||||||
)
|
|
||||||
# Mask of 3DoFs into hand_model_params
|
|
||||||
mask_model_params_threedofs = torch.cat(
|
|
||||||
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
|
||||||
)
|
|
||||||
# Mask of 1DoFs (including 2DoF) into hand_model_params
|
|
||||||
mask_model_params_onedofs = torch.cat(
|
|
||||||
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert eulers to hand_cont hand_cont
|
|
||||||
## First for 3DoFs
|
|
||||||
hand_model_params_threedofs = hand_model_params[
|
|
||||||
..., mask_model_params_threedofs
|
|
||||||
].unflatten(-1, (-1, 3))
|
|
||||||
hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
|
|
||||||
## Next for 1DoFs
|
|
||||||
hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
|
|
||||||
hand_cont_onedofs = torch.stack(
|
|
||||||
[hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
|
|
||||||
).flatten(-2, -1)
|
|
||||||
|
|
||||||
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
|
|
||||||
hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
|
|
||||||
hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
|
|
||||||
hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
|
|
||||||
|
|
||||||
return hand_cont
|
|
||||||
|
|
||||||
|
|
||||||
def batch9Dfrom6D(poses):
|
|
||||||
# Args: poses: ... x 6, where "6" is the combined first and second columns
|
|
||||||
# First, get the rotaiton matrix
|
|
||||||
x_raw = poses[..., :3]
|
|
||||||
y_raw = poses[..., 3:]
|
|
||||||
|
|
||||||
x = F.normalize(x_raw, dim=-1)
|
|
||||||
z = torch.cross(x, y_raw, dim=-1)
|
|
||||||
z = F.normalize(z, dim=-1)
|
|
||||||
y = torch.cross(z, x, dim=-1)
|
|
||||||
|
|
||||||
matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
|
|
||||||
|
|
||||||
return matrix
|
|
||||||
|
|
||||||
|
|
||||||
def batch4Dfrom2D(poses):
|
|
||||||
# Args: poses: ... x 2, where "2" is sincos
|
|
||||||
poses_norm = F.normalize(poses, dim=-1)
|
|
||||||
|
|
||||||
poses_4d = torch.stack(
|
|
||||||
[
|
|
||||||
poses_norm[..., 1],
|
|
||||||
poses_norm[..., 0],
|
|
||||||
-poses_norm[..., 0],
|
|
||||||
poses_norm[..., 1],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
) # Flattened SO2.
|
|
||||||
|
|
||||||
return poses_4d # .... x 4
|
|
||||||
|
|
||||||
|
|
||||||
def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
|
|
||||||
# fmt: off
|
|
||||||
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
|
||||||
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
|
||||||
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
|
||||||
# fmt: on
|
|
||||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
|
||||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
|
||||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
|
||||||
assert body_pose_cont.shape[-1] == (
|
|
||||||
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
|
|
||||||
)
|
|
||||||
# Get subsets
|
|
||||||
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
|
||||||
body_cont_1dofs = body_pose_cont[
|
|
||||||
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
|
|
||||||
]
|
|
||||||
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
|
|
||||||
# Convert conts to model params
|
|
||||||
## First for 3dofs
|
|
||||||
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
|
|
||||||
body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
|
|
||||||
## Next for 1dofs
|
|
||||||
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
|
|
||||||
body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
|
|
||||||
if inflate_trans:
|
|
||||||
assert (
|
|
||||||
False
|
|
||||||
), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
|
|
||||||
else:
|
|
||||||
## Nothing to do for trans
|
|
||||||
body_rotmat_trans = body_cont_trans
|
|
||||||
# Put them together
|
|
||||||
body_rotmat_params = torch.cat(
|
|
||||||
[body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
|
|
||||||
)
|
|
||||||
return body_rotmat_params
|
|
||||||
|
|
||||||
|
|
||||||
_BODY_IDX_CACHE: dict = {}
|
_BODY_IDX_CACHE: dict = {}
|
||||||
|
|
||||||
|
|
||||||
@ -349,8 +226,6 @@ def compact_cont_to_model_params_body(body_pose_cont):
|
|||||||
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
|
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
|
||||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
||||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
||||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
|
||||||
assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
|
|
||||||
# Get subsets
|
# Get subsets
|
||||||
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
||||||
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
|
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
|
||||||
@ -372,42 +247,10 @@ def compact_cont_to_model_params_body(body_pose_cont):
|
|||||||
return body_pose_params
|
return body_pose_params
|
||||||
|
|
||||||
|
|
||||||
def compact_model_params_to_cont_body(body_pose_params):
|
# Hand indices into the 133-dim param and 260-dim cont body-pose vectors.
|
||||||
# fmt: off
|
mhr_param_hand_idxs = list(range(62, 116))
|
||||||
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
mhr_cont_hand_idxs = list(range(72, 132)) + list(range(190, 238))
|
||||||
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
|
||||||
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
|
||||||
# fmt: on
|
|
||||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
|
||||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
|
||||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
|
||||||
assert body_pose_params.shape[-1] == (
|
|
||||||
num_3dof_angles + num_1dof_angles + num_1dof_trans
|
|
||||||
)
|
|
||||||
# Take out params
|
|
||||||
body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
|
|
||||||
body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
|
|
||||||
body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
|
|
||||||
# params to cont
|
|
||||||
body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
|
|
||||||
-2, -1
|
|
||||||
)
|
|
||||||
body_cont_1dofs = torch.stack(
|
|
||||||
[body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
|
|
||||||
).flatten(-2, -1)
|
|
||||||
body_cont_trans = body_params_trans
|
|
||||||
# Put them together
|
|
||||||
body_pose_cont = torch.cat(
|
|
||||||
[body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
|
|
||||||
)
|
|
||||||
return body_pose_cont
|
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
|
|
||||||
mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
|
|
||||||
mhr_param_hand_mask = torch.zeros(133).bool()
|
mhr_param_hand_mask = torch.zeros(133).bool()
|
||||||
mhr_param_hand_mask[mhr_param_hand_idxs] = True
|
mhr_param_hand_mask[mhr_param_hand_idxs] = True
|
||||||
mhr_cont_hand_mask = torch.zeros(260).bool()
|
mhr_cont_hand_mask = torch.zeros(260).bool()
|
||||||
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
|
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
|
||||||
# fmt: on
|
|
||||||
|
|||||||
@ -43,15 +43,6 @@ class FourierPositionEncoding(nn.Module):
|
|||||||
self.num_bands = num_bands
|
self.num_bands = num_bands
|
||||||
self.max_resolution = [max_resolution] * n
|
self.max_resolution = [max_resolution] * n
|
||||||
|
|
||||||
@property
|
|
||||||
def channels(self):
|
|
||||||
num_dims = len(self.max_resolution)
|
|
||||||
encoding_size = self.num_bands * num_dims
|
|
||||||
encoding_size *= 2 # sin-cos
|
|
||||||
encoding_size += num_dims # concat
|
|
||||||
|
|
||||||
return encoding_size
|
|
||||||
|
|
||||||
def forward(self, pos: torch.Tensor):
|
def forward(self, pos: torch.Tensor):
|
||||||
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
|
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
|
||||||
return fourier_pos_enc
|
return fourier_pos_enc
|
||||||
@ -118,9 +109,7 @@ class PerspectiveHead(nn.Module):
|
|||||||
pred_cam: torch.Tensor,
|
pred_cam: torch.Tensor,
|
||||||
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
|
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
|
||||||
bbox_size: torch.Tensor, # [N,], in original image space
|
bbox_size: torch.Tensor, # [N,], in original image space
|
||||||
img_size: torch.Tensor,
|
|
||||||
cam_int: torch.Tensor, # [B, 3, 3]
|
cam_int: torch.Tensor, # [B, 3, 3]
|
||||||
use_intrin_center: bool = False,
|
|
||||||
):
|
):
|
||||||
batch_size = points_3d.shape[0]
|
batch_size = points_3d.shape[0]
|
||||||
pred_cam = pred_cam.clone()
|
pred_cam = pred_cam.clone()
|
||||||
@ -133,12 +122,8 @@ class PerspectiveHead(nn.Module):
|
|||||||
focal_length = cam_int[:, 0, 0]
|
focal_length = cam_int[:, 0, 0]
|
||||||
tz = 2 * focal_length / bs
|
tz = 2 * focal_length / bs
|
||||||
|
|
||||||
if not use_intrin_center:
|
cx = 2 * (bbox_center[:, 0] - cam_int[:, 0, 2]) / bs
|
||||||
cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs
|
cy = 2 * (bbox_center[:, 1] - cam_int[:, 1, 2]) / bs
|
||||||
cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
|
|
||||||
else:
|
|
||||||
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
|
|
||||||
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
|
|
||||||
|
|
||||||
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
||||||
|
|
||||||
|
|||||||
@ -37,20 +37,15 @@ class SAM3DBody(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, device=None, dtype=None, operations=None):
|
def __init__(self, device=None, dtype=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# `operations` falls back to torch.nn so the model is constructible
|
|
||||||
# without comfy.ops; matches the pattern in comfy/ldm/sam3/.
|
|
||||||
ops = operations if operations is not None else nn
|
|
||||||
|
|
||||||
# Per-batch state populated by `_initialize_batch`.
|
# Per-batch state populated by `_initialize_batch`.
|
||||||
self._max_num_person = None
|
self._max_num_person = None
|
||||||
self._person_valid = None
|
|
||||||
|
|
||||||
self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False)
|
self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False)
|
||||||
self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False)
|
self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False)
|
||||||
|
|
||||||
self.image_size = IMAGE_SIZE
|
self.image_size = IMAGE_SIZE
|
||||||
|
|
||||||
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=ops)
|
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations)
|
||||||
embed_dims = self.backbone.embed_dims
|
embed_dims = self.backbone.embed_dims
|
||||||
|
|
||||||
# MHR rig shared between body + hand pose heads via a non-registered
|
# MHR rig shared between body + hand pose heads via a non-registered
|
||||||
@ -72,7 +67,7 @@ class SAM3DBody(nn.Module):
|
|||||||
self.head_pose.hand_pose_comps.data = (
|
self.head_pose.hand_pose_comps.data = (
|
||||||
torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
|
torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
|
||||||
)
|
)
|
||||||
self.init_pose = ops.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
|
self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs)
|
self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs)
|
||||||
self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
|
self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
|
||||||
@ -81,7 +76,7 @@ class SAM3DBody(nn.Module):
|
|||||||
self.head_pose_hand.hand_pose_comps.data = (
|
self.head_pose_hand.hand_pose_comps.data = (
|
||||||
torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
|
torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
|
||||||
)
|
)
|
||||||
self.init_pose_hand = ops.Embedding(
|
self.init_pose_hand = operations.Embedding(
|
||||||
1, self.head_pose_hand.npose, device=device, dtype=dtype
|
1, self.head_pose_hand.npose, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -93,25 +88,25 @@ class SAM3DBody(nn.Module):
|
|||||||
device=device, dtype=dtype, operations=operations,
|
device=device, dtype=dtype, operations=operations,
|
||||||
)
|
)
|
||||||
self.head_camera = PerspectiveHead(**camera_kwargs)
|
self.head_camera = PerspectiveHead(**camera_kwargs)
|
||||||
self.init_camera = ops.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype)
|
self.init_camera = operations.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs)
|
self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs)
|
||||||
self.init_camera_hand = ops.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype)
|
self.init_camera_hand = operations.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype)
|
||||||
|
|
||||||
cond_dim = 3
|
cond_dim = 3
|
||||||
init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim
|
init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim
|
||||||
linear_kwargs = dict(device=device, dtype=dtype)
|
linear_kwargs = dict(device=device, dtype=dtype)
|
||||||
self.init_to_token_mhr = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs)
|
self.init_to_token_mhr = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs)
|
||||||
self.prev_to_token_mhr = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
|
self.prev_to_token_mhr = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
|
||||||
self.init_to_token_mhr_hand = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs)
|
self.init_to_token_mhr_hand = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs)
|
||||||
self.prev_to_token_mhr_hand = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
|
self.prev_to_token_mhr_hand = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs)
|
||||||
|
|
||||||
self.prompt_encoder = PromptEncoder(
|
self.prompt_encoder = PromptEncoder(
|
||||||
embed_dim=embed_dims, # match backbone dims so PE adds directly
|
embed_dim=embed_dims, # match backbone dims so PE adds directly
|
||||||
num_body_joints=N_KEYPOINTS,
|
num_body_joints=N_KEYPOINTS,
|
||||||
device=device, dtype=dtype, operations=operations,
|
device=device, dtype=dtype, operations=operations,
|
||||||
)
|
)
|
||||||
self.prompt_to_token = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
self.prompt_to_token = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||||
|
|
||||||
decoder_kwargs = dict(
|
decoder_kwargs = dict(
|
||||||
dims=DECODER_DIM,
|
dims=DECODER_DIM,
|
||||||
@ -141,11 +136,10 @@ class SAM3DBody(nn.Module):
|
|||||||
|
|
||||||
self.keypoint_embedding_idxs = list(range(N_KEYPOINTS))
|
self.keypoint_embedding_idxs = list(range(N_KEYPOINTS))
|
||||||
self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS))
|
self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS))
|
||||||
self.keypoint_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
self.keypoint_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||||
self.keypoint_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
self.keypoint_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||||
|
|
||||||
self.hand_box_embedding = ops.Embedding(2, DECODER_DIM, **linear_kwargs)
|
self.hand_box_embedding = operations.Embedding(2, DECODER_DIM, **linear_kwargs)
|
||||||
self.hand_cls_embed = ops.Linear(DECODER_DIM, 2, **linear_kwargs)
|
|
||||||
self.bbox_embed = MLP(
|
self.bbox_embed = MLP(
|
||||||
input_dim=DECODER_DIM, hidden_dim=DECODER_DIM,
|
input_dim=DECODER_DIM, hidden_dim=DECODER_DIM,
|
||||||
output_dim=4, num_layers=3,
|
output_dim=4, num_layers=3,
|
||||||
@ -158,13 +152,13 @@ class SAM3DBody(nn.Module):
|
|||||||
)
|
)
|
||||||
self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs)
|
self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs)
|
||||||
self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs)
|
self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs)
|
||||||
self.keypoint_feat_linear = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
self.keypoint_feat_linear = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||||
self.keypoint_feat_linear_hand = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
self.keypoint_feat_linear_hand = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs)
|
||||||
|
|
||||||
self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS))
|
self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS))
|
||||||
self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS))
|
self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS))
|
||||||
self.keypoint3d_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
self.keypoint3d_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||||
self.keypoint3d_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
self.keypoint3d_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs)
|
||||||
self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs)
|
self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs)
|
||||||
self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs)
|
self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs)
|
||||||
|
|
||||||
@ -183,11 +177,9 @@ class SAM3DBody(nn.Module):
|
|||||||
def _initialize_batch(self, batch: Dict) -> None:
|
def _initialize_batch(self, batch: Dict) -> None:
|
||||||
if batch["img"].dim() == 5:
|
if batch["img"].dim() == 5:
|
||||||
self._batch_size, self._max_num_person = batch["img"].shape[:2]
|
self._batch_size, self._max_num_person = batch["img"].shape[:2]
|
||||||
self._person_valid = self._flatten_person(batch["person_valid"]) > 0
|
|
||||||
else:
|
else:
|
||||||
self._batch_size = batch["img"].shape[0]
|
self._batch_size = batch["img"].shape[0]
|
||||||
self._max_num_person = 0
|
self._max_num_person = 0
|
||||||
self._person_valid = None
|
|
||||||
|
|
||||||
def _flatten_person(self, x: torch.Tensor) -> torch.Tensor:
|
def _flatten_person(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
assert self._max_num_person is not None, "No max_num_person initialized"
|
assert self._max_num_person is not None, "No max_num_person initialized"
|
||||||
@ -258,11 +250,9 @@ class SAM3DBody(nn.Module):
|
|||||||
if is_multi_image:
|
if is_multi_image:
|
||||||
assert isinstance(img, list)
|
assert isinstance(img, list)
|
||||||
n = len(img)
|
n = len(img)
|
||||||
H_src, W_src = img[0].shape[:2]
|
|
||||||
src_t = torch.stack(list(img), dim=0)
|
src_t = torch.stack(list(img), dim=0)
|
||||||
else:
|
else:
|
||||||
n = int(left_xyxy.shape[0])
|
n = int(left_xyxy.shape[0])
|
||||||
H_src, W_src = img.shape[:2]
|
|
||||||
src_t = img.unsqueeze(0).expand(n, -1, -1, -1)
|
src_t = img.unsqueeze(0).expand(n, -1, -1, -1)
|
||||||
|
|
||||||
H_out, W_out = int(self.image_size[0]), int(self.image_size[1])
|
H_out, W_out = int(self.image_size[0]), int(self.image_size[1])
|
||||||
@ -292,14 +282,12 @@ class SAM3DBody(nn.Module):
|
|||||||
zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device)
|
zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device)
|
||||||
person_valid = torch.ones((1, n), dtype=torch.float32, device=device)
|
person_valid = torch.ones((1, n), dtype=torch.float32, device=device)
|
||||||
img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous()
|
img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous()
|
||||||
ori_img_size = torch.tensor([W_src, H_src], dtype=torch.float32, device=device).expand(n, 2).contiguous()
|
|
||||||
cam_int_dev = cam_int.to(device).to(dtype=torch.float32)
|
cam_int_dev = cam_int.to(device).to(dtype=torch.float32)
|
||||||
|
|
||||||
def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy):
|
def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy):
|
||||||
return {
|
return {
|
||||||
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
|
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
|
||||||
"img_size": img_size.unsqueeze(0),
|
"img_size": img_size.unsqueeze(0),
|
||||||
"ori_img_size": ori_img_size.unsqueeze(0),
|
|
||||||
"bbox_center": centers_t.to(device).unsqueeze(0),
|
"bbox_center": centers_t.to(device).unsqueeze(0),
|
||||||
"bbox_scale": scales_t.to(device).unsqueeze(0),
|
"bbox_scale": scales_t.to(device).unsqueeze(0),
|
||||||
"bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0),
|
"bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0),
|
||||||
@ -349,7 +337,6 @@ class SAM3DBody(nn.Module):
|
|||||||
self,
|
self,
|
||||||
branch: str,
|
branch: str,
|
||||||
image_embeddings: torch.Tensor,
|
image_embeddings: torch.Tensor,
|
||||||
init_estimate: Optional[torch.Tensor] = None,
|
|
||||||
keypoints: Optional[torch.Tensor] = None,
|
keypoints: Optional[torch.Tensor] = None,
|
||||||
prev_estimate: Optional[torch.Tensor] = None,
|
prev_estimate: Optional[torch.Tensor] = None,
|
||||||
condition_info: Optional[torch.Tensor] = None,
|
condition_info: Optional[torch.Tensor] = None,
|
||||||
@ -359,7 +346,6 @@ class SAM3DBody(nn.Module):
|
|||||||
of the pipeline is shared.
|
of the pipeline is shared.
|
||||||
|
|
||||||
image_embeddings: (B, C, H, W) backbone features.
|
image_embeddings: (B, C, H, W) backbone features.
|
||||||
init_estimate: (B, 1, C) initial pose+cam estimate to refine.
|
|
||||||
keypoints: (B, N, 3) prompts as (x, y in [0, 1], label).
|
keypoints: (B, N, 3) prompts as (x, y in [0, 1], label).
|
||||||
label: 0..K = joint, -1 = incorrect, -2 = invalid.
|
label: 0..K = joint, -1 = incorrect, -2 = invalid.
|
||||||
prev_estimate: (B, 1, C) previous estimate for pose refinement.
|
prev_estimate: (B, 1, C) previous estimate for pose refinement.
|
||||||
@ -402,15 +388,11 @@ class SAM3DBody(nn.Module):
|
|||||||
|
|
||||||
# .to(image_embeddings) moves weights CPU→GPU under dynamic loading
|
# .to(image_embeddings) moves weights CPU→GPU under dynamic loading
|
||||||
# (they stay on CPU until first use).
|
# (they stay on CPU until first use).
|
||||||
if init_estimate is None:
|
init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
|
||||||
init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
|
init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
|
||||||
init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1)
|
init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3)
|
||||||
init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3)
|
|
||||||
|
|
||||||
init_input = (
|
init_input = torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
|
||||||
torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1)
|
|
||||||
if condition_info is not None else init_estimate
|
|
||||||
)
|
|
||||||
token_embeddings = init_to_token(init_input).view(batch_size, 1, -1)
|
token_embeddings = init_to_token(init_input).view(batch_size, 1, -1)
|
||||||
num_pose_token = token_embeddings.shape[1] # always 1
|
num_pose_token = token_embeddings.shape[1] # always 1
|
||||||
|
|
||||||
@ -495,9 +477,8 @@ class SAM3DBody(nn.Module):
|
|||||||
|
|
||||||
def _get_mask_prompt(self, batch, image_embeddings):
|
def _get_mask_prompt(self, batch, image_embeddings):
|
||||||
x_mask = self._flatten_person(batch["mask"])
|
x_mask = self._flatten_person(batch["mask"])
|
||||||
# batch tensors are fp32 from prepare_batch; mask_downscaling is in the
|
|
||||||
# Loader's dtype — cast once so the conv input matches.
|
|
||||||
x_mask = x_mask.to(image_embeddings.dtype)
|
x_mask = x_mask.to(image_embeddings.dtype)
|
||||||
|
|
||||||
mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings(
|
mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings(
|
||||||
x_mask, image_embeddings.shape[0], image_embeddings.shape[2:]
|
x_mask, image_embeddings.shape[0], image_embeddings.shape[2:]
|
||||||
)
|
)
|
||||||
@ -546,7 +527,6 @@ class SAM3DBody(nn.Module):
|
|||||||
# expand+contiguous for the vertices branch.
|
# expand+contiguous for the vertices branch.
|
||||||
bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx]
|
bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx]
|
||||||
bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0]
|
bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0]
|
||||||
ori_img_size = self._flatten_person(batch["ori_img_size"])[batch_idx]
|
|
||||||
cam_int = self._flatten_person(
|
cam_int = self._flatten_person(
|
||||||
batch["cam_int"]
|
batch["cam_int"]
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
@ -556,8 +536,7 @@ class SAM3DBody(nn.Module):
|
|||||||
|
|
||||||
def _project(points_3d):
|
def _project(points_3d):
|
||||||
return head_camera.perspective_projection(
|
return head_camera.perspective_projection(
|
||||||
points_3d, pred_cam, bbox_center, bbox_scale, ori_img_size, cam_int,
|
points_3d, pred_cam, bbox_center, bbox_scale, cam_int,
|
||||||
use_intrin_center=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cam_out = _project(pose_output["pred_keypoints_3d"])
|
cam_out = _project(pose_output["pred_keypoints_3d"])
|
||||||
@ -632,7 +611,6 @@ class SAM3DBody(nn.Module):
|
|||||||
tokens_output, pose_output = self.forward_decoder(
|
tokens_output, pose_output = self.forward_decoder(
|
||||||
"body",
|
"body",
|
||||||
image_embeddings[self.body_batch_idx],
|
image_embeddings[self.body_batch_idx],
|
||||||
init_estimate=None,
|
|
||||||
keypoints=keypoints_prompt[self.body_batch_idx],
|
keypoints=keypoints_prompt[self.body_batch_idx],
|
||||||
prev_estimate=None,
|
prev_estimate=None,
|
||||||
condition_info=condition_info[self.body_batch_idx],
|
condition_info=condition_info[self.body_batch_idx],
|
||||||
@ -643,7 +621,6 @@ class SAM3DBody(nn.Module):
|
|||||||
tokens_output_hand, pose_output_hand = self.forward_decoder(
|
tokens_output_hand, pose_output_hand = self.forward_decoder(
|
||||||
"hand",
|
"hand",
|
||||||
image_embeddings[self.hand_batch_idx],
|
image_embeddings[self.hand_batch_idx],
|
||||||
init_estimate=None,
|
|
||||||
keypoints=keypoints_prompt[self.hand_batch_idx],
|
keypoints=keypoints_prompt[self.hand_batch_idx],
|
||||||
prev_estimate=None,
|
prev_estimate=None,
|
||||||
condition_info=condition_info[self.hand_batch_idx],
|
condition_info=condition_info[self.hand_batch_idx],
|
||||||
@ -661,10 +638,8 @@ class SAM3DBody(nn.Module):
|
|||||||
# match the head-MLP external contract (_get_hand_box would .float() anyway).
|
# match the head-MLP external contract (_get_hand_box would .float() anyway).
|
||||||
if len(self.body_batch_idx):
|
if len(self.body_batch_idx):
|
||||||
output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float()
|
output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float()
|
||||||
output["mhr"]["hand_logits"] = self.hand_cls_embed(tokens_output).float()
|
|
||||||
if len(self.hand_batch_idx):
|
if len(self.hand_batch_idx):
|
||||||
output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid()
|
output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid()
|
||||||
output["mhr_hand"]["hand_logits"] = self.hand_cls_embed(tokens_output_hand)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -715,10 +690,10 @@ class SAM3DBody(nn.Module):
|
|||||||
# Concat lhand+rhand along dim 0 so backbone+decoder run once on
|
# Concat lhand+rhand along dim 0 so backbone+decoder run once on
|
||||||
# (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass.
|
# (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass.
|
||||||
batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand)
|
batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand)
|
||||||
saved_batch_state = (self._batch_size, self._max_num_person, self._person_valid)
|
saved_batch_state = (self._batch_size, self._max_num_person)
|
||||||
self._initialize_batch(batch_hands)
|
self._initialize_batch(batch_hands)
|
||||||
hands_output = self.forward_step(batch_hands, decoder_type="hand")
|
hands_output = self.forward_step(batch_hands, decoder_type="hand")
|
||||||
self._batch_size, self._max_num_person, self._person_valid = saved_batch_state
|
self._batch_size, self._max_num_person = saved_batch_state
|
||||||
n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1]
|
n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1]
|
||||||
lhand_output, rhand_output = self._split_hand_output(hands_output, n_left)
|
lhand_output, rhand_output = self._split_hand_output(hands_output, n_left)
|
||||||
# Free the batched image_embeddings/condition_info (unused downstream);
|
# Free the batched image_embeddings/condition_info (unused downstream);
|
||||||
@ -808,9 +783,7 @@ class SAM3DBody(nn.Module):
|
|||||||
# to get an updated body pose estimation.
|
# to get an updated body pose estimation.
|
||||||
self._set_active_branch("body")
|
self._set_active_branch("body")
|
||||||
|
|
||||||
right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
# right_kps_full / left_kps_full already computed above (unchanged since).
|
||||||
left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone()
|
|
||||||
left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1
|
|
||||||
right_kps_crop = self._full_to_crop(batch, right_kps_full)
|
right_kps_crop = self._full_to_crop(batch, right_kps_full)
|
||||||
left_kps_crop = self._full_to_crop(batch, left_kps_full)
|
left_kps_crop = self._full_to_crop(batch, left_kps_full)
|
||||||
|
|
||||||
@ -1030,7 +1003,6 @@ class SAM3DBody(nn.Module):
|
|||||||
_, pose_output = self.forward_decoder(
|
_, pose_output = self.forward_decoder(
|
||||||
"body",
|
"body",
|
||||||
image_embeddings,
|
image_embeddings,
|
||||||
init_estimate=None, # use the default init, not the prev estimate
|
|
||||||
keypoints=keypoint_prompt,
|
keypoints=keypoint_prompt,
|
||||||
prev_estimate=prev_estimate,
|
prev_estimate=prev_estimate,
|
||||||
condition_info=condition_info,
|
condition_info=condition_info,
|
||||||
|
|||||||
@ -29,38 +29,37 @@ class PromptEncoder(nn.Module):
|
|||||||
Encodes prompts for input to SAM's mask decoder.
|
Encodes prompts for input to SAM's mask decoder.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
ops = operations if operations is not None else nn
|
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.num_body_joints = num_body_joints
|
self.num_body_joints = num_body_joints
|
||||||
|
|
||||||
# Keypoint prompts
|
# Keypoint prompts
|
||||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||||
self.point_embeddings = nn.ModuleList(
|
self.point_embeddings = nn.ModuleList(
|
||||||
[ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
|
[operations.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
|
||||||
)
|
)
|
||||||
self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
self.not_a_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||||
self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
self.invalid_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
|
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
|
||||||
LN2d = LayerNorm2d_op(ops)
|
LN2d = LayerNorm2d_op(operations)
|
||||||
mask_in_chans = 256
|
mask_in_chans = 256
|
||||||
self.mask_downscaling = nn.Sequential(
|
self.mask_downscaling = nn.Sequential(
|
||||||
ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
|
operations.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||||
LN2d(mask_in_chans // 64, device=device, dtype=dtype),
|
LN2d(mask_in_chans // 64, device=device, dtype=dtype),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
|
operations.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||||
LN2d(mask_in_chans // 16, device=device, dtype=dtype),
|
LN2d(mask_in_chans // 16, device=device, dtype=dtype),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
|
operations.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||||
LN2d(mask_in_chans // 4, device=device, dtype=dtype),
|
LN2d(mask_in_chans // 4, device=device, dtype=dtype),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
|
operations.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||||
LN2d(mask_in_chans, device=device, dtype=dtype),
|
LN2d(mask_in_chans, device=device, dtype=dtype),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
|
operations.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
# Trained values for the gating conv and no_mask_embed are loaded from the state dict
|
# Trained values for the gating conv and no_mask_embed are loaded from the state dict
|
||||||
self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
self.no_mask_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
|
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||||
"""Positional encoding over the image-embedding grid; (1, C, H, W)."""
|
"""Positional encoding over the image-embedding grid; (1, C, H, W)."""
|
||||||
@ -120,8 +119,7 @@ class PromptEncoder(nn.Module):
|
|||||||
Bx(embed_dim)x(embed_H)x(embed_W)
|
Bx(embed_dim)x(embed_H)x(embed_W)
|
||||||
"""
|
"""
|
||||||
bs = self._get_batch_size(keypoints, boxes, masks)
|
bs = self._get_batch_size(keypoints, boxes, masks)
|
||||||
# Anchor device on the input prompts so we don't pull the offloaded
|
|
||||||
# CPU embedding device under dynamic loading.
|
|
||||||
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
|
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
|
||||||
device = ref.device if ref is not None else self.point_embeddings[0].weight.device
|
device = ref.device if ref is not None else self.point_embeddings[0].weight.device
|
||||||
weight_dtype = self.invalid_point_embed.weight.dtype
|
weight_dtype = self.invalid_point_embed.weight.dtype
|
||||||
@ -136,23 +134,10 @@ class PromptEncoder(nn.Module):
|
|||||||
|
|
||||||
return sparse_embeddings, sparse_masks
|
return sparse_embeddings, sparse_masks
|
||||||
|
|
||||||
def get_mask_embeddings(
|
def get_mask_embeddings(self, masks: torch.Tensor, bs: int = 1, size: Tuple[int, int] = (16, 16)) -> torch.Tensor:
|
||||||
self,
|
"""Embeds mask inputs. Caller casts both outputs to its working dtype."""
|
||||||
masks: Optional[torch.Tensor] = None,
|
no_mask_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, size[0], size[1])
|
||||||
bs: int = 1,
|
mask_embeddings = self.mask_downscaling(masks)
|
||||||
size: Tuple[int, int] = (16, 16), # [H, W]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Embeds mask inputs."""
|
|
||||||
# masks is always on the active device when present; fall back to the
|
|
||||||
# downscaling Conv's weight device when it isn't (rare callers).
|
|
||||||
ref = masks if masks is not None else next(self.mask_downscaling.parameters())
|
|
||||||
no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand(
|
|
||||||
bs, -1, size[0], size[1]
|
|
||||||
)
|
|
||||||
if masks is not None:
|
|
||||||
mask_embeddings = self.mask_downscaling(masks)
|
|
||||||
else:
|
|
||||||
mask_embeddings = no_mask_embeddings
|
|
||||||
return mask_embeddings, no_mask_embeddings
|
return mask_embeddings, no_mask_embeddings
|
||||||
|
|
||||||
|
|
||||||
@ -170,12 +155,9 @@ class PromptableDecoder(nn.Module):
|
|||||||
repeat_pe: bool = False,
|
repeat_pe: bool = False,
|
||||||
do_interm_preds: bool = False,
|
do_interm_preds: bool = False,
|
||||||
keypoint_token_update: bool = False,
|
keypoint_token_update: bool = False,
|
||||||
device=None,
|
device=None, dtype=None, operations=None,
|
||||||
dtype=None,
|
|
||||||
operations=None,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
ops = operations if operations is not None else nn
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
TransformerDecoderLayer(
|
TransformerDecoderLayer(
|
||||||
@ -193,7 +175,7 @@ class PromptableDecoder(nn.Module):
|
|||||||
for i in range(depth)
|
for i in range(depth)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
|
self.norm_final = operations.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
|
||||||
self.do_interm_preds = do_interm_preds
|
self.do_interm_preds = do_interm_preds
|
||||||
self.keypoint_token_update = keypoint_token_update
|
self.keypoint_token_update = keypoint_token_update
|
||||||
|
|
||||||
|
|||||||
@ -166,12 +166,10 @@ def prepare_batch(
|
|||||||
mask_score_t = torch.ones((n,), dtype=torch.float32)
|
mask_score_t = torch.ones((n,), dtype=torch.float32)
|
||||||
|
|
||||||
img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous()
|
img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous()
|
||||||
ori_img_size_t = torch.tensor([width, height], dtype=torch.float32).expand(n, 2).contiguous()
|
|
||||||
|
|
||||||
batch = {
|
batch = {
|
||||||
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
|
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
|
||||||
"img_size": img_size_t.unsqueeze(0), # (1, N, 2)
|
"img_size": img_size_t.unsqueeze(0), # (1, N, 2)
|
||||||
"ori_img_size": ori_img_size_t.unsqueeze(0),# (1, N, 2)
|
|
||||||
"bbox_center": centers.unsqueeze(0), # (1, N, 2)
|
"bbox_center": centers.unsqueeze(0), # (1, N, 2)
|
||||||
"bbox_scale": scales.unsqueeze(0), # (1, N, 2)
|
"bbox_scale": scales.unsqueeze(0), # (1, N, 2)
|
||||||
"bbox": boxes_t.unsqueeze(0), # (1, N, 4)
|
"bbox": boxes_t.unsqueeze(0), # (1, N, 4)
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
"""BVH export for SAM 3D Body pose_data.
|
"""BVH export for SAM 3D Body pose_data.
|
||||||
|
|
||||||
BVH stores explicit bone OFFSETs per joint, so any standard importer
|
BVH stores explicit bone OFFSETs per joint, so standard importers reconstruct
|
||||||
(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations
|
anatomical bone orientations directly (unlike glTF). We skip the rig's joint 0
|
||||||
directly — no heuristic guessing as needed for glTF. We skip the rig's joint 0
|
(static world anchor) and use joint 1 as the ROOT (6 channels: XYZ pos + ZXY
|
||||||
(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos +
|
rot); other joints get 3 channels. Rotations are intrinsic Z-X-Y Euler degrees.
|
||||||
ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are
|
|
||||||
intrinsic Z-X-Y Euler degrees.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -49,13 +47,10 @@ def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def _find_bvh_root(parents: np.ndarray, is_external: bool = False) -> int:
|
def _find_bvh_root(parents: np.ndarray, is_external: bool = False) -> int:
|
||||||
"""First child of the rig's world anchor so the static origin→body stick
|
"""First child of the rig's world anchor, dropping the origin→body stick.
|
||||||
bone gets left out. Falls back to the first root joint.
|
Falls back to the first root joint. External rigs whose root is already the
|
||||||
|
articulated body root with multiple child chains keep the root — descending
|
||||||
MHR's joint 0 is a static world anchor whose single child is the pelvis, so
|
into one child would drop the sibling limbs."""
|
||||||
skipping it is correct. External rigs (e.g. SOMA-77) whose root is already
|
|
||||||
the articulated body root with multiple child chains must keep the root —
|
|
||||||
descending into one child would drop the sibling limbs from the BVH."""
|
|
||||||
NJ = parents.shape[0]
|
NJ = parents.shape[0]
|
||||||
world_anchors = [j for j in range(NJ)
|
world_anchors = [j for j in range(NJ)
|
||||||
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
|
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
|
||||||
@ -93,14 +88,11 @@ def build_bvh(
|
|||||||
track_index: int = -1,
|
track_index: int = -1,
|
||||||
units: str = "cm",
|
units: str = "cm",
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Build a BVH file from pose_data. Returns UTF-8 encoded text bytes.
|
"""Build a BVH file from pose_data. Returns UTF-8 text bytes.
|
||||||
|
|
||||||
`model` may be None when pose_data carries a `_skeleton_override` (external
|
`model` may be None when pose_data carries a `_skeleton_override` (external
|
||||||
rigs, e.g. Kimodo); the rig hierarchy/offsets/bind are read from the
|
rigs); the rig hierarchy/offsets/bind come from the override. `units` is
|
||||||
override instead of the MHR model.
|
"cm" (default) or "m" — affects OFFSET/root-position, not rotations.
|
||||||
|
|
||||||
`units` is "cm" (default, standard mocap convention) or "m". Affects the
|
|
||||||
OFFSET and root-position values; rotations are independent of units.
|
|
||||||
"""
|
"""
|
||||||
if units not in ("cm", "m"):
|
if units not in ("cm", "m"):
|
||||||
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
|
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
|
||||||
@ -123,10 +115,8 @@ def build_bvh(
|
|||||||
body_root = _find_bvh_root(parents, is_external)
|
body_root = _find_bvh_root(parents, is_external)
|
||||||
children_map = _build_children_map(parents)
|
children_map = _build_children_map(parents)
|
||||||
|
|
||||||
# Bone OFFSETs come from MHR's translation_offsets (joint position
|
# Bone OFFSETs = translation_offsets (joint position relative to parent).
|
||||||
# relative to parent in parent's local-bind frame). For the BVH root,
|
# The BVH root uses its bind world position so the skeleton imports in place.
|
||||||
# we use its bind world position so the skeleton sits at the right
|
|
||||||
# spot when imported.
|
|
||||||
bind_global = rig.bind_global_cm # (NJ, 8) cm
|
bind_global = rig.bind_global_cm # (NJ, 8) cm
|
||||||
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
|
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
|
||||||
offset_m = rig.joint_offsets_cm.astype(np.float64) * 0.01
|
offset_m = rig.joint_offsets_cm.astype(np.float64) * 0.01
|
||||||
@ -139,9 +129,8 @@ def build_bvh(
|
|||||||
_visit(c)
|
_visit(c)
|
||||||
_visit(body_root)
|
_visit(body_root)
|
||||||
|
|
||||||
# Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative)
|
# Stored pred_global_rots/pred_joint_coords (authoritative); derive locals
|
||||||
# rather than re-running rig.forward, then derive locals with body_root
|
# with body_root as the BVH-space hierarchy root.
|
||||||
# treated as the hierarchy root in BVH-space.
|
|
||||||
rig_global_m = global_skel_state_from_pose_data(
|
rig_global_m = global_skel_state_from_pose_data(
|
||||||
pose_data, frame_indices, person_k, NJ,
|
pose_data, frame_indices, person_k, NJ,
|
||||||
joint_coords_y_down=rig.per_frame_y_down,
|
joint_coords_y_down=rig.per_frame_y_down,
|
||||||
@ -203,9 +192,8 @@ def build_bvh(
|
|||||||
lines.append(f"Frames: {n_frames}")
|
lines.append(f"Frames: {n_frames}")
|
||||||
lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
|
lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
|
||||||
|
|
||||||
# Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per
|
# Channel matrix per frame: root pos (3) + root rot (3) + non-root rots
|
||||||
# frame, columns in `bvh_order` order. Vectorized — savetxt's C-side
|
# (3 each), columns in `bvh_order`. savetxt is far faster than f-strings.
|
||||||
# formatting beats Python f-strings by ~10× on long clips.
|
|
||||||
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
|
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
|
||||||
motion = np.concatenate([
|
motion = np.concatenate([
|
||||||
root_pos_m * unit_scale, # (N, 3)
|
root_pos_m * unit_scale, # (N, 3)
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent
|
"""3D capsule rendering for OpenPose-style skeletons (SCAIL-Pose look).
|
||||||
torch ray-marching SDF renderer adapted to SAM3DBody pose_data.
|
|
||||||
|
|
||||||
Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps),
|
Each limb is a true 3D capsule (cylinder + hemispherical caps), projected
|
||||||
projected through the per-person camera (`pred_cam_t` + `focal_length` +
|
through the per-person camera (`pred_cam_t` + `focal_length` + image_size) so
|
||||||
image_size) so closer limbs appear thicker/brighter — the SCAIL-Pose
|
closer limbs appear thicker/brighter. Self-contained analytic ray-capsule
|
||||||
visual style. Self-contained: no dependency on the SCAIL-Pose package.
|
renderer. Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
|
||||||
|
|
||||||
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
@ -41,14 +38,12 @@ def _build_specs_from_pose(
|
|||||||
palette: str,
|
palette: str,
|
||||||
person_brightness_falloff: float = 0.0,
|
person_brightness_falloff: float = 0.0,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Flatten body + optional hand limbs for one frame into
|
"""Flatten body + optional hand limbs for one frame into (starts, ends,
|
||||||
(starts, ends, colors_rgba, is_hand) in camera coords (Y-down, +Z forward).
|
colors_rgba, is_hand) in camera coords (Y-down, +Z forward). Drops non-finite
|
||||||
Drops endpoints that are non-finite or behind the camera. `is_hand` flags
|
or behind-camera endpoints; `is_hand` lets the renderer draw hands thinner.
|
||||||
the hand limbs so the renderer can draw them thinner.
|
|
||||||
|
|
||||||
`person_brightness_falloff` mixes each per-person limb color toward white
|
`person_brightness_falloff` mixes each per-person color toward white by
|
||||||
by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the
|
`1 - falloff^k` for track k (track 0 stays vivid)."""
|
||||||
mesh rasterizer and GLB exporters."""
|
|
||||||
starts: List[np.ndarray] = []
|
starts: List[np.ndarray] = []
|
||||||
ends: List[np.ndarray] = []
|
ends: List[np.ndarray] = []
|
||||||
colors: List[np.ndarray] = []
|
colors: List[np.ndarray] = []
|
||||||
@ -65,8 +60,7 @@ def _build_specs_from_pose(
|
|||||||
if body_op is None or cam_t is None:
|
if body_op is None or cam_t is None:
|
||||||
continue
|
continue
|
||||||
cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3)
|
cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3)
|
||||||
# op-keypoints are camera frame (Y-down); add cam_t to place the
|
# op-keypoints are camera frame; add cam_t to place the subject in front.
|
||||||
# subject in front of the camera.
|
|
||||||
body_kp = body_op + cam_t_np[None, :]
|
body_kp = body_op + cam_t_np[None, :]
|
||||||
|
|
||||||
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
|
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
|
||||||
@ -148,10 +142,9 @@ def _ray_capsule_t(
|
|||||||
ba_len: torch.Tensor, # (M,) segment length
|
ba_len: torch.Tensor, # (M,) segment length
|
||||||
radius: torch.Tensor, # (M,) per-capsule radius
|
radius: torch.Tensor, # (M,) per-capsule radius
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Closed-form ray-capsule intersection. Returns (K, M) tensor of ray
|
"""Closed-form ray-capsule intersection -> (K, M) ray params t to the nearest
|
||||||
parameters t to the nearest valid hit per capsule, +inf where the ray
|
valid hit per capsule, +inf on miss. Capsule = union of (cylinder, hemisphere
|
||||||
misses. A capsule is the union of (cylinder body, hemisphere at A,
|
at A, hemisphere at B), each a quadratic root-find."""
|
||||||
hemisphere at B); each component is a quadratic root-find."""
|
|
||||||
INF = float("inf")
|
INF = float("inf")
|
||||||
r_sq = radius * radius # (M,)
|
r_sq = radius * radius # (M,)
|
||||||
|
|
||||||
@ -238,9 +231,8 @@ def _render_capsules_torch(
|
|||||||
z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item()))
|
z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item()))
|
||||||
z_near = max(0.05, z_min - float(radius.max().item()))
|
z_near = max(0.05, z_min - float(radius.max().item()))
|
||||||
|
|
||||||
# Union of per-capsule screen-space bboxes. Pixels outside this mask
|
# Union of per-capsule screen-space bboxes — pixels outside can't hit any
|
||||||
# provably can't hit any capsule, so the analytic intersection only runs
|
# capsule, so intersection only runs on the relevant subset of the canvas.
|
||||||
# on the relevant subset of the canvas (~5-15% at 1080p for typical poses).
|
|
||||||
sz = starts[:, 2].clamp(min=z_near)
|
sz = starts[:, 2].clamp(min=z_near)
|
||||||
ez = ends[:, 2].clamp(min=z_near)
|
ez = ends[:, 2].clamp(min=z_near)
|
||||||
sx_p = starts[:, 0] * fx / sz + cx
|
sx_p = starts[:, 0] * fx / sz + cx
|
||||||
@ -261,16 +253,13 @@ def _render_capsules_torch(
|
|||||||
if xmax_i > xmin_i and ymax_i > ymin_i:
|
if xmax_i > xmin_i and ymax_i > ymin_i:
|
||||||
coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True
|
coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True
|
||||||
|
|
||||||
# Analytic ray-capsule intersection. One pass over the masked pixels —
|
# Analytic ray-capsule intersection, one pass over the masked pixels.
|
||||||
# the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel
|
|
||||||
# plus 6 SDF evaluations per hit pixel for finite-difference normals.
|
|
||||||
INF = float("inf")
|
INF = float("inf")
|
||||||
flat_t = torch.full((N,), INF, device=device, dtype=torch.float32)
|
flat_t = torch.full((N,), INF, device=device, dtype=torch.float32)
|
||||||
flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long)
|
flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long)
|
||||||
active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1)
|
active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1)
|
||||||
if active_idx.numel() > 0:
|
if active_idx.numel() > 0:
|
||||||
# Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory
|
# Cap per-chunk (K, M) tensors to ~4M elements to bound peak memory.
|
||||||
# manageable when both K (image pixels) and M (capsules) are large.
|
|
||||||
chunk_max = max(1, int(4_000_000 / max(M, 1)))
|
chunk_max = max(1, int(4_000_000 / max(M, 1)))
|
||||||
for i0 in range(0, active_idx.numel(), chunk_max):
|
for i0 in range(0, active_idx.numel(), chunk_max):
|
||||||
sub = active_idx[i0 : i0 + chunk_max]
|
sub = active_idx[i0 : i0 + chunk_max]
|
||||||
@ -284,7 +273,7 @@ def _render_capsules_torch(
|
|||||||
flat_t[winners] = t_min[hit]
|
flat_t[winners] = t_min[hit]
|
||||||
flat_m_idx[winners] = m_idx[hit]
|
flat_m_idx[winners] = m_idx[hit]
|
||||||
|
|
||||||
# Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade.
|
# Shade via analytic normal (P - closest point on segment).
|
||||||
out = torch.zeros((N, 3), dtype=torch.float32, device=device)
|
out = torch.zeros((N, 3), dtype=torch.float32, device=device)
|
||||||
if background_rgb is not None:
|
if background_rgb is not None:
|
||||||
out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone()
|
out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone()
|
||||||
@ -306,10 +295,10 @@ def _render_capsules_torch(
|
|||||||
|
|
||||||
col = colors[m_h, :3]
|
col = colors[m_h, :3]
|
||||||
if flat_shade:
|
if flat_shade:
|
||||||
# Solid per-limb color (OpenPose look) — no lighting/depth modulation.
|
# Solid per-limb color (OpenPose look) — no lighting/depth.
|
||||||
out[hit_idx] = col
|
out[hit_idx] = col
|
||||||
return out.view(H, W, 3).clamp(0.0, 1.0)
|
return out.view(H, W, 3).clamp(0.0, 1.0)
|
||||||
# SCAIL Blinn-Phong (render_torch.py:290-331). Headlight: light = +Z.
|
# SCAIL Blinn-Phong, headlight along +Z.
|
||||||
diff = torch.clamp(-(normals[:, 2]), min=0.0)
|
diff = torch.clamp(-(normals[:, 2]), min=0.0)
|
||||||
diffuse = 0.45 + 0.55 * diff
|
diffuse = 0.45 + 0.55 * diff
|
||||||
|
|
||||||
@ -319,7 +308,7 @@ def _render_capsules_torch(
|
|||||||
half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||||
spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32)
|
spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32)
|
||||||
|
|
||||||
# Mild depth fade matches SCAIL's mm-scale ramp in our meter units.
|
# Mild depth fade.
|
||||||
z_vals = p_hit[:, 2]
|
z_vals = p_hit[:, 2]
|
||||||
z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item())
|
z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item())
|
||||||
if z_hi - z_lo > 1e-6:
|
if z_hi - z_lo > 1e-6:
|
||||||
@ -351,21 +340,18 @@ def render_pose_data_capsules(
|
|||||||
hand_radius_scale: float = 0.4,
|
hand_radius_scale: float = 0.4,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Render a frame's pose_data as 3D capsules projected through the per-
|
"""Render a frame's pose_data as 3D capsules through the per-person camera.
|
||||||
person camera. Returns (H, W, 3) fp32 in [0, 1].
|
Returns (H, W, 3) fp32 in [0, 1].
|
||||||
|
|
||||||
`composite='over'` paints over `background` (black if None);
|
`composite='over'` paints over `background` (black if None); 'mesh_only'
|
||||||
`composite='mesh_only'` always uses a black canvas.
|
uses a black canvas. `radius_m` is in meters; hand limbs use
|
||||||
|
`radius_m * hand_radius_scale`. fx/fy come from each person's `focal_length`.
|
||||||
`radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`).
|
|
||||||
Hand limbs use `radius_m * hand_radius_scale` (their bones are far shorter
|
|
||||||
than body limbs). Camera fx/fy come from each person's `focal_length`.
|
|
||||||
"""
|
"""
|
||||||
persons = pose_data["frames"][frame_idx]
|
persons = pose_data["frames"][frame_idx]
|
||||||
if device is None:
|
if device is None:
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
# SAM3DBody shares one camera across the clip — pick from the first valid person.
|
# SAM3DBody shares one camera across the clip — use the first valid person.
|
||||||
fx = fy = float(min(H, W))
|
fx = fy = float(min(H, W))
|
||||||
for person in persons:
|
for person in persons:
|
||||||
f = person.get("focal_length")
|
f = person.get("focal_length")
|
||||||
|
|||||||
@ -1,16 +1,10 @@
|
|||||||
"""GLB export — OpenPose 18-keypoint visualization mode.
|
"""GLB export — OpenPose 18-keypoint visualization mode.
|
||||||
|
|
||||||
Independent of the MHR rig — sourced from pose_data's `pred_keypoints_3d`
|
Sourced from pose_data's `pred_keypoints_3d`, independent of the MHR rig. Each
|
||||||
(the model's regressed surface keypoints). Each track becomes an armature
|
track becomes an armature with a joint per keypoint; sphere markers and limbs
|
||||||
with a sibling joint per keypoint; sphere markers + stick/capsule limbs are
|
are skinned to those joints. Optional hands (`pred_keypoints_3d` 21..62) and
|
||||||
skinned to those joints.
|
face landmarks (`pred_vertices` at fixed vertex IDs) extend the same armature.
|
||||||
|
Shared tables/palettes/mappings live in `glb_shared.py`.
|
||||||
Optional hand keypoints (also from `pred_keypoints_3d`, indices 21..62) and
|
|
||||||
face landmarks (sampled from `pred_vertices` at fixed head-mesh vertex IDs)
|
|
||||||
extend the same armature.
|
|
||||||
|
|
||||||
OpenPose-shared tables / palettes / mappings live in `glb_shared.py` and are
|
|
||||||
imported below — they're also used by the 2D and 3D renderers in this package.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -55,9 +49,8 @@ def _finalize_skinned_mesh(
|
|||||||
joints: np.ndarray, weights: np.ndarray, vert_colors: np.ndarray,
|
joints: np.ndarray, weights: np.ndarray, vert_colors: np.ndarray,
|
||||||
smooth_shade: bool,
|
smooth_shade: bool,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Apply smooth or flat shading to an indexed sphere/stick group mesh and
|
"""Shade a skinned group mesh and pack per-vertex colors. Smooth keeps the
|
||||||
pack per-vertex colors. Smooth keeps the indexed mesh + per-vertex colors;
|
indexed mesh; flat duplicates verts per face and gathers face-corner colors."""
|
||||||
flat duplicates verts per face and gathers face-corner colors."""
|
|
||||||
if smooth_shade:
|
if smooth_shade:
|
||||||
v_f, n_f, f_f, j_f, w_f = smooth_shade_mesh(verts, faces, joints, weights)
|
v_f, n_f, f_f, j_f, w_f = smooth_shade_mesh(verts, faces, joints, weights)
|
||||||
return v_f, n_f, f_f, j_f, w_f, vert_colors.astype(np.float32)
|
return v_f, n_f, f_f, j_f, w_f, vert_colors.astype(np.float32)
|
||||||
@ -73,10 +66,8 @@ def _finalize_skinned_mesh(
|
|||||||
def _pair_colors_from_kp(
|
def _pair_colors_from_kp(
|
||||||
pairs: Tuple[Tuple[int, int], ...], kp_colors: np.ndarray, endpoint: int = 1,
|
pairs: Tuple[Tuple[int, int], ...], kp_colors: np.ndarray, endpoint: int = 1,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Per-limb color = endpoint-vertex color from `kp_colors`. Default
|
"""Per-limb color from `kp_colors`. `endpoint=1` (default) picks the distal
|
||||||
`endpoint=1` picks the second (distal) vertex of each pair, which is
|
vertex of each pair — the OpenPose per-finger gradient for base→tip fingers."""
|
||||||
the OpenPose-canonical per-finger gradient when fingers go base→tip
|
|
||||||
(wrist=0 → thumb1=1 → thumb2=2 …)."""
|
|
||||||
n = len(pairs)
|
n = len(pairs)
|
||||||
out = np.zeros((n, 3), dtype=np.float32)
|
out = np.zeros((n, 3), dtype=np.float32)
|
||||||
for i, (a, b) in enumerate(pairs):
|
for i, (a, b) in enumerate(pairs):
|
||||||
@ -88,19 +79,13 @@ def _openpose_bind_at_rig_rest(
|
|||||||
pose_data: Dict[str, Any], *,
|
pose_data: Dict[str, Any], *,
|
||||||
include_hands: bool, face_vert_ids: Optional[np.ndarray],
|
include_hands: bool, face_vert_ids: Optional[np.ndarray],
|
||||||
) -> Optional[np.ndarray]:
|
) -> Optional[np.ndarray]:
|
||||||
"""OpenPose keypoint positions at the rig's REST pose (T-pose at authoring
|
"""OpenPose keypoint positions at the rig's REST pose, from the override's
|
||||||
origin), built from the `_skeleton_override`'s `bind_global_m` (joint rest
|
`bind_global_m` (joint rest TRS) and `rest_verts_m` (face landmarks).
|
||||||
TRS) and `rest_verts_m` (mesh rest verts for face landmarks).
|
|
||||||
|
|
||||||
Used as the static-bind for openpose-mode geometry so the GLB's static
|
Used as the static-bind so the GLB's static POSITION sits at rig origin,
|
||||||
POSITION attribute sits at rig origin — matching skeletal mode's bind and
|
matching skeletal mode and producing the same rest→scene-frame-0 transition.
|
||||||
producing the same 'snap from rest to scene-frame-0' transition at the
|
Returns None when the override lacks the needed mappings — caller then falls
|
||||||
start of playback. Without this, the static geometry is at scene-frame-0
|
back to per-frame extraction (kp_seq[0])."""
|
||||||
(kp_seq[0]) and viewers that auto-fit on static POSITION will center on
|
|
||||||
the scene location, hiding the per-frame motion.
|
|
||||||
|
|
||||||
Returns None when the override is missing or doesn't carry all the needed
|
|
||||||
mappings — caller falls back to per-frame extraction (kp_seq[0])."""
|
|
||||||
override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None
|
override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None
|
||||||
if override is None or "bind_global_m" not in override:
|
if override is None or "bind_global_m" not in override:
|
||||||
return None
|
return None
|
||||||
@ -141,19 +126,12 @@ def _openpose_bind_at_rig_rest(
|
|||||||
def _extract_openpose_keypoints(
|
def _extract_openpose_keypoints(
|
||||||
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""(N, 18, 3) OpenPose keypoint positions in rig-native Y-up metres.
|
"""(N, 18, 3) OpenPose keypoints in rig-native Y-up metres.
|
||||||
|
|
||||||
Two sources, in priority order:
|
External-skeleton path: when the override carries `openpose18_joint_indices`
|
||||||
|
((18, 2) int32), synthesize from each person's `pred_joint_coords` (already
|
||||||
1. **External-skeleton path** — when pose_data has `_skeleton_override`
|
Y-up, no flip). MHR70 path (default): re-index `pred_keypoints_3d` to COCO-18
|
||||||
with `openpose18_joint_indices` ((18, 2) int32, see
|
and un-flip y/z (stored y-down by sam3d_body).
|
||||||
`_resolve_openpose_keypoints_from_joints`), synthesize from each
|
|
||||||
person's `pred_joint_coords` directly. The override frame is already
|
|
||||||
rig-native Y-up, so no axis flip.
|
|
||||||
2. **MHR70 path** (default for SAM3DBody_Predict output) — re-index the
|
|
||||||
first 70 of 308 MHR keypoints (`pred_keypoints_3d`) to COCO-18.
|
|
||||||
Stored y-down (post `j3d[..., [1,2]] *= -1` in sam3d_body), so we
|
|
||||||
un-flip y/z to match rig-native Y-up.
|
|
||||||
"""
|
"""
|
||||||
frames = pose_data["frames"]
|
frames = pose_data["frames"]
|
||||||
N = len(frame_indices)
|
N = len(frame_indices)
|
||||||
@ -195,10 +173,8 @@ def _extract_openpose_keypoints(
|
|||||||
for t_idx, t in enumerate(frame_indices):
|
for t_idx, t in enumerate(frame_indices):
|
||||||
person = frames[t][person_k]
|
person = frames[t][person_k]
|
||||||
if "pred_keypoints_3d" not in person:
|
if "pred_keypoints_3d" not in person:
|
||||||
# Diagnose the source: external-skeleton producers ship
|
# External-skeleton producer without `openpose18_joint_indices`:
|
||||||
# `_skeleton_override` instead of MHR70 keypoints. If the
|
# can't synthesize the 18-keypoint set.
|
||||||
# producer didn't populate `openpose18_joint_indices` either,
|
|
||||||
# we can't synthesize the 18-keypoint set.
|
|
||||||
if override is not None:
|
if override is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"build_glb_openpose: this pose_data carries "
|
"build_glb_openpose: this pose_data carries "
|
||||||
@ -229,15 +205,11 @@ def _extract_openpose_keypoints(
|
|||||||
def _extract_openpose_hand_keypoints(
|
def _extract_openpose_hand_keypoints(
|
||||||
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""(N, 42, 3) right+left OpenPose hand keypoints (21 + 21) in rig-native
|
"""(N, 42, 3) right+left OpenPose hand keypoints (21+21) in rig-native Y-up.
|
||||||
Y-up frame.
|
|
||||||
|
|
||||||
External-skeleton path: requires `openpose_hand21_r_joint_indices` AND
|
External-skeleton path: needs `openpose_hand21_{r,l}_joint_indices` ((21, 2)
|
||||||
`openpose_hand21_l_joint_indices` ((21, 2) int32 each) in the override.
|
int32) in the override, resolved against `pred_joint_coords`. MHR70 path:
|
||||||
Resolved against per-frame `pred_joint_coords` like the body path.
|
re-orders `pred_keypoints_3d` 21..62 to OpenPose-21 (wrist + 5 fingers)."""
|
||||||
|
|
||||||
MHR70 path: re-orders `pred_keypoints_3d` indices 21..62 to OpenPose-21
|
|
||||||
(wrist + 5 fingers, thumb→pinky, base→tip)."""
|
|
||||||
frames = pose_data["frames"]
|
frames = pose_data["frames"]
|
||||||
N = len(frame_indices)
|
N = len(frame_indices)
|
||||||
out = np.zeros((N, 42, 3), dtype=np.float32)
|
out = np.zeros((N, 42, 3), dtype=np.float32)
|
||||||
@ -307,10 +279,8 @@ def _extract_face_landmarks_from_verts(
|
|||||||
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
||||||
vert_ids: np.ndarray,
|
vert_ids: np.ndarray,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""(N, K_face, 3) face landmarks sampled from per-frame `pred_vertices`
|
"""(N, K_face, 3) face landmarks sampled from `pred_vertices` at the given
|
||||||
at the supplied head-mesh vertex IDs, unflipped to MHR-native Y-up.
|
vertex IDs, unflipped to Y-up. Per-frame deformation is already baked in."""
|
||||||
Each landmark inherits per-frame shape/expr/pose deformation for free
|
|
||||||
since `pred_vertices` already has it baked in."""
|
|
||||||
frames = pose_data["frames"]
|
frames = pose_data["frames"]
|
||||||
N = len(frame_indices)
|
N = len(frame_indices)
|
||||||
K = int(vert_ids.shape[0])
|
K = int(vert_ids.shape[0])
|
||||||
@ -335,18 +305,11 @@ def _build_openpose_spheres(
|
|||||||
smooth_shade: bool = False,
|
smooth_shade: bool = False,
|
||||||
joint_indices: Optional[np.ndarray] = None,
|
joint_indices: Optional[np.ndarray] = None,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""UV sphere per OpenPose keypoint, rigidly skinned to that keypoint's
|
"""UV sphere per keypoint, rigidly skinned to that keypoint's joint and
|
||||||
joint, vertex-colored from kp_colors. `base_joint_idx` is added to the
|
vertex-colored from kp_colors. `base_joint_idx` offsets the emitted JOINTS_0
|
||||||
emitted JOINTS_0 indices so callers can place this group at any offset
|
indices (body=0, right hand=18, …); `joint_indices`, if given, sets explicit
|
||||||
in the shared skin (body=0, right hand=18, etc.). `joint_indices` (when
|
per-sphere indices so callers can skip keypoints (e.g. SCAIL head dots).
|
||||||
given) overrides that with explicit per-sphere joint indices, so callers
|
Returns (verts, normals, faces, joints4, weights4, vert_colors)."""
|
||||||
can skip keypoints (e.g. SCAIL head dots).
|
|
||||||
|
|
||||||
`smooth_shade=True` keeps the indexed mesh and writes per-vertex
|
|
||||||
normals via face-normal averaging — round shading on the spheres.
|
|
||||||
`smooth_shade=False` (default) flat-shades by duplicating verts per
|
|
||||||
face, matching the existing OpenPose-mode look. Returns
|
|
||||||
(verts, normals, faces, joints4, weights4, vert_colors)."""
|
|
||||||
sv, sf = uv_sphere_unit()
|
sv, sf = uv_sphere_unit()
|
||||||
K = bind_kp_m.shape[0]
|
K = bind_kp_m.shape[0]
|
||||||
Nv = sv.shape[0]
|
Nv = sv.shape[0]
|
||||||
@ -376,43 +339,23 @@ def _capsule_mesh_local(
|
|||||||
end_width_frac: float = 0.3,
|
end_width_frac: float = 0.3,
|
||||||
shape: str = "ellipsoid",
|
shape: str = "ellipsoid",
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Build a per-limb mesh in limb-local frame along +Y from y=0 (head
|
"""Per-limb mesh in limb-local frame along +Y from y=0 (head) to y=L (tail).
|
||||||
pole) to y=L (tail pole).
|
|
||||||
|
|
||||||
`shape` selects the silhouette:
|
`shape`:
|
||||||
- 'ellipsoid' (default): tips are small hemispheres of radius
|
- 'ellipsoid' (default): hemisphere tips of radius `W * end_width_frac`,
|
||||||
`W * end_width_frac`; body has ellipsoidal radius profile
|
ellipsoidal sin(π·u) body profile (fat middle, narrow ends).
|
||||||
sin(π*u) from w_end at the junctions to W at the middle. Gives
|
- 'capsule': SCAIL "rig" limb — an OPEN cylinder of constant radius W,
|
||||||
a fat-middle / narrow-end stretched-ellipse look.
|
no caps. Pair with same-radius sphere markers so they cap the ends
|
||||||
- 'capsule': SCAIL-style "rig" limb — an OPEN cylinder of constant
|
seamlessly (caps would bump out when sphere radius ≠ cap radius).
|
||||||
radius W with no hemisphere caps. Pair with sphere joint markers
|
|
||||||
of the same radius so the spheres seamlessly cap the open
|
|
||||||
cylinder ends (the cylinder cross-section ring at the endpoint
|
|
||||||
lies exactly on the sphere surface). Drawing hemisphere caps
|
|
||||||
inside the joint sphere creates a visible bump where the cap
|
|
||||||
pokes out unevenly when sphere radius ≠ cap radius — open
|
|
||||||
cylinders avoid that.
|
|
||||||
|
|
||||||
Per-limb mesh is required because the cap height (w_end) depends on
|
A per-limb mesh is needed because cap height depends on width — one
|
||||||
the limb width — a single canonical mesh can't produce true
|
canonical mesh can't give true hemispheres for arbitrary L:W in ellipsoid.
|
||||||
hemispheres for arbitrary L:W ratios in ellipsoid mode.
|
|
||||||
|
|
||||||
Returns:
|
Returns (verts (Nv,3), faces (Nf,3), weights (Nv,2) head/tail, sums to 1).
|
||||||
verts: (Nv, 3) float32 — limb-local positions in meters.
|
|
||||||
faces: (Nf, 3) uint32 — triangle indices.
|
|
||||||
weights: (Nv, 2) float32 — (head, tail) skinning weights, linearly
|
|
||||||
interpolated by axial position (sums to 1).
|
|
||||||
"""
|
"""
|
||||||
W = max(1e-6, min(float(W), float(L) * 0.5 - 1e-6))
|
W = max(1e-6, min(float(W), float(L) * 0.5 - 1e-6))
|
||||||
if str(shape) == "capsule":
|
if str(shape) == "capsule":
|
||||||
# SCAIL-style "rig" limb: an OPEN cylinder of constant radius W,
|
# Open cylinder, no caps — sphere markers cap the ends (see docstring).
|
||||||
# no hemisphere caps. The sphere joint markers at each endpoint
|
|
||||||
# provide the rounded ends of the bone — when sphere_radius ==
|
|
||||||
# cylinder_radius, the cylinder cross-section ring at the bone
|
|
||||||
# endpoint lies exactly on the sphere surface, so silhouette is
|
|
||||||
# seamless. Hemisphere caps would create a visible bump where
|
|
||||||
# the cap pokes out of the sphere if cap_r ≠ marker_r, so we
|
|
||||||
# omit them entirely.
|
|
||||||
cap_r = 0.0
|
cap_r = 0.0
|
||||||
body_r = W
|
body_r = W
|
||||||
if n_cap_lat is None:
|
if n_cap_lat is None:
|
||||||
@ -425,7 +368,7 @@ def _capsule_mesh_local(
|
|||||||
end_frac = float(min(0.95, max(0.05, end_width_frac)))
|
end_frac = float(min(0.95, max(0.05, end_width_frac)))
|
||||||
cap_r = max(1e-7, W * end_frac)
|
cap_r = max(1e-7, W * end_frac)
|
||||||
body_r = W
|
body_r = W
|
||||||
# Ellipsoid defaults: more body rings to sample the sin(π·u) curve.
|
# More body rings to sample the sin(π·u) curve.
|
||||||
if n_cap_lat is None:
|
if n_cap_lat is None:
|
||||||
n_cap_lat = 3
|
n_cap_lat = 3
|
||||||
if n_body is None:
|
if n_body is None:
|
||||||
@ -473,10 +416,7 @@ def _capsule_mesh_local(
|
|||||||
phi = 2.0 * np.pi * k / n_lon
|
phi = 2.0 * np.pi * k / n_lon
|
||||||
verts.append([body_r * float(np.cos(phi)), 0.0, body_r * float(np.sin(phi))])
|
verts.append([body_r * float(np.cos(phi)), 0.0, body_r * float(np.sin(phi))])
|
||||||
|
|
||||||
# Body intermediate rings (between the cap junctions for capped meshes,
|
# Body intermediate rings (none for 'capsule', n_body=0 by default).
|
||||||
# between the two end rings for open cylinders). For 'capsule' mode
|
|
||||||
# n_body=0 by default — no intermediate rings needed for a constant-
|
|
||||||
# radius cylinder.
|
|
||||||
body_rings: List[int] = []
|
body_rings: List[int] = []
|
||||||
is_ellipsoid = str(shape) == "ellipsoid"
|
is_ellipsoid = str(shape) == "ellipsoid"
|
||||||
for j in range(1, n_body + 1):
|
for j in range(1, n_body + 1):
|
||||||
@ -572,11 +512,8 @@ def _scail_redirect_neck_stub(body_kp: np.ndarray) -> np.ndarray:
|
|||||||
def _openpose_limb_rest_trs(
|
def _openpose_limb_rest_trs(
|
||||||
bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...],
|
bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...],
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
"""Per-limb rest TRS:
|
"""Per-limb rest TRS: midpoints (K_pairs, 3) and unit a→b axes (or +Y if
|
||||||
midpoints (K_pairs, 3): rest midpoint between bind_kp_m[a] and bind_kp_m[b].
|
degenerate). Caller uses midpoints as rest translation, axes for alignment."""
|
||||||
rest_axes (K_pairs, 3): unit direction a→b at rest (or +Y if degenerate).
|
|
||||||
Caller uses `midpoints` as each limb joint's rest translation (rotation =
|
|
||||||
identity), and `rest_axes` to compute per-frame alignment rotations."""
|
|
||||||
K_pairs = len(pairs)
|
K_pairs = len(pairs)
|
||||||
mid = np.zeros((K_pairs, 3), dtype=np.float32)
|
mid = np.zeros((K_pairs, 3), dtype=np.float32)
|
||||||
axis = np.zeros((K_pairs, 3), dtype=np.float32)
|
axis = np.zeros((K_pairs, 3), dtype=np.float32)
|
||||||
@ -595,13 +532,10 @@ def _openpose_limb_rest_trs(
|
|||||||
def _openpose_limb_anim_trs(
|
def _openpose_limb_anim_trs(
|
||||||
kp_seq: np.ndarray, pairs: Tuple[Tuple[int, int], ...], rest_axes: np.ndarray,
|
kp_seq: np.ndarray, pairs: Tuple[Tuple[int, int], ...], rest_axes: np.ndarray,
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
"""Per-frame limb TRS:
|
"""Per-frame limb TRS: anim_mid (N, K_pairs, 3) midpoints and anim_quat
|
||||||
anim_mid (N, K_pairs, 3): midpoint of (kp_seq[t][a], kp_seq[t][b]).
|
(N, K_pairs, 4 xyzw) aligning each limb's rest axis to its frame-t axis.
|
||||||
anim_quat (N, K_pairs, 4): rotation (xyzw) that aligns each limb's rest
|
Drives skin_matrix(t) = T(mid_t)·R_t·T(-mid_rest) — rigid rotation about
|
||||||
axis to its frame-t axis.
|
the rest midpoint, no LBS cross-section thinning."""
|
||||||
Together with rest TRS, this drives `skin_matrix(t) = T(mid_t) * R_t *
|
|
||||||
T(-mid_rest)` so each capsule rigidly rotates about its rest midpoint to
|
|
||||||
track the limb's current direction — no LBS cross-section thinning."""
|
|
||||||
N = kp_seq.shape[0]
|
N = kp_seq.shape[0]
|
||||||
K_pairs = len(pairs)
|
K_pairs = len(pairs)
|
||||||
anim_mid = np.zeros((N, K_pairs, 3), dtype=np.float32)
|
anim_mid = np.zeros((N, K_pairs, 3), dtype=np.float32)
|
||||||
@ -616,7 +550,7 @@ def _openpose_limb_anim_trs(
|
|||||||
n = float(np.linalg.norm(d))
|
n = float(np.linalg.norm(d))
|
||||||
if n > 1e-9:
|
if n > 1e-9:
|
||||||
R[t, k] = rotation_align(ax_rest, d / n)
|
R[t, k] = rotation_align(ax_rest, d / n)
|
||||||
quat = rotmat_to_quat_np(R).astype(np.float32) # (N, K_pairs, 4) xyzw
|
quat = rotmat_to_quat_np(R).astype(np.float32)
|
||||||
return anim_mid, quat
|
return anim_mid, quat
|
||||||
|
|
||||||
|
|
||||||
@ -628,20 +562,14 @@ def _build_openpose_sticks(
|
|||||||
smooth_shade: bool = False,
|
smooth_shade: bool = False,
|
||||||
end_width_frac: float = 0.3,
|
end_width_frac: float = 0.3,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Capsule (cylinder + hemispherical caps) per limb pair (a, b).
|
"""Capsule per limb pair (a, b), each sized to its own length/width so caps
|
||||||
|
are true hemispheres regardless of L:W. Ellipsoid mode auto-clamps width to
|
||||||
|
`length * 0.1` so short limbs don't look chunky.
|
||||||
|
|
||||||
Each limb gets its own mesh sized to that limb's length and width so
|
Rigid (weight=1) binding to a per-limb joint at `limb_joint_base_idx +
|
||||||
the caps are TRUE hemispheres of radius `half_width_eff` — the limb
|
limb_idx`, which the caller animates with midpoint translation + rotation
|
||||||
silhouette is rounded-rectangle-like, regardless of L:W ratio. Width
|
(avoids LBS thinning). Returns (verts, normals, faces, joints4, weights4,
|
||||||
auto-clamped to `length * 0.1` so short limbs (face/ear) don't look
|
vert_colors)."""
|
||||||
chunky next to long ones.
|
|
||||||
|
|
||||||
Skinning: rigid (weight=1) binding to a per-limb joint at
|
|
||||||
`limb_joint_base_idx + limb_idx` — the caller animates that joint with
|
|
||||||
midpoint translation + rest-to-current rotation so each capsule rotates
|
|
||||||
rigidly with its limb (avoids translation-only LBS cross-section
|
|
||||||
thinning). Returns flat-shaded (verts, normals, faces, joints4,
|
|
||||||
weights4, vert_colors)."""
|
|
||||||
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
out_v_chunks: List[np.ndarray] = []
|
out_v_chunks: List[np.ndarray] = []
|
||||||
@ -663,13 +591,10 @@ def _build_openpose_sticks(
|
|||||||
unit_dir = direction / length
|
unit_dir = direction / length
|
||||||
R = rotation_align(canonical, unit_dir)
|
R = rotation_align(canonical, unit_dir)
|
||||||
if is_capsule:
|
if is_capsule:
|
||||||
# SCAIL-style uniform radius — every bone gets the same width.
|
# Uniform radius — every bone the same width (clamped internally).
|
||||||
# `_capsule_mesh_local` clamps internally to L/2-eps so very
|
|
||||||
# short bones don't go degenerate.
|
|
||||||
half_width_eff = max(MIN_WIDTH, half_width_m)
|
half_width_eff = max(MIN_WIDTH, half_width_m)
|
||||||
else:
|
else:
|
||||||
# Ellipsoid mode: original auto-thinning so short face/ear
|
# Auto-thin so short face/ear limbs aren't chunky next to body limbs.
|
||||||
# limbs don't look chunky next to long body limbs.
|
|
||||||
half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m))
|
half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m))
|
||||||
|
|
||||||
v_local, f_local, _weights_unused = _capsule_mesh_local(
|
v_local, f_local, _weights_unused = _capsule_mesh_local(
|
||||||
@ -678,10 +603,8 @@ def _build_openpose_sticks(
|
|||||||
v_world = v_local @ R.T + head
|
v_world = v_local @ R.T + head
|
||||||
Nv = v_local.shape[0]
|
Nv = v_local.shape[0]
|
||||||
|
|
||||||
# Rigid binding to the per-limb joint. The 2-bone (head, tail) weights
|
# Rigid binding to the per-limb joint; the 2-bone weights are discarded
|
||||||
# from `_capsule_mesh_local` are discarded — they're translation-only
|
# (translation-only under LBS, would thin the cross-section).
|
||||||
# under glTF LBS and don't rotate the cross-section, causing visible
|
|
||||||
# thinning when the limb axis changes between rest and animated pose.
|
|
||||||
j_arr = np.zeros((Nv, 4), dtype=np.uint16)
|
j_arr = np.zeros((Nv, 4), dtype=np.uint16)
|
||||||
j_arr[:, 0] = limb_idx + limb_joint_base_idx
|
j_arr[:, 0] = limb_idx + limb_joint_base_idx
|
||||||
w_arr = np.zeros((Nv, 4), dtype=np.float32)
|
w_arr = np.zeros((Nv, 4), dtype=np.float32)
|
||||||
@ -730,40 +653,24 @@ def build_glb_openpose(
|
|||||||
stick_end_width_frac: float = 0.6,
|
stick_end_width_frac: float = 0.6,
|
||||||
bone_smooth_window: int = 0,
|
bone_smooth_window: int = 0,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Build a GLB containing an OpenPose-style 3D skeleton — sphere markers
|
"""Build a GLB of an OpenPose-style 3D skeleton — sphere markers per keypoint
|
||||||
per keypoint plus rainbow-colored sticks between standard limb pairs.
|
plus colored sticks between limb pairs, one armature per track. Body from
|
||||||
Body keypoints are sourced from pose_data's `pred_keypoints_3d` (no rig
|
`pred_keypoints_3d`; optional hands (same source) and face landmarks
|
||||||
forward needed). Optional hand keypoints (also from `pred_keypoints_3d`)
|
(`pred_vertices`) extend each armature.
|
||||||
and face landmarks (sampled from `pred_vertices` at fixed head-mesh
|
|
||||||
vertex IDs) extend the same per-track armature.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
include_hands: append the standard 21+21 OpenPose hand keypoints to
|
include_hands: append the 21+21 OpenPose hand keypoints per track.
|
||||||
each track's armature (right hand at MHR70 indices 21..41,
|
hand_marker_radius_m: hand sphere radius. 0 = auto = 0.4 × marker_radius_m.
|
||||||
left at 42..62).
|
hand_stick_radius_m: hand limb half-width. 0 = auto = 0.5 × stick_radius_m.
|
||||||
hand_marker_radius_m: per-hand sphere radius. 0 = auto = 0.4 ×
|
hand_color_style: 'dwpose' (default) = solid-blue dots + rainbow sticks;
|
||||||
`marker_radius_m` (hand keypoints are anatomically smaller than
|
'openpose' = rainbow dots AND sticks.
|
||||||
body joints; matches DWPose's smaller hand dots).
|
face_style: 'disabled' (default) | 'full' (~30 contour pts) | 'eyes_mouth'
|
||||||
hand_stick_radius_m: per-hand limb half-width. 0 = auto = 0.5 ×
|
(eyes + outer-lip subset); sampled at vertex IDs from
|
||||||
`stick_radius_m`.
|
`canonical_colors["positions"]`.
|
||||||
hand_color_style: 'dwpose' (default) = solid-blue hand dots,
|
face_marker_radius_m: face landmark sphere radius. 0 = auto = 0.3 ×
|
||||||
rainbow per-finger sticks (controlnet_aux/dwpose convention);
|
marker_radius_m. Rendered as dots only, no contour lines.
|
||||||
'openpose' = rainbow per-finger dots AND sticks (matches
|
palette: 'openpose' = rainbow gradient per keypoint; 'scail' = warm right
|
||||||
poseParameters.cpp::HAND_COLORS_RENDER).
|
/ cool left, grey centerline, distinct per-limb colors.
|
||||||
face_style: 'disabled' (default) | 'full' | 'eyes_mouth' — face
|
|
||||||
landmarks sampled from `pred_vertices` at vertex IDs picked from
|
|
||||||
`pose_data["canonical_colors"]["positions"]`. 'full' = all ~30
|
|
||||||
contour points; 'eyes_mouth' = the eyes + outer-lip subset.
|
|
||||||
face_marker_radius_m: per-face landmark sphere radius. 0 = auto =
|
|
||||||
0.3 × `marker_radius_m` — face landmarks are densely packed
|
|
||||||
around the eyes/mouth/jaw and need to be much smaller than
|
|
||||||
body keypoints to keep the layout legible. Face landmarks are
|
|
||||||
rendered as standalone dots (no contour lines), matching
|
|
||||||
DWPose's face_pose draw style.
|
|
||||||
palette: body color scheme. 'openpose' = standard rainbow gradient
|
|
||||||
per keypoint (canonical OpenPose convention); 'scail' =
|
|
||||||
SCAIL-Pose style — warm hues right side, cool hues left side,
|
|
||||||
grey neck-to-nose centerline, distinct per-limb colors.
|
|
||||||
"""
|
"""
|
||||||
is_scail = str(palette) == "scail"
|
is_scail = str(palette) == "scail"
|
||||||
# SCAIL drops the face bones (13..16) and eye/ear spheres; keeps nose (idx 0,
|
# SCAIL drops the face bones (13..16) and eye/ear spheres; keeps nose (idx 0,
|
||||||
@ -771,13 +678,11 @@ def build_glb_openpose(
|
|||||||
body_pairs = OPENPOSE_18_PAIRS[:13] if is_scail else OPENPOSE_18_PAIRS
|
body_pairs = OPENPOSE_18_PAIRS[:13] if is_scail else OPENPOSE_18_PAIRS
|
||||||
body_sphere_kp = (np.arange(14, dtype=np.int64)
|
body_sphere_kp = (np.arange(14, dtype=np.int64)
|
||||||
if is_scail else np.arange(18, dtype=np.int64))
|
if is_scail else np.arange(18, dtype=np.int64))
|
||||||
if str(palette) == "scail":
|
if is_scail:
|
||||||
body_sphere_colors = SCAIL_KEYPOINT_COLORS_18
|
body_sphere_colors = SCAIL_KEYPOINT_COLORS_18
|
||||||
body_stick_colors = SCAIL_LIMB_COLORS_17
|
body_stick_colors = SCAIL_LIMB_COLORS_17
|
||||||
elif str(palette) == "openpose":
|
elif str(palette) == "openpose":
|
||||||
# Existing OpenPose behavior: same rainbow array used for both
|
# Same rainbow array drives both spheres and sticks.
|
||||||
# spheres (per-keypoint) and sticks (per-limb, indexed 0..16 of
|
|
||||||
# the 18-element rainbow — yields a legible per-limb gradient).
|
|
||||||
body_sphere_colors = OPENPOSE_RAINBOW_18
|
body_sphere_colors = OPENPOSE_RAINBOW_18
|
||||||
body_stick_colors = OPENPOSE_RAINBOW_18
|
body_stick_colors = OPENPOSE_RAINBOW_18
|
||||||
else:
|
else:
|
||||||
@ -892,13 +797,9 @@ def build_glb_openpose(
|
|||||||
if bone_smooth_window and bone_smooth_window > 1:
|
if bone_smooth_window and bone_smooth_window > 1:
|
||||||
kp_seq = gaussian_smooth_positions(kp_seq, int(bone_smooth_window))
|
kp_seq = gaussian_smooth_positions(kp_seq, int(bone_smooth_window))
|
||||||
|
|
||||||
# Static-bind = rig's REST pose when available (override path); else
|
# Static-bind = rig REST pose when available, else frame 0. The rest
|
||||||
# fall back to frame 0 of the motion. The rest-pose bind makes the
|
# bind keeps static POSITION at rig origin so viewers auto-center there
|
||||||
# GLB's static POSITION attribute sit at rig origin, so viewers
|
# and the motion is visible (see _openpose_bind_at_rig_rest).
|
||||||
# auto-fit/center on rig origin and the animation visibly snaps from
|
|
||||||
# rest to scene-frame-0 — matching skeletal mode's behavior. Without
|
|
||||||
# this, openpose's static geometry is at scene-frame-0 and viewers
|
|
||||||
# mis-center on the scene location, masking the motion entirely.
|
|
||||||
bind_kp_m_rest = _openpose_bind_at_rig_rest(
|
bind_kp_m_rest = _openpose_bind_at_rig_rest(
|
||||||
pose_data, include_hands=include_hands, face_vert_ids=face_vert_ids,
|
pose_data, include_hands=include_hands, face_vert_ids=face_vert_ids,
|
||||||
)
|
)
|
||||||
@ -914,7 +815,7 @@ def build_glb_openpose(
|
|||||||
person_root_idx = len(nodes) - 1
|
person_root_idx = len(nodes) - 1
|
||||||
scene_root_indices.append(person_root_idx)
|
scene_root_indices.append(person_root_idx)
|
||||||
|
|
||||||
# K keypoint joint nodes (spheres bind here, rigid translation only).
|
# K keypoint joint nodes (spheres bind here, translation only).
|
||||||
joint_node_indices: List[int] = []
|
joint_node_indices: List[int] = []
|
||||||
for j in range(K):
|
for j in range(K):
|
||||||
nodes.append({
|
nodes.append({
|
||||||
@ -926,9 +827,7 @@ def build_glb_openpose(
|
|||||||
joint_node_indices.append(len(nodes) - 1)
|
joint_node_indices.append(len(nodes) - 1)
|
||||||
person_root["children"].extend(joint_node_indices)
|
person_root["children"].extend(joint_node_indices)
|
||||||
|
|
||||||
# Per-limb REST TRS (midpoint + axis) and per-frame TRS (midpoint +
|
# Per-limb rest + per-frame TRS; sticks bind rigidly to these joints.
|
||||||
# quaternion that aligns rest-axis → frame-t-axis). Sticks bind
|
|
||||||
# rigidly to these joints so each capsule rotates with its limb.
|
|
||||||
limb_rest_mids_list: List[np.ndarray] = []
|
limb_rest_mids_list: List[np.ndarray] = []
|
||||||
limb_rest_axes_list: List[np.ndarray] = []
|
limb_rest_axes_list: List[np.ndarray] = []
|
||||||
limb_anim_mids_list: List[np.ndarray] = []
|
limb_anim_mids_list: List[np.ndarray] = []
|
||||||
@ -951,12 +850,10 @@ def build_glb_openpose(
|
|||||||
limb_rest_axes_list.append(raxis_h)
|
limb_rest_axes_list.append(raxis_h)
|
||||||
limb_anim_mids_list.append(amid_h)
|
limb_anim_mids_list.append(amid_h)
|
||||||
limb_anim_quats_list.append(aquat_h)
|
limb_anim_quats_list.append(aquat_h)
|
||||||
limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0) # (K_limbs, 3)
|
limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0)
|
||||||
limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1) # (N, K_limbs, 3)
|
limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1)
|
||||||
limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1) # (N, K_limbs, 4)
|
limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1)
|
||||||
# Hemisphere-align consecutive quats per limb so LINEAR interpolation
|
# Hemisphere-align consecutive quats so LINEAR interp takes the short path.
|
||||||
# takes the short path (otherwise large per-frame rotations can flip
|
|
||||||
# signs and produce visible "twist back" artifacts mid-playback).
|
|
||||||
limb_anim_quats = quat_sign_fix_per_joint(limb_anim_quats).astype(np.float32)
|
limb_anim_quats = quat_sign_fix_per_joint(limb_anim_quats).astype(np.float32)
|
||||||
|
|
||||||
limb_joint_indices: List[int] = []
|
limb_joint_indices: List[int] = []
|
||||||
@ -970,8 +867,8 @@ def build_glb_openpose(
|
|||||||
limb_joint_indices.append(len(nodes) - 1)
|
limb_joint_indices.append(len(nodes) - 1)
|
||||||
person_root["children"].extend(limb_joint_indices)
|
person_root["children"].extend(limb_joint_indices)
|
||||||
|
|
||||||
# Combined skin: keypoint joints (IBM = T(-bind_kp_m)) then limb joints
|
# Combined skin: keypoint joints then limb joints; IBM = T(-rest) for
|
||||||
# (IBM = T(-limb_rest_mid)). Both yield identity skin_matrix at rest.
|
# both, yielding identity skin_matrix at rest.
|
||||||
all_joint_indices = joint_node_indices + limb_joint_indices
|
all_joint_indices = joint_node_indices + limb_joint_indices
|
||||||
ibm = np.tile(np.eye(4, dtype=np.float32), (K + K_limbs, 1, 1))
|
ibm = np.tile(np.eye(4, dtype=np.float32), (K + K_limbs, 1, 1))
|
||||||
ibm[:K, :3, 3] = -bind_kp_m
|
ibm[:K, :3, 3] = -bind_kp_m
|
||||||
@ -985,10 +882,8 @@ def build_glb_openpose(
|
|||||||
})
|
})
|
||||||
skin_idx = len(skins) - 1
|
skin_idx = len(skins) - 1
|
||||||
|
|
||||||
# Per-group geometry. Spheres bind to keypoint joints (base_joint_idx
|
# Per-group geometry. Spheres bind to keypoint joints [0, K); sticks to
|
||||||
# ∈ [0, K)); sticks bind to limb joints (limb_joint_base_idx ∈
|
# limb joints [K, K+K_limbs). Stacked body → R-hand → L-hand → face.
|
||||||
# [K, K + K_limbs)). Groups stack body → right hand → left hand →
|
|
||||||
# face for keypoint joints, and body → R-hand → L-hand for limbs.
|
|
||||||
group_meshes: List[Tuple[np.ndarray, np.ndarray, np.ndarray,
|
group_meshes: List[Tuple[np.ndarray, np.ndarray, np.ndarray,
|
||||||
np.ndarray, np.ndarray, np.ndarray]] = []
|
np.ndarray, np.ndarray, np.ndarray]] = []
|
||||||
sp = _build_openpose_spheres(
|
sp = _build_openpose_spheres(
|
||||||
@ -1008,9 +903,7 @@ def build_glb_openpose(
|
|||||||
group_meshes.append(st)
|
group_meshes.append(st)
|
||||||
|
|
||||||
if include_hands:
|
if include_hands:
|
||||||
# Hand stick colors stay rainbow per-finger regardless of
|
# Hand sticks stay rainbow per-finger; only dots switch under 'dwpose'.
|
||||||
# `hand_color_style` — only the sphere dots switch to solid
|
|
||||||
# blue under 'dwpose'. Matches controlnet_aux/dwpose/util.py.
|
|
||||||
hand_pair_colors = _pair_colors_from_kp(
|
hand_pair_colors = _pair_colors_from_kp(
|
||||||
OPENPOSE_HAND_PAIRS, OPENPOSE_HAND_COLORS_21, endpoint=1,
|
OPENPOSE_HAND_PAIRS, OPENPOSE_HAND_COLORS_21, endpoint=1,
|
||||||
)
|
)
|
||||||
@ -1033,9 +926,7 @@ def build_glb_openpose(
|
|||||||
if K_face > 0:
|
if K_face > 0:
|
||||||
f_off = K_body + K_hands
|
f_off = K_body + K_hands
|
||||||
f_bind = bind_kp_m[f_off:f_off + K_face]
|
f_bind = bind_kp_m[f_off:f_off + K_face]
|
||||||
# DWPose face = dots only, no contour lines
|
# DWPose face = dots only, no contour lines.
|
||||||
# (controlnet_aux/dwpose/util.py::draw_facepose draws white
|
|
||||||
# circles per landmark and never connects them).
|
|
||||||
group_meshes.append(_build_openpose_spheres(
|
group_meshes.append(_build_openpose_spheres(
|
||||||
f_bind, float(face_marker_radius_m),
|
f_bind, float(face_marker_radius_m),
|
||||||
FACE_LANDMARK_COLORS, base_joint_idx=f_off,
|
FACE_LANDMARK_COLORS, base_joint_idx=f_off,
|
||||||
@ -1087,9 +978,8 @@ def build_glb_openpose(
|
|||||||
"target": {"node": joint_node_indices[j], "path": "translation"},
|
"target": {"node": joint_node_indices[j], "path": "translation"},
|
||||||
})
|
})
|
||||||
|
|
||||||
# Per-limb-joint translation + rotation channels. Stationary limbs
|
# Per-limb-joint translation + rotation; stationary limbs bake their
|
||||||
# have their constant TRS baked into the node so they don't bloat the
|
# constant TRS into the node instead of an animation channel.
|
||||||
# animation buffer.
|
|
||||||
for k in range(K_limbs):
|
for k in range(K_limbs):
|
||||||
t_k = limb_anim_mids[:, k, :].astype(np.float32)
|
t_k = limb_anim_mids[:, k, :].astype(np.float32)
|
||||||
if (np.ptp(t_k, axis=0) < 1e-6).all():
|
if (np.ptp(t_k, axis=0) < 1e-6).all():
|
||||||
@ -1103,9 +993,7 @@ def build_glb_openpose(
|
|||||||
"target": {"node": limb_joint_indices[k], "path": "translation"},
|
"target": {"node": limb_joint_indices[k], "path": "translation"},
|
||||||
})
|
})
|
||||||
q_k = limb_anim_quats[:, k, :].astype(np.float32)
|
q_k = limb_anim_quats[:, k, :].astype(np.float32)
|
||||||
# ptp on the absolute value handles the +q == -q ambiguity, but
|
# Plain ptp is fine — signs already aligned by quat_sign_fix_per_joint.
|
||||||
# `quat_sign_fix_per_joint` already aligned signs so a plain ptp
|
|
||||||
# is fine here.
|
|
||||||
if (np.ptp(q_k, axis=0) < 1e-6).all():
|
if (np.ptp(q_k, axis=0) < 1e-6).all():
|
||||||
nodes[limb_joint_indices[k]]["rotation"] = q_k[0].tolist()
|
nodes[limb_joint_indices[k]]["rotation"] = q_k[0].tolist()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,15 +1,11 @@
|
|||||||
"""GLB export for SAM 3D Body pose_data.
|
"""Shared GLB export helpers for SAM 3D Body pose_data.
|
||||||
|
|
||||||
Mode: skeletal — rebuilds the MHR 127-bone rig. Per-frame local TRS comes from
|
Skeletal mode rebuilds the MHR 127-bone rig: per-frame local TRS from
|
||||||
re-running param_transform on saved mhr_model_params; rest verts from a
|
param_transform on mhr_model_params, rest verts from a zero-pose forward,
|
||||||
zero-pose forward with the person's shape_params; sparse triplet skinning is
|
sparse skinning compacted to glTF's 4-influence form, expression re-exposed as
|
||||||
compacted to glTF's max-4-influences form; facial expression is re-exposed as
|
72 morph targets. Camera-y-down data is un-flipped to glTF Y-up. Pose
|
||||||
72 morph targets driven by expr_params.
|
correctives are dropped (glTF skinning can't represent them), so extreme joint
|
||||||
|
angles differ from the SAM3DBody renderer by the corrective amount.
|
||||||
pred_vertices/pred_cam_t are camera-y-down — un-flipped here so the GLB lives
|
|
||||||
in glTF-spec Y-up. Pose correctives are dropped (glTF skinning can't represent
|
|
||||||
them); deformation at extreme joint angles will differ from the SAM3DBody
|
|
||||||
renderer by the corrective amount.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -24,12 +20,11 @@ import torch
|
|||||||
|
|
||||||
from comfy_extras.sam3d_body.rasterizer import rainbow_colors_from_canonical
|
from comfy_extras.sam3d_body.rasterizer import rainbow_colors_from_canonical
|
||||||
|
|
||||||
# fp32-rounded ln(2). Used as `exp(x * _LN2)` to compute 2**x bit-identically
|
# fp32-rounded ln(2); exp(x * _LN2) matches the rig's own 2**x bit-for-bit.
|
||||||
# to the rig's own `torch.exp(jp[..., 6:7] * _LN2)`
|
|
||||||
_LN2 = 0.6931471824645996
|
_LN2 = 0.6931471824645996
|
||||||
|
|
||||||
|
|
||||||
# Quaternion / rotation helpers (xyzw convention, matching MHR rig)
|
# Quaternion / rotation helpers (xyzw, matching MHR rig)
|
||||||
|
|
||||||
def _euler_xyz_to_quat_np(angles: np.ndarray) -> np.ndarray:
|
def _euler_xyz_to_quat_np(angles: np.ndarray) -> np.ndarray:
|
||||||
"""(roll, pitch, yaw) -> (x, y, z, w). Mirrors mhr_rig._euler_xyz_to_quat."""
|
"""(roll, pitch, yaw) -> (x, y, z, w). Mirrors mhr_rig._euler_xyz_to_quat."""
|
||||||
@ -96,8 +91,7 @@ def _skel_state_compose_np(s1: np.ndarray, s2: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
|
def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
|
||||||
"""Edge-replicate Gaussian smoothing along axis 0 (time); sigma = window/4.
|
"""Edge-replicate Gaussian smoothing along time (sigma = window/4). float64."""
|
||||||
Endpoints replicate so they aren't pulled toward zero. Returns float64."""
|
|
||||||
a = np.asarray(arr, dtype=np.float64)
|
a = np.asarray(arr, dtype=np.float64)
|
||||||
n = a.shape[0]
|
n = a.shape[0]
|
||||||
half = window // 2
|
half = window // 2
|
||||||
@ -117,9 +111,8 @@ def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
|
def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
|
||||||
"""Gaussian-smooth a (N, NJ, 4) quaternion sequence along time. Sign-aligns
|
"""Smooth a (N, NJ, 4) quaternion sequence along time: sign-align per joint,
|
||||||
per joint first, convolves per-component, renormalizes. Suppresses multi-
|
convolve per-component, renormalize. Calms bone spikes at extreme poses."""
|
||||||
frame bone spikes at extreme poses without needing the upstream Smooth node."""
|
|
||||||
if window <= 1 or q_seq.shape[0] < 2:
|
if window <= 1 or q_seq.shape[0] < 2:
|
||||||
return q_seq
|
return q_seq
|
||||||
out = _gaussian_smooth_time(quat_sign_fix_per_joint(q_seq), window)
|
out = _gaussian_smooth_time(quat_sign_fix_per_joint(q_seq), window)
|
||||||
@ -128,18 +121,16 @@ def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray:
|
def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray:
|
||||||
"""Gaussian-smooth a (N, K, 3) position sequence along time (edge-replicate
|
"""Smooth a (N, K, 3) position sequence along time. Calms jittery keypoint
|
||||||
padding). Used to calm jittery keypoint tracks before the openpose rig
|
tracks before the openpose rig derives sphere translations + limb TRS."""
|
||||||
derives sphere translations + limb TRS from them."""
|
|
||||||
if window <= 1 or seq.shape[0] < 2:
|
if window <= 1 or seq.shape[0] < 2:
|
||||||
return seq
|
return seq
|
||||||
return _gaussian_smooth_time(seq, window).astype(np.float32)
|
return _gaussian_smooth_time(seq, window).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
|
def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
|
||||||
"""Walk (N, NJ, 4) along time, flip sign whenever consecutive frames sit
|
"""Walk (N, NJ, 4) along time, flipping sign when consecutive frames sit on
|
||||||
on opposite hemispheres. Eliminates long-path slerp glitches (mid-anim
|
opposite hemispheres. Avoids long-path slerp glitches. fp64 internally."""
|
||||||
cartwheel flip). fp64 to avoid drift; normalizes input defensively."""
|
|
||||||
out = np.array(q_seq, dtype=np.float64, copy=True)
|
out = np.array(q_seq, dtype=np.float64, copy=True)
|
||||||
norms = np.linalg.norm(out, axis=-1, keepdims=True)
|
norms = np.linalg.norm(out, axis=-1, keepdims=True)
|
||||||
out = out / np.maximum(norms, 1e-12)
|
out = out / np.maximum(norms, 1e-12)
|
||||||
@ -151,11 +142,9 @@ def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def bone_locals_from_globals(rig_global: np.ndarray, parents: np.ndarray) -> np.ndarray:
|
def bone_locals_from_globals(rig_global: np.ndarray, parents: np.ndarray) -> np.ndarray:
|
||||||
"""Globals (N, NJ, 8) + parents -> per-bone local TRS (N, NJ, 8) such that
|
"""Globals (N, NJ, 8) + parents -> per-bone local TRS so FK reproduces
|
||||||
FK over (parents, bone_local) reproduces rig_global. local =
|
rig_global. local = inverse(parent_global) ∘ child_global, robust to
|
||||||
inverse(parent_global) ∘ child_global makes this robust to hierarchy-
|
hierarchy-convention mismatches in `parents`."""
|
||||||
convention mismatches: glTF FK gives back exactly rig_global even if
|
|
||||||
`parents` doesn't match the rig's pmi-walk."""
|
|
||||||
N, NJ, _ = rig_global.shape
|
N, NJ, _ = rig_global.shape
|
||||||
bone_local = np.zeros_like(rig_global)
|
bone_local = np.zeros_like(rig_global)
|
||||||
for j in range(NJ):
|
for j in range(NJ):
|
||||||
@ -188,8 +177,7 @@ def _quat_to_mat3_np(q: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
def collect_tracks(pose_data: Dict[str, Any], track_index: int) -> List[Tuple[int, List[int]]]:
|
def collect_tracks(pose_data: Dict[str, Any], track_index: int) -> List[Tuple[int, List[int]]]:
|
||||||
"""List of (person_index, frame_indices). track_index == -1 means every
|
"""List of (person_index, frame_indices). track_index == -1 means every
|
||||||
present track; empty tracks are dropped. Same person index across frames
|
present track; empty tracks dropped. Same person index = same subject."""
|
||||||
is assumed same subject (Smooth/Predict enforce this on tracked bboxes)."""
|
|
||||||
frames = pose_data["frames"]
|
frames = pose_data["frames"]
|
||||||
max_p = max((len(f) for f in frames), default=0)
|
max_p = max((len(f) for f in frames), default=0)
|
||||||
if max_p == 0:
|
if max_p == 0:
|
||||||
@ -257,8 +245,7 @@ class GLBWriter:
|
|||||||
return len(self.accessors) - 1
|
return len(self.accessors) - 1
|
||||||
|
|
||||||
def add_vec3_f32_no_minmax(self, arr: np.ndarray) -> int:
|
def add_vec3_f32_no_minmax(self, arr: np.ndarray) -> int:
|
||||||
"""Morph-target POSITIONs: spec lets us skip min/max, avoiding a
|
"""Morph-target POSITIONs: spec lets us skip min/max."""
|
||||||
per-frame delta bbox."""
|
|
||||||
a = np.ascontiguousarray(arr, dtype=np.float32)
|
a = np.ascontiguousarray(arr, dtype=np.float32)
|
||||||
view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY)
|
view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY)
|
||||||
self.accessors.append({
|
self.accessors.append({
|
||||||
@ -288,9 +275,8 @@ class GLBWriter:
|
|||||||
return len(self.accessors) - 1
|
return len(self.accessors) - 1
|
||||||
|
|
||||||
def add_scalar_f32_flat(self, arr: np.ndarray, count: int) -> int:
|
def add_scalar_f32_flat(self, arr: np.ndarray, count: int) -> int:
|
||||||
"""Animation-output scalars: `count` is keyframes, not floats. Morph-
|
"""Animation-output scalars: `count` is keyframes, not floats (morph
|
||||||
target weight tracks store N_morph weights per keyframe as flat float32
|
weight tracks store N_morph weights per keyframe)."""
|
||||||
with count=N_keyframes."""
|
|
||||||
a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1)
|
a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1)
|
||||||
view_idx = self._add_view(a.tobytes())
|
view_idx = self._add_view(a.tobytes())
|
||||||
self.accessors.append({
|
self.accessors.append({
|
||||||
@ -382,9 +368,8 @@ def bake_vertex_colors(
|
|||||||
rainbow_tilt_z_deg: float,
|
rainbow_tilt_z_deg: float,
|
||||||
pastel_mix: float,
|
pastel_mix: float,
|
||||||
) -> Optional[np.ndarray]:
|
) -> Optional[np.ndarray]:
|
||||||
"""Per-vertex RGB matching the renderer's shader preset, on the canonical
|
"""Per-vertex RGB matching the renderer's shader preset. Returns (N_v, 3)
|
||||||
mesh. Returns (N_v, 3) float32 in [0, 1], or None for `default` (let the
|
float32 in [0, 1], or None for `default` (use the viewer's material)."""
|
||||||
viewer's default material handle shading)."""
|
|
||||||
if shader == "default" or canonical_colors is None:
|
if shader == "default" or canonical_colors is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -432,8 +417,8 @@ def compute_normals(verts: np.ndarray, faces: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def _parents_from_pmi(rig: Any) -> np.ndarray:
|
def _parents_from_pmi(rig: Any) -> np.ndarray:
|
||||||
"""Parent index per joint from skel_pmi. pmi is (2, 266): row 0 = child,
|
"""Parent index per joint from skel_pmi ((2, 266): row 0 child, row 1
|
||||||
row 1 = parent, split into BFS levels by skel_pmi_buffer_sizes. Roots = -1."""
|
parent, split into BFS levels by skel_pmi_buffer_sizes). Roots = -1."""
|
||||||
NJ = int(rig.NUM_JOINTS)
|
NJ = int(rig.NUM_JOINTS)
|
||||||
pmi = rig.skel_pmi.cpu().numpy()
|
pmi = rig.skel_pmi.cpu().numpy()
|
||||||
sizes = rig.skel_pmi_buffer_sizes.cpu().numpy().tolist()
|
sizes = rig.skel_pmi_buffer_sizes.cpu().numpy().tolist()
|
||||||
@ -450,47 +435,29 @@ def _parents_from_pmi(rig: Any) -> np.ndarray:
|
|||||||
|
|
||||||
def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||||
"""Return ``_skeleton_override`` dict if present. Non-MHR skeletons supply
|
"""Return ``_skeleton_override`` dict if present. Non-MHR skeletons supply
|
||||||
this to bypass MHR rig extraction (see ComfyUI-Kimodo). Required keys:
|
this to bypass MHR rig extraction (see ComfyUI-Kimodo).
|
||||||
parents: (NJ,) int32, -1 = root
|
|
||||||
bind_global_m: (NJ, 8) f32 — [t.xyz | q.xyzw | scale], meters
|
Required keys:
|
||||||
lbs_compact_joints: (V, 8) uint16 — pre-compacted skin influences
|
parents: (NJ,) int32, -1 = root
|
||||||
lbs_compact_weights: (V, 8) f32
|
bind_global_m: (NJ, 8) f32 — [t.xyz | q.xyzw | scale], meters
|
||||||
lbs_compact_max_inf: int — actual max influences (≤ 8)
|
lbs_compact_joints: (V, 8) uint16 — pre-compacted skin influences
|
||||||
rest_verts_m: (V, 3) f32
|
lbs_compact_weights: (V, 8) f32
|
||||||
faces: (F, 3) uint32
|
lbs_compact_max_inf: int — actual max influences (≤ 8)
|
||||||
Optional:
|
rest_verts_m: (V, 3) f32
|
||||||
per_frame_y_down: bool — set False if pred_joint_coords are already
|
faces: (F, 3) uint32
|
||||||
rig-native Y-up (kimodo). Default True (MHR).
|
|
||||||
openpose18_joint_indices: (18, 2) int32 — body OpenPose-18 → joint
|
Optional (enable openpose mode on external rigs):
|
||||||
index pair, resolved against per-frame
|
per_frame_y_down: bool — False if pred_joint_coords are already Y-up
|
||||||
`pred_joint_coords`. Each row is
|
(kimodo). Default True (MHR).
|
||||||
(joint_a, joint_b); b == -1 = single
|
openpose18_joint_indices: (18, 2) int32 — body keypoint → (a, b)
|
||||||
joint, else default midpoint of the two
|
joints, resolved against `pred_joint_coords`.
|
||||||
(lets producers approximate keypoints
|
b == -1 = single joint, else midpoint of (a, b).
|
||||||
with no matching joint, e.g. Nose ≈
|
openpose18_joint_weights: (18,) f32 — blend w: w*a + (1-w)*b
|
||||||
midpoint(LeftEye, RightEye)). Enables
|
(default 0.5; outside [0,1] extrapolates; ignored
|
||||||
`SAM3DBody_ToGLB(mode="openpose")` on
|
when b == -1).
|
||||||
external rigs.
|
openpose_hand21_{r,l}_joint_indices: (21, 2) int32 — per-hand keypoint
|
||||||
openpose18_joint_weights: (18,) f32 — optional per-keypoint blend
|
maps; both required for include_hands=True.
|
||||||
weight for the (a, b) mapping above.
|
openpose_hand21_{r,l}_joint_weights: (21,) f32 — optional, same as above.
|
||||||
Position = w*joints[a] + (1-w)*joints[b]
|
|
||||||
when b ≥ 0 (default w=0.5 → midpoint).
|
|
||||||
Values outside [0, 1] EXTRAPOLATE past
|
|
||||||
the line segment — used to approximate
|
|
||||||
landmarks with no nearby joint pair
|
|
||||||
(e.g. ears: w=2.0 along the eye→eye
|
|
||||||
axis puts each ear one eye-distance
|
|
||||||
outside the corresponding eye). Ignored
|
|
||||||
for single-joint rows (b = -1).
|
|
||||||
openpose_hand21_r_joint_indices: (21, 2) int32 — right-hand OpenPose-21
|
|
||||||
(wrist + 5 fingers × 4 joints, base→tip)
|
|
||||||
→ joint index pair. Required (alongside
|
|
||||||
the L counterpart) for openpose mode
|
|
||||||
with include_hands=True.
|
|
||||||
openpose_hand21_l_joint_indices: (21, 2) int32 — left-hand counterpart.
|
|
||||||
openpose_hand21_r_joint_weights: (21,) f32 — optional, same semantics as
|
|
||||||
`openpose18_joint_weights`.
|
|
||||||
openpose_hand21_l_joint_weights: (21,) f32 — optional, same as above.
|
|
||||||
"""
|
"""
|
||||||
if pose_data is None:
|
if pose_data is None:
|
||||||
return None
|
return None
|
||||||
@ -502,12 +469,10 @@ def extract_rig_static(model: Any, pose_data: Optional[Dict[str, Any]] = None) -
|
|||||||
use that instead of MHR-specific `model.head_pose.mhr` buffers."""
|
use that instead of MHR-specific `model.head_pose.mhr` buffers."""
|
||||||
override = _get_skeleton_override(pose_data)
|
override = _get_skeleton_override(pose_data)
|
||||||
if override is not None:
|
if override is not None:
|
||||||
# External rig: caller pre-compacts skin and supplies bind global directly,
|
# External rig: skin pre-compacted, bind global supplied directly.
|
||||||
# so we don't need MHR's PCA pose / expression bases.
|
|
||||||
parents = np.asarray(override["parents"], dtype=np.int32)
|
parents = np.asarray(override["parents"], dtype=np.int32)
|
||||||
rest_v = np.asarray(override["rest_verts_m"], dtype=np.float32)
|
rest_v = np.asarray(override["rest_verts_m"], dtype=np.float32)
|
||||||
# BVH needs parent-relative bone OFFSETs (cm). MHR ships these directly;
|
# BVH needs parent-relative bone offsets (cm); derive from bind globals.
|
||||||
# external rigs only give bind globals, so derive locals from them.
|
|
||||||
bind_global_m = np.asarray(override["bind_global_m"], dtype=np.float32)
|
bind_global_m = np.asarray(override["bind_global_m"], dtype=np.float32)
|
||||||
local_bind = bone_locals_from_globals(bind_global_m[None], parents)[0]
|
local_bind = bone_locals_from_globals(bind_global_m[None], parents)[0]
|
||||||
joint_translation_offsets = (local_bind[:, :3] * 100.0).astype(np.float32)
|
joint_translation_offsets = (local_bind[:, :3] * 100.0).astype(np.float32)
|
||||||
@ -560,29 +525,26 @@ def compact_skin_to_n(
|
|||||||
skin_indices: np.ndarray, vert_indices: np.ndarray, weights: np.ndarray,
|
skin_indices: np.ndarray, vert_indices: np.ndarray, weights: np.ndarray,
|
||||||
num_verts: int, max_inf: int = 8,
|
num_verts: int, max_inf: int = 8,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, int]:
|
) -> Tuple[np.ndarray, np.ndarray, int]:
|
||||||
"""Sparse (joint, vert, weight) triplets -> dense (joints[V, max_inf],
|
"""Sparse (joint, vert, weight) triplets -> dense (joints, weights) of shape
|
||||||
weights[V, max_inf]). Keeps `max_inf` largest-magnitude influences,
|
(V, max_inf), keeping the largest influences and renormalizing. `actual_max`
|
||||||
renormalizes. `actual_max` lets the caller skip JOINTS_1/WEIGHTS_1 when
|
lets the caller skip JOINTS_1/WEIGHTS_1 when nothing exceeds 4 influences."""
|
||||||
nothing exceeds 4 influences."""
|
|
||||||
joints = np.zeros((num_verts, max_inf), dtype=np.uint16)
|
joints = np.zeros((num_verts, max_inf), dtype=np.uint16)
|
||||||
out_w = np.zeros((num_verts, max_inf), dtype=np.float32)
|
out_w = np.zeros((num_verts, max_inf), dtype=np.float32)
|
||||||
counts = np.zeros(num_verts, dtype=np.int32)
|
counts = np.zeros(num_verts, dtype=np.int32)
|
||||||
|
|
||||||
if vert_indices.size:
|
if vert_indices.size:
|
||||||
# lexsort secondary key first: groups by vert, weights descending within group.
|
# Group by vert, weights descending within each group.
|
||||||
order = np.lexsort((-weights, vert_indices))
|
order = np.lexsort((-weights, vert_indices))
|
||||||
vi_sorted = vert_indices[order]
|
vi_sorted = vert_indices[order]
|
||||||
sk_sorted = skin_indices[order]
|
sk_sorted = skin_indices[order]
|
||||||
w_sorted = weights[order]
|
w_sorted = weights[order]
|
||||||
|
|
||||||
# Per-row rank within its vertex group: 0 at each group start, +1 elsewhere.
|
# Per-row rank within its vertex group (0 at each group start).
|
||||||
# group_start[i] is True when vi_sorted[i] starts a new vertex.
|
|
||||||
n = vi_sorted.size
|
n = vi_sorted.size
|
||||||
group_start = np.empty(n, dtype=bool)
|
group_start = np.empty(n, dtype=bool)
|
||||||
group_start[0] = True
|
group_start[0] = True
|
||||||
np.not_equal(vi_sorted[1:], vi_sorted[:-1], out=group_start[1:])
|
np.not_equal(vi_sorted[1:], vi_sorted[:-1], out=group_start[1:])
|
||||||
pos = np.arange(n, dtype=np.int64)
|
pos = np.arange(n, dtype=np.int64)
|
||||||
# Position of each row's group start, broadcast forward.
|
|
||||||
group_start_pos = np.maximum.accumulate(np.where(group_start, pos, 0))
|
group_start_pos = np.maximum.accumulate(np.where(group_start, pos, 0))
|
||||||
rank = pos - group_start_pos
|
rank = pos - group_start_pos
|
||||||
|
|
||||||
@ -609,9 +571,8 @@ def zero_pose_rest_verts(
|
|||||||
model: Any, shape_params: np.ndarray, expr_zero: bool = True,
|
model: Any, shape_params: np.ndarray, expr_zero: bool = True,
|
||||||
pose_data: Optional[Dict[str, Any]] = None,
|
pose_data: Optional[Dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Rig with zero pose + this subject's shape -> rest verts (V, 3) in
|
"""Zero pose + this subject's shape -> rest verts (V, 3) in rig-native Y-up
|
||||||
rig-native Y-up meters. External-skeleton path returns `rest_verts_m`
|
meters. External path returns `rest_verts_m` directly."""
|
||||||
directly (no PCA shape space to expand)."""
|
|
||||||
override = _get_skeleton_override(pose_data)
|
override = _get_skeleton_override(pose_data)
|
||||||
if override is not None:
|
if override is not None:
|
||||||
return np.asarray(override["rest_verts_m"], dtype=np.float32)
|
return np.asarray(override["rest_verts_m"], dtype=np.float32)
|
||||||
@ -624,14 +585,11 @@ def zero_pose_rest_verts(
|
|||||||
sp = torch.from_numpy(np.ascontiguousarray(shape_params, dtype=np.float32)).to(device)
|
sp = torch.from_numpy(np.ascontiguousarray(shape_params, dtype=np.float32)).to(device)
|
||||||
if sp.ndim == 1:
|
if sp.ndim == 1:
|
||||||
sp = sp.unsqueeze(0)
|
sp = sp.unsqueeze(0)
|
||||||
# mhr.forward(identity_coeffs, model_parameters, expr_coeffs):
|
# rig.forward(shape, model_params, expr); zero pose + zero expr.
|
||||||
# identity_rest = base_shape + identity_basis @ shape;
|
|
||||||
# cat([model_params, zeros]) through param_transform; expr added.
|
|
||||||
model_params = torch.zeros(1, 204, device=device, dtype=dtype)
|
model_params = torch.zeros(1, 204, device=device, dtype=dtype)
|
||||||
expr = torch.zeros(1, 72, device=device, dtype=dtype)
|
expr = torch.zeros(1, 72, device=device, dtype=dtype)
|
||||||
verts, _ = rig(sp.to(dtype), model_params, expr, apply_correctives=False)
|
verts, _ = rig(sp.to(dtype), model_params, expr, apply_correctives=False)
|
||||||
# Rig outputs cm; mhr_head divides by 100 for meters. Match that.
|
verts_m = verts[0].cpu().float().numpy() / 100.0 # cm -> m
|
||||||
verts_m = verts[0].cpu().float().numpy() / 100.0
|
|
||||||
return verts_m.astype(np.float32)
|
return verts_m.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
@ -639,7 +597,7 @@ def global_skel_state_per_frame(
|
|||||||
model: Any, mhr_model_params: np.ndarray,
|
model: Any, mhr_model_params: np.ndarray,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Rig FK over a batch of mhr_model_params -> (N, NJ, 8) = (t cm, q xyzw,
|
"""Rig FK over a batch of mhr_model_params -> (N, NJ, 8) = (t cm, q xyzw,
|
||||||
scale). Bones are shape- and expression-independent so we pass zeros."""
|
scale). Bones are shape/expression-independent, so pass zeros."""
|
||||||
inner = model.model if hasattr(model, "model") else model
|
inner = model.model if hasattr(model, "model") else model
|
||||||
rig = inner.head_pose.mhr
|
rig = inner.head_pose.mhr
|
||||||
device = next(rig.parameters()).device
|
device = next(rig.parameters()).device
|
||||||
@ -655,8 +613,8 @@ def global_skel_state_per_frame(
|
|||||||
|
|
||||||
|
|
||||||
def rotmat_to_quat_np(R: np.ndarray) -> np.ndarray:
|
def rotmat_to_quat_np(R: np.ndarray) -> np.ndarray:
|
||||||
"""(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978 branched, largest-component
|
"""(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978, largest-component pick.
|
||||||
pick for stability. Cross-frame sign-fixing is the caller's job."""
|
Cross-frame sign-fixing is the caller's job."""
|
||||||
shape = R.shape[:-2]
|
shape = R.shape[:-2]
|
||||||
Rf = R.reshape(-1, 3, 3).astype(np.float64)
|
Rf = R.reshape(-1, 3, 3).astype(np.float64)
|
||||||
M = Rf.shape[0]
|
M = Rf.shape[0]
|
||||||
@ -703,14 +661,12 @@ def global_skel_state_from_pose_data(
|
|||||||
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
pose_data: Dict[str, Any], frame_indices: List[int], person_k: int,
|
||||||
NJ: int, *, joint_coords_y_down: bool = True,
|
NJ: int, *, joint_coords_y_down: bool = True,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Build per-frame skel_state from stored pred_global_rots + pred_joint_coords,
|
"""Per-frame skel_state from stored pred_global_rots + pred_joint_coords,
|
||||||
bypassing rig.forward. Returns (N, NJ, 8) in METERS, MHR-native frame.
|
bypassing rig.forward. Returns (N, NJ, 8) in meters, MHR-native frame.
|
||||||
|
|
||||||
pred_global_rots are MHR-native (no y/z flip). For MHR, pred_joint_coords
|
pred_global_rots are MHR-native. pred_joint_coords are y-down for MHR
|
||||||
are stored y-down (post-flip), so un-flip when `joint_coords_y_down=True`.
|
(un-flipped when `joint_coords_y_down=True`); external rigs store y-up
|
||||||
External skeletons (Kimodo) store y-up already → pass False. Scale
|
(pass False). Scale defaults to 1 (not preserved in pose_data)."""
|
||||||
defaults to 1 (rig scale isn't preserved in pose_data; close to 1 for
|
|
||||||
typical body poses)."""
|
|
||||||
frames = pose_data["frames"]
|
frames = pose_data["frames"]
|
||||||
N = len(frame_indices)
|
N = len(frame_indices)
|
||||||
rotmat = np.zeros((N, NJ, 3, 3), dtype=np.float32)
|
rotmat = np.zeros((N, NJ, 3, 3), dtype=np.float32)
|
||||||
@ -731,10 +687,8 @@ def global_skel_state_from_pose_data(
|
|||||||
|
|
||||||
|
|
||||||
def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
||||||
"""Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm.
|
"""Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm,
|
||||||
Inverse of `lbs_inverse_bind_pose` modulo precision; used as bones' static
|
used as bones' static TRS. External rig: convert `bind_global_m` m -> cm."""
|
||||||
TRS so the rest mesh looks correct with no animation playing. External
|
|
||||||
rig: convert override's `bind_global_m` from m → cm to match this contract."""
|
|
||||||
override = _get_skeleton_override(pose_data)
|
override = _get_skeleton_override(pose_data)
|
||||||
if override is not None:
|
if override is not None:
|
||||||
bind_m = np.asarray(override["bind_global_m"], dtype=np.float32).copy()
|
bind_m = np.asarray(override["bind_global_m"], dtype=np.float32).copy()
|
||||||
@ -746,13 +700,10 @@ def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> n
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Rig:
|
class Rig:
|
||||||
"""Normalized static rig for the GLB/BVH exporters, independent of where it
|
"""Normalized static rig for the GLB/BVH exporters, source-independent: MHR
|
||||||
came from: an MHR model (`Rig.from_pose_data(pose_data, model)`) or an inline
|
model or inline `pose_data["_skeleton_override"]` (external rigs). Consumers
|
||||||
`pose_data["_skeleton_override"]` (external rigs, e.g. ComfyUI-Kimodo).
|
never branch on the source. Only `rest_verts_m` is source-dependent — MHR
|
||||||
|
expands it from `shape_params`; external rigs ship it fixed.
|
||||||
Consumers read these fields and never branch on the source. The only
|
|
||||||
source-dependent operation is `rest_verts_m` — MHR rest verts depend on the
|
|
||||||
subject's `shape_params`; external rigs ship fixed rest verts.
|
|
||||||
"""
|
"""
|
||||||
parents: np.ndarray # (NJ,) int32, -1 = root
|
parents: np.ndarray # (NJ,) int32, -1 = root
|
||||||
joint_offsets_cm: np.ndarray # (NJ, 3) parent-relative bind offsets, cm
|
joint_offsets_cm: np.ndarray # (NJ, 3) parent-relative bind offsets, cm
|
||||||
@ -816,9 +767,8 @@ class Rig:
|
|||||||
|
|
||||||
|
|
||||||
def ibp_from_bind_global(bind_skel_state_m: np.ndarray) -> np.ndarray:
|
def ibp_from_bind_global(bind_skel_state_m: np.ndarray) -> np.ndarray:
|
||||||
"""Inverse-bind MAT4 by inverting the rig's bind global (meters). Guarantees
|
"""Inverse-bind MAT4 from the rig's bind global (meters). IBP[j] =
|
||||||
IBP[j] = inverse(FK over bind local TRS) — exactly what glTF skinning
|
inverse(FK over bind local TRS), as glTF skinning needs. Returns (NJ, 4, 4)
|
||||||
needs given bones default to the bind local TRS. Returns (NJ, 4, 4)
|
|
||||||
column-major."""
|
column-major."""
|
||||||
NJ = bind_skel_state_m.shape[0]
|
NJ = bind_skel_state_m.shape[0]
|
||||||
t = bind_skel_state_m[:, :3].astype(np.float32)
|
t = bind_skel_state_m[:, :3].astype(np.float32)
|
||||||
@ -877,10 +827,8 @@ def _ibp_to_mat4(ibp_skel: np.ndarray) -> np.ndarray:
|
|||||||
|
|
||||||
|
|
||||||
def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndarray]:
|
def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
"""Unit UV sphere, poles ±Y. `n_lat` kept ODD by default so one ring
|
"""Unit UV sphere, poles ±Y. `n_lat` odd so a ring lands at the equator;
|
||||||
lands at the equator. Default (9, 16) gives 146 verts / 288 faces — n_lon
|
n_lon=16 matches the capsule cylinder so end rings meet flush."""
|
||||||
matches the 16-segment cylinder used by capsule limbs AND the equator
|
|
||||||
ring aligns 1-to-1 with the cylinder end ring, so silhouettes meet flush."""
|
|
||||||
verts: List[List[float]] = [[0.0, -1.0, 0.0]] # south pole at index 0
|
verts: List[List[float]] = [[0.0, -1.0, 0.0]] # south pole at index 0
|
||||||
for i in range(1, n_lat + 1):
|
for i in range(1, n_lat + 1):
|
||||||
lat = -0.5 * np.pi + np.pi * i / (n_lat + 1)
|
lat = -0.5 * np.pi + np.pi * i / (n_lat + 1)
|
||||||
@ -924,8 +872,8 @@ def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndar
|
|||||||
def flat_shade_mesh(
|
def flat_shade_mesh(
|
||||||
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
|
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Smooth -> flat by duplicating verts per face; each triangle gets 3
|
"""Flat-shade by duplicating verts per face; each triangle gets 3 unique
|
||||||
unique verts sharing its face normal. Skinning attrs duplicated alongside."""
|
verts sharing its face normal. Skinning attrs duplicated alongside."""
|
||||||
F = faces.shape[0]
|
F = faces.shape[0]
|
||||||
new_v = np.zeros((F * 3, 3), dtype=np.float32)
|
new_v = np.zeros((F * 3, 3), dtype=np.float32)
|
||||||
new_n = np.zeros((F * 3, 3), dtype=np.float32)
|
new_n = np.zeros((F * 3, 3), dtype=np.float32)
|
||||||
@ -949,9 +897,8 @@ def flat_shade_mesh(
|
|||||||
def smooth_shade_mesh(
|
def smooth_shade_mesh(
|
||||||
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
|
verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Area-weighted per-vertex normals (smooth shading). Geometry, skinning,
|
"""Area-weighted per-vertex normals. Geometry/skinning/indexing pass through
|
||||||
indexing pass through unchanged so vertex colors stay aligned. Orphan
|
unchanged so vertex colors stay aligned. Orphan verts get +Y fallback."""
|
||||||
verts get +Y fallback."""
|
|
||||||
Nv = int(verts.shape[0])
|
Nv = int(verts.shape[0])
|
||||||
v0 = verts[faces[:, 0]]
|
v0 = verts[faces[:, 0]]
|
||||||
v1 = verts[faces[:, 1]]
|
v1 = verts[faces[:, 1]]
|
||||||
@ -994,11 +941,9 @@ def rotation_align(from_vec: np.ndarray, to_vec: np.ndarray) -> np.ndarray:
|
|||||||
def make_lit_material(
|
def make_lit_material(
|
||||||
roughness: float = 0.85, double_sided: bool = False, opacity: float = 1.0,
|
roughness: float = 0.85, double_sided: bool = False, opacity: float = 1.0,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Lit PBR material using vertex COLOR_0 multiplicatively. KHR_materials_unlit
|
"""Lit PBR material using vertex COLOR_0. Dielectric (metallic=0) so colors
|
||||||
is intentionally off so viewer lighting reveals surface form. metallic=0
|
stay readable; roughness 0.85 suits rainbow body meshes, 0.3 the glossy
|
||||||
keeps the surface dielectric so vertex colors stay readable. roughness=0.85
|
SCAIL rig. opacity < 1 switches to alpha-blend."""
|
||||||
suits dense rainbow body meshes; 0.3 matches SCAIL-Pose's glossy rig look.
|
|
||||||
opacity < 1 switches to alpha-blend (e.g. see-through body mesh over bones)."""
|
|
||||||
a = float(max(0.0, min(1.0, opacity)))
|
a = float(max(0.0, min(1.0, opacity)))
|
||||||
mat = {
|
mat = {
|
||||||
"pbrMetallicRoughness": {
|
"pbrMetallicRoughness": {
|
||||||
@ -1182,14 +1127,12 @@ def openpose_render_keypoints(
|
|||||||
person: Dict[str, Any], pose_data: Optional[Dict[str, Any]], part: str,
|
person: Dict[str, Any], pose_data: Optional[Dict[str, Any]], part: str,
|
||||||
*, dim: int, H: int = 0, W: int = 0,
|
*, dim: int, H: int = 0, W: int = 0,
|
||||||
) -> Optional[np.ndarray]:
|
) -> Optional[np.ndarray]:
|
||||||
"""OpenPose keypoints for one person, in op-layout, CAMERA frame (Y-down).
|
"""OpenPose keypoints for one person, op-layout, camera frame (Y-down).
|
||||||
`part` in {'body','hand_r','hand_l'}. dim=3 -> (K, 3) metres pre-cam_t-add;
|
`part` in {'body','hand_r','hand_l'}. dim=3 -> (K, 3) metres pre-cam_t-add;
|
||||||
dim=2 -> (K, 2) image pixels. Returns None when the source data is missing.
|
dim=2 -> (K, 2) pixels. Returns None when source data is missing.
|
||||||
|
|
||||||
External rigs (override carries the joint-index map) resolve from per-frame
|
External rigs resolve from `pred_joint_coords` (Y-up -> flipped to Y-down);
|
||||||
`pred_joint_coords` (rig-native Y-up -> flipped to camera Y-down, matching
|
MHR reindexes stored `pred_keypoints_{3d,2d}` via the MHR70 map."""
|
||||||
the pred_vertices convention). MHR reindexes the stored
|
|
||||||
`pred_keypoints_{3d,2d}` via the MHR70 map."""
|
|
||||||
map_key, w_key, mhr_map = _OPENPOSE_RENDER_MAPS[part]
|
map_key, w_key, mhr_map = _OPENPOSE_RENDER_MAPS[part]
|
||||||
override = _get_skeleton_override(pose_data)
|
override = _get_skeleton_override(pose_data)
|
||||||
ext_map = override.get(map_key) if override is not None else None
|
ext_map = override.get(map_key) if override is not None else None
|
||||||
@ -1228,11 +1171,9 @@ def openpose_render_keypoints(
|
|||||||
return kp_full[mhr_map]
|
return kp_full[mhr_map]
|
||||||
|
|
||||||
|
|
||||||
# Face landmarks from the MHR rig (option `face_source="rig"`).
|
# Face landmarks (face_source="rig"). MHR has no face bones, so landmarks are
|
||||||
# MHR has no face bones — face deforms via expr_params morphs — so landmarks
|
# sourced from `pred_vertices` at vertex IDs picked by NN against the target xyz
|
||||||
# are sourced from `pred_vertices` at fixed vertex IDs picked by NN against
|
# below. Tweak targets if landmarks land off-surface.
|
||||||
# anatomically-plausible target xyz in canonical Y-up. Iterate visually in
|
|
||||||
# Blender and tweak targets if landmarks land off-surface.
|
|
||||||
|
|
||||||
# (name, target_xyz) in MHR canonical Y-up meters.
|
# (name, target_xyz) in MHR canonical Y-up meters.
|
||||||
FACE_LANDMARK_TARGETS: Tuple[Tuple[str, Tuple[float, float, float]], ...] = (
|
FACE_LANDMARK_TARGETS: Tuple[Tuple[str, Tuple[float, float, float]], ...] = (
|
||||||
@ -1290,10 +1231,8 @@ def select_face_landmark_vert_ids(
|
|||||||
face_mask: Optional[np.ndarray] = None,
|
face_mask: Optional[np.ndarray] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Pick MHR head vertex IDs for each `FACE_LANDMARK_TARGETS` by NN in
|
"""Pick MHR head vertex IDs for each `FACE_LANDMARK_TARGETS` by NN in
|
||||||
canonical positions. Filter: `face_mask` (verts that deform with any of
|
canonical positions, restricted to `face_mask` verts (expression-deforming)
|
||||||
the 72 expression axes) if available — keeps chin/jaw search off the
|
when available, else a position bbox (less reliable around the chin/jaw)."""
|
||||||
neck. Otherwise a position bbox (less reliable; throat verts sometimes
|
|
||||||
pull chin targets)."""
|
|
||||||
P = np.asarray(canonical_positions, dtype=np.float32).reshape(-1, 3)
|
P = np.asarray(canonical_positions, dtype=np.float32).reshape(-1, 3)
|
||||||
if face_mask is not None and np.asarray(face_mask).any():
|
if face_mask is not None and np.asarray(face_mask).any():
|
||||||
valid = np.where(np.asarray(face_mask).reshape(-1))[0]
|
valid = np.where(np.asarray(face_mask).reshape(-1))[0]
|
||||||
|
|||||||
@ -1,19 +1,11 @@
|
|||||||
"""GLB export — skeletal (real armature) mode.
|
"""GLB export — skeletal (real armature) mode.
|
||||||
|
|
||||||
Rebuilds an Armature with the MHR 127-bone rig:
|
Rebuilds an Armature with the MHR 127-bone rig: per-frame local TRS from
|
||||||
- per-frame local TRS comes from re-running param_transform on the saved
|
param_transform on `mhr_model_params`, rest verts from a zero-pose forward,
|
||||||
`mhr_model_params`;
|
sparse skinning compacted to glTF's 4-influence form, and facial expression as
|
||||||
- rest verts come from a zero-pose forward with each person's `shape_params`;
|
72 morph targets driven by `expr_params`. Optional octahedron bone-vis is
|
||||||
- sparse triplet skinning is compacted to glTF's max-4-influences-per-vertex form;
|
rigidly skinned alongside for viewers that don't draw bones. Shared infra lives
|
||||||
- facial expression is re-exposed as 72 morph targets driven by `expr_params`
|
in `glb_shared.py`.
|
||||||
so face animation survives plain glTF skinning.
|
|
||||||
|
|
||||||
Optional bone visualization (octahedrons) is rigidly
|
|
||||||
skinned alongside the body mesh — used to preview the armature in glTF
|
|
||||||
viewers that don't draw bones.
|
|
||||||
|
|
||||||
Shared GLB infra (writer, math, rig static extraction, shaders, normals)
|
|
||||||
stays in `glb_shared.py`; only this mode's geometry + assembly live here.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -44,8 +36,7 @@ from .glb_shared import (
|
|||||||
from comfy_extras.sam3d_body.utils import jet_colormap
|
from comfy_extras.sam3d_body.utils import jet_colormap
|
||||||
|
|
||||||
def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]:
|
def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]:
|
||||||
"""Per-bone RGB color (NJ, 3) float32 in [0, 1]. Returns None for 'white'
|
"""Per-bone RGB (NJ, 3) float32 in [0, 1]. None for 'white' (default material)."""
|
||||||
(no per-bone color → bone-vis mesh uses default unlit material)."""
|
|
||||||
if scheme == "rainbow_y":
|
if scheme == "rainbow_y":
|
||||||
y = bind_pos_m[:, 1].astype(np.float32)
|
y = bind_pos_m[:, 1].astype(np.float32)
|
||||||
y_min, y_max = float(y.min()), float(y.max())
|
y_min, y_max = float(y.min()), float(y.max())
|
||||||
@ -55,9 +46,8 @@ def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray
|
|||||||
|
|
||||||
|
|
||||||
def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
|
def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
|
||||||
"""Canonical Blender-style bone octahedron. Head at origin, tail at +Y,
|
"""Canonical Blender-style bone octahedron: head at origin, tail at +Y, unit
|
||||||
unit length, ridge at 1/10 height. 6 verts, 8 triangles. Faces wound
|
length, ridge at 1/10 height. 6 verts, 8 triangles, faces wound outward."""
|
||||||
so cross(v1-v0, v2-v0) points OUTWARD from the bone axis."""
|
|
||||||
v = np.array([
|
v = np.array([
|
||||||
[0.0, 0.0, 0.0], # 0: head
|
[0.0, 0.0, 0.0], # 0: head
|
||||||
[0.0, 1.0, 0.0], # 1: tail
|
[0.0, 1.0, 0.0], # 1: tail
|
||||||
@ -78,18 +68,16 @@ def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
|
|||||||
def _bone_edges(
|
def _bone_edges(
|
||||||
joint_pos_m: np.ndarray, parents: np.ndarray,
|
joint_pos_m: np.ndarray, parents: np.ndarray,
|
||||||
) -> List[Tuple[int, int, np.ndarray, np.ndarray]]:
|
) -> List[Tuple[int, int, np.ndarray, np.ndarray]]:
|
||||||
"""Return one (parent_idx, child_idx, head_pos, tail_pos) tuple per
|
"""One (parent_idx, child_idx, head_pos, tail_pos) per parent→child edge.
|
||||||
parent→child edge in the hierarchy, skipping edges whose PARENT is a
|
Skips edges whose parent is a root (world-anchor sticks) and zero-length
|
||||||
root joint (those typically anchor the skeleton at world origin and
|
edges."""
|
||||||
just look like a stray stick from origin to the body). Zero-length
|
|
||||||
edges are skipped too."""
|
|
||||||
NJ = joint_pos_m.shape[0]
|
NJ = joint_pos_m.shape[0]
|
||||||
out: List[Tuple[int, int, np.ndarray, np.ndarray]] = []
|
out: List[Tuple[int, int, np.ndarray, np.ndarray]] = []
|
||||||
for c in range(NJ):
|
for c in range(NJ):
|
||||||
p = int(parents[c])
|
p = int(parents[c])
|
||||||
if not (0 <= p < NJ and p != c):
|
if not (0 <= p < NJ and p != c):
|
||||||
continue
|
continue
|
||||||
# Skip if parent itself is a root — that bone is a world-anchor stick.
|
# Skip world-anchor sticks: parent itself is a root.
|
||||||
gp = int(parents[p])
|
gp = int(parents[p])
|
||||||
if not (0 <= gp < NJ and gp != p):
|
if not (0 <= gp < NJ and gp != p):
|
||||||
continue
|
continue
|
||||||
@ -104,9 +92,8 @@ def _bone_edges(
|
|||||||
def _build_bone_octahedrons_mesh(
|
def _build_bone_octahedrons_mesh(
|
||||||
bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02,
|
bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""One Blender-style octahedron per parent→child edge. Returns
|
"""One octahedron per parent→child edge. Returns (verts, normals, faces,
|
||||||
(verts, normals, faces, joints, weights, child_idx_per_vert);
|
joints, weights, child_idx_per_vert); child_idx feeds per-bone color."""
|
||||||
child_idx feeds per-bone color lookup at the call site."""
|
|
||||||
base_v, base_f = _octahedron_unit()
|
base_v, base_f = _octahedron_unit()
|
||||||
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
@ -117,8 +104,7 @@ def _build_bone_octahedrons_mesh(
|
|||||||
out_w: List[List[float]] = []
|
out_w: List[List[float]] = []
|
||||||
child_per_vert: List[int] = []
|
child_per_vert: List[int] = []
|
||||||
|
|
||||||
# Width scales with length so short bones (fingers, face) don't look chunky
|
# Width scales with length (capped by half_width_m) so short bones aren't chunky.
|
||||||
# next to long ones (limbs, spine). `half_width_m` caps long bones.
|
|
||||||
WIDTH_RATIO = 0.1
|
WIDTH_RATIO = 0.1
|
||||||
MIN_WIDTH = 0.001
|
MIN_WIDTH = 0.001
|
||||||
for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents):
|
for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents):
|
||||||
@ -151,8 +137,8 @@ def _build_bone_octahedrons_mesh(
|
|||||||
out_n.extend(n_world.tolist())
|
out_n.extend(n_world.tolist())
|
||||||
for face in base_f:
|
for face in base_f:
|
||||||
out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off])
|
out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off])
|
||||||
# Dual skin head→parent, tail→child, ridges blend by canonical Y so the
|
# Dual skin (head→parent, tail→child); ridges blend by canonical Y so
|
||||||
# bone stretches between joints instead of going rigid with one.
|
# the bone stretches between joints instead of going rigid with one.
|
||||||
for k in range(base_v.shape[0]):
|
for k in range(base_v.shape[0]):
|
||||||
y_canon = float(base_v[k, 1])
|
y_canon = float(base_v[k, 1])
|
||||||
w_parent = max(0.0, 1.0 - y_canon)
|
w_parent = max(0.0, 1.0 - y_canon)
|
||||||
@ -196,22 +182,17 @@ def build_glb_skeletal(
|
|||||||
bone_vis_color: str = "white",
|
bone_vis_color: str = "white",
|
||||||
include_body_mesh: bool = True,
|
include_body_mesh: bool = True,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Build pose_data as a real Armature GLB blob with per-bone TRS keyframes.
|
"""Build pose_data as a real Armature GLB with per-bone TRS keyframes. For
|
||||||
|
MHR, facial expression is exposed as 72 morph targets when
|
||||||
|
include_face_morphs=True.
|
||||||
|
|
||||||
For MHR (default) facial expression is exposed as 72 morph targets driven
|
External skeletons (e.g. ComfyUI-Kimodo) can supply
|
||||||
by expr_params per frame when include_face_morphs=True.
|
``pose_data["_skeleton_override"]`` to bypass MHR rig extraction (``model``
|
||||||
|
may be None then); per-frame state still reads ``pred_global_rots`` /
|
||||||
External skeletons (e.g. ComfyUI-Kimodo) can supply a
|
``pred_joint_coords``. See ``glb_shared._get_skeleton_override`` for the schema.
|
||||||
``pose_data["_skeleton_override"]`` dict to bypass the MHR rig extraction
|
|
||||||
entirely. When present, ``model`` may be None and the rig data, bind pose,
|
|
||||||
skin weights, and rest verts come from the override. Per-frame skeletal
|
|
||||||
state still reads ``pred_global_rots`` / ``pred_joint_coords`` from each
|
|
||||||
person dict (kimodo populates these from its own FK output). See
|
|
||||||
``glb.shared._get_skeleton_override`` for the override schema.
|
|
||||||
"""
|
"""
|
||||||
frames = pose_data["frames"]
|
frames = pose_data["frames"]
|
||||||
# Only `pred_cam_t` is camera-y-down; mhr_model_params, lbs_*, expr_basis,
|
# Only `pred_cam_t` is camera-y-down; everything else is rig-native Y-up.
|
||||||
# faces are all rig-native (Y-up).
|
|
||||||
faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32)
|
faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32)
|
||||||
tracks = collect_tracks(pose_data, track_index)
|
tracks = collect_tracks(pose_data, track_index)
|
||||||
if not tracks:
|
if not tracks:
|
||||||
@ -219,17 +200,14 @@ def build_glb_skeletal(
|
|||||||
|
|
||||||
rig = Rig.from_pose_data(pose_data, model)
|
rig = Rig.from_pose_data(pose_data, model)
|
||||||
NJ = rig.num_joints
|
NJ = rig.num_joints
|
||||||
# NV = rig.num_verts
|
|
||||||
NEXPR = rig.num_expr
|
NEXPR = rig.num_expr
|
||||||
parents = rig.parents
|
parents = rig.parents
|
||||||
if not rig.can_rerun_fk:
|
if not rig.can_rerun_fk:
|
||||||
# External rigs have no PCA pose params to re-run; only stored globals
|
# External rigs have no PCA pose params to re-run; use stored globals.
|
||||||
# are available, and they store joint coords already Y-up.
|
|
||||||
use_stored_global_rots = True
|
use_stored_global_rots = True
|
||||||
joint_coords_y_down = rig.per_frame_y_down
|
joint_coords_y_down = rig.per_frame_y_down
|
||||||
# Skinning is already compacted to ≤8 influences per vertex (MHR averages
|
# Skin already compacted to ≤8 influences/vertex (some shoulder/hip verts
|
||||||
# ~2.8 but some shoulder/hip verts hit 5-8; keeping only 4 there leaks
|
# need >4, else per-bone rotation noise leaks into the mesh).
|
||||||
# per-bone rotation noise into the rendered mesh).
|
|
||||||
joints_8 = rig.lbs_joints
|
joints_8 = rig.lbs_joints
|
||||||
weights_8 = rig.lbs_weights
|
weights_8 = rig.lbs_weights
|
||||||
actual_max_inf = rig.lbs_max_inf
|
actual_max_inf = rig.lbs_max_inf
|
||||||
@ -238,14 +216,12 @@ def build_glb_skeletal(
|
|||||||
use_set1 = actual_max_inf > 4
|
use_set1 = actual_max_inf > 4
|
||||||
joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None
|
joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None
|
||||||
weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None
|
weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None
|
||||||
# Derive bone locals from the rig's bind globals rather than recomputing
|
# Derive bone locals from bind globals so any `parents`/FK mismatch is
|
||||||
# FK ourselves, so any mismatch between `parents` and the rig's actual FK
|
# absorbed into the local TRS instead of producing wrong globals.
|
||||||
# is absorbed into the local TRS instead of producing wrong globals.
|
|
||||||
bind_global_m = rig.bind_global_m
|
bind_global_m = rig.bind_global_m
|
||||||
bind_local = bone_locals_from_globals(bind_global_m[None], parents)[0]
|
bind_local = bone_locals_from_globals(bind_global_m[None], parents)[0]
|
||||||
|
|
||||||
# IBP = inverse of bind global. With bone defaults set to bind_local and
|
# IBP = inverse of bind global → skin_matrix at rest is identity.
|
||||||
# FK composed via `parents`, skin_matrix at rest = identity.
|
|
||||||
ibp_mat4 = ibp_from_bind_global(bind_global_m)
|
ibp_mat4 = ibp_from_bind_global(bind_global_m)
|
||||||
|
|
||||||
w = GLBWriter()
|
w = GLBWriter()
|
||||||
@ -316,9 +292,7 @@ def build_glb_skeletal(
|
|||||||
body_mesh_node_idx: Optional[int] = None
|
body_mesh_node_idx: Optional[int] = None
|
||||||
|
|
||||||
if include_body:
|
if include_body:
|
||||||
# MHR rest verts depend on the subject's shape_params; external rigs
|
# MHR rest verts depend on shape_params; external rigs ignore the arg.
|
||||||
# ship fixed rest verts and ignore the arg (so the empty external
|
|
||||||
# `shape_params` is harmless).
|
|
||||||
shape_params_arr = np.asarray(
|
shape_params_arr = np.asarray(
|
||||||
frames[frame_indices[0]][person_k].get("shape_params", []),
|
frames[frame_indices[0]][person_k].get("shape_params", []),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
@ -349,8 +323,8 @@ def build_glb_skeletal(
|
|||||||
"indices": indices_acc,
|
"indices": indices_acc,
|
||||||
"mode": 4,
|
"mode": 4,
|
||||||
}
|
}
|
||||||
# See-through body when bones are shown, else opaque (only when a
|
# See-through body when bones are shown, else opaque (only if a
|
||||||
# vertex-color shader baked COLOR_0 — otherwise default material).
|
# shader baked COLOR_0; otherwise default material).
|
||||||
if color_acc is not None or include_bones:
|
if color_acc is not None or include_bones:
|
||||||
materials.append(make_lit_material(opacity=0.35 if include_bones else 1.0))
|
materials.append(make_lit_material(opacity=0.35 if include_bones else 1.0))
|
||||||
primitive["material"] = len(materials) - 1
|
primitive["material"] = len(materials) - 1
|
||||||
@ -373,8 +347,7 @@ def build_glb_skeletal(
|
|||||||
if include_bones:
|
if include_bones:
|
||||||
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
|
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
|
||||||
|
|
||||||
# Indexes `bone_palette`: octahedrons use the bone's child joint so
|
# Color by child joint so every bone has its own color.
|
||||||
# every bone has its own color regardless of skin target.
|
|
||||||
color_idx_per_vert: Optional[np.ndarray] = None
|
color_idx_per_vert: Optional[np.ndarray] = None
|
||||||
hw = float(bone_vis_radius_m)
|
hw = float(bone_vis_radius_m)
|
||||||
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
|
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
|
||||||
@ -422,8 +395,8 @@ def build_glb_skeletal(
|
|||||||
nodes.append(bv_mesh_node)
|
nodes.append(bv_mesh_node)
|
||||||
person_root["children"].append(len(nodes) - 1)
|
person_root["children"].append(len(nodes) - 1)
|
||||||
|
|
||||||
# Per-frame GLOBAL skel state → bone locals via parent-inverse.
|
# Per-frame global skel state → bone locals via parent-inverse. Stored
|
||||||
# Default uses the rig's stored output; the fallback re-runs FK.
|
# output by default; fallback re-runs FK.
|
||||||
if use_stored_global_rots:
|
if use_stored_global_rots:
|
||||||
rig_global_m = global_skel_state_from_pose_data(
|
rig_global_m = global_skel_state_from_pose_data(
|
||||||
pose_data, frame_indices, person_k, NJ,
|
pose_data, frame_indices, person_k, NJ,
|
||||||
@ -437,11 +410,9 @@ def build_glb_skeletal(
|
|||||||
rig_global_cm = global_skel_state_per_frame(model, mp_per_frame)
|
rig_global_cm = global_skel_state_per_frame(model, mp_per_frame)
|
||||||
rig_global_m = rig_global_cm.copy().astype(np.float32)
|
rig_global_m = rig_global_cm.copy().astype(np.float32)
|
||||||
rig_global_m[..., :3] *= 0.01
|
rig_global_m[..., :3] *= 0.01
|
||||||
# Sign-fix on the GLOBAL quats BEFORE deriving locals. The rig's
|
# Sign-fix global quats BEFORE deriving locals: a parent's ±180° flip
|
||||||
# Euler-XYZ parametrization wraps at ±180° for spinning joints; if we
|
# would otherwise propagate into the child's local translation and cause
|
||||||
# only fix locals, the parent's flip propagates into the child's
|
# visible "axis resets" mid-animation.
|
||||||
# local translation (t_local inherits parent sign via q_parent_inv)
|
|
||||||
# and produces visible "axis resets" mid-animation.
|
|
||||||
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
|
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
|
||||||
bone_local_anim = bone_locals_from_globals(rig_global_m, parents)
|
bone_local_anim = bone_locals_from_globals(rig_global_m, parents)
|
||||||
local_t = bone_local_anim[..., :3].astype(np.float32)
|
local_t = bone_local_anim[..., :3].astype(np.float32)
|
||||||
@ -449,20 +420,17 @@ def build_glb_skeletal(
|
|||||||
local_s = bone_local_anim[..., 7].astype(np.float32)
|
local_s = bone_local_anim[..., 7].astype(np.float32)
|
||||||
# Second pass on locals catches residual drift from the parent-inverse.
|
# Second pass on locals catches residual drift from the parent-inverse.
|
||||||
local_q = quat_sign_fix_per_joint(local_q)
|
local_q = quat_sign_fix_per_joint(local_q)
|
||||||
# Hemisphere-align frame 0 with the bind quat so pause/play takes the
|
# Align frame 0 with the bind quat so pause/play takes the short path.
|
||||||
# short path; then re-propagate.
|
|
||||||
bind_q = bind_local[:, 3:7].astype(np.float32)
|
bind_q = bind_local[:, 3:7].astype(np.float32)
|
||||||
if local_q.shape[0] > 0:
|
if local_q.shape[0] > 0:
|
||||||
d0 = (bind_q * local_q[0]).sum(axis=-1)
|
d0 = (bind_q * local_q[0]).sum(axis=-1)
|
||||||
sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None]
|
sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None]
|
||||||
local_q[0] = local_q[0] * sign0
|
local_q[0] = local_q[0] * sign0
|
||||||
local_q = quat_sign_fix_per_joint(local_q)
|
local_q = quat_sign_fix_per_joint(local_q)
|
||||||
# Optional smoothing for multi-frame rig spikes (e.g. q.w discontinuity
|
# Optional smoothing for multi-frame rig spikes (e.g. q.w at handstand).
|
||||||
# at handstand) that the upstream Smooth node may not catch.
|
|
||||||
if bone_smooth_window and bone_smooth_window > 1:
|
if bone_smooth_window and bone_smooth_window > 1:
|
||||||
local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window))
|
local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window))
|
||||||
# fp64 renormalize → fp32 keyframes. Viewers' nlerp amplifies non-unit
|
# fp64 renormalize → fp32; viewers' nlerp amplifies non-unit drift.
|
||||||
# drift into visible flips otherwise.
|
|
||||||
lq64 = local_q.astype(np.float64)
|
lq64 = local_q.astype(np.float64)
|
||||||
lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12)
|
lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12)
|
||||||
local_q = lq64.astype(np.float32)
|
local_q = lq64.astype(np.float32)
|
||||||
@ -527,7 +495,7 @@ def build_glb_skeletal(
|
|||||||
"target": {"node": person_root_idx, "path": "translation"},
|
"target": {"node": person_root_idx, "path": "translation"},
|
||||||
})
|
})
|
||||||
|
|
||||||
# Body-mesh-only: bone-vis primitives have no morph targets.
|
# Body mesh only — bone-vis primitives have no morph targets.
|
||||||
if expr_morph_accs and body_mesh_node_idx is not None:
|
if expr_morph_accs and body_mesh_node_idx is not None:
|
||||||
expr_per_frame = np.stack([
|
expr_per_frame = np.stack([
|
||||||
np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32)
|
np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32)
|
||||||
|
|||||||
@ -34,8 +34,7 @@ def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]:
|
|||||||
|
|
||||||
def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
|
def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
|
||||||
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
|
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
|
||||||
resolution. Returns (per_frame_bboxes, per_frame_masks) or
|
resolution. Returns (None, None) on empty track / frame-count mismatch."""
|
||||||
(None, None) when the track is empty / frame count doesn't match"""
|
|
||||||
|
|
||||||
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
|
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
|
||||||
if packed is None:
|
if packed is None:
|
||||||
@ -100,7 +99,7 @@ def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[to
|
|||||||
def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any],
|
def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any],
|
||||||
H: int, W: int) -> Dict[str, Any]:
|
H: int, W: int) -> Dict[str, Any]:
|
||||||
"""Re-project every frame's pose through a Load3D 6DOF camera (position/
|
"""Re-project every frame's pose through a Load3D 6DOF camera (position/
|
||||||
target/zoom + optional FOV). Returns a new mhr_pose_data; unchanged on
|
target/zoom + optional FOV). Returns a new mhr_pose_data, unchanged on
|
||||||
empty/invalid input."""
|
empty/invalid input."""
|
||||||
first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else []
|
first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else []
|
||||||
if not first_frame:
|
if not first_frame:
|
||||||
@ -158,16 +157,16 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str,
|
|||||||
y_axis = np.cross(z_axis, x_axis)
|
y_axis = np.cross(z_axis, x_axis)
|
||||||
R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32)
|
R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32)
|
||||||
|
|
||||||
# Eye: dolly along the given offset; for a rotation-only camera (position ==
|
# Eye: dolly along the offset; rotation-only camera keeps the predicted
|
||||||
# target) keep the predicted viewing distance so only orientation/roll changes.
|
# viewing distance so only orientation/roll changes.
|
||||||
if has_offset:
|
if has_offset:
|
||||||
eye = target + offset / max(0.01, zoom)
|
eye = target + offset / max(0.01, zoom)
|
||||||
else:
|
else:
|
||||||
d = max(0.1, float(target[2]))
|
d = max(0.1, float(target[2]))
|
||||||
eye = target - z_axis * (d / max(0.01, zoom))
|
eye = target - z_axis * (d / max(0.01, zoom))
|
||||||
|
|
||||||
# Lens: use the camera's own FoV; else the SAM3D predicted focal (viewpoint-
|
# Lens: camera FoV if given, else the SAM3D predicted focal. Three.js fov
|
||||||
# only change). Three.js fov is vertical → focal from image height.
|
# is vertical → focal from image height.
|
||||||
cam_fov = float(camera_info.get("fov", 0.0) or 0.0)
|
cam_fov = float(camera_info.get("fov", 0.0) or 0.0)
|
||||||
if cam_fov > 0:
|
if cam_fov > 0:
|
||||||
new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0)))
|
new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0)))
|
||||||
@ -178,10 +177,8 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str,
|
|||||||
|
|
||||||
center = np.array([W * 0.5, H * 0.5], dtype=np.float32)
|
center = np.array([W * 0.5, H * 0.5], dtype=np.float32)
|
||||||
reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"}
|
reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"}
|
||||||
# External rigs (e.g. Kimodo) store pred_joint_coords rig-native Y-up; the
|
# External rigs store pred_joint_coords Y-up; transform them through the
|
||||||
# render openpose/scail keypoint provider resolves from them and flips Y/Z.
|
# camera too (in camera space, then back to Y-up) so they follow the override.
|
||||||
# Transform them through the camera too (in camera space, then back to Y-up)
|
|
||||||
# so those keypoints follow the override instead of staying in the old frame.
|
|
||||||
override = mhr_pose_data.get("_skeleton_override")
|
override = mhr_pose_data.get("_skeleton_override")
|
||||||
joints_y_up = override is not None and not bool(override.get("per_frame_y_down", False))
|
joints_y_up = override is not None and not bool(override.get("per_frame_y_down", False))
|
||||||
new_frames: List[List[Dict[str, Any]]] = []
|
new_frames: List[List[Dict[str, Any]]] = []
|
||||||
@ -242,8 +239,7 @@ def run_batched_single_chunk(inner: SAM3DBody, frames_rgb: List[torch.Tensor], p
|
|||||||
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
|
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
|
||||||
|
|
||||||
if per_frame_masks is not None:
|
if per_frame_masks is not None:
|
||||||
# Broadcast a single-mask bundle to per-bbox: when the user supplied one
|
# One mask but multiple bboxes per frame → each bbox gets the same mask.
|
||||||
# mask but multiple bboxes per frame, each bbox gets the same mask.
|
|
||||||
flat_masks = []
|
flat_masks = []
|
||||||
for f in range(N):
|
for f in range(N):
|
||||||
mf = per_frame_masks[f]
|
mf = per_frame_masks[f]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user