From ecbaefd8fc1ab04c92aaac65bb0b81f2f4dbd99f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 16 Jun 2026 20:47:15 +0300 Subject: [PATCH] Big cleanup --- comfy/ldm/sam3d_body/mhr/mhr_head.py | 82 ++--- comfy/ldm/sam3d_body/mhr/mhr_rig.py | 22 +- comfy/ldm/sam3d_body/mhr/mhr_utils.py | 173 +--------- comfy/ldm/sam3d_body/model/camera_modules.py | 19 +- comfy/ldm/sam3d_body/model/model.py | 80 ++--- comfy/ldm/sam3d_body/model/prompt.py | 52 +-- comfy/ldm/sam3d_body/utils.py | 2 - comfy_extras/sam3d_body/export/bvh.py | 46 +-- comfy_extras/sam3d_body/export/capsules.py | 70 ++-- .../sam3d_body/export/glb_openpose.py | 312 ++++++------------ comfy_extras/sam3d_body/export/glb_shared.py | 251 ++++++-------- .../sam3d_body/export/glb_skeletal.py | 122 +++---- comfy_extras/sam3d_body/utils.py | 22 +- 13 files changed, 376 insertions(+), 877 deletions(-) diff --git a/comfy/ldm/sam3d_body/mhr/mhr_head.py b/comfy/ldm/sam3d_body/mhr/mhr_head.py index e30e48cb3..d226180a3 100644 --- a/comfy/ldm/sam3d_body/mhr/mhr_head.py +++ b/comfy/ldm/sam3d_body/mhr/mhr_head.py @@ -4,25 +4,15 @@ import torch import torch.nn as nn from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat -from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask +from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, mhr_param_hand_mask from ..model.transformer import MLP class MHRHead(nn.Module): - def __init__( - self, - input_dim: int, - mhr_rig, - mlp_depth: int = 1, - extra_joint_regressor: str = "", - mlp_channel_div_factor: int = 8, - enable_hand_model=False, - device=None, - dtype=None, - operations=None, - ): + def __init__(self, input_dim: int, mhr_rig, mlp_depth: int = 1, mlp_channel_div_factor: int = 8, enable_hand_model=False, + device=None, dtype=None, operations=None): super().__init__() # Store the shared MHRRig as a non-registered Python attribute object.__setattr__(self, "mhr", mhr_rig) @@ -48,9 +38,7 @@ class MHRHead(nn.Module): hidden_dim=input_dim // mlp_channel_div_factor, output_dim=self.npose, num_layers=mlp_depth, - device=device, - dtype=dtype, - operations=operations, + device=device, dtype=dtype, operations=operations, ) # MHR Parameters @@ -75,28 +63,25 @@ class MHRHead(nn.Module): self.local_to_world_wrist = _p(3, 3) self.nonhand_param_idxs = _p(145, dtype=torch.int64) # Hand-painted per-vertex face region RGB (rainbow_face_semantic shader). - # Optional — loaded from the .safetensors if present, otherwise the - # render path falls back to a coarse geometric approximation. - self.register_buffer( - "face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32), - ) + self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32)) - def canonical_vertices(self, device=None): + def canonical_vertices(self): """Return the T-pose vertices for the mean shape (scaled to meters). Runs MHR with zero pose / shape / scale / expression so the returned mesh is the canonical rest pose — fixed per-model """ - dev = device or self.scale_mean.device - dt = self.scale_mean.dtype + device = self.scale_mean.device + dtype = self.scale_mean.dtype B = 1 - global_trans = torch.zeros(B, 3, device=dev, dtype=dt) - global_rot = torch.zeros(B, 3, device=dev, dtype=dt) - body_pose = torch.zeros(B, 130, device=dev, dtype=dt) - hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt) - scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt) - shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt) - expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt) + global_trans = torch.zeros(B, 3, device=device, dtype=dtype) + global_rot = torch.zeros(B, 3, device=device, dtype=dtype) + body_pose = torch.zeros(B, 130, device=device, dtype=dtype) + hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=device, dtype=dtype) + scale = torch.zeros(B, self.num_scale_comps, device=device, dtype=dtype) + shape = torch.zeros(B, self.num_shape_comps, device=device, dtype=dtype) + expr = torch.zeros(B, self.num_face_comps, device=device, dtype=dtype) + verts = self.mhr_forward( global_trans=global_trans, global_rot=global_rot, @@ -108,20 +93,6 @@ class MHRHead(nn.Module): ) # single-tensor shape (1, N_v, 3) in meters return verts[0] - def get_zero_pose_init(self, factor=1.0): - # Initialize pose token with zero-initialized learnable params - # Note: bias/initial value should be zero-pose in cont, not all-zeros - weights = torch.zeros(1, self.npose) - weights[:, : 6 + self.body_cont_dim] = torch.cat( - [ - torch.FloatTensor([1, 0, 0, 0, 1, 0]), - compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze() - * factor, - ], - dim=0, - ) - return weights - def replace_hands_in_pose(self, full_pose_params, hand_pose_params): assert full_pose_params.shape[1] == 136 @@ -159,12 +130,9 @@ class MHRHead(nn.Module): shape_params, expr_params=None, return_keypoints=False, - do_pcblend=True, return_joint_coords=False, return_model_params=False, return_joint_rotations=False, - scale_offsets=None, - vertex_offsets=None, ): # Align everything to the static buffers dt = self.scale_mean.dtype @@ -206,14 +174,10 @@ class MHRHead(nn.Module): shape_params = shape_params[None] # Convert scale... scales = self.scale_mean[None, :] + scale_params @ self.scale_comps - if scale_offsets is not None: - scales = scales + scale_offsets # Now, figure out the pose. ## 10 here is because it's more stable to optimize global translation in meters. - full_pose_params = torch.cat( - [global_trans * 10, global_rot, body_pose_params], dim=1 - ) # B x 127 + full_pose_params = torch.cat([global_trans * 10, global_rot, body_pose_params], dim=1) # B x 127 ## Put in hands if hand_pose_params is not None: full_pose_params = self.replace_hands_in_pose( @@ -268,14 +232,7 @@ class MHRHead(nn.Module): else: return tuple(to_return) - def forward( - self, - x: torch.Tensor, - init_estimate: Optional[torch.Tensor] = None, - do_pcblend=True, - slim_keypoints=False, - intermediate: bool = False, - ): + def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None, intermediate: bool = False): """ Args: x: pose token with shape [B, C], usually C=DECODER.DIM @@ -331,7 +288,6 @@ class MHRHead(nn.Module): scale_params=pred_scale, shape_params=pred_shape, expr_params=pred_face, - do_pcblend=do_pcblend, return_keypoints=True, return_joint_coords=True, return_model_params=True, @@ -356,7 +312,7 @@ class MHRHead(nn.Module): # Head-MLP outputs are promoted to fp32 here so the external # pose_output["mhr"] contract has a stable dtype regardless of what # the head ran at (fp16/bf16 for speed). MHR-derived outputs are - # already fp32 from MHR's math layers; the cast on them is a no-op. + # already fp32 from MHR's math layers. output = { "pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(), "pred_pose_rotmat": None, diff --git a/comfy/ldm/sam3d_body/mhr/mhr_rig.py b/comfy/ldm/sam3d_body/mhr/mhr_rig.py index f5542563f..90202a6fe 100644 --- a/comfy/ldm/sam3d_body/mhr/mhr_rig.py +++ b/comfy/ldm/sam3d_body/mhr/mhr_rig.py @@ -1,7 +1,7 @@ # Adapted from facebookresearch/MHR (Apache 2.0): # https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py # Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas -# verbatim from the TorchScript source bundled in the upstream mhr_model.pt +# verbatim from the upstream mhr_model.pt # (pymomentum.{skel_state,quaternion,backend.skel_state_backend}). # Original Copyright (c) Meta Platforms, Inc. and affiliates. @@ -52,7 +52,7 @@ def _skel_multiply(s1, s2): Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized before composition. With many FK levels the previously-normalized quats - drift in ULPs; the JIT renormalizes defensively, so we do too to stay + drift in ULPs; upstream renormalizes defensively, so we do too to stay bit-close to its outputs. """ t1, sc1 = s1[..., :3], s1[..., 7:8] @@ -78,7 +78,7 @@ def _skel_transform_points(skel_state, points): def _global_skel_state_from_local(local, pmi_levels): - """FK walk in fp64 (matches the JIT's use_double_precision=True path). + """FK walk in fp64 (matches upstream's use_double_precision=True path). `pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs, one per BFS level. Avoids per-call torch.split + tolist() sync. @@ -95,7 +95,7 @@ def _global_skel_state_from_local(local, pmi_levels): class MHRRig(nn.Module): """Plain-PyTorch reimplementation of Meta's MHR rig. - All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's + All math runs in fp32 (FK upcast to fp64 internally, matching upstream's use_double_precision=True backend) regardless of the host model's dtype. """ @@ -110,13 +110,11 @@ class MHRRig(nn.Module): POSE_CORR_HIDDEN = 3000 POSE_CORR_SPARSE_NNZ = 53136 - def __init__(self, device=None, dtype=None, operations=None): + def __init__(self, device=None): super().__init__() - del dtype, operations - f32 = torch.float32 # All buffers are populated by load_state_dict from the `mhr.*` keys - def _p(*shape, dtype=f32): + def _p(*shape, dtype=torch.float32): return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False) def _b(name, *shape, dtype): self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device)) @@ -147,10 +145,10 @@ class MHRRig(nn.Module): self._pmi_levels_cache = None def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True): - f32 = self.base_shape.dtype - identity_coeffs = identity_coeffs.to(f32) - model_parameters = model_parameters.to(f32) - expr_coeffs = expr_coeffs.to(f32) + dtype = self.base_shape.dtype + identity_coeffs = identity_coeffs.to(dtype) + model_parameters = model_parameters.to(dtype) + expr_coeffs = expr_coeffs.to(dtype) B = identity_coeffs.shape[0] identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs) diff --git a/comfy/ldm/sam3d_body/mhr/mhr_utils.py b/comfy/ldm/sam3d_body/mhr/mhr_utils.py index 421b5fcd7..543137f79 100644 --- a/comfy/ldm/sam3d_body/mhr/mhr_utils.py +++ b/comfy/ldm/sam3d_body/mhr/mhr_utils.py @@ -1,5 +1,5 @@ # MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers -# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity +# (batch6DFromXYZ, batchXYZfrom6D) are the continuity # representation from Zhou et al., "On the Continuity of Rotation # Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035), # implementations from papagina/RotationContinuity: @@ -158,18 +158,10 @@ def _hand_masks(device): m = _HAND_MASK_CACHE.get(device) if m is not None: return m - mask_cont_threedofs = torch.cat( - [torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS] - ).to(device) - mask_cont_onedofs = torch.cat( - [torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS] - ).to(device) - mask_model_params_threedofs = torch.cat( - [torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS] - ).to(device) - mask_model_params_onedofs = torch.cat( - [torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS] - ).to(device) + mask_cont_threedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device) + mask_cont_onedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device) + mask_model_params_threedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device) + mask_model_params_onedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device) m = dict( mask_cont_threedofs=mask_cont_threedofs, mask_cont_onedofs=mask_cont_onedofs, @@ -182,7 +174,6 @@ def _hand_masks(device): def compact_cont_to_model_params_hand(hand_cont): # These are ordered by joint, not model params ^^ - assert hand_cont.shape[-1] == 54 m = _hand_masks(hand_cont.device) mask_cont_threedofs = m["mask_cont_threedofs"] mask_cont_onedofs = m["mask_cont_onedofs"] @@ -209,120 +200,6 @@ def compact_cont_to_model_params_hand(hand_cont): return hand_model_params -def compact_model_params_to_cont_hand(hand_model_params): - # These are ordered by joint, not model params ^^ - assert hand_model_params.shape[-1] == 27 - hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1]) - assert sum(hand_dofs_in_order) == 27 - # Mask of 3DoFs into hand_cont - mask_cont_threedofs = torch.cat( - [torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order] - ) - # Mask of 1DoFs (including 2DoF) into hand_cont - mask_cont_onedofs = torch.cat( - [torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order] - ) - # Mask of 3DoFs into hand_model_params - mask_model_params_threedofs = torch.cat( - [torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order] - ) - # Mask of 1DoFs (including 2DoF) into hand_model_params - mask_model_params_onedofs = torch.cat( - [torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order] - ) - - # Convert eulers to hand_cont hand_cont - ## First for 3DoFs - hand_model_params_threedofs = hand_model_params[ - ..., mask_model_params_threedofs - ].unflatten(-1, (-1, 3)) - hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1) - ## Next for 1DoFs - hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs] - hand_cont_onedofs = torch.stack( - [hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1 - ).flatten(-2, -1) - - # Finally, assemble into a 27-dim vector, ordered by joint, then XYZ. - hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params) - hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs - hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs - - return hand_cont - - -def batch9Dfrom6D(poses): - # Args: poses: ... x 6, where "6" is the combined first and second columns - # First, get the rotaiton matrix - x_raw = poses[..., :3] - y_raw = poses[..., 3:] - - x = F.normalize(x_raw, dim=-1) - z = torch.cross(x, y_raw, dim=-1) - z = F.normalize(z, dim=-1) - y = torch.cross(z, x, dim=-1) - - matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9 - - return matrix - - -def batch4Dfrom2D(poses): - # Args: poses: ... x 2, where "2" is sincos - poses_norm = F.normalize(poses, dim=-1) - - poses_4d = torch.stack( - [ - poses_norm[..., 1], - poses_norm[..., 0], - -poses_norm[..., 0], - poses_norm[..., 1], - ], - dim=-1, - ) # Flattened SO2. - - return poses_4d # .... x 4 - - -def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False): - # fmt: off - all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]) - all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]) - all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]) - # fmt: on - num_3dof_angles = len(all_param_3dof_rot_idxs) * 3 - num_1dof_angles = len(all_param_1dof_rot_idxs) - num_1dof_trans = len(all_param_1dof_trans_idxs) - assert body_pose_cont.shape[-1] == ( - 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans - ) - # Get subsets - body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles] - body_cont_1dofs = body_pose_cont[ - ..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles - ] - body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :] - # Convert conts to model params - ## First for 3dofs - body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6)) - body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1) - ## Next for 1dofs - body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos) - body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1) - if inflate_trans: - assert ( - False - ), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots" - else: - ## Nothing to do for trans - body_rotmat_trans = body_cont_trans - # Put them together - body_rotmat_params = torch.cat( - [body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1 - ) - return body_rotmat_params - - _BODY_IDX_CACHE: dict = {} @@ -349,8 +226,6 @@ def compact_cont_to_model_params_body(body_pose_cont): (all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device) num_3dof_angles = len(all_param_3dof_rot_idxs) * 3 num_1dof_angles = len(all_param_1dof_rot_idxs) - num_1dof_trans = len(all_param_1dof_trans_idxs) - assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans # Get subsets body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles] body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles] @@ -372,42 +247,10 @@ def compact_cont_to_model_params_body(body_pose_cont): return body_pose_params -def compact_model_params_to_cont_body(body_pose_params): - # fmt: off - all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]) - all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]) - all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]) - # fmt: on - num_3dof_angles = len(all_param_3dof_rot_idxs) * 3 - num_1dof_angles = len(all_param_1dof_rot_idxs) - num_1dof_trans = len(all_param_1dof_trans_idxs) - assert body_pose_params.shape[-1] == ( - num_3dof_angles + num_1dof_angles + num_1dof_trans - ) - # Take out params - body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()] - body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs] - body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs] - # params to cont - body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten( - -2, -1 - ) - body_cont_1dofs = torch.stack( - [body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1 - ).flatten(-2, -1) - body_cont_trans = body_params_trans - # Put them together - body_pose_cont = torch.cat( - [body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1 - ) - return body_pose_cont - - -# fmt: off -mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115] -mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237] +# Hand indices into the 133-dim param and 260-dim cont body-pose vectors. +mhr_param_hand_idxs = list(range(62, 116)) +mhr_cont_hand_idxs = list(range(72, 132)) + list(range(190, 238)) mhr_param_hand_mask = torch.zeros(133).bool() mhr_param_hand_mask[mhr_param_hand_idxs] = True mhr_cont_hand_mask = torch.zeros(260).bool() mhr_cont_hand_mask[mhr_cont_hand_idxs] = True -# fmt: on diff --git a/comfy/ldm/sam3d_body/model/camera_modules.py b/comfy/ldm/sam3d_body/model/camera_modules.py index 0faff8534..c4a3d967e 100644 --- a/comfy/ldm/sam3d_body/model/camera_modules.py +++ b/comfy/ldm/sam3d_body/model/camera_modules.py @@ -43,15 +43,6 @@ class FourierPositionEncoding(nn.Module): self.num_bands = num_bands self.max_resolution = [max_resolution] * n - @property - def channels(self): - num_dims = len(self.max_resolution) - encoding_size = self.num_bands * num_dims - encoding_size *= 2 # sin-cos - encoding_size += num_dims # concat - - return encoding_size - def forward(self, pos: torch.Tensor): fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution) return fourier_pos_enc @@ -118,9 +109,7 @@ class PerspectiveHead(nn.Module): pred_cam: torch.Tensor, bbox_center: torch.Tensor, # [N, 2], in original image space (w, h) bbox_size: torch.Tensor, # [N,], in original image space - img_size: torch.Tensor, cam_int: torch.Tensor, # [B, 3, 3] - use_intrin_center: bool = False, ): batch_size = points_3d.shape[0] pred_cam = pred_cam.clone() @@ -133,12 +122,8 @@ class PerspectiveHead(nn.Module): focal_length = cam_int[:, 0, 0] tz = 2 * focal_length / bs - if not use_intrin_center: - cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs - cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs - else: - cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs - cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs + cx = 2 * (bbox_center[:, 0] - cam_int[:, 0, 2]) / bs + cy = 2 * (bbox_center[:, 1] - cam_int[:, 1, 2]) / bs pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) diff --git a/comfy/ldm/sam3d_body/model/model.py b/comfy/ldm/sam3d_body/model/model.py index 9d1a53ec8..2063ce559 100644 --- a/comfy/ldm/sam3d_body/model/model.py +++ b/comfy/ldm/sam3d_body/model/model.py @@ -37,20 +37,15 @@ class SAM3DBody(nn.Module): def __init__(self, device=None, dtype=None, operations=None): super().__init__() - # `operations` falls back to torch.nn so the model is constructible - # without comfy.ops; matches the pattern in comfy/ldm/sam3/. - ops = operations if operations is not None else nn - # Per-batch state populated by `_initialize_batch`. self._max_num_person = None - self._person_valid = None self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False) self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False) self.image_size = IMAGE_SIZE - self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=ops) + self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations) embed_dims = self.backbone.embed_dims # MHR rig shared between body + hand pose heads via a non-registered @@ -72,7 +67,7 @@ class SAM3DBody(nn.Module): self.head_pose.hand_pose_comps.data = ( torch.eye(54).to(self.head_pose.hand_pose_comps.data).float() ) - self.init_pose = ops.Embedding(1, self.head_pose.npose, device=device, dtype=dtype) + self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype) self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs) self.head_pose_hand.hand_pose_comps_ori = nn.Parameter( @@ -81,7 +76,7 @@ class SAM3DBody(nn.Module): self.head_pose_hand.hand_pose_comps.data = ( torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float() ) - self.init_pose_hand = ops.Embedding( + self.init_pose_hand = operations.Embedding( 1, self.head_pose_hand.npose, device=device, dtype=dtype ) @@ -93,25 +88,25 @@ class SAM3DBody(nn.Module): device=device, dtype=dtype, operations=operations, ) self.head_camera = PerspectiveHead(**camera_kwargs) - self.init_camera = ops.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype) + self.init_camera = operations.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype) self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs) - self.init_camera_hand = ops.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype) + self.init_camera_hand = operations.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype) cond_dim = 3 init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim linear_kwargs = dict(device=device, dtype=dtype) - self.init_to_token_mhr = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs) - self.prev_to_token_mhr = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) - self.init_to_token_mhr_hand = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs) - self.prev_to_token_mhr_hand = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) + self.init_to_token_mhr = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs) + self.prev_to_token_mhr = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) + self.init_to_token_mhr_hand = operations.Linear(init_dim, DECODER_DIM, **linear_kwargs) + self.prev_to_token_mhr_hand = operations.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) self.prompt_encoder = PromptEncoder( embed_dim=embed_dims, # match backbone dims so PE adds directly num_body_joints=N_KEYPOINTS, device=device, dtype=dtype, operations=operations, ) - self.prompt_to_token = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) + self.prompt_to_token = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs) decoder_kwargs = dict( dims=DECODER_DIM, @@ -141,11 +136,10 @@ class SAM3DBody(nn.Module): self.keypoint_embedding_idxs = list(range(N_KEYPOINTS)) self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS)) - self.keypoint_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) - self.keypoint_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) + self.keypoint_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) + self.keypoint_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) - self.hand_box_embedding = ops.Embedding(2, DECODER_DIM, **linear_kwargs) - self.hand_cls_embed = ops.Linear(DECODER_DIM, 2, **linear_kwargs) + self.hand_box_embedding = operations.Embedding(2, DECODER_DIM, **linear_kwargs) self.bbox_embed = MLP( input_dim=DECODER_DIM, hidden_dim=DECODER_DIM, output_dim=4, num_layers=3, @@ -158,13 +152,13 @@ class SAM3DBody(nn.Module): ) self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs) self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs) - self.keypoint_feat_linear = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) - self.keypoint_feat_linear_hand = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) + self.keypoint_feat_linear = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs) + self.keypoint_feat_linear_hand = operations.Linear(embed_dims, DECODER_DIM, **linear_kwargs) self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS)) self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS)) - self.keypoint3d_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) - self.keypoint3d_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) + self.keypoint3d_embedding = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) + self.keypoint3d_embedding_hand = operations.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs) self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs) @@ -183,11 +177,9 @@ class SAM3DBody(nn.Module): def _initialize_batch(self, batch: Dict) -> None: if batch["img"].dim() == 5: self._batch_size, self._max_num_person = batch["img"].shape[:2] - self._person_valid = self._flatten_person(batch["person_valid"]) > 0 else: self._batch_size = batch["img"].shape[0] self._max_num_person = 0 - self._person_valid = None def _flatten_person(self, x: torch.Tensor) -> torch.Tensor: assert self._max_num_person is not None, "No max_num_person initialized" @@ -258,11 +250,9 @@ class SAM3DBody(nn.Module): if is_multi_image: assert isinstance(img, list) n = len(img) - H_src, W_src = img[0].shape[:2] src_t = torch.stack(list(img), dim=0) else: n = int(left_xyxy.shape[0]) - H_src, W_src = img.shape[:2] src_t = img.unsqueeze(0).expand(n, -1, -1, -1) H_out, W_out = int(self.image_size[0]), int(self.image_size[1]) @@ -292,14 +282,12 @@ class SAM3DBody(nn.Module): zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device) person_valid = torch.ones((1, n), dtype=torch.float32, device=device) img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous() - ori_img_size = torch.tensor([W_src, H_src], dtype=torch.float32, device=device).expand(n, 2).contiguous() cam_int_dev = cam_int.to(device).to(dtype=torch.float32) def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy): return { "img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out) "img_size": img_size.unsqueeze(0), - "ori_img_size": ori_img_size.unsqueeze(0), "bbox_center": centers_t.to(device).unsqueeze(0), "bbox_scale": scales_t.to(device).unsqueeze(0), "bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0), @@ -349,7 +337,6 @@ class SAM3DBody(nn.Module): self, branch: str, image_embeddings: torch.Tensor, - init_estimate: Optional[torch.Tensor] = None, keypoints: Optional[torch.Tensor] = None, prev_estimate: Optional[torch.Tensor] = None, condition_info: Optional[torch.Tensor] = None, @@ -359,7 +346,6 @@ class SAM3DBody(nn.Module): of the pipeline is shared. image_embeddings: (B, C, H, W) backbone features. - init_estimate: (B, 1, C) initial pose+cam estimate to refine. keypoints: (B, N, 3) prompts as (x, y in [0, 1], label). label: 0..K = joint, -1 = incorrect, -2 = invalid. prev_estimate: (B, 1, C) previous estimate for pose refinement. @@ -402,15 +388,11 @@ class SAM3DBody(nn.Module): # .to(image_embeddings) moves weights CPU→GPU under dynamic loading # (they stay on CPU until first use). - if init_estimate is None: - init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) - init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) - init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3) + init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) + init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) + init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3) - init_input = ( - torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1) - if condition_info is not None else init_estimate - ) + init_input = torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1) token_embeddings = init_to_token(init_input).view(batch_size, 1, -1) num_pose_token = token_embeddings.shape[1] # always 1 @@ -495,9 +477,8 @@ class SAM3DBody(nn.Module): def _get_mask_prompt(self, batch, image_embeddings): x_mask = self._flatten_person(batch["mask"]) - # batch tensors are fp32 from prepare_batch; mask_downscaling is in the - # Loader's dtype — cast once so the conv input matches. x_mask = x_mask.to(image_embeddings.dtype) + mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings( x_mask, image_embeddings.shape[0], image_embeddings.shape[2:] ) @@ -546,7 +527,6 @@ class SAM3DBody(nn.Module): # expand+contiguous for the vertices branch. bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx] bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0] - ori_img_size = self._flatten_person(batch["ori_img_size"])[batch_idx] cam_int = self._flatten_person( batch["cam_int"] .unsqueeze(1) @@ -556,8 +536,7 @@ class SAM3DBody(nn.Module): def _project(points_3d): return head_camera.perspective_projection( - points_3d, pred_cam, bbox_center, bbox_scale, ori_img_size, cam_int, - use_intrin_center=True, + points_3d, pred_cam, bbox_center, bbox_scale, cam_int, ) cam_out = _project(pose_output["pred_keypoints_3d"]) @@ -632,7 +611,6 @@ class SAM3DBody(nn.Module): tokens_output, pose_output = self.forward_decoder( "body", image_embeddings[self.body_batch_idx], - init_estimate=None, keypoints=keypoints_prompt[self.body_batch_idx], prev_estimate=None, condition_info=condition_info[self.body_batch_idx], @@ -643,7 +621,6 @@ class SAM3DBody(nn.Module): tokens_output_hand, pose_output_hand = self.forward_decoder( "hand", image_embeddings[self.hand_batch_idx], - init_estimate=None, keypoints=keypoints_prompt[self.hand_batch_idx], prev_estimate=None, condition_info=condition_info[self.hand_batch_idx], @@ -661,10 +638,8 @@ class SAM3DBody(nn.Module): # match the head-MLP external contract (_get_hand_box would .float() anyway). if len(self.body_batch_idx): output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float() - output["mhr"]["hand_logits"] = self.hand_cls_embed(tokens_output).float() if len(self.hand_batch_idx): output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid() - output["mhr_hand"]["hand_logits"] = self.hand_cls_embed(tokens_output_hand) return output @@ -715,10 +690,10 @@ class SAM3DBody(nn.Module): # Concat lhand+rhand along dim 0 so backbone+decoder run once on # (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass. batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand) - saved_batch_state = (self._batch_size, self._max_num_person, self._person_valid) + saved_batch_state = (self._batch_size, self._max_num_person) self._initialize_batch(batch_hands) hands_output = self.forward_step(batch_hands, decoder_type="hand") - self._batch_size, self._max_num_person, self._person_valid = saved_batch_state + self._batch_size, self._max_num_person = saved_batch_state n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1] lhand_output, rhand_output = self._split_hand_output(hands_output, n_left) # Free the batched image_embeddings/condition_info (unused downstream); @@ -808,9 +783,7 @@ class SAM3DBody(nn.Module): # to get an updated body pose estimation. self._set_active_branch("body") - right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() - left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() - left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 + # right_kps_full / left_kps_full already computed above (unchanged since). right_kps_crop = self._full_to_crop(batch, right_kps_full) left_kps_crop = self._full_to_crop(batch, left_kps_full) @@ -1030,7 +1003,6 @@ class SAM3DBody(nn.Module): _, pose_output = self.forward_decoder( "body", image_embeddings, - init_estimate=None, # use the default init, not the prev estimate keypoints=keypoint_prompt, prev_estimate=prev_estimate, condition_info=condition_info, diff --git a/comfy/ldm/sam3d_body/model/prompt.py b/comfy/ldm/sam3d_body/model/prompt.py index 2a5466276..5b2ec232d 100644 --- a/comfy/ldm/sam3d_body/model/prompt.py +++ b/comfy/ldm/sam3d_body/model/prompt.py @@ -29,38 +29,37 @@ class PromptEncoder(nn.Module): Encodes prompts for input to SAM's mask decoder. """ super().__init__() - ops = operations if operations is not None else nn self.embed_dim = embed_dim self.num_body_joints = num_body_joints # Keypoint prompts self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.point_embeddings = nn.ModuleList( - [ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)] + [operations.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)] ) - self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) - self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) + self.not_a_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype) + self.invalid_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype) # Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim. - LN2d = LayerNorm2d_op(ops) + LN2d = LayerNorm2d_op(operations) mask_in_chans = 256 self.mask_downscaling = nn.Sequential( - ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype), + operations.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(mask_in_chans // 64, device=device, dtype=dtype), nn.GELU(), - ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype), + operations.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(mask_in_chans // 16, device=device, dtype=dtype), nn.GELU(), - ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype), + operations.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(mask_in_chans // 4, device=device, dtype=dtype), nn.GELU(), - ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype), + operations.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(mask_in_chans, device=device, dtype=dtype), nn.GELU(), - ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype), + operations.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype), ) # Trained values for the gating conv and no_mask_embed are loaded from the state dict - self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) + self.no_mask_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype) def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor: """Positional encoding over the image-embedding grid; (1, C, H, W).""" @@ -120,8 +119,7 @@ class PromptEncoder(nn.Module): Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(keypoints, boxes, masks) - # Anchor device on the input prompts so we don't pull the offloaded - # CPU embedding device under dynamic loading. + ref = keypoints if keypoints is not None else boxes if boxes is not None else masks device = ref.device if ref is not None else self.point_embeddings[0].weight.device weight_dtype = self.invalid_point_embed.weight.dtype @@ -136,23 +134,10 @@ class PromptEncoder(nn.Module): return sparse_embeddings, sparse_masks - def get_mask_embeddings( - self, - masks: Optional[torch.Tensor] = None, - bs: int = 1, - size: Tuple[int, int] = (16, 16), # [H, W] - ) -> torch.Tensor: - """Embeds mask inputs.""" - # masks is always on the active device when present; fall back to the - # downscaling Conv's weight device when it isn't (rare callers). - ref = masks if masks is not None else next(self.mask_downscaling.parameters()) - no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand( - bs, -1, size[0], size[1] - ) - if masks is not None: - mask_embeddings = self.mask_downscaling(masks) - else: - mask_embeddings = no_mask_embeddings + def get_mask_embeddings(self, masks: torch.Tensor, bs: int = 1, size: Tuple[int, int] = (16, 16)) -> torch.Tensor: + """Embeds mask inputs. Caller casts both outputs to its working dtype.""" + no_mask_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, size[0], size[1]) + mask_embeddings = self.mask_downscaling(masks) return mask_embeddings, no_mask_embeddings @@ -170,12 +155,9 @@ class PromptableDecoder(nn.Module): repeat_pe: bool = False, do_interm_preds: bool = False, keypoint_token_update: bool = False, - device=None, - dtype=None, - operations=None, + device=None, dtype=None, operations=None, ): super().__init__() - ops = operations if operations is not None else nn self.layers = nn.ModuleList( TransformerDecoderLayer( @@ -193,7 +175,7 @@ class PromptableDecoder(nn.Module): for i in range(depth) ) - self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype) + self.norm_final = operations.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype) self.do_interm_preds = do_interm_preds self.keypoint_token_update = keypoint_token_update diff --git a/comfy/ldm/sam3d_body/utils.py b/comfy/ldm/sam3d_body/utils.py index 45b7cb014..d3e5d36c8 100644 --- a/comfy/ldm/sam3d_body/utils.py +++ b/comfy/ldm/sam3d_body/utils.py @@ -166,12 +166,10 @@ def prepare_batch( mask_score_t = torch.ones((n,), dtype=torch.float32) img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous() - ori_img_size_t = torch.tensor([width, height], dtype=torch.float32).expand(n, 2).contiguous() batch = { "img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out) "img_size": img_size_t.unsqueeze(0), # (1, N, 2) - "ori_img_size": ori_img_size_t.unsqueeze(0),# (1, N, 2) "bbox_center": centers.unsqueeze(0), # (1, N, 2) "bbox_scale": scales.unsqueeze(0), # (1, N, 2) "bbox": boxes_t.unsqueeze(0), # (1, N, 4) diff --git a/comfy_extras/sam3d_body/export/bvh.py b/comfy_extras/sam3d_body/export/bvh.py index e690a7b86..7e4a9c1a5 100644 --- a/comfy_extras/sam3d_body/export/bvh.py +++ b/comfy_extras/sam3d_body/export/bvh.py @@ -1,11 +1,9 @@ """BVH export for SAM 3D Body pose_data. -BVH stores explicit bone OFFSETs per joint, so any standard importer -(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations -directly — no heuristic guessing as needed for glTF. We skip the rig's joint 0 -(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos + -ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are -intrinsic Z-X-Y Euler degrees. +BVH stores explicit bone OFFSETs per joint, so standard importers reconstruct +anatomical bone orientations directly (unlike glTF). We skip the rig's joint 0 +(static world anchor) and use joint 1 as the ROOT (6 channels: XYZ pos + ZXY +rot); other joints get 3 channels. Rotations are intrinsic Z-X-Y Euler degrees. """ from __future__ import annotations @@ -49,13 +47,10 @@ def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray: def _find_bvh_root(parents: np.ndarray, is_external: bool = False) -> int: - """First child of the rig's world anchor so the static origin→body stick - bone gets left out. Falls back to the first root joint. - - MHR's joint 0 is a static world anchor whose single child is the pelvis, so - skipping it is correct. External rigs (e.g. SOMA-77) whose root is already - the articulated body root with multiple child chains must keep the root — - descending into one child would drop the sibling limbs from the BVH.""" + """First child of the rig's world anchor, dropping the origin→body stick. + Falls back to the first root joint. External rigs whose root is already the + articulated body root with multiple child chains keep the root — descending + into one child would drop the sibling limbs.""" NJ = parents.shape[0] world_anchors = [j for j in range(NJ) if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)] @@ -93,14 +88,11 @@ def build_bvh( track_index: int = -1, units: str = "cm", ) -> bytes: - """Build a BVH file from pose_data. Returns UTF-8 encoded text bytes. + """Build a BVH file from pose_data. Returns UTF-8 text bytes. `model` may be None when pose_data carries a `_skeleton_override` (external - rigs, e.g. Kimodo); the rig hierarchy/offsets/bind are read from the - override instead of the MHR model. - - `units` is "cm" (default, standard mocap convention) or "m". Affects the - OFFSET and root-position values; rotations are independent of units. + rigs); the rig hierarchy/offsets/bind come from the override. `units` is + "cm" (default) or "m" — affects OFFSET/root-position, not rotations. """ if units not in ("cm", "m"): raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}") @@ -123,10 +115,8 @@ def build_bvh( body_root = _find_bvh_root(parents, is_external) children_map = _build_children_map(parents) - # Bone OFFSETs come from MHR's translation_offsets (joint position - # relative to parent in parent's local-bind frame). For the BVH root, - # we use its bind world position so the skeleton sits at the right - # spot when imported. + # Bone OFFSETs = translation_offsets (joint position relative to parent). + # The BVH root uses its bind world position so the skeleton imports in place. bind_global = rig.bind_global_cm # (NJ, 8) cm bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m offset_m = rig.joint_offsets_cm.astype(np.float64) * 0.01 @@ -139,9 +129,8 @@ def build_bvh( _visit(c) _visit(body_root) - # Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative) - # rather than re-running rig.forward, then derive locals with body_root - # treated as the hierarchy root in BVH-space. + # Stored pred_global_rots/pred_joint_coords (authoritative); derive locals + # with body_root as the BVH-space hierarchy root. rig_global_m = global_skel_state_from_pose_data( pose_data, frame_indices, person_k, NJ, joint_coords_y_down=rig.per_frame_y_down, @@ -203,9 +192,8 @@ def build_bvh( lines.append(f"Frames: {n_frames}") lines.append(f"Frame Time: {1.0 / float(fps):.6f}") - # Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per - # frame, columns in `bvh_order` order. Vectorized — savetxt's C-side - # formatting beats Python f-strings by ~10× on long clips. + # Channel matrix per frame: root pos (3) + root rot (3) + non-root rots + # (3 each), columns in `bvh_order`. savetxt is far faster than f-strings. non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64) motion = np.concatenate([ root_pos_m * unit_scale, # (N, 3) diff --git a/comfy_extras/sam3d_body/export/capsules.py b/comfy_extras/sam3d_body/export/capsules.py index 50fe33a5f..ece5ecc5b 100644 --- a/comfy_extras/sam3d_body/export/capsules.py +++ b/comfy_extras/sam3d_body/export/capsules.py @@ -1,12 +1,9 @@ -"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent -torch ray-marching SDF renderer adapted to SAM3DBody pose_data. +"""3D capsule rendering for OpenPose-style skeletons (SCAIL-Pose look). -Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps), -projected through the per-person camera (`pred_cam_t` + `focal_length` + -image_size) so closer limbs appear thicker/brighter — the SCAIL-Pose -visual style. Self-contained: no dependency on the SCAIL-Pose package. - -Output: (H, W, 3) fp32 torch.Tensor in [0, 1]. +Each limb is a true 3D capsule (cylinder + hemispherical caps), projected +through the per-person camera (`pred_cam_t` + `focal_length` + image_size) so +closer limbs appear thicker/brighter. Self-contained analytic ray-capsule +renderer. Output: (H, W, 3) fp32 torch.Tensor in [0, 1]. """ from typing import Any, Dict, List, Optional, Tuple @@ -41,14 +38,12 @@ def _build_specs_from_pose( palette: str, person_brightness_falloff: float = 0.0, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Flatten body + optional hand limbs for one frame into - (starts, ends, colors_rgba, is_hand) in camera coords (Y-down, +Z forward). - Drops endpoints that are non-finite or behind the camera. `is_hand` flags - the hand limbs so the renderer can draw them thinner. + """Flatten body + optional hand limbs for one frame into (starts, ends, + colors_rgba, is_hand) in camera coords (Y-down, +Z forward). Drops non-finite + or behind-camera endpoints; `is_hand` lets the renderer draw hands thinner. - `person_brightness_falloff` mixes each per-person limb color toward white - by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the - mesh rasterizer and GLB exporters.""" + `person_brightness_falloff` mixes each per-person color toward white by + `1 - falloff^k` for track k (track 0 stays vivid).""" starts: List[np.ndarray] = [] ends: List[np.ndarray] = [] colors: List[np.ndarray] = [] @@ -65,8 +60,7 @@ def _build_specs_from_pose( if body_op is None or cam_t is None: continue cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3) - # op-keypoints are camera frame (Y-down); add cam_t to place the - # subject in front of the camera. + # op-keypoints are camera frame; add cam_t to place the subject in front. body_kp = body_op + cam_t_np[None, :] pastel = 0.0 if k == 0 else (1.0 - falloff ** k) @@ -148,10 +142,9 @@ def _ray_capsule_t( ba_len: torch.Tensor, # (M,) segment length radius: torch.Tensor, # (M,) per-capsule radius ) -> torch.Tensor: - """Closed-form ray-capsule intersection. Returns (K, M) tensor of ray - parameters t to the nearest valid hit per capsule, +inf where the ray - misses. A capsule is the union of (cylinder body, hemisphere at A, - hemisphere at B); each component is a quadratic root-find.""" + """Closed-form ray-capsule intersection -> (K, M) ray params t to the nearest + valid hit per capsule, +inf on miss. Capsule = union of (cylinder, hemisphere + at A, hemisphere at B), each a quadratic root-find.""" INF = float("inf") r_sq = radius * radius # (M,) @@ -238,9 +231,8 @@ def _render_capsules_torch( z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item())) z_near = max(0.05, z_min - float(radius.max().item())) - # Union of per-capsule screen-space bboxes. Pixels outside this mask - # provably can't hit any capsule, so the analytic intersection only runs - # on the relevant subset of the canvas (~5-15% at 1080p for typical poses). + # Union of per-capsule screen-space bboxes — pixels outside can't hit any + # capsule, so intersection only runs on the relevant subset of the canvas. sz = starts[:, 2].clamp(min=z_near) ez = ends[:, 2].clamp(min=z_near) sx_p = starts[:, 0] * fx / sz + cx @@ -261,16 +253,13 @@ def _render_capsules_torch( if xmax_i > xmin_i and ymax_i > ymin_i: coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True - # Analytic ray-capsule intersection. One pass over the masked pixels — - # the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel - # plus 6 SDF evaluations per hit pixel for finite-difference normals. + # Analytic ray-capsule intersection, one pass over the masked pixels. INF = float("inf") flat_t = torch.full((N,), INF, device=device, dtype=torch.float32) flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long) active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1) if active_idx.numel() > 0: - # Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory - # manageable when both K (image pixels) and M (capsules) are large. + # Cap per-chunk (K, M) tensors to ~4M elements to bound peak memory. chunk_max = max(1, int(4_000_000 / max(M, 1))) for i0 in range(0, active_idx.numel(), chunk_max): sub = active_idx[i0 : i0 + chunk_max] @@ -284,7 +273,7 @@ def _render_capsules_torch( flat_t[winners] = t_min[hit] flat_m_idx[winners] = m_idx[hit] - # Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade. + # Shade via analytic normal (P - closest point on segment). out = torch.zeros((N, 3), dtype=torch.float32, device=device) if background_rgb is not None: out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone() @@ -306,10 +295,10 @@ def _render_capsules_torch( col = colors[m_h, :3] if flat_shade: - # Solid per-limb color (OpenPose look) — no lighting/depth modulation. + # Solid per-limb color (OpenPose look) — no lighting/depth. out[hit_idx] = col return out.view(H, W, 3).clamp(0.0, 1.0) - # SCAIL Blinn-Phong (render_torch.py:290-331). Headlight: light = +Z. + # SCAIL Blinn-Phong, headlight along +Z. diff = torch.clamp(-(normals[:, 2]), min=0.0) diffuse = 0.45 + 0.55 * diff @@ -319,7 +308,7 @@ def _render_capsules_torch( half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8) spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32) - # Mild depth fade matches SCAIL's mm-scale ramp in our meter units. + # Mild depth fade. z_vals = p_hit[:, 2] z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item()) if z_hi - z_lo > 1e-6: @@ -351,21 +340,18 @@ def render_pose_data_capsules( hand_radius_scale: float = 0.4, device: Optional[torch.device] = None, ) -> torch.Tensor: - """Render a frame's pose_data as 3D capsules projected through the per- - person camera. Returns (H, W, 3) fp32 in [0, 1]. + """Render a frame's pose_data as 3D capsules through the per-person camera. + Returns (H, W, 3) fp32 in [0, 1]. - `composite='over'` paints over `background` (black if None); - `composite='mesh_only'` always uses a black canvas. - - `radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`). - Hand limbs use `radius_m * hand_radius_scale` (their bones are far shorter - than body limbs). Camera fx/fy come from each person's `focal_length`. + `composite='over'` paints over `background` (black if None); 'mesh_only' + uses a black canvas. `radius_m` is in meters; hand limbs use + `radius_m * hand_radius_scale`. fx/fy come from each person's `focal_length`. """ persons = pose_data["frames"][frame_idx] if device is None: device = comfy.model_management.get_torch_device() - # SAM3DBody shares one camera across the clip — pick from the first valid person. + # SAM3DBody shares one camera across the clip — use the first valid person. fx = fy = float(min(H, W)) for person in persons: f = person.get("focal_length") diff --git a/comfy_extras/sam3d_body/export/glb_openpose.py b/comfy_extras/sam3d_body/export/glb_openpose.py index 666aab69e..38a651399 100644 --- a/comfy_extras/sam3d_body/export/glb_openpose.py +++ b/comfy_extras/sam3d_body/export/glb_openpose.py @@ -1,16 +1,10 @@ """GLB export — OpenPose 18-keypoint visualization mode. -Independent of the MHR rig — sourced from pose_data's `pred_keypoints_3d` -(the model's regressed surface keypoints). Each track becomes an armature -with a sibling joint per keypoint; sphere markers + stick/capsule limbs are -skinned to those joints. - -Optional hand keypoints (also from `pred_keypoints_3d`, indices 21..62) and -face landmarks (sampled from `pred_vertices` at fixed head-mesh vertex IDs) -extend the same armature. - -OpenPose-shared tables / palettes / mappings live in `glb_shared.py` and are -imported below — they're also used by the 2D and 3D renderers in this package. +Sourced from pose_data's `pred_keypoints_3d`, independent of the MHR rig. Each +track becomes an armature with a joint per keypoint; sphere markers and limbs +are skinned to those joints. Optional hands (`pred_keypoints_3d` 21..62) and +face landmarks (`pred_vertices` at fixed vertex IDs) extend the same armature. +Shared tables/palettes/mappings live in `glb_shared.py`. """ from __future__ import annotations @@ -55,9 +49,8 @@ def _finalize_skinned_mesh( joints: np.ndarray, weights: np.ndarray, vert_colors: np.ndarray, smooth_shade: bool, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Apply smooth or flat shading to an indexed sphere/stick group mesh and - pack per-vertex colors. Smooth keeps the indexed mesh + per-vertex colors; - flat duplicates verts per face and gathers face-corner colors.""" + """Shade a skinned group mesh and pack per-vertex colors. Smooth keeps the + indexed mesh; flat duplicates verts per face and gathers face-corner colors.""" if smooth_shade: v_f, n_f, f_f, j_f, w_f = smooth_shade_mesh(verts, faces, joints, weights) return v_f, n_f, f_f, j_f, w_f, vert_colors.astype(np.float32) @@ -73,10 +66,8 @@ def _finalize_skinned_mesh( def _pair_colors_from_kp( pairs: Tuple[Tuple[int, int], ...], kp_colors: np.ndarray, endpoint: int = 1, ) -> np.ndarray: - """Per-limb color = endpoint-vertex color from `kp_colors`. Default - `endpoint=1` picks the second (distal) vertex of each pair, which is - the OpenPose-canonical per-finger gradient when fingers go base→tip - (wrist=0 → thumb1=1 → thumb2=2 …).""" + """Per-limb color from `kp_colors`. `endpoint=1` (default) picks the distal + vertex of each pair — the OpenPose per-finger gradient for base→tip fingers.""" n = len(pairs) out = np.zeros((n, 3), dtype=np.float32) for i, (a, b) in enumerate(pairs): @@ -88,19 +79,13 @@ def _openpose_bind_at_rig_rest( pose_data: Dict[str, Any], *, include_hands: bool, face_vert_ids: Optional[np.ndarray], ) -> Optional[np.ndarray]: - """OpenPose keypoint positions at the rig's REST pose (T-pose at authoring - origin), built from the `_skeleton_override`'s `bind_global_m` (joint rest - TRS) and `rest_verts_m` (mesh rest verts for face landmarks). + """OpenPose keypoint positions at the rig's REST pose, from the override's + `bind_global_m` (joint rest TRS) and `rest_verts_m` (face landmarks). - Used as the static-bind for openpose-mode geometry so the GLB's static - POSITION attribute sits at rig origin — matching skeletal mode's bind and - producing the same 'snap from rest to scene-frame-0' transition at the - start of playback. Without this, the static geometry is at scene-frame-0 - (kp_seq[0]) and viewers that auto-fit on static POSITION will center on - the scene location, hiding the per-frame motion. - - Returns None when the override is missing or doesn't carry all the needed - mappings — caller falls back to per-frame extraction (kp_seq[0]).""" + Used as the static-bind so the GLB's static POSITION sits at rig origin, + matching skeletal mode and producing the same rest→scene-frame-0 transition. + Returns None when the override lacks the needed mappings — caller then falls + back to per-frame extraction (kp_seq[0]).""" override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None if override is None or "bind_global_m" not in override: return None @@ -141,19 +126,12 @@ def _openpose_bind_at_rig_rest( def _extract_openpose_keypoints( pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, ) -> np.ndarray: - """(N, 18, 3) OpenPose keypoint positions in rig-native Y-up metres. + """(N, 18, 3) OpenPose keypoints in rig-native Y-up metres. - Two sources, in priority order: - - 1. **External-skeleton path** — when pose_data has `_skeleton_override` - with `openpose18_joint_indices` ((18, 2) int32, see - `_resolve_openpose_keypoints_from_joints`), synthesize from each - person's `pred_joint_coords` directly. The override frame is already - rig-native Y-up, so no axis flip. - 2. **MHR70 path** (default for SAM3DBody_Predict output) — re-index the - first 70 of 308 MHR keypoints (`pred_keypoints_3d`) to COCO-18. - Stored y-down (post `j3d[..., [1,2]] *= -1` in sam3d_body), so we - un-flip y/z to match rig-native Y-up. + External-skeleton path: when the override carries `openpose18_joint_indices` + ((18, 2) int32), synthesize from each person's `pred_joint_coords` (already + Y-up, no flip). MHR70 path (default): re-index `pred_keypoints_3d` to COCO-18 + and un-flip y/z (stored y-down by sam3d_body). """ frames = pose_data["frames"] N = len(frame_indices) @@ -195,10 +173,8 @@ def _extract_openpose_keypoints( for t_idx, t in enumerate(frame_indices): person = frames[t][person_k] if "pred_keypoints_3d" not in person: - # Diagnose the source: external-skeleton producers ship - # `_skeleton_override` instead of MHR70 keypoints. If the - # producer didn't populate `openpose18_joint_indices` either, - # we can't synthesize the 18-keypoint set. + # External-skeleton producer without `openpose18_joint_indices`: + # can't synthesize the 18-keypoint set. if override is not None: raise ValueError( "build_glb_openpose: this pose_data carries " @@ -229,15 +205,11 @@ def _extract_openpose_keypoints( def _extract_openpose_hand_keypoints( pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, ) -> np.ndarray: - """(N, 42, 3) right+left OpenPose hand keypoints (21 + 21) in rig-native - Y-up frame. + """(N, 42, 3) right+left OpenPose hand keypoints (21+21) in rig-native Y-up. - External-skeleton path: requires `openpose_hand21_r_joint_indices` AND - `openpose_hand21_l_joint_indices` ((21, 2) int32 each) in the override. - Resolved against per-frame `pred_joint_coords` like the body path. - - MHR70 path: re-orders `pred_keypoints_3d` indices 21..62 to OpenPose-21 - (wrist + 5 fingers, thumb→pinky, base→tip).""" + External-skeleton path: needs `openpose_hand21_{r,l}_joint_indices` ((21, 2) + int32) in the override, resolved against `pred_joint_coords`. MHR70 path: + re-orders `pred_keypoints_3d` 21..62 to OpenPose-21 (wrist + 5 fingers).""" frames = pose_data["frames"] N = len(frame_indices) out = np.zeros((N, 42, 3), dtype=np.float32) @@ -307,10 +279,8 @@ def _extract_face_landmarks_from_verts( pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, vert_ids: np.ndarray, ) -> np.ndarray: - """(N, K_face, 3) face landmarks sampled from per-frame `pred_vertices` - at the supplied head-mesh vertex IDs, unflipped to MHR-native Y-up. - Each landmark inherits per-frame shape/expr/pose deformation for free - since `pred_vertices` already has it baked in.""" + """(N, K_face, 3) face landmarks sampled from `pred_vertices` at the given + vertex IDs, unflipped to Y-up. Per-frame deformation is already baked in.""" frames = pose_data["frames"] N = len(frame_indices) K = int(vert_ids.shape[0]) @@ -335,18 +305,11 @@ def _build_openpose_spheres( smooth_shade: bool = False, joint_indices: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """UV sphere per OpenPose keypoint, rigidly skinned to that keypoint's - joint, vertex-colored from kp_colors. `base_joint_idx` is added to the - emitted JOINTS_0 indices so callers can place this group at any offset - in the shared skin (body=0, right hand=18, etc.). `joint_indices` (when - given) overrides that with explicit per-sphere joint indices, so callers - can skip keypoints (e.g. SCAIL head dots). - - `smooth_shade=True` keeps the indexed mesh and writes per-vertex - normals via face-normal averaging — round shading on the spheres. - `smooth_shade=False` (default) flat-shades by duplicating verts per - face, matching the existing OpenPose-mode look. Returns - (verts, normals, faces, joints4, weights4, vert_colors).""" + """UV sphere per keypoint, rigidly skinned to that keypoint's joint and + vertex-colored from kp_colors. `base_joint_idx` offsets the emitted JOINTS_0 + indices (body=0, right hand=18, …); `joint_indices`, if given, sets explicit + per-sphere indices so callers can skip keypoints (e.g. SCAIL head dots). + Returns (verts, normals, faces, joints4, weights4, vert_colors).""" sv, sf = uv_sphere_unit() K = bind_kp_m.shape[0] Nv = sv.shape[0] @@ -376,43 +339,23 @@ def _capsule_mesh_local( end_width_frac: float = 0.3, shape: str = "ellipsoid", ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Build a per-limb mesh in limb-local frame along +Y from y=0 (head - pole) to y=L (tail pole). + """Per-limb mesh in limb-local frame along +Y from y=0 (head) to y=L (tail). - `shape` selects the silhouette: - - 'ellipsoid' (default): tips are small hemispheres of radius - `W * end_width_frac`; body has ellipsoidal radius profile - sin(π*u) from w_end at the junctions to W at the middle. Gives - a fat-middle / narrow-end stretched-ellipse look. - - 'capsule': SCAIL-style "rig" limb — an OPEN cylinder of constant - radius W with no hemisphere caps. Pair with sphere joint markers - of the same radius so the spheres seamlessly cap the open - cylinder ends (the cylinder cross-section ring at the endpoint - lies exactly on the sphere surface). Drawing hemisphere caps - inside the joint sphere creates a visible bump where the cap - pokes out unevenly when sphere radius ≠ cap radius — open - cylinders avoid that. + `shape`: + - 'ellipsoid' (default): hemisphere tips of radius `W * end_width_frac`, + ellipsoidal sin(π·u) body profile (fat middle, narrow ends). + - 'capsule': SCAIL "rig" limb — an OPEN cylinder of constant radius W, + no caps. Pair with same-radius sphere markers so they cap the ends + seamlessly (caps would bump out when sphere radius ≠ cap radius). - Per-limb mesh is required because the cap height (w_end) depends on - the limb width — a single canonical mesh can't produce true - hemispheres for arbitrary L:W ratios in ellipsoid mode. + A per-limb mesh is needed because cap height depends on width — one + canonical mesh can't give true hemispheres for arbitrary L:W in ellipsoid. - Returns: - verts: (Nv, 3) float32 — limb-local positions in meters. - faces: (Nf, 3) uint32 — triangle indices. - weights: (Nv, 2) float32 — (head, tail) skinning weights, linearly - interpolated by axial position (sums to 1). + Returns (verts (Nv,3), faces (Nf,3), weights (Nv,2) head/tail, sums to 1). """ W = max(1e-6, min(float(W), float(L) * 0.5 - 1e-6)) if str(shape) == "capsule": - # SCAIL-style "rig" limb: an OPEN cylinder of constant radius W, - # no hemisphere caps. The sphere joint markers at each endpoint - # provide the rounded ends of the bone — when sphere_radius == - # cylinder_radius, the cylinder cross-section ring at the bone - # endpoint lies exactly on the sphere surface, so silhouette is - # seamless. Hemisphere caps would create a visible bump where - # the cap pokes out of the sphere if cap_r ≠ marker_r, so we - # omit them entirely. + # Open cylinder, no caps — sphere markers cap the ends (see docstring). cap_r = 0.0 body_r = W if n_cap_lat is None: @@ -425,7 +368,7 @@ def _capsule_mesh_local( end_frac = float(min(0.95, max(0.05, end_width_frac))) cap_r = max(1e-7, W * end_frac) body_r = W - # Ellipsoid defaults: more body rings to sample the sin(π·u) curve. + # More body rings to sample the sin(π·u) curve. if n_cap_lat is None: n_cap_lat = 3 if n_body is None: @@ -473,10 +416,7 @@ def _capsule_mesh_local( phi = 2.0 * np.pi * k / n_lon verts.append([body_r * float(np.cos(phi)), 0.0, body_r * float(np.sin(phi))]) - # Body intermediate rings (between the cap junctions for capped meshes, - # between the two end rings for open cylinders). For 'capsule' mode - # n_body=0 by default — no intermediate rings needed for a constant- - # radius cylinder. + # Body intermediate rings (none for 'capsule', n_body=0 by default). body_rings: List[int] = [] is_ellipsoid = str(shape) == "ellipsoid" for j in range(1, n_body + 1): @@ -572,11 +512,8 @@ def _scail_redirect_neck_stub(body_kp: np.ndarray) -> np.ndarray: def _openpose_limb_rest_trs( bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...], ) -> Tuple[np.ndarray, np.ndarray]: - """Per-limb rest TRS: - midpoints (K_pairs, 3): rest midpoint between bind_kp_m[a] and bind_kp_m[b]. - rest_axes (K_pairs, 3): unit direction a→b at rest (or +Y if degenerate). - Caller uses `midpoints` as each limb joint's rest translation (rotation = - identity), and `rest_axes` to compute per-frame alignment rotations.""" + """Per-limb rest TRS: midpoints (K_pairs, 3) and unit a→b axes (or +Y if + degenerate). Caller uses midpoints as rest translation, axes for alignment.""" K_pairs = len(pairs) mid = np.zeros((K_pairs, 3), dtype=np.float32) axis = np.zeros((K_pairs, 3), dtype=np.float32) @@ -595,13 +532,10 @@ def _openpose_limb_rest_trs( def _openpose_limb_anim_trs( kp_seq: np.ndarray, pairs: Tuple[Tuple[int, int], ...], rest_axes: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: - """Per-frame limb TRS: - anim_mid (N, K_pairs, 3): midpoint of (kp_seq[t][a], kp_seq[t][b]). - anim_quat (N, K_pairs, 4): rotation (xyzw) that aligns each limb's rest - axis to its frame-t axis. - Together with rest TRS, this drives `skin_matrix(t) = T(mid_t) * R_t * - T(-mid_rest)` so each capsule rigidly rotates about its rest midpoint to - track the limb's current direction — no LBS cross-section thinning.""" + """Per-frame limb TRS: anim_mid (N, K_pairs, 3) midpoints and anim_quat + (N, K_pairs, 4 xyzw) aligning each limb's rest axis to its frame-t axis. + Drives skin_matrix(t) = T(mid_t)·R_t·T(-mid_rest) — rigid rotation about + the rest midpoint, no LBS cross-section thinning.""" N = kp_seq.shape[0] K_pairs = len(pairs) anim_mid = np.zeros((N, K_pairs, 3), dtype=np.float32) @@ -616,7 +550,7 @@ def _openpose_limb_anim_trs( n = float(np.linalg.norm(d)) if n > 1e-9: R[t, k] = rotation_align(ax_rest, d / n) - quat = rotmat_to_quat_np(R).astype(np.float32) # (N, K_pairs, 4) xyzw + quat = rotmat_to_quat_np(R).astype(np.float32) return anim_mid, quat @@ -628,20 +562,14 @@ def _build_openpose_sticks( smooth_shade: bool = False, end_width_frac: float = 0.3, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Capsule (cylinder + hemispherical caps) per limb pair (a, b). + """Capsule per limb pair (a, b), each sized to its own length/width so caps + are true hemispheres regardless of L:W. Ellipsoid mode auto-clamps width to + `length * 0.1` so short limbs don't look chunky. - Each limb gets its own mesh sized to that limb's length and width so - the caps are TRUE hemispheres of radius `half_width_eff` — the limb - silhouette is rounded-rectangle-like, regardless of L:W ratio. Width - auto-clamped to `length * 0.1` so short limbs (face/ear) don't look - chunky next to long ones. - - Skinning: rigid (weight=1) binding to a per-limb joint at - `limb_joint_base_idx + limb_idx` — the caller animates that joint with - midpoint translation + rest-to-current rotation so each capsule rotates - rigidly with its limb (avoids translation-only LBS cross-section - thinning). Returns flat-shaded (verts, normals, faces, joints4, - weights4, vert_colors).""" + Rigid (weight=1) binding to a per-limb joint at `limb_joint_base_idx + + limb_idx`, which the caller animates with midpoint translation + rotation + (avoids LBS thinning). Returns (verts, normals, faces, joints4, weights4, + vert_colors).""" canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32) out_v_chunks: List[np.ndarray] = [] @@ -663,13 +591,10 @@ def _build_openpose_sticks( unit_dir = direction / length R = rotation_align(canonical, unit_dir) if is_capsule: - # SCAIL-style uniform radius — every bone gets the same width. - # `_capsule_mesh_local` clamps internally to L/2-eps so very - # short bones don't go degenerate. + # Uniform radius — every bone the same width (clamped internally). half_width_eff = max(MIN_WIDTH, half_width_m) else: - # Ellipsoid mode: original auto-thinning so short face/ear - # limbs don't look chunky next to long body limbs. + # Auto-thin so short face/ear limbs aren't chunky next to body limbs. half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m)) v_local, f_local, _weights_unused = _capsule_mesh_local( @@ -678,10 +603,8 @@ def _build_openpose_sticks( v_world = v_local @ R.T + head Nv = v_local.shape[0] - # Rigid binding to the per-limb joint. The 2-bone (head, tail) weights - # from `_capsule_mesh_local` are discarded — they're translation-only - # under glTF LBS and don't rotate the cross-section, causing visible - # thinning when the limb axis changes between rest and animated pose. + # Rigid binding to the per-limb joint; the 2-bone weights are discarded + # (translation-only under LBS, would thin the cross-section). j_arr = np.zeros((Nv, 4), dtype=np.uint16) j_arr[:, 0] = limb_idx + limb_joint_base_idx w_arr = np.zeros((Nv, 4), dtype=np.float32) @@ -730,40 +653,24 @@ def build_glb_openpose( stick_end_width_frac: float = 0.6, bone_smooth_window: int = 0, ) -> bytes: - """Build a GLB containing an OpenPose-style 3D skeleton — sphere markers - per keypoint plus rainbow-colored sticks between standard limb pairs. - Body keypoints are sourced from pose_data's `pred_keypoints_3d` (no rig - forward needed). Optional hand keypoints (also from `pred_keypoints_3d`) - and face landmarks (sampled from `pred_vertices` at fixed head-mesh - vertex IDs) extend the same per-track armature. + """Build a GLB of an OpenPose-style 3D skeleton — sphere markers per keypoint + plus colored sticks between limb pairs, one armature per track. Body from + `pred_keypoints_3d`; optional hands (same source) and face landmarks + (`pred_vertices`) extend each armature. Args: - include_hands: append the standard 21+21 OpenPose hand keypoints to - each track's armature (right hand at MHR70 indices 21..41, - left at 42..62). - hand_marker_radius_m: per-hand sphere radius. 0 = auto = 0.4 × - `marker_radius_m` (hand keypoints are anatomically smaller than - body joints; matches DWPose's smaller hand dots). - hand_stick_radius_m: per-hand limb half-width. 0 = auto = 0.5 × - `stick_radius_m`. - hand_color_style: 'dwpose' (default) = solid-blue hand dots, - rainbow per-finger sticks (controlnet_aux/dwpose convention); - 'openpose' = rainbow per-finger dots AND sticks (matches - poseParameters.cpp::HAND_COLORS_RENDER). - face_style: 'disabled' (default) | 'full' | 'eyes_mouth' — face - landmarks sampled from `pred_vertices` at vertex IDs picked from - `pose_data["canonical_colors"]["positions"]`. 'full' = all ~30 - contour points; 'eyes_mouth' = the eyes + outer-lip subset. - face_marker_radius_m: per-face landmark sphere radius. 0 = auto = - 0.3 × `marker_radius_m` — face landmarks are densely packed - around the eyes/mouth/jaw and need to be much smaller than - body keypoints to keep the layout legible. Face landmarks are - rendered as standalone dots (no contour lines), matching - DWPose's face_pose draw style. - palette: body color scheme. 'openpose' = standard rainbow gradient - per keypoint (canonical OpenPose convention); 'scail' = - SCAIL-Pose style — warm hues right side, cool hues left side, - grey neck-to-nose centerline, distinct per-limb colors. + include_hands: append the 21+21 OpenPose hand keypoints per track. + hand_marker_radius_m: hand sphere radius. 0 = auto = 0.4 × marker_radius_m. + hand_stick_radius_m: hand limb half-width. 0 = auto = 0.5 × stick_radius_m. + hand_color_style: 'dwpose' (default) = solid-blue dots + rainbow sticks; + 'openpose' = rainbow dots AND sticks. + face_style: 'disabled' (default) | 'full' (~30 contour pts) | 'eyes_mouth' + (eyes + outer-lip subset); sampled at vertex IDs from + `canonical_colors["positions"]`. + face_marker_radius_m: face landmark sphere radius. 0 = auto = 0.3 × + marker_radius_m. Rendered as dots only, no contour lines. + palette: 'openpose' = rainbow gradient per keypoint; 'scail' = warm right + / cool left, grey centerline, distinct per-limb colors. """ is_scail = str(palette) == "scail" # SCAIL drops the face bones (13..16) and eye/ear spheres; keeps nose (idx 0, @@ -771,13 +678,11 @@ def build_glb_openpose( body_pairs = OPENPOSE_18_PAIRS[:13] if is_scail else OPENPOSE_18_PAIRS body_sphere_kp = (np.arange(14, dtype=np.int64) if is_scail else np.arange(18, dtype=np.int64)) - if str(palette) == "scail": + if is_scail: body_sphere_colors = SCAIL_KEYPOINT_COLORS_18 body_stick_colors = SCAIL_LIMB_COLORS_17 elif str(palette) == "openpose": - # Existing OpenPose behavior: same rainbow array used for both - # spheres (per-keypoint) and sticks (per-limb, indexed 0..16 of - # the 18-element rainbow — yields a legible per-limb gradient). + # Same rainbow array drives both spheres and sticks. body_sphere_colors = OPENPOSE_RAINBOW_18 body_stick_colors = OPENPOSE_RAINBOW_18 else: @@ -892,13 +797,9 @@ def build_glb_openpose( if bone_smooth_window and bone_smooth_window > 1: kp_seq = gaussian_smooth_positions(kp_seq, int(bone_smooth_window)) - # Static-bind = rig's REST pose when available (override path); else - # fall back to frame 0 of the motion. The rest-pose bind makes the - # GLB's static POSITION attribute sit at rig origin, so viewers - # auto-fit/center on rig origin and the animation visibly snaps from - # rest to scene-frame-0 — matching skeletal mode's behavior. Without - # this, openpose's static geometry is at scene-frame-0 and viewers - # mis-center on the scene location, masking the motion entirely. + # Static-bind = rig REST pose when available, else frame 0. The rest + # bind keeps static POSITION at rig origin so viewers auto-center there + # and the motion is visible (see _openpose_bind_at_rig_rest). bind_kp_m_rest = _openpose_bind_at_rig_rest( pose_data, include_hands=include_hands, face_vert_ids=face_vert_ids, ) @@ -914,7 +815,7 @@ def build_glb_openpose( person_root_idx = len(nodes) - 1 scene_root_indices.append(person_root_idx) - # K keypoint joint nodes (spheres bind here, rigid translation only). + # K keypoint joint nodes (spheres bind here, translation only). joint_node_indices: List[int] = [] for j in range(K): nodes.append({ @@ -926,9 +827,7 @@ def build_glb_openpose( joint_node_indices.append(len(nodes) - 1) person_root["children"].extend(joint_node_indices) - # Per-limb REST TRS (midpoint + axis) and per-frame TRS (midpoint + - # quaternion that aligns rest-axis → frame-t-axis). Sticks bind - # rigidly to these joints so each capsule rotates with its limb. + # Per-limb rest + per-frame TRS; sticks bind rigidly to these joints. limb_rest_mids_list: List[np.ndarray] = [] limb_rest_axes_list: List[np.ndarray] = [] limb_anim_mids_list: List[np.ndarray] = [] @@ -951,12 +850,10 @@ def build_glb_openpose( limb_rest_axes_list.append(raxis_h) limb_anim_mids_list.append(amid_h) limb_anim_quats_list.append(aquat_h) - limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0) # (K_limbs, 3) - limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1) # (N, K_limbs, 3) - limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1) # (N, K_limbs, 4) - # Hemisphere-align consecutive quats per limb so LINEAR interpolation - # takes the short path (otherwise large per-frame rotations can flip - # signs and produce visible "twist back" artifacts mid-playback). + limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0) + limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1) + limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1) + # Hemisphere-align consecutive quats so LINEAR interp takes the short path. limb_anim_quats = quat_sign_fix_per_joint(limb_anim_quats).astype(np.float32) limb_joint_indices: List[int] = [] @@ -970,8 +867,8 @@ def build_glb_openpose( limb_joint_indices.append(len(nodes) - 1) person_root["children"].extend(limb_joint_indices) - # Combined skin: keypoint joints (IBM = T(-bind_kp_m)) then limb joints - # (IBM = T(-limb_rest_mid)). Both yield identity skin_matrix at rest. + # Combined skin: keypoint joints then limb joints; IBM = T(-rest) for + # both, yielding identity skin_matrix at rest. all_joint_indices = joint_node_indices + limb_joint_indices ibm = np.tile(np.eye(4, dtype=np.float32), (K + K_limbs, 1, 1)) ibm[:K, :3, 3] = -bind_kp_m @@ -985,10 +882,8 @@ def build_glb_openpose( }) skin_idx = len(skins) - 1 - # Per-group geometry. Spheres bind to keypoint joints (base_joint_idx - # ∈ [0, K)); sticks bind to limb joints (limb_joint_base_idx ∈ - # [K, K + K_limbs)). Groups stack body → right hand → left hand → - # face for keypoint joints, and body → R-hand → L-hand for limbs. + # Per-group geometry. Spheres bind to keypoint joints [0, K); sticks to + # limb joints [K, K+K_limbs). Stacked body → R-hand → L-hand → face. group_meshes: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]] = [] sp = _build_openpose_spheres( @@ -1008,9 +903,7 @@ def build_glb_openpose( group_meshes.append(st) if include_hands: - # Hand stick colors stay rainbow per-finger regardless of - # `hand_color_style` — only the sphere dots switch to solid - # blue under 'dwpose'. Matches controlnet_aux/dwpose/util.py. + # Hand sticks stay rainbow per-finger; only dots switch under 'dwpose'. hand_pair_colors = _pair_colors_from_kp( OPENPOSE_HAND_PAIRS, OPENPOSE_HAND_COLORS_21, endpoint=1, ) @@ -1033,9 +926,7 @@ def build_glb_openpose( if K_face > 0: f_off = K_body + K_hands f_bind = bind_kp_m[f_off:f_off + K_face] - # DWPose face = dots only, no contour lines - # (controlnet_aux/dwpose/util.py::draw_facepose draws white - # circles per landmark and never connects them). + # DWPose face = dots only, no contour lines. group_meshes.append(_build_openpose_spheres( f_bind, float(face_marker_radius_m), FACE_LANDMARK_COLORS, base_joint_idx=f_off, @@ -1087,9 +978,8 @@ def build_glb_openpose( "target": {"node": joint_node_indices[j], "path": "translation"}, }) - # Per-limb-joint translation + rotation channels. Stationary limbs - # have their constant TRS baked into the node so they don't bloat the - # animation buffer. + # Per-limb-joint translation + rotation; stationary limbs bake their + # constant TRS into the node instead of an animation channel. for k in range(K_limbs): t_k = limb_anim_mids[:, k, :].astype(np.float32) if (np.ptp(t_k, axis=0) < 1e-6).all(): @@ -1103,9 +993,7 @@ def build_glb_openpose( "target": {"node": limb_joint_indices[k], "path": "translation"}, }) q_k = limb_anim_quats[:, k, :].astype(np.float32) - # ptp on the absolute value handles the +q == -q ambiguity, but - # `quat_sign_fix_per_joint` already aligned signs so a plain ptp - # is fine here. + # Plain ptp is fine — signs already aligned by quat_sign_fix_per_joint. if (np.ptp(q_k, axis=0) < 1e-6).all(): nodes[limb_joint_indices[k]]["rotation"] = q_k[0].tolist() else: diff --git a/comfy_extras/sam3d_body/export/glb_shared.py b/comfy_extras/sam3d_body/export/glb_shared.py index 0db4813a3..017ad20a5 100644 --- a/comfy_extras/sam3d_body/export/glb_shared.py +++ b/comfy_extras/sam3d_body/export/glb_shared.py @@ -1,15 +1,11 @@ -"""GLB export for SAM 3D Body pose_data. +"""Shared GLB export helpers for SAM 3D Body pose_data. -Mode: skeletal — rebuilds the MHR 127-bone rig. Per-frame local TRS comes from -re-running param_transform on saved mhr_model_params; rest verts from a -zero-pose forward with the person's shape_params; sparse triplet skinning is -compacted to glTF's max-4-influences form; facial expression is re-exposed as -72 morph targets driven by expr_params. - -pred_vertices/pred_cam_t are camera-y-down — un-flipped here so the GLB lives -in glTF-spec Y-up. Pose correctives are dropped (glTF skinning can't represent -them); deformation at extreme joint angles will differ from the SAM3DBody -renderer by the corrective amount. +Skeletal mode rebuilds the MHR 127-bone rig: per-frame local TRS from +param_transform on mhr_model_params, rest verts from a zero-pose forward, +sparse skinning compacted to glTF's 4-influence form, expression re-exposed as +72 morph targets. Camera-y-down data is un-flipped to glTF Y-up. Pose +correctives are dropped (glTF skinning can't represent them), so extreme joint +angles differ from the SAM3DBody renderer by the corrective amount. """ from __future__ import annotations @@ -24,12 +20,11 @@ import torch from comfy_extras.sam3d_body.rasterizer import rainbow_colors_from_canonical -# fp32-rounded ln(2). Used as `exp(x * _LN2)` to compute 2**x bit-identically -# to the rig's own `torch.exp(jp[..., 6:7] * _LN2)` +# fp32-rounded ln(2); exp(x * _LN2) matches the rig's own 2**x bit-for-bit. _LN2 = 0.6931471824645996 -# Quaternion / rotation helpers (xyzw convention, matching MHR rig) +# Quaternion / rotation helpers (xyzw, matching MHR rig) def _euler_xyz_to_quat_np(angles: np.ndarray) -> np.ndarray: """(roll, pitch, yaw) -> (x, y, z, w). Mirrors mhr_rig._euler_xyz_to_quat.""" @@ -96,8 +91,7 @@ def _skel_state_compose_np(s1: np.ndarray, s2: np.ndarray) -> np.ndarray: def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray: - """Edge-replicate Gaussian smoothing along axis 0 (time); sigma = window/4. - Endpoints replicate so they aren't pulled toward zero. Returns float64.""" + """Edge-replicate Gaussian smoothing along time (sigma = window/4). float64.""" a = np.asarray(arr, dtype=np.float64) n = a.shape[0] half = window // 2 @@ -117,9 +111,8 @@ def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray: def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray: - """Gaussian-smooth a (N, NJ, 4) quaternion sequence along time. Sign-aligns - per joint first, convolves per-component, renormalizes. Suppresses multi- - frame bone spikes at extreme poses without needing the upstream Smooth node.""" + """Smooth a (N, NJ, 4) quaternion sequence along time: sign-align per joint, + convolve per-component, renormalize. Calms bone spikes at extreme poses.""" if window <= 1 or q_seq.shape[0] < 2: return q_seq out = _gaussian_smooth_time(quat_sign_fix_per_joint(q_seq), window) @@ -128,18 +121,16 @@ def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray: def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray: - """Gaussian-smooth a (N, K, 3) position sequence along time (edge-replicate - padding). Used to calm jittery keypoint tracks before the openpose rig - derives sphere translations + limb TRS from them.""" + """Smooth a (N, K, 3) position sequence along time. Calms jittery keypoint + tracks before the openpose rig derives sphere translations + limb TRS.""" if window <= 1 or seq.shape[0] < 2: return seq return _gaussian_smooth_time(seq, window).astype(np.float32) def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray: - """Walk (N, NJ, 4) along time, flip sign whenever consecutive frames sit - on opposite hemispheres. Eliminates long-path slerp glitches (mid-anim - cartwheel flip). fp64 to avoid drift; normalizes input defensively.""" + """Walk (N, NJ, 4) along time, flipping sign when consecutive frames sit on + opposite hemispheres. Avoids long-path slerp glitches. fp64 internally.""" out = np.array(q_seq, dtype=np.float64, copy=True) norms = np.linalg.norm(out, axis=-1, keepdims=True) out = out / np.maximum(norms, 1e-12) @@ -151,11 +142,9 @@ def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray: def bone_locals_from_globals(rig_global: np.ndarray, parents: np.ndarray) -> np.ndarray: - """Globals (N, NJ, 8) + parents -> per-bone local TRS (N, NJ, 8) such that - FK over (parents, bone_local) reproduces rig_global. local = - inverse(parent_global) ∘ child_global makes this robust to hierarchy- - convention mismatches: glTF FK gives back exactly rig_global even if - `parents` doesn't match the rig's pmi-walk.""" + """Globals (N, NJ, 8) + parents -> per-bone local TRS so FK reproduces + rig_global. local = inverse(parent_global) ∘ child_global, robust to + hierarchy-convention mismatches in `parents`.""" N, NJ, _ = rig_global.shape bone_local = np.zeros_like(rig_global) for j in range(NJ): @@ -188,8 +177,7 @@ def _quat_to_mat3_np(q: np.ndarray) -> np.ndarray: def collect_tracks(pose_data: Dict[str, Any], track_index: int) -> List[Tuple[int, List[int]]]: """List of (person_index, frame_indices). track_index == -1 means every - present track; empty tracks are dropped. Same person index across frames - is assumed same subject (Smooth/Predict enforce this on tracked bboxes).""" + present track; empty tracks dropped. Same person index = same subject.""" frames = pose_data["frames"] max_p = max((len(f) for f in frames), default=0) if max_p == 0: @@ -257,8 +245,7 @@ class GLBWriter: return len(self.accessors) - 1 def add_vec3_f32_no_minmax(self, arr: np.ndarray) -> int: - """Morph-target POSITIONs: spec lets us skip min/max, avoiding a - per-frame delta bbox.""" + """Morph-target POSITIONs: spec lets us skip min/max.""" a = np.ascontiguousarray(arr, dtype=np.float32) view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY) self.accessors.append({ @@ -288,9 +275,8 @@ class GLBWriter: return len(self.accessors) - 1 def add_scalar_f32_flat(self, arr: np.ndarray, count: int) -> int: - """Animation-output scalars: `count` is keyframes, not floats. Morph- - target weight tracks store N_morph weights per keyframe as flat float32 - with count=N_keyframes.""" + """Animation-output scalars: `count` is keyframes, not floats (morph + weight tracks store N_morph weights per keyframe).""" a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1) view_idx = self._add_view(a.tobytes()) self.accessors.append({ @@ -382,9 +368,8 @@ def bake_vertex_colors( rainbow_tilt_z_deg: float, pastel_mix: float, ) -> Optional[np.ndarray]: - """Per-vertex RGB matching the renderer's shader preset, on the canonical - mesh. Returns (N_v, 3) float32 in [0, 1], or None for `default` (let the - viewer's default material handle shading).""" + """Per-vertex RGB matching the renderer's shader preset. Returns (N_v, 3) + float32 in [0, 1], or None for `default` (use the viewer's material).""" if shader == "default" or canonical_colors is None: return None @@ -432,8 +417,8 @@ def compute_normals(verts: np.ndarray, faces: np.ndarray) -> np.ndarray: def _parents_from_pmi(rig: Any) -> np.ndarray: - """Parent index per joint from skel_pmi. pmi is (2, 266): row 0 = child, - row 1 = parent, split into BFS levels by skel_pmi_buffer_sizes. Roots = -1.""" + """Parent index per joint from skel_pmi ((2, 266): row 0 child, row 1 + parent, split into BFS levels by skel_pmi_buffer_sizes). Roots = -1.""" NJ = int(rig.NUM_JOINTS) pmi = rig.skel_pmi.cpu().numpy() sizes = rig.skel_pmi_buffer_sizes.cpu().numpy().tolist() @@ -450,47 +435,29 @@ def _parents_from_pmi(rig: Any) -> np.ndarray: def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Return ``_skeleton_override`` dict if present. Non-MHR skeletons supply - this to bypass MHR rig extraction (see ComfyUI-Kimodo). Required keys: - parents: (NJ,) int32, -1 = root - bind_global_m: (NJ, 8) f32 — [t.xyz | q.xyzw | scale], meters - lbs_compact_joints: (V, 8) uint16 — pre-compacted skin influences - lbs_compact_weights: (V, 8) f32 - lbs_compact_max_inf: int — actual max influences (≤ 8) - rest_verts_m: (V, 3) f32 - faces: (F, 3) uint32 - Optional: - per_frame_y_down: bool — set False if pred_joint_coords are already - rig-native Y-up (kimodo). Default True (MHR). - openpose18_joint_indices: (18, 2) int32 — body OpenPose-18 → joint - index pair, resolved against per-frame - `pred_joint_coords`. Each row is - (joint_a, joint_b); b == -1 = single - joint, else default midpoint of the two - (lets producers approximate keypoints - with no matching joint, e.g. Nose ≈ - midpoint(LeftEye, RightEye)). Enables - `SAM3DBody_ToGLB(mode="openpose")` on - external rigs. - openpose18_joint_weights: (18,) f32 — optional per-keypoint blend - weight for the (a, b) mapping above. - Position = w*joints[a] + (1-w)*joints[b] - when b ≥ 0 (default w=0.5 → midpoint). - Values outside [0, 1] EXTRAPOLATE past - the line segment — used to approximate - landmarks with no nearby joint pair - (e.g. ears: w=2.0 along the eye→eye - axis puts each ear one eye-distance - outside the corresponding eye). Ignored - for single-joint rows (b = -1). - openpose_hand21_r_joint_indices: (21, 2) int32 — right-hand OpenPose-21 - (wrist + 5 fingers × 4 joints, base→tip) - → joint index pair. Required (alongside - the L counterpart) for openpose mode - with include_hands=True. - openpose_hand21_l_joint_indices: (21, 2) int32 — left-hand counterpart. - openpose_hand21_r_joint_weights: (21,) f32 — optional, same semantics as - `openpose18_joint_weights`. - openpose_hand21_l_joint_weights: (21,) f32 — optional, same as above. + this to bypass MHR rig extraction (see ComfyUI-Kimodo). + + Required keys: + parents: (NJ,) int32, -1 = root + bind_global_m: (NJ, 8) f32 — [t.xyz | q.xyzw | scale], meters + lbs_compact_joints: (V, 8) uint16 — pre-compacted skin influences + lbs_compact_weights: (V, 8) f32 + lbs_compact_max_inf: int — actual max influences (≤ 8) + rest_verts_m: (V, 3) f32 + faces: (F, 3) uint32 + + Optional (enable openpose mode on external rigs): + per_frame_y_down: bool — False if pred_joint_coords are already Y-up + (kimodo). Default True (MHR). + openpose18_joint_indices: (18, 2) int32 — body keypoint → (a, b) + joints, resolved against `pred_joint_coords`. + b == -1 = single joint, else midpoint of (a, b). + openpose18_joint_weights: (18,) f32 — blend w: w*a + (1-w)*b + (default 0.5; outside [0,1] extrapolates; ignored + when b == -1). + openpose_hand21_{r,l}_joint_indices: (21, 2) int32 — per-hand keypoint + maps; both required for include_hands=True. + openpose_hand21_{r,l}_joint_weights: (21,) f32 — optional, same as above. """ if pose_data is None: return None @@ -502,12 +469,10 @@ def extract_rig_static(model: Any, pose_data: Optional[Dict[str, Any]] = None) - use that instead of MHR-specific `model.head_pose.mhr` buffers.""" override = _get_skeleton_override(pose_data) if override is not None: - # External rig: caller pre-compacts skin and supplies bind global directly, - # so we don't need MHR's PCA pose / expression bases. + # External rig: skin pre-compacted, bind global supplied directly. parents = np.asarray(override["parents"], dtype=np.int32) rest_v = np.asarray(override["rest_verts_m"], dtype=np.float32) - # BVH needs parent-relative bone OFFSETs (cm). MHR ships these directly; - # external rigs only give bind globals, so derive locals from them. + # BVH needs parent-relative bone offsets (cm); derive from bind globals. bind_global_m = np.asarray(override["bind_global_m"], dtype=np.float32) local_bind = bone_locals_from_globals(bind_global_m[None], parents)[0] joint_translation_offsets = (local_bind[:, :3] * 100.0).astype(np.float32) @@ -560,29 +525,26 @@ def compact_skin_to_n( skin_indices: np.ndarray, vert_indices: np.ndarray, weights: np.ndarray, num_verts: int, max_inf: int = 8, ) -> Tuple[np.ndarray, np.ndarray, int]: - """Sparse (joint, vert, weight) triplets -> dense (joints[V, max_inf], - weights[V, max_inf]). Keeps `max_inf` largest-magnitude influences, - renormalizes. `actual_max` lets the caller skip JOINTS_1/WEIGHTS_1 when - nothing exceeds 4 influences.""" + """Sparse (joint, vert, weight) triplets -> dense (joints, weights) of shape + (V, max_inf), keeping the largest influences and renormalizing. `actual_max` + lets the caller skip JOINTS_1/WEIGHTS_1 when nothing exceeds 4 influences.""" joints = np.zeros((num_verts, max_inf), dtype=np.uint16) out_w = np.zeros((num_verts, max_inf), dtype=np.float32) counts = np.zeros(num_verts, dtype=np.int32) if vert_indices.size: - # lexsort secondary key first: groups by vert, weights descending within group. + # Group by vert, weights descending within each group. order = np.lexsort((-weights, vert_indices)) vi_sorted = vert_indices[order] sk_sorted = skin_indices[order] w_sorted = weights[order] - # Per-row rank within its vertex group: 0 at each group start, +1 elsewhere. - # group_start[i] is True when vi_sorted[i] starts a new vertex. + # Per-row rank within its vertex group (0 at each group start). n = vi_sorted.size group_start = np.empty(n, dtype=bool) group_start[0] = True np.not_equal(vi_sorted[1:], vi_sorted[:-1], out=group_start[1:]) pos = np.arange(n, dtype=np.int64) - # Position of each row's group start, broadcast forward. group_start_pos = np.maximum.accumulate(np.where(group_start, pos, 0)) rank = pos - group_start_pos @@ -609,9 +571,8 @@ def zero_pose_rest_verts( model: Any, shape_params: np.ndarray, expr_zero: bool = True, pose_data: Optional[Dict[str, Any]] = None, ) -> np.ndarray: - """Rig with zero pose + this subject's shape -> rest verts (V, 3) in - rig-native Y-up meters. External-skeleton path returns `rest_verts_m` - directly (no PCA shape space to expand).""" + """Zero pose + this subject's shape -> rest verts (V, 3) in rig-native Y-up + meters. External path returns `rest_verts_m` directly.""" override = _get_skeleton_override(pose_data) if override is not None: return np.asarray(override["rest_verts_m"], dtype=np.float32) @@ -624,14 +585,11 @@ def zero_pose_rest_verts( sp = torch.from_numpy(np.ascontiguousarray(shape_params, dtype=np.float32)).to(device) if sp.ndim == 1: sp = sp.unsqueeze(0) - # mhr.forward(identity_coeffs, model_parameters, expr_coeffs): - # identity_rest = base_shape + identity_basis @ shape; - # cat([model_params, zeros]) through param_transform; expr added. + # rig.forward(shape, model_params, expr); zero pose + zero expr. model_params = torch.zeros(1, 204, device=device, dtype=dtype) expr = torch.zeros(1, 72, device=device, dtype=dtype) verts, _ = rig(sp.to(dtype), model_params, expr, apply_correctives=False) - # Rig outputs cm; mhr_head divides by 100 for meters. Match that. - verts_m = verts[0].cpu().float().numpy() / 100.0 + verts_m = verts[0].cpu().float().numpy() / 100.0 # cm -> m return verts_m.astype(np.float32) @@ -639,7 +597,7 @@ def global_skel_state_per_frame( model: Any, mhr_model_params: np.ndarray, ) -> np.ndarray: """Rig FK over a batch of mhr_model_params -> (N, NJ, 8) = (t cm, q xyzw, - scale). Bones are shape- and expression-independent so we pass zeros.""" + scale). Bones are shape/expression-independent, so pass zeros.""" inner = model.model if hasattr(model, "model") else model rig = inner.head_pose.mhr device = next(rig.parameters()).device @@ -655,8 +613,8 @@ def global_skel_state_per_frame( def rotmat_to_quat_np(R: np.ndarray) -> np.ndarray: - """(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978 branched, largest-component - pick for stability. Cross-frame sign-fixing is the caller's job.""" + """(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978, largest-component pick. + Cross-frame sign-fixing is the caller's job.""" shape = R.shape[:-2] Rf = R.reshape(-1, 3, 3).astype(np.float64) M = Rf.shape[0] @@ -703,14 +661,12 @@ def global_skel_state_from_pose_data( pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, NJ: int, *, joint_coords_y_down: bool = True, ) -> np.ndarray: - """Build per-frame skel_state from stored pred_global_rots + pred_joint_coords, - bypassing rig.forward. Returns (N, NJ, 8) in METERS, MHR-native frame. + """Per-frame skel_state from stored pred_global_rots + pred_joint_coords, + bypassing rig.forward. Returns (N, NJ, 8) in meters, MHR-native frame. - pred_global_rots are MHR-native (no y/z flip). For MHR, pred_joint_coords - are stored y-down (post-flip), so un-flip when `joint_coords_y_down=True`. - External skeletons (Kimodo) store y-up already → pass False. Scale - defaults to 1 (rig scale isn't preserved in pose_data; close to 1 for - typical body poses).""" + pred_global_rots are MHR-native. pred_joint_coords are y-down for MHR + (un-flipped when `joint_coords_y_down=True`); external rigs store y-up + (pass False). Scale defaults to 1 (not preserved in pose_data).""" frames = pose_data["frames"] N = len(frame_indices) rotmat = np.zeros((N, NJ, 3, 3), dtype=np.float32) @@ -731,10 +687,8 @@ def global_skel_state_from_pose_data( def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> np.ndarray: - """Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm. - Inverse of `lbs_inverse_bind_pose` modulo precision; used as bones' static - TRS so the rest mesh looks correct with no animation playing. External - rig: convert override's `bind_global_m` from m → cm to match this contract.""" + """Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm, + used as bones' static TRS. External rig: convert `bind_global_m` m -> cm.""" override = _get_skeleton_override(pose_data) if override is not None: bind_m = np.asarray(override["bind_global_m"], dtype=np.float32).copy() @@ -746,13 +700,10 @@ def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> n @dataclass class Rig: - """Normalized static rig for the GLB/BVH exporters, independent of where it - came from: an MHR model (`Rig.from_pose_data(pose_data, model)`) or an inline - `pose_data["_skeleton_override"]` (external rigs, e.g. ComfyUI-Kimodo). - - Consumers read these fields and never branch on the source. The only - source-dependent operation is `rest_verts_m` — MHR rest verts depend on the - subject's `shape_params`; external rigs ship fixed rest verts. + """Normalized static rig for the GLB/BVH exporters, source-independent: MHR + model or inline `pose_data["_skeleton_override"]` (external rigs). Consumers + never branch on the source. Only `rest_verts_m` is source-dependent — MHR + expands it from `shape_params`; external rigs ship it fixed. """ parents: np.ndarray # (NJ,) int32, -1 = root joint_offsets_cm: np.ndarray # (NJ, 3) parent-relative bind offsets, cm @@ -816,9 +767,8 @@ class Rig: def ibp_from_bind_global(bind_skel_state_m: np.ndarray) -> np.ndarray: - """Inverse-bind MAT4 by inverting the rig's bind global (meters). Guarantees - IBP[j] = inverse(FK over bind local TRS) — exactly what glTF skinning - needs given bones default to the bind local TRS. Returns (NJ, 4, 4) + """Inverse-bind MAT4 from the rig's bind global (meters). IBP[j] = + inverse(FK over bind local TRS), as glTF skinning needs. Returns (NJ, 4, 4) column-major.""" NJ = bind_skel_state_m.shape[0] t = bind_skel_state_m[:, :3].astype(np.float32) @@ -877,10 +827,8 @@ def _ibp_to_mat4(ibp_skel: np.ndarray) -> np.ndarray: def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndarray]: - """Unit UV sphere, poles ±Y. `n_lat` kept ODD by default so one ring - lands at the equator. Default (9, 16) gives 146 verts / 288 faces — n_lon - matches the 16-segment cylinder used by capsule limbs AND the equator - ring aligns 1-to-1 with the cylinder end ring, so silhouettes meet flush.""" + """Unit UV sphere, poles ±Y. `n_lat` odd so a ring lands at the equator; + n_lon=16 matches the capsule cylinder so end rings meet flush.""" verts: List[List[float]] = [[0.0, -1.0, 0.0]] # south pole at index 0 for i in range(1, n_lat + 1): lat = -0.5 * np.pi + np.pi * i / (n_lat + 1) @@ -924,8 +872,8 @@ def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndar def flat_shade_mesh( verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Smooth -> flat by duplicating verts per face; each triangle gets 3 - unique verts sharing its face normal. Skinning attrs duplicated alongside.""" + """Flat-shade by duplicating verts per face; each triangle gets 3 unique + verts sharing its face normal. Skinning attrs duplicated alongside.""" F = faces.shape[0] new_v = np.zeros((F * 3, 3), dtype=np.float32) new_n = np.zeros((F * 3, 3), dtype=np.float32) @@ -949,9 +897,8 @@ def flat_shade_mesh( def smooth_shade_mesh( verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Area-weighted per-vertex normals (smooth shading). Geometry, skinning, - indexing pass through unchanged so vertex colors stay aligned. Orphan - verts get +Y fallback.""" + """Area-weighted per-vertex normals. Geometry/skinning/indexing pass through + unchanged so vertex colors stay aligned. Orphan verts get +Y fallback.""" Nv = int(verts.shape[0]) v0 = verts[faces[:, 0]] v1 = verts[faces[:, 1]] @@ -994,11 +941,9 @@ def rotation_align(from_vec: np.ndarray, to_vec: np.ndarray) -> np.ndarray: def make_lit_material( roughness: float = 0.85, double_sided: bool = False, opacity: float = 1.0, ) -> dict: - """Lit PBR material using vertex COLOR_0 multiplicatively. KHR_materials_unlit - is intentionally off so viewer lighting reveals surface form. metallic=0 - keeps the surface dielectric so vertex colors stay readable. roughness=0.85 - suits dense rainbow body meshes; 0.3 matches SCAIL-Pose's glossy rig look. - opacity < 1 switches to alpha-blend (e.g. see-through body mesh over bones).""" + """Lit PBR material using vertex COLOR_0. Dielectric (metallic=0) so colors + stay readable; roughness 0.85 suits rainbow body meshes, 0.3 the glossy + SCAIL rig. opacity < 1 switches to alpha-blend.""" a = float(max(0.0, min(1.0, opacity))) mat = { "pbrMetallicRoughness": { @@ -1182,14 +1127,12 @@ def openpose_render_keypoints( person: Dict[str, Any], pose_data: Optional[Dict[str, Any]], part: str, *, dim: int, H: int = 0, W: int = 0, ) -> Optional[np.ndarray]: - """OpenPose keypoints for one person, in op-layout, CAMERA frame (Y-down). + """OpenPose keypoints for one person, op-layout, camera frame (Y-down). `part` in {'body','hand_r','hand_l'}. dim=3 -> (K, 3) metres pre-cam_t-add; - dim=2 -> (K, 2) image pixels. Returns None when the source data is missing. + dim=2 -> (K, 2) pixels. Returns None when source data is missing. - External rigs (override carries the joint-index map) resolve from per-frame - `pred_joint_coords` (rig-native Y-up -> flipped to camera Y-down, matching - the pred_vertices convention). MHR reindexes the stored - `pred_keypoints_{3d,2d}` via the MHR70 map.""" + External rigs resolve from `pred_joint_coords` (Y-up -> flipped to Y-down); + MHR reindexes stored `pred_keypoints_{3d,2d}` via the MHR70 map.""" map_key, w_key, mhr_map = _OPENPOSE_RENDER_MAPS[part] override = _get_skeleton_override(pose_data) ext_map = override.get(map_key) if override is not None else None @@ -1228,11 +1171,9 @@ def openpose_render_keypoints( return kp_full[mhr_map] -# Face landmarks from the MHR rig (option `face_source="rig"`). -# MHR has no face bones — face deforms via expr_params morphs — so landmarks -# are sourced from `pred_vertices` at fixed vertex IDs picked by NN against -# anatomically-plausible target xyz in canonical Y-up. Iterate visually in -# Blender and tweak targets if landmarks land off-surface. +# Face landmarks (face_source="rig"). MHR has no face bones, so landmarks are +# sourced from `pred_vertices` at vertex IDs picked by NN against the target xyz +# below. Tweak targets if landmarks land off-surface. # (name, target_xyz) in MHR canonical Y-up meters. FACE_LANDMARK_TARGETS: Tuple[Tuple[str, Tuple[float, float, float]], ...] = ( @@ -1290,10 +1231,8 @@ def select_face_landmark_vert_ids( face_mask: Optional[np.ndarray] = None, ) -> np.ndarray: """Pick MHR head vertex IDs for each `FACE_LANDMARK_TARGETS` by NN in - canonical positions. Filter: `face_mask` (verts that deform with any of - the 72 expression axes) if available — keeps chin/jaw search off the - neck. Otherwise a position bbox (less reliable; throat verts sometimes - pull chin targets).""" + canonical positions, restricted to `face_mask` verts (expression-deforming) + when available, else a position bbox (less reliable around the chin/jaw).""" P = np.asarray(canonical_positions, dtype=np.float32).reshape(-1, 3) if face_mask is not None and np.asarray(face_mask).any(): valid = np.where(np.asarray(face_mask).reshape(-1))[0] diff --git a/comfy_extras/sam3d_body/export/glb_skeletal.py b/comfy_extras/sam3d_body/export/glb_skeletal.py index 128a1c5bd..5c03ac5c8 100644 --- a/comfy_extras/sam3d_body/export/glb_skeletal.py +++ b/comfy_extras/sam3d_body/export/glb_skeletal.py @@ -1,19 +1,11 @@ """GLB export — skeletal (real armature) mode. -Rebuilds an Armature with the MHR 127-bone rig: - - per-frame local TRS comes from re-running param_transform on the saved - `mhr_model_params`; - - rest verts come from a zero-pose forward with each person's `shape_params`; - - sparse triplet skinning is compacted to glTF's max-4-influences-per-vertex form; - - facial expression is re-exposed as 72 morph targets driven by `expr_params` - so face animation survives plain glTF skinning. - -Optional bone visualization (octahedrons) is rigidly -skinned alongside the body mesh — used to preview the armature in glTF -viewers that don't draw bones. - -Shared GLB infra (writer, math, rig static extraction, shaders, normals) -stays in `glb_shared.py`; only this mode's geometry + assembly live here. +Rebuilds an Armature with the MHR 127-bone rig: per-frame local TRS from +param_transform on `mhr_model_params`, rest verts from a zero-pose forward, +sparse skinning compacted to glTF's 4-influence form, and facial expression as +72 morph targets driven by `expr_params`. Optional octahedron bone-vis is +rigidly skinned alongside for viewers that don't draw bones. Shared infra lives +in `glb_shared.py`. """ from __future__ import annotations @@ -44,8 +36,7 @@ from .glb_shared import ( from comfy_extras.sam3d_body.utils import jet_colormap def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]: - """Per-bone RGB color (NJ, 3) float32 in [0, 1]. Returns None for 'white' - (no per-bone color → bone-vis mesh uses default unlit material).""" + """Per-bone RGB (NJ, 3) float32 in [0, 1]. None for 'white' (default material).""" if scheme == "rainbow_y": y = bind_pos_m[:, 1].astype(np.float32) y_min, y_max = float(y.min()), float(y.max()) @@ -55,9 +46,8 @@ def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]: - """Canonical Blender-style bone octahedron. Head at origin, tail at +Y, - unit length, ridge at 1/10 height. 6 verts, 8 triangles. Faces wound - so cross(v1-v0, v2-v0) points OUTWARD from the bone axis.""" + """Canonical Blender-style bone octahedron: head at origin, tail at +Y, unit + length, ridge at 1/10 height. 6 verts, 8 triangles, faces wound outward.""" v = np.array([ [0.0, 0.0, 0.0], # 0: head [0.0, 1.0, 0.0], # 1: tail @@ -78,18 +68,16 @@ def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]: def _bone_edges( joint_pos_m: np.ndarray, parents: np.ndarray, ) -> List[Tuple[int, int, np.ndarray, np.ndarray]]: - """Return one (parent_idx, child_idx, head_pos, tail_pos) tuple per - parent→child edge in the hierarchy, skipping edges whose PARENT is a - root joint (those typically anchor the skeleton at world origin and - just look like a stray stick from origin to the body). Zero-length - edges are skipped too.""" + """One (parent_idx, child_idx, head_pos, tail_pos) per parent→child edge. + Skips edges whose parent is a root (world-anchor sticks) and zero-length + edges.""" NJ = joint_pos_m.shape[0] out: List[Tuple[int, int, np.ndarray, np.ndarray]] = [] for c in range(NJ): p = int(parents[c]) if not (0 <= p < NJ and p != c): continue - # Skip if parent itself is a root — that bone is a world-anchor stick. + # Skip world-anchor sticks: parent itself is a root. gp = int(parents[p]) if not (0 <= gp < NJ and gp != p): continue @@ -104,9 +92,8 @@ def _bone_edges( def _build_bone_octahedrons_mesh( bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """One Blender-style octahedron per parent→child edge. Returns - (verts, normals, faces, joints, weights, child_idx_per_vert); - child_idx feeds per-bone color lookup at the call site.""" + """One octahedron per parent→child edge. Returns (verts, normals, faces, + joints, weights, child_idx_per_vert); child_idx feeds per-bone color.""" base_v, base_f = _octahedron_unit() canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32) @@ -117,8 +104,7 @@ def _build_bone_octahedrons_mesh( out_w: List[List[float]] = [] child_per_vert: List[int] = [] - # Width scales with length so short bones (fingers, face) don't look chunky - # next to long ones (limbs, spine). `half_width_m` caps long bones. + # Width scales with length (capped by half_width_m) so short bones aren't chunky. WIDTH_RATIO = 0.1 MIN_WIDTH = 0.001 for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents): @@ -151,8 +137,8 @@ def _build_bone_octahedrons_mesh( out_n.extend(n_world.tolist()) for face in base_f: out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off]) - # Dual skin head→parent, tail→child, ridges blend by canonical Y so the - # bone stretches between joints instead of going rigid with one. + # Dual skin (head→parent, tail→child); ridges blend by canonical Y so + # the bone stretches between joints instead of going rigid with one. for k in range(base_v.shape[0]): y_canon = float(base_v[k, 1]) w_parent = max(0.0, 1.0 - y_canon) @@ -196,22 +182,17 @@ def build_glb_skeletal( bone_vis_color: str = "white", include_body_mesh: bool = True, ) -> bytes: - """Build pose_data as a real Armature GLB blob with per-bone TRS keyframes. + """Build pose_data as a real Armature GLB with per-bone TRS keyframes. For + MHR, facial expression is exposed as 72 morph targets when + include_face_morphs=True. - For MHR (default) facial expression is exposed as 72 morph targets driven - by expr_params per frame when include_face_morphs=True. - - External skeletons (e.g. ComfyUI-Kimodo) can supply a - ``pose_data["_skeleton_override"]`` dict to bypass the MHR rig extraction - entirely. When present, ``model`` may be None and the rig data, bind pose, - skin weights, and rest verts come from the override. Per-frame skeletal - state still reads ``pred_global_rots`` / ``pred_joint_coords`` from each - person dict (kimodo populates these from its own FK output). See - ``glb.shared._get_skeleton_override`` for the override schema. + External skeletons (e.g. ComfyUI-Kimodo) can supply + ``pose_data["_skeleton_override"]`` to bypass MHR rig extraction (``model`` + may be None then); per-frame state still reads ``pred_global_rots`` / + ``pred_joint_coords``. See ``glb_shared._get_skeleton_override`` for the schema. """ frames = pose_data["frames"] - # Only `pred_cam_t` is camera-y-down; mhr_model_params, lbs_*, expr_basis, - # faces are all rig-native (Y-up). + # Only `pred_cam_t` is camera-y-down; everything else is rig-native Y-up. faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32) tracks = collect_tracks(pose_data, track_index) if not tracks: @@ -219,17 +200,14 @@ def build_glb_skeletal( rig = Rig.from_pose_data(pose_data, model) NJ = rig.num_joints - # NV = rig.num_verts NEXPR = rig.num_expr parents = rig.parents if not rig.can_rerun_fk: - # External rigs have no PCA pose params to re-run; only stored globals - # are available, and they store joint coords already Y-up. + # External rigs have no PCA pose params to re-run; use stored globals. use_stored_global_rots = True joint_coords_y_down = rig.per_frame_y_down - # Skinning is already compacted to ≤8 influences per vertex (MHR averages - # ~2.8 but some shoulder/hip verts hit 5-8; keeping only 4 there leaks - # per-bone rotation noise into the rendered mesh). + # Skin already compacted to ≤8 influences/vertex (some shoulder/hip verts + # need >4, else per-bone rotation noise leaks into the mesh). joints_8 = rig.lbs_joints weights_8 = rig.lbs_weights actual_max_inf = rig.lbs_max_inf @@ -238,14 +216,12 @@ def build_glb_skeletal( use_set1 = actual_max_inf > 4 joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None - # Derive bone locals from the rig's bind globals rather than recomputing - # FK ourselves, so any mismatch between `parents` and the rig's actual FK - # is absorbed into the local TRS instead of producing wrong globals. + # Derive bone locals from bind globals so any `parents`/FK mismatch is + # absorbed into the local TRS instead of producing wrong globals. bind_global_m = rig.bind_global_m bind_local = bone_locals_from_globals(bind_global_m[None], parents)[0] - # IBP = inverse of bind global. With bone defaults set to bind_local and - # FK composed via `parents`, skin_matrix at rest = identity. + # IBP = inverse of bind global → skin_matrix at rest is identity. ibp_mat4 = ibp_from_bind_global(bind_global_m) w = GLBWriter() @@ -316,9 +292,7 @@ def build_glb_skeletal( body_mesh_node_idx: Optional[int] = None if include_body: - # MHR rest verts depend on the subject's shape_params; external rigs - # ship fixed rest verts and ignore the arg (so the empty external - # `shape_params` is harmless). + # MHR rest verts depend on shape_params; external rigs ignore the arg. shape_params_arr = np.asarray( frames[frame_indices[0]][person_k].get("shape_params", []), dtype=np.float32, @@ -349,8 +323,8 @@ def build_glb_skeletal( "indices": indices_acc, "mode": 4, } - # See-through body when bones are shown, else opaque (only when a - # vertex-color shader baked COLOR_0 — otherwise default material). + # See-through body when bones are shown, else opaque (only if a + # shader baked COLOR_0; otherwise default material). if color_acc is not None or include_bones: materials.append(make_lit_material(opacity=0.35 if include_bones else 1.0)) primitive["material"] = len(materials) - 1 @@ -373,8 +347,7 @@ def build_glb_skeletal( if include_bones: bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color) - # Indexes `bone_palette`: octahedrons use the bone's child joint so - # every bone has its own color regardless of skin target. + # Color by child joint so every bone has its own color. color_idx_per_vert: Optional[np.ndarray] = None hw = float(bone_vis_radius_m) bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh( @@ -422,8 +395,8 @@ def build_glb_skeletal( nodes.append(bv_mesh_node) person_root["children"].append(len(nodes) - 1) - # Per-frame GLOBAL skel state → bone locals via parent-inverse. - # Default uses the rig's stored output; the fallback re-runs FK. + # Per-frame global skel state → bone locals via parent-inverse. Stored + # output by default; fallback re-runs FK. if use_stored_global_rots: rig_global_m = global_skel_state_from_pose_data( pose_data, frame_indices, person_k, NJ, @@ -437,11 +410,9 @@ def build_glb_skeletal( rig_global_cm = global_skel_state_per_frame(model, mp_per_frame) rig_global_m = rig_global_cm.copy().astype(np.float32) rig_global_m[..., :3] *= 0.01 - # Sign-fix on the GLOBAL quats BEFORE deriving locals. The rig's - # Euler-XYZ parametrization wraps at ±180° for spinning joints; if we - # only fix locals, the parent's flip propagates into the child's - # local translation (t_local inherits parent sign via q_parent_inv) - # and produces visible "axis resets" mid-animation. + # Sign-fix global quats BEFORE deriving locals: a parent's ±180° flip + # would otherwise propagate into the child's local translation and cause + # visible "axis resets" mid-animation. rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7]) bone_local_anim = bone_locals_from_globals(rig_global_m, parents) local_t = bone_local_anim[..., :3].astype(np.float32) @@ -449,20 +420,17 @@ def build_glb_skeletal( local_s = bone_local_anim[..., 7].astype(np.float32) # Second pass on locals catches residual drift from the parent-inverse. local_q = quat_sign_fix_per_joint(local_q) - # Hemisphere-align frame 0 with the bind quat so pause/play takes the - # short path; then re-propagate. + # Align frame 0 with the bind quat so pause/play takes the short path. bind_q = bind_local[:, 3:7].astype(np.float32) if local_q.shape[0] > 0: d0 = (bind_q * local_q[0]).sum(axis=-1) sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None] local_q[0] = local_q[0] * sign0 local_q = quat_sign_fix_per_joint(local_q) - # Optional smoothing for multi-frame rig spikes (e.g. q.w discontinuity - # at handstand) that the upstream Smooth node may not catch. + # Optional smoothing for multi-frame rig spikes (e.g. q.w at handstand). if bone_smooth_window and bone_smooth_window > 1: local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window)) - # fp64 renormalize → fp32 keyframes. Viewers' nlerp amplifies non-unit - # drift into visible flips otherwise. + # fp64 renormalize → fp32; viewers' nlerp amplifies non-unit drift. lq64 = local_q.astype(np.float64) lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12) local_q = lq64.astype(np.float32) @@ -527,7 +495,7 @@ def build_glb_skeletal( "target": {"node": person_root_idx, "path": "translation"}, }) - # Body-mesh-only: bone-vis primitives have no morph targets. + # Body mesh only — bone-vis primitives have no morph targets. if expr_morph_accs and body_mesh_node_idx is not None: expr_per_frame = np.stack([ np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32) diff --git a/comfy_extras/sam3d_body/utils.py b/comfy_extras/sam3d_body/utils.py index dda3b6d04..d8d0e9eb4 100644 --- a/comfy_extras/sam3d_body/utils.py +++ b/comfy_extras/sam3d_body/utils.py @@ -34,8 +34,7 @@ def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]: def inputs_from_sam3_track(track_data, B: int, H: int, W: int): """Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image - resolution. Returns (per_frame_bboxes, per_frame_masks) or - (None, None) when the track is empty / frame count doesn't match""" + resolution. Returns (None, None) on empty track / frame-count mismatch.""" packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None if packed is None: @@ -100,7 +99,7 @@ def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[to def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any], H: int, W: int) -> Dict[str, Any]: """Re-project every frame's pose through a Load3D 6DOF camera (position/ - target/zoom + optional FOV). Returns a new mhr_pose_data; unchanged on + target/zoom + optional FOV). Returns a new mhr_pose_data, unchanged on empty/invalid input.""" first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else [] if not first_frame: @@ -158,16 +157,16 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, y_axis = np.cross(z_axis, x_axis) R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32) - # Eye: dolly along the given offset; for a rotation-only camera (position == - # target) keep the predicted viewing distance so only orientation/roll changes. + # Eye: dolly along the offset; rotation-only camera keeps the predicted + # viewing distance so only orientation/roll changes. if has_offset: eye = target + offset / max(0.01, zoom) else: d = max(0.1, float(target[2])) eye = target - z_axis * (d / max(0.01, zoom)) - # Lens: use the camera's own FoV; else the SAM3D predicted focal (viewpoint- - # only change). Three.js fov is vertical → focal from image height. + # Lens: camera FoV if given, else the SAM3D predicted focal. Three.js fov + # is vertical → focal from image height. cam_fov = float(camera_info.get("fov", 0.0) or 0.0) if cam_fov > 0: new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0))) @@ -178,10 +177,8 @@ def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, center = np.array([W * 0.5, H * 0.5], dtype=np.float32) reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"} - # External rigs (e.g. Kimodo) store pred_joint_coords rig-native Y-up; the - # render openpose/scail keypoint provider resolves from them and flips Y/Z. - # Transform them through the camera too (in camera space, then back to Y-up) - # so those keypoints follow the override instead of staying in the old frame. + # External rigs store pred_joint_coords Y-up; transform them through the + # camera too (in camera space, then back to Y-up) so they follow the override. override = mhr_pose_data.get("_skeleton_override") joints_y_up = override is not None and not bool(override.get("per_frame_y_down", False)) new_frames: List[List[Dict[str, Any]]] = [] @@ -242,8 +239,7 @@ def run_batched_single_chunk(inner: SAM3DBody, frames_rgb: List[torch.Tensor], p img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)] if per_frame_masks is not None: - # Broadcast a single-mask bundle to per-bbox: when the user supplied one - # mask but multiple bboxes per frame, each bbox gets the same mask. + # One mask but multiple bboxes per frame → each bbox gets the same mask. flat_masks = [] for f in range(N): mf = per_frame_masks[f]