From 12940417782daf40fdb314a49f0673b926b63912 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 26 May 2026 02:15:15 +0300 Subject: [PATCH] Initial sam3d body support --- comfy/ldm/sam3d_body/mhr/mhr_head.py | 377 ++++++ comfy/ldm/sam3d_body/mhr/mhr_rig.py | 235 ++++ comfy/ldm/sam3d_body/mhr/mhr_utils.py | 413 ++++++ comfy/ldm/sam3d_body/model/camera_modules.py | 155 +++ comfy/ldm/sam3d_body/model/dinov3.py | 250 ++++ comfy/ldm/sam3d_body/model/model.py | 1197 +++++++++++++++++ comfy/ldm/sam3d_body/model/prompt.py | 272 ++++ comfy/ldm/sam3d_body/model/transformer.py | 104 ++ comfy/ldm/sam3d_body/utils.py | 341 +++++ comfy_extras/mediapipe/face_landmarker.py | 49 +- comfy_extras/nodes_sam3d_body.py | 1014 ++++++++++++++ comfy_extras/nodes_save_3d.py | 442 +++++- comfy_extras/nodes_sdpose.py | 278 +--- comfy_extras/pose/keypoint_draw.py | 348 +++++ comfy_extras/sam3d_body/export/bvh.py | 207 +++ comfy_extras/sam3d_body/export/capsules.py | 403 ++++++ .../sam3d_body/export/glb_openpose.py | 1138 ++++++++++++++++ comfy_extras/sam3d_body/export/glb_shared.py | 1138 ++++++++++++++++ .../sam3d_body/export/glb_skeletal.py | 578 ++++++++ comfy_extras/sam3d_body/export/openpose_2d.py | 233 ++++ comfy_extras/sam3d_body/face_expression.py | 516 +++++++ comfy_extras/sam3d_body/rasterizer.py | 467 +++++++ comfy_extras/sam3d_body/utils.py | 397 ++++++ nodes.py | 1 + 24 files changed, 10261 insertions(+), 292 deletions(-) create mode 100644 comfy/ldm/sam3d_body/mhr/mhr_head.py create mode 100644 comfy/ldm/sam3d_body/mhr/mhr_rig.py create mode 100644 comfy/ldm/sam3d_body/mhr/mhr_utils.py create mode 100644 comfy/ldm/sam3d_body/model/camera_modules.py create mode 100644 comfy/ldm/sam3d_body/model/dinov3.py create mode 100644 comfy/ldm/sam3d_body/model/model.py create mode 100644 comfy/ldm/sam3d_body/model/prompt.py create mode 100644 comfy/ldm/sam3d_body/model/transformer.py create mode 100644 comfy/ldm/sam3d_body/utils.py create mode 100644 comfy_extras/nodes_sam3d_body.py create mode 100644 comfy_extras/pose/keypoint_draw.py create mode 100644 comfy_extras/sam3d_body/export/bvh.py create mode 100644 comfy_extras/sam3d_body/export/capsules.py create mode 100644 comfy_extras/sam3d_body/export/glb_openpose.py create mode 100644 comfy_extras/sam3d_body/export/glb_shared.py create mode 100644 comfy_extras/sam3d_body/export/glb_skeletal.py create mode 100644 comfy_extras/sam3d_body/export/openpose_2d.py create mode 100644 comfy_extras/sam3d_body/face_expression.py create mode 100644 comfy_extras/sam3d_body/rasterizer.py create mode 100644 comfy_extras/sam3d_body/utils.py diff --git a/comfy/ldm/sam3d_body/mhr/mhr_head.py b/comfy/ldm/sam3d_body/mhr/mhr_head.py new file mode 100644 index 000000000..e30e48cb3 --- /dev/null +++ b/comfy/ldm/sam3d_body/mhr/mhr_head.py @@ -0,0 +1,377 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat +from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask + +from ..model.transformer import MLP + + +class MHRHead(nn.Module): + + def __init__( + self, + input_dim: int, + mhr_rig, + mlp_depth: int = 1, + extra_joint_regressor: str = "", + mlp_channel_div_factor: int = 8, + enable_hand_model=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + # Store the shared MHRRig as a non-registered Python attribute + object.__setattr__(self, "mhr", mhr_rig) + + self.num_shape_comps = 45 + self.num_scale_comps = 28 + self.num_hand_comps = 54 + self.num_face_comps = 72 + self.enable_hand_model = enable_hand_model + + self.body_cont_dim = 260 + self.npose = ( + 6 # Global Rotation + + self.body_cont_dim # then body + + self.num_shape_comps + + self.num_scale_comps + + self.num_hand_comps * 2 + + self.num_face_comps + ) + + self.proj = MLP( + input_dim=input_dim, + hidden_dim=input_dim // mlp_channel_div_factor, + output_dim=self.npose, + num_layers=mlp_depth, + device=device, + dtype=dtype, + operations=operations, + ) + + # MHR Parameters + self.num_hand_scale_comps = self.num_scale_comps - 18 + self.num_hand_pose_comps = self.num_hand_comps + + # Buffers populated by load_state_dict from the safetensors + def _p(*shape, dtype=torch.float32): + return nn.Parameter(torch.empty(*shape, dtype=dtype), requires_grad=False) + self.joint_rotation = _p(127, 3, 3) + self.scale_mean = _p(68) + self.scale_comps = _p(28, 68) + self.faces = _p(36874, 3, dtype=torch.int64) + self.hand_pose_mean = _p(54) + self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False) + self.hand_joint_idxs_left = _p(27, dtype=torch.int64) + self.hand_joint_idxs_right = _p(27, dtype=torch.int64) + self.keypoint_mapping = _p(308, 18439 + 127) + # Some special buffers for the hand-version + self.right_wrist_coords = _p(3) + self.root_coords = _p(3) + self.local_to_world_wrist = _p(3, 3) + self.nonhand_param_idxs = _p(145, dtype=torch.int64) + # Hand-painted per-vertex face region RGB (rainbow_face_semantic shader). + # Optional — loaded from the .safetensors if present, otherwise the + # render path falls back to a coarse geometric approximation. + self.register_buffer( + "face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32), + ) + + def canonical_vertices(self, device=None): + """Return the T-pose vertices for the mean shape (scaled to meters). + + Runs MHR with zero pose / shape / scale / expression so the returned + mesh is the canonical rest pose — fixed per-model + """ + dev = device or self.scale_mean.device + dt = self.scale_mean.dtype + B = 1 + global_trans = torch.zeros(B, 3, device=dev, dtype=dt) + global_rot = torch.zeros(B, 3, device=dev, dtype=dt) + body_pose = torch.zeros(B, 130, device=dev, dtype=dt) + hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt) + scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt) + shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt) + expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt) + verts = self.mhr_forward( + global_trans=global_trans, + global_rot=global_rot, + body_pose_params=body_pose, + hand_pose_params=hand_pose, + scale_params=scale, + shape_params=shape, + expr_params=expr, + ) # single-tensor shape (1, N_v, 3) in meters + return verts[0] + + def get_zero_pose_init(self, factor=1.0): + # Initialize pose token with zero-initialized learnable params + # Note: bias/initial value should be zero-pose in cont, not all-zeros + weights = torch.zeros(1, self.npose) + weights[:, : 6 + self.body_cont_dim] = torch.cat( + [ + torch.FloatTensor([1, 0, 0, 0, 1, 0]), + compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze() + * factor, + ], + dim=0, + ) + return weights + + def replace_hands_in_pose(self, full_pose_params, hand_pose_params): + assert full_pose_params.shape[1] == 136 + + # This drops in the hand poses from hand_pose_params (PCA 6D) into full_pose_params. + # Split into left and right hands + left_hand_params, right_hand_params = torch.split( + hand_pose_params, + [self.num_hand_pose_comps, self.num_hand_pose_comps], + dim=1, + ) + + # Change from cont to model params + left_hand_params_model_params = compact_cont_to_model_params_hand( + self.hand_pose_mean + + torch.einsum("da,ab->db", left_hand_params, self.hand_pose_comps) + ) + right_hand_params_model_params = compact_cont_to_model_params_hand( + self.hand_pose_mean + + torch.einsum("da,ab->db", right_hand_params, self.hand_pose_comps) + ) + + # Drop it in + full_pose_params[:, self.hand_joint_idxs_left] = left_hand_params_model_params + full_pose_params[:, self.hand_joint_idxs_right] = right_hand_params_model_params + + return full_pose_params # B x 207 + + def mhr_forward( + self, + global_trans, + global_rot, + body_pose_params, + hand_pose_params, + scale_params, + shape_params, + expr_params=None, + return_keypoints=False, + do_pcblend=True, + return_joint_coords=False, + return_model_params=False, + return_joint_rotations=False, + scale_offsets=None, + vertex_offsets=None, + ): + # Align everything to the static buffers + dt = self.scale_mean.dtype + global_trans = global_trans.to(dt) + global_rot = global_rot.to(dt) + body_pose_params = body_pose_params.to(dt) + if hand_pose_params is not None: + hand_pose_params = hand_pose_params.to(dt) + scale_params = scale_params.to(dt) + shape_params = shape_params.to(dt) + if expr_params is not None: + expr_params = expr_params.to(dt) + + if self.enable_hand_model: + # Transfer wrist-centric predictions to the body. + global_rot_ori = global_rot.clone() + global_trans_ori = global_trans.clone() + global_rot = rotmat_to_euler( + "xyz", + euler_to_rotmat("xyz", global_rot_ori) @ self.local_to_world_wrist, + ) + global_trans = ( + -( + euler_to_rotmat("xyz", global_rot) + @ (self.right_wrist_coords - self.root_coords) + + self.root_coords + ) + + global_trans_ori + ) + + body_pose_params = body_pose_params[..., :130] + + # Convert from scale and shape params to actual scales and vertices + + # Add singleton batches in case... + if len(scale_params.shape) == 1: + scale_params = scale_params[None] + if len(shape_params.shape) == 1: + shape_params = shape_params[None] + # Convert scale... + scales = self.scale_mean[None, :] + scale_params @ self.scale_comps + if scale_offsets is not None: + scales = scales + scale_offsets + + # Now, figure out the pose. + ## 10 here is because it's more stable to optimize global translation in meters. + full_pose_params = torch.cat( + [global_trans * 10, global_rot, body_pose_params], dim=1 + ) # B x 127 + ## Put in hands + if hand_pose_params is not None: + full_pose_params = self.replace_hands_in_pose( + full_pose_params, hand_pose_params + ) + model_params = torch.cat([full_pose_params, scales], dim=1) + + if self.enable_hand_model: + # Zero out non-hand parameters + model_params[:, self.nonhand_param_idxs] = 0 + + curr_skinned_verts, curr_skel_state = self.mhr( + shape_params, model_params, expr_params + ) + curr_joint_coords, curr_joint_quats, _ = torch.split( + curr_skel_state, [3, 4, 1], dim=2 + ) + curr_skinned_verts = curr_skinned_verts / 100 + curr_joint_coords = curr_joint_coords / 100 + curr_joint_rots = unitquat_to_rotmat(curr_joint_quats) + + # Prepare returns + to_return = [curr_skinned_verts] + if return_keypoints: + # Get sapiens 308 keypoints + model_vert_joints = torch.cat( + [curr_skinned_verts, curr_joint_coords], dim=1 + ) # B x (num_verts + 127) x 3 + + kp_map = self.keypoint_mapping.to(model_vert_joints.dtype) + model_keypoints_pred = ( + (kp_map @ model_vert_joints.permute(1, 0, 2).flatten(1, 2)) + .reshape(-1, model_vert_joints.shape[0], 3) + .permute(1, 0, 2) + ) + + if self.enable_hand_model: + # Zero out everything except for the right hand + model_keypoints_pred[:, :21] = 0 + model_keypoints_pred[:, 42:] = 0 + + to_return = to_return + [model_keypoints_pred] + if return_joint_coords: + to_return = to_return + [curr_joint_coords] + if return_model_params: + to_return = to_return + [model_params] + if return_joint_rotations: + to_return = to_return + [curr_joint_rots] + + if isinstance(to_return, list) and len(to_return) == 1: + return to_return[0] + else: + return tuple(to_return) + + def forward( + self, + x: torch.Tensor, + init_estimate: Optional[torch.Tensor] = None, + do_pcblend=True, + slim_keypoints=False, + intermediate: bool = False, + ): + """ + Args: + x: pose token with shape [B, C], usually C=DECODER.DIM + init_estimate: [B, self.npose] + intermediate: when True, the caller only needs the keypoints/pose + outputs needed by the per-layer keypoint-token update path — + vertex output is suppressed so `camera_project` skips the + 18439-vertex perspective projection on intermediate decoder + layers. The final layer must call with intermediate=False. + """ + batch_size = x.shape[0] + pred = self.proj(x) + if init_estimate is not None: + pred = pred + init_estimate + + # From pred, we want to pull out individual predictions. + + ## First, get globals + ### Global rotation is first 6. + count = 6 + global_rot_6d = pred[:, :count] + global_rot_rotmat = rot6d_to_rotmat(global_rot_6d) # B x 3 x 3 + global_rot_euler = rotmat_to_euler("ZYX", global_rot_rotmat) # B x 3 + global_trans = torch.zeros_like(global_rot_euler) + + ## Next, get body pose. + ### Hold onto raw, continuous version for iterative correction. + pred_pose_cont = pred[:, count : count + self.body_cont_dim] + count += self.body_cont_dim + ### Convert to eulers (and trans) + pred_pose_euler = compact_cont_to_model_params_body(pred_pose_cont) + ### Zero-out hands + pred_pose_euler[:, mhr_param_hand_mask] = 0 + ### Zero-out jaw + pred_pose_euler[:, -3:] = 0 + + ## Get remaining parameters + pred_shape = pred[:, count : count + self.num_shape_comps] + count += self.num_shape_comps + pred_scale = pred[:, count : count + self.num_scale_comps] + count += self.num_scale_comps + pred_hand = pred[:, count : count + self.num_hand_comps * 2] + count += self.num_hand_comps * 2 + pred_face = pred[:, count : count + self.num_face_comps] * 0 + count += self.num_face_comps + + # Run everything through mhr + output = self.mhr_forward( + global_trans=global_trans, + global_rot=global_rot_euler, + body_pose_params=pred_pose_euler, + hand_pose_params=pred_hand, + scale_params=pred_scale, + shape_params=pred_shape, + expr_params=pred_face, + do_pcblend=do_pcblend, + return_keypoints=True, + return_joint_coords=True, + return_model_params=True, + return_joint_rotations=True, + ) + + # Some existing code to get joints and fix camera system + verts, j3d, jcoords, mhr_model_params, joint_global_rots = output + j3d = j3d[:, :70] # 308 --> 70 keypoints + + # Intermediate decoder layers only consume pred_keypoints_3d via the + # keypoint-token update path; suppress verts so camera_project skips + # the 18439-vertex perspective projection. + if intermediate: + verts = None + if verts is not None: + verts[..., [1, 2]] *= -1 # Camera system difference + j3d[..., [1, 2]] *= -1 # Camera system difference + if jcoords is not None: + jcoords[..., [1, 2]] *= -1 + + # Head-MLP outputs are promoted to fp32 here so the external + # pose_output["mhr"] contract has a stable dtype regardless of what + # the head ran at (fp16/bf16 for speed). MHR-derived outputs are + # already fp32 from MHR's math layers; the cast on them is a no-op. + output = { + "pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(), + "pred_pose_rotmat": None, + "global_rot": global_rot_euler.float(), + "body_pose": pred_pose_euler.float(), + "shape": pred_shape.float(), + "scale": pred_scale.float(), + "hand": pred_hand.float(), + "face": pred_face.float(), + "pred_keypoints_3d": j3d.reshape(batch_size, -1, 3), + "pred_vertices": verts.reshape(batch_size, -1, 3) if verts is not None else None, + "pred_joint_coords": jcoords.reshape(batch_size, -1, 3) if jcoords is not None else None, + "faces": self.faces.cpu().numpy(), + "joint_global_rots": joint_global_rots, + "mhr_model_params": mhr_model_params, + } + + return output diff --git a/comfy/ldm/sam3d_body/mhr/mhr_rig.py b/comfy/ldm/sam3d_body/mhr/mhr_rig.py new file mode 100644 index 000000000..f5542563f --- /dev/null +++ b/comfy/ldm/sam3d_body/mhr/mhr_rig.py @@ -0,0 +1,235 @@ +# Adapted from facebookresearch/MHR (Apache 2.0): +# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py +# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas +# verbatim from the TorchScript source bundled in the upstream mhr_model.pt +# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}). +# Original Copyright (c) Meta Platforms, Inc. and affiliates. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +torch.sparse.check_sparse_tensor_invariants.disable() # silence the "implicitly disabled" UserWarning. + +from .mhr_utils import batch6DFromXYZ + +_LN2 = 0.6931471824645996 + +def _euler_xyz_to_quat(angles): + """(roll, pitch, yaw) -> quaternion (x, y, z, w). Matches pymomentum.quaternion.euler_xyz_to_quaternion.""" + roll, pitch, yaw = angles.unbind(-1) + cy, sy = torch.cos(yaw * 0.5), torch.sin(yaw * 0.5) + cp, sp = torch.cos(pitch * 0.5), torch.sin(pitch * 0.5) + cr, sr = torch.cos(roll * 0.5), torch.sin(roll * 0.5) + x = sr * cp * cy - cr * sp * sy + y = cr * sp * cy + sr * cp * sy + z = cr * cp * sy - sr * sp * cy + w = cr * cp * cy + sr * sp * sy + return torch.stack([x, y, z, w], dim=-1) + + +def _quat_multiply(q1, q2): + x1, y1, z1, w1 = q1.unbind(-1) + x2, y2, z2, w2 = q2.unbind(-1) + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + return torch.stack([x, y, z, w], dim=-1) + + +def _quat_rotate(q, v): + """Rotate v by unit quaternion q (xyzw). v + 2 * (axis x v * w + axis x (axis x v)).""" + axis = q[..., :3] + r = q[..., 3:4] + av = torch.cross(axis, v, dim=-1) + aav = torch.cross(axis, av, dim=-1) + return v + 2.0 * (av * r + aav) + + +def _skel_multiply(s1, s2): + """Compose two skel states (..., 8). Returns parent ∘ child. + + Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized + before composition. With many FK levels the previously-normalized quats + drift in ULPs; the JIT renormalizes defensively, so we do too to stay + bit-close to its outputs. + """ + t1, sc1 = s1[..., :3], s1[..., 7:8] + t2, sc2 = s2[..., :3], s2[..., 7:8] + q1 = F.normalize(s1[..., 3:7], p=2, dim=-1, eps=1e-12) + q2 = F.normalize(s2[..., 3:7], p=2, dim=-1, eps=1e-12) + t_res = t1 + sc1 * _quat_rotate(q1, t2) + q_res = _quat_multiply(q1, q2) + s_res = sc1 * sc2 + return torch.cat([t_res, q_res, s_res], dim=-1) + + +def _skel_transform_points(skel_state, points): + """Apply skel_state (..., 8) to points (..., 3): t + q * (s * points). + + Assumes the quaternion in skel_state is already unit-norm. Callers that + can't guarantee that should normalize first. + """ + t = skel_state[..., :3] + q = skel_state[..., 3:7] + s = skel_state[..., 7:8] + return t + _quat_rotate(q, s * points) + + +def _global_skel_state_from_local(local, pmi_levels): + """FK walk in fp64 (matches the JIT's use_double_precision=True path). + + `pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs, + one per BFS level. Avoids per-call torch.split + tolist() sync. + """ + orig_dtype = local.dtype + g = local.to(torch.float64).clone() + for source, target in pmi_levels: + parent = g.index_select(-2, target) + child = g.index_select(-2, source) + g.index_copy_(-2, source, _skel_multiply(parent, child)) + return g.to(orig_dtype) + + +class MHRRig(nn.Module): + """Plain-PyTorch reimplementation of Meta's MHR rig. + + All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's + use_double_precision=True backend) regardless of the host model's dtype. + """ + + NUM_VERTS = 18439 + NUM_JOINTS = 127 + NUM_LBS_TRIPLETS = 51337 + NUM_IDENTITY = 45 + NUM_EXPR = 72 + PARAM_TRANSFORM_IN = 249 # = model_parameters(204) + identity_coeffs(45) + PARAM_TRANSFORM_OUT = 889 # = NUM_JOINTS * 7 + POSE_CORR_IN = 750 # = (NUM_JOINTS - 2) * 6 + POSE_CORR_HIDDEN = 3000 + POSE_CORR_SPARSE_NNZ = 53136 + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + del dtype, operations + f32 = torch.float32 + + # All buffers are populated by load_state_dict from the `mhr.*` keys + def _p(*shape, dtype=f32): + return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False) + def _b(name, *shape, dtype): + self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device)) + + self.base_shape = _p(self.NUM_VERTS, 3) + self.identity_basis = _p(self.NUM_IDENTITY, self.NUM_VERTS, 3) + self.expr_basis = _p(self.NUM_EXPR, self.NUM_VERTS, 3) + self.param_transform = _p(self.PARAM_TRANSFORM_OUT, self.PARAM_TRANSFORM_IN) + + self.skel_joint_translation_offsets = _p(self.NUM_JOINTS, 3) + self.skel_joint_prerotations = _p(self.NUM_JOINTS, 4) + _b("skel_joint_parents", self.NUM_JOINTS, dtype=torch.int32) + _b("skel_pmi", 2, 266, dtype=torch.int64) + _b("skel_pmi_buffer_sizes", 4, dtype=torch.int64) + + self.lbs_inverse_bind_pose = _p(self.NUM_JOINTS, 8) + self.lbs_skin_weights = _p(self.NUM_LBS_TRIPLETS) + _b("lbs_skin_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int32) + _b("lbs_vert_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int64) + + _b("pose_corr_sparse_indices", 2, self.POSE_CORR_SPARSE_NNZ, dtype=torch.int64) + self.pose_corr_sparse_weight = _p(self.POSE_CORR_SPARSE_NNZ) + + self.register_buffer("pose_corr_sparse_shape", torch.tensor([self.POSE_CORR_HIDDEN, self.POSE_CORR_IN], dtype=torch.int64, device=device)) + self.pose_corr_weight = _p(self.NUM_VERTS * 3, self.POSE_CORR_HIDDEN) + self.pose_corr_bias = None + self._pose_corr_sparse_cache = None + self._pmi_levels_cache = None + + def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True): + f32 = self.base_shape.dtype + identity_coeffs = identity_coeffs.to(f32) + model_parameters = model_parameters.to(f32) + expr_coeffs = expr_coeffs.to(f32) + B = identity_coeffs.shape[0] + + identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs) + + cat_in = torch.cat([model_parameters, torch.zeros_like(identity_coeffs)], dim=1) + joint_parameters = torch.einsum("dn,bn->bd", self.param_transform, cat_in) + + jp = joint_parameters.view(B, self.NUM_JOINTS, 7) + local_t = jp[..., :3] + self.skel_joint_translation_offsets.unsqueeze(0) + local_q = _euler_xyz_to_quat(jp[..., 3:6]) + local_q = _quat_multiply(self.skel_joint_prerotations.unsqueeze(0), local_q) + local_s = torch.exp(jp[..., 6:7] * _LN2) + local_state = torch.cat([local_t, local_q, local_s], dim=-1) + + skel_state = _global_skel_state_from_local(local_state, self._pmi_levels()) + + face_expr = torch.einsum("nvd,bn->bvd", self.expr_basis, expr_coeffs) + unposed = identity_rest + face_expr + if apply_correctives: + unposed = unposed + self._pose_correctives(joint_parameters) + + verts = self._skin(skel_state, unposed) + return verts, skel_state + + def _pose_correctives(self, joint_parameters): + B = joint_parameters.shape[0] + jp = joint_parameters.view(B, self.NUM_JOINTS, 7) + # Joints [2:] only — root and one more skipped. Take Euler XYZ (cols 3:6). + feat = batch6DFromXYZ(jp[:, 2:, 3:6], return_9D=False) # (B, 125, 6) + feat[..., 0] -= 1.0 + feat[..., 4] -= 1.0 + feat = feat.flatten(1, 2) # (B, 750) + + h = (self._sparse_w() @ feat.T).T # (B, 3000) + h = F.relu(h) + out = F.linear(h, self.pose_corr_weight, self.pose_corr_bias) # (B, 55317) + return out.view(B, self.NUM_VERTS, 3) + + def _pmi_levels(self): + cached = self._pmi_levels_cache + pmi = self.skel_pmi + if cached is not None and cached[0][0].device == pmi.device: + return cached + sizes = self.skel_pmi_buffer_sizes.tolist() + parts = torch.split(pmi, sizes, dim=1) + levels = [(p[0], p[1]) for p in parts] + self._pmi_levels_cache = levels + return levels + + def _sparse_w(self): + cached = self._pose_corr_sparse_cache + w = self.pose_corr_sparse_weight + if cached is not None and cached.device == w.device and cached.dtype == w.dtype: + return cached + sparse = torch.sparse_coo_tensor( + self.pose_corr_sparse_indices, + w, + tuple(self.pose_corr_sparse_shape.tolist()), + check_invariants=False, + ).coalesce() + self._pose_corr_sparse_cache = sparse + return sparse + + def _skin(self, skel_state, rest_verts): + B = skel_state.shape[0] + ibp = self.lbs_inverse_bind_pose.unsqueeze(0).expand(B, self.NUM_JOINTS, 8) + joint_xform = _skel_multiply(skel_state, ibp) + + norm_q = F.normalize(joint_xform[..., 3:7], p=2, dim=-1, eps=1e-12) + joint_xform = torch.cat([joint_xform[..., :3], norm_q, joint_xform[..., 7:8]], dim=-1) + + sk_idx = self.lbs_skin_indices.long() + v_idx = self.lbs_vert_indices + w = self.lbs_skin_weights + + per_triplet_xform = joint_xform.index_select(-2, sk_idx) # (B, 51337, 8) + per_triplet_rest = rest_verts.index_select(-2, v_idx) # (B, 51337, 3) + contrib = _skel_transform_points(per_triplet_xform, per_triplet_rest) * w.unsqueeze(0).unsqueeze(-1) + + out = torch.zeros(B, self.NUM_VERTS, 3, dtype=rest_verts.dtype, device=rest_verts.device) + out.index_add_(-2, v_idx, contrib) + return out diff --git a/comfy/ldm/sam3d_body/mhr/mhr_utils.py b/comfy/ldm/sam3d_body/mhr/mhr_utils.py new file mode 100644 index 000000000..421b5fcd7 --- /dev/null +++ b/comfy/ldm/sam3d_body/mhr/mhr_utils.py @@ -0,0 +1,413 @@ +# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers +# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity +# representation from Zhou et al., "On the Continuity of Rotation +# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035), +# implementations from papagina/RotationContinuity: +# https://github.com/papagina/RotationContinuity/blob/758b0ce5/shapenet/code/tools.py +# The compact_cont_to_model_params_* functions are MHR-rig-specific glue. + +import torch +import torch.nn.functional as F + + +def rotation_angle_difference(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """ + Compute the angle difference (magnitude) between two batches of SO(3) rotation matrices. + Args: + A: Tensor of shape (*, 3, 3), batch of rotation matrices. + B: Tensor of shape (*, 3, 3), batch of rotation matrices. + Returns: + Tensor of shape (*,), angle differences in radians. + """ + # Compute relative rotation matrix + R_rel = torch.matmul(A, B.transpose(-2, -1)) # (B, 3, 3) + # Compute trace of relative rotation + trace = R_rel[..., 0, 0] + R_rel[..., 1, 1] + R_rel[..., 2, 2] # (B,) + # Compute angle using the trace formula + cos_theta = (trace - 1) / 2 + # Clamp for numerical stability + cos_theta_clamped = torch.clamp(cos_theta, -1.0, 1.0) + # Compute angle difference + angle = torch.acos(cos_theta_clamped) + return angle + + +def fix_wrist_euler( + wrist_xzy, limits_x=(-2.2, 1.0), limits_z=(-2.2, 1.5), limits_y=(-1.2, 1.5) +): + """ + wrist_xzy: B x 2 x 3 (X, Z, Y angles) + Returns: Fixed angles within joint limits + """ + x, z, y = wrist_xzy[..., 0], wrist_xzy[..., 1], wrist_xzy[..., 2] + + x_alt = torch.atan2(torch.sin(x + torch.pi), torch.cos(x + torch.pi)) + z_alt = torch.atan2(torch.sin(-(z + torch.pi)), torch.cos(-(z + torch.pi))) + y_alt = torch.atan2(torch.sin(y + torch.pi), torch.cos(y + torch.pi)) + + # Calculate L2 violation distance + def calc_violation(val, limits): + below = torch.clamp(limits[0] - val, min=0.0) + above = torch.clamp(val - limits[1], min=0.0) + return below**2 + above**2 + + violation_orig = ( + calc_violation(x, limits_x) + + calc_violation(z, limits_z) + + calc_violation(y, limits_y) + ) + + violation_alt = ( + calc_violation(x_alt, limits_x) + + calc_violation(z_alt, limits_z) + + calc_violation(y_alt, limits_y) + ) + + # Use alternative where it has lower L2 violation + use_alt = violation_alt < violation_orig + + # Stack alternative and apply mask + wrist_xzy_alt = torch.stack([x_alt, z_alt, y_alt], dim=-1) + result = torch.where(use_alt.unsqueeze(-1), wrist_xzy_alt, wrist_xzy) + + return result + + +# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py +def batch6DFromXYZ(r, return_9D=False): + """ + Generate a matrix representing a rotation defined by a XYZ-Euler + rotation. + + Args: + r: ... x 3 rotation vectors + + Returns: + ... x 6 + """ + rc = torch.cos(r) + rs = torch.sin(r) + cx = rc[..., 0] + cy = rc[..., 1] + cz = rc[..., 2] + sx = rs[..., 0] + sy = rs[..., 1] + sz = rs[..., 2] + + result = torch.empty(list(r.shape[:-1]) + [3, 3], dtype=r.dtype).to(r.device) + + result[..., 0, 0] = cy * cz + result[..., 0, 1] = -cx * sz + sx * sy * cz + result[..., 0, 2] = sx * sz + cx * sy * cz + result[..., 1, 0] = cy * sz + result[..., 1, 1] = cx * cz + sx * sy * sz + result[..., 1, 2] = -sx * cz + cx * sy * sz + result[..., 2, 0] = -sy + result[..., 2, 1] = sx * cy + result[..., 2, 2] = cx * cy + + if not return_9D: + return torch.cat([result[..., :, 0], result[..., :, 1]], dim=-1) + else: + return result + + +# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L82 +def batchXYZfrom6D(poses): + # Args: poses: ... x 6, where "6" is the combined first and second columns + # First, get the rotaiton matrix + x_raw = poses[..., :3] + y_raw = poses[..., 3:] + + x = F.normalize(x_raw, dim=-1) + z = torch.cross(x, y_raw, dim=-1) + z = F.normalize(z, dim=-1) + y = torch.cross(z, x, dim=-1) + + matrix = torch.stack([x, y, z], dim=-1) # ... x 3 x 3 + + # Now get it into euler + # https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L412 + sy = torch.sqrt( + matrix[..., 0, 0] * matrix[..., 0, 0] + matrix[..., 1, 0] * matrix[..., 1, 0] + ) + singular = sy < 1e-6 + singular = singular.float() + + x = torch.atan2(matrix[..., 2, 1], matrix[..., 2, 2]) + y = torch.atan2(-matrix[..., 2, 0], sy) + z = torch.atan2(matrix[..., 1, 0], matrix[..., 0, 0]) + + xs = torch.atan2(-matrix[..., 1, 2], matrix[..., 1, 1]) + ys = torch.atan2(-matrix[..., 2, 0], sy) + zs = matrix[..., 1, 0] * 0 + + out_euler = torch.zeros_like(matrix[..., 0]) + out_euler[..., 0] = x * (1 - singular) + xs * singular + out_euler[..., 1] = y * (1 - singular) + ys * singular + out_euler[..., 2] = z * (1 - singular) + zs * singular + + return out_euler + + +_HAND_DOFS = [3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1] +_HAND_MASK_CACHE: dict = {} # device -> dict of masks + + +def _hand_masks(device): + m = _HAND_MASK_CACHE.get(device) + if m is not None: + return m + mask_cont_threedofs = torch.cat( + [torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS] + ).to(device) + mask_cont_onedofs = torch.cat( + [torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS] + ).to(device) + mask_model_params_threedofs = torch.cat( + [torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS] + ).to(device) + mask_model_params_onedofs = torch.cat( + [torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS] + ).to(device) + m = dict( + mask_cont_threedofs=mask_cont_threedofs, + mask_cont_onedofs=mask_cont_onedofs, + mask_model_params_threedofs=mask_model_params_threedofs, + mask_model_params_onedofs=mask_model_params_onedofs, + ) + _HAND_MASK_CACHE[device] = m + return m + + +def compact_cont_to_model_params_hand(hand_cont): + # These are ordered by joint, not model params ^^ + assert hand_cont.shape[-1] == 54 + m = _hand_masks(hand_cont.device) + mask_cont_threedofs = m["mask_cont_threedofs"] + mask_cont_onedofs = m["mask_cont_onedofs"] + mask_model_params_threedofs = m["mask_model_params_threedofs"] + mask_model_params_onedofs = m["mask_model_params_onedofs"] + + # Convert hand_cont to eulers + ## First for 3DoFs + hand_cont_threedofs = hand_cont[..., mask_cont_threedofs].unflatten(-1, (-1, 6)) + hand_model_params_threedofs = batchXYZfrom6D(hand_cont_threedofs).flatten(-2, -1) + ## Next for 1DoFs + hand_cont_onedofs = hand_cont[..., mask_cont_onedofs].unflatten( + -1, (-1, 2) + ) # (sincos) + hand_model_params_onedofs = torch.atan2( + hand_cont_onedofs[..., -2], hand_cont_onedofs[..., -1] + ) + + # Finally, assemble into a 27-dim vector, ordered by joint, then XYZ. + hand_model_params = torch.zeros(*hand_cont.shape[:-1], 27).to(hand_cont) + hand_model_params[..., mask_model_params_threedofs] = hand_model_params_threedofs + hand_model_params[..., mask_model_params_onedofs] = hand_model_params_onedofs + + return hand_model_params + + +def compact_model_params_to_cont_hand(hand_model_params): + # These are ordered by joint, not model params ^^ + assert hand_model_params.shape[-1] == 27 + hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1]) + assert sum(hand_dofs_in_order) == 27 + # Mask of 3DoFs into hand_cont + mask_cont_threedofs = torch.cat( + [torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order] + ) + # Mask of 1DoFs (including 2DoF) into hand_cont + mask_cont_onedofs = torch.cat( + [torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order] + ) + # Mask of 3DoFs into hand_model_params + mask_model_params_threedofs = torch.cat( + [torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order] + ) + # Mask of 1DoFs (including 2DoF) into hand_model_params + mask_model_params_onedofs = torch.cat( + [torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order] + ) + + # Convert eulers to hand_cont hand_cont + ## First for 3DoFs + hand_model_params_threedofs = hand_model_params[ + ..., mask_model_params_threedofs + ].unflatten(-1, (-1, 3)) + hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1) + ## Next for 1DoFs + hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs] + hand_cont_onedofs = torch.stack( + [hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1 + ).flatten(-2, -1) + + # Finally, assemble into a 27-dim vector, ordered by joint, then XYZ. + hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params) + hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs + hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs + + return hand_cont + + +def batch9Dfrom6D(poses): + # Args: poses: ... x 6, where "6" is the combined first and second columns + # First, get the rotaiton matrix + x_raw = poses[..., :3] + y_raw = poses[..., 3:] + + x = F.normalize(x_raw, dim=-1) + z = torch.cross(x, y_raw, dim=-1) + z = F.normalize(z, dim=-1) + y = torch.cross(z, x, dim=-1) + + matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9 + + return matrix + + +def batch4Dfrom2D(poses): + # Args: poses: ... x 2, where "2" is sincos + poses_norm = F.normalize(poses, dim=-1) + + poses_4d = torch.stack( + [ + poses_norm[..., 1], + poses_norm[..., 0], + -poses_norm[..., 0], + poses_norm[..., 1], + ], + dim=-1, + ) # Flattened SO2. + + return poses_4d # .... x 4 + + +def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False): + # fmt: off + all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]) + all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]) + all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]) + # fmt: on + num_3dof_angles = len(all_param_3dof_rot_idxs) * 3 + num_1dof_angles = len(all_param_1dof_rot_idxs) + num_1dof_trans = len(all_param_1dof_trans_idxs) + assert body_pose_cont.shape[-1] == ( + 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans + ) + # Get subsets + body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles] + body_cont_1dofs = body_pose_cont[ + ..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles + ] + body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :] + # Convert conts to model params + ## First for 3dofs + body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6)) + body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1) + ## Next for 1dofs + body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos) + body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1) + if inflate_trans: + assert ( + False + ), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots" + else: + ## Nothing to do for trans + body_rotmat_trans = body_cont_trans + # Put them together + body_rotmat_params = torch.cat( + [body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1 + ) + return body_rotmat_params + + +_BODY_IDX_CACHE: dict = {} + + +def _body_idxs(device): + cached = _BODY_IDX_CACHE.get(device) + if cached is not None: + return cached + # fmt: off + all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]).to(device) + all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]).to(device) + all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]).to(device) + # fmt: on + cached = ( + all_param_3dof_rot_idxs, + all_param_1dof_rot_idxs, + all_param_1dof_trans_idxs, + all_param_3dof_rot_idxs.flatten(), + ) + _BODY_IDX_CACHE[device] = cached + return cached + + +def compact_cont_to_model_params_body(body_pose_cont): + (all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device) + num_3dof_angles = len(all_param_3dof_rot_idxs) * 3 + num_1dof_angles = len(all_param_1dof_rot_idxs) + num_1dof_trans = len(all_param_1dof_trans_idxs) + assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans + # Get subsets + body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles] + body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles] + body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :] + # Convert conts to model params + ## First for 3dofs + body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6)) + body_params_3dofs = batchXYZfrom6D(body_cont_3dofs).flatten(-2, -1) + ## Next for 1dofs + body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos) + body_params_1dofs = torch.atan2(body_cont_1dofs[..., -2], body_cont_1dofs[..., -1]) + ## Nothing to do for trans + body_params_trans = body_cont_trans + # Put them together + body_pose_params = torch.zeros(*body_pose_cont.shape[:-1], 133, dtype=body_pose_cont.dtype, device=body_pose_cont.device) + body_pose_params[..., idxs_3dof_flat] = body_params_3dofs + body_pose_params[..., all_param_1dof_rot_idxs] = body_params_1dofs + body_pose_params[..., all_param_1dof_trans_idxs] = body_params_trans + return body_pose_params + + +def compact_model_params_to_cont_body(body_pose_params): + # fmt: off + all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]) + all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]) + all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]) + # fmt: on + num_3dof_angles = len(all_param_3dof_rot_idxs) * 3 + num_1dof_angles = len(all_param_1dof_rot_idxs) + num_1dof_trans = len(all_param_1dof_trans_idxs) + assert body_pose_params.shape[-1] == ( + num_3dof_angles + num_1dof_angles + num_1dof_trans + ) + # Take out params + body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()] + body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs] + body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs] + # params to cont + body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten( + -2, -1 + ) + body_cont_1dofs = torch.stack( + [body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1 + ).flatten(-2, -1) + body_cont_trans = body_params_trans + # Put them together + body_pose_cont = torch.cat( + [body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1 + ) + return body_pose_cont + + +# fmt: off +mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115] +mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237] +mhr_param_hand_mask = torch.zeros(133).bool() +mhr_param_hand_mask[mhr_param_hand_idxs] = True +mhr_cont_hand_mask = torch.zeros(260).bool() +mhr_cont_hand_mask[mhr_cont_hand_idxs] = True +# fmt: on diff --git a/comfy/ldm/sam3d_body/model/camera_modules.py b/comfy/ldm/sam3d_body/model/camera_modules.py new file mode 100644 index 000000000..0faff8534 --- /dev/null +++ b/comfy/ldm/sam3d_body/model/camera_modules.py @@ -0,0 +1,155 @@ +import math + +import einops +import torch +import torch.nn.functional as F +from comfy.ldm.cascade.common import LayerNorm2d_op +from torch import nn + +from typing import List, Optional, Tuple, Union + +from ..utils import perspective_projection +from .transformer import MLP + +class CameraEncoder(nn.Module): + def __init__(self, embed_dim: int, patch_size: int = 14, device=None, dtype=None, operations=None): + super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.camera = FourierPositionEncoding(n=3, num_bands=16, max_resolution=64) + + self.conv = operations.Conv2d(embed_dim + 99, embed_dim, kernel_size=1, bias=False, device=device, dtype=dtype) + self.norm = LayerNorm2d_op(operations)(embed_dim, device=device, dtype=dtype) + + def forward(self, img_embeddings: torch.Tensor, rays: torch.Tensor): + B, D, _h, _w = img_embeddings.shape + + scale = 1 / self.patch_size + rays = F.interpolate(rays, scale_factor=(scale, scale), mode="bilinear", align_corners=False, antialias=True) + rays = rays.permute(0, 2, 3, 1).contiguous() # [b, h, w, 2] + rays = torch.cat([rays, torch.ones_like(rays[..., :1])], dim=-1) + rays_embeddings = self.camera(pos=rays.reshape(B, -1, 3)) # (bs, N, 99): rays fourier embedding + rays_embeddings = einops.rearrange(rays_embeddings, "b (h w) c -> b c h w", h=_h, w=_w).contiguous() + + z = torch.cat([img_embeddings, rays_embeddings], dim=1) + return self.norm(self.conv(z)) + + +class FourierPositionEncoding(nn.Module): + """Sin/cos Fourier features for ray positions""" + + def __init__(self, n: int, num_bands: int, max_resolution: int): + super().__init__() + self.num_bands = num_bands + self.max_resolution = [max_resolution] * n + + @property + def channels(self): + num_dims = len(self.max_resolution) + encoding_size = self.num_bands * num_dims + encoding_size *= 2 # sin-cos + encoding_size += num_dims # concat + + return encoding_size + + def forward(self, pos: torch.Tensor): + fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution) + return fourier_pos_enc + + +def _generate_fourier_features(pos: torch.Tensor, num_bands: int, max_resolution: List[int], min_freq: float = 1.0): + b, n = pos.shape[:2] + + freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=pos.device, dtype=pos.dtype) for res in max_resolution], dim=0) + + per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0) + per_pos_features = per_pos_features.reshape(b, n, -1) + + # Sin-Cos + per_pos_features = torch.cat([torch.sin(math.pi * per_pos_features), torch.cos(math.pi * per_pos_features)], dim=-1) + + # Concat with initial pos + per_pos_features = torch.cat([pos, per_pos_features], dim=-1) + + return per_pos_features + + +class PerspectiveHead(nn.Module): + """ + Predict camera translation (s, tx, ty) and perform full-perspective 2D reprojection (CLIFF/CameraHMR setup). + """ + + def __init__(self, input_dim: int, img_size: Union[int, Tuple[int, int]], # model input size (W, H) + mlp_depth: int = 1, mlp_channel_div_factor: int = 8, default_scale_factor: float = 1.0, + device=None, dtype=None, operations=None + ): + super().__init__() + + # Metadata to compute 3D skeleton and 2D reprojection + self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) + self.ncam = 3 # (s, tx, ty) + self.default_scale_factor = default_scale_factor + + self.proj = MLP( + input_dim=input_dim, + hidden_dim=input_dim // mlp_channel_div_factor, + output_dim=self.ncam, + num_layers=mlp_depth, + device=device, + dtype=dtype, + operations=operations, + ) + + def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None): + """ + Args: + x: pose token with shape [B, C], usually C=DECODER.DIM + init_estimate: [B, self.ncam] + """ + pred_cam = self.proj(x) + if init_estimate is not None: + pred_cam = pred_cam + init_estimate + + return pred_cam + + def perspective_projection( + self, + points_3d: torch.Tensor, + pred_cam: torch.Tensor, + bbox_center: torch.Tensor, # [N, 2], in original image space (w, h) + bbox_size: torch.Tensor, # [N,], in original image space + img_size: torch.Tensor, + cam_int: torch.Tensor, # [B, 3, 3] + use_intrin_center: bool = False, + ): + batch_size = points_3d.shape[0] + pred_cam = pred_cam.clone() + pred_cam[..., [0, 2]] *= -1 # Camera system difference + + # Compute camera translation: (scale, x, y) --> (x, y, depth) + # depth ~= f / s, Note that f is in the NDC space + s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2] + bs = bbox_size * s * self.default_scale_factor + 1e-8 + focal_length = cam_int[:, 0, 0] + tz = 2 * focal_length / bs + + if not use_intrin_center: + cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs + cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs + else: + cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs + cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs + + pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) + + # Compute camera translation + j3d_cam = points_3d + pred_cam_t.unsqueeze(1) + + # Projection to the image plane, note that the projection output is in original image space now. + j2d = perspective_projection(j3d_cam, cam_int) + + return { + "pred_keypoints_2d": j2d.reshape(batch_size, -1, 2), + "pred_keypoints_2d_depth": j3d_cam.reshape(batch_size, -1, 3)[:, :, 2], + "pred_cam_t": pred_cam_t, "focal_length": focal_length, + } diff --git a/comfy/ldm/sam3d_body/model/dinov3.py b/comfy/ldm/sam3d_body/model/dinov3.py new file mode 100644 index 000000000..97637b2f9 --- /dev/null +++ b/comfy/ldm/sam3d_body/model/dinov3.py @@ -0,0 +1,250 @@ +# DINOv3 ViT-H+ backbone for SAM 3D Body. +# +# Single-file consolidation of the inference path. SAM 3D Body only ships a +# `dinov3_vith16plus` checkpoint, so the architecture is hardcoded rather +# than reconstructed from Hydra-flavoured configs. +# +# Adapted from facebookresearch/dinov3 (DINOv3 License Agreement). Trimmed +# to what's actually exercised at inference: no multi-crop training path, +# no DINOHead, no causal blocks, no rmsnorm/Mlp variants, no rope shift / +# jitter / rescale (training-time augmentations). + +#TODO: Unify with TRELLIS2 + +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from comfy.ldm.modules.attention import optimized_attention +from torch import Tensor, nn + +# DINOv3 ViT-H+ architecture constants. +EMBED_DIM = 1280 +DEPTH = 32 +NUM_HEADS = 20 +FFN_RATIO = 6.0 +PATCH_SIZE = 16 +LAYERSCALE_INIT = 1.0e-5 +N_STORAGE_TOKENS = 4 +LAYERNORM_EPS = 1e-5 # "layernormbf16" preset uses 1e-5 +ROPE_BASE = 100.0 + +# RoPE (axial sin/cos, no learnable weights) + +def _rotate_half(x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) + + +def _apply_rope(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + return x * cos + _rotate_half(x) * sin + + +class RopePositionEmbedding(nn.Module): + """Axial RoPE for 2D patch grids; periods buffer is deterministic.""" + + def __init__(self, embed_dim: int, num_heads: int, dtype=torch.float32, device=None): + super().__init__() + assert embed_dim % (4 * num_heads) == 0 + D_head = embed_dim // num_heads + # Periods are persistent so they round-trip through state_dict, but the + # values are deterministic from D_head/base; load_state_dict will + # overwrite this with the saved buffer either way. + periods = ROPE_BASE ** ( + 2 * torch.arange(D_head // 4, dtype=dtype, device=device) / (D_head // 2) + ) + self.register_buffer("periods", periods, persistent=True) + self._dtype = dtype + + def forward(self, H: int, W: int) -> Tuple[Tensor, Tensor]: + device, dtype = self.periods.device, self._dtype + coords_h = torch.arange(0.5, H, device=device, dtype=dtype) / H + coords_w = torch.arange(0.5, W, device=device, dtype=dtype) / W + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = 2.0 * coords.flatten(0, 1) - 1.0 # [HW, 2] in [-1, +1] + angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] + angles = angles.flatten(1, 2).tile(2) # [HW, D_head] + return torch.sin(angles), torch.cos(angles) + + +def _apply_rope_to_qk(q: Tensor, k: Tensor, rope: Tuple[Tensor, Tensor]): + """Apply RoPE only to the patch-token slice (skip CLS + storage tokens).""" + sin, cos = rope + rope_dtype = sin.dtype + q_dtype, k_dtype = q.dtype, k.dtype + q = q.to(rope_dtype) + k = k.to(rope_dtype) + prefix = q.shape[-2] - sin.shape[-2] + q_pre, q_rope = q[..., :prefix, :], q[..., prefix:, :] + k_pre, k_rope = k[..., :prefix, :], k[..., prefix:, :] + q = torch.cat([q_pre, _apply_rope(q_rope, sin, cos)], dim=-2) + k = torch.cat([k_pre, _apply_rope(k_rope, sin, cos)], dim=-2) + return q.to(q_dtype), k.to(k_dtype) + +# Layers + +class LayerScale(nn.Module): + def __init__(self, dim: int, init_values: float, device=None, dtype=None): + super().__init__() + self.gamma = nn.Parameter( + torch.full((dim,), init_values, device=device, dtype=dtype) + ) + + def forward(self, x: Tensor) -> Tensor: + return x * self.gamma + + +class SwiGLUFFN(nn.Module): + """w3(silu(w1(x)) * w2(x)).""" + + def __init__(self, in_features: int, hidden_features: int, align_to: int = 8, + device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + d = int(hidden_features * 2 / 3) + h = d + (-d % align_to) + self.w1 = ops.Linear(in_features, h, bias=True, device=device, dtype=dtype) + self.w2 = ops.Linear(in_features, h, bias=True, device=device, dtype=dtype) + self.w3 = ops.Linear(h, in_features, bias=True, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + return self.w3(F.silu(self.w1(x)) * self.w2(x)) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + self.num_heads = num_heads + # DINOv3's `mask_k_bias` zeroes the K third of qkv.bias. The mask is + # deterministic from out_features, so the loader applies it in-place + # once after load_state_dict (see `apply_dinov3_qkv_bias_mask`) and the + # forward stays a plain F.linear. + self.qkv = ops.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype) + self.proj = ops.Linear(dim, dim, bias=True, device=device, dtype=dtype) + + def forward(self, x: Tensor, rope: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = qkv.unbind(2) + q, k, v = (t.transpose(1, 2) for t in (q, k, v)) + if rope is not None: + q, k = _apply_rope_to_qk(q, k, rope) + # low_precision_attention=False forces attention_sage (when enabled + # globally in comfy) to fall back to pytorch SDPA. SAM 3D Body's + # regression heads (camera projection, MHR rig math) are sensitive + # to attention output precision; sage's int8/fp8 path drifts the + # keypoints and mesh visibly. + x = optimized_attention( + q, k, v, self.num_heads, skip_reshape=True, + low_precision_attention=False, + ) + return self.proj(x) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, ffn_ratio: float, + device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + self.norm1 = ops.LayerNorm(dim, eps=LAYERNORM_EPS, device=device, dtype=dtype) + self.attn = SelfAttention(dim, num_heads, device=device, dtype=dtype, operations=operations) + self.ls1 = LayerScale(dim, LAYERSCALE_INIT, device=device, dtype=dtype) + self.norm2 = ops.LayerNorm(dim, eps=LAYERNORM_EPS, device=device, dtype=dtype) + self.mlp = SwiGLUFFN(dim, int(dim * ffn_ratio), device=device, dtype=dtype, operations=operations) + self.ls2 = LayerScale(dim, LAYERSCALE_INIT, device=device, dtype=dtype) + + def forward(self, x: Tensor, rope=None) -> Tensor: + x = x + self.ls1(self.attn(self.norm1(x), rope=rope)) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans=3, embed_dim=EMBED_DIM, patch_size=PATCH_SIZE, + device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + self.proj = ops.Conv2d( + in_chans, embed_dim, + kernel_size=patch_size, stride=patch_size, + device=device, dtype=dtype, + ) + +# Encoder + wrapper + +class _DinoEncoder(nn.Module): + """Inner ViT module. Held under `Dinov3Backbone.encoder` so state_dict + keys (`backbone.encoder.*`) match the upstream layout.""" + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + self.patch_size = PATCH_SIZE + self.embed_dim = EMBED_DIM + + self.patch_embed = PatchEmbed( + embed_dim=EMBED_DIM, patch_size=PATCH_SIZE, + device=device, dtype=dtype, operations=operations, + ) + self.cls_token = nn.Parameter(torch.empty(1, 1, EMBED_DIM, device=device, dtype=dtype)) + self.storage_tokens = nn.Parameter( + torch.empty(1, N_STORAGE_TOKENS, EMBED_DIM, device=device, dtype=dtype) + ) + # The released config sets pos_embed_rope_dtype="fp32"; periods stays + # in fp32 regardless of the backbone weight dtype. + self.rope_embed = RopePositionEmbedding(EMBED_DIM, NUM_HEADS, dtype=torch.float32, device=device) + + self.blocks = nn.ModuleList([ + Block(EMBED_DIM, NUM_HEADS, FFN_RATIO, device=device, dtype=dtype, operations=operations) + for _ in range(DEPTH) + ]) + self.norm = ops.LayerNorm(EMBED_DIM, eps=LAYERNORM_EPS, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + x = self.patch_embed.proj(x) # (B, embed_dim, H, W) + B, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) # (B, H*W, embed_dim) + + # Prepend CLS + storage tokens. + x = torch.cat([ + self.cls_token.expand(B, -1, -1), + self.storage_tokens.expand(B, -1, -1), + x, + ], dim=1) + + rope = self.rope_embed(H=H, W=W) + for blk in self.blocks: + x = blk(x, rope) + x = self.norm(x) + + # Drop CLS + storage tokens; reshape patch grid to (B, C, H, W). + x = x[:, 1 + N_STORAGE_TOKENS :] + return x.reshape(B, H, W, EMBED_DIM).permute(0, 3, 1, 2).contiguous() + + +class Dinov3Backbone(nn.Module): + """Public backbone interface used by SAM3DBody.""" + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = _DinoEncoder(device=device, dtype=dtype, operations=operations) + self.patch_size = PATCH_SIZE + self.embed_dim = self.embed_dims = EMBED_DIM + + def forward(self, x: Tensor) -> Tensor: + return self.encoder(x) + + +def apply_dinov3_qkv_bias_mask(backbone: "Dinov3Backbone") -> None: + """Zero the K third of every block's qkv.bias in-place. + + Implements DINOv3's `mask_k_bias` once at load time so the per-block forward + stays a plain F.linear instead of cloning + slicing the bias every call. + """ + for blk in backbone.encoder.blocks: + qkv = blk.attn.qkv + if qkv.bias is not None: + o = qkv.out_features + qkv.bias.data[o // 3 : 2 * o // 3] = 0 diff --git a/comfy/ldm/sam3d_body/model/model.py b/comfy/ldm/sam3d_body/model/model.py new file mode 100644 index 000000000..dbc4b8a3e --- /dev/null +++ b/comfy/ldm/sam3d_body/model/model.py @@ -0,0 +1,1197 @@ +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +from comfy.ldm.sam3.sam import PositionEmbeddingRandom + +from .dinov3 import Dinov3Backbone +from .prompt import PromptEncoder, PromptableDecoder +from ..mhr.mhr_head import MHRHead +from ..mhr.mhr_rig import MHRRig +from ..mhr.mhr_utils import fix_wrist_euler, rotation_angle_difference +from .transformer import MLP +from .camera_modules import CameraEncoder, PerspectiveHead +from comfy_extras.mediapipe.face_landmarker import FaceLandmarker +from ..utils import bbox_xyxy2cs, fix_aspect_ratio, get_warp_matrices, warp_affine_batched, euler_to_rotmat, rotmat_to_euler + +# Architecture constants for the released `dinov3-h+` SAM 3D Body checkpoint. +IMAGE_SIZE = (512, 512) +IMAGE_MEAN = (0.485, 0.456, 0.406) +IMAGE_STD = (0.229, 0.224, 0.225) +DECODER_DIM = 1024 +DECODER_DEPTH = 6 +DECODER_HEADS = 8 +DECODER_DIM_HEAD = 64 +DECODER_MLP_DIM = 1024 +MHR_MLP_DEPTH = 2 +CAMERA_MLP_DEPTH = 2 +CAMERA_DEFAULT_SCALE_FACTOR_HAND = 10.0 +N_KEYPOINTS = 70 # mhr70 + + +class SAM3DBody(nn.Module): + pelvis_idx = [9, 10] # left_hip, right_hip + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + # `operations` falls back to torch.nn so the model is constructible + # without comfy.ops; matches the pattern in comfy/ldm/sam3/. + ops = operations if operations is not None else nn + + # Per-batch state populated by `_initialize_batch`. + self._max_num_person = None + self._person_valid = None + + self.register_buffer("image_mean", torch.tensor(IMAGE_MEAN).view(-1, 1, 1), False) + self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False) + + self.image_size = IMAGE_SIZE + + self.backbone = Dinov3Backbone(device=device, dtype=dtype, operations=operations) + embed_dims = self.backbone.embed_dims + + # MHR rig shared between body + hand pose heads via a non-registered + # Python ref, so state_dict has one top-level `mhr.*` key tree (not + # duplicated under `head_pose.mhr.*` AND `head_pose_hand.mhr.*`). + self.mhr = MHRRig(device=device) + + head_kwargs = dict( + input_dim=DECODER_DIM, + mlp_depth=MHR_MLP_DEPTH, + mhr_rig=self.mhr, + mlp_channel_div_factor=1, + device=device, dtype=dtype, operations=operations, + ) + self.head_pose = MHRHead(**head_kwargs) + self.head_pose.hand_pose_comps_ori = nn.Parameter( + self.head_pose.hand_pose_comps.clone(), requires_grad=False + ) + self.head_pose.hand_pose_comps.data = ( + torch.eye(54).to(self.head_pose.hand_pose_comps.data).float() + ) + self.init_pose = ops.Embedding(1, self.head_pose.npose, device=device, dtype=dtype) + + self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs) + self.head_pose_hand.hand_pose_comps_ori = nn.Parameter( + self.head_pose_hand.hand_pose_comps.clone(), requires_grad=False + ) + self.head_pose_hand.hand_pose_comps.data = ( + torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float() + ) + self.init_pose_hand = ops.Embedding( + 1, self.head_pose_hand.npose, device=device, dtype=dtype + ) + + camera_kwargs = dict( + input_dim=DECODER_DIM, + img_size=IMAGE_SIZE, + mlp_depth=CAMERA_MLP_DEPTH, + mlp_channel_div_factor=1, + device=device, dtype=dtype, operations=operations, + ) + self.head_camera = PerspectiveHead(**camera_kwargs) + self.init_camera = ops.Embedding(1, self.head_camera.ncam, device=device, dtype=dtype) + + self.head_camera_hand = PerspectiveHead(default_scale_factor=CAMERA_DEFAULT_SCALE_FACTOR_HAND, **camera_kwargs) + self.init_camera_hand = ops.Embedding(1, self.head_camera_hand.ncam, device=device, dtype=dtype) + + cond_dim = 3 + init_dim = self.head_pose.npose + self.head_camera.ncam + cond_dim + linear_kwargs = dict(device=device, dtype=dtype) + self.init_to_token_mhr = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs) + self.prev_to_token_mhr = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) + self.init_to_token_mhr_hand = ops.Linear(init_dim, DECODER_DIM, **linear_kwargs) + self.prev_to_token_mhr_hand = ops.Linear(init_dim - cond_dim, DECODER_DIM, **linear_kwargs) + + self.prompt_encoder = PromptEncoder( + embed_dim=embed_dims, # match backbone dims so PE adds directly + num_body_joints=N_KEYPOINTS, + device=device, dtype=dtype, operations=operations, + ) + self.prompt_to_token = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) + + decoder_kwargs = dict( + dims=DECODER_DIM, + context_dims=embed_dims, + depth=DECODER_DEPTH, + num_heads=DECODER_HEADS, + head_dims=DECODER_DIM_HEAD, + mlp_dims=DECODER_MLP_DIM, + repeat_pe=True, + do_interm_preds=True, + keypoint_token_update="v2", + device=device, dtype=dtype, operations=operations, + ) + self.decoder = PromptableDecoder(**decoder_kwargs) + self.decoder_hand = PromptableDecoder(**decoder_kwargs) + self.hand_pe_layer = PositionEmbeddingRandom(embed_dims // 2) + + # Inference-time dtype set by the Loader via model.backbone.to(dtype). + self.backbone_dtype = torch.float32 + + ray_kwargs = dict( + embed_dim=embed_dims, patch_size=self.backbone.patch_size, + device=device, dtype=dtype, operations=operations, + ) + self.ray_cond_emb = CameraEncoder(**ray_kwargs) + self.ray_cond_emb_hand = CameraEncoder(**ray_kwargs) + + self.keypoint_embedding_idxs = list(range(N_KEYPOINTS)) + self.keypoint_embedding_idxs_hand = list(range(N_KEYPOINTS)) + self.keypoint_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) + self.keypoint_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) + + self.hand_box_embedding = ops.Embedding(2, DECODER_DIM, **linear_kwargs) + self.hand_cls_embed = ops.Linear(DECODER_DIM, 2, **linear_kwargs) + self.bbox_embed = MLP( + input_dim=DECODER_DIM, hidden_dim=DECODER_DIM, + output_dim=4, num_layers=3, + device=device, dtype=dtype, operations=operations, + ) + + posemb_kwargs = dict( + hidden_dim=DECODER_DIM, output_dim=DECODER_DIM, num_layers=2, + device=device, dtype=dtype, operations=operations, + ) + self.keypoint_posemb_linear = MLP(input_dim=2, **posemb_kwargs) + self.keypoint_posemb_linear_hand = MLP(input_dim=2, **posemb_kwargs) + self.keypoint_feat_linear = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) + self.keypoint_feat_linear_hand = ops.Linear(embed_dims, DECODER_DIM, **linear_kwargs) + + self.keypoint3d_embedding_idxs = list(range(N_KEYPOINTS)) + self.keypoint3d_embedding_idxs_hand = list(range(N_KEYPOINTS)) + self.keypoint3d_embedding = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) + self.keypoint3d_embedding_hand = ops.Embedding(N_KEYPOINTS, DECODER_DIM, **linear_kwargs) + self.keypoint3d_posemb_linear = MLP(input_dim=3, **posemb_kwargs) + self.keypoint3d_posemb_linear_hand = MLP(input_dim=3, **posemb_kwargs) + + self.face_landmarker = FaceLandmarker( + device=device, dtype=torch.float32, operations=None, + detector_variant="both", # short+full, picks whichever found more faces per frame. + ) + + def data_preprocess(self, inputs: torch.Tensor) -> torch.Tensor: + if inputs.max() > 1 and self.image_mean.max() <= 1.0: + inputs = inputs / 255.0 + elif inputs.max() <= 1.0 and self.image_mean.max() > 1: + inputs = inputs * 255.0 + return (inputs - self.image_mean) / self.image_std + + def _initialize_batch(self, batch: Dict) -> None: + if batch["img"].dim() == 5: + self._batch_size, self._max_num_person = batch["img"].shape[:2] + self._person_valid = self._flatten_person(batch["person_valid"]) > 0 + else: + self._batch_size = batch["img"].shape[0] + self._max_num_person = 0 + self._person_valid = None + + def _flatten_person(self, x: torch.Tensor) -> torch.Tensor: + assert self._max_num_person is not None, "No max_num_person initialized" + if self._max_num_person: + x = x.view(self._batch_size * self._max_num_person, *x.shape[2:]) + return x + + def _set_active_branch(self, kind: str) -> None: + """Route subsequent calls through the body or hand decoder by switching + which batch indices are active.""" + n = self._batch_size * self._max_num_person + all_idx = list(range(n)) + if kind == "body": + self.body_batch_idx, self.hand_batch_idx = all_idx, [] + elif kind == "hand": + self.body_batch_idx, self.hand_batch_idx = [], all_idx + else: + raise ValueError(f"Invalid branch kind: {kind!r}") + + @staticmethod + def _concat_hand_batches(a: Dict, b: Dict) -> Dict: + """Merge two prepare_batch dicts along dim 0 for a single hand pass. + Tensors cat, lists extend, scalars/metadata taken from `a`.""" + out = {} + for k, va in a.items(): + vb = b.get(k) + if isinstance(va, torch.Tensor) and isinstance(vb, torch.Tensor): + out[k] = torch.cat([va, vb], dim=0) + elif isinstance(va, list) and isinstance(vb, list): + out[k] = va + vb + else: + out[k] = va + return out + + @staticmethod + def _split_hand_output(batched: Dict, n_left: int) -> Tuple[Dict, Dict]: + """Inverse of `_concat_hand_batches`. Only `mhr_hand` needs splitting; + condition_info / image_embeddings aren't consumed downstream.""" + batched_mhr = batched["mhr_hand"] + lhand_mhr: Dict[str, Any] = {} + rhand_mhr: Dict[str, Any] = {} + for k, v in batched_mhr.items(): + if isinstance(v, torch.Tensor): + lhand_mhr[k] = v[:n_left] + rhand_mhr[k] = v[n_left:] + else: + # numpy `faces`, `pred_pose_rotmat=None`, etc. -- shared. + lhand_mhr[k] = v + rhand_mhr[k] = v + return ( + {"mhr": None, "mhr_hand": lhand_mhr}, + {"mhr": None, "mhr_hand": rhand_mhr}, + ) + + def _prepare_hand_batches_gpu( + self, + img, + left_xyxy: torch.Tensor, + right_xyxy: torch.Tensor, + cam_int: torch.Tensor, + is_multi_image: bool, + ) -> Tuple[Dict, Dict]: + """Build batch_lhand + batch_rhand directly on GPU. Bit-exact match + for the CPU `prepare_batch` × 2 path, with the source uploaded once + and both warps issued through one batched grid_sample.""" + + device = comfy.model_management.get_torch_device() + if is_multi_image: + assert isinstance(img, list) + n = len(img) + H_src, W_src = img[0].shape[:2] + src_t = torch.stack(list(img), dim=0) + else: + n = int(left_xyxy.shape[0]) + H_src, W_src = img.shape[:2] + src_t = img.unsqueeze(0).expand(n, -1, -1, -1) + + H_out, W_out = int(self.image_size[0]), int(self.image_size[1]) + bbox_padding = 0.9 # matches transform_hand + aspect = 0.75 + + def _meta(boxes_xyxy): + centers, scales = bbox_xyxy2cs(boxes_xyxy, padding=bbox_padding) + scales = fix_aspect_ratio(scales, aspect) + scales = fix_aspect_ratio(scales, W_out / H_out) + mats = get_warp_matrices(centers, scales, (H_out, W_out)) + return centers, scales, mats + + l_centers, l_scales, l_mats = _meta(left_xyxy) + r_centers, r_scales, r_mats = _meta(right_xyxy) + + src_t = src_t.to(device, non_blocking=True).permute(0, 3, 1, 2).float() + + warped_l = warp_affine_batched(torch.flip(src_t, dims=[3]), l_mats, (H_out, W_out)) + warped_r = warp_affine_batched(src_t, r_mats, (H_out, W_out)) + # floor -> /255 matches the per-item uint8 round-trip path. + l_img = (torch.floor(warped_l).clamp_(0.0, 255.0) / 255.0).contiguous() + r_img = (torch.floor(warped_r).clamp_(0.0, 255.0) / 255.0).contiguous() + + # All-zero mask + score 0 (matches prepare_batch's masks=None path). + zero_mask = torch.zeros((n, 1, H_out, W_out), dtype=torch.float32, device=device) + zero_mask_score = torch.zeros((n,), dtype=torch.float32, device=device) + person_valid = torch.ones((1, n), dtype=torch.float32, device=device) + img_size = torch.tensor([W_out, H_out], dtype=torch.float32, device=device).expand(n, 2).contiguous() + ori_img_size = torch.tensor([W_src, H_src], dtype=torch.float32, device=device).expand(n, 2).contiguous() + cam_int_dev = cam_int.to(device).to(dtype=torch.float32) + + def _build(centers_t, scales_t, mats_t, img_t, boxes_xyxy): + return { + "img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out) + "img_size": img_size.unsqueeze(0), + "ori_img_size": ori_img_size.unsqueeze(0), + "bbox_center": centers_t.to(device).unsqueeze(0), + "bbox_scale": scales_t.to(device).unsqueeze(0), + "bbox": torch.as_tensor(boxes_xyxy, dtype=torch.float32).to(device).unsqueeze(0), + "affine_trans": mats_t.to(device).unsqueeze(0), + "mask": zero_mask.unsqueeze(0), # (1, N, 1, H_out, W_out) + "mask_score": zero_mask_score.unsqueeze(0), # (1, N) + "person_valid": person_valid, # (1, N) shared OK + "cam_int": cam_int_dev, + } + + return ( + _build(l_centers, l_scales, l_mats, l_img, left_xyxy), + _build(r_centers, r_scales, r_mats, r_img, right_xyxy), + ) + + # Forward path + + def _get_decoder_condition(self, batch: Dict) -> torch.Tensor: + """CLIFF-style condition: ((cx-img_cx)/f, (cy-img_cy)/f, b/f), all in [-1, 1].""" + num_person = batch["img"].shape[1] + cx, cy = torch.chunk(self._flatten_person(batch["bbox_center"]), chunks=2, dim=-1) + b = self._flatten_person(batch["bbox_scale"])[:, [0]] + + cam_int_per_person = self._flatten_person( + batch["cam_int"].unsqueeze(1).expand(-1, num_person, -1, -1).contiguous() + ) + focal_length = cam_int_per_person[:, 0, 0] + full_img_cxy = cam_int_per_person[:, [0, 1], [2, 2]] + condition_info = torch.cat( + [cx - full_img_cxy[:, [0]], cy - full_img_cxy[:, [1]], b], dim=-1, + ) + condition_info[:, :2] = condition_info[:, :2] / focal_length.unsqueeze(-1) + condition_info[:, 2] = condition_info[:, 2] / focal_length + return condition_info.type(batch["img"].dtype) + + @staticmethod + def _append_token_block(token_embeddings, token_augment, embedding_weight, batch_size): + """Append a token block from `embedding_weight` (+ zero-block in + token_augment). Returns (token_embeddings, token_augment, start_idx).""" + start_idx = token_embeddings.shape[1] + block = embedding_weight.to(token_embeddings)[None, :, :].repeat(batch_size, 1, 1) + token_embeddings = torch.cat([token_embeddings, block], dim=1) + token_augment = torch.cat([token_augment, torch.zeros_like(block)], dim=1) + return token_embeddings, token_augment, start_idx + + def forward_decoder( + self, + branch: str, + image_embeddings: torch.Tensor, + init_estimate: Optional[torch.Tensor] = None, + keypoints: Optional[torch.Tensor] = None, + prev_estimate: Optional[torch.Tensor] = None, + condition_info: Optional[torch.Tensor] = None, + batch=None, + ): + """`branch` selects body or hand decoder + paired attribute set; rest + of the pipeline is shared. + + image_embeddings: (B, C, H, W) backbone features. + init_estimate: (B, 1, C) initial pose+cam estimate to refine. + keypoints: (B, N, 3) prompts as (x, y in [0, 1], label). + label: 0..K = joint, -1 = incorrect, -2 = invalid. + prev_estimate: (B, 1, C) previous estimate for pose refinement. + condition_info: (B, c) extra condition concatenated to input tokens. + """ + if branch == "body": + init_pose_emb = self.init_pose + init_camera_emb = self.init_camera + init_to_token = self.init_to_token_mhr + prev_to_token = self.prev_to_token_mhr + ray_cond_emb = self.ray_cond_emb + ray_cond_key = "ray_cond" + head_pose = self.head_pose + head_camera = self.head_camera + keypoint_embedding = self.keypoint_embedding + keypoint3d_embedding = self.keypoint3d_embedding + decoder = self.decoder + batch_idx = self.body_batch_idx + # Body shares the prompt encoder's PE. + image_augment_fn = self.prompt_encoder.get_dense_pe + elif branch == "hand": + init_pose_emb = self.init_pose_hand + init_camera_emb = self.init_camera_hand + init_to_token = self.init_to_token_mhr_hand + prev_to_token = self.prev_to_token_mhr_hand + ray_cond_emb = self.ray_cond_emb_hand + ray_cond_key = "ray_cond_hand" + head_pose = self.head_pose_hand + head_camera = self.head_camera_hand + keypoint_embedding = self.keypoint_embedding_hand + keypoint3d_embedding = self.keypoint3d_embedding_hand + decoder = self.decoder_hand + batch_idx = self.hand_batch_idx + # Hand decoder has its own PE layer (not the prompt encoder's). + image_augment_fn = self.hand_pe_layer + else: + raise ValueError(f"Invalid branch: {branch!r}") + + batch_size = image_embeddings.shape[0] + + # .to(image_embeddings) moves weights CPU→GPU under dynamic loading + # (they stay on CPU until first use). + if init_estimate is None: + init_pose = init_pose_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) + init_camera = init_camera_emb.weight.to(image_embeddings).expand(batch_size, -1).unsqueeze(1) + init_estimate = torch.cat([init_pose, init_camera], dim=-1) # B x 1 x (404 + 3) + + init_input = ( + torch.cat([condition_info.view(batch_size, 1, -1), init_estimate], dim=-1) + if condition_info is not None else init_estimate + ) + token_embeddings = init_to_token(init_input).view(batch_size, 1, -1) + num_pose_token = token_embeddings.shape[1] # always 1 + + image_augment, token_augment, token_mask = None, None, None + if keypoints is not None: + if prev_estimate is None: + prev_estimate = init_estimate + prev_embeddings = prev_to_token(prev_estimate).view(batch_size, 1, -1) + + # PE generated in fp32; cast back to decoder dtype. + image_augment = image_augment_fn(image_embeddings.shape[-2:]).to(image_embeddings.dtype) + + # ray_cond is fp32 from get_ray_condition; cast so CameraEncoder's + # internal cat doesn't silently promote everything back to fp32. + image_embeddings = ray_cond_emb( + image_embeddings, batch[ray_cond_key].type(image_embeddings.dtype), + ) + + # Keypoints start as [0, 0, -2]. Labels select the embedding + # weight (special for -2, -1, then per joint). + prompt_embeddings, _ = self.prompt_encoder(keypoints=keypoints) + prompt_embeddings = self.prompt_to_token(prompt_embeddings) + + # Pin dtypes so a silent fp16→fp32 promotion in any branch + # (init/prev/prompt) doesn't break the index_put assigns below. + token_embeddings = torch.cat( + [token_embeddings, prev_embeddings, prompt_embeddings], dim=1, + ).to(image_embeddings.dtype) + prev_embeddings = prev_embeddings.to(image_embeddings.dtype) + prompt_embeddings = prompt_embeddings.to(image_embeddings.dtype) + + token_augment = torch.zeros_like(token_embeddings) + token_augment[:, [num_pose_token]] = prev_embeddings + token_augment[:, (num_pose_token + 1):] = prompt_embeddings + + token_embeddings, token_augment, hand_det_emb_start_idx = self._append_token_block( + token_embeddings, token_augment, self.hand_box_embedding.weight, batch_size, + ) + token_embeddings, token_augment, kps_emb_start_idx = self._append_token_block( + token_embeddings, token_augment, keypoint_embedding.weight, batch_size, + ) + token_embeddings, token_augment, kps3d_emb_start_idx = self._append_token_block( + token_embeddings, token_augment, keypoint3d_embedding.weight, batch_size, + ) + + last_layer_idx = len(decoder.layers) - 1 + def token_to_pose_output_fn(tokens, prev_pose_output, layer_idx): + pose_token = tokens[:, 0] + prev_pose = init_pose.view(batch_size, -1) + prev_camera = init_camera.view(batch_size, -1) + # Suppress vertices on non-final layers — kp-token updates only + # need keypoints, so we skip the 18439-vertex perspective projection. + is_intermediate = layer_idx != last_layer_idx + pose_output = head_pose(pose_token, prev_pose, intermediate=is_intermediate) + pose_output["pred_cam"] = head_camera(pose_token, prev_camera) + pose_output = self.camera_project(pose_output, batch, branch=branch) + pose_output["pred_keypoints_2d_cropped"] = self._full_to_crop( + batch, pose_output["pred_keypoints_2d"], batch_idx, + ) + return pose_output + + def keypoint_token_update_fn_comb(*args): + args = self._keypoint_token_update(branch, kps_emb_start_idx, image_embeddings, *args) + args = self._keypoint3d_token_update(branch, kps3d_emb_start_idx, *args) + return args + + pose_token, pose_output = decoder( + token_embeddings, + image_embeddings, + token_augment, + image_augment, + token_mask, + token_to_pose_output_fn=token_to_pose_output_fn, + keypoint_token_update_fn=keypoint_token_update_fn_comb, + ) + + return ( + pose_token[:, hand_det_emb_start_idx:hand_det_emb_start_idx + 2], + pose_output, + ) + + + def _get_mask_prompt(self, batch, image_embeddings): + x_mask = self._flatten_person(batch["mask"]) + # batch tensors are fp32 from prepare_batch; mask_downscaling is in the + # Loader's dtype — cast once so the conv input matches. + x_mask = x_mask.to(image_embeddings.dtype) + mask_embeddings, no_mask_embeddings = self.prompt_encoder.get_mask_embeddings( + x_mask, image_embeddings.shape[0], image_embeddings.shape[2:] + ) + + mask_score = self._flatten_person(batch["mask_score"]).view(-1, 1, 1, 1).to(image_embeddings.dtype) + mask_embeddings = torch.where( + mask_score > 0, + mask_score * mask_embeddings.to(image_embeddings), + no_mask_embeddings.to(image_embeddings), + ) + return mask_embeddings + + def _full_to_crop( + self, + batch: Dict, + pred_keypoints_2d: torch.Tensor, + batch_idx: torch.Tensor = None, + ) -> torch.Tensor: + """Full-image kp coords → crop-normalized [-0.5, 0.5].""" + pred_keypoints_2d_cropped = torch.cat( + [pred_keypoints_2d, torch.ones_like(pred_keypoints_2d[:, :, [-1]])], dim=-1 + ) + if batch_idx is not None: + affine_trans = self._flatten_person(batch["affine_trans"])[batch_idx].to( + pred_keypoints_2d_cropped + ) + img_size = self._flatten_person(batch["img_size"])[batch_idx].unsqueeze(1) + else: + affine_trans = self._flatten_person(batch["affine_trans"]).to( + pred_keypoints_2d_cropped + ) + img_size = self._flatten_person(batch["img_size"]).unsqueeze(1) + pred_keypoints_2d_cropped = pred_keypoints_2d_cropped @ affine_trans.mT + pred_keypoints_2d_cropped = pred_keypoints_2d_cropped[..., :2] / img_size - 0.5 + + return pred_keypoints_2d_cropped + + def camera_project(self, pose_output: Dict, batch: Dict, branch: str = "body") -> Dict: + """Project 3D keypoints (+ optional vertices) to 2D. `branch` selects + the body or hand attribute set + batch slice.""" + head_camera = self.head_camera_hand if branch == "hand" else self.head_camera + batch_idx = self.hand_batch_idx if branch == "hand" else self.body_batch_idx + pred_cam = pose_output["pred_cam"] + + # Hoist the shared bbox/intrinsics slice so we don't recompute the + # expand+contiguous for the vertices branch. + bbox_center = self._flatten_person(batch["bbox_center"])[batch_idx] + bbox_scale = self._flatten_person(batch["bbox_scale"])[batch_idx, 0] + ori_img_size = self._flatten_person(batch["ori_img_size"])[batch_idx] + cam_int = self._flatten_person( + batch["cam_int"] + .unsqueeze(1) + .expand(-1, batch["img"].shape[1], -1, -1) + .contiguous() + )[batch_idx] + + def _project(points_3d): + return head_camera.perspective_projection( + points_3d, pred_cam, bbox_center, bbox_scale, ori_img_size, cam_int, + use_intrin_center=True, + ) + + cam_out = _project(pose_output["pred_keypoints_3d"]) + if pose_output.get("pred_vertices") is not None: + pose_output["pred_keypoints_2d_verts"] = _project( + pose_output["pred_vertices"] + )["pred_keypoints_2d"] + + pose_output.update(cam_out) + return pose_output + + def get_ray_condition(self, batch): + B, N, _, H, W = batch["img"].shape + meshgrid_xy = ( + torch.stack( + torch.meshgrid(torch.arange(H), torch.arange(W), indexing="xy"), dim=2 + )[None, None, :, :, :] + .repeat(B, N, 1, 1, 1) + .to(batch["affine_trans"].device) + ) # B x N x H x W x 2 + meshgrid_xy = ( + meshgrid_xy / batch["affine_trans"][:, :, None, None, [0, 1], [0, 1]] + ) + meshgrid_xy = ( + meshgrid_xy + - batch["affine_trans"][:, :, None, None, [0, 1], [2, 2]] + / batch["affine_trans"][:, :, None, None, [0, 1], [0, 1]] + ) + + # Subtract out center & normalize to be rays + meshgrid_xy = ( + meshgrid_xy - batch["cam_int"][:, None, None, None, [0, 1], [2, 2]] + ) + meshgrid_xy = ( + meshgrid_xy / batch["cam_int"][:, None, None, None, [0, 1], [0, 1]] + ) + + return meshgrid_xy.permute(0, 1, 4, 2, 3).to( + batch["img"].dtype + ) # This is B x num_person x 2 x H x W + + def forward_pose_branch(self, batch: Dict) -> Dict: + """One pose-decoder pass over the crop batch (body and/or hand).""" + batch_size, num_person = batch["img"].shape[:2] + + x = self.data_preprocess(self._flatten_person(batch["img"])) + + ray_cond = self._flatten_person(self.get_ray_condition(batch)) + if len(self.body_batch_idx): + batch["ray_cond"] = ray_cond[self.body_batch_idx].clone() + if len(self.hand_batch_idx): + batch["ray_cond_hand"] = ray_cond[self.hand_batch_idx].clone() + ray_cond = None + + image_embeddings = self.backbone(x.type(self.backbone_dtype)) + # bf16 mantissa too lossy for the heads — promote back. fp16 survives. + if self.backbone_dtype != torch.float16: + image_embeddings = image_embeddings.type(x.dtype) + + image_embeddings = image_embeddings + self._get_mask_prompt(batch, image_embeddings) + + # condition_info is fp32 from `_get_decoder_condition`; align to + # decoder dtype so the downstream cat doesn't auto-promote. + condition_info = self._get_decoder_condition(batch).type(image_embeddings.dtype) + + # Seed prompt: all-invalid keypoints (label = -2). + keypoints_prompt = torch.zeros((batch_size * num_person, 1, 3)).to(batch["img"]) + keypoints_prompt[:, :, -1] = -2 + + pose_output, pose_output_hand = None, None + if len(self.body_batch_idx): + tokens_output, pose_output = self.forward_decoder( + "body", + image_embeddings[self.body_batch_idx], + init_estimate=None, + keypoints=keypoints_prompt[self.body_batch_idx], + prev_estimate=None, + condition_info=condition_info[self.body_batch_idx], + batch=batch, + ) + pose_output = pose_output[-1] + if len(self.hand_batch_idx): + tokens_output_hand, pose_output_hand = self.forward_decoder( + "hand", + image_embeddings[self.hand_batch_idx], + init_estimate=None, + keypoints=keypoints_prompt[self.hand_batch_idx], + prev_estimate=None, + condition_info=condition_info[self.hand_batch_idx], + batch=batch, + ) + pose_output_hand = pose_output_hand[-1] + + output = { + "mhr": pose_output, + "mhr_hand": pose_output_hand, + "condition_info": condition_info, + "image_embeddings": image_embeddings, + } + # hand_box is (x1, y1, w, h) ∈ [0, 1]. Body path promotes to fp32 to + # match the head-MLP external contract (_get_hand_box would .float() anyway). + if len(self.body_batch_idx): + output["mhr"]["hand_box"] = self.bbox_embed(tokens_output).sigmoid().float() + output["mhr"]["hand_logits"] = self.hand_cls_embed(tokens_output).float() + if len(self.hand_batch_idx): + output["mhr_hand"]["hand_box"] = self.bbox_embed(tokens_output_hand).sigmoid() + output["mhr_hand"]["hand_logits"] = self.hand_cls_embed(tokens_output_hand) + + return output + + def forward_step(self, batch: Dict, decoder_type: str = "body") -> Tuple[Dict, Dict]: + self._set_active_branch(decoder_type) + return self.forward_pose_branch(batch) + + def run_inference( + self, + img, + batch: Dict, + inference_type: str = "full", + thresh_wrist_angle=1.4, + ): + """3DB inference. inference_type: 'full' (body + hand-refined), + 'body' (body decoder only), 'hand' (hand decoder only).""" + + is_multi_image = isinstance(img, list) + ref_img = img[0] if is_multi_image else img + height, width = ref_img.shape[:2] + cam_int = batch["cam_int"].clone() + + if inference_type == "body": + return self.forward_step(batch, decoder_type="body") + if inference_type == "hand": + return self.forward_step(batch, decoder_type="hand") + if inference_type != "full": + raise ValueError(f"Invalid inference type: {inference_type!r}") + + # 1. Body decoder pass. + pose_output = self.forward_step(batch, decoder_type="body") + left_xyxy, right_xyxy = self._get_hand_box(pose_output, batch) + ori_local_wrist_rotmat = euler_to_rotmat( + "XZY", + pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]].unflatten(1, (2, 3)), + ) + + # 2. Hand re-run. Flip the left box's x so it indexes into the + # to-be-flipped source image (the flip itself happens on GPU inside + # _prepare_hand_batches_gpu — no CPU copy of the frames needed). + tmp = left_xyxy.clone() + left_xyxy[:, 0] = width - tmp[:, 2] - 1 + left_xyxy[:, 2] = width - tmp[:, 0] - 1 + + batch_lhand, batch_rhand = self._prepare_hand_batches_gpu( + img, left_xyxy, right_xyxy, cam_int.clone(), is_multi_image, + ) + # Concat lhand+rhand along dim 0 so backbone+decoder run once on + # (2, num_person, ...) — saves one full DINOv3 ViT-H+ pass. + batch_hands = self._concat_hand_batches(batch_lhand, batch_rhand) + saved_batch_state = (self._batch_size, self._max_num_person, self._person_valid) + self._initialize_batch(batch_hands) + hands_output = self.forward_step(batch_hands, decoder_type="hand") + self._batch_size, self._max_num_person, self._person_valid = saved_batch_state + n_left = batch_lhand["img"].shape[0] * batch_lhand["img"].shape[1] + lhand_output, rhand_output = self._split_hand_output(hands_output, n_left) + # Free the batched image_embeddings/condition_info (unused downstream); + # mhr_hand views into the underlying tensors stay alive via l/rhand_output. + del hands_output, batch_hands + + # Unflip left-hand output. Keep MHR consts as 0-d on-device tensors — + # `.item()` would force four hard CPU<->GPU syncs in the hot path. + _lhand_scale = lhand_output["mhr_hand"]["scale"] + scale_r_hands_mean = self.head_pose.scale_mean[8].to(_lhand_scale) + scale_l_hands_mean = self.head_pose.scale_mean[9].to(_lhand_scale) + scale_r_hands_std = self.head_pose.scale_comps[8, 8].to(_lhand_scale) + scale_l_hands_std = self.head_pose.scale_comps[9, 9].to(_lhand_scale) + lhand_output["mhr_hand"]["scale"][:, 9] = ( + (scale_r_hands_mean + scale_r_hands_std * lhand_output["mhr_hand"]["scale"][:, 8]) + - scale_l_hands_mean + ) / scale_l_hands_std + # Right-hand global rotation flipped → used as left. + lhand_output["mhr_hand"]["joint_global_rots"][:, 78] = \ + lhand_output["mhr_hand"]["joint_global_rots"][:, 42].clone() + lhand_output["mhr_hand"]["joint_global_rots"][:, 78, [1, 2], :] *= -1 + lhand_output["mhr_hand"]["hand"][:, :54] = lhand_output["mhr_hand"]["hand"][:, 54:] + batch_lhand["bbox_center"][:, :, 0] = width - batch_lhand["bbox_center"][:, :, 0] - 1 + + # 3. Validity criteria for replacing body-decoder hand pose. + # (a) local wrist pose difference: hand vs body wrist rotations + joint_rotations = pose_output["mhr"]["joint_global_rots"] + _dev = joint_rotations.device + lowarm_joint_idxs = torch.LongTensor([76, 40]).to(_dev) # left, right + lowarm_joint_rotations = joint_rotations[:, lowarm_joint_idxs] + wrist_twist_joint_idxs = torch.LongTensor([77, 41]).to(_dev) + wrist_zero_rot_pose = lowarm_joint_rotations @ self.head_pose.joint_rotation[wrist_twist_joint_idxs] + pred_global_wrist_rotmat = torch.stack( + [lhand_output["mhr_hand"]["joint_global_rots"][:, 78], + rhand_output["mhr_hand"]["joint_global_rots"][:, 42]], + dim=1, + ) + fused_local_wrist_rotmat = torch.einsum( + "kabc,kabd->kadc", pred_global_wrist_rotmat, wrist_zero_rot_pose, + ) + angle_difference_valid_mask = rotation_angle_difference( + ori_local_wrist_rotmat, fused_local_wrist_rotmat, + ) < thresh_wrist_angle + + # (b) hand box big enough to give the decoder useful pixels + hand_box_size_thresh = 64 + hand_box_size_valid_mask = torch.stack( + [(batch_lhand["bbox_scale"].flatten(0, 1) > hand_box_size_thresh).all(dim=1), + (batch_rhand["bbox_scale"].flatten(0, 1) > hand_box_size_thresh).all(dim=1)], + dim=1, + ) + + # (c) all hand 2D keypoints inside the crop box + hand_kps2d_thresh = 0.5 + hand_kps2d_valid_mask = torch.stack( + [lhand_output["mhr_hand"]["pred_keypoints_2d_cropped"].abs().amax(dim=(1, 2)) < hand_kps2d_thresh, + rhand_output["mhr_hand"]["pred_keypoints_2d_cropped"].abs().amax(dim=(1, 2)) < hand_kps2d_thresh], + dim=1, + ) + + # (d) hand-decoder wrist close to body-decoder wrist in 2D + hand_wrist_kps2d_thresh = 0.25 + kps_right_wrist_idx, kps_left_wrist_idx = 41, 62 + right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() + left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() + left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 + body_right_kps_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() + body_left_kps_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_left_wrist_idx]].clone() + right_kps_dist = (right_kps_full - body_right_kps_full).flatten(0, 1).norm(dim=-1) \ + / batch_lhand["bbox_scale"].flatten(0, 1)[:, 0] + left_kps_dist = (left_kps_full - body_left_kps_full).flatten(0, 1).norm(dim=-1) \ + / batch_rhand["bbox_scale"].flatten(0, 1)[:, 0] + hand_wrist_kps2d_valid_mask = torch.stack( + [left_kps_dist < hand_wrist_kps2d_thresh, + right_kps_dist < hand_wrist_kps2d_thresh], + dim=1, + ) + + hand_valid_mask = ( + angle_difference_valid_mask + & hand_box_size_valid_mask + & hand_kps2d_valid_mask + & hand_wrist_kps2d_valid_mask + ) + + # Re-prompt body decoder with hand-decoder wrists + body-decoder elbows + # to get an updated body pose estimation. + self._set_active_branch("body") + + right_kps_full = rhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() + left_kps_full = lhand_output["mhr_hand"]["pred_keypoints_2d"][:, [kps_right_wrist_idx]].clone() + left_kps_full[:, :, 0] = width - left_kps_full[:, :, 0] - 1 + right_kps_crop = self._full_to_crop(batch, right_kps_full) + left_kps_crop = self._full_to_crop(batch, left_kps_full) + + kps_right_elbow_idx, kps_left_elbow_idx = 8, 7 + right_kps_elbow_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_right_elbow_idx]].clone() + left_kps_elbow_full = pose_output["mhr"]["pred_keypoints_2d"][:, [kps_left_elbow_idx]].clone() + right_kps_elbow_crop = self._full_to_crop(batch, right_kps_elbow_full) + left_kps_elbow_crop = self._full_to_crop(batch, left_kps_elbow_full) + + keypoint_prompt = torch.cat( + [right_kps_crop, left_kps_crop, right_kps_elbow_crop, left_kps_elbow_crop], dim=1, + ) + keypoint_prompt = torch.cat([keypoint_prompt, keypoint_prompt[..., [-1]]], dim=-1) + keypoint_prompt[:, 0, -1] = kps_right_wrist_idx + keypoint_prompt[:, 1, -1] = kps_left_wrist_idx + keypoint_prompt[:, 2, -1] = kps_right_elbow_idx + keypoint_prompt[:, 3, -1] = kps_left_elbow_idx + + if keypoint_prompt.shape[0] > 1: + invalid_prompt = ( + (keypoint_prompt[..., 0] < -0.5) + | (keypoint_prompt[..., 0] > 0.5) + | (keypoint_prompt[..., 1] < -0.5) + | (keypoint_prompt[..., 1] > 0.5) + | (~hand_valid_mask[..., [1, 0, 1, 0]]) + ).unsqueeze(-1) + dummy_prompt = torch.zeros((1, 1, 3)).to(keypoint_prompt) + dummy_prompt[:, :, -1] = -2 + # Shift [-0.5, 0.5] → [0, 1] for the prompt encoder. + keypoint_prompt[:, :, :2] = torch.clamp(keypoint_prompt[:, :, :2] + 0.5, 0.0, 1.0) + keypoint_prompt = torch.where(invalid_prompt, dummy_prompt, keypoint_prompt) + else: + valid_keypoint = ( + torch.all( + (keypoint_prompt[:, :, :2] > -0.5) & (keypoint_prompt[:, :, :2] < 0.5), + dim=2, + ) + & hand_valid_mask[..., [1, 0, 1, 0]] + ).squeeze() + keypoint_prompt = keypoint_prompt[:, valid_keypoint] + keypoint_prompt[:, :, :2] = torch.clamp(keypoint_prompt[:, :, :2] + 0.5, 0.0, 1.0) + + if keypoint_prompt.numel() != 0: + pose_output, _ = self.run_keypoint_prompt(batch, pose_output, keypoint_prompt) + + # 4. Drop hand pose / scale / shape from the hand decoder into the body output. + updated_hand_pose = torch.cat( + [lhand_output["mhr_hand"]["hand"][:, :54], + rhand_output["mhr_hand"]["hand"][:, 54:]], + dim=1, + ) + updated_scale = pose_output["mhr"]["scale"].clone() + updated_scale[:, 9] = lhand_output["mhr_hand"]["scale"][:, 9] + updated_scale[:, 8] = rhand_output["mhr_hand"]["scale"][:, 8] + updated_scale[:, 18:] = ( + lhand_output["mhr_hand"]["scale"][:, 18:] + + rhand_output["mhr_hand"]["scale"][:, 18:] + ) / 2 + updated_shape = pose_output["mhr"]["shape"].clone() + updated_shape[:, 40:] = ( + lhand_output["mhr_hand"]["shape"][:, 40:] + + rhand_output["mhr_hand"]["shape"][:, 40:] + ) / 2 + + # 5. IK: solve local wrist Euler from the (updated) global wrist rotmat. + joint_rotations = self.head_pose.mhr_forward( + global_trans=pose_output["mhr"]["global_rot"] * 0, + global_rot=pose_output["mhr"]["global_rot"], + body_pose_params=pose_output["mhr"]["body_pose"], + hand_pose_params=updated_hand_pose, + scale_params=updated_scale, + shape_params=updated_shape, + expr_params=pose_output["mhr"]["face"], + return_joint_rotations=True, + )[1] + _dev = joint_rotations.device + lowarm_joint_idxs = torch.LongTensor([76, 40]).to(_dev) + lowarm_joint_rotations = joint_rotations[:, lowarm_joint_idxs] + # joint_rotation is a static buffer at head dtype; cast to MHR's fp32 + # to keep the rotation matmul in fp32. + wrist_twist_joint_idxs = torch.LongTensor([77, 41]).to(_dev) + wrist_zero_rot_pose = lowarm_joint_rotations @ \ + self.head_pose.joint_rotation[wrist_twist_joint_idxs].to(joint_rotations.dtype) + pred_global_wrist_rotmat = torch.stack( + [lhand_output["mhr_hand"]["joint_global_rots"][:, 78], + rhand_output["mhr_hand"]["joint_global_rots"][:, 42]], + dim=1, + ) + fused_local_wrist_rotmat = torch.einsum( + "kabc,kabd->kadc", pred_global_wrist_rotmat, wrist_zero_rot_pose, + ) + wrist_xzy = fix_wrist_euler(rotmat_to_euler("XZY", fused_local_wrist_rotmat)) + + valid_angle = ( + (rotation_angle_difference(ori_local_wrist_rotmat, fused_local_wrist_rotmat) < thresh_wrist_angle) + & hand_valid_mask + ).unsqueeze(-1) + + body_pose = pose_output["mhr"]["body_pose"][ + :, [41, 43, 42, 31, 33, 32] + ].unflatten(1, (2, 3)) + updated_body_pose = torch.where(valid_angle, wrist_xzy, body_pose) + pose_output["mhr"]["body_pose"][:, [41, 43, 42, 31, 33, 32]] = ( + updated_body_pose.flatten(1, 2) + ) + + hand_pose = pose_output["mhr"]["hand"].unflatten(1, (2, 54)) + pose_output["mhr"]["hand"] = torch.where( + valid_angle, updated_hand_pose.unflatten(1, (2, 54)), hand_pose + ).flatten(1, 2) + + hand_scale = torch.stack( + [pose_output["mhr"]["scale"][:, 9], pose_output["mhr"]["scale"][:, 8]], + dim=1, + ) + updated_hand_scale = torch.stack( + [updated_scale[:, 9], updated_scale[:, 8]], dim=1 + ) + masked_hand_scale = torch.where( + valid_angle.squeeze(-1), updated_hand_scale, hand_scale + ) + pose_output["mhr"]["scale"][:, 9] = masked_hand_scale[:, 0] + pose_output["mhr"]["scale"][:, 8] = masked_hand_scale[:, 1] + + # Replace shared shape and scale + pose_output["mhr"]["scale"][:, 18:] = torch.where( + valid_angle.squeeze(-1).sum(dim=1, keepdim=True) > 0, + ( + lhand_output["mhr_hand"]["scale"][:, 18:] + * valid_angle.squeeze(-1)[:, [0]] + + rhand_output["mhr_hand"]["scale"][:, 18:] + * valid_angle.squeeze(-1)[:, [1]] + ) + / (valid_angle.squeeze(-1).sum(dim=1, keepdim=True) + 1e-8), + pose_output["mhr"]["scale"][:, 18:], + ) + pose_output["mhr"]["shape"][:, 40:] = torch.where( + valid_angle.squeeze(-1).sum(dim=1, keepdim=True) > 0, + ( + lhand_output["mhr_hand"]["shape"][:, 40:] + * valid_angle.squeeze(-1)[:, [0]] + + rhand_output["mhr_hand"]["shape"][:, 40:] + * valid_angle.squeeze(-1)[:, [1]] + ) + / (valid_angle.squeeze(-1).sum(dim=1, keepdim=True) + 1e-8), + pose_output["mhr"]["shape"][:, 40:], + ) + + # Re-run MHR forward with the updated parameters. + verts, j3d, jcoords, mhr_model_params, joint_global_rots = self.head_pose.mhr_forward( + global_trans=pose_output["mhr"]["global_rot"] * 0, + global_rot=pose_output["mhr"]["global_rot"], + body_pose_params=pose_output["mhr"]["body_pose"], + hand_pose_params=pose_output["mhr"]["hand"], + scale_params=pose_output["mhr"]["scale"], + shape_params=pose_output["mhr"]["shape"], + expr_params=pose_output["mhr"]["face"], + return_keypoints=True, + return_joint_coords=True, + return_model_params=True, + return_joint_rotations=True, + ) + # j3d: 308 → 70 body/hand kps + 238 face landmarks. All four buffers + # need the same y/z flip so they share a coordinate system. + j3d_face = j3d[:, 70:].clone() + j3d = j3d[:, :70] + verts[..., [1, 2]] *= -1 + j3d[..., [1, 2]] *= -1 + j3d_face[..., [1, 2]] *= -1 + jcoords[..., [1, 2]] *= -1 + pose_output["mhr"]["pred_keypoints_3d"] = j3d + pose_output["mhr"]["pred_face_keypoints_3d"] = j3d_face + pose_output["mhr"]["pred_vertices"] = verts + pose_output["mhr"]["pred_joint_coords"] = jcoords + pose_output["mhr"]["pred_pose_raw"][...] = 0 # invalidated by the IK update + pose_output["mhr"]["mhr_model_params"] = mhr_model_params + + def _project_kp3d(kp3d: torch.Tensor) -> torch.Tensor: + proj = kp3d + pose_output["mhr"]["pred_cam_t"][:, None, :] + proj[:, :, [0, 1]] = proj[:, :, [0, 1]] * pose_output["mhr"]["focal_length"][:, None, None] + proj[:, :, [0, 1]] = ( + proj[:, :, [0, 1]] + + torch.FloatTensor([width / 2, height / 2]).to(proj)[None, None, :] + * proj[:, :, [2]] + ) + proj[:, :, :2] = proj[:, :, :2] / proj[:, :, [2]] + return proj[:, :, :2] + + pose_output["mhr"]["pred_keypoints_2d"] = _project_kp3d( + pose_output["mhr"]["pred_keypoints_3d"].clone() + ) + pose_output["mhr"]["pred_face_keypoints_2d"] = _project_kp3d( + pose_output["mhr"]["pred_face_keypoints_3d"].clone() + ) + + return pose_output, batch_lhand, batch_rhand, lhand_output, rhand_output + + def run_keypoint_prompt(self, batch, output, keypoint_prompt): + image_embeddings = output["image_embeddings"] + condition_info = output["condition_info"] + pose_output = output["mhr"] + prev_estimate = torch.cat( + [ + pose_output["pred_pose_raw"], + pose_output["shape"], + pose_output["scale"], + pose_output["hand"], + pose_output["face"], + ], + dim=1, + ).unsqueeze(1) + prev_estimate = torch.cat( + [prev_estimate, pose_output["pred_cam"].unsqueeze(1)], dim=-1, + ) + + _, pose_output = self.forward_decoder( + "body", + image_embeddings, + init_estimate=None, # use the default init, not the prev estimate + keypoints=keypoint_prompt, + prev_estimate=prev_estimate, + condition_info=condition_info, + batch=batch, + ) + pose_output = pose_output[-1] + + output.update({"mhr": pose_output}) + return output, keypoint_prompt + + def _get_hand_box(self, pose_output, batch): + """Hand bbox from the detector → full-image coords (xyxy). Stays on + device throughout.""" + hand_box = pose_output["mhr"]["hand_box"] # (B, 2, 4) fp32 + pred_left_hand_box = hand_box[:, 0] * self.image_size[0] + pred_right_hand_box = hand_box[:, 1] * self.image_size[0] + + # Square the boxes (long side wins). + batch["left_center"] = pred_left_hand_box[:, :2] + batch["left_scale"] = pred_left_hand_box[:, 2:].amax(dim=1, keepdim=True).repeat(1, 2) + batch["right_center"] = pred_right_hand_box[:, :2] + batch["right_scale"] = pred_right_hand_box[:, 2:].amax(dim=1, keepdim=True).repeat(1, 2) + + # Invert the crop's full→crop affine. rot=0 makes it diagonal: + # divide-by-scale and subtract translation offset. + affine_trans = batch["affine_trans"][0] + affine_scale = affine_trans[:, 0, 0] + affine_offset = affine_trans[:, :2, 2] + batch["left_scale"] = batch["left_scale"] / affine_scale[:, None] + batch["right_scale"] = batch["right_scale"] / affine_scale[:, None] + batch["left_center"] = (batch["left_center"] - affine_offset) / affine_scale[:, None] + batch["right_center"] = (batch["right_center"] - affine_offset) / affine_scale[:, None] + + left_xyxy = torch.stack( + [ + batch["left_center"][:, 0] - batch["left_scale"][:, 0] / 2, + batch["left_center"][:, 1] - batch["left_scale"][:, 1] / 2, + batch["left_center"][:, 0] + batch["left_scale"][:, 0] / 2, + batch["left_center"][:, 1] + batch["left_scale"][:, 1] / 2, + ], + dim=1, + ) + right_xyxy = torch.stack( + [ + batch["right_center"][:, 0] - batch["right_scale"][:, 0] / 2, + batch["right_center"][:, 1] - batch["right_scale"][:, 1] / 2, + batch["right_center"][:, 0] + batch["right_scale"][:, 0] / 2, + batch["right_center"][:, 1] + batch["right_scale"][:, 1] / 2, + ], + dim=1, + ) + + return left_xyxy, right_xyxy + + + # Shared 2D-keypoint-driven token update. `branch` picks body/hand attrs; + # rest is identical. Called via keypoint_token_update_fn_comb in + # forward_decoder. + def _keypoint_token_update( + self, + branch: str, + kps_emb_start_idx, + image_embeddings, + token_embeddings, + token_augment, + pose_output, + layer_idx, + ): + if branch == "body": + decoder_layers = self.decoder.layers + kp_emb_w = self.keypoint_embedding.weight + kp_idxs = self.keypoint_embedding_idxs + posemb_linear = self.keypoint_posemb_linear + feat_linear = self.keypoint_feat_linear + else: + decoder_layers = self.decoder_hand.layers + kp_emb_w = self.keypoint_embedding_hand.weight + kp_idxs = self.keypoint_embedding_idxs_hand + posemb_linear = self.keypoint_posemb_linear_hand + feat_linear = self.keypoint_feat_linear_hand + + # Last layer's pose output is final — nothing to inject back. + if layer_idx == len(decoder_layers) - 1: + return token_embeddings, token_augment, pose_output, layer_idx + + token_augment = token_augment.clone() + num_keypoints = kp_emb_w.shape[0] + + # kp comes from fp32 MHR/cam projection; cast once to decoder dtype + # so posemb / grid_sample match. + pred_keypoints_2d_cropped = pose_output["pred_keypoints_2d_cropped"].clone()[:, kp_idxs] + pred_keypoints_2d_depth = pose_output["pred_keypoints_2d_depth"].clone()[:, kp_idxs] + pred_keypoints_2d_cropped = pred_keypoints_2d_cropped.to(image_embeddings.dtype) + + # Mask out-of-frame OR behind-camera keypoints' contributions. + pred_keypoints_2d_cropped_01 = pred_keypoints_2d_cropped + 0.5 + invalid_mask = ( + (pred_keypoints_2d_cropped_01[:, :, 0] < 0) + | (pred_keypoints_2d_cropped_01[:, :, 0] > 1) + | (pred_keypoints_2d_cropped_01[:, :, 1] < 0) + | (pred_keypoints_2d_cropped_01[:, :, 1] > 1) + | (pred_keypoints_2d_depth[:, :] < 1e-5) + ) + + token_augment[:, kps_emb_start_idx : kps_emb_start_idx + num_keypoints, :] = ( + posemb_linear(pred_keypoints_2d_cropped) * (~invalid_mask[:, :, None]) + ) + + # Bilinear-sample image features at each kp's projected location. + # grid_sample wants -1..1; cropped form is -0.5..0.5, so ×2. + sample_points = pred_keypoints_2d_cropped * 2 + feats = F.grid_sample( + image_embeddings, sample_points[:, :, None, :], + mode="bilinear", padding_mode="zeros", align_corners=False, + ).squeeze(3).permute(0, 2, 1) + feats = feats * (~invalid_mask[:, :, None]) + + token_embeddings = token_embeddings.clone() + token_embeddings[:, kps_emb_start_idx : kps_emb_start_idx + num_keypoints, :] += ( + feat_linear(feats) + ) + + return token_embeddings, token_augment, pose_output, layer_idx + + def _keypoint3d_token_update( + self, + branch: str, + kps3d_emb_start_idx, + token_embeddings, + token_augment, + pose_output, + layer_idx, + ): + if branch == "body": + decoder_layers = self.decoder.layers + kp3d_emb_w = self.keypoint3d_embedding.weight + kp3d_idxs = self.keypoint3d_embedding_idxs + posemb_linear = self.keypoint3d_posemb_linear + else: + decoder_layers = self.decoder_hand.layers + kp3d_emb_w = self.keypoint3d_embedding_hand.weight + kp3d_idxs = self.keypoint3d_embedding_idxs_hand + posemb_linear = self.keypoint3d_posemb_linear_hand + + if layer_idx == len(decoder_layers) - 1: + return token_embeddings, token_augment, pose_output, layer_idx + + num_keypoints3d = kp3d_emb_w.shape[0] + + # Pelvis-normalize so 3D kps live in subject-centric coords (don't + # leak global cam translation into the token signal). Cast to decoder + # dtype before posemb_linear writes back into token_augment. + pred_keypoints_3d = pose_output["pred_keypoints_3d"].clone() + pred_keypoints_3d = pred_keypoints_3d - ( + pred_keypoints_3d[:, [self.pelvis_idx[0]], :] + + pred_keypoints_3d[:, [self.pelvis_idx[1]], :] + ) / 2 + pred_keypoints_3d = pred_keypoints_3d[:, kp3d_idxs].to(token_augment.dtype) + + token_augment = token_augment.clone() + token_augment[:, kps3d_emb_start_idx : kps3d_emb_start_idx + num_keypoints3d, :] = ( + posemb_linear(pred_keypoints_3d) + ) + + return token_embeddings, token_augment, pose_output, layer_idx diff --git a/comfy/ldm/sam3d_body/model/prompt.py b/comfy/ldm/sam3d_body/model/prompt.py new file mode 100644 index 000000000..2a5466276 --- /dev/null +++ b/comfy/ldm/sam3d_body/model/prompt.py @@ -0,0 +1,272 @@ +"""SAM 3D Body prompt pipeline: encode (keypoint, mask) prompts and run them +through a cross-attention transformer decoder over (token, image) pairs. + +Both adapted from the SAM-style prompt path (Meta, Apache 2.0): +https://github.com/facebookresearch/segment-anything +""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from comfy.ldm.cascade.common import LayerNorm2d_op +from comfy.ldm.sam3.sam import PositionEmbeddingRandom + +from .transformer import TransformerDecoderLayer + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + num_body_joints: int, + device=None, + dtype=None, + operations=None, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + """ + super().__init__() + ops = operations if operations is not None else nn + self.embed_dim = embed_dim + self.num_body_joints = num_body_joints + + # Keypoint prompts + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + self.point_embeddings = nn.ModuleList( + [ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)] + ) + self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) + self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) + + # Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim. + LN2d = LayerNorm2d_op(ops) + mask_in_chans = 256 + self.mask_downscaling = nn.Sequential( + ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(mask_in_chans // 64, device=device, dtype=dtype), + nn.GELU(), + ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(mask_in_chans // 16, device=device, dtype=dtype), + nn.GELU(), + ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(mask_in_chans // 4, device=device, dtype=dtype), + nn.GELU(), + ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(mask_in_chans, device=device, dtype=dtype), + nn.GELU(), + ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype), + ) + # Trained values for the gating conv and no_mask_embed are loaded from the state dict + self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype) + + def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor: + """Positional encoding over the image-embedding grid; (1, C, H, W).""" + return self.pe_layer(size) + + def _embed_keypoints(self, points: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Embeds point prompts. + Assuming points have been normalized to [0, 1]. + + Output shape [B, N, C], mask shape [B, N] + """ + assert points.min() >= 0 and points.max() <= 1 + # PE compute in fp32 for precision (sin/cos of large coords), then cast back to the embedding weight dtype + weight_dtype = self.invalid_point_embed.weight.dtype + point_embedding = self.pe_layer._encode(points.to(torch.float)).to(weight_dtype) + point_embedding[labels == -2] = 0.0 # invalid points + point_embedding[labels == -2] += self.invalid_point_embed.weight.to(point_embedding) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding) + for i in range(self.num_body_joints): + point_embedding[labels == i] += self.point_embeddings[i].weight.to(point_embedding) + + point_mask = labels > -2 + return point_embedding, point_mask + + def _get_batch_size(self, keypoints: Optional[torch.Tensor], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor]) -> int: + if keypoints is not None: + return keypoints.shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def forward( + self, + keypoints: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + masks: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + keypoints (torchTensor or none): point coordinates and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(keypoints, boxes, masks) + # Anchor device on the input prompts so we don't pull the offloaded + # CPU embedding device under dynamic loading. + ref = keypoints if keypoints is not None else boxes if boxes is not None else masks + device = ref.device if ref is not None else self.point_embeddings[0].weight.device + weight_dtype = self.invalid_point_embed.weight.dtype + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=device, dtype=weight_dtype) + sparse_masks = torch.empty((bs, 0), device=device) + if keypoints is not None: + coords = keypoints[:, :, :2] + labels = keypoints[:, :, -1] + point_embeddings, point_mask = self._embed_keypoints(coords, labels) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + sparse_masks = torch.cat([sparse_masks, point_mask], dim=1) + + return sparse_embeddings, sparse_masks + + def get_mask_embeddings( + self, + masks: Optional[torch.Tensor] = None, + bs: int = 1, + size: Tuple[int, int] = (16, 16), # [H, W] + ) -> torch.Tensor: + """Embeds mask inputs.""" + # masks is always on the active device when present; fall back to the + # downscaling Conv's weight device when it isn't (rare callers). + ref = masks if masks is not None else next(self.mask_downscaling.parameters()) + no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand( + bs, -1, size[0], size[1] + ) + if masks is not None: + mask_embeddings = self.mask_downscaling(masks) + else: + mask_embeddings = no_mask_embeddings + return mask_embeddings, no_mask_embeddings + + +class PromptableDecoder(nn.Module): + """Cross-attention transformer decoder over (token, image) pairs.""" + + def __init__( + self, + dims: int, + context_dims: int, + depth: int, + num_heads: int = 8, + head_dims: int = 64, + mlp_dims: int = 1024, + repeat_pe: bool = False, + do_interm_preds: bool = False, + keypoint_token_update: bool = False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + ops = operations if operations is not None else nn + + self.layers = nn.ModuleList( + TransformerDecoderLayer( + token_dims=dims, + context_dims=context_dims, + num_heads=num_heads, + head_dims=head_dims, + mlp_dims=mlp_dims, + repeat_pe=repeat_pe, + skip_first_pe=(i == 0), + device=device, + dtype=dtype, + operations=operations, + ) + for i in range(depth) + ) + + self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype) + self.do_interm_preds = do_interm_preds + self.keypoint_token_update = keypoint_token_update + + def forward( + self, + token_embedding: torch.Tensor, + image_embedding: torch.Tensor, + token_augment: Optional[torch.Tensor] = None, + image_augment: Optional[torch.Tensor] = None, + token_mask: Optional[torch.Tensor] = None, + token_to_pose_output_fn=None, + keypoint_token_update_fn=None, + hand_embeddings=None, + hand_augment=None, + ): + """ + Args: + token_embedding: [B, N, C] + image_embedding: [B, C, H, W] -- flattened to [B, HW, C] inline + """ + # Channels-last for the transformer. + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + if image_augment is not None: + image_augment = image_augment.flatten(2).permute(0, 2, 1) + if hand_embeddings is not None: + hand_embeddings = hand_embeddings.flatten(2).permute(0, 2, 1) + hand_augment = hand_augment.flatten(2).permute(0, 2, 1) + if len(hand_augment) == 1: + # inflate batch dimension + assert len(hand_augment.shape) == 3 + hand_augment = hand_augment.repeat(len(hand_embeddings), 1, 1) + + all_pose_outputs = [] if self.do_interm_preds else None + if self.do_interm_preds: + assert token_to_pose_output_fn is not None + + layer_idx = 0 + for layer_idx, layer in enumerate(self.layers): + if hand_embeddings is None: + token_embedding, image_embedding = layer( + token_embedding, image_embedding, + token_augment, image_augment, token_mask, + ) + else: + token_embedding, image_embedding = layer( + token_embedding, + torch.cat([image_embedding, hand_embeddings], dim=1), + token_augment, + torch.cat([image_augment, hand_augment], dim=1), + token_mask, + ) + image_embedding = image_embedding[:, : image_augment.shape[1]] + + if self.do_interm_preds and layer_idx < len(self.layers) - 1: + curr = token_to_pose_output_fn( + self.norm_final(token_embedding), + prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None, + layer_idx=layer_idx, + ) + all_pose_outputs.append(curr) + if self.keypoint_token_update: + assert keypoint_token_update_fn is not None + token_embedding, token_augment, _, _ = keypoint_token_update_fn( + token_embedding, token_augment, curr, layer_idx, + ) + + out = self.norm_final(token_embedding) + if self.do_interm_preds: + curr = token_to_pose_output_fn( + out, + prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None, + layer_idx=layer_idx, + ) + all_pose_outputs.append(curr) + return out, all_pose_outputs + return out diff --git a/comfy/ldm/sam3d_body/model/transformer.py b/comfy/ldm/sam3d_body/model/transformer.py new file mode 100644 index 000000000..edf123d6b --- /dev/null +++ b/comfy/ldm/sam3d_body/model/transformer.py @@ -0,0 +1,104 @@ +from typing import Optional + +import torch +import torch.nn as nn +from comfy.ldm.modules.attention import optimized_attention + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act_layer=nn.ReLU, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList( + operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) + for i in range(num_layers) + ) + self.act = act_layer() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < len(self.layers) - 1 else layer(x) + return x + +class Attention(nn.Module): + def __init__(self, embed_dims, num_heads, query_dims=None, key_dims=None, value_dims=None, qkv_bias=True, proj_bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.query_dims = query_dims or embed_dims + self.key_dims = key_dims or embed_dims + self.value_dims = value_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.head_dims = embed_dims // num_heads + + lin = lambda i, o, b: operations.Linear(i, o, bias=b, device=device, dtype=dtype) + self.q_proj = lin(self.query_dims, embed_dims, qkv_bias) + self.k_proj = lin(self.key_dims, embed_dims, qkv_bias) + self.v_proj = lin(self.value_dims, embed_dims, qkv_bias) + self.proj = lin(embed_dims, self.query_dims, proj_bias) + + def _split(self, x: torch.Tensor) -> torch.Tensor: + b, n, _ = x.shape + return x.reshape(b, n, self.num_heads, self.head_dims).transpose(1, 2) + + def forward(self, q, k, v, attn_mask: Optional[torch.Tensor] = None): + q, k, v = self._split(self.q_proj(q)), self._split(self.k_proj(k)), self._split(self.v_proj(v)) + x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True, low_precision_attention=False) + return self.proj(x) + +class TransformerDecoderLayer(nn.Module): + def __init__(self, token_dims, context_dims, num_heads=8, head_dims=64, mlp_dims=1024, + repeat_pe=False, skip_first_pe=False, device=None, dtype=None, operations=None): + super().__init__() + self.repeat_pe = repeat_pe + self.skip_first_pe = skip_first_pe + + ln = lambda d: operations.LayerNorm(d, eps=1e-6, device=device, dtype=dtype) + attn_dim = num_heads * head_dims + attn_kwargs = dict(embed_dims=attn_dim, num_heads=num_heads, device=device, dtype=dtype, operations=operations) + + if repeat_pe: + self.ln_pe_1, self.ln_pe_2 = ln(token_dims), ln(context_dims) + + self.ln1 = ln(token_dims) + self.self_attn = Attention(query_dims=token_dims, key_dims=token_dims, value_dims=token_dims, **attn_kwargs) + + self.ln2_1, self.ln2_2 = ln(token_dims), ln(context_dims) + self.cross_attn = Attention(query_dims=token_dims, key_dims=context_dims, value_dims=context_dims, **attn_kwargs) + + self.ln3 = ln(token_dims) + self.ffn = MLP(token_dims, mlp_dims, token_dims, num_layers=2, act_layer=nn.GELU, device=device, dtype=dtype, operations=operations) + + def forward(self, x, context, x_pe=None, context_pe=None, x_mask=None): + """x: [B, N_tokens, C], context: [B, N_ctx, C], x_mask: [B, N_tokens] or None.""" + # LaPE-style PE re-norm per layer. + if self.repeat_pe and context_pe is not None: + x_pe = self.ln_pe_1(x_pe) + context_pe = self.ln_pe_2(context_pe) + + # Self-attn over tokens. + if self.repeat_pe and not self.skip_first_pe and x_pe is not None: + q = k = self.ln1(x) + x_pe + v = self.ln1(x) + else: + q = k = v = self.ln1(x) + + attn_mask = None + if x_mask is not None: + attn_mask = x_mask[:, :, None] @ x_mask[:, None, :] + attn_mask.diagonal(dim1=1, dim2=2).fill_(1) # avoid all-invalid rows -> nan + attn_mask = attn_mask > 0 + x = x + self.self_attn(q, k, v, attn_mask=attn_mask) + + # Cross-attn: tokens attend to image context. + if self.repeat_pe and context_pe is not None: + q = self.ln2_1(x) + x_pe + k = self.ln2_2(context) + context_pe + v = self.ln2_2(context) + else: + q = self.ln2_1(x) + k = v = self.ln2_2(context) + x = x + self.cross_attn(q, k, v) + + x = x + self.ffn(self.ln3(x)) + return x, context diff --git a/comfy/ldm/sam3d_body/utils.py b/comfy/ldm/sam3d_body/utils.py new file mode 100644 index 000000000..45b7cb014 --- /dev/null +++ b/comfy/ldm/sam3d_body/utils.py @@ -0,0 +1,341 @@ +# The bbox/affine math (xyxy<->cs, get_warp_matrices) is the standard +# top-down pose-estimation crop pipeline from MMPose (Apache 2.0): +# https://github.com/open-mmlab/mmpose — same algorithm as UDP (CVPR 2020). + +from typing import Dict, Tuple + +import torch +import torch.nn.functional as F + + +# Bbox + affine math +# All `output_size` / image-shape tuples in this block are (H, W) to match +# the torch.Size convention used everywhere else in the codebase. + +def bbox_xyxy2cs(bbox, padding: float) -> Tuple[torch.Tensor, torch.Tensor]: + """xyxy bbox -> (center, scale) with optional padding multiplier.""" + bbox = torch.as_tensor(bbox, dtype=torch.float32) + dim = bbox.dim() + if dim == 1: + bbox = bbox.unsqueeze(0) + x1, y1, x2, y2 = bbox[:, 0:1], bbox[:, 1:2], bbox[:, 2:3], bbox[:, 3:4] + center = torch.cat([x1 + x2, y1 + y2], dim=1) * 0.5 + scale = torch.cat([x2 - x1, y2 - y1], dim=1) * padding + if dim == 1: + return center[0], scale[0] + return center, scale + + +def fix_aspect_ratio(bbox_scale, aspect_ratio: float) -> torch.Tensor: + """Pad whichever side is too narrow to hit `aspect_ratio` (w/h).""" + bbox_scale = torch.as_tensor(bbox_scale, dtype=torch.float32) + dim = bbox_scale.dim() + if dim == 1: + bbox_scale = bbox_scale.unsqueeze(0) + w, h = bbox_scale[:, 0:1], bbox_scale[:, 1:2] + out = torch.where( + w > h * aspect_ratio, + torch.cat([w, w / aspect_ratio], dim=1), + torch.cat([h * aspect_ratio, h], dim=1), + ) + return out[0] if dim == 1 else out + + +def get_warp_matrices(centers, scales, output_size: Tuple[int, int]) -> torch.Tensor: + """Batched 2x3 affine matrices mapping each (center, scale) bbox region to + the output box. `output_size` is (H_out, W_out). With rot=0 the MMPose + 3-point fit reduces to a closed-form isotropic scale + translate. + """ + centers = torch.as_tensor(centers, dtype=torch.float32) + scales = torch.as_tensor(scales, dtype=torch.float32) + if centers.dim() == 1: + centers = centers.unsqueeze(0) + scales = scales.unsqueeze(0) + n = centers.shape[0] + src_w = scales[:, 0] + dst_h = float(output_size[0]) + dst_w = float(output_size[1]) + # With rot=0 the warp is just scale + translate (uniform x/y scale based + # on src_w/dst_w). The closed form drops out of MMPose's 3-point solve. + s = dst_w / src_w # (N,) + mats = torch.zeros((n, 2, 3), dtype=torch.float32) + mats[:, 0, 0] = s + mats[:, 1, 1] = s + mats[:, 0, 2] = dst_w * 0.5 - s * centers[:, 0] + mats[:, 1, 2] = dst_h * 0.5 - s * centers[:, 1] + return mats # (N, 2, 3) + + +def warp_affine_batched( + src_t: torch.Tensor, # (N, C, H_src, W_src) float + mats: torch.Tensor, # (N, 2, 3) float + output_size: Tuple[int, int] # (H_out, W_out) + ) -> torch.Tensor: + """Apply N forward (src->dst) 2x3 affine warps to N source images in one + grid_sample call. Kept generic over arbitrary affines (not specialized to + the scale+translate produced by `get_warp_matrices`) so callers can pass + rotated/sheared affines; the per-crop 3x3 invert is O(N) of trivial work.""" + + H_out, W_out = int(output_size[0]), int(output_size[1]) + N, _, H_src, W_src = src_t.shape + device = src_t.device + + # Invert each forward affine; grid_sample needs dst->src. + mats_t = mats.to(device=device, dtype=torch.float32) + bottom = torch.tensor([0.0, 0.0, 1.0], device=device).expand(N, 1, 3) + mats_3 = torch.cat([mats_t, bottom], dim=1) # (N, 3, 3) + mats_inv = torch.linalg.inv(mats_3)[:, :2, :] # (N, 2, 3) + + # Output pixel-center grid (i+0.5, j+0.5). + ys, xs = torch.meshgrid( + torch.arange(H_out, dtype=torch.float32, device=device) + 0.5, + torch.arange(W_out, dtype=torch.float32, device=device) + 0.5, + indexing="ij", + ) + homo = torch.stack([xs, ys, torch.ones_like(xs)], dim=-1) # (H_out, W_out, 3) + src_pos = torch.einsum("nkl,ijl->nijk", mats_inv, homo) # (N, H_out, W_out, 2) + # Normalize to [-1, 1] grid_sample coords (align_corners=False). + src_pos[..., 0] = src_pos[..., 0] / W_src * 2 - 1 + src_pos[..., 1] = src_pos[..., 1] / H_src * 2 - 1 + + return F.grid_sample(src_t, src_pos, mode="bilinear", padding_mode="zeros", align_corners=False) + + +# Batch construction (one prediction over N person crops from a single image) + +def prepare_batch( + img, # (H, W, 3) uint8 torch tensor or list of such tensors + boxes, # (N, 4) xyxy (numpy or torch) + input_size: Tuple[int, int], # (W, H) of the model crop + bbox_padding: float = 1.25, # xyxy->cs padding multiplier (1.25 body, 0.9 hand) + aspect_ratio: float = 0.75, # w/h of the crop (0.75 matches HMR2/Sapiens) + masks=None, # optional per-person masks + masks_score=None, # optional per-person mask scores + cam_int=None, # optional camera intrinsics +) -> Dict: + """Build the batch dict the SAM3DBody forward expects, doing the N crops in one batched `grid_sample` call.""" + + is_multi_image = isinstance(img, list) + if is_multi_image: + assert len(img) == boxes.shape[0] + height, width = img[0].shape[:2] + else: + height, width = img.shape[:2] + + n = int(boxes.shape[0]) + assert n > 0, "prepare_batch needs at least one box" + + W_out, H_out = int(input_size[0]), int(input_size[1]) + + # Per-box bbox math (cheap, vectorized, CPU). + centers, scales = bbox_xyxy2cs(boxes, padding=bbox_padding) + # Two passes: first hits the upstream bbox aspect (e.g. 0.75 HMR2/Sapiens + # convention), second pads further if the model crop's W_out/H_out differs + # from that. When they match (common case) the second call is a no-op. + scales = fix_aspect_ratio(scales, aspect_ratio) + scales = fix_aspect_ratio(scales, W_out / H_out) + mats = get_warp_matrices(centers, scales, (H_out, W_out)) # (N, 2, 3) + + # Stack source images into a contiguous (N, 3, H, W) tensor on CPU. + if is_multi_image: + src_t = torch.stack(list(img), dim=0) + else: + src_t = img.unsqueeze(0).expand(n, -1, -1, -1) + src_t = src_t.permute(0, 3, 1, 2).contiguous().float() # (N, 3, H, W) in [0, 255] + + warped_t = warp_affine_batched(src_t, mats, (H_out, W_out)) # (N, 3, H_out, W_out) + # Float warp -> floor (matches the legacy uint8 round-trip) -> /255. + img_t = torch.floor(warped_t).clamp_(0.0, 255.0) / 255.0 # (N, 3, H_out, W_out) in [0, 1] + + # Masks: zero-init when missing, otherwise stack and warp through the same matrices. + boxes_t = torch.as_tensor(boxes, dtype=torch.float32) + if masks is None: + mask_t = torch.zeros((n, H_out, W_out), dtype=torch.float32) + mask_score_t = torch.zeros((n,), dtype=torch.float32) + else: + # masks is an array of N items, each (H, W) or (H, W, 1). + masks_t = torch.stack([torch.as_tensor(masks[i]) for i in range(n)], dim=0) + if masks_t.dim() == 4 and masks_t.shape[-1] == 1: + masks_t = masks_t[..., 0] + masks_src_t = masks_t.float().unsqueeze(1) # (N, 1, H, W) in [0, 255] + warped_masks = warp_affine_batched(masks_src_t, mats, (H_out, W_out)) + mask_t = torch.floor(warped_masks.squeeze(1)).clamp_(0.0, 255.0) + if masks_score is not None: + mask_score_t = torch.as_tensor([masks_score[i] for i in range(n)], dtype=torch.float32) + else: + mask_score_t = torch.ones((n,), dtype=torch.float32) + + img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous() + ori_img_size_t = torch.tensor([width, height], dtype=torch.float32).expand(n, 2).contiguous() + + batch = { + "img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out) + "img_size": img_size_t.unsqueeze(0), # (1, N, 2) + "ori_img_size": ori_img_size_t.unsqueeze(0),# (1, N, 2) + "bbox_center": centers.unsqueeze(0), # (1, N, 2) + "bbox_scale": scales.unsqueeze(0), # (1, N, 2) + "bbox": boxes_t.unsqueeze(0), # (1, N, 4) + "affine_trans": mats.unsqueeze(0), # (1, N, 2, 3) + "mask": mask_t.unsqueeze(0).unsqueeze(2), # (1, N, 1, H_out, W_out) + "mask_score": mask_score_t.unsqueeze(0), # (1, N) + "person_valid": torch.ones((1, n), dtype=torch.float32), + } + + if cam_int is not None: + batch["cam_int"] = cam_int.to(batch["img"]) + else: + # Default intrinsics: focal = sqrt(W^2 + H^2), principal point = image center. + f = (height ** 2 + width ** 2) ** 0.5 + batch["cam_int"] = torch.tensor( + [[[f, 0, width / 2.0], [0, f, height / 2.0], [0, 0, 1]]], + ).to(batch["img"]) + + return batch + + +# Geometry utils + +def rot6d_to_rotmat( + x: torch.Tensor # (B, 6) batch of 6-D rotation representations. + ) -> torch.Tensor: # (B, 3, 3) rotation matrices. + """6D continuous rotation rep (Zhou et al., CVPR 2019) -> 3x3 rotation matrix.""" + x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous() + a1, a2 = x[:, :, 0], x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1) + b3 = torch.linalg.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def perspective_projection( + x: torch.Tensor, # (B, N, 3) 3D points in camera coords. + K: torch.Tensor # (B, 3, 3) camera intrinsics. + ) -> torch.Tensor: # (B, N, 2) 2D image-plane projections. + """Project 3D points (already in camera frame) through intrinsics K.""" + y = x / x[:, :, -1].unsqueeze(-1) # perspective divide + y = torch.einsum("bij,bkj->bki", K, y) # apply intrinsics + return y[:, :, :2] + + +# Rotation conversions, behavior mirrors the roma library (https://github.com/naver/roma) + +def _axis_rotmat(axis: str, angle: torch.Tensor) -> torch.Tensor: + """Rotation matrices around a single coordinate axis. Shape (..., 3, 3).""" + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + if axis == "X": + flat = (one, zero, zero, + zero, cos, -sin, + zero, sin, cos) + elif axis == "Y": + flat = (cos, zero, sin, + zero, one, zero, + -sin, zero, cos) + elif axis == "Z": + flat = (cos, -sin, zero, + sin, cos, zero, + zero, zero, one) + else: + raise ValueError(f"Invalid axis {axis!r}; expected X/Y/Z.") + return torch.stack(flat, dim=-1).reshape(angle.shape + (3, 3)) + + +def euler_to_rotmat(convention: str, angles: torch.Tensor) -> torch.Tensor: + """Euler angles -> rotation matrix, matching roma's case-keyed convention.""" + axes = convention.upper() + R0 = _axis_rotmat(axes[0], angles[..., 0]) + R1 = _axis_rotmat(axes[1], angles[..., 1]) + R2 = _axis_rotmat(axes[2], angles[..., 2]) + if convention.islower(): + return R2 @ R1 @ R0 + return R0 @ R1 @ R2 + + +def _index_from_letter(letter: str) -> int: + return {"X": 0, "Y": 1, "Z": 2}[letter] + + +def _angle_from_tan( + axis: str, + other_axis: str, + data: torch.Tensor, + horizontal: bool, + tait_bryan: bool, +) -> torch.Tensor: + """Extract an outer Euler angle from a row/column of a rotation matrix. + + Adapted from PyTorch3D's matrix_to_euler_angles helper. + """ + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ("XY", "YZ", "ZX") + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _matrix_to_euler_intrinsic(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """Decompose a rotation matrix into intrinsic Euler angles (uppercase abc). + + Adapted from PyTorch3D's matrix_to_euler_angles. + """ + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + sign = -1.0 if (i0 - i2) in (-1, 2) else 1.0 + central = torch.asin(matrix[..., i0, i2] * sign) + else: + central = torch.acos(matrix[..., i0, i0]) + + out = ( + _angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan), + central, + _angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan), + ) + return torch.stack(out, dim=-1) + + +def rotmat_to_euler(convention: str, matrix: torch.Tensor) -> torch.Tensor: + """Rotation matrix -> Euler angles, inverse of :func:`euler_to_rotmat`. + + PyTorch3D's matrix_to_euler_angles uses the convention R = R_a R_b R_c for + convention "abc"; that matches roma's UPPERCASE ordering directly. For + roma's lowercase, the matrix is reversed (R_c R_b R_a), so we decompose + with the reversed convention and flip the angles back to axis order. + """ + if matrix.shape[-2:] != (3, 3): + raise ValueError(f"Expected (..., 3, 3) rotation matrix, got {tuple(matrix.shape)}.") + if convention.isupper(): + return _matrix_to_euler_intrinsic(matrix, convention) + decomposed = _matrix_to_euler_intrinsic(matrix, convention.upper()[::-1]) + return decomposed.flip(-1) + + +def unitquat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: + """Unit quaternion (x, y, z, w) -> rotation matrix. + + Matches roma.unitquat_to_rotmat (scalar-last). The quaternion is assumed to be normalized. + + Args: + quat: (..., 4) unit quaternion. + Returns: + (..., 3, 3) rotation matrix. + """ + x, y, z, w = quat.unbind(dim=-1) + tx, ty, tz = 2 * x, 2 * y, 2 * z + twx, twy, twz = tx * w, ty * w, tz * w + txx, txy, txz = tx * x, ty * x, tz * x + tyy, tyz, tzz = ty * y, tz * y, tz * z + one = torch.ones_like(w) + flat = ( + one - (tyy + tzz), txy - twz, txz + twy, + txy + twz, one - (txx + tzz), tyz - twx, + txz - twy, tyz + twx, one - (txx + tyy), + ) + return torch.stack(flat, dim=-1).reshape(quat.shape[:-1] + (3, 3)) diff --git a/comfy_extras/mediapipe/face_landmarker.py b/comfy_extras/mediapipe/face_landmarker.py index a792b6046..6a9a25f82 100644 --- a/comfy_extras/mediapipe/face_landmarker.py +++ b/comfy_extras/mediapipe/face_landmarker.py @@ -5,7 +5,7 @@ from __future__ import annotations import math from functools import lru_cache -from typing import List, Tuple +from typing import List, Optional, Tuple import numpy as np import torch @@ -558,31 +558,47 @@ def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) - class FaceLandmarker(nn.Module): """BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short' - (128², ≤2m) or 'full' (192² FPN, ≤5m). State dict uses inner-module prefixes - `detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel + (128², ≤2m), 'full' (192² FPN, ≤5m), or 'both' — which holds both detectors + alongside shared mesh/blendshapes so detect_batch can run them and per-frame + keep whichever found more faces (tie → short). State dict uses inner-module + prefixes `detector.*` (short or single-variant), `detector_full.*` (only + under 'both'), `mesh.*`, `blendshapes.*`; the outer FaceLandmarkerModel wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading. """ def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"): super().__init__() - det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant) - self.detector_variant = detector_variant - self.detector = det_cls(device=device, dtype=dtype, operations=operations) + if detector_variant == "both": + self.detector = BlazeFace(device=device, dtype=dtype, operations=operations) + self.detector_full = BlazeFaceFullRange(device=device, dtype=dtype, operations=operations) + else: + det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}[detector_variant] + self.detector = det_cls(device=device, dtype=dtype, operations=operations) self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations) self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations) self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False) + def _detector(self, variant: str) -> nn.Module: + if self.detector_variant == "both": + return self.detector_full if variant == "full" else self.detector + return self.detector + def run_detector_batch(self, images_rgb_uint8: List[np.ndarray], score_thresh: float = _BF_MIN_SCORE, - iou_thresh: float = 0.5): + iou_thresh: float = 0.5, + variant: Optional[str] = None): """Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded) - where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords.""" + where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords. + `variant` overrides per-call (required on 'both'-mode instances).""" if not images_rgb_uint8: return [], [], [], [] - device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype + if variant is None: + variant = "short" if self.detector_variant == "both" else self.detector_variant + detector = self._detector(variant) + device, dtype = detector.stem.weight.device, detector.stem.weight.dtype det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range) - if self.detector_variant == "full" + if variant == "full" else (_BF_INPUT_SIZE, _decode_blazeface)) # Same-size frames: stack once and transfer once. Variable size falls back @@ -598,7 +614,7 @@ class FaceLandmarker(nn.Module): det_crops = [w[0] for w in warps] sub_rects = [(w[1], w[2], w[3]) for w in warps] - regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0)) + regs_b, cls_b = detector(torch.stack(det_crops, dim=0)) regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy() per_frame = [] for b in range(len(images_rgb_uint8)): @@ -607,15 +623,22 @@ class FaceLandmarker(nn.Module): return img_raws, sub_rects, sizes, per_frame def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1, - score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]: + score_thresh: float = _BF_MIN_SCORE, + variant: Optional[str] = None) -> List[List[dict]]: """Full pipeline batched across `images_rgb_uint8`. Returns one face-dict list per image (empty if nothing detected). Face dict: bbox_xyxy (4,) image pixels, blendshapes {52} ∈ [0,1], landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in 192-canonical (pre-transformation) units, presence float (raw logit). + On 'both'-mode instances `variant=None` runs both detectors and keeps + whichever found more faces per frame (tie → short). """ + if variant is None and self.detector_variant == "both": + s = self.detect_batch(images_rgb_uint8, num_faces=num_faces, score_thresh=score_thresh, variant="short") + f = self.detect_batch(images_rgb_uint8, num_faces=num_faces, score_thresh=score_thresh, variant="full") + return [s[b] if len(s[b]) >= len(f[b]) else f[b] for b in range(len(images_rgb_uint8))] img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch( - images_rgb_uint8, score_thresh=score_thresh, + images_rgb_uint8, score_thresh=score_thresh, variant=variant, ) # tensor-normalized → image-normalized [0,1] for _detection_to_face_rect. for b, decoded in enumerate(per_frame_dets): diff --git a/comfy_extras/nodes_sam3d_body.py b/comfy_extras/nodes_sam3d_body.py new file mode 100644 index 000000000..afa6855a5 --- /dev/null +++ b/comfy_extras/nodes_sam3d_body.py @@ -0,0 +1,1014 @@ +"""SAM 3D Body — Predict + Smooth nodes and their inference helpers.""" + +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +from tqdm import tqdm +from scipy.signal import savgol_coeffs + +import comfy.model_management +import comfy.utils +from comfy_api.latest import io, ComfyExtension +from typing_extensions import override +import folder_paths + +from comfy.ldm.sam3d_body.model.model import SAM3DBody +from comfy.ldm.sam3d_body.model.dinov3 import apply_dinov3_qkv_bias_mask +from comfy_extras.sam3d_body.utils import ( + cam_int_from_fov, + cam_int_from_moge, + inputs_from_sam3_track, + run_batched_frames, + run_batched_single_chunk, + compute_canonical_colors, + compute_hand_vert_mask, + ) + +from comfy_extras.sam3d_body.rasterizer import render_pose_data_torch as render_pose_data +from comfy_extras.sam3d_body.export.capsules import render_pose_data_capsules +from comfy_extras.sam3d_body.export.openpose_2d import render_pose_data_openpose +from comfy_extras.sam3d_body import face_expression as fx +from comfy_extras.sam3d_body.utils import image_to_uint8 + + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") +# MHRPoseData = SAM3DBody_Predict's native output (carries mhr_model_params, +# shape_params, expr_params, MHR70 keypoint layout, canonical_colors keyed to +# MHR mesh, hand_vert_mask from MHR LBS). The export-side consumers +# (BuildPoseGLB / SavePoseBVH in comfy_extras/nodes_save_3d.py) also accept +# KIMODO_POSE_DATA via a MultiType union — those types are mirrored there. +MHRPoseData = io.Custom("MHR_POSE_DATA") +SAM3DBodyModel = io.Custom("SAM3D_BODY_MODEL") +MoGeGeometry = io.Custom("MOGE_GEOMETRY") + +# Loader + +class SAM3DBody_Loader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3DBody_Loader", + display_name="Load SAM3D Body Model", + category="image/detection/sam3dbody/", #TODO: better category? + inputs=[ + io.Combo.Input( + "model_file", + options=folder_paths.get_filename_list("detection"), + tooltip="SAM 3D Body weights (.safetensors) in the 'detection' folder", + ), + ], + outputs=[SAM3DBodyModel.Output("model", display_name="sam3d_body_model")], + ) + + + @classmethod + def execute(cls, model_file) -> io.NodeOutput: + path = folder_paths.get_full_path_or_raise("detection", model_file) + sd = comfy.utils.load_torch_file(path, safe_load=True) + sd = {k.replace(".layers.0.0.", ".layers.0."): v for k, v in sd.items()} + + load_device = comfy.model_management.get_torch_device() + weight_dtype = comfy.utils.weight_dtype(sd) + torch_dtype = comfy.model_management.unet_dtype( + device=load_device, model_params=-1, weight_dtype=weight_dtype, + ) + manual_cast_dtype = comfy.model_management.unet_manual_cast(torch_dtype, load_device) + operations = comfy.ops.pick_operations(torch_dtype, manual_cast_dtype, load_device=load_device, disable_fast_fp8=True) + + model = SAM3DBody(dtype=torch_dtype, operations=operations) + model.load_state_dict(sd, strict=False) + + apply_dinov3_qkv_bias_mask(model.backbone) + + model.eval() + model.backbone_dtype = torch_dtype + model._sam3d_image_size = model.image_size + + model._sam3d_canonical_colors = compute_canonical_colors(model) + model._sam3d_hand_vert_mask = compute_hand_vert_mask(model) + + patcher = comfy.model_patcher.CoreModelPatcher( + model, + load_device=load_device, + offload_device=comfy.model_management.unet_offload_device(), + size=comfy.model_management.module_size(model), + ) + return io.NodeOutput(patcher) + +# Predict + +class SAM3DBody_Predict(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3DBody_Predict", + display_name="Predict SAM3D Body", + category="image/detection/sam3dbody/", + inputs=[ + SAM3DBodyModel.Input("sam3d_body_model"), + io.Image.Input("image"), + SAM3TrackData.Input( + "sam3_track_data", optional=True, + tooltip=("Output of SAM3 Video Track, required for multi-person detection"), + ), + io.Boolean.Input( + "run_hand_refinement", default=True, + tooltip="Improves hand pose at the cost of extra inference time and memory use"), + io.Float.Input( + "fov_degrees", + default=0.0, min=0.0, max=170.0, step=0.5, + tooltip=( + "Vertical FOV in degrees. Affects predicted depth (cam_t.z) and " + "absolute scale. 0 = use moge_geometry or fall back to ~53° (16:9). " + "Any non-zero value overrides moge_geometry." + ), + ), + MoGeGeometry.Input( + "moge_geometry", + optional=True, + tooltip=( + "MoGe geometry (from MoGeInference), used to calculate camera field of view." + "For batches choose the most representative frame, or leave unset" + ), + ), + io.Int.Input( + "chunk_size", #TODO: automate? + default=64, min=1, max=512, step=1, advanced=True, + tooltip=( + "Max person-crops per forward. Higher = throughput + VRAM; " + "per-chunk frame count is chunk_size / persons_per_frame." + ), + ), + ], + outputs=[MHRPoseData.Output("mhr_pose_data")], + ) + + @classmethod + def execute(cls, sam3d_body_model, image, sam3_track_data=None, run_hand_refinement=True, fov_degrees=0.0, moge_geometry=None, chunk_size=144) -> io.NodeOutput: + comfy.model_management.load_model_gpu(sam3d_body_model) + inner: SAM3DBody = sam3d_body_model.model + + B, H, W, _ = image.shape + image_size = getattr(inner, "_sam3d_image_size", (512, 512)) + + per_frame_bboxes, per_frame_masks = (None, None) + if sam3_track_data is not None: + per_frame_bboxes, per_frame_masks = inputs_from_sam3_track(sam3_track_data, B, H, W) + if per_frame_bboxes is None: + # No track wired (or empty / frame count mismatch) — single-person + # full-frame fallback. Multi-person scenes need SAM3 Video Track. + full_frame_bbox = torch.tensor([[0.0, 0.0, float(W), float(H)]], dtype=torch.float32) + per_frame_bboxes = [full_frame_bbox.clone() for _ in range(B)] + per_frame_masks = None + inference_type = "full" if run_hand_refinement else "body" + # Precedence: explicit fov_degrees > MoGe estimate > diagonal default. + cam_int = cam_int_from_fov(int(H), int(W), float(fov_degrees)) + if cam_int is None: + cam_int = cam_int_from_moge(moge_geometry, int(H), int(W)) + + frames_rgb: List[Optional[torch.Tensor]] = [] + for f in range(B): + if per_frame_bboxes[f].shape[0] == 0: + frames_rgb.append(None) + else: + frames_rgb.append(image_to_uint8(image[f])) + + # Batched path requires uniform non-zero K across all frames. + bbox_counts = {per_frame_bboxes[f].shape[0] for f in range(B) if frames_rgb[f] is not None} + can_batch = ( + len(bbox_counts) == 1 + and 0 not in bbox_counts + and all(frames_rgb[f] is not None for f in range(B)) + ) + + frames_out: List[List[Dict[str, Any]]] = [] + pbar = comfy.utils.ProgressBar(B) + + if can_batch and B > 0: + frames_out = run_batched_frames( + inner, frames_rgb, per_frame_bboxes, per_frame_masks, + image_size, inference_type, + cam_int=cam_int, + pbar=pbar, + crops_per_chunk=int(chunk_size), + ) + else: + # Mixed K per frame — call the batched path once per frame. + for f in range(B): + if frames_rgb[f] is None or per_frame_bboxes[f].shape[0] == 0: + frames_out.append([]) + pbar.update(1) + continue + mask_f = [per_frame_masks[f]] if per_frame_masks is not None else None + chunk = run_batched_single_chunk( + inner, [frames_rgb[f]], [per_frame_bboxes[f]], mask_f, + image_size, inference_type, + K=int(per_frame_bboxes[f].shape[0]), + cam_int=cam_int, + ) + frames_out.append(chunk[0]) + pbar.update(1) + + mhr_pose_data = { + "frames": frames_out, + "faces": inner.head_pose.faces.cpu().numpy(), + "image_size": (int(H), int(W)), + "canonical_colors": getattr(inner, "_sam3d_canonical_colors", None), + "hand_vert_mask": getattr(inner, "_sam3d_hand_vert_mask", None), + } + return io.NodeOutput(mhr_pose_data) + + +class SAM3DBody_FaceExpression(io.ComfyNode): + """Drive MHR face blendshapes from the core MediaPipe Face Landmarker. + + Detects per-frame faces, IoU-matches each to a tracked person, maps the 52 + ARKit blendshapes onto MHR's 72-axis `expr_params`, and re-runs MHR forward + so pred_vertices/pred_keypoints reflect the new expression. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3DBody_FaceExpression", + description="Drive MHR face blendshapes from the core MediaPipe Face Landmarker.", + display_name="Face Expression to SAM3D Body", #TODO: better name? + category="image/detection/sam3dbody/", + inputs=[ + MHRPoseData.Input("mhr_pose_data"), + SAM3DBodyModel.Input("sam3d_body_model"), + io.Image.Input("image"), + io.Float.Input( + "strength", default=1.0, min=0.0, max=4.0, step=0.05, + tooltip="Global multiplier on all blendshapes. >1 exaggerates.", + ), + io.Float.Input( + "mouth_strength", default=1.0, min=0.0, max=4.0, step=0.05, + tooltip="Multiplier on mouth/jaw shapes. MP's jawOpen saturates near 1.0.", + advanced=True, + ), + io.Float.Input( + "eye_strength", default=2.0, min=0.0, max=4.0, step=0.05, + tooltip="Multiplier on eye shapes. MP rarely exceeds 0.5; 2-3x often needed.", + advanced=True, + ), + io.Float.Input( + "brow_strength", default=2.0, min=0.0, max=4.0, step=0.05, + tooltip="Multiplier on brow/cheek/sneer shapes. MP outputs ~0.1-0.3; 2-3x.", + advanced=True, + ), + io.Float.Input( + "input_threshold", default=0.02, min=0.0, max=0.5, step=0.01, + tooltip=( + "Deadzone on MediaPipe's raw output (below = zero, above = linear remap). " + ), + advanced=True, + ), + io.Int.Input( + "blendshape_smooth_window", default=7, min=1, max=31, step=2, + tooltip=( + "Gaussian window on MediaPipe's per-frame signal before MHR mapping. " + "MediaPipe's raw output swings 30-70% frame-to-frame on static faces. " + "1 = disabled. Use odd values." + ), + advanced=True, + ), + ], + outputs=[ + MHRPoseData.Output("mhr_pose_data"), + ], + ) + + @classmethod + def execute(cls, mhr_pose_data, sam3d_body_model, image, + strength=1.0, mouth_strength=1.0, eye_strength=1.0, brow_strength=1.0, + input_threshold=0.15, blendshape_smooth_window=7) -> io.NodeOutput: + + comfy.model_management.load_model_gpu(sam3d_body_model) + inner: SAM3DBody = sam3d_body_model.model + + frames = mhr_pose_data["frames"] + B = len(frames) + if B == 0: + return io.NodeOutput(mhr_pose_data) + + img_np = (image * 255.0).clamp(0.0, 255.0).to(torch.uint8).cpu().numpy() + new_frames: List[List[Dict[str, Any]]] = [[dict(p) for p in f] for f in frames] + + max_persons = max((len(f) for f in new_frames), default=0) + per_person_coefs: List[List[Optional[Dict[str, float]]]] = [ + [None] * B for _ in range(max_persons) + ] + pbar = comfy.utils.ProgressBar(B) + n_total_frames_with_persons = 0 + + crop_factor = 1.2 + + # Auto-pick full-frame vs per-person crops. BlazeFace full-range needs + # ≥32px face in its 192px input; below that we escalate to per-person + # crops. Face height ≈ 20% of body-bbox height (rough but stable). + H_img0, W_img0 = img_np.shape[1], img_np.shape[2] + min_bbox_px = 32.0 * max(H_img0, W_img0) / (192.0 * 0.20) + use_per_person_crops = any( + (p["bbox"][3] - p["bbox"][1]) < min_bbox_px + for persons in new_frames for p in persons + ) + + for fi in tqdm(range(B), desc="SAM3D face expression detect"): + persons = new_frames[fi] + img_fi = img_np[min(fi, img_np.shape[0] - 1)] + + if not persons: + pbar.update(1) + continue + n_total_frames_with_persons += 1 + + person_bboxes = [np.asarray(p["bbox"], dtype=np.float32) for p in persons] + H_img, W_img = img_fi.shape[:2] + + if use_per_person_crops: + # One MP call per person on a tight head crop — recovers small/ + # distant faces that the full-frame 192px BlazeFace would miss. + for pid, pb in enumerate(person_bboxes): + if pid >= max_persons: + continue + cr = fx.head_crop_from_keypoints( + persons[pid].get("pred_keypoints_2d"), crop_factor, W_img, H_img, + ) + if cr is None: + cr = fx.head_region_crop(pb, crop_factor, W_img, H_img) + faces = fx.detect_faces_in_crop(inner, img_fi, cr, num_faces=1) + if not faces: + continue + # Pick face closest to person bbox center when a neighbor leaks in. + pcx, pcy = 0.5 * (pb[0] + pb[2]), 0.5 * (pb[1] + pb[3]) + best = min( + faces, + key=lambda f: (0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]) - pcx) ** 2 + + (0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3]) - pcy) ** 2, + ) + per_person_coefs[pid][fi] = best["blendshapes"] + else: + faces = inner.face_landmarker.detect_batch([img_fi], num_faces=max(1, len(persons)))[0] + if faces: + face_bboxes = [f["bbox_xyxy"] for f in faces] + assignment = fx.assign_faces_to_persons(face_bboxes, person_bboxes) + for pid, face_idx in enumerate(assignment): + if face_idx is None or pid >= max_persons: + continue + per_person_coefs[pid][fi] = faces[face_idx]["blendshapes"] + + pbar.update(1) + + # Baseline subtraction. MP has subject-specific rest bias (e.g. + # naturally-raised brow at 0.15); without subtraction, strength + # multipliers bake that into every frame. Per-clip needs ~30 frames + # or it would zero out the expression. + BASELINE_MIN_FRAMES = 30 + if n_total_frames_with_persons >= BASELINE_MIN_FRAMES: + for pid in range(max_persons): + per_person_coefs[pid] = fx.subtract_per_clip_baseline( + per_person_coefs[pid], percentile=5.0, + ) + else: + logging.warning( + f"[SAM 3D Body FaceExpression] per-clip baseline subtraction " + f"needs ~{BASELINE_MIN_FRAMES}+ frames with detections; " + f"got {n_total_frames_with_persons}. Skipping subtraction." + ) + + # Smooth raw signal AFTER baseline subtraction but BEFORE gap fill — + # MP's per-frame noise gets averaged out at the source. + bs_win = int(blendshape_smooth_window) + if bs_win > 1: + for pid in range(max_persons): + per_person_coefs[pid] = fx.smooth_blendshape_series( + per_person_coefs[pid], window=bs_win, + ) + + for pid in range(max_persons): + per_person_coefs[pid] = fx.fill_detection_gaps( + per_person_coefs[pid], method="interpolate", max_gap=12, + ) + + n_written = 0 + for fi in range(B): + for pid, p in enumerate(new_frames[fi]): + if pid >= max_persons: + continue + coefs = per_person_coefs[pid][fi] + if coefs is None: + continue + p["expr_params"] = fx.arkit_to_expr_params( + coefs, + strength=float(strength), + mouth_strength=float(mouth_strength), + eye_strength=float(eye_strength), + brow_strength=float(brow_strength), + input_threshold=float(input_threshold), + ).astype(np.float32) + n_written += 1 + + if n_written > 0: + fx.regenerate_mesh_from_params(inner, new_frames) + + new_pose = dict(mhr_pose_data) + new_pose["frames"] = new_frames + + return io.NodeOutput(new_pose) + + +class SAM3DBody_Smooth(io.ComfyNode): + """Reduce frame-to-frame jitter via vertex-space temporal averaging. + Backs off on mesh-geometry keys when the subject rotates fast (averaging + across a spin flattens the mesh); camera-space keys still get full + smoothing. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3DBody_Smooth", + description="Reduce frame-to-frame jitter via vertex-space temporal averaging", + display_name="Smooth SAM3D Body Pose Frames", + category="image/detection/sam3dbody/", + inputs=[ + MHRPoseData.Input("mhr_pose_data"), + io.Float.Input( + "strength", + default=1.0, min=0.0, max=1.0, step=0.05, + tooltip="Blend raw (0) → smoothed (1).", + ), + io.Combo.Input( + "method", + options=["gaussian", "savgol"], + default="gaussian", advanced=True, + tooltip=( + "'gaussian': symmetric weighted average — phase-preserving " + "(no time-shift), best general-purpose smoother. " + "'savgol': sliding polynomial fit — preserves sharp peaks " + ), + ), + io.Int.Input( + "window", + default=7, min=1, max=51, step=2, advanced=True, + tooltip="Temporal window in frames (odd values).", + ), + io.Float.Input( + "rotation_threshold_deg", + default=15.0, min=0.0, max=45.0, step=1.0, advanced=True, + tooltip=( + "Geometry smoothing drops to RAW above this root-rotation " + "rate (deg/frame) to preserve fast spins. 15° suits most " + "content; low values trigger on ordinary jitter and " + "silently sabotage smoothing. 0 = disable backoff." + ), + ), + ], + outputs=[MHRPoseData.Output("mhr_pose_data")], + ) + + @classmethod + def execute(cls, mhr_pose_data, method, window, strength, rotation_threshold_deg) -> io.NodeOutput: + if strength <= 0.0 or window <= 1: + return io.NodeOutput(mhr_pose_data) + + frames = mhr_pose_data["frames"] + B = len(frames) + if B < 2: + return io.NodeOutput(mhr_pose_data) + max_p = max((len(f) for f in frames), default=0) + if max_p == 0: + return io.NodeOutput(mhr_pose_data) + + # Geometry keys rotate with the subject, so linear averaging during + # fast spins flattens the mesh — these get per-frame adaptive strength. + keys_geom = { + "pred_vertices", "pred_keypoints_3d", "pred_joint_coords", + "pred_global_rots", "mhr_model_params", "body_pose_params", + "global_rot", "pred_pose_raw", + } + # Camera / appearance / 2D keys are safe to smooth linearly. + keys_cam = { + "pred_cam_t", "pred_keypoints_2d", "focal_length", + "shape_params", "scale_params", "hand_pose_params", "expr_params", + } + all_keys = sorted(keys_geom | keys_cam) + + kernel = _smoothing_kernel(method, window) + smoothed = [list(f) for f in frames] + + base_blend = float(strength) + rot_thresh = float(np.deg2rad(max(0.0, rotation_threshold_deg))) + + for pid in range(max_p): + valid = np.array([pid < len(f) for f in frames], dtype=bool) + if valid.sum() < 2: + continue + + # Adaptive blend per frame from `global_rot` (euler ZYX); + geom_blend = np.full(B, base_blend, dtype=np.float32) + if rot_thresh > 0.0: + root_rotmats = [] + valid_root = [] + for fi in range(B): + if not valid[fi]: + root_rotmats.append(np.eye(3, dtype=np.float32)) + valid_root.append(False) + continue + gr = frames[fi][pid].get("global_rot") + if gr is None: + root_rotmats.append(np.eye(3, dtype=np.float32)) + valid_root.append(False) + continue + eul = np.asarray(gr, dtype=np.float32).reshape(3) + # ZYX convention: R = Rz @ Ry @ Rx + cz, sz = np.cos(eul[0]), np.sin(eul[0]) + cy, sy = np.cos(eul[1]), np.sin(eul[1]) + cx, sx = np.cos(eul[2]), np.sin(eul[2]) + Rz = np.array([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dtype=np.float32) + Ry = np.array([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=np.float32) + Rx = np.array([[1, 0, 0], [0, cx, -sx], [0, sx, cx]], dtype=np.float32) + root_rotmats.append(Rz @ Ry @ Rx) + valid_root.append(True) + ang = np.zeros(B, dtype=np.float32) + for fi in range(1, B): + if valid_root[fi] and valid_root[fi - 1]: + R_delta = root_rotmats[fi] @ root_rotmats[fi - 1].T + cos_a = float(np.clip((np.trace(R_delta) - 1.0) / 2.0, -1.0, 1.0)) + ang[fi] = abs(np.arccos(cos_a)) + # Peak-smear over ±window/2 — neighbors of a rotated frame + # must back off too, or the temporal average pulls the rotated + # pose in and flattens the mesh anyway. + half = max(1, window // 2) + ang_smooth = np.zeros_like(ang) + for fi in range(B): + lo = max(0, fi - half) + hi = min(B, fi + half + 1) + ang_smooth[fi] = ang[lo:hi].max() if hi > lo else 0.0 + # base_blend at no rotation, 0 at threshold. + ratio = np.clip(ang_smooth / rot_thresh, 0.0, 1.0) + geom_blend = base_blend * (1.0 - ratio) + + for key in all_keys: + ref = None + for fi in range(B): + if valid[fi] and key in frames[fi][pid] and frames[fi][pid][key] is not None: + ref = np.asarray(frames[fi][pid][key]) + break + if ref is None or ref.dtype.kind not in "fiu": + continue + stacked = np.zeros((B,) + ref.shape, dtype=np.float32) + for fi in range(B): + if valid[fi] and key in frames[fi][pid] and frames[fi][pid][key] is not None: + stacked[fi] = np.asarray(frames[fi][pid][key], dtype=np.float32) + else: + stacked[fi] = stacked[fi - 1] if fi > 0 else 0.0 + filtered = _apply_temporal_filter(stacked, kernel) + + if key in keys_geom: + b = geom_blend + while b.ndim < stacked.ndim: + b = b[..., None] + out = (1.0 - b) * stacked + b * filtered + else: + out = (1.0 - base_blend) * stacked + base_blend * filtered + + for fi in range(B): + if valid[fi]: + smoothed[fi][pid] = dict(smoothed[fi][pid]) + smoothed[fi][pid][key] = out[fi].astype(ref.dtype) + + new_pose = dict(mhr_pose_data) + new_pose["frames"] = smoothed + return io.NodeOutput(new_pose) + + +def _smoothing_kernel(method: str, window: int) -> np.ndarray: + window = max(1, int(window)) + if window % 2 == 0: + window += 1 + if method == "savgol": + order = 3 if window >= 5 else min(window - 1, 1) + return savgol_coeffs(window, order).astype(np.float32) + # gaussian (default) + sigma = max(1.0, window / 5.0) + x = np.arange(window) - (window - 1) / 2.0 + k = np.exp(-(x ** 2) / (2 * sigma ** 2)) + return (k / k.sum()).astype(np.float32) + + +def _apply_temporal_filter(stacked: np.ndarray, kernel: np.ndarray) -> np.ndarray: + """stacked: (B, *feature_shape). Returns same shape, smoothed over axis 0.""" + B = stacked.shape[0] + w = len(kernel) + pad = w // 2 + flat = stacked.reshape(B, -1) # (B, K) + padded = np.concatenate( + [np.repeat(flat[:1], pad, axis=0), flat, np.repeat(flat[-1:], pad, axis=0)], + axis=0, + ) # (B + 2*pad, K) + out = np.zeros_like(flat) + for i, k in enumerate(kernel): + out += k * padded[i : i + B] + return out.reshape(stacked.shape) + + +# Render + +def rainbow_tilt_inputs(): + """Shared rainbow-shader tilt inputs (used by Render and ToGLB schemas).""" + return [ + io.Float.Input( + "rainbow_tilt_z", default=-35.0, min=-90.0, max=90.0, step=0.5, + tooltip="Rotate rainbow jet axis around Z (forward). Differentiates left/right.", + ), + io.Float.Input( + "rainbow_tilt_x", default=0.0, min=-90.0, max=90.0, step=0.5, + tooltip="Rotate rainbow jet axis around X (right). Differentiates front/back.", + ), + ] + + +def _render_mesh_mode_inputs(): + return [ + io.DynamicCombo.Input( + "shader", + options=[ + io.DynamicCombo.Option("default", []), + io.DynamicCombo.Option("normals", []), + io.DynamicCombo.Option("rainbow", rainbow_tilt_inputs()), + io.DynamicCombo.Option("rainbow_face_normal", rainbow_tilt_inputs()), + io.DynamicCombo.Option("rainbow_face_semantic", rainbow_tilt_inputs()), + io.DynamicCombo.Option("depth", []), + ], + tooltip=( + "Preset shader. 'normals' = current surface normal in camera " + "space (OpenGL Y+ normal-map convention: +X→R, +Y→G, +Z→B). " + "'rainbow' = RealisDance style body-Y jet; the 'rainbow_face_*' " + "variants override face verts with normal/per-region colors; " + "'depth' = linear gray." + ), + ), + io.Float.Input("opacity", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input( + "person_palette_falloff", + default=0.6, min=0.1, max=1.0, step=0.05, + tooltip=( + "Per-person desaturation toward white: track k gets a " + "(1 - falloff^k) pastel mix (SCAIL 'softer second person'). 1.0 = off." + ), + ), + io.Combo.Input( + "region", + options=["full_body", "hands_only"], + default="full_body", + tooltip=( + "'hands_only' filters faces via the precomputed `hand_vert_mask` " + "(LBS weights against canonical hand KPs) — isolates the hand " + "mesh for debugging. Falls back to full mesh if the mask is missing." + ), + ), + ] + + +def _render_capsules_mode_inputs(): + return [ + io.Float.Input( + "radius_m", default=0.022, min=0.005, max=0.2, step=0.001, + tooltip="Capsule radius in meters (SCAIL reference: ~0.022 m).", + ), + io.Combo.Input( + "hand_style", + options=["disabled", "dwpose", "openpose"], + default="dwpose", + tooltip=( + "Composite 2D OpenPose hands on top of the 3D capsule body " + "(matches SCAIL — no 3D hand capsules). 'disabled' = no hands. " + "'dwpose' = solid-blue hand dots; 'openpose' = rainbow dots. " + "Sticks stay rainbow per-finger either way." + ), + ), + io.Combo.Input( + "face_style", + options=["disabled", "full", "eyes_mouth"], + default="disabled", + tooltip=( + "'full' = all face landmarks (sapiens-238 if present, else " + "rig-fallback ~30). 'eyes_mouth' = rig-fallback subset (~12 " + "dots: eyes + outer lips only). 'disabled' = no face dots." + ), + ), + io.Float.Input( + "person_palette_falloff", default=0.6, min=0.1, max=1.0, step=0.05, + tooltip=( + "Per-person desaturation: track k blends toward white by " + "1 - falloff^k. Track 0 stays vivid; 1.0 disables falloff." + ), + ), + ] + + +def _render_openpose_mode_inputs(): + return [ + io.Int.Input( + "marker_radius_px", default=4, min=1, max=32, step=1, + tooltip="Body keypoint dot radius (px).", + ), + io.Int.Input( + "stick_width_px", default=4, min=1, max=32, step=1, + tooltip="Body limb ellipse half-width (px). DWPose default = 4.", + ), + io.Float.Input( + "limb_alpha", default=0.6, min=0.0, max=1.0, step=0.05, + tooltip="Per-limb alpha. DWPose default = 0.6.", + ), + io.Combo.Input( + "face_style", + options=["disabled", "full", "eyes_mouth"], + default="disabled", + tooltip=( + "'full' = all face landmarks (sapiens-238 if present, else " + "rig-fallback ~30). 'eyes_mouth' = rig-fallback subset (~12 " + "dots: eyes + outer lips only). 'disabled' = no face dots." + ), + ), + io.Combo.Input( + "hand_style", + options=["disabled", "dwpose", "openpose"], + default="disabled", + tooltip=( + "Draw 21+21 hand keypoints + sticks. 'disabled' = no hands. " + "'dwpose' = solid-blue dots; 'openpose' = rainbow dots." + ), + ), + io.Float.Input( + "person_palette_falloff", default=0.6, min=0.1, max=1.0, step=0.05, + tooltip=( + "Per-person desaturation: track k blends toward white by " + "1 - falloff^k. Track 0 stays vivid; 1.0 disables falloff." + ), + ), + ] + + +def _scale_pose_data(mhr_pose_data: Dict[str, Any], new_H: int, new_W: int) -> Dict[str, Any]: + """Rescale per-person camera intrinsics + 2D coords to a new canvas size. + Pose data records focal_length in pixels of the original image; without + scaling, the FOV would change and subjects would be cropped/zoomed. + + When the new aspect differs from the original, the body (3D-projected + through focal_length on a centered principal point) lands in a + letterboxed region of the new canvas. 2D-prestored coords must follow + the same uniform scale + center offset so face/hand overlays align with + the body — per-axis stretching would split them apart.""" + old_H, old_W = mhr_pose_data["image_size"] + if new_H == old_H and new_W == old_W: + return mhr_pose_data + rW = new_W / old_W + rH = new_H / old_H + r_focal = min(rW, rH) + offset_x = (new_W - r_focal * old_W) * 0.5 + offset_y = (new_H - r_focal * old_H) * 0.5 + new_frames: List[List[Dict[str, Any]]] = [] + for frame in mhr_pose_data["frames"]: + scaled = [] + for p in frame: + p = dict(p) + f = p.get("focal_length") + if f is not None: + p["focal_length"] = np.asarray(f, dtype=np.float32) * r_focal + for k in ("pred_keypoints_2d", "pred_face_keypoints_2d"): + v = p.get(k) + if v is not None: + arr = np.asarray(v, dtype=np.float32).copy() + arr[..., 0] = arr[..., 0] * r_focal + offset_x + arr[..., 1] = arr[..., 1] * r_focal + offset_y + p[k] = arr + bb = p.get("bbox") + if bb is not None: + bb = np.asarray(bb, dtype=np.float32).copy() + bb[..., [0, 2]] = bb[..., [0, 2]] * r_focal + offset_x + bb[..., [1, 3]] = bb[..., [1, 3]] * r_focal + offset_y + p["bbox"] = bb + scaled.append(p) + new_frames.append(scaled) + out = dict(mhr_pose_data) + out["image_size"] = (new_H, new_W) + out["frames"] = new_frames + return out + + +class SAM3DBody_Render(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3DBody_Render", + display_name="Render SAM3D Body", + category="image/detection/sam3dbody/", + inputs=[ + MHRPoseData.Input("mhr_pose_data"), + io.Image.Input( + "background", + optional=True, + tooltip="Per-frame background. Omitted = black canvas.", + ), + io.Int.Input( + "width", default=0, min=0, max=16384, step=8, + tooltip=( + "Output width in pixels. 0 = use pose data's native " + "image_size. If only one of width/height is set, the " + "other is derived preserving the original aspect." + ), + ), + io.Int.Input( + "height", default=0, min=0, max=16384, step=8, + tooltip=( + "Output height in pixels. 0 = use pose data's native " + "image_size. If only one of width/height is set, the " + "other is derived preserving the original aspect." + ), + ), + io.DynamicCombo.Input( + "render_style", + options=[ + io.DynamicCombo.Option("mesh", _render_mesh_mode_inputs()), + io.DynamicCombo.Option("silhouette", []), + io.DynamicCombo.Option("openpose", _render_openpose_mode_inputs()), + io.DynamicCombo.Option("scail", _render_capsules_mode_inputs()), + ], + tooltip=( + "'mesh' = 3D MHR mesh rasterized through the camera. " + "'silhouette' = binary mask of the mesh (white-on-black, " + "background ignored). 'openpose' = flat 2D skeleton " + "from pred_keypoints_2d (DWPose look). 'scail' = SCAIL " + "3D capsules via torch SDF ray-march (proper occlusion / depth)." + ), + ), + ], + outputs=[io.Image.Output("image")], + ) + + + @classmethod + def execute(cls, mhr_pose_data, background=None, width=0, height=0, render_style=None) -> io.NodeOutput: + render_style = render_style or {"render_style": "mesh"} + mode_key = render_style.get("render_style", "mesh") + + native_H, native_W = mhr_pose_data["image_size"] + new_W, new_H = int(width), int(height) + if new_W == 0 and new_H == 0: + H, W = native_H, native_W + px_scale = 1.0 + else: + if new_W == 0: + new_W = max(1, round(native_W * new_H / native_H)) + elif new_H == 0: + new_H = max(1, round(native_H * new_W / native_W)) + mhr_pose_data = _scale_pose_data(mhr_pose_data, new_H, new_W) + H, W = new_H, new_W + # Marker/stick px constants are authored for native resolution — + # scale them so the openpose overlay reads at the same relative size. + px_scale = min(new_W / native_W, new_H / native_H) + + B = len(mhr_pose_data["frames"]) + if B == 0: + return io.NodeOutput(torch.zeros(1, H, W, 3, dtype=torch.float32)) + + out_device = comfy.model_management.intermediate_device() + bg_t = None if background is None else background.to(device=out_device, dtype=torch.float32) + + + if mode_key == "silhouette": + composite = "silhouette" + elif bg_t is not None: + composite = "over" + else: + composite = "mesh_only" + + if mode_key == "openpose": + marker_radius_px = max(1, int(round(render_style.get("marker_radius_px", 4) * px_scale))) + stick_width_px = max(1, int(round(render_style.get("stick_width_px", 4) * px_scale))) + limb_alpha = float(render_style.get("limb_alpha", 0.6)) + face_style = str(render_style.get("face_style", "disabled")) + hand_style = str(render_style.get("hand_style", "disabled")) + include_hands = hand_style != "disabled" + hand_color_style = hand_style if include_hands else "dwpose" + person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6)) + elif mode_key == "scail": + cap_radius_m = float(render_style.get("radius_m", 0.030)) + cap_hand_style = str(render_style.get("hand_style", "disabled")) + cap_include_hands = cap_hand_style != "disabled" + cap_hand_color_style = cap_hand_style if cap_include_hands else "dwpose" + cap_face_style = str(render_style.get("face_style", "disabled")) + person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6)) + elif mode_key == "mesh": + shader_dict = render_style.get("shader") or {} + shader_key = shader_dict.get("shader", "default") + rainbow_tilt_x = float(shader_dict.get("rainbow_tilt_x", 0.0)) + rainbow_tilt_z = float(shader_dict.get("rainbow_tilt_z", -35.0)) + opacity = float(render_style.get("opacity", 1.0)) + person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6)) + region = str(render_style.get("region", "full_body")) + + if region == "hands_only": + hand_mask = mhr_pose_data["hand_vert_mask"] + faces_full = np.asarray(mhr_pose_data["faces"]) + keep = hand_mask[faces_full].all(axis=1) + mhr_pose_data = dict(mhr_pose_data) + mhr_pose_data["faces"] = np.ascontiguousarray( + faces_full[keep], dtype=faces_full.dtype, + ) + else: # silhouette — no shader/opacity controls, mask is binary + shader_key = "default" + rainbow_tilt_x = 0.0 + rainbow_tilt_z = -35.0 + opacity = 1.0 + person_palette_falloff = 0.6 + + frames_out = [] + pbar = comfy.utils.ProgressBar(B) + desc = ( + "SAM3D openpose-2D render" if mode_key == "openpose" + else "SAM3D SCAIL-3D render" if mode_key == "scail" + else "SAM3D silhouette" if mode_key == "silhouette" + else "SAM3D render" + ) + for f in tqdm(range(B), desc=desc): + bg_f = None + if bg_t is not None: + bg_f = bg_t[min(f, bg_t.shape[0] - 1)] + if mode_key == "openpose": + img = render_pose_data_openpose( + mhr_pose_data, frame_idx=f, W=W, H=H, + background=bg_f, + composite=composite, + marker_radius_px=marker_radius_px, + stick_width_px=stick_width_px, + limb_alpha=limb_alpha, + include_hands=include_hands, + face_style=face_style, + hand_color_style=hand_color_style, + person_brightness_falloff=person_palette_falloff, + ) + elif mode_key == "scail": + # SCAIL renders body as 3D capsules + 2D openpose hands on top + img = render_pose_data_capsules( + mhr_pose_data, frame_idx=f, W=W, H=H, + background=bg_f, + composite=composite, + radius_m=cap_radius_m, + include_hands=False, + palette="scail", + person_brightness_falloff=person_palette_falloff, + ) + if cap_include_hands or cap_face_style != "disabled": + scail_overlay_px = max(1, int(round(4 * px_scale))) + scail_face_px = max(1, int(round(1 * px_scale))) + img = render_pose_data_openpose( + mhr_pose_data, frame_idx=f, W=W, H=H, + background=img, + composite="over", + include_body=False, + include_hands=cap_include_hands, + face_style=cap_face_style, + marker_radius_px=scail_overlay_px, + stick_width_px=scail_overlay_px, + face_marker_radius_px=scail_face_px, + hand_color_style=cap_hand_color_style, + person_brightness_falloff=person_palette_falloff, + ) + else: + img = render_pose_data( + mhr_pose_data, frame_idx=f, W=W, H=H, + background=bg_f, composite=composite, opacity=opacity, + shader_preset=shader_key, + rainbow_tilt_x_deg=rainbow_tilt_x, + rainbow_tilt_z_deg=rainbow_tilt_z, + person_brightness_falloff=person_palette_falloff, + ) + frames_out.append(img) + pbar.update(1) + + out_image = torch.stack(frames_out, dim=0) + if out_image.device != out_device: + out_image = out_image.to(out_device) + return io.NodeOutput(out_image) + + +class SAM3DBodyExtension(ComfyExtension): + @override + async def get_node_list(self) -> List[type[io.ComfyNode]]: + return [ + SAM3DBody_Loader, + SAM3DBody_Predict, + SAM3DBody_FaceExpression, + SAM3DBody_Smooth, + SAM3DBody_Render, + ] + + +async def comfy_entrypoint() -> SAM3DBodyExtension: + return SAM3DBodyExtension() diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index c03524246..f0a57a2f1 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -1,4 +1,6 @@ -"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB node.""" +"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB +node, plus pose-data exporters (BuildPoseGLB / SavePoseBVH) that accept either +SAM3DBody Predict's MHR pose data or external-rig pose data from Kimodo.""" import json import logging @@ -15,6 +17,15 @@ import folder_paths from comfy.cli_args import args from comfy_api.latest import ComfyExtension, IO, Types +from comfy_extras.sam3d_body.export.bvh import build_bvh +from comfy_extras.sam3d_body.export.glb_openpose import build_glb_openpose +from comfy_extras.sam3d_body.export.glb_skeletal import build_glb_skeletal + + +MHRPoseData = IO.Custom("MHR_POSE_DATA") +KimodoPoseData = IO.Custom("KIMODO_POSE_DATA") +SAM3DBodyModel = IO.Custom("SAM3D_BODY_MODEL") + def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None): # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, @@ -386,10 +397,437 @@ class SaveGLB(IO.ComfyNode): return IO.NodeOutput(ui={"3d": results}) +def rainbow_tilt_inputs(): + """Shared rainbow-shader tilt inputs (used by Render and ToGLB schemas).""" + return [ + IO.Float.Input( + "rainbow_tilt_z", default=-35.0, min=-90.0, max=90.0, step=0.5, + tooltip="Rotate rainbow jet axis around Z (forward). Differentiates left/right.", + ), + IO.Float.Input( + "rainbow_tilt_x", default=0.0, min=-90.0, max=90.0, step=0.5, + tooltip="Rotate rainbow jet axis around X (right). Differentiates front/back.", + ), + ] + + +class BuildPoseGLB(IO.ComfyNode): + """Convert pose_data to an in-memory animated GLB""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BuildPoseGLB", + display_name="Build Pose GLB", + description="Convert pose data to an animated GLB", + category="3d", + inputs=[ + IO.MultiType.Input( + "pose_data", types=[MHRPoseData, KimodoPoseData], + tooltip=( + "MHR pose data from SAM3DBody_Predict, or external-rig " + "pose data from Kimodo (`_skeleton_override`-augmented)." + ), + ), + SAM3DBodyModel.Input("sam3d_body_model", optional=True), + IO.DynamicCombo.Input( + "mesh_style", + options=[ + IO.DynamicCombo.Option("body_mesh", [ + IO.Int.Input( + "bone_smooth_window", + default=0, min=0, max=51, step=2, + tooltip=( + "Gaussian window on per-bone rotation keyframes. 0 = off. " + "7-15 helps cartwheels/spins where upstream Smooth misses spikes." + ), + ), + IO.DynamicCombo.Input( + "bone_vis", + options=[ + IO.DynamicCombo.Option("off", []), + IO.DynamicCombo.Option("octahedrons", [ + IO.Float.Input( + "bone_vis_radius_m", + default=0.02, min=0.005, max=0.5, step=0.005, + tooltip="Radius in m (sphere radius / octahedron half-width).", + ), + IO.Combo.Input( + "bone_vis_color", + options=["white", "rainbow_y"], + default="rainbow_y", + tooltip=( + "Per-bone vertex colors (unlit material). " + "'white' = none, 'rainbow_y' = head→toe jet." + ), + ), + ]), + IO.DynamicCombo.Option("sticks", [ + IO.Combo.Input( + "bone_vis_color", + options=["white", "rainbow_y"], + default="rainbow_y", + tooltip="Per-bone vertex colors (see octahedrons).", + ), + ]), + ], + tooltip=( + "Bone vis shape, rigidly skinned to each joint. " + "'octahedrons' = Blender-style directional bones (joint → " + "primary child); 'sticks' = thin lines." + ), + ), + IO.DynamicCombo.Input( + "shader", + options=[ + IO.DynamicCombo.Option("default", []), + IO.DynamicCombo.Option("rainbow", [ + *rainbow_tilt_inputs(), + IO.Float.Input( + "person_palette_falloff", + default=0.6, min=0.1, max=1.0, step=0.05, + tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.", + ), + ]), + IO.DynamicCombo.Option("rainbow_face_normal", [ + *rainbow_tilt_inputs(), + IO.Float.Input( + "person_palette_falloff", + default=0.6, min=0.1, max=1.0, step=0.05, + tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.", + ), + ]), + IO.DynamicCombo.Option("rainbow_face_semantic", [ + *rainbow_tilt_inputs(), + IO.Float.Input( + "person_palette_falloff", + default=0.6, min=0.1, max=1.0, step=0.05, + tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.", + ), + ]), + ], + tooltip=( + "Bake per-vertex colors matching the Render node's shaders " + "(COLOR_0 + KHR_materials_unlit). 'default' = no colors." + ), + ), + ]), + IO.DynamicCombo.Option("bones_only", [ + IO.Int.Input( + "bone_smooth_window", + default=0, min=0, max=51, step=2, + tooltip=( + "Gaussian window on per-bone rotation keyframes. 0 = off. " + "7-15 helps cartwheels/spins where upstream Smooth misses spikes." + ), + ), + IO.DynamicCombo.Input( + "bone_vis", + options=[ + IO.DynamicCombo.Option("octahedrons", [ + IO.Float.Input( + "bone_vis_radius_m", + default=0.02, min=0.005, max=0.5, step=0.005, + tooltip="Radius in m (sphere radius / octahedron half-width).", + ), + IO.Combo.Input( + "bone_vis_color", + options=["white", "rainbow_y"], + default="rainbow_y", + tooltip=( + "Per-bone vertex colors (unlit material). " + "'white' = none, 'rainbow_y' = head→toe jet." + ), + ), + ]), + IO.DynamicCombo.Option("sticks", [ + IO.Combo.Input( + "bone_vis_color", + options=["white", "rainbow_y"], + default="rainbow_y", + tooltip="Per-bone vertex colors (see octahedrons).", + ), + ]), + ], + tooltip=( + "Bone vis shape, rigidly skinned to each joint. " + "'octahedrons' = Blender-style directional bones (joint → " + "primary child); 'sticks' = thin lines." + ), + ), + ]), + IO.DynamicCombo.Option("openpose", [ + IO.Float.Input( + "marker_radius_m", default=0.010, min=0.005, max=0.1, step=0.001, + tooltip="Sphere radius in m.", + ), + IO.Float.Input( + "stick_radius_m", default=0.008, min=0.002, max=0.05, step=0.001, + tooltip="Limb half-width in m. Auto-clamped to bone_length x 0.1.", + ), + IO.Boolean.Input( + "include_hands", default=False, + tooltip=( + "Append 21+21 OpenPose hands (wrist + 5 fingers x 4 joints, " + "base→tip) sourced from pred_keypoints_3d." + ), + ), + IO.Float.Input( + "hand_marker_radius_m", default=0.005, min=0.001, max=0.1, step=0.001, + tooltip="Hand sphere radius in m.", + ), + IO.Float.Input( + "hand_stick_radius_m", default=0.003, min=0.001, max=0.05, step=0.001, + tooltip="Hand limb half-width in m.", + ), + IO.Combo.Input( + "face_source", + options=["off", "rig"], + default="off", + tooltip=( + "'rig' adds ~30 face-contour landmarks sampled from pred_vertices " + "at fixed head-mesh vertex IDs (brow/eyes/nose/mouth/jaw); needs " + "canonical_colors on pose_data." + ), + ), + IO.Float.Input( + "face_marker_radius_m", default=0.0, min=0.0, max=0.05, step=0.0005, + tooltip="Face dot radius. 0 = auto = 0.3 x marker_radius_m.", + ), + ]), + IO.DynamicCombo.Option("scail", [ + IO.Float.Input( + "stick_radius_m", default=0.022, min=0.002, max=0.1, step=0.001, + tooltip=( + "Cylinder radius in m. Bones are open cylinders at constant " + "radius; joint spheres (auto-sized to match) cap the open ends. " + "SCAIL reference = 0.0215 m." + ), + ), + IO.Float.Input( + "marker_radius_m", default=0.0, min=0.0, max=0.1, step=0.001, + tooltip="Joint sphere radius. 0 = auto = stick_radius_m (flush cap).", + ), + IO.Float.Input( + "material_roughness", default=0.3, min=0.0, max=1.0, step=0.05, + tooltip="PBR roughness. SCAIL ref = 0.3. 1 = matte; 0 = chrome.", + ), + IO.Boolean.Input( + "include_hands", default=False, + tooltip="Append 21+21 hand keypoints + capsule sticks per track.", + ), + IO.Float.Input( + "hand_marker_radius_m", default=0.005, min=0.001, max=0.05, step=0.001, + tooltip="Hand sphere radius in m.", + ), + IO.Float.Input( + "hand_stick_radius_m", default=0.003, min=0.001, max=0.05, step=0.001, + tooltip="Hand cylinder radius in m.", + ), + ]), + ], + tooltip=( + "'body_mesh' = real Armature (127 bones, skinning, TRS " + "keyframes, 72 face morphs; needs model). " + "'bones_only' = bone-shape primitives at each joint (preview armature). " + "'openpose' = OpenPose-18 3D skeleton from keypoints " + "(no model needed). 'scail' = SCAIL 3D capsule rig (open " + "cylinders capped flush by joint spheres)." + ), + ), + IO.Float.Input( + "fps", default=24.0, min=1.0, max=240.0, step=1.0, + tooltip="Animation frame rate.", + ), + IO.Combo.Input( + "camera_translation", + options=["off", "centered", "absolute"], + default="off", + tooltip=( + "Bake pred_cam_t into per-track root translation. " + "'off' = origin; 'centered' = delta from frame 0; " + "'absolute' = raw (Z is camera depth — usually meters away)." + ), + ), + IO.Int.Input( + "track_index", default=-1, min=-1, max=15, + tooltip="-1 = all tracks; ≥0 = single track.", + ), + ], + outputs=[IO.File3DGLB.Output("glb")], + ) + + @classmethod + def execute(cls, pose_data, mesh_style, sam3d_body_model=None, fps=24.0, camera_translation="off", track_index=-1) -> IO.NodeOutput: + mesh_style = mesh_style or {"mesh_style": "body_mesh"} + mode_key = mesh_style["mesh_style"] + # `shader` is nested in body_mesh; absent for bones_only. + shader_dict = mesh_style.get("shader") or {} + shader_key = shader_dict.get("shader", "default") + common = dict( + fps=float(fps), + camera_translation=str(camera_translation), + track_index=int(track_index), + shader=str(shader_key), + rainbow_tilt_x_deg=float(shader_dict.get("rainbow_tilt_x", 0.0)), + rainbow_tilt_z_deg=float(shader_dict.get("rainbow_tilt_z", 0.0)), + person_palette_falloff=float(shader_dict.get("person_palette_falloff", 0.6)), + ) + if mode_key in ("body_mesh", "bones_only"): + # External rigs (e.g. ComfyUI-Kimodo) supply pose_data["_skeleton_override"] + # so the GLB writer reads rig/bind/skin from there instead of MHR. + has_external_rig = isinstance(pose_data, dict) and ("_skeleton_override" in pose_data) + if sam3d_body_model is None and not has_external_rig: + raise ValueError( + f"BuildPoseGLB: '{mode_key}' mode needs the `sam3d_body_model` input OR a " + "`_skeleton_override` dict in pose_data. Connect the SAM3DBody model " + "or feed pose_data from a node that supplies the override (e.g. KimodoSample)." + ) + default_shape = "off" if mode_key == "body_mesh" else "octahedrons" + bone_vis_dict = mesh_style.get("bone_vis", {"bone_vis": default_shape}) + bone_vis = str(bone_vis_dict.get("bone_vis", default_shape)) + bone_vis_radius_m = float(bone_vis_dict.get("bone_vis_radius_m", 0.04)) + bone_vis_color = str(bone_vis_dict.get("bone_vis_color", "white")) + glb_bytes = build_glb_skeletal( + pose_data, sam3d_body_model, + bone_smooth_window=int(mesh_style.get("bone_smooth_window", 0)), + bone_vis=bone_vis, + bone_vis_radius_m=bone_vis_radius_m, + bone_vis_color=bone_vis_color, + include_body_mesh=(mode_key == "body_mesh"), + **common, + ) + elif mode_key == "openpose": + # Rig-independent: sourced from pred_keypoints_3d. face_source='rig' + # additionally reads canonical_colors for head-mesh vertex IDs. + glb_bytes = build_glb_openpose( + pose_data, + fps=float(fps), + camera_translation=str(camera_translation), + track_index=int(track_index), + marker_radius_m=float(mesh_style.get("marker_radius_m", 0.025)), + stick_radius_m=float(mesh_style.get("stick_radius_m", 0.008)), + include_hands=bool(mesh_style.get("include_hands", False)), + hand_marker_radius_m=float(mesh_style.get("hand_marker_radius_m", 0.005)), + hand_stick_radius_m=float(mesh_style.get("hand_stick_radius_m", 0.003)), + face_source=str(mesh_style.get("face_source", "off")), + face_marker_radius_m=float(mesh_style.get("face_marker_radius_m", 0.0)), + palette="openpose", + shape="ellipsoid", + ) + elif mode_key == "scail": + # SCAIL rig: open cylinders capped flush by joint spheres (sphere + # radius defaults to cylinder radius for a seamless silhouette). + cap_stick_radius = float(mesh_style.get("stick_radius_m", 0.022)) + cap_marker_radius = float(mesh_style.get("marker_radius_m", 0.0)) + if cap_marker_radius <= 0.0: + cap_marker_radius = cap_stick_radius + glb_bytes = build_glb_openpose( + pose_data, + fps=float(fps), + camera_translation=str(camera_translation), + track_index=int(track_index), + marker_radius_m=cap_marker_radius, + stick_radius_m=cap_stick_radius, + include_hands=bool(mesh_style.get("include_hands", False)), + hand_marker_radius_m=float(mesh_style.get("hand_marker_radius_m", 0.005)), + hand_stick_radius_m=float(mesh_style.get("hand_stick_radius_m", 0.003)), + face_source="off", + palette="scail", + shape="capsule", + smooth_shade=True, + # SCAIL material: slightly glossy (0.3) + double-sided so the + # inside of the open cylinders shades sensibly at grazing angles. + material_roughness=float(mesh_style.get("material_roughness", 0.3)), + material_double_sided=True, + ) + else: + raise ValueError(f"BuildPoseGLB: unknown mesh_style {mode_key!r}") + + return IO.NodeOutput(Types.File3D(BytesIO(glb_bytes), file_format="glb")) + + +class SavePoseBVH(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SavePoseBVH", + description="Save pose data as BVH mocap file", + display_name="Save Pose BVH", + category="3d", + is_output_node=True, + inputs=[ + IO.MultiType.Input( + "pose_data", types=[MHRPoseData, KimodoPoseData], + tooltip=( + "MHR pose data from SAM3DBody_Predict, or external-rig " + "pose data from Kimodo." + ), + ), + SAM3DBodyModel.Input("sam3d_body_model"), + IO.String.Input("filename_prefix", default="3d/ComfyUI"), + IO.Float.Input( + "fps", default=24.0, min=1.0, max=240.0, step=1.0, + tooltip="Animation frame rate (BVH `Frame Time`).", + ), + IO.Combo.Input( + "camera_translation", + options=["off", "centered", "absolute"], + default="off", + tooltip=( + "Bake pred_cam_t into the root's position channels. " + "'off' = bind position; 'centered' = delta from frame 0; " + "'absolute' = raw (Z is camera depth — usually meters away)." + ), + ), + IO.Combo.Input( + "units", + options=["cm", "m"], + default="cm", + tooltip="BVH OFFSET/position units. 'cm' is the mocap standard.", + ), + IO.Int.Input( + "track_index", default=0, min=0, max=15, + tooltip="Track to export. BVH carries one skeleton; export multi-person clips one at a time.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + outputs=[], + ) + + @classmethod + def execute(cls, pose_data, sam3d_body_model, filename_prefix="3d/ComfyUI", + fps=24.0, camera_translation="off", units="cm", + track_index=0) -> IO.NodeOutput: + bvh_bytes = build_bvh( + pose_data, sam3d_body_model, + fps=float(fps), + camera_translation=str(camera_translation), + track_index=int(track_index), + units=str(units), + ) + + full_output_folder, filename, counter, subfolder, _ = \ + folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory(), + ) + f = f"{filename}_{counter:05}_.bvh" + out_path = os.path.join(full_output_folder, f) + with open(out_path, "wb") as fh: + fh.write(bvh_bytes) + + return IO.NodeOutput(ui={"3d": [{ + "filename": f, + "subfolder": subfolder, + "type": "output", + }]}) + + class Save3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [SaveGLB] + return [SaveGLB, BuildPoseGLB, SavePoseBVH] async def comfy_entrypoint() -> Save3DExtension: diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py index 20d459b00..ebac4e829 100644 --- a/comfy_extras/nodes_sdpose.py +++ b/comfy_extras/nodes_sdpose.py @@ -2,11 +2,10 @@ import torch import comfy.utils import comfy.model_management import numpy as np -import math -import colorsys from tqdm import tqdm from typing_extensions import override from comfy_api.latest import ComfyExtension, io +from comfy_extras.pose.keypoint_draw import KeypointDraw from comfy_extras.nodes_lotus import LotusConditioning @@ -73,281 +72,6 @@ def _to_openpose_frames(all_keypoints, all_scores, height, width): return frames -class KeypointDraw: - """ - Pose keypoint drawing class that supports both numpy and cv2 backends. - """ - def __init__(self): - try: - import cv2 - self.draw = cv2 - except ImportError: - self.draw = self - - # Hand connections (same for both hands) - self.hand_edges = [ - [0, 1], [1, 2], [2, 3], [3, 4], # thumb - [0, 5], [5, 6], [6, 7], [7, 8], # index - [0, 9], [9, 10], [10, 11], [11, 12], # middle - [0, 13], [13, 14], [14, 15], [15, 16], # ring - [0, 17], [17, 18], [18, 19], [19, 20], # pinky - ] - - # Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed) - self.body_limbSeq = [ - [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], - [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], - [1, 16], [16, 18] - ] - - # Colors matching DWPose - self.colors = [ - [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], - [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], - [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], - [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85] - ] - - @staticmethod - def circle(canvas_np, center, radius, color, **kwargs): - """Draw a filled circle using NumPy vectorized operations.""" - cx, cy = center - h, w = canvas_np.shape[:2] - - radius_int = int(np.ceil(radius)) - - y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1) - x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1) - - if y_max <= y_min or x_max <= x_min: - return - - y, x = np.ogrid[y_min:y_max, x_min:x_max] - mask = (x - cx)**2 + (y - cy)**2 <= radius**2 - canvas_np[y_min:y_max, x_min:x_max][mask] = color - - @staticmethod - def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs): - """Draw line using Bresenham's algorithm with NumPy operations.""" - x0, y0, x1, y1 = *pt1, *pt2 - h, w = canvas_np.shape[:2] - dx, dy = abs(x1 - x0), abs(y1 - y0) - sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1) - err, x, y, line_points = dx - dy, x0, y0, [] - - while True: - line_points.append((x, y)) - if x == x1 and y == y1: - break - e2 = 2 * err - if e2 > -dy: - err, x = err - dy, x + sx - if e2 < dx: - err, y = err + dx, y + sy - - if thickness > 1: - radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5)) - for px, py in line_points: - y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1) - if y_max > y_min and x_max > x_min: - yy, xx = np.ogrid[y_min:y_max, x_min:x_max] - canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color - else: - line_points = np.array(line_points) - valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w) - if (valid_points := line_points[valid]).size: - canvas_np[valid_points[:, 1], valid_points[:, 0]] = color - - @staticmethod - def fillConvexPoly(canvas_np, pts, color, **kwargs): - """Fill polygon using vectorized scanline algorithm.""" - if len(pts) < 3: - return - pts = np.array(pts, dtype=np.int32) - h, w = canvas_np.shape[:2] - y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1) - if y_max <= y_min or x_max <= x_min: - return - yy, xx = np.mgrid[y_min:y_max, x_min:x_max] - mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool) - - for i in range(len(pts)): - p1, p2 = pts[i], pts[(i + 1) % len(pts)] - y1, y2 = p1[1], p2[1] - if y1 == y2: - continue - if y1 > y2: - p1, p2, y1, y2 = p2, p1, p2[1], p1[1] - if not (edge_mask := (yy >= y1) & (yy < y2)).any(): - continue - mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1)) - - canvas_np[y_min:y_max, x_min:x_max][mask] = color - - @staticmethod - def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs): - """Python implementation of cv2.ellipse2Poly.""" - axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output - angle = angle % 360 - if arc_start > arc_end: - arc_start, arc_end = arc_end, arc_start - while arc_start < 0: - arc_start, arc_end = arc_start + 360, arc_end + 360 - while arc_end > 360: - arc_end, arc_start = arc_end - 360, arc_start - 360 - if arc_end - arc_start > 360: - arc_start, arc_end = 0, 360 - - angle_rad = math.radians(angle) - alpha, beta = math.cos(angle_rad), math.sin(angle_rad) - pts = [] - for i in range(arc_start, arc_end + delta, delta): - theta_rad = math.radians(min(i, arc_end)) - x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad) - pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))]) - - unique_pts, prev_pt = [], (float('inf'), float('inf')) - for pt in pts: - if (pt_tuple := tuple(pt)) != prev_pt: - unique_pts.append(pt) - prev_pt = pt_tuple - - return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]] - - def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3, - draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3): - """ - Draw wholebody keypoints (134 keypoints after processing) in DWPose style. - - Expected keypoint format (after neck insertion and remapping): - - Body: 0-17 (18 keypoints in OpenPose format, neck at index 1) - - Foot: 18-23 (6 keypoints) - - Face: 24-91 (68 landmarks) - - Right hand: 92-112 (21 keypoints) - - Left hand: 113-133 (21 keypoints) - - Args: - canvas: The canvas to draw on (numpy array) - keypoints: Array of keypoint coordinates - scores: Optional confidence scores for each keypoint - threshold: Minimum confidence threshold for drawing keypoints - - Returns: - canvas: The canvas with keypoints drawn - """ - H, W, C = canvas.shape - - # Draw body limbs - if draw_body and len(keypoints) >= 18: - for i, limb in enumerate(self.body_limbSeq): - # Convert from 1-indexed to 0-indexed - idx1, idx2 = limb[0] - 1, limb[1] - 1 - - if idx1 >= 18 or idx2 >= 18: - continue - - if scores is not None: - if scores[idx1] < threshold or scores[idx2] < threshold: - continue - - Y = [keypoints[idx1][0], keypoints[idx2][0]] - X = [keypoints[idx1][1], keypoints[idx2][1]] - mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2 - length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) - - if length < 1: - continue - - angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) - - polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1) - - self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)]) - - # Draw body keypoints - if draw_body and len(keypoints) >= 18: - for i in range(18): - if scores is not None and scores[i] < threshold: - continue - x, y = int(keypoints[i][0]), int(keypoints[i][1]) - if 0 <= x < W and 0 <= y < H: - self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) - - # Draw foot keypoints (18-23, 6 keypoints) - if draw_feet and len(keypoints) >= 24: - for i in range(18, 24): - if scores is not None and scores[i] < threshold: - continue - x, y = int(keypoints[i][0]), int(keypoints[i][1]) - if 0 <= x < W and 0 <= y < H: - self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) - - # Draw right hand (92-112) - if draw_hands and len(keypoints) >= 113: - eps = 0.01 - for ie, edge in enumerate(self.hand_edges): - idx1, idx2 = 92 + edge[0], 92 + edge[1] - if scores is not None: - if scores[idx1] < threshold or scores[idx2] < threshold: - continue - - x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) - x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) - - if x1 > eps and y1 > eps and x2 > eps and y2 > eps: - if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: - # HSV to RGB conversion for rainbow colors - r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) - color = (int(r * 255), int(g * 255), int(b * 255)) - self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) - - # Draw right hand keypoints - for i in range(92, 113): - if scores is not None and scores[i] < threshold: - continue - x, y = int(keypoints[i][0]), int(keypoints[i][1]) - if x > eps and y > eps and 0 <= x < W and 0 <= y < H: - self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) - - # Draw left hand (113-133) - if draw_hands and len(keypoints) >= 134: - eps = 0.01 - for ie, edge in enumerate(self.hand_edges): - idx1, idx2 = 113 + edge[0], 113 + edge[1] - if scores is not None: - if scores[idx1] < threshold or scores[idx2] < threshold: - continue - - x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) - x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) - - if x1 > eps and y1 > eps and x2 > eps and y2 > eps: - if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: - # HSV to RGB conversion for rainbow colors - r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) - color = (int(r * 255), int(g * 255), int(b * 255)) - self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) - - # Draw left hand keypoints - for i in range(113, 134): - if scores is not None and i < len(scores) and scores[i] < threshold: - continue - x, y = int(keypoints[i][0]), int(keypoints[i][1]) - if x > eps and y > eps and 0 <= x < W and 0 <= y < H: - self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) - - # Draw face keypoints (24-91) - white dots only, no lines - if draw_face and len(keypoints) >= 92: - eps = 0.01 - for i in range(24, 92): - if scores is not None and scores[i] < threshold: - continue - x, y = int(keypoints[i][0]), int(keypoints[i][1]) - if x > eps and y > eps and 0 <= x < W and 0 <= y < H: - self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1) - - return canvas - class SDPoseDrawKeypoints(io.ComfyNode): @classmethod def define_schema(cls): diff --git a/comfy_extras/pose/keypoint_draw.py b/comfy_extras/pose/keypoint_draw.py new file mode 100644 index 000000000..505205f4e --- /dev/null +++ b/comfy_extras/pose/keypoint_draw.py @@ -0,0 +1,348 @@ +"""Pose keypoint drawing primitives shared across pose nodes. + +`KeypointDraw` exposes a cv2-or-numpy backend so callers can use one drawing +API regardless of whether OpenCV is installed: + + kd = KeypointDraw() + kd.draw.circle(canvas, (x, y), radius, color, thickness=-1) + kd.draw.line(canvas, p1, p2, color, thickness=4) + kd.draw.fillConvexPoly(canvas, polygon, color) + kd.draw.ellipse2Poly(center, axes, angle, 0, 360, 1) + +It also carries DWPose's body/hand topology + color tables, used by: + - comfy_extras.nodes_sdpose (SDPose pose drawing) + - comfy_extras.pose.export.openpose_2d (SAM 3D Body 2D pose viz) + - comfy_extras.pose.export.glb_shared (SAM 3D Body GLB tables) +""" + +import colorsys +import math + +import numpy as np + + +class KeypointDraw: + """ + Pose keypoint drawing class that supports both numpy and cv2 backends. + """ + def __init__(self): + try: + import cv2 + self.draw = cv2 + except ImportError: + self.draw = self + + # Hand connections (same for both hands) + self.hand_edges = [ + [0, 1], [1, 2], [2, 3], [3, 4], # thumb + [0, 5], [5, 6], [6, 7], [7, 8], # index + [0, 9], [9, 10], [10, 11], [11, 12], # middle + [0, 13], [13, 14], [14, 15], [15, 16], # ring + [0, 17], [17, 18], [18, 19], [19, 20], # pinky + ] + + # Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed) + self.body_limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], + [1, 16], [16, 18] + ] + + # Colors matching DWPose + self.colors = [ + [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], + [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], + [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85] + ] + + @staticmethod + def circle(canvas_np, center, radius, color, **kwargs): + """Draw a filled circle using NumPy vectorized operations.""" + cx, cy = center + h, w = canvas_np.shape[:2] + + radius_int = int(np.ceil(radius)) + + y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1) + x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1) + + if y_max <= y_min or x_max <= x_min: + return + + y, x = np.ogrid[y_min:y_max, x_min:x_max] + mask = (x - cx)**2 + (y - cy)**2 <= radius**2 + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs): + """Draw line using Bresenham's algorithm with NumPy operations.""" + x0, y0, x1, y1 = *pt1, *pt2 + h, w = canvas_np.shape[:2] + dx, dy = abs(x1 - x0), abs(y1 - y0) + sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1) + err, x, y, line_points = dx - dy, x0, y0, [] + + while True: + line_points.append((x, y)) + if x == x1 and y == y1: + break + e2 = 2 * err + if e2 > -dy: + err, x = err - dy, x + sx + if e2 < dx: + err, y = err + dx, y + sy + + if thickness > 1: + radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5)) + for px, py in line_points: + y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1) + if y_max > y_min and x_max > x_min: + yy, xx = np.ogrid[y_min:y_max, x_min:x_max] + canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color + else: + line_points = np.array(line_points) + valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w) + if (valid_points := line_points[valid]).size: + canvas_np[valid_points[:, 1], valid_points[:, 0]] = color + + @staticmethod + def fillConvexPoly(canvas_np, pts, color, **kwargs): + """Fill polygon using vectorized scanline algorithm.""" + if len(pts) < 3: + return + pts = np.array(pts, dtype=np.int32) + h, w = canvas_np.shape[:2] + y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1) + if y_max <= y_min or x_max <= x_min: + return + yy, xx = np.mgrid[y_min:y_max, x_min:x_max] + mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool) + + for i in range(len(pts)): + p1, p2 = pts[i], pts[(i + 1) % len(pts)] + y1, y2 = p1[1], p2[1] + if y1 == y2: + continue + if y1 > y2: + p1, p2, y1, y2 = p2, p1, p2[1], p1[1] + if not (edge_mask := (yy >= y1) & (yy < y2)).any(): + continue + mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1)) + + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs): + """Python implementation of cv2.ellipse2Poly.""" + axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output + angle = angle % 360 + if arc_start > arc_end: + arc_start, arc_end = arc_end, arc_start + while arc_start < 0: + arc_start, arc_end = arc_start + 360, arc_end + 360 + while arc_end > 360: + arc_end, arc_start = arc_end - 360, arc_start - 360 + if arc_end - arc_start > 360: + arc_start, arc_end = 0, 360 + + angle_rad = math.radians(angle) + alpha, beta = math.cos(angle_rad), math.sin(angle_rad) + pts = [] + for i in range(arc_start, arc_end + delta, delta): + theta_rad = math.radians(min(i, arc_end)) + x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad) + pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))]) + + unique_pts, prev_pt = [], (float('inf'), float('inf')) + for pt in pts: + if (pt_tuple := tuple(pt)) != prev_pt: + unique_pts.append(pt) + prev_pt = pt_tuple + + return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]] + + def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3, + draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, + stick_width=4, face_point_size=3, + marker_radius=4, hand_stick_width=2, hand_marker_radius=4, + limb_alpha=1.0, hand_dot_color=(0, 0, 255)): + """ + Draw wholebody keypoints (134 keypoints after processing) in DWPose style. + + Expected keypoint format (after neck insertion and remapping): + - Body: 0-17 (18 keypoints in OpenPose format, neck at index 1) + - Foot: 18-23 (6 keypoints) + - Face: 24-91 (68 landmarks) + - Right hand: 92-112 (21 keypoints) + - Left hand: 113-133 (21 keypoints) + + Args: + canvas: The canvas to draw on (numpy array) + keypoints: Array of keypoint coordinates + scores: Optional confidence scores for each keypoint + threshold: Minimum confidence threshold for drawing keypoints + stick_width: Body limb half-width (passed to ellipse2Poly). + face_point_size: Radius of the white face dots. + marker_radius: Radius of body/foot dots. Defaults to 4 (DWPose). + hand_stick_width: Thickness of hand limb lines. Defaults to 2. + hand_marker_radius: Radius of hand dots. Defaults to 4. + limb_alpha: Body-limb alpha blend (0..1). 1.0 = opaque fill (default), + <1.0 enables per-limb bbox-clipped alpha overlay (DWPose semantics + where overlapping limbs darken). + hand_dot_color: Either an (R, G, B) tuple/list of ints for solid-color + hand dots (default (0, 0, 255), DWPose blue), or a (21, 3) array + for per-keypoint hand-dot colors (OpenPose-style rainbow palette). + + Returns: + canvas: The canvas with keypoints drawn + """ + H, W, C = canvas.shape + + # Normalize hand_dot_color to a (21, 3) int array. + hdc_arr = np.asarray(hand_dot_color, dtype=int) + if hdc_arr.ndim == 1: + hdc_arr = np.tile(hdc_arr.reshape(1, 3), (21, 1)) + hand_dot_tuples = [tuple(int(c) for c in hdc_arr[i]) for i in range(21)] + + do_alpha = float(limb_alpha) < 1.0 + + # Draw body limbs + if draw_body and len(keypoints) >= 18: + for i, limb in enumerate(self.body_limbSeq): + # Convert from 1-indexed to 0-indexed + idx1, idx2 = limb[0] - 1, limb[1] - 1 + + if idx1 >= 18 or idx2 >= 18: + continue + + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + Y = [keypoints[idx1][0], keypoints[idx2][0]] + X = [keypoints[idx1][1], keypoints[idx2][1]] + mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2 + length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) + + if length < 1: + continue + + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1) + + color = self.colors[i % len(self.colors)] + if do_alpha: + _fill_poly_alpha(canvas, polygon, color, limb_alpha, self.draw) + else: + self.draw.fillConvexPoly(canvas, polygon, color) + + # Draw body keypoints + if draw_body and len(keypoints) >= 18: + for i in range(18): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), marker_radius, self.colors[i % len(self.colors)], thickness=-1) + + # Draw foot keypoints (18-23, 6 keypoints) + if draw_feet and len(keypoints) >= 24: + for i in range(18, 24): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), marker_radius, self.colors[i % len(self.colors)], thickness=-1) + + # Draw right hand (92-112) + if draw_hands and len(keypoints) >= 113: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 92 + edge[0], 92 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=hand_stick_width) + + # Draw right hand keypoints + for i in range(92, 113): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), hand_marker_radius, hand_dot_tuples[i - 92], thickness=-1) + + # Draw left hand (113-133) + if draw_hands and len(keypoints) >= 134: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 113 + edge[0], 113 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=hand_stick_width) + + # Draw left hand keypoints + for i in range(113, 134): + if scores is not None and i < len(scores) and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), hand_marker_radius, hand_dot_tuples[i - 113], thickness=-1) + + # Draw face keypoints (24-91) - white dots only, no lines + if draw_face and len(keypoints) >= 92: + eps = 0.01 + for i in range(24, 92): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1) + + return canvas + + +def _fill_poly_alpha(canvas, polygon, color, alpha, draw_backend): + """Bbox-clipped alpha-blended fillConvexPoly. `canvas` is mutated in-place. + + DWPose semantics: each limb blends with `alpha` independently so overlapping + limbs darken further. Operates on the polygon's bbox to avoid copying the + whole canvas per limb. + """ + H, W = canvas.shape[:2] + poly_arr = np.asarray(polygon, dtype=np.int32) + x0 = max(0, int(poly_arr[:, 0].min())) + xN = min(W, int(poly_arr[:, 0].max()) + 1) + y0 = max(0, int(poly_arr[:, 1].min())) + yN = min(H, int(poly_arr[:, 1].max()) + 1) + if xN <= x0 or yN <= y0: + return + local_poly = poly_arr - np.array([x0, y0], dtype=poly_arr.dtype) + roi = canvas[y0:yN, x0:xN].copy() + draw_backend.fillConvexPoly(roi, local_poly, color) + a = float(alpha) + canvas[y0:yN, x0:xN] = np.clip( + roi.astype(np.float32) * a + canvas[y0:yN, x0:xN].astype(np.float32) * (1.0 - a), + 0, 255, + ).astype(np.uint8) diff --git a/comfy_extras/sam3d_body/export/bvh.py b/comfy_extras/sam3d_body/export/bvh.py new file mode 100644 index 000000000..906631977 --- /dev/null +++ b/comfy_extras/sam3d_body/export/bvh.py @@ -0,0 +1,207 @@ +"""BVH export for SAM 3D Body pose_data. + +BVH stores explicit bone OFFSETs per joint, so any standard importer +(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations +directly — no heuristic guessing as needed for glTF. We skip the rig's joint 0 +(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos + +ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are +intrinsic Z-X-Y Euler degrees. +""" + +from __future__ import annotations + +import io +from typing import Any, Dict, List + +import numpy as np + +from .glb_shared import ( + bind_skel_state, + bone_locals_from_globals, + collect_tracks, + extract_rig_static, + global_skel_state_from_pose_data, + quat_sign_fix_per_joint, + unflip, +) + + +def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray: + """xyzw quat → intrinsic Z-X-Y Euler degrees, returned as (..., 3) in + (z, x, y) order to match BVH's `CHANNELS Zrotation Xrotation Yrotation`.""" + q = np.asarray(quat, dtype=np.float64) + x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] + # ZXY decomposition R = Rz(c)·Rx(a)·Ry(b): + # M[2][1] = 2(yz + xw) → sin(a) + # M[0][1] = 2(xy - zw) → -cos(a) sin(c) + # M[1][1] = 1 - 2(x² + z²) → cos(a) cos(c) + # M[2][0] = 2(xz - yw) → -cos(a) sin(b) + # M[2][2] = 1 - 2(x² + y²) → cos(a) cos(b) + M21 = np.clip(2.0 * (y * z + x * w), -1.0, 1.0) + M01 = 2.0 * (x * y - z * w) + M11 = 1.0 - 2.0 * (x * x + z * z) + M20 = 2.0 * (x * z - y * w) + M22 = 1.0 - 2.0 * (x * x + y * y) + a = np.arcsin(M21) + c = np.arctan2(-M01, M11) + b = np.arctan2(-M20, M22) + out = np.stack([np.rad2deg(c), np.rad2deg(a), np.rad2deg(b)], axis=-1) + return out.astype(np.float32) + + +def _find_bvh_root(parents: np.ndarray) -> int: + """First child of the rig's world anchor so the static origin→body stick + bone gets left out. Falls back to the first root joint.""" + NJ = parents.shape[0] + world_anchors = [j for j in range(NJ) + if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)] + if not world_anchors: + return 0 + children: List[List[int]] = [[] for _ in range(NJ)] + for j in range(NJ): + p = int(parents[j]) + if 0 <= p < NJ and p != j: + children[p].append(j) + wa = world_anchors[0] + if children[wa]: + return children[wa][0] + return wa + + +def _build_children_map(parents: np.ndarray) -> List[List[int]]: + NJ = parents.shape[0] + out: List[List[int]] = [[] for _ in range(NJ)] + for j in range(NJ): + p = int(parents[j]) + if 0 <= p < NJ and p != j: + out[p].append(j) + return out + + +def build_bvh( + pose_data: Dict[str, Any], + model: Any, + *, + fps: float = 24.0, + camera_translation: str = "off", + track_index: int = -1, + units: str = "cm", +) -> bytes: + """Build a BVH file from pose_data. Returns UTF-8 encoded text bytes. + + `units` is "cm" (default, standard mocap convention) or "m". Affects the + OFFSET and root-position values; rotations are independent of units. + """ + if units not in ("cm", "m"): + raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}") + unit_scale = 100.0 if units == "cm" else 1.0 + + rig_static = extract_rig_static(model) + NJ = int(rig_static["num_joints"]) + parents = rig_static["parents"] + frames = pose_data["frames"] + + tracks = collect_tracks(pose_data, track_index) + if not tracks: + raise ValueError("build_bvh: no valid tracks in pose_data") + person_k, frame_indices = tracks[0] + n_frames = len(frame_indices) + if n_frames == 0: + raise ValueError("build_bvh: track has zero frames") + + body_root = _find_bvh_root(parents) + children_map = _build_children_map(parents) + + # Bone OFFSETs come from MHR's translation_offsets (joint position + # relative to parent in parent's local-bind frame). For the BVH root, + # we use its bind world position so the skeleton sits at the right + # spot when imported. + bind_global = bind_skel_state(model) # (NJ, 8) cm + bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m + offset_m = rig_static["joint_translation_offsets"].astype(np.float64) * 0.01 + + # DFS order rooted at body_root — matches per-frame channel order. + bvh_order: List[int] = [] + def _visit(j: int) -> None: + bvh_order.append(j) + for c in children_map[j]: + _visit(c) + _visit(body_root) + + # Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative) + # rather than re-running rig.forward, then derive locals with body_root + # treated as the hierarchy root in BVH-space. + rig_global_m = global_skel_state_from_pose_data( + pose_data, frame_indices, person_k, NJ, + ) + rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7]) + bvh_parents = parents.copy() + bvh_parents[body_root] = -1 + bone_local = bone_locals_from_globals(rig_global_m, bvh_parents) + # Second pass catches sign discontinuities from the parent-inverse composition. + bone_local[..., 3:7] = quat_sign_fix_per_joint(bone_local[..., 3:7]) + + eulers_deg = _quat_to_zxy_euler_deg(bone_local[..., 3:7]) + + if camera_translation in ("absolute", "centered"): + cam_t = np.stack([ + unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32)) + for t in frame_indices + ], axis=0).astype(np.float64) + if camera_translation == "centered": + cam_t = cam_t - cam_t[0:1] + root_pos_m = bind_pos_m[body_root][None, :] + cam_t + else: + root_pos_m = np.tile(bind_pos_m[body_root], (n_frames, 1)) + + lines: List[str] = ["HIERARCHY"] + + def _emit_joint(j: int, depth: int, is_root: bool) -> None: + ind = " " * depth + keyword = "ROOT" if is_root else "JOINT" + name = "Hips" if is_root else f"joint_{j:03d}" + lines.append(f"{ind}{keyword} {name}") + lines.append(ind + "{") + o = (bind_pos_m[j] if is_root else offset_m[j]) * unit_scale + lines.append(f"{ind} OFFSET {o[0]:.6f} {o[1]:.6f} {o[2]:.6f}") + if is_root: + lines.append(ind + " CHANNELS 6 Xposition Yposition Zposition " + "Zrotation Xrotation Yrotation") + else: + lines.append(ind + " CHANNELS 3 Zrotation Xrotation Yrotation") + kids = children_map[j] + if kids: + for c in kids: + _emit_joint(c, depth + 1, is_root=False) + else: + # End Site (standard BVH spec) gives leaf bones a drawable length. + lines.append(ind + " End Site") + lines.append(ind + " {") + tip = (offset_m[j] * unit_scale) * 0.3 + tip_norm = float(np.linalg.norm(tip)) + if tip_norm < 0.5 * unit_scale * 0.01: # < 0.5 mm → fall back + tip = np.array([0.0, 0.05 * unit_scale, 0.0], dtype=np.float64) + lines.append(f"{ind} OFFSET {tip[0]:.6f} {tip[1]:.6f} {tip[2]:.6f}") + lines.append(ind + " }") + lines.append(ind + "}") + + _emit_joint(body_root, 0, is_root=True) + + lines.append("MOTION") + lines.append(f"Frames: {n_frames}") + lines.append(f"Frame Time: {1.0 / float(fps):.6f}") + + # Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per + # frame, columns in `bvh_order` order. Vectorized — savetxt's C-side + # formatting beats Python f-strings by ~10× on long clips. + non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64) + motion = np.concatenate([ + root_pos_m * unit_scale, # (N, 3) + eulers_deg[:, body_root].astype(np.float64), # (N, 3) + eulers_deg[:, non_root_idx, :].reshape(n_frames, -1), # (N, 3*(NJ-1)) + ], axis=1) + buf = io.StringIO() + np.savetxt(buf, motion, fmt="%.6f") + lines.append(buf.getvalue().rstrip("\n")) + + return ("\n".join(lines) + "\n").encode("utf-8") diff --git a/comfy_extras/sam3d_body/export/capsules.py b/comfy_extras/sam3d_body/export/capsules.py new file mode 100644 index 000000000..c9278f4fb --- /dev/null +++ b/comfy_extras/sam3d_body/export/capsules.py @@ -0,0 +1,403 @@ +"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent +torch ray-marching SDF renderer adapted to SAM3DBody pose_data. + +Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps), +projected through the per-person camera (`pred_cam_t` + `focal_length` + +image_size) so closer limbs appear thicker/brighter — the SCAIL-Pose +visual style. Self-contained: no dependency on the SCAIL-Pose package. + +Output: (H, W, 3) fp32 torch.Tensor in [0, 1]. +""" + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch + +import comfy.model_management + +from .glb_shared import ( + OPENPOSE_18_PAIRS, + OPENPOSE18_TO_MHR70, + OPENPOSE_RAINBOW_18, + SCAIL_LIMB_COLORS_17, + OPENPOSE_HAND_PAIRS, + OPENPOSE_HAND21_TO_MHR70_R, + OPENPOSE_HAND21_TO_MHR70_L, + OPENPOSE_HAND_COLORS_21, +) + + +def _limb_palette_rgb01(palette: str) -> np.ndarray: + """17 per-limb RGB colors in [0,1] for the OpenPose-18 body limbs.""" + if palette == "scail": + return SCAIL_LIMB_COLORS_17.astype(np.float32) + return OPENPOSE_RAINBOW_18[: len(OPENPOSE_18_PAIRS)].astype(np.float32) + + +def _build_specs_from_pose( + persons: List[Dict[str, Any]], + *, + include_hands: bool, + palette: str, + person_brightness_falloff: float = 0.0, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Flatten body + optional hand limbs for one frame into + (starts, ends, colors_rgba) in camera coords (Y-down, +Z forward). + Drops endpoints that are non-finite or behind the camera. + + `person_brightness_falloff` mixes each per-person limb color toward white + by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the + mesh rasterizer and GLB exporters.""" + starts: List[np.ndarray] = [] + ends: List[np.ndarray] = [] + colors: List[np.ndarray] = [] + + body_limb_colors = _limb_palette_rgb01(palette) + hand_limb_colors = OPENPOSE_HAND_COLORS_21.astype(np.float32) + + falloff = max(0.0, min(1.0, float(person_brightness_falloff))) + + for k, person in enumerate(persons): + kp2d_full = person.get("pred_keypoints_3d") + cam_t = person.get("pred_cam_t") + if kp2d_full is None or cam_t is None: + continue + kp = np.asarray(kp2d_full, dtype=np.float32) + if kp.ndim != 2 or kp.shape[1] != 3 or kp.shape[0] < 70: + continue + cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3) + # pred_keypoints_3d is camera frame (Y-down post-flip); add cam_t to + # place the subject in front of the camera. + kp_cam = kp + cam_t_np[None, :] + + pastel = 0.0 if k == 0 else (1.0 - falloff ** k) + + def _tint(rgb: np.ndarray) -> np.ndarray: + if pastel <= 0: + return rgb + return rgb * (1.0 - pastel) + pastel + + # SCAIL drops the 4 face bones (13..16: nose↔eyes, eyes→ears) — in the + # reference, NLF leaves those COCO slots at zero so its `sum==0` skip + # silently culls them. The grey neck limb (12) blends spine direction + # (mid-hip → neck, stable) with the neck→nose direction at 60/40 so + # the stub tracks head pose lightly without flapping around like full + # nose direction does. + body_limb_count = 13 if palette == "scail" else len(OPENPOSE_18_PAIRS) + body_kp = kp_cam[OPENPOSE18_TO_MHR70] + spine_dir = None + if palette == "scail": + mid_hip = 0.5 * (body_kp[8] + body_kp[11]) # 8=RHip, 11=LHip + sd = body_kp[1] - mid_hip # 1=Neck + sd_len = float(np.linalg.norm(sd)) + if np.all(np.isfinite(sd)) and sd_len > 1e-6: + spine_dir = sd / sd_len + for limb_i, (a, b) in enumerate(OPENPOSE_18_PAIRS[:body_limb_count]): + sa, sb = body_kp[a], body_kp[b] + if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))): + continue + if sa[2] <= 0 or sb[2] <= 0: + continue + if palette == "scail" and limb_i == 12: + nose_vec = sb - sa + nose_len = float(np.linalg.norm(nose_vec)) + if nose_len > 1e-6 and spine_dir is not None: + nose_dir = nose_vec / nose_len + mixed = 0.6 * spine_dir + 0.4 * nose_dir + mixed = mixed / max(float(np.linalg.norm(mixed)), 1e-6) + sb = sa + mixed * (nose_len * 0.5) + elif nose_len > 1e-6: + sb = sa + nose_vec * 0.5 + elif spine_dir is not None: + sb = sa + spine_dir * (sd_len * 0.3) + starts.append(sa) + ends.append(sb) + color_rgb = _tint(body_limb_colors[limb_i]) + colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0], + dtype=np.float32)) + + if include_hands: + r_kp = kp_cam[OPENPOSE_HAND21_TO_MHR70_R] + l_kp = kp_cam[OPENPOSE_HAND21_TO_MHR70_L] + for limb_i, (a, b) in enumerate(OPENPOSE_HAND_PAIRS): + for hand_kp in (r_kp, l_kp): + sa, sb = hand_kp[a], hand_kp[b] + if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))): + continue + if sa[2] <= 0 or sb[2] <= 0: + continue + starts.append(sa) + ends.append(sb) + color_rgb = _tint(hand_limb_colors[(a + b) % len(hand_limb_colors)]) + colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0], + dtype=np.float32)) + + if not starts: + return (np.zeros((0, 3), dtype=np.float32), + np.zeros((0, 3), dtype=np.float32), + np.zeros((0, 4), dtype=np.float32)) + return (np.stack(starts).astype(np.float32), + np.stack(ends).astype(np.float32), + np.stack(colors).astype(np.float32)) + + +def _ray_capsule_t( + ray_dirs: torch.Tensor, # (K, 3) unit rays from camera origin + starts: torch.Tensor, # (M, 3) + ends: torch.Tensor, # (M, 3) + ba_norm: torch.Tensor, # (M, 3) unit axis (A → B) + ba_len: torch.Tensor, # (M,) segment length + radius: float, +) -> torch.Tensor: + """Closed-form ray-capsule intersection. Returns (K, M) tensor of ray + parameters t to the nearest valid hit per capsule, +inf where the ray + misses. A capsule is the union of (cylinder body, hemisphere at A, + hemisphere at B); each component is a quadratic root-find.""" + INF = float("inf") + r_sq = float(radius) * float(radius) + + # Cached dot products. + dn = ray_dirs @ ba_norm.transpose(0, 1) # (K, M) — d·n + dA = ray_dirs @ starts.transpose(0, 1) # (K, M) — d·A + dB = ray_dirs @ ends.transpose(0, 1) # (K, M) — d·B + An = (starts * ba_norm).sum(-1) # (M,) — A·n + A_sq = (starts * starts).sum(-1) # (M,) — |A|² + B_sq = (ends * ends).sum(-1) # (M,) — |B|² + + # Cylinder body: project onto plane ⊥ n and solve |P_⊥(t)|² = r². + a_c = 1.0 - dn * dn # (K, M) + b_c = -2.0 * (dA - dn * An) # (K, M) + c_c = A_sq - An * An - r_sq # (M,) + disc_c = b_c * b_c - 4.0 * a_c * c_c + safe_a = a_c.clamp(min=1e-9) + sqrt_c = torch.sqrt(disc_c.clamp(min=0.0)) + t_cyl = (-b_c - sqrt_c) / (2.0 * safe_a) + s_cyl = t_cyl * dn - An # axial projection from A + cyl_ok = (disc_c >= 0) & (a_c > 1e-7) & (t_cyl > 1e-6) & \ + (s_cyl >= 0.0) & (s_cyl <= ba_len) + t_cyl = torch.where(cyl_ok, t_cyl, torch.full_like(t_cyl, INF)) + + # Sphere at A, restricted to the hemisphere with axial projection ≤ 0. + disc_a = dA * dA - (A_sq - r_sq) + sqrt_a = torch.sqrt(disc_a.clamp(min=0.0)) + t_sa = dA - sqrt_a + s_a = t_sa * dn - An + a_ok = (disc_a >= 0) & (t_sa > 1e-6) & (s_a <= 0.0) + t_sa = torch.where(a_ok, t_sa, torch.full_like(t_sa, INF)) + + # Sphere at B, restricted to the hemisphere with axial projection ≥ ba_len. + disc_b = dB * dB - (B_sq - r_sq) + sqrt_b = torch.sqrt(disc_b.clamp(min=0.0)) + t_sb = dB - sqrt_b + s_b = t_sb * dn - An + b_ok = (disc_b >= 0) & (t_sb > 1e-6) & (s_b >= ba_len) + t_sb = torch.where(b_ok, t_sb, torch.full_like(t_sb, INF)) + + return torch.minimum(torch.minimum(t_cyl, t_sa), t_sb) + + +def _render_capsules_torch( + starts: torch.Tensor, + ends: torch.Tensor, + colors: torch.Tensor, + H: int, W: int, + fx: float, fy: float, cx: float, cy: float, + radius: float, + background_rgb: Optional[torch.Tensor], + device: torch.device, +) -> torch.Tensor: + """Analytic ray-capsule renderer for a union of capsules. Camera at + origin looking down +Z; pixels in y-down screen coords.""" + M = int(starts.shape[0]) + if M == 0: + if background_rgb is not None: + return background_rgb.to(device=device, dtype=torch.float32).clamp(0.0, 1.0) + return torch.zeros(H, W, 3, dtype=torch.float32, device=device) + + yy, xx = torch.meshgrid( + torch.arange(H, device=device, dtype=torch.float32), + torch.arange(W, device=device, dtype=torch.float32), + indexing="ij", + ) + u = (xx - cx) / fx + v = (yy - cy) / fy + z = torch.ones_like(u) + ray_dirs = torch.stack([u, v, z], dim=-1) + ray_dirs = ray_dirs / torch.linalg.norm(ray_dirs, dim=-1, keepdim=True) + flat_dirs = ray_dirs.view(-1, 3) + N = flat_dirs.shape[0] + + ba = ends - starts + ba_len = torch.linalg.norm(ba, dim=1).clamp(min=1e-6) + ba_norm = ba / ba_len.unsqueeze(1) + + z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item())) + z_near = max(0.05, z_min - radius) + + # Union of per-capsule screen-space bboxes. Pixels outside this mask + # provably can't hit any capsule, so the analytic intersection only runs + # on the relevant subset of the canvas (~5-15% at 1080p for typical poses). + sz = starts[:, 2].clamp(min=z_near) + ez = ends[:, 2].clamp(min=z_near) + sx_p = starts[:, 0] * fx / sz + cx + sy_p = starts[:, 1] * fy / sz + cy + ex_p = ends[:, 0] * fx / ez + cx + ey_p = ends[:, 1] * fy / ez + cy + # Projected radius using the closer endpoint — conservative bbox. + r_pix = radius * fx / torch.minimum(sz, ez) + pad = 2.0 + xmin_t = (torch.minimum(sx_p, ex_p) - r_pix - pad).floor().long().clamp(min=0, max=W) + xmax_t = (torch.maximum(sx_p, ex_p) + r_pix + pad).ceil().long().clamp(min=0, max=W) + ymin_t = (torch.minimum(sy_p, ey_p) - r_pix - pad).floor().long().clamp(min=0, max=H) + ymax_t = (torch.maximum(sy_p, ey_p) + r_pix + pad).ceil().long().clamp(min=0, max=H) + # One stack→tolist sync amortizes the GPU→CPU read over all M bboxes. + bboxes_cpu = torch.stack([xmin_t, ymin_t, xmax_t, ymax_t], dim=1).tolist() + coarse_mask = torch.zeros(H, W, dtype=torch.bool, device=device) + for xmin_i, ymin_i, xmax_i, ymax_i in bboxes_cpu: + if xmax_i > xmin_i and ymax_i > ymin_i: + coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True + + # Analytic ray-capsule intersection. One pass over the masked pixels — + # the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel + # plus 6 SDF evaluations per hit pixel for finite-difference normals. + INF = float("inf") + flat_t = torch.full((N,), INF, device=device, dtype=torch.float32) + flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long) + active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1) + if active_idx.numel() > 0: + # Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory + # manageable when both K (image pixels) and M (capsules) are large. + chunk_max = max(1, int(4_000_000 / max(M, 1))) + for i0 in range(0, active_idx.numel(), chunk_max): + sub = active_idx[i0 : i0 + chunk_max] + t_KM = _ray_capsule_t( + flat_dirs[sub], starts, ends, ba_norm, ba_len, radius, + ) + t_min, m_idx = t_KM.min(dim=1) + hit = t_min < INF + if hit.any(): + winners = sub[hit] + flat_t[winners] = t_min[hit] + flat_m_idx[winners] = m_idx[hit] + + # Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade. + out = torch.zeros((N, 3), dtype=torch.float32, device=device) + if background_rgb is not None: + out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone() + hit_idx = torch.nonzero(flat_m_idx >= 0, as_tuple=False).squeeze(1) + if hit_idx.numel() > 0: + rd = flat_dirs[hit_idx] + t_h = flat_t[hit_idx] + m_h = flat_m_idx[hit_idx] + p_hit = rd * t_h.unsqueeze(-1) + + A_h = starts[m_h] + n_h = ba_norm[m_h] + L_h = ba_len[m_h] + proj = ((p_hit - A_h) * n_h).sum(-1).clamp(min=0.0) + proj = torch.minimum(proj, L_h) + C_h = A_h + proj.unsqueeze(-1) * n_h + normals = p_hit - C_h + normals = normals / normals.norm(dim=-1, keepdim=True).clamp(min=1e-8) + + col = colors[m_h, :3] + # SCAIL shading (render_torch.py:290-331). Light from camera (+Z toward + # subject); diffuse term `N·-L` simplifies to `-N.z`. Specular uses the + # proper Blinn-Phong half-vector `(view + (-L))` — using `diff` as a + # shortcut would lock the highlight to image center. + diff = torch.clamp(-(normals[:, 2]), min=0.0) + diffuse = 0.45 + 0.55 * diff + + view_dir = -rd + half_dir = view_dir.clone() + half_dir[:, 2] -= 1.0 + half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8) + spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32) + + # SCAIL's reference depth fade uses mm-scale constants (`z_max + 6000`) + # that translate to almost no fade in our meter units — `depth_factor` + # stays ~0.85-1.0. Matching that with a mild ramp. + z_vals = p_hit[:, 2] + z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item()) + if z_hi - z_lo > 1e-6: + fade = 1.0 - (z_vals - z_lo) / (z_hi - z_lo) + depth_factor = 0.85 + 0.15 * fade + else: + depth_factor = torch.ones_like(z_vals) + + base = col * diffuse.unsqueeze(-1) * depth_factor.unsqueeze(-1) + highlight = (0.5 * spec * depth_factor).unsqueeze(-1) + out[hit_idx] = base + highlight + + return out.view(H, W, 3).clamp(0.0, 1.0) + + +def render_pose_data_capsules( + pose_data: Dict[str, Any], + *, + frame_idx: int, + W: int, + H: int, + background: Optional[torch.Tensor] = None, + composite: str = "over", + radius_m: float = 0.025, + include_hands: bool = False, + palette: str = "scail", + person_brightness_falloff: float = 0.0, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Render a frame's pose_data as 3D capsules projected through the per- + person camera. Returns (H, W, 3) fp32 in [0, 1]. + + `composite='over'` paints over `background` (black if None); + `composite='mesh_only'` always uses a black canvas. + + `radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`). + Camera fx/fy come from each person's `focal_length` (pixels); cx/cy = center. + """ + persons = pose_data["frames"][frame_idx] + if device is None: + device = comfy.model_management.get_torch_device() + + # SAM3DBody shares one camera across the clip — pick from the first valid person. + fx = fy = float(min(H, W)) + for person in persons: + f = person.get("focal_length") + if f is None: + continue + fx = fy = float(np.asarray(f, dtype=np.float32).reshape(-1)[0]) + break + cx, cy = W * 0.5, H * 0.5 + + starts_np, ends_np, colors_np = _build_specs_from_pose( + persons, include_hands=include_hands, palette=palette, + person_brightness_falloff=person_brightness_falloff, + ) + + bg_t: Optional[torch.Tensor] = None + if composite == "over" and background is not None: + bg_t = background.to(device=device, dtype=torch.float32).clamp(0.0, 1.0) + if bg_t.shape[:2] != (H, W): + bg_t = bg_t.permute(2, 0, 1).unsqueeze(0) + bg_t = torch.nn.functional.interpolate( + bg_t, size=(H, W), mode="bilinear", align_corners=False, + ) + bg_t = bg_t.squeeze(0).permute(1, 2, 0).contiguous() + + if starts_np.shape[0] == 0: + if bg_t is not None: + return bg_t + return torch.zeros(H, W, 3, dtype=torch.float32, device=device) + + starts_t = torch.from_numpy(starts_np).to(device=device, dtype=torch.float32) + ends_t = torch.from_numpy(ends_np).to(device=device, dtype=torch.float32) + colors_t = torch.from_numpy(colors_np).to(device=device, dtype=torch.float32) + + return _render_capsules_torch( + starts_t, ends_t, colors_t, + H=H, W=W, fx=fx, fy=fy, cx=cx, cy=cy, + radius=float(radius_m), + background_rgb=bg_t, + device=device, + ) diff --git a/comfy_extras/sam3d_body/export/glb_openpose.py b/comfy_extras/sam3d_body/export/glb_openpose.py new file mode 100644 index 000000000..602edde09 --- /dev/null +++ b/comfy_extras/sam3d_body/export/glb_openpose.py @@ -0,0 +1,1138 @@ +"""GLB export — OpenPose 18-keypoint visualization mode. + +Independent of the MHR rig — sourced from pose_data's `pred_keypoints_3d` +(the model's regressed surface keypoints). Each track becomes an armature +with a sibling joint per keypoint; sphere markers + stick/capsule limbs are +skinned to those joints. + +Optional hand keypoints (also from `pred_keypoints_3d`, indices 21..62) and +face landmarks (sampled from `pred_vertices` at fixed head-mesh vertex IDs) +extend the same armature. + +OpenPose-shared tables / palettes / mappings live in `glb_shared.py` and are +imported below — they're also used by the 2D and 3D renderers in this package. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from .glb_shared import ( + DWPOSE_HAND_COLORS_21, + FACE_LANDMARK_COLORS, + FACE_LANDMARK_TARGETS, + GLBWriter, + OPENPOSE18_TO_MHR70, + OPENPOSE_18_NAMES, + OPENPOSE_18_PAIRS, + OPENPOSE_HAND21_NAMES, + OPENPOSE_HAND21_TO_MHR70_L, + OPENPOSE_HAND21_TO_MHR70_R, + OPENPOSE_HAND_COLORS_21, + OPENPOSE_HAND_PAIRS, + OPENPOSE_RAINBOW_18, + SCAIL_KEYPOINT_COLORS_18, + SCAIL_LIMB_COLORS_17, + collect_tracks, + flat_shade_mesh, + make_lit_material, + quat_sign_fix_per_joint, + rotation_align, + rotmat_to_quat_np, + select_face_landmark_vert_ids, + smooth_shade_mesh, + unflip, + uv_sphere_unit, +) + + +def _finalize_skinned_mesh( + verts: np.ndarray, faces: np.ndarray, + joints: np.ndarray, weights: np.ndarray, vert_colors: np.ndarray, + smooth_shade: bool, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Apply smooth or flat shading to an indexed sphere/stick group mesh and + pack per-vertex colors. Smooth keeps the indexed mesh + per-vertex colors; + flat duplicates verts per face and gathers face-corner colors.""" + if smooth_shade: + v_f, n_f, f_f, j_f, w_f = smooth_shade_mesh(verts, faces, joints, weights) + return v_f, n_f, f_f, j_f, w_f, vert_colors.astype(np.float32) + F = faces.shape[0] + pre_faces = faces.copy() + v_f, n_f, f_f, j_f, w_f = flat_shade_mesh(verts, faces, joints, weights) + c_f = np.zeros((F * 3, 3), dtype=np.float32) + for k in range(3): + c_f[k::3] = vert_colors[pre_faces[:, k]] + return v_f, n_f, f_f, j_f, w_f, c_f + + +def _pair_colors_from_kp( + pairs: Tuple[Tuple[int, int], ...], kp_colors: np.ndarray, endpoint: int = 1, +) -> np.ndarray: + """Per-limb color = endpoint-vertex color from `kp_colors`. Default + `endpoint=1` picks the second (distal) vertex of each pair, which is + the OpenPose-canonical per-finger gradient when fingers go base→tip + (wrist=0 → thumb1=1 → thumb2=2 …).""" + n = len(pairs) + out = np.zeros((n, 3), dtype=np.float32) + for i, (a, b) in enumerate(pairs): + out[i] = kp_colors[b if endpoint == 1 else a] + return out + + +def _openpose_bind_at_rig_rest( + pose_data: Dict[str, Any], *, + include_hands: bool, face_vert_ids: Optional[np.ndarray], +) -> Optional[np.ndarray]: + """OpenPose keypoint positions at the rig's REST pose (T-pose at authoring + origin), built from the `_skeleton_override`'s `bind_global_m` (joint rest + TRS) and `rest_verts_m` (mesh rest verts for face landmarks). + + Used as the static-bind for openpose-mode geometry so the GLB's static + POSITION attribute sits at rig origin — matching skeletal mode's bind and + producing the same 'snap from rest to scene-frame-0' transition at the + start of playback. Without this, the static geometry is at scene-frame-0 + (kp_seq[0]) and viewers that auto-fit on static POSITION will center on + the scene location, hiding the per-frame motion. + + Returns None when the override is missing or doesn't carry all the needed + mappings — caller falls back to per-frame extraction (kp_seq[0]).""" + override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None + if override is None or "bind_global_m" not in override: + return None + op18 = override.get("openpose18_joint_indices") + if op18 is None: + return None + rest_pos = np.asarray(override["bind_global_m"], dtype=np.float32)[:, :3] + op18_w = override.get("openpose18_joint_weights") + parts: List[np.ndarray] = [ + _resolve_openpose_keypoints_from_joints( + rest_pos, np.asarray(op18, dtype=np.int64), + weights=None if op18_w is None else np.asarray(op18_w, dtype=np.float32), + ) + ] + if include_hands: + op21_r = override.get("openpose_hand21_r_joint_indices") + op21_l = override.get("openpose_hand21_l_joint_indices") + if op21_r is None or op21_l is None: + return None + op21_r_w = override.get("openpose_hand21_r_joint_weights") + op21_l_w = override.get("openpose_hand21_l_joint_weights") + parts.append(_resolve_openpose_keypoints_from_joints( + rest_pos, np.asarray(op21_r, dtype=np.int64), + weights=None if op21_r_w is None else np.asarray(op21_r_w, dtype=np.float32), + )) + parts.append(_resolve_openpose_keypoints_from_joints( + rest_pos, np.asarray(op21_l, dtype=np.int64), + weights=None if op21_l_w is None else np.asarray(op21_l_w, dtype=np.float32), + )) + if face_vert_ids is not None: + rest_verts = override.get("rest_verts_m") + if rest_verts is None: + return None + parts.append(np.asarray(rest_verts, dtype=np.float32)[face_vert_ids]) + return np.concatenate(parts, axis=0).astype(np.float32) + + +def _resolve_openpose_keypoints_from_joints( + joints: np.ndarray, mapping: np.ndarray, + weights: Optional[np.ndarray] = None, +) -> np.ndarray: + """Resolve a `(K, 2)` joint-index → keypoint mapping against a per-frame + `(J, 3)` joint-position array. + + Row `(a, b)` with `b == -1` uses `joints[a]` directly (any weight ignored). + Row `(a, b)` with `b >= 0` returns `w * joints[a] + (1 - w) * joints[b]`: + - default (weights=None): `w = 0.5` → plain midpoint, useful for + keypoints that genuinely lie between two joints (Nose ≈ midpoint of + eyes). + - explicit `w` outside [0, 1] EXTRAPOLATES past the line segment, which + is how we approximate landmarks that have no rig joint AND no + in-between joint pair (Ears ≈ RightEye + (RightEye − LeftEye), i.e. + `w_a = 2.0` along the eye→ear axis).""" + a = mapping[:, 0].astype(np.int64) + b = mapping[:, 1].astype(np.int64) + pos_a = joints[a] + has_b = b >= 0 + if not has_b.any(): + return pos_a.astype(np.float32, copy=False) + b_safe = np.where(has_b, b, a) + pos_b = joints[b_safe] + if weights is None: + w_a = np.where(has_b, 0.5, 1.0).astype(np.float32) + else: + w_a = np.where(has_b, np.asarray(weights, dtype=np.float32), 1.0) + w_b = (1.0 - w_a).astype(np.float32) + out = pos_a * w_a[:, None] + pos_b * w_b[:, None] + return out.astype(np.float32, copy=False) + + +def _extract_openpose_keypoints( + pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, +) -> np.ndarray: + """(N, 18, 3) OpenPose keypoint positions in rig-native Y-up metres. + + Two sources, in priority order: + + 1. **External-skeleton path** — when pose_data has `_skeleton_override` + with `openpose18_joint_indices` ((18, 2) int32, see + `_resolve_openpose_keypoints_from_joints`), synthesize from each + person's `pred_joint_coords` directly. The override frame is already + rig-native Y-up, so no axis flip. + 2. **MHR70 path** (default for SAM3DBody_Predict output) — re-index the + first 70 of 308 MHR keypoints (`pred_keypoints_3d`) to COCO-18. + Stored y-down (post `j3d[..., [1,2]] *= -1` in sam3d_body), so we + un-flip y/z to match rig-native Y-up. + """ + frames = pose_data["frames"] + N = len(frame_indices) + out = np.zeros((N, 18, 3), dtype=np.float32) + + override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None + op18 = override.get("openpose18_joint_indices") if override is not None else None + if op18 is not None: + op18 = np.asarray(op18, dtype=np.int64) + if op18.ndim != 2 or op18.shape != (18, 2): + raise ValueError( + "build_glb_openpose: `openpose18_joint_indices` in " + "`_skeleton_override` must be shape (18, 2); got " + f"{tuple(op18.shape)}. Each row is (joint_a, joint_b); " + "use joint_b=-1 for single-joint keypoints." + ) + op18_w = override.get("openpose18_joint_weights") + if op18_w is not None: + op18_w = np.asarray(op18_w, dtype=np.float32) + if op18_w.shape != (18,): + raise ValueError( + "build_glb_openpose: `openpose18_joint_weights` must be " + f"shape (18,); got {tuple(op18_w.shape)}." + ) + for t_idx, t in enumerate(frame_indices): + person = frames[t][person_k] + if "pred_joint_coords" not in person: + raise ValueError( + "build_glb_openpose: external-skeleton path needs " + "per-frame `pred_joint_coords` (J, 3) on each person; " + f"missing at frame={t}, track={person_k}." + ) + joints = np.asarray(person["pred_joint_coords"], dtype=np.float32) + out[t_idx] = _resolve_openpose_keypoints_from_joints( + joints, op18, weights=op18_w, + ) + return out + + for t_idx, t in enumerate(frame_indices): + person = frames[t][person_k] + if "pred_keypoints_3d" not in person: + # Diagnose the source: external-skeleton producers ship + # `_skeleton_override` instead of MHR70 keypoints. If the + # producer didn't populate `openpose18_joint_indices` either, + # we can't synthesize the 18-keypoint set. + if override is not None: + raise ValueError( + "build_glb_openpose: this pose_data carries " + "`_skeleton_override` but it doesn't include " + "`openpose18_joint_indices` and the per-frame person " + "dict is missing `pred_keypoints_3d`. Ask the upstream " + "node to populate `openpose18_joint_indices` on the " + "override (a (18, 2) int32 mapping into its joint list), " + "or switch SAM3DBody_ToGLB to `skeletal` mode." + ) + present_keys = sorted(person.keys()) + raise ValueError( + "build_glb_openpose: pose_data is missing " + "`pred_keypoints_3d` (frame=%d, track=%d). Keys present " + "on this person: %s. Re-run SAM3DBody_Predict — older " + "saved pose_data may pre-date the field, and any " + "intermediate node that rebuilds person dicts must " + "preserve it." + % (t, person_k, present_keys) + ) + kp = np.asarray(person["pred_keypoints_3d"], dtype=np.float32) + out[t_idx] = kp[OPENPOSE18_TO_MHR70] + out[..., 1] *= -1.0 + out[..., 2] *= -1.0 + return out + + +def _extract_openpose_hand_keypoints( + pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, +) -> np.ndarray: + """(N, 42, 3) right+left OpenPose hand keypoints (21 + 21) in rig-native + Y-up frame. + + External-skeleton path: requires `openpose_hand21_r_joint_indices` AND + `openpose_hand21_l_joint_indices` ((21, 2) int32 each) in the override. + Resolved against per-frame `pred_joint_coords` like the body path. + + MHR70 path: re-orders `pred_keypoints_3d` indices 21..62 to OpenPose-21 + (wrist + 5 fingers, thumb→pinky, base→tip).""" + frames = pose_data["frames"] + N = len(frame_indices) + out = np.zeros((N, 42, 3), dtype=np.float32) + + override = pose_data.get("_skeleton_override") if isinstance(pose_data, dict) else None + op21_r = override.get("openpose_hand21_r_joint_indices") if override is not None else None + op21_l = override.get("openpose_hand21_l_joint_indices") if override is not None else None + if override is not None and (op21_r is not None or op21_l is not None): + if op21_r is None or op21_l is None: + raise ValueError( + "build_glb_openpose: external skeleton must supply BOTH " + "`openpose_hand21_r_joint_indices` and " + "`openpose_hand21_l_joint_indices` for include_hands=True." + ) + op21_r = np.asarray(op21_r, dtype=np.int64) + op21_l = np.asarray(op21_l, dtype=np.int64) + for arr, side in ((op21_r, "r"), (op21_l, "l")): + if arr.ndim != 2 or arr.shape != (21, 2): + raise ValueError( + f"build_glb_openpose: `openpose_hand21_{side}_joint_indices` " + f"must be shape (21, 2); got {tuple(arr.shape)}." + ) + op21_r_w = override.get("openpose_hand21_r_joint_weights") + op21_l_w = override.get("openpose_hand21_l_joint_weights") + op21_r_w = (np.asarray(op21_r_w, dtype=np.float32) + if op21_r_w is not None else None) + op21_l_w = (np.asarray(op21_l_w, dtype=np.float32) + if op21_l_w is not None else None) + for t_idx, t in enumerate(frame_indices): + person = frames[t][person_k] + if "pred_joint_coords" not in person: + raise ValueError( + "build_glb_openpose: external-skeleton path needs " + "per-frame `pred_joint_coords` for hands." + ) + joints = np.asarray(person["pred_joint_coords"], dtype=np.float32) + out[t_idx, :21] = _resolve_openpose_keypoints_from_joints( + joints, op21_r, weights=op21_r_w, + ) + out[t_idx, 21:] = _resolve_openpose_keypoints_from_joints( + joints, op21_l, weights=op21_l_w, + ) + return out + + for t_idx, t in enumerate(frame_indices): + person = frames[t][person_k] + if "pred_keypoints_3d" not in person: + if override is not None: + raise ValueError( + "build_glb_openpose: include_hands=True with an external " + "skeleton needs `openpose_hand21_r_joint_indices` and " + "`openpose_hand21_l_joint_indices` on `_skeleton_override`. " + "Disable hands or ask the upstream node to populate them." + ) + raise ValueError( + "build_glb_openpose: pose_data is missing `pred_keypoints_3d`." + ) + kp = np.asarray(person["pred_keypoints_3d"], dtype=np.float32) + out[t_idx, :21] = kp[OPENPOSE_HAND21_TO_MHR70_R] + out[t_idx, 21:] = kp[OPENPOSE_HAND21_TO_MHR70_L] + out[..., 1] *= -1.0 + out[..., 2] *= -1.0 + return out + + +def _extract_face_landmarks_from_verts( + pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, + vert_ids: np.ndarray, +) -> np.ndarray: + """(N, K_face, 3) face landmarks sampled from per-frame `pred_vertices` + at the supplied head-mesh vertex IDs, unflipped to MHR-native Y-up. + Each landmark inherits per-frame shape/expr/pose deformation for free + since `pred_vertices` already has it baked in.""" + frames = pose_data["frames"] + N = len(frame_indices) + K = int(vert_ids.shape[0]) + out = np.zeros((N, K, 3), dtype=np.float32) + for t_idx, t in enumerate(frame_indices): + person = frames[t][person_k] + if "pred_vertices" not in person: + raise ValueError( + "build_glb_openpose: face_source='rig' needs `pred_vertices` " + "on every frame — re-run Predict to populate it." + ) + v = np.asarray(person["pred_vertices"], dtype=np.float32).reshape(-1, 3) + out[t_idx] = v[vert_ids] + out[..., 1] *= -1.0 + out[..., 2] *= -1.0 + return out + + +def _build_openpose_spheres( + bind_kp_m: np.ndarray, radius_m: float, kp_colors: np.ndarray, + base_joint_idx: int = 0, + smooth_shade: bool = False, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """UV sphere per OpenPose keypoint, rigidly skinned to that keypoint's + joint, vertex-colored from kp_colors. `base_joint_idx` is added to the + emitted JOINTS_0 indices so callers can place this group at any offset + in the shared skin (body=0, right hand=18, etc.). + + `smooth_shade=True` keeps the indexed mesh and writes per-vertex + normals via face-normal averaging — round shading on the spheres. + `smooth_shade=False` (default) flat-shades by duplicating verts per + face, matching the existing OpenPose-mode look. Returns + (verts, normals, faces, joints4, weights4, vert_colors).""" + sv, sf = uv_sphere_unit() + K = bind_kp_m.shape[0] + Nv = sv.shape[0] + Nf = sf.shape[0] + out_v = np.zeros((K * Nv, 3), dtype=np.float32) + out_n = np.zeros((K * Nv, 3), dtype=np.float32) + out_f = np.zeros((K * Nf, 3), dtype=np.uint32) + out_j = np.zeros((K * Nv, 4), dtype=np.uint16) + out_w = np.zeros((K * Nv, 4), dtype=np.float32) + out_c = np.zeros((K * Nv, 3), dtype=np.float32) + for j in range(K): + v_off = j * Nv + out_v[v_off:v_off + Nv] = sv * radius_m + bind_kp_m[j] + out_n[v_off:v_off + Nv] = sv + out_f[j * Nf:(j + 1) * Nf] = sf + v_off + out_j[v_off:v_off + Nv, 0] = j + base_joint_idx + out_w[v_off:v_off + Nv, 0] = 1.0 + out_c[v_off:v_off + Nv] = kp_colors[j] + return _finalize_skinned_mesh(out_v, out_f, out_j, out_w, out_c, smooth_shade) + + +def _capsule_mesh_local( + L: float, W: float, *, + n_cap_lat: Optional[int] = None, + n_body: Optional[int] = None, + n_lon: Optional[int] = None, + end_width_frac: float = 0.3, + shape: str = "ellipsoid", +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Build a per-limb mesh in limb-local frame along +Y from y=0 (head + pole) to y=L (tail pole). + + `shape` selects the silhouette: + - 'ellipsoid' (default): tips are small hemispheres of radius + `W * end_width_frac`; body has ellipsoidal radius profile + sin(π*u) from w_end at the junctions to W at the middle. Gives + a fat-middle / narrow-end stretched-ellipse look. + - 'capsule': SCAIL-style "rig" limb — an OPEN cylinder of constant + radius W with no hemisphere caps. Pair with sphere joint markers + of the same radius so the spheres seamlessly cap the open + cylinder ends (the cylinder cross-section ring at the endpoint + lies exactly on the sphere surface). Drawing hemisphere caps + inside the joint sphere creates a visible bump where the cap + pokes out unevenly when sphere radius ≠ cap radius — open + cylinders avoid that. + + Per-limb mesh is required because the cap height (w_end) depends on + the limb width — a single canonical mesh can't produce true + hemispheres for arbitrary L:W ratios in ellipsoid mode. + + Returns: + verts: (Nv, 3) float32 — limb-local positions in meters. + faces: (Nf, 3) uint32 — triangle indices. + weights: (Nv, 2) float32 — (head, tail) skinning weights, linearly + interpolated by axial position (sums to 1). + """ + W = max(1e-6, min(float(W), float(L) * 0.5 - 1e-6)) + if str(shape) == "capsule": + # SCAIL-style "rig" limb: an OPEN cylinder of constant radius W, + # no hemisphere caps. The sphere joint markers at each endpoint + # provide the rounded ends of the bone — when sphere_radius == + # cylinder_radius, the cylinder cross-section ring at the bone + # endpoint lies exactly on the sphere surface, so silhouette is + # seamless. Hemisphere caps would create a visible bump where + # the cap pokes out of the sphere if cap_r ≠ marker_r, so we + # omit them entirely. + cap_r = 0.0 + body_r = W + if n_cap_lat is None: n_cap_lat = 0 + if n_body is None: n_body = 0 + if n_lon is None: n_lon = 16 + elif str(shape) == "ellipsoid": + end_frac = float(min(0.95, max(0.05, end_width_frac))) + cap_r = max(1e-7, W * end_frac) + body_r = W + # Ellipsoid defaults: more body rings to sample the sin(π·u) curve. + if n_cap_lat is None: n_cap_lat = 3 + if n_body is None: n_body = 7 + if n_lon is None: n_lon = 12 + else: + raise ValueError( + f"_capsule_mesh_local: unknown shape={shape!r} " + "(expected 'ellipsoid' or 'capsule')" + ) + if 2.0 * cap_r >= L: + cap_r = max(0.0, L * 0.5 - 1e-6) + body_len = float(L) - 2.0 * cap_r + n_cap_lat = max(0, int(n_cap_lat)) + n_body = max(0, int(n_body)) + n_lon = max(3, int(n_lon)) + + has_caps = n_cap_lat > 0 + + verts: List[List[float]] = [] + head_pole = -1 + tail_pole = -1 + head_rings: List[int] = [] + tail_rings: List[int] = [] + + if has_caps: + # Head pole vertex at y=0 (south pole of head hemisphere). + head_pole = len(verts) + verts.append([0.0, 0.0, 0.0]) + # Head cap rings (i = 1..n_cap_lat). Last ring (i=n_cap_lat, + # theta=π/2) is the head-body junction at y=cap_r, r=cap_r. + for i in range(1, n_cap_lat + 1): + theta = (np.pi * 0.5) * i / n_cap_lat + y = cap_r * (1.0 - np.cos(theta)) + r = cap_r * np.sin(theta) + head_rings.append(len(verts)) + for k in range(n_lon): + phi = 2.0 * np.pi * k / n_lon + verts.append([r * float(np.cos(phi)), float(y), r * float(np.sin(phi))]) + else: + # Open cylinder: no caps, no pole. Add an end ring at y=0 directly. + head_rings.append(len(verts)) + for k in range(n_lon): + phi = 2.0 * np.pi * k / n_lon + verts.append([body_r * float(np.cos(phi)), 0.0, body_r * float(np.sin(phi))]) + + # Body intermediate rings (between the cap junctions for capped meshes, + # between the two end rings for open cylinders). For 'capsule' mode + # n_body=0 by default — no intermediate rings needed for a constant- + # radius cylinder. + body_rings: List[int] = [] + is_ellipsoid = str(shape) == "ellipsoid" + for j in range(1, n_body + 1): + u = j / (n_body + 1) + y = cap_r + body_len * u + if is_ellipsoid: + r = cap_r + (body_r - cap_r) * float(np.sin(np.pi * u)) + else: + r = body_r + body_rings.append(len(verts)) + for k in range(n_lon): + phi = 2.0 * np.pi * k / n_lon + verts.append([r * float(np.cos(phi)), float(y), r * float(np.sin(phi))]) + + if has_caps: + # Tail cap rings (i = 0..n_cap_lat-1). First ring (i=0, theta=π/2) + # is the body-tail junction at y=L-cap_r, r=cap_r; last + # (i=n_cap_lat-1) is the ring just before the pole. + for i in range(0, n_cap_lat): + theta = (np.pi * 0.5) * (n_cap_lat - i) / n_cap_lat + y = float(L) - cap_r * (1.0 - np.cos(theta)) + r = cap_r * np.sin(theta) + tail_rings.append(len(verts)) + for k in range(n_lon): + phi = 2.0 * np.pi * k / n_lon + verts.append([r * float(np.cos(phi)), float(y), r * float(np.sin(phi))]) + tail_pole = len(verts) + verts.append([0.0, float(L), 0.0]) + else: + # Open cylinder end ring at y=L. + tail_rings.append(len(verts)) + for k in range(n_lon): + phi = 2.0 * np.pi * k / n_lon + verts.append([body_r * float(np.cos(phi)), float(L), body_r * float(np.sin(phi))]) + + faces: List[List[int]] = [] + + if has_caps: + # Head pole fan — outward (-Y) normal at the south pole. + r0 = head_rings[0] + for k in range(n_lon): + a = r0 + k + b = r0 + (k + 1) % n_lon + faces.append([head_pole, a, b]) + + # All inter-ring quads in axial order. + all_rings = head_rings + body_rings + tail_rings + for i in range(len(all_rings) - 1): + rl = all_rings[i] + rh = all_rings[i + 1] + for k in range(n_lon): + a = rl + k + b = rl + (k + 1) % n_lon + c = rh + (k + 1) % n_lon + d = rh + k + faces.append([a, c, b]) + faces.append([a, d, c]) + + if has_caps: + # Tail pole fan — outward (+Y) normal at the north pole. + rL = tail_rings[-1] + for k in range(n_lon): + a = rL + k + b = rL + (k + 1) % n_lon + faces.append([tail_pole, b, a]) + + v_arr = np.asarray(verts, dtype=np.float32) + weights = np.zeros((v_arr.shape[0], 2), dtype=np.float32) + weights[:, 1] = np.clip(v_arr[:, 1] / max(float(L), 1e-12), 0.0, 1.0) + weights[:, 0] = 1.0 - weights[:, 1] + + return v_arr, np.asarray(faces, dtype=np.uint32), weights + + +def _openpose_limb_rest_trs( + bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...], +) -> Tuple[np.ndarray, np.ndarray]: + """Per-limb rest TRS: + midpoints (K_pairs, 3): rest midpoint between bind_kp_m[a] and bind_kp_m[b]. + rest_axes (K_pairs, 3): unit direction a→b at rest (or +Y if degenerate). + Caller uses `midpoints` as each limb joint's rest translation (rotation = + identity), and `rest_axes` to compute per-frame alignment rotations.""" + K_pairs = len(pairs) + mid = np.zeros((K_pairs, 3), dtype=np.float32) + axis = np.zeros((K_pairs, 3), dtype=np.float32) + axis[:, 1] = 1.0 + for k, (a, b) in enumerate(pairs): + a_pos = bind_kp_m[a] + b_pos = bind_kp_m[b] + mid[k] = 0.5 * (a_pos + b_pos) + d = b_pos - a_pos + n = float(np.linalg.norm(d)) + if n > 1e-9: + axis[k] = d / n + return mid, axis + + +def _openpose_limb_anim_trs( + kp_seq: np.ndarray, pairs: Tuple[Tuple[int, int], ...], rest_axes: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray]: + """Per-frame limb TRS: + anim_mid (N, K_pairs, 3): midpoint of (kp_seq[t][a], kp_seq[t][b]). + anim_quat (N, K_pairs, 4): rotation (xyzw) that aligns each limb's rest + axis to its frame-t axis. + Together with rest TRS, this drives `skin_matrix(t) = T(mid_t) * R_t * + T(-mid_rest)` so each capsule rigidly rotates about its rest midpoint to + track the limb's current direction — no LBS cross-section thinning.""" + N = kp_seq.shape[0] + K_pairs = len(pairs) + anim_mid = np.zeros((N, K_pairs, 3), dtype=np.float32) + R = np.tile(np.eye(3, dtype=np.float32), (N, K_pairs, 1, 1)) + for k, (a, b) in enumerate(pairs): + ax_rest = rest_axes[k] + for t in range(N): + a_pos = kp_seq[t, a] + b_pos = kp_seq[t, b] + anim_mid[t, k] = 0.5 * (a_pos + b_pos) + d = b_pos - a_pos + n = float(np.linalg.norm(d)) + if n > 1e-9: + R[t, k] = rotation_align(ax_rest, d / n) + quat = rotmat_to_quat_np(R).astype(np.float32) # (N, K_pairs, 4) xyzw + return anim_mid, quat + + +def _build_openpose_sticks( + bind_kp_m: np.ndarray, pairs: Tuple[Tuple[int, int], ...], + half_width_m: float, pair_colors: np.ndarray, + limb_joint_base_idx: int = 0, + shape: str = "ellipsoid", + smooth_shade: bool = False, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Capsule (cylinder + hemispherical caps) per limb pair (a, b). + + Each limb gets its own mesh sized to that limb's length and width so + the caps are TRUE hemispheres of radius `half_width_eff` — the limb + silhouette is rounded-rectangle-like, regardless of L:W ratio. Width + auto-clamped to `length * 0.1` so short limbs (face/ear) don't look + chunky next to long ones. + + Skinning: rigid (weight=1) binding to a per-limb joint at + `limb_joint_base_idx + limb_idx` — the caller animates that joint with + midpoint translation + rest-to-current rotation so each capsule rotates + rigidly with its limb (avoids translation-only LBS cross-section + thinning). Returns flat-shaded (verts, normals, faces, joints4, + weights4, vert_colors).""" + canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32) + + out_v_chunks: List[np.ndarray] = [] + out_f_chunks: List[np.ndarray] = [] + out_j_chunks: List[np.ndarray] = [] + out_w_chunks: List[np.ndarray] = [] + out_c_chunks: List[np.ndarray] = [] + v_total = 0 + WIDTH_RATIO = 0.1 + MIN_WIDTH = 0.001 + is_capsule = str(shape) == "capsule" + for limb_idx, (a, b) in enumerate(pairs): + head = bind_kp_m[a] + tail = bind_kp_m[b] + direction = tail - head + length = float(np.linalg.norm(direction)) + if length < 1e-6: + continue + unit_dir = direction / length + R = rotation_align(canonical, unit_dir) + if is_capsule: + # SCAIL-style uniform radius — every bone gets the same width. + # `_capsule_mesh_local` clamps internally to L/2-eps so very + # short bones don't go degenerate. + half_width_eff = max(MIN_WIDTH, half_width_m) + else: + # Ellipsoid mode: original auto-thinning so short face/ear + # limbs don't look chunky next to long body limbs. + half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m)) + + v_local, f_local, _weights_unused = _capsule_mesh_local( + length, half_width_eff, shape=shape, + ) + v_world = v_local @ R.T + head + Nv = v_local.shape[0] + + # Rigid binding to the per-limb joint. The 2-bone (head, tail) weights + # from `_capsule_mesh_local` are discarded — they're translation-only + # under glTF LBS and don't rotate the cross-section, causing visible + # thinning when the limb axis changes between rest and animated pose. + j_arr = np.zeros((Nv, 4), dtype=np.uint16) + j_arr[:, 0] = limb_idx + limb_joint_base_idx + w_arr = np.zeros((Nv, 4), dtype=np.float32) + w_arr[:, 0] = 1.0 + c_arr = np.tile(pair_colors[limb_idx], (Nv, 1)).astype(np.float32) + + out_v_chunks.append(v_world) + out_f_chunks.append(f_local + v_total) + out_j_chunks.append(j_arr) + out_w_chunks.append(w_arr) + out_c_chunks.append(c_arr) + v_total += Nv + + if not out_v_chunks: + return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32), + np.zeros((0, 3), dtype=np.uint32), np.zeros((0, 4), dtype=np.uint16), + np.zeros((0, 4), dtype=np.float32), np.zeros((0, 3), dtype=np.float32)) + + verts = np.concatenate(out_v_chunks, axis=0) + faces = np.concatenate(out_f_chunks, axis=0) + joints = np.concatenate(out_j_chunks, axis=0) + weights = np.concatenate(out_w_chunks, axis=0) + colors = np.concatenate(out_c_chunks, axis=0) + return _finalize_skinned_mesh(verts, faces, joints, weights, colors, smooth_shade) + + +def build_glb_openpose( + pose_data: Dict[str, Any], + *, + fps: float = 24.0, + camera_translation: str = "off", + track_index: int = -1, + marker_radius_m: float = 0.025, + stick_radius_m: float = 0.008, + include_hands: bool = False, + hand_marker_radius_m: float = 0.0, + hand_stick_radius_m: float = 0.0, + hand_color_style: str = "dwpose", + face_source: str = "off", + face_marker_radius_m: float = 0.0, + palette: str = "openpose", + shape: str = "ellipsoid", + smooth_shade: bool = False, + material_roughness: float = 0.85, + material_double_sided: bool = False, +) -> bytes: + """Build a GLB containing an OpenPose-style 3D skeleton — sphere markers + per keypoint plus rainbow-colored sticks between standard limb pairs. + Body keypoints are sourced from pose_data's `pred_keypoints_3d` (no rig + forward needed). Optional hand keypoints (also from `pred_keypoints_3d`) + and face landmarks (sampled from `pred_vertices` at fixed head-mesh + vertex IDs) extend the same per-track armature. + + Args: + include_hands: append the standard 21+21 OpenPose hand keypoints to + each track's armature (right hand at MHR70 indices 21..41, + left at 42..62). + hand_marker_radius_m: per-hand sphere radius. 0 = auto = 0.4 × + `marker_radius_m` (hand keypoints are anatomically smaller than + body joints; matches DWPose's smaller hand dots). + hand_stick_radius_m: per-hand limb half-width. 0 = auto = 0.5 × + `stick_radius_m`. + hand_color_style: 'dwpose' (default) = solid-blue hand dots, + rainbow per-finger sticks (controlnet_aux/dwpose convention); + 'openpose' = rainbow per-finger dots AND sticks (matches + poseParameters.cpp::HAND_COLORS_RENDER). + face_source: 'off' (default) | 'rig' — when 'rig', adds ~30 face + contour landmarks sampled from `pred_vertices` at vertex IDs + picked from `pose_data["canonical_colors"]["positions"]`. + face_marker_radius_m: per-face landmark sphere radius. 0 = auto = + 0.3 × `marker_radius_m` — face landmarks are densely packed + around the eyes/mouth/jaw and need to be much smaller than + body keypoints to keep the layout legible. Face landmarks are + rendered as standalone dots (no contour lines), matching + DWPose's face_pose draw style. + palette: body color scheme. 'openpose' = standard rainbow gradient + per keypoint (canonical OpenPose convention); 'scail' = + SCAIL-Pose style — warm hues right side, cool hues left side, + grey neck-to-nose centerline, distinct per-limb colors. + """ + if str(palette) == "scail": + body_sphere_colors = SCAIL_KEYPOINT_COLORS_18 + body_stick_colors = SCAIL_LIMB_COLORS_17 + elif str(palette) == "openpose": + # Existing OpenPose behavior: same rainbow array used for both + # spheres (per-keypoint) and sticks (per-limb, indexed 0..16 of + # the 18-element rainbow — yields a legible per-limb gradient). + body_sphere_colors = OPENPOSE_RAINBOW_18 + body_stick_colors = OPENPOSE_RAINBOW_18 + else: + raise ValueError( + f"build_glb_openpose: unknown palette={palette!r} " + "(expected 'openpose' or 'scail')" + ) + + if float(hand_marker_radius_m) <= 0.0: + hand_marker_radius_m = float(marker_radius_m) * 0.4 + if float(hand_stick_radius_m) <= 0.0: + hand_stick_radius_m = float(stick_radius_m) * 0.5 + if float(face_marker_radius_m) <= 0.0: + face_marker_radius_m = float(marker_radius_m) * 0.3 + if hand_color_style == "dwpose": + hand_sphere_colors = DWPOSE_HAND_COLORS_21 + elif hand_color_style == "openpose": + hand_sphere_colors = OPENPOSE_HAND_COLORS_21 + else: + raise ValueError( + f"build_glb_openpose: unknown hand_color_style=" + f"{hand_color_style!r} (expected 'dwpose' or 'openpose')" + ) + tracks = collect_tracks(pose_data, track_index) + if not tracks: + raise ValueError("build_glb_openpose: no valid tracks in pose_data") + + face_vert_ids: Optional[np.ndarray] = None + if face_source == "rig": + canonical_colors = pose_data.get("canonical_colors") or {} + positions = canonical_colors.get("positions") + if positions is None: + raise ValueError( + "build_glb_openpose: face_source='rig' needs " + "pose_data['canonical_colors']['positions'] (computed at " + "model load and attached by Predict). Ensure the SAM3DBody " + "Loader+Predict ran upstream of this node." + ) + face_vert_ids = select_face_landmark_vert_ids( + np.asarray(positions), + face_mask=canonical_colors.get("face_mask"), + ) + elif face_source != "off": + raise ValueError( + f"build_glb_openpose: unknown face_source={face_source!r} " + "(expected 'off' or 'rig')" + ) + + K_body = 18 + K_hands = 42 if include_hands else 0 + K_face = int(face_vert_ids.shape[0]) if face_vert_ids is not None else 0 + K = K_body + K_hands + K_face + + # Limb counts: one joint per stick pair. Limb joints carry translation + + # rotation so each capsule rotates rigidly with its limb (no LBS thinning). + K_body_limbs = len(OPENPOSE_18_PAIRS) + K_hand_limbs = len(OPENPOSE_HAND_PAIRS) if include_hands else 0 + K_limbs = K_body_limbs + 2 * K_hand_limbs # face has no sticks + + # Joint name list mirrors the keypoint stacking order: body → hands → face. + joint_names: List[str] = [f"openpose_{n}" for n in OPENPOSE_18_NAMES] + if include_hands: + joint_names.extend([f"openpose_R_{n}" for n in OPENPOSE_HAND21_NAMES]) + joint_names.extend([f"openpose_L_{n}" for n in OPENPOSE_HAND21_NAMES]) + if K_face > 0: + joint_names.extend([f"openpose_face_{name}" + for name, _ in FACE_LANDMARK_TARGETS]) + + # Limb joint names, stacked body → R-hand → L-hand to match the limb + # joint ordering in skin.joints (after the K keypoint joints). + limb_names: List[str] = [ + f"openpose_limb_{OPENPOSE_18_NAMES[a]}_{OPENPOSE_18_NAMES[b]}" + for (a, b) in OPENPOSE_18_PAIRS + ] + if include_hands: + for side in ("R", "L"): + for (a, b) in OPENPOSE_HAND_PAIRS: + limb_names.append( + f"openpose_{side}hand_limb_" + f"{OPENPOSE_HAND21_NAMES[a]}_{OPENPOSE_HAND21_NAMES[b]}" + ) + + w = GLBWriter() + nodes: List[dict] = [] + meshes: List[dict] = [] + skins: List[dict] = [] + materials: List[dict] = [] + animations: List[dict] = [] + scene_root_indices: List[int] = [] + + for track_i, (person_k, frame_indices) in enumerate(tracks): + body_seq = _extract_openpose_keypoints(pose_data, frame_indices, person_k) + n_frames = body_seq.shape[0] + if n_frames == 0: + continue + + seq_chunks: List[np.ndarray] = [body_seq] + if include_hands: + seq_chunks.append(_extract_openpose_hand_keypoints( + pose_data, frame_indices, person_k)) + if face_vert_ids is not None: + seq_chunks.append(_extract_face_landmarks_from_verts( + pose_data, frame_indices, person_k, face_vert_ids)) + kp_seq = np.concatenate(seq_chunks, axis=1) # (N, K, 3) + + # Static-bind = rig's REST pose when available (override path); else + # fall back to frame 0 of the motion. The rest-pose bind makes the + # GLB's static POSITION attribute sit at rig origin, so viewers + # auto-fit/center on rig origin and the animation visibly snaps from + # rest to scene-frame-0 — matching skeletal mode's behavior. Without + # this, openpose's static geometry is at scene-frame-0 and viewers + # mis-center on the scene location, masking the motion entirely. + bind_kp_m_rest = _openpose_bind_at_rig_rest( + pose_data, include_hands=include_hands, face_vert_ids=face_vert_ids, + ) + bind_kp_m = (bind_kp_m_rest if bind_kp_m_rest is not None + else kp_seq[0].astype(np.float32)) + + person_root: Dict[str, Any] = {"name": f"track{track_i:02d}", "children": []} + nodes.append(person_root) + person_root_idx = len(nodes) - 1 + scene_root_indices.append(person_root_idx) + + # K keypoint joint nodes (spheres bind here, rigid translation only). + joint_node_indices: List[int] = [] + for j in range(K): + nodes.append({ + "name": joint_names[j], + "translation": bind_kp_m[j].tolist(), + "rotation": [0.0, 0.0, 0.0, 1.0], + "scale": [1.0, 1.0, 1.0], + }) + joint_node_indices.append(len(nodes) - 1) + person_root["children"].extend(joint_node_indices) + + # Per-limb REST TRS (midpoint + axis) and per-frame TRS (midpoint + + # quaternion that aligns rest-axis → frame-t-axis). Sticks bind + # rigidly to these joints so each capsule rotates with its limb. + limb_rest_mids_list: List[np.ndarray] = [] + limb_rest_axes_list: List[np.ndarray] = [] + limb_anim_mids_list: List[np.ndarray] = [] + limb_anim_quats_list: List[np.ndarray] = [] + rmid_b, raxis_b = _openpose_limb_rest_trs(bind_kp_m[:K_body], OPENPOSE_18_PAIRS) + amid_b, aquat_b = _openpose_limb_anim_trs(kp_seq[:, :K_body], OPENPOSE_18_PAIRS, raxis_b) + limb_rest_mids_list.append(rmid_b) + limb_rest_axes_list.append(raxis_b) + limb_anim_mids_list.append(amid_b) + limb_anim_quats_list.append(aquat_b) + if include_hands: + for h_off in (K_body, K_body + 21): + rmid_h, raxis_h = _openpose_limb_rest_trs( + bind_kp_m[h_off:h_off + 21], OPENPOSE_HAND_PAIRS, + ) + amid_h, aquat_h = _openpose_limb_anim_trs( + kp_seq[:, h_off:h_off + 21], OPENPOSE_HAND_PAIRS, raxis_h, + ) + limb_rest_mids_list.append(rmid_h) + limb_rest_axes_list.append(raxis_h) + limb_anim_mids_list.append(amid_h) + limb_anim_quats_list.append(aquat_h) + limb_rest_mids = np.concatenate(limb_rest_mids_list, axis=0) # (K_limbs, 3) + limb_anim_mids = np.concatenate(limb_anim_mids_list, axis=1) # (N, K_limbs, 3) + limb_anim_quats = np.concatenate(limb_anim_quats_list, axis=1) # (N, K_limbs, 4) + # Hemisphere-align consecutive quats per limb so LINEAR interpolation + # takes the short path (otherwise large per-frame rotations can flip + # signs and produce visible "twist back" artifacts mid-playback). + limb_anim_quats = quat_sign_fix_per_joint(limb_anim_quats).astype(np.float32) + + limb_joint_indices: List[int] = [] + for k in range(K_limbs): + nodes.append({ + "name": limb_names[k], + "translation": limb_rest_mids[k].tolist(), + "rotation": [0.0, 0.0, 0.0, 1.0], + "scale": [1.0, 1.0, 1.0], + }) + limb_joint_indices.append(len(nodes) - 1) + person_root["children"].extend(limb_joint_indices) + + # Combined skin: keypoint joints (IBM = T(-bind_kp_m)) then limb joints + # (IBM = T(-limb_rest_mid)). Both yield identity skin_matrix at rest. + all_joint_indices = joint_node_indices + limb_joint_indices + ibm = np.tile(np.eye(4, dtype=np.float32), (K + K_limbs, 1, 1)) + ibm[:K, :3, 3] = -bind_kp_m + if K_limbs > 0: + ibm[K:K + K_limbs, :3, 3] = -limb_rest_mids + ibm_acc = w.add_mat4_f32(ibm.transpose(0, 2, 1).astype(np.float32)) + skins.append({ + "joints": all_joint_indices, + "inverseBindMatrices": ibm_acc, + "skeleton": person_root_idx, + }) + skin_idx = len(skins) - 1 + + # Per-group geometry. Spheres bind to keypoint joints (base_joint_idx + # ∈ [0, K)); sticks bind to limb joints (limb_joint_base_idx ∈ + # [K, K + K_limbs)). Groups stack body → right hand → left hand → + # face for keypoint joints, and body → R-hand → L-hand for limbs. + group_meshes: List[Tuple[np.ndarray, np.ndarray, np.ndarray, + np.ndarray, np.ndarray, np.ndarray]] = [] + sp = _build_openpose_spheres( + bind_kp_m[:K_body], float(marker_radius_m), + body_sphere_colors, base_joint_idx=0, + smooth_shade=smooth_shade, + ) + st = _build_openpose_sticks( + bind_kp_m[:K_body], OPENPOSE_18_PAIRS, float(stick_radius_m), + body_stick_colors, limb_joint_base_idx=K, # body limbs start at K + shape=shape, + smooth_shade=smooth_shade, + ) + group_meshes.append(sp) + group_meshes.append(st) + + if include_hands: + # Hand stick colors stay rainbow per-finger regardless of + # `hand_color_style` — only the sphere dots switch to solid + # blue under 'dwpose'. Matches controlnet_aux/dwpose/util.py. + hand_pair_colors = _pair_colors_from_kp( + OPENPOSE_HAND_PAIRS, OPENPOSE_HAND_COLORS_21, endpoint=1, + ) + for hand_i, h_off in enumerate((K_body, K_body + 21)): # right, then left + h_bind = bind_kp_m[h_off:h_off + 21] + group_meshes.append(_build_openpose_spheres( + h_bind, float(hand_marker_radius_m), + hand_sphere_colors, base_joint_idx=h_off, + smooth_shade=smooth_shade, + )) + group_meshes.append(_build_openpose_sticks( + h_bind, OPENPOSE_HAND_PAIRS, float(hand_stick_radius_m), + hand_pair_colors, + limb_joint_base_idx=K + K_body_limbs + hand_i * K_hand_limbs, + shape=shape, + smooth_shade=smooth_shade, + )) + + if K_face > 0: + f_off = K_body + K_hands + f_bind = bind_kp_m[f_off:f_off + K_face] + # DWPose face = dots only, no contour lines + # (controlnet_aux/dwpose/util.py::draw_facepose draws white + # circles per landmark and never connects them). + group_meshes.append(_build_openpose_spheres( + f_bind, float(face_marker_radius_m), + FACE_LANDMARK_COLORS, base_joint_idx=f_off, + smooth_shade=smooth_shade, + )) + + primitives: List[dict] = [] + for (v_arr, n_arr, f_arr, j_arr, w_arr, c_arr) in group_meshes: + if v_arr.shape[0] == 0: + continue + attrs = { + "POSITION": w.add_vec3_f32(v_arr), + "NORMAL": w.add_vec3_f32(n_arr), + "JOINTS_0": w.add_joints_u16(j_arr), + "WEIGHTS_0": w.add_weights_f32(w_arr), + "COLOR_0": w.add_vec3_f32(c_arr), + } + materials.append(make_lit_material( + roughness=material_roughness, + double_sided=material_double_sided, + )) + primitives.append({ + "attributes": attrs, + "indices": w.add_indices_u32(f_arr.reshape(-1)), + "mode": 4, + "material": len(materials) - 1, + }) + if not primitives: + continue + meshes.append({"primitives": primitives}) + nodes.append({ + "name": f"track{track_i:02d}_openpose", + "mesh": len(meshes) - 1, + "skin": skin_idx, + }) + person_root["children"].append(len(nodes) - 1) + + times = np.asarray(frame_indices, dtype=np.float32) / float(fps) + time_acc = w.add_scalar_f32(times) + samplers: List[dict] = [] + channels: List[dict] = [] + for j in range(K): + t_j = kp_seq[:, j, :].astype(np.float32) + if (np.ptp(t_j, axis=0) < 1e-6).all(): + nodes[joint_node_indices[j]]["translation"] = t_j[0].tolist() + continue + acc = w.add_vec3_f32_anim(t_j) + samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": joint_node_indices[j], "path": "translation"}, + }) + + # Per-limb-joint translation + rotation channels. Stationary limbs + # have their constant TRS baked into the node so they don't bloat the + # animation buffer. + for k in range(K_limbs): + t_k = limb_anim_mids[:, k, :].astype(np.float32) + if (np.ptp(t_k, axis=0) < 1e-6).all(): + nodes[limb_joint_indices[k]]["translation"] = t_k[0].tolist() + else: + acc = w.add_vec3_f32_anim(t_k) + samplers.append({"input": time_acc, "output": acc, + "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": limb_joint_indices[k], "path": "translation"}, + }) + q_k = limb_anim_quats[:, k, :].astype(np.float32) + # ptp on the absolute value handles the +q == -q ambiguity, but + # `quat_sign_fix_per_joint` already aligned signs so a plain ptp + # is fine here. + if (np.ptp(q_k, axis=0) < 1e-6).all(): + nodes[limb_joint_indices[k]]["rotation"] = q_k[0].tolist() + else: + acc = w.add_vec4_f32(q_k) + samplers.append({"input": time_acc, "output": acc, + "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": limb_joint_indices[k], "path": "rotation"}, + }) + + if camera_translation != "off": + frames = pose_data["frames"] + cam_t = np.stack([ + unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32)) + for t in frame_indices + ], axis=0) + if camera_translation == "centered" and cam_t.shape[0] > 0: + cam_t = cam_t - cam_t[0:1] + if (np.ptp(cam_t, axis=0) < 1e-6).all(): + person_root["translation"] = cam_t[0].tolist() + else: + acc = w.add_vec3_f32_anim(cam_t) + samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": person_root_idx, "path": "translation"}, + }) + + animations.append({ + "name": f"track{track_i:02d}", + "samplers": samplers, "channels": channels, + }) + + if not scene_root_indices: + raise ValueError("build_glb_openpose: produced no tracks") + + gltf: Dict[str, Any] = { + "asset": {"version": "2.0", "generator": "ComfyUI-SAM3DBody"}, + "scene": 0, + "scenes": [{"nodes": scene_root_indices}], + "nodes": nodes, + "meshes": meshes, + "skins": skins, + } + if materials: + gltf["materials"] = materials + if animations: + gltf["animations"] = animations + return w.to_bytes(gltf) diff --git a/comfy_extras/sam3d_body/export/glb_shared.py b/comfy_extras/sam3d_body/export/glb_shared.py new file mode 100644 index 000000000..66a732a18 --- /dev/null +++ b/comfy_extras/sam3d_body/export/glb_shared.py @@ -0,0 +1,1138 @@ +"""GLB export for SAM 3D Body pose_data. + +Mode: skeletal — rebuilds the MHR 127-bone rig. Per-frame local TRS comes from +re-running param_transform on saved mhr_model_params; rest verts from a +zero-pose forward with the person's shape_params; sparse triplet skinning is +compacted to glTF's max-4-influences form; facial expression is re-exposed as +72 morph targets driven by expr_params. + +pred_vertices/pred_cam_t are camera-y-down — un-flipped here so the GLB lives +in glTF-spec Y-up. Pose correctives are dropped (glTF skinning can't represent +them); deformation at extreme joint angles will differ from the SAM3DBody +renderer by the corrective amount. +""" + +from __future__ import annotations + +import json +import struct +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch + +from comfy_extras.sam3d_body.rasterizer import rainbow_colors_from_canonical + +# fp32-rounded ln(2). Used as `exp(x * _LN2)` to compute 2**x bit-identically +# to the rig's own `torch.exp(jp[..., 6:7] * _LN2)` +_LN2 = 0.6931471824645996 + + +# Quaternion / rotation helpers (xyzw convention, matching MHR rig) + +def _euler_xyz_to_quat_np(angles: np.ndarray) -> np.ndarray: + """(roll, pitch, yaw) -> (x, y, z, w). Mirrors mhr_rig._euler_xyz_to_quat.""" + roll, pitch, yaw = angles[..., 0], angles[..., 1], angles[..., 2] + cy, sy = np.cos(yaw * 0.5), np.sin(yaw * 0.5) + cp, sp = np.cos(pitch * 0.5), np.sin(pitch * 0.5) + cr, sr = np.cos(roll * 0.5), np.sin(roll * 0.5) + x = sr * cp * cy - cr * sp * sy + y = cr * sp * cy + sr * cp * sy + z = cr * cp * sy - sr * sp * cy + w = cr * cp * cy + sr * sp * sy + return np.stack([x, y, z, w], axis=-1) + + +def _quat_multiply_np(q1: np.ndarray, q2: np.ndarray) -> np.ndarray: + """xyzw product. Mirrors mhr_rig._quat_multiply.""" + x1, y1, z1, w1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3] + x2, y2, z2, w2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3] + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + return np.stack([x, y, z, w], axis=-1) + + +def _quat_rotate_np(q: np.ndarray, v: np.ndarray) -> np.ndarray: + """Rotate v by unit xyzw q. Mirrors mhr_rig._quat_rotate.""" + axis = q[..., :3] + r = q[..., 3:4] + av = np.cross(axis, v, axis=-1) + aav = np.cross(axis, av, axis=-1) + return v + 2.0 * (av * r + aav) + + +def _skel_state_inverse_np(skel_state: np.ndarray) -> np.ndarray: + """Inverse of (t, q, s). Normalizes q first so non-unit input is OK.""" + t = skel_state[..., :3] + q = skel_state[..., 3:7] + s = skel_state[..., 7:8] + q = q / np.maximum(np.linalg.norm(q, axis=-1, keepdims=True), 1e-12) + s_safe = np.where(np.abs(s) > 1e-12, s, 1.0) + s_inv = 1.0 / s_safe + q_inv = np.concatenate([-q[..., :3], q[..., 3:4]], axis=-1) + t_inv = -s_inv * _quat_rotate_np(q_inv, t) + return np.concatenate([t_inv, q_inv, s_inv], axis=-1) + + +def _skel_state_compose_np(s1: np.ndarray, s2: np.ndarray) -> np.ndarray: + """s1 ∘ s2. Mirrors mhr_rig._skel_multiply.""" + t1 = s1[..., :3] + q1 = s1[..., 3:7] + sc1 = s1[..., 7:8] + + t2 = s2[..., :3] + q2 = s2[..., 3:7] + sc2 = s2[..., 7:8] + # Defensive normalization to match the rig's `F.normalize` calls. + q1 = q1 / np.maximum(np.linalg.norm(q1, axis=-1, keepdims=True), 1e-12) + q2 = q2 / np.maximum(np.linalg.norm(q2, axis=-1, keepdims=True), 1e-12) + t_res = t1 + sc1 * _quat_rotate_np(q1, t2) + q_res = _quat_multiply_np(q1, q2) + s_res = sc1 * sc2 + return np.concatenate([t_res, q_res, s_res], axis=-1) + + +def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray: + """Gaussian-smooth a (N, NJ, 4) quaternion sequence along time. Sign-aligns + per joint first, convolves per-component, renormalizes. Suppresses multi- + frame bone spikes at extreme poses without needing the upstream Smooth node.""" + if window <= 1 or q_seq.shape[0] < 2: + return q_seq + aligned = quat_sign_fix_per_joint(q_seq).astype(np.float64) + n = q_seq.shape[0] + half = window // 2 + sigma = max(0.5, window / 4.0) + x = np.arange(-half, half + 1, dtype=np.float64) + kernel = np.exp(-x * x / (2.0 * sigma * sigma)) + kernel = kernel / kernel.sum() + # Edge-replicate padding so endpoints don't get pulled toward zero. + pad = half + padded = np.concatenate([ + np.broadcast_to(aligned[:1], (pad,) + aligned.shape[1:]), + aligned, + np.broadcast_to(aligned[-1:], (pad,) + aligned.shape[1:]), + ], axis=0) + out = np.zeros_like(aligned) + for k, w in enumerate(kernel): + out += w * padded[k:k + n] + norms = np.linalg.norm(out, axis=-1, keepdims=True) + out = out / np.maximum(norms, 1e-12) + return out.astype(np.float32) + + +def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray: + """Walk (N, NJ, 4) along time, flip sign whenever consecutive frames sit + on opposite hemispheres. Eliminates long-path slerp glitches (mid-anim + cartwheel flip). fp64 to avoid drift; normalizes input defensively.""" + out = np.array(q_seq, dtype=np.float64, copy=True) + norms = np.linalg.norm(out, axis=-1, keepdims=True) + out = out / np.maximum(norms, 1e-12) + for t in range(1, out.shape[0]): + dots = (out[t - 1] * out[t]).sum(axis=-1) + sign = np.where(dots < 0.0, -1.0, 1.0)[:, None] + out[t] = out[t] * sign + return out.astype(np.float32) + + +def bone_locals_from_globals(rig_global: np.ndarray, parents: np.ndarray) -> np.ndarray: + """Globals (N, NJ, 8) + parents -> per-bone local TRS (N, NJ, 8) such that + FK over (parents, bone_local) reproduces rig_global. local = + inverse(parent_global) ∘ child_global makes this robust to hierarchy- + convention mismatches: glTF FK gives back exactly rig_global even if + `parents` doesn't match the rig's pmi-walk.""" + N, NJ, _ = rig_global.shape + bone_local = np.zeros_like(rig_global) + for j in range(NJ): + p = int(parents[j]) + if 0 <= p < NJ and p != j: + parent_g = rig_global[:, p] + parent_g_inv = _skel_state_inverse_np(parent_g) + bone_local[:, j] = _skel_state_compose_np(parent_g_inv, rig_global[:, j]) + else: + bone_local[:, j] = rig_global[:, j] + return bone_local + + +def _quat_to_mat3_np(q: np.ndarray) -> np.ndarray: + x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] + n = x * x + y * y + z * z + w * w + s = np.where(n > 0, 2.0 / n, 0.0) + R = np.empty(q.shape[:-1] + (3, 3), dtype=q.dtype) + R[..., 0, 0] = 1 - s * (y * y + z * z) + R[..., 0, 1] = s * (x * y - z * w) + R[..., 0, 2] = s * (x * z + y * w) + R[..., 1, 0] = s * (x * y + z * w) + R[..., 1, 1] = 1 - s * (x * x + z * z) + R[..., 1, 2] = s * (y * z - x * w) + R[..., 2, 0] = s * (x * z - y * w) + R[..., 2, 1] = s * (y * z + x * w) + R[..., 2, 2] = 1 - s * (x * x + y * y) + return R + + +def collect_tracks(pose_data: Dict[str, Any], track_index: int) -> List[Tuple[int, List[int]]]: + """List of (person_index, frame_indices). track_index == -1 means every + present track; empty tracks are dropped. Same person index across frames + is assumed same subject (Smooth/Predict enforce this on tracked bboxes).""" + frames = pose_data["frames"] + max_p = max((len(f) for f in frames), default=0) + if max_p == 0: + return [] + if track_index >= 0: + if track_index >= max_p: + return [] + wanted = [track_index] + else: + wanted = list(range(max_p)) + + tracks: List[Tuple[int, List[int]]] = [] + for k in wanted: + valid = [t for t, fr in enumerate(frames) if k < len(fr)] + if valid: + tracks.append((k, valid)) + return tracks + + +# glTF binary builder + + +_FLOAT = 5126 +_USHORT = 5123 +_UINT = 5125 +_BYTE_ARRAY = 34962 +_BYTE_ELEMENT = 34963 + + +def _pad4(buf: bytes, fill: bytes = b"\x00") -> bytes: + n = (4 - (len(buf) % 4)) % 4 + return buf + fill * n + + +class GLBWriter: + """Builds a single .glb from incremental accessor/bufferView additions.""" + + def __init__(self) -> None: + self._buffer = bytearray() + self.bufferViews: List[dict] = [] + self.accessors: List[dict] = [] + + def _add_view(self, data: bytes, *, target: Optional[int] = None) -> int: + offset = len(self._buffer) + self._buffer += data + # 4-byte align so subsequent views start on a boundary. + pad = (4 - (offset + len(data)) % 4) % 4 + if pad: + self._buffer += b"\x00" * pad + view = {"buffer": 0, "byteOffset": offset, "byteLength": len(data)} + if target is not None: + view["target"] = target + self.bufferViews.append(view) + return len(self.bufferViews) - 1 + + def add_vec3_f32(self, arr: np.ndarray, *, target: int = _BYTE_ARRAY) -> int: + a = np.ascontiguousarray(arr, dtype=np.float32) + view_idx = self._add_view(a.tobytes(), target=target) + self.accessors.append({ + "bufferView": view_idx, "componentType": _FLOAT, + "count": a.shape[0], "type": "VEC3", + "min": a.min(axis=0).tolist(), "max": a.max(axis=0).tolist(), + }) + return len(self.accessors) - 1 + + def add_vec3_f32_no_minmax(self, arr: np.ndarray) -> int: + """Morph-target POSITIONs: spec lets us skip min/max, avoiding a + per-frame delta bbox.""" + a = np.ascontiguousarray(arr, dtype=np.float32) + view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY) + self.accessors.append({ + "bufferView": view_idx, "componentType": _FLOAT, + "count": a.shape[0], "type": "VEC3", + }) + return len(self.accessors) - 1 + + def add_indices_u32(self, arr: np.ndarray) -> int: + a = np.ascontiguousarray(arr, dtype=np.uint32).reshape(-1) + view_idx = self._add_view(a.tobytes(), target=_BYTE_ELEMENT) + self.accessors.append({ + "bufferView": view_idx, "componentType": _UINT, + "count": int(a.size), "type": "SCALAR", + }) + return len(self.accessors) - 1 + + def add_scalar_f32(self, arr: np.ndarray) -> int: + a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1) + view_idx = self._add_view(a.tobytes()) + self.accessors.append({ + "bufferView": view_idx, "componentType": _FLOAT, + "count": int(a.size), "type": "SCALAR", + "min": [float(a.min())] if a.size else [0.0], + "max": [float(a.max())] if a.size else [0.0], + }) + return len(self.accessors) - 1 + + def add_scalar_f32_flat(self, arr: np.ndarray, count: int) -> int: + """Animation-output scalars: `count` is keyframes, not floats. Morph- + target weight tracks store N_morph weights per keyframe as flat float32 + with count=N_keyframes.""" + a = np.ascontiguousarray(arr, dtype=np.float32).reshape(-1) + view_idx = self._add_view(a.tobytes()) + self.accessors.append({ + "bufferView": view_idx, "componentType": _FLOAT, + "count": int(count), "type": "SCALAR", + }) + return len(self.accessors) - 1 + + def add_vec3_f32_anim(self, arr: np.ndarray) -> int: + a = np.ascontiguousarray(arr, dtype=np.float32) + view_idx = self._add_view(a.tobytes()) + self.accessors.append({ + "bufferView": view_idx, "componentType": _FLOAT, + "count": a.shape[0], "type": "VEC3", + }) + return len(self.accessors) - 1 + + def add_vec4_f32(self, arr: np.ndarray) -> int: + a = np.ascontiguousarray(arr, dtype=np.float32) + view_idx = self._add_view(a.tobytes()) + self.accessors.append({ + "bufferView": view_idx, "componentType": _FLOAT, + "count": a.shape[0], "type": "VEC4", + }) + return len(self.accessors) - 1 + + def add_mat4_f32(self, arr: np.ndarray) -> int: + a = np.ascontiguousarray(arr, dtype=np.float32) + view_idx = self._add_view(a.tobytes()) + self.accessors.append({ + "bufferView": view_idx, "componentType": _FLOAT, + "count": a.shape[0], "type": "MAT4", + }) + return len(self.accessors) - 1 + + def add_joints_u16(self, arr: np.ndarray) -> int: + a = np.ascontiguousarray(arr, dtype=np.uint16) + view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY) + self.accessors.append({ + "bufferView": view_idx, "componentType": _USHORT, + "count": a.shape[0], "type": "VEC4", + }) + return len(self.accessors) - 1 + + def add_weights_f32(self, arr: np.ndarray) -> int: + a = np.ascontiguousarray(arr, dtype=np.float32) + view_idx = self._add_view(a.tobytes(), target=_BYTE_ARRAY) + self.accessors.append({ + "bufferView": view_idx, "componentType": _FLOAT, + "count": a.shape[0], "type": "VEC4", + }) + return len(self.accessors) - 1 + + def to_bytes(self, gltf: dict) -> bytes: + gltf["buffers"] = [{"byteLength": len(self._buffer)}] + gltf["bufferViews"] = self.bufferViews + gltf["accessors"] = self.accessors + + json_bytes = json.dumps(gltf, separators=(",", ":")).encode("utf-8") + json_padded = _pad4(json_bytes, fill=b" ") + bin_padded = _pad4(bytes(self._buffer)) + + total = 12 + 8 + len(json_padded) + 8 + len(bin_padded) + header = struct.pack("<4sII", b"glTF", 2, total) + json_chunk = struct.pack(" np.ndarray: + out = np.array(arr, dtype=np.float32, copy=True) + out[..., 1] *= -1.0 + out[..., 2] *= -1.0 + return out + + +_BAKEABLE_SHADERS = { + "default", "rainbow", + "rainbow_face_normal", "rainbow_face_semantic", +} + + +def bake_vertex_colors( + canonical_colors: Optional[Dict[str, np.ndarray]], + shader: str, + rainbow_tilt_x_deg: float, + rainbow_tilt_z_deg: float, + pastel_mix: float, +) -> Optional[np.ndarray]: + """Per-vertex RGB matching the renderer's shader preset, on the canonical + mesh. Returns (N_v, 3) float32 in [0, 1], or None for `default` (let the + viewer's default material handle shading).""" + if shader == "default" or canonical_colors is None: + return None + + positions = np.asarray(canonical_colors["positions"], dtype=np.float32) + + vcolor = rainbow_colors_from_canonical( + positions, tilt_x_deg=rainbow_tilt_x_deg, tilt_z_deg=rainbow_tilt_z_deg, + ).copy() + if shader in ("rainbow_face_normal", "rainbow_face_semantic"): + face_mask = canonical_colors.get("face_mask") + if face_mask is not None and np.asarray(face_mask).any(): + if shader == "rainbow_face_normal": + norm = np.asarray(canonical_colors["norm"], dtype=np.float32) + vcolor[face_mask] = norm[face_mask] + else: # rainbow_face_semantic + sem = np.asarray(canonical_colors["face_region_rgb"], dtype=np.float32) + assigned = sem.sum(axis=1) > 0 + vcolor[assigned] = sem[assigned] + + # SCAIL-style per-person pastel mix toward white (track 0 = full color). + pm = max(0.0, min(1.0, float(pastel_mix))) + if pm > 0: + vcolor = vcolor * (1.0 - pm) + pm + return np.clip(vcolor, 0.0, 1.0).astype(np.float32) + + +def compute_pastel_mix(track_i: int, falloff: float) -> float: + """SCAIL-style desaturation: track 0 = 0.0, track k = 1 - falloff^k.""" + f = max(0.1, min(1.0, float(falloff))) + return 0.0 if track_i == 0 else (1.0 - f ** track_i) + + +def compute_normals(verts: np.ndarray, faces: np.ndarray) -> np.ndarray: + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + fn = np.cross(v1 - v0, v2 - v0).astype(np.float32) + vn = np.zeros_like(verts, dtype=np.float32) + np.add.at(vn, faces[:, 0], fn) + np.add.at(vn, faces[:, 1], fn) + np.add.at(vn, faces[:, 2], fn) + ln = np.linalg.norm(vn, axis=1, keepdims=True) + ln[ln < 1e-8] = 1.0 + return (vn / ln).astype(np.float32) + + +def _parents_from_pmi(rig: Any) -> np.ndarray: + """Parent index per joint from skel_pmi. pmi is (2, 266): row 0 = child, + row 1 = parent, split into BFS levels by skel_pmi_buffer_sizes. Roots = -1.""" + NJ = int(rig.NUM_JOINTS) + pmi = rig.skel_pmi.cpu().numpy() + sizes = rig.skel_pmi_buffer_sizes.cpu().numpy().tolist() + parents = np.full(NJ, -1, dtype=np.int32) + offset = 0 + for sz in sizes: + if sz > 0: + src = pmi[0, offset:offset + sz].astype(np.int64) + tgt = pmi[1, offset:offset + sz].astype(np.int64) + parents[src] = tgt + offset += sz + return parents + + +def _get_skeleton_override(pose_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Return ``_skeleton_override`` dict if present. Non-MHR skeletons supply + this to bypass MHR rig extraction (see ComfyUI-Kimodo). Required keys: + parents: (NJ,) int32, -1 = root + bind_global_m: (NJ, 8) f32 — [t.xyz | q.xyzw | scale], meters + lbs_compact_joints: (V, 8) uint16 — pre-compacted skin influences + lbs_compact_weights: (V, 8) f32 + lbs_compact_max_inf: int — actual max influences (≤ 8) + rest_verts_m: (V, 3) f32 + faces: (F, 3) uint32 + Optional: + per_frame_y_down: bool — set False if pred_joint_coords are already + rig-native Y-up (kimodo). Default True (MHR). + openpose18_joint_indices: (18, 2) int32 — body OpenPose-18 → joint + index pair, resolved against per-frame + `pred_joint_coords`. Each row is + (joint_a, joint_b); b == -1 = single + joint, else default midpoint of the two + (lets producers approximate keypoints + with no matching joint, e.g. Nose ≈ + midpoint(LeftEye, RightEye)). Enables + `SAM3DBody_ToGLB(mode="openpose")` on + external rigs. + openpose18_joint_weights: (18,) f32 — optional per-keypoint blend + weight for the (a, b) mapping above. + Position = w*joints[a] + (1-w)*joints[b] + when b ≥ 0 (default w=0.5 → midpoint). + Values outside [0, 1] EXTRAPOLATE past + the line segment — used to approximate + landmarks with no nearby joint pair + (e.g. ears: w=2.0 along the eye→eye + axis puts each ear one eye-distance + outside the corresponding eye). Ignored + for single-joint rows (b = -1). + openpose_hand21_r_joint_indices: (21, 2) int32 — right-hand OpenPose-21 + (wrist + 5 fingers × 4 joints, base→tip) + → joint index pair. Required (alongside + the L counterpart) for openpose mode + with include_hands=True. + openpose_hand21_l_joint_indices: (21, 2) int32 — left-hand counterpart. + openpose_hand21_r_joint_weights: (21,) f32 — optional, same semantics as + `openpose18_joint_weights`. + openpose_hand21_l_joint_weights: (21,) f32 — optional, same as above. + """ + if pose_data is None: + return None + return pose_data.get("_skeleton_override") + + +def extract_rig_static(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> Dict[str, np.ndarray]: + """Static rig buffers as numpy. If `pose_data` carries `_skeleton_override`, + use that instead of MHR-specific `model.head_pose.mhr` buffers.""" + override = _get_skeleton_override(pose_data) + if override is not None: + # External rig: caller pre-compacts skin and supplies bind global directly, + # so we don't need MHR's PCA pose / expression bases. + parents = np.asarray(override["parents"], dtype=np.int32) + rest_v = np.asarray(override["rest_verts_m"], dtype=np.float32) + return { + "parents": parents, + "parents_pmi": parents, + "lbs_compact_joints": np.asarray(override["lbs_compact_joints"], dtype=np.uint16), + "lbs_compact_weights": np.asarray(override["lbs_compact_weights"], dtype=np.float32), + "lbs_compact_max_inf": int(override.get("lbs_compact_max_inf", 4)), + "faces": np.asarray(override["faces"], dtype=np.uint32), + "num_joints": int(parents.shape[0]), + "num_verts": int(rest_v.shape[0]), + "num_expr": 0, + "num_shape": 0, + "_external": True, + } + + inner = model.model if hasattr(model, "model") else model + rig = inner.head_pose.mhr + head = inner.head_pose + + def _np(t: torch.Tensor) -> np.ndarray: + return t.cpu().numpy() + + # `skel_joint_parents` encodes the anatomical hierarchy; pmi-derived order + # is BFS-optimized for parallel FK and may include traversal quirks. + explicit_parents = _np(rig.skel_joint_parents).astype(np.int32) + return { + "parents": explicit_parents, # (127,) int32, -1 = root + "parents_pmi": _parents_from_pmi(rig), # kept for FK-related uses + "joint_translation_offsets": _np(rig.skel_joint_translation_offsets), # (127, 3) cm + "joint_prerotations": _np(rig.skel_joint_prerotations), # (127, 4) xyzw + "param_transform": _np(rig.param_transform), # (889, 249) + "lbs_inverse_bind_pose": _np(rig.lbs_inverse_bind_pose), # (127, 8) + "lbs_skin_weights": _np(rig.lbs_skin_weights), # (NNZ,) + "lbs_skin_indices": _np(rig.lbs_skin_indices).astype(np.int64), # (NNZ,) + "lbs_vert_indices": _np(rig.lbs_vert_indices).astype(np.int64), # (NNZ,) + "expr_basis": _np(rig.expr_basis), # (72, 18439, 3) + "faces": _np(head.faces).astype(np.uint32), # (36874, 3) + "num_joints": int(rig.NUM_JOINTS), + "num_verts": int(rig.NUM_VERTS), + "num_expr": int(rig.NUM_EXPR), + "num_shape": int(rig.NUM_IDENTITY), + "_external": False, + } + + +def compact_skin_to_n( + skin_indices: np.ndarray, vert_indices: np.ndarray, weights: np.ndarray, + num_verts: int, max_inf: int = 8, +) -> Tuple[np.ndarray, np.ndarray, int]: + """Sparse (joint, vert, weight) triplets -> dense (joints[V, max_inf], + weights[V, max_inf]). Keeps `max_inf` largest-magnitude influences, + renormalizes. `actual_max` lets the caller skip JOINTS_1/WEIGHTS_1 when + nothing exceeds 4 influences.""" + joints = np.zeros((num_verts, max_inf), dtype=np.uint16) + out_w = np.zeros((num_verts, max_inf), dtype=np.float32) + counts = np.zeros(num_verts, dtype=np.int32) + + if vert_indices.size: + # lexsort secondary key first: groups by vert, weights descending within group. + order = np.lexsort((-weights, vert_indices)) + vi_sorted = vert_indices[order] + sk_sorted = skin_indices[order] + w_sorted = weights[order] + + # Per-row rank within its vertex group: 0 at each group start, +1 elsewhere. + # group_start[i] is True when vi_sorted[i] starts a new vertex. + n = vi_sorted.size + group_start = np.empty(n, dtype=bool) + group_start[0] = True + np.not_equal(vi_sorted[1:], vi_sorted[:-1], out=group_start[1:]) + pos = np.arange(n, dtype=np.int64) + # Position of each row's group start, broadcast forward. + group_start_pos = np.maximum.accumulate(np.where(group_start, pos, 0)) + rank = pos - group_start_pos + + keep = rank < max_inf + vk = vi_sorted[keep] + rk = rank[keep] + joints[vk, rk] = sk_sorted[keep].astype(np.uint16, copy=False) + out_w[vk, rk] = w_sorted[keep].astype(np.float32, copy=False) + + true_counts = np.bincount(vi_sorted, minlength=num_verts) + np.minimum(true_counts, max_inf, out=counts, casting="unsafe") + + sums = out_w.sum(axis=1, keepdims=True) + nz = sums.squeeze(-1) > 0 + out_w[nz] /= sums[nz] + zero_w = ~nz + if zero_w.any(): + out_w[zero_w, 0] = 1.0 + actual_max = int(counts.max()) if counts.size else 0 + return joints, out_w, actual_max + + +def zero_pose_rest_verts( + model: Any, shape_params: np.ndarray, expr_zero: bool = True, + pose_data: Optional[Dict[str, Any]] = None, +) -> np.ndarray: + """Rig with zero pose + this subject's shape -> rest verts (V, 3) in + rig-native Y-up meters. External-skeleton path returns `rest_verts_m` + directly (no PCA shape space to expand).""" + override = _get_skeleton_override(pose_data) + if override is not None: + return np.asarray(override["rest_verts_m"], dtype=np.float32) + inner = model.model if hasattr(model, "model") else model + head = inner.head_pose + rig = head.mhr + device = rig.scale_mean.device if hasattr(rig, "scale_mean") else next(rig.parameters()).device + dtype = next(rig.parameters()).dtype + + sp = torch.from_numpy(np.ascontiguousarray(shape_params, dtype=np.float32)).to(device) + if sp.ndim == 1: + sp = sp.unsqueeze(0) + # mhr.forward(identity_coeffs, model_parameters, expr_coeffs): + # identity_rest = base_shape + identity_basis @ shape; + # cat([model_params, zeros]) through param_transform; expr added. + model_params = torch.zeros(1, 204, device=device, dtype=dtype) + expr = torch.zeros(1, 72, device=device, dtype=dtype) + verts, _ = rig(sp.to(dtype), model_params, expr, apply_correctives=False) + # Rig outputs cm; mhr_head divides by 100 for meters. Match that. + verts_m = verts[0].cpu().float().numpy() / 100.0 + return verts_m.astype(np.float32) + + +def global_skel_state_per_frame( + model: Any, mhr_model_params: np.ndarray, +) -> np.ndarray: + """Rig FK over a batch of mhr_model_params -> (N, NJ, 8) = (t cm, q xyzw, + scale). Bones are shape- and expression-independent so we pass zeros.""" + inner = model.model if hasattr(model, "model") else model + rig = inner.head_pose.mhr + device = next(rig.parameters()).device + dtype = next(rig.parameters()).dtype + + N = mhr_model_params.shape[0] + mp = torch.from_numpy(np.ascontiguousarray(mhr_model_params, dtype=np.float32)).to(device=device, dtype=dtype) + sp = torch.zeros(N, rig.NUM_IDENTITY, device=device, dtype=dtype) + expr = torch.zeros(N, rig.NUM_EXPR, device=device, dtype=dtype) + + _, skel_state = rig(sp, mp, expr, apply_correctives=False) + return skel_state.cpu().float().numpy() # (N, NJ, 8) cm + + +def rotmat_to_quat_np(R: np.ndarray) -> np.ndarray: + """(..., 3, 3) -> (..., 4) xyzw. Shepperd 1978 branched, largest-component + pick for stability. Cross-frame sign-fixing is the caller's job.""" + shape = R.shape[:-2] + Rf = R.reshape(-1, 3, 3).astype(np.float64) + M = Rf.shape[0] + q = np.zeros((M, 4), dtype=np.float64) + + trace = Rf[:, 0, 0] + Rf[:, 1, 1] + Rf[:, 2, 2] + m1 = trace > 0 + if m1.any(): + S = np.sqrt(trace[m1] + 1.0) * 2.0 + q[m1, 3] = 0.25 * S + q[m1, 0] = (Rf[m1, 2, 1] - Rf[m1, 1, 2]) / S + q[m1, 1] = (Rf[m1, 0, 2] - Rf[m1, 2, 0]) / S + q[m1, 2] = (Rf[m1, 1, 0] - Rf[m1, 0, 1]) / S + + rest = ~m1 + m2 = rest & (Rf[:, 0, 0] > Rf[:, 1, 1]) & (Rf[:, 0, 0] > Rf[:, 2, 2]) + if m2.any(): + S = np.sqrt(1.0 + Rf[m2, 0, 0] - Rf[m2, 1, 1] - Rf[m2, 2, 2]) * 2.0 + q[m2, 3] = (Rf[m2, 2, 1] - Rf[m2, 1, 2]) / S + q[m2, 0] = 0.25 * S + q[m2, 1] = (Rf[m2, 0, 1] + Rf[m2, 1, 0]) / S + q[m2, 2] = (Rf[m2, 0, 2] + Rf[m2, 2, 0]) / S + + m3 = rest & ~m2 & (Rf[:, 1, 1] > Rf[:, 2, 2]) + if m3.any(): + S = np.sqrt(1.0 + Rf[m3, 1, 1] - Rf[m3, 0, 0] - Rf[m3, 2, 2]) * 2.0 + q[m3, 3] = (Rf[m3, 0, 2] - Rf[m3, 2, 0]) / S + q[m3, 0] = (Rf[m3, 0, 1] + Rf[m3, 1, 0]) / S + q[m3, 1] = 0.25 * S + q[m3, 2] = (Rf[m3, 1, 2] + Rf[m3, 2, 1]) / S + + m4 = rest & ~m2 & ~m3 + if m4.any(): + S = np.sqrt(1.0 + Rf[m4, 2, 2] - Rf[m4, 0, 0] - Rf[m4, 1, 1]) * 2.0 + q[m4, 3] = (Rf[m4, 1, 0] - Rf[m4, 0, 1]) / S + q[m4, 0] = (Rf[m4, 0, 2] + Rf[m4, 2, 0]) / S + q[m4, 1] = (Rf[m4, 1, 2] + Rf[m4, 2, 1]) / S + q[m4, 2] = 0.25 * S + + return q.reshape(shape + (4,)).astype(np.float32) + + +def global_skel_state_from_pose_data( + pose_data: Dict[str, Any], frame_indices: List[int], person_k: int, + NJ: int, *, joint_coords_y_down: bool = True, +) -> np.ndarray: + """Build per-frame skel_state from stored pred_global_rots + pred_joint_coords, + bypassing rig.forward. Returns (N, NJ, 8) in METERS, MHR-native frame. + + pred_global_rots are MHR-native (no y/z flip). For MHR, pred_joint_coords + are stored y-down (post-flip), so un-flip when `joint_coords_y_down=True`. + External skeletons (Kimodo) store y-up already → pass False. Scale + defaults to 1 (rig scale isn't preserved in pose_data; close to 1 for + typical body poses).""" + frames = pose_data["frames"] + N = len(frame_indices) + rotmat = np.zeros((N, NJ, 3, 3), dtype=np.float32) + coords = np.zeros((N, NJ, 3), dtype=np.float32) + for t_idx, t in enumerate(frame_indices): + person = frames[t][person_k] + rotmat[t_idx] = np.asarray(person["pred_global_rots"], dtype=np.float32)[:NJ] + coords[t_idx] = np.asarray(person["pred_joint_coords"], dtype=np.float32)[:NJ] + if joint_coords_y_down: + coords[..., 1] *= -1.0 + coords[..., 2] *= -1.0 + quat = rotmat_to_quat_np(rotmat) + skel_state = np.zeros((N, NJ, 8), dtype=np.float32) + skel_state[..., :3] = coords + skel_state[..., 3:7] = quat + skel_state[..., 7] = 1.0 + return skel_state + + +def bind_skel_state(model: Any, pose_data: Optional[Dict[str, Any]] = None) -> np.ndarray: + """Rig FK with all-zero params -> bind-pose global skel state (NJ, 8) in cm. + Inverse of `lbs_inverse_bind_pose` modulo precision; used as bones' static + TRS so the rest mesh looks correct with no animation playing. External + rig: convert override's `bind_global_m` from m → cm to match this contract.""" + override = _get_skeleton_override(pose_data) + if override is not None: + bind_m = np.asarray(override["bind_global_m"], dtype=np.float32).copy() + bind_m[:, :3] *= 100.0 + return bind_m + zero_mp = np.zeros((1, 204), dtype=np.float32) + return global_skel_state_per_frame(model, zero_mp)[0] + + +def ibp_from_bind_global(bind_skel_state_m: np.ndarray) -> np.ndarray: + """Inverse-bind MAT4 by inverting the rig's bind global (meters). Guarantees + IBP[j] = inverse(FK over bind local TRS) — exactly what glTF skinning + needs given bones default to the bind local TRS. Returns (NJ, 4, 4) + column-major.""" + NJ = bind_skel_state_m.shape[0] + t = bind_skel_state_m[:, :3].astype(np.float32) + q = bind_skel_state_m[:, 3:7].astype(np.float32) + s = bind_skel_state_m[:, 7].astype(np.float32) + # Forward bind M = T * R * S (uniform scale): [s*R | t; 0 | 1] + R = _quat_to_mat3_np(q) + M = np.zeros((NJ, 4, 4), dtype=np.float32) + M[:, :3, :3] = R * s[:, None, None] + M[:, :3, 3] = t + M[:, 3, 3] = 1.0 + # fp64 4x4 invert per joint for stability, back to fp32. + M_inv = np.linalg.inv(M.astype(np.float64)).astype(np.float32) + # glTF MAT4 accessor is column-major. + return M_inv.transpose(0, 2, 1).astype(np.float32) + + +def _local_trs_per_frame( + rig_static: Dict[str, np.ndarray], mhr_model_params: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Per-frame (local_t[N, 127, 3], local_q[N, 127, 4 xyzw], local_s[N, 127]) + in rig-native frame, meters. Mirrors mhr_rig.forward without skinning.""" + pt = rig_static["param_transform"] # (889, 249) = (127*7, 204+45) + t_off = rig_static["joint_translation_offsets"] # (127, 3) cm + q_pre = rig_static["joint_prerotations"] # (127, 4) + NJ = rig_static["num_joints"] + + N = mhr_model_params.shape[0] + cat_in = np.zeros((N, pt.shape[1]), dtype=np.float32) + cat_in[:, :mhr_model_params.shape[1]] = mhr_model_params.astype(np.float32) + # joint_parameters[n, d] = sum_i pt[d, i] * cat_in[n, i] + jp = cat_in @ pt.T + jp = jp.reshape(N, NJ, 7) + + local_t_cm = jp[..., :3] + t_off[None] + local_q_raw = _euler_xyz_to_quat_np(jp[..., 3:6]) + local_q = _quat_multiply_np(q_pre[None], local_q_raw) + local_s = np.exp(jp[..., 6] * _LN2) + + # rig-cm -> glTF-meters + return (local_t_cm * 0.01).astype(np.float32), local_q.astype(np.float32), local_s.astype(np.float32) + + +def _ibp_to_mat4(ibp_skel: np.ndarray) -> np.ndarray: + """(127, 8) IBP skel-state -> (127, 4, 4) column-major MAT4, t in meters.""" + NJ = ibp_skel.shape[0] + t = ibp_skel[:, :3] * 0.01 # cm -> m + q = ibp_skel[:, 3:7] + s = ibp_skel[:, 7] + R = _quat_to_mat3_np(q) + M = np.zeros((NJ, 4, 4), dtype=np.float32) + M[:, :3, :3] = R * s[:, None, None] + M[:, :3, 3] = t + M[:, 3, 3] = 1.0 + return M.transpose(0, 2, 1).astype(np.float32) + + +def uv_sphere_unit(n_lat: int = 9, n_lon: int = 16) -> Tuple[np.ndarray, np.ndarray]: + """Unit UV sphere, poles ±Y. `n_lat` kept ODD by default so one ring + lands at the equator. Default (9, 16) gives 146 verts / 288 faces — n_lon + matches the 16-segment cylinder used by capsule limbs AND the equator + ring aligns 1-to-1 with the cylinder end ring, so silhouettes meet flush.""" + verts: List[List[float]] = [[0.0, -1.0, 0.0]] # south pole at index 0 + for i in range(1, n_lat + 1): + lat = -0.5 * np.pi + np.pi * i / (n_lat + 1) + y = float(np.sin(lat)) + r = float(np.cos(lat)) + for k in range(n_lon): + phi = 2.0 * np.pi * k / n_lon + verts.append([r * float(np.cos(phi)), y, r * float(np.sin(phi))]) + north_idx = len(verts) + verts.append([0.0, 1.0, 0.0]) + + faces: List[List[int]] = [] + # South cap — winding gives -Y outward normal. + south_ring = 1 + for k in range(n_lon): + a = south_ring + k + b = south_ring + (k + 1) % n_lon + faces.append([0, a, b]) + # Inter-ring quads, outward radial. + for i in range(n_lat - 1): + rl = 1 + i * n_lon + rh = 1 + (i + 1) * n_lon + for k in range(n_lon): + a = rl + k + b = rl + (k + 1) % n_lon + c = rh + (k + 1) % n_lon + d = rh + k + faces.append([a, c, b]) + faces.append([a, d, c]) + # North cap — winding gives +Y outward normal. + rL = 1 + (n_lat - 1) * n_lon + for k in range(n_lon): + a = rL + k + b = rL + (k + 1) % n_lon + faces.append([north_idx, b, a]) + + return (np.asarray(verts, dtype=np.float32), + np.asarray(faces, dtype=np.uint32)) + + +def flat_shade_mesh( + verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Smooth -> flat by duplicating verts per face; each triangle gets 3 + unique verts sharing its face normal. Skinning attrs duplicated alongside.""" + F = faces.shape[0] + new_v = np.zeros((F * 3, 3), dtype=np.float32) + new_n = np.zeros((F * 3, 3), dtype=np.float32) + new_j = np.zeros((F * 3, 4), dtype=np.uint16) + new_w = np.zeros((F * 3, 4), dtype=np.float32) + new_f = np.arange(F * 3, dtype=np.uint32).reshape(F, 3) + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + fn = np.cross(v1 - v0, v2 - v0) + fn_len = np.linalg.norm(fn, axis=1, keepdims=True) + fn = np.where(fn_len > 1e-8, fn / np.maximum(fn_len, 1e-12), np.array([[0.0, 1.0, 0.0]])) + for k in range(3): + new_v[k::3] = verts[faces[:, k]] + new_n[k::3] = fn + new_j[k::3] = joints[faces[:, k]] + new_w[k::3] = weights[faces[:, k]] + return new_v, new_n, new_f, new_j, new_w + + +def smooth_shade_mesh( + verts: np.ndarray, faces: np.ndarray, joints: np.ndarray, weights: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Area-weighted per-vertex normals (smooth shading). Geometry, skinning, + indexing pass through unchanged so vertex colors stay aligned. Orphan + verts get +Y fallback.""" + Nv = int(verts.shape[0]) + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + fn = np.cross(v1 - v0, v2 - v0).astype(np.float32) + vn = np.zeros((Nv, 3), dtype=np.float32) + np.add.at(vn, faces[:, 0], fn) + np.add.at(vn, faces[:, 1], fn) + np.add.at(vn, faces[:, 2], fn) + ln = np.linalg.norm(vn, axis=1, keepdims=True) + vn = np.where(ln > 1e-8, vn / np.maximum(ln, 1e-12), np.array([[0.0, 1.0, 0.0]], dtype=np.float32)) + return ( + verts.astype(np.float32), + vn.astype(np.float32), + faces.astype(np.uint32), + joints, + weights, + ) + + +def rotation_align(from_vec: np.ndarray, to_vec: np.ndarray) -> np.ndarray: + """3x3 rotation mapping unit `from_vec` to unit `to_vec`.""" + cos_t = float(np.dot(from_vec, to_vec)) + cross = np.cross(from_vec, to_vec) + sin_t = float(np.linalg.norm(cross)) + if sin_t < 1e-8: + if cos_t > 0: + return np.eye(3, dtype=np.float32) + # Anti-aligned: 180° around any perpendicular. For ≈+Y, use X. + return np.diag([1.0, -1.0, -1.0]).astype(np.float32) + axis = cross / sin_t + K = np.array([ + [0.0, -axis[2], axis[1]], + [axis[2], 0.0, -axis[0]], + [-axis[1], axis[0], 0.0], + ], dtype=np.float32) + return (np.eye(3, dtype=np.float32) + sin_t * K + (1.0 - cos_t) * (K @ K)).astype(np.float32) + + +def make_lit_material( + roughness: float = 0.85, double_sided: bool = False, +) -> dict: + """Lit PBR material using vertex COLOR_0 multiplicatively. KHR_materials_unlit + is intentionally off so viewer lighting reveals surface form. metallic=0 + keeps the surface dielectric so vertex colors stay readable. roughness=0.85 + suits dense rainbow body meshes; 0.3 matches SCAIL-Pose's glossy rig look.""" + mat = { + "pbrMetallicRoughness": { + "baseColorFactor": [1.0, 1.0, 1.0, 1.0], + "metallicFactor": 0.0, + "roughnessFactor": float(max(0.0, min(1.0, roughness))), + }, + } + if double_sided: + mat["doubleSided"] = True + return mat + + +# OpenPose 18-keypoint viz (independent of MHR rig — uses pred_keypoints_3d, +# the model's regressed surface keypoints). + + +OPENPOSE_18_NAMES = ( + "Nose", "Neck", "RShoulder", "RElbow", "RWrist", + "LShoulder", "LElbow", "LWrist", "RHip", "RKnee", + "RAnkle", "LHip", "LKnee", "LAnkle", "REye", + "LEye", "REar", "LEar", +) + +# COCO-18 OpenPose -> MHR70. Subset of `MHR70_TO_OPENPOSE` in +# comfy/ldm/sam3d/mhr70.py (no toes/heels). +OPENPOSE18_TO_MHR70 = np.array([ + 0, # 0 Nose + 69, # 1 Neck + 6, # 2 RShoulder + 8, # 3 RElbow + 41, # 4 RWrist + 5, # 5 LShoulder + 7, # 6 LElbow + 62, # 7 LWrist + 10, # 8 RHip + 12, # 9 RKnee + 14, # 10 RAnkle + 9, # 11 LHip + 11, # 12 LKnee + 13, # 13 LAnkle + 2, # 14 REye + 1, # 15 LEye + 4, # 16 REar + 3, # 17 LEar +], dtype=np.int64) + +# OpenPose limb pairs + rainbow palette delegate to the canonical DWPose tables +# carried by `comfy_extras.pose.keypoint_draw.KeypointDraw` (also used by nodes_sdpose). +# `body_limbSeq` is 1-indexed there; we use 0-indexed throughout this module. +from comfy_extras.pose.keypoint_draw import KeypointDraw as _KeypointDraw +_KD = _KeypointDraw() +OPENPOSE_18_PAIRS = tuple((a - 1, b - 1) for a, b in _KD.body_limbSeq) +OPENPOSE_RAINBOW_18 = (np.array(_KD.colors, dtype=np.float32) / 255.0) + + +# SCAIL-Pose limb palette (17 limbs in `OPENPOSE_18_PAIRS` order): warm = +# right side, cool = left, grey centerline, pink/violet face. Matches +# ComfyUI-SCAIL-Pose's `nlf_render.py::ordered_colors_255`. +SCAIL_LIMB_COLORS_17 = (np.array([ + [255, 0, 0], # 0 Neck → R.Shoulder (Red) + [ 0, 255, 255], # 1 Neck → L.Shoulder (Cyan) + [255, 85, 0], # 2 R.Shoulder → R.Elbow (Orange) + [255, 170, 0], # 3 R.Elbow → R.Wrist (Golden Orange) + [ 0, 170, 255], # 4 L.Shoulder → L.Elbow (Sky Blue) + [ 0, 85, 255], # 5 L.Elbow → L.Wrist (Medium Blue) + [180, 255, 0], # 6 Neck → R.Hip (Yellow-Green) + [ 0, 255, 0], # 7 R.Hip → R.Knee (Bright Green) + [ 0, 255, 85], # 8 R.Knee → R.Ankle (Light Green-Blue) + [ 0, 0, 255], # 9 Neck → L.Hip (Pure Blue) + [ 85, 0, 255], # 10 L.Hip → L.Knee (Purple-Blue) + [170, 0, 255], # 11 L.Knee → L.Ankle (Medium Purple) + [150, 150, 150], # 12 Neck → Nose (Grey) + [255, 0, 170], # 13 Nose → R.Eye (Pink-Magenta) + [ 50, 0, 255], # 14 R.Eye → R.Ear (Dark Violet) + [255, 0, 170], # 15 Nose → L.Eye (Pink-Magenta) + [ 50, 0, 255], # 16 L.Eye → L.Ear (Dark Violet) +], dtype=np.float32) / 255.0) + + +def _scail_keypoint_colors_18(limb_pairs: Tuple[Tuple[int, int], ...] = None) -> np.ndarray: + """18 keypoint colors derived from 17 SCAIL limb colors. Each kp inherits + the first limb where it's the distal endpoint; mid-grey otherwise (only + the neck/nose root in OpenPose-18).""" + pairs = limb_pairs if limb_pairs is not None else OPENPOSE_18_PAIRS + out = np.tile(np.array([0.6, 0.6, 0.6], dtype=np.float32), (18, 1)) + for limb_i, (_, b) in enumerate(pairs): + if (out[b] == 0.6).all(): + out[b] = SCAIL_LIMB_COLORS_17[limb_i] + return out + + +SCAIL_KEYPOINT_COLORS_18 = _scail_keypoint_colors_18() + + +# OpenPose hand: 21 kp per hand = wrist + 5 fingers × 4 joints (proximal→distal). +# MHR70 stores fingers as (tip, joint1, joint2, joint3=MCP) so we reverse each +# 4-tuple. See comfy/ldm/sam3d/mhr70.py. +OPENPOSE_HAND21_NAMES = ( + "wrist", + "thumb1", "thumb2", "thumb3", "thumb4", + "index1", "index2", "index3", "index4", + "middle1", "middle2", "middle3", "middle4", + "ring1", "ring2", "ring3", "ring4", + "pinky1", "pinky2", "pinky3", "pinky4", +) + +OPENPOSE_HAND21_TO_MHR70_R = np.array([ + 41, # 0 right_wrist + 24, 23, 22, 21, # thumb base→tip + 28, 27, 26, 25, # index + 32, 31, 30, 29, # middle + 36, 35, 34, 33, # ring + 40, 39, 38, 37, # pinky +], dtype=np.int64) + +OPENPOSE_HAND21_TO_MHR70_L = np.array([ + 62, # 0 left_wrist + 45, 44, 43, 42, # thumb base→tip + 49, 48, 47, 46, # index + 53, 52, 51, 50, # middle + 57, 56, 55, 54, # ring + 61, 60, 59, 58, # pinky +], dtype=np.int64) + +# OpenPose hand limbs: 5 chains × 4 bones, delegated to KeypointDraw.hand_edges. +OPENPOSE_HAND_PAIRS = tuple(tuple(e) for e in _KD.hand_edges) + +# OpenPose hand colors (poseParameters.cpp::HAND_COLORS_RENDER): wrist grey, +# then per-finger base→tip gradient red/yellow/green/cyan/magenta. +OPENPOSE_HAND_COLORS_21 = (np.array([ + [100, 100, 100], + [100, 0, 0], [150, 0, 0], [200, 0, 0], [255, 0, 0], + [100, 100, 0], [150, 150, 0], [200, 200, 0], [255, 255, 0], + [ 0, 100, 50], [ 0, 150, 75], [ 0, 200, 100], [ 0, 255, 125], + [ 0, 100, 100], [ 0, 150, 150], [ 0, 200, 200], [ 0, 255, 255], + [100, 0, 100], [150, 0, 150], [200, 0, 200], [255, 0, 255], +], dtype=np.float32) / 255.0) + +# DWPose: solid blue hand dots, rainbow per-finger bones (matches +# controlnet_aux/dwpose/util.py::draw_handpose). +DWPOSE_HAND_COLORS_21 = np.tile( + np.array([[0.0, 0.0, 1.0]], dtype=np.float32), (21, 1) +) + + +# Face landmarks from the MHR rig (option `face_source="rig"`). +# MHR has no face bones — face deforms via expr_params morphs — so landmarks +# are sourced from `pred_vertices` at fixed vertex IDs picked by NN against +# anatomically-plausible target xyz in canonical Y-up. Iterate visually in +# Blender and tweak targets if landmarks land off-surface. + +# (name, target_xyz) in MHR canonical Y-up meters. +FACE_LANDMARK_TARGETS: Tuple[Tuple[str, Tuple[float, float, float]], ...] = ( + # Brows — 3 per side, outer→inner + ("r_brow_outer", (-0.058, 1.690, 0.090)), + ("r_brow_mid", (-0.040, 1.695, 0.105)), + ("r_brow_inner", (-0.020, 1.692, 0.115)), + ("l_brow_inner", (+0.020, 1.692, 0.115)), + ("l_brow_mid", (+0.040, 1.695, 0.105)), + ("l_brow_outer", (+0.058, 1.690, 0.090)), + # Right eye — outer/top/inner/bottom + ("r_eye_outer", (-0.058, 1.660, 0.085)), + ("r_eye_top", (-0.040, 1.673, 0.090)), + ("r_eye_inner", (-0.022, 1.665, 0.092)), + ("r_eye_bot", (-0.040, 1.652, 0.090)), + # Left eye + ("l_eye_outer", (+0.058, 1.660, 0.085)), + ("l_eye_top", (+0.040, 1.673, 0.090)), + ("l_eye_inner", (+0.022, 1.665, 0.092)), + ("l_eye_bot", (+0.040, 1.652, 0.090)), + # Nose + ("nose_bridge", (0.000, 1.660, 0.110)), + ("nose_mid", (0.000, 1.620, 0.125)), + ("nose_tip", (0.000, 1.585, 0.135)), + ("nostril_r", (-0.014, 1.580, 0.115)), + ("nostril_l", (+0.014, 1.580, 0.115)), + # Mouth — 4 outer-lip points + ("mouth_r_corner", (-0.030, 1.540, 0.105)), + ("upper_lip_mid", (+0.000, 1.555, 0.115)), + ("mouth_l_corner", (+0.030, 1.540, 0.105)), + ("lower_lip_mid", (+0.000, 1.530, 0.110)), + # Chin + jaw line — Y raised so NN search lands on chin tip / jaw underside + # (above the jaw-neck boundary at y~1.47) instead of throat verts. + ("chin", (0.000, 1.498, 0.108)), + ("r_jaw_low", (-0.038, 1.512, 0.100)), + ("r_jaw_mid", (-0.062, 1.535, 0.080)), + ("r_jaw_high", (-0.078, 1.562, 0.060)), + ("l_jaw_low", (+0.038, 1.512, 0.100)), + ("l_jaw_mid", (+0.062, 1.535, 0.080)), + ("l_jaw_high", (+0.078, 1.562, 0.060)), +) + +# Solid white face landmarks — matches DWPose, reads cleanly against the +# rainbow body palette. +def _face_landmark_colors() -> np.ndarray: + white = np.array([1.0, 1.0, 1.0], dtype=np.float32) + return np.tile(white, (len(FACE_LANDMARK_TARGETS), 1)) + + +FACE_LANDMARK_COLORS: np.ndarray = _face_landmark_colors() + + +def select_face_landmark_vert_ids( + canonical_positions: np.ndarray, + face_mask: Optional[np.ndarray] = None, +) -> np.ndarray: + """Pick MHR head vertex IDs for each `FACE_LANDMARK_TARGETS` by NN in + canonical positions. Filter: `face_mask` (verts that deform with any of + the 72 expression axes) if available — keeps chin/jaw search off the + neck. Otherwise a position bbox (less reliable; throat verts sometimes + pull chin targets).""" + P = np.asarray(canonical_positions, dtype=np.float32).reshape(-1, 3) + if face_mask is not None and np.asarray(face_mask).any(): + valid = np.where(np.asarray(face_mask).reshape(-1))[0] + else: + head_mask = (P[:, 1] > 1.47) & (np.abs(P[:, 0]) < 0.11) & (P[:, 2] > 0.04) + valid = np.where(head_mask)[0] + if valid.size == 0: + raise ValueError( + "select_face_landmark_vert_ids: no head verts matched the " + "canonical filter — check that pose_data.canonical_colors " + "holds the MHR rest-pose positions / face_mask." + ) + P_valid = P[valid] + out = np.empty(len(FACE_LANDMARK_TARGETS), dtype=np.int64) + for i, (_, xyz) in enumerate(FACE_LANDMARK_TARGETS): + target = np.asarray(xyz, dtype=np.float32) + d2 = np.sum((P_valid - target) ** 2, axis=1) + out[i] = int(valid[int(d2.argmin())]) + return out diff --git a/comfy_extras/sam3d_body/export/glb_skeletal.py b/comfy_extras/sam3d_body/export/glb_skeletal.py new file mode 100644 index 000000000..94783b5a0 --- /dev/null +++ b/comfy_extras/sam3d_body/export/glb_skeletal.py @@ -0,0 +1,578 @@ +"""GLB export — skeletal (real armature) mode. + +Rebuilds an Armature with the MHR 127-bone rig: + - per-frame local TRS comes from re-running param_transform on the saved + `mhr_model_params`; + - rest verts come from a zero-pose forward with each person's `shape_params`; + - sparse triplet skinning is compacted to glTF's max-4-influences-per-vertex form; + - facial expression is re-exposed as 72 morph targets driven by `expr_params` + so face animation survives plain glTF skinning. + +Optional bone visualization (octahedrons / sticks) is rigidly +skinned alongside the body mesh — used to preview the armature in glTF +viewers that don't draw bones. + +Shared GLB infra (writer, math, rig static extraction, shaders, normals) +stays in `glb_shared.py`; only this mode's geometry + assembly live here. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from .glb_shared import ( + GLBWriter, + bake_vertex_colors, + bind_skel_state, + bone_locals_from_globals, + collect_tracks, + compact_skin_to_n, + compute_normals, + compute_pastel_mix, + extract_rig_static, + flat_shade_mesh, + gaussian_smooth_quats, + global_skel_state_from_pose_data, + global_skel_state_per_frame, + ibp_from_bind_global, + make_lit_material, + quat_sign_fix_per_joint, + rotation_align, + unflip, + zero_pose_rest_verts, +) + +from comfy_extras.sam3d_body.utils import jet_colormap + +def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]: + """Per-bone RGB color (NJ, 3) float32 in [0, 1]. Returns None for 'white' + (no per-bone color → bone-vis mesh uses default unlit material).""" + if scheme == "rainbow_y": + y = bind_pos_m[:, 1].astype(np.float32) + y_min, y_max = float(y.min()), float(y.max()) + s = np.clip((y - y_min) / max(y_max - y_min, 1e-6), 0.0, 1.0) + return jet_colormap(s) + return None + + +def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]: + """Canonical Blender-style bone octahedron. Head at origin, tail at +Y, + unit length, ridge at 1/10 height. 6 verts, 8 triangles. Faces wound + so cross(v1-v0, v2-v0) points OUTWARD from the bone axis.""" + v = np.array([ + [0.0, 0.0, 0.0], # 0: head + [0.0, 1.0, 0.0], # 1: tail + [1.0, 0.1, 0.0], # 2: +X ridge (pre-scale; X/Z scale by half_width) + [-1.0, 0.1, 0.0], # 3: -X ridge + [0.0, 0.1, 1.0], # 4: +Z ridge + [0.0, 0.1, -1.0], # 5: -Z ridge + ], dtype=np.float32) + f = np.array([ + # head pyramid: outward = away from bone axis, slightly -Y + [0, 2, 4], [0, 5, 2], [0, 3, 5], [0, 4, 3], + # tail pyramid: outward = away from bone axis, slightly +Y + [1, 4, 2], [1, 3, 4], [1, 5, 3], [1, 2, 5], + ], dtype=np.uint32) + return v, f + + +def _bone_edges( + joint_pos_m: np.ndarray, parents: np.ndarray, +) -> List[Tuple[int, int, np.ndarray, np.ndarray]]: + """Return one (parent_idx, child_idx, head_pos, tail_pos) tuple per + parent→child edge in the hierarchy, skipping edges whose PARENT is a + root joint (those typically anchor the skeleton at world origin and + just look like a stray stick from origin to the body). Zero-length + edges are skipped too.""" + NJ = joint_pos_m.shape[0] + out: List[Tuple[int, int, np.ndarray, np.ndarray]] = [] + for c in range(NJ): + p = int(parents[c]) + if not (0 <= p < NJ and p != c): + continue + # Skip if parent itself is a root — that bone is a world-anchor stick. + gp = int(parents[p]) + if not (0 <= gp < NJ and gp != p): + continue + head = joint_pos_m[p].astype(np.float32) + tail = joint_pos_m[c].astype(np.float32) + if float(np.linalg.norm(tail - head)) < 1e-6: + continue + out.append((p, c, head, tail)) + return out + + +def _build_bone_octahedrons_mesh( + bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """One Blender-style octahedron per parent→child edge. Returns + (verts, normals, faces, joints, weights, child_idx_per_vert); + child_idx feeds per-bone color lookup at the call site.""" + base_v, base_f = _octahedron_unit() + canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32) + + out_v: List[List[float]] = [] + out_n: List[List[float]] = [] + out_f: List[List[int]] = [] + out_j: List[List[int]] = [] + out_w: List[List[float]] = [] + child_per_vert: List[int] = [] + + # Width scales with length so short bones (fingers, face) don't look chunky + # next to long ones (limbs, spine). `half_width_m` caps long bones. + WIDTH_RATIO = 0.1 + MIN_WIDTH = 0.001 + for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents): + direction = tail - head + length = float(np.linalg.norm(direction)) + if length < 1e-6: + continue + unit_dir = direction / length + R = rotation_align(canonical, unit_dir) + + half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m)) + scale = np.array([half_width_eff, length, half_width_eff], dtype=np.float32) + v_local = base_v * scale + v_world = v_local @ R.T + head + + # head pole outward = -Y, tail pole +Y, ridges outward in XZ. + n_local = np.zeros_like(base_v) + n_local[0] = [0.0, -1.0, 0.0] + n_local[1] = [0.0, 1.0, 0.0] + for k in range(2, 6): + n = base_v[k].copy() + n[1] = 0.0 + n_norm = float(np.linalg.norm(n)) + if n_norm > 0: + n_local[k] = n / n_norm + n_world = n_local @ R.T + + v_off = len(out_v) + out_v.extend(v_world.tolist()) + out_n.extend(n_world.tolist()) + for face in base_f: + out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off]) + # Dual skin head→parent, tail→child, ridges blend by canonical Y so the + # bone stretches between joints instead of going rigid with one. + for k in range(base_v.shape[0]): + y_canon = float(base_v[k, 1]) + w_parent = max(0.0, 1.0 - y_canon) + w_child = max(0.0, y_canon) + wsum = w_parent + w_child + if wsum > 0: + w_parent /= wsum + w_child /= wsum + out_j.append([int(parent_idx), int(child_idx), 0, 0]) + out_w.append([w_parent, w_child, 0.0, 0.0]) + child_per_vert.append(int(child_idx)) + + if not out_v: + return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32), + np.zeros((0, 3), dtype=np.uint32), np.zeros((0, 4), dtype=np.uint16), + np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=np.int64)) + return (np.asarray(out_v, dtype=np.float32), + np.asarray(out_n, dtype=np.float32), + np.asarray(out_f, dtype=np.uint32), + np.asarray(out_j, dtype=np.uint16), + np.asarray(out_w, dtype=np.float32), + np.asarray(child_per_vert, dtype=np.int64)) + + +def build_glb_skeletal( + pose_data: Dict[str, Any], + model: Any = None, + *, + fps: float = 24.0, + camera_translation: str = "off", + track_index: int = -1, + include_face_morphs: bool = True, + shader: str = "default", + rainbow_tilt_x_deg: float = 0.0, + rainbow_tilt_z_deg: float = 0.0, + person_palette_falloff: float = 0.6, + bone_smooth_window: int = 0, + use_stored_global_rots: bool = True, + bone_vis: str = "off", + bone_vis_radius_m: float = 0.04, + bone_vis_color: str = "white", + include_body_mesh: bool = True, +) -> bytes: + """Build pose_data as a real Armature GLB blob with per-bone TRS keyframes. + + For MHR (default) facial expression is exposed as 72 morph targets driven + by expr_params per frame when include_face_morphs=True. + + External skeletons (e.g. ComfyUI-Kimodo) can supply a + ``pose_data["_skeleton_override"]`` dict to bypass the MHR rig extraction + entirely. When present, ``model`` may be None and the rig data, bind pose, + skin weights, and rest verts come from the override. Per-frame skeletal + state still reads ``pred_global_rots`` / ``pred_joint_coords`` from each + person dict (kimodo populates these from its own FK output). See + ``glb.shared._get_skeleton_override`` for the override schema. + """ + frames = pose_data["frames"] + # Only `pred_cam_t` is camera-y-down; mhr_model_params, lbs_*, expr_basis, + # faces are all rig-native (Y-up). + faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32) + tracks = collect_tracks(pose_data, track_index) + if not tracks: + raise ValueError("build_glb_skeletal: no valid tracks in pose_data") + + rig_static = extract_rig_static(model, pose_data) + NJ = rig_static["num_joints"] + NV = rig_static["num_verts"] + NEXPR = rig_static["num_expr"] + parents = rig_static["parents"] + is_external = bool(rig_static.get("_external", False)) + if is_external: + # External rigs have no PCA pose params to re-run; only stored globals + # are available, and kimodo stores joint coords already Y-up. + use_stored_global_rots = True + joint_coords_y_down = not is_external + # Compact sparse skinning to 8 influences per vertex into glTF's two + # JOINTS_*/WEIGHTS_* sets. MHR averages ~2.8 influences/vert but some + # shoulder/hip verts have 5-8 where multiple joints cancel — keeping only + # 4 there leaks per-bone rotation noise into the rendered mesh. + if is_external: + joints_8 = rig_static["lbs_compact_joints"] + weights_8 = rig_static["lbs_compact_weights"] + actual_max_inf = rig_static["lbs_compact_max_inf"] + else: + joints_8, weights_8, actual_max_inf = compact_skin_to_n( + rig_static["lbs_skin_indices"], rig_static["lbs_vert_indices"], + rig_static["lbs_skin_weights"], NV, max_inf=8, + ) + joints_set0 = np.ascontiguousarray(joints_8[:, :4]) + weights_set0 = np.ascontiguousarray(weights_8[:, :4]) + use_set1 = actual_max_inf > 4 + joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None + weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None + # Derive bone locals from the rig's bind globals rather than recomputing + # FK ourselves, so any mismatch between `parents` and the rig's actual FK + # is absorbed into the local TRS instead of producing wrong globals. + bind_global_cm = bind_skel_state(model, pose_data) + bind_global_m = bind_global_cm.copy().astype(np.float32) + bind_global_m[:, :3] *= 0.01 + bind_local = bone_locals_from_globals(bind_global_m[None], rig_static["parents"])[0] + + # IBP = inverse of bind global. With bone defaults set to bind_local and + # FK composed via `parents`, skin_matrix at rest = identity. + ibp_mat4 = ibp_from_bind_global(bind_global_m) + + w = GLBWriter() + + nodes: List[dict] = [] + meshes: List[dict] = [] + skins: List[dict] = [] + materials: List[dict] = [] + animations: List[dict] = [] + scene_root_indices: List[int] = [] + canonical_colors = pose_data.get("canonical_colors") + + indices_acc = w.add_indices_u32(faces_native) + joints0_acc = w.add_joints_u16(joints_set0) + weights0_acc = w.add_weights_f32(weights_set0) + joints1_acc = w.add_joints_u16(joints_set1) if use_set1 else None + weights1_acc = w.add_weights_f32(weights_set1) if use_set1 else None + ibm_acc = w.add_mat4_f32(ibp_mat4) + + expr_morph_accs: List[int] = [] + if include_face_morphs and NEXPR > 0: + eb = rig_static["expr_basis"].astype(np.float32) * 0.01 + for e in range(NEXPR): + expr_morph_accs.append(w.add_vec3_f32_no_minmax(eb[e])) + + for track_i, (person_k, frame_indices) in enumerate(tracks): + person_root = {"name": f"track{track_i:02d}", "children": []} + nodes.append(person_root) + person_root_idx = len(nodes) - 1 + scene_root_indices.append(person_root_idx) + + bone_node_indices: List[int] = [] + for j in range(NJ): + bone = { + "name": f"bone_{j:03d}", + "translation": bind_local[j, :3].tolist(), + "rotation": bind_local[j, 3:7].tolist(), + "scale": [float(bind_local[j, 7])] * 3, + } + nodes.append(bone) + bone_node_indices.append(len(nodes) - 1) + + bone_children: List[List[int]] = [[] for _ in range(NJ)] + bone_root_indices: List[int] = [] + for j in range(NJ): + p = int(parents[j]) + if 0 <= p < NJ and p != j: + bone_children[p].append(bone_node_indices[j]) + else: + bone_root_indices.append(bone_node_indices[j]) + for j in range(NJ): + if bone_children[j]: + nodes[bone_node_indices[j]]["children"] = bone_children[j] + person_root["children"].extend(bone_root_indices) + + skin = { + "joints": bone_node_indices, + "inverseBindMatrices": ibm_acc, + "skeleton": bone_root_indices[0] if bone_root_indices else bone_node_indices[0], + } + skins.append(skin) + skin_idx = len(skins) - 1 + + include_body = bool(include_body_mesh) + include_bones = bone_vis in ("octahedrons", "sticks") + body_mesh_node_idx: Optional[int] = None + + if include_body: + # External rigs have no PCA shape — `zero_pose_rest_verts` short- + # circuits to `pose_data["_skeleton_override"]["rest_verts_m"]`, + # so zeroed shape_params is safe there. + if is_external: + shape_params_arr = np.zeros(0, dtype=np.float32) + else: + shape_params_arr = np.asarray( + frames[frame_indices[0]][person_k]["shape_params"], dtype=np.float32, + ) + rest_v = zero_pose_rest_verts(model, shape_params_arr, pose_data=pose_data) + normals = compute_normals(rest_v, faces_native) + positions_acc = w.add_vec3_f32(rest_v) + normals_acc = w.add_vec3_f32(normals) + + pastel_mix = compute_pastel_mix(track_i, person_palette_falloff) + vcolor = bake_vertex_colors( + canonical_colors, shader, + rainbow_tilt_x_deg, rainbow_tilt_z_deg, pastel_mix, + ) + color_acc = w.add_vec3_f32(vcolor) if vcolor is not None else None + + attributes = { + "POSITION": positions_acc, "NORMAL": normals_acc, + "JOINTS_0": joints0_acc, "WEIGHTS_0": weights0_acc, + } + if joints1_acc is not None: + attributes["JOINTS_1"] = joints1_acc + attributes["WEIGHTS_1"] = weights1_acc + if color_acc is not None: + attributes["COLOR_0"] = color_acc + primitive = { + "attributes": attributes, + "indices": indices_acc, + "mode": 4, + } + if color_acc is not None: + materials.append(make_lit_material()) + primitive["material"] = len(materials) - 1 + if expr_morph_accs: + primitive["targets"] = [{"POSITION": a} for a in expr_morph_accs] + + mesh = {"primitives": [primitive]} + if expr_morph_accs: + mesh["weights"] = [0.0] * len(expr_morph_accs) + meshes.append(mesh) + mesh_idx = len(meshes) - 1 + + mesh_node = { + "name": f"track{track_i:02d}_mesh", "mesh": mesh_idx, "skin": skin_idx, + } + nodes.append(mesh_node) + body_mesh_node_idx = len(nodes) - 1 + person_root["children"].append(body_mesh_node_idx) + + if include_bones: + bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color) + + # Indexes `bone_palette`: octahedrons/sticks use the bone's child + # joint so every bone has its own color regardless of skin target. + # 'sticks' = thin octahedrons. glTF LINES skinning is unreliable + # (Three.js' GLTFLoader doesn't animate skinned line primitives), + # so we render triangle tubes instead. + color_idx_per_vert: Optional[np.ndarray] = None + hw = float(bone_vis_radius_m) if bone_vis == "octahedrons" else 0.0035 + bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh( + bind_global_m[:, :3], rig_static["parents"], half_width_m=hw, + ) + if bv_v.shape[0] > 0: + F = bv_f.shape[0] + expanded_child = np.empty((F * 3,), dtype=np.int64) + for k in range(3): + expanded_child[k::3] = child_per_vert[bv_f[:, k]] + bv_v, bv_n, bv_f, bv_j, bv_w = flat_shade_mesh(bv_v, bv_f, bv_j, bv_w) + color_idx_per_vert = expanded_child + primitive_mode = 4 + bv_idx_flat = bv_f.reshape(-1) + + if bv_v.shape[0] > 0: + bv_pos_acc = w.add_vec3_f32(bv_v) + bv_idx_acc = w.add_indices_u32(bv_idx_flat) + bv_j_acc = w.add_joints_u16(bv_j) + bv_w_acc = w.add_weights_f32(bv_w) + bv_attrs = { + "POSITION": bv_pos_acc, + "JOINTS_0": bv_j_acc, "WEIGHTS_0": bv_w_acc, + } + if bv_n is not None: + bv_attrs["NORMAL"] = w.add_vec3_f32(bv_n) + if bone_palette is not None and color_idx_per_vert is not None: + bv_color = bone_palette[color_idx_per_vert].astype(np.float32) + bv_attrs["COLOR_0"] = w.add_vec3_f32(bv_color) + bv_primitive = { + "attributes": bv_attrs, + "indices": bv_idx_acc, + "mode": primitive_mode, + } + if bone_palette is not None: + materials.append(make_lit_material()) + bv_primitive["material"] = len(materials) - 1 + bv_mesh = {"primitives": [bv_primitive]} + meshes.append(bv_mesh) + bv_mesh_node = { + "name": f"track{track_i:02d}_bones", + "mesh": len(meshes) - 1, + "skin": skin_idx, + } + nodes.append(bv_mesh_node) + person_root["children"].append(len(nodes) - 1) + + # Per-frame GLOBAL skel state → bone locals via parent-inverse. + # Default uses the rig's stored output; the fallback re-runs FK. + if use_stored_global_rots: + rig_global_m = global_skel_state_from_pose_data( + pose_data, frame_indices, person_k, NJ, + joint_coords_y_down=joint_coords_y_down, + ) + else: + mp_per_frame = np.stack([ + np.asarray(frames[t][person_k]["mhr_model_params"], dtype=np.float32) + for t in frame_indices + ], axis=0) + rig_global_cm = global_skel_state_per_frame(model, mp_per_frame) + rig_global_m = rig_global_cm.copy().astype(np.float32) + rig_global_m[..., :3] *= 0.01 + # Sign-fix on the GLOBAL quats BEFORE deriving locals. The rig's + # Euler-XYZ parametrization wraps at ±180° for spinning joints; if we + # only fix locals, the parent's flip propagates into the child's + # local translation (t_local inherits parent sign via q_parent_inv) + # and produces visible "axis resets" mid-animation. + rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7]) + bone_local_anim = bone_locals_from_globals(rig_global_m, rig_static["parents"]) + local_t = bone_local_anim[..., :3].astype(np.float32) + local_q = bone_local_anim[..., 3:7].astype(np.float32) + local_s = bone_local_anim[..., 7].astype(np.float32) + # Second pass on locals catches residual drift from the parent-inverse. + local_q = quat_sign_fix_per_joint(local_q) + # Hemisphere-align frame 0 with the bind quat so pause/play takes the + # short path; then re-propagate. + bind_q = bind_local[:, 3:7].astype(np.float32) + if local_q.shape[0] > 0: + d0 = (bind_q * local_q[0]).sum(axis=-1) + sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None] + local_q[0] = local_q[0] * sign0 + local_q = quat_sign_fix_per_joint(local_q) + # Optional smoothing for multi-frame rig spikes (e.g. q.w discontinuity + # at handstand) that the upstream Smooth node may not catch. + if bone_smooth_window and bone_smooth_window > 1: + local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window)) + # fp64 renormalize → fp32 keyframes. Viewers' nlerp amplifies non-unit + # drift into visible flips otherwise. + lq64 = local_q.astype(np.float64) + lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12) + local_q = lq64.astype(np.float32) + + n_frames = len(frame_indices) + times = np.asarray(frame_indices, dtype=np.float32) / float(fps) + time_acc = w.add_scalar_f32(times) + + samplers: List[dict] = [] + channels: List[dict] = [] + + for j in range(NJ): + t_j = local_t[:, j, :] + q_j = local_q[:, j, :] + s_j = np.broadcast_to(local_s[:, j:j+1], (n_frames, 3)).astype(np.float32) + + t_const = (np.ptp(t_j, axis=0) < 1e-6).all() + q_const = (np.ptp(q_j, axis=0) < 1e-6).all() + s_const = (np.ptp(s_j, axis=0) < 1e-6).all() + + if t_const: + nodes[bone_node_indices[j]]["translation"] = t_j[0].tolist() + else: + acc = w.add_vec3_f32_anim(t_j) + samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": bone_node_indices[j], "path": "translation"}, + }) + + if q_const: + nodes[bone_node_indices[j]]["rotation"] = q_j[0].tolist() + else: + acc = w.add_vec4_f32(q_j) + samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": bone_node_indices[j], "path": "rotation"}, + }) + + if s_const: + nodes[bone_node_indices[j]]["scale"] = s_j[0].tolist() + else: + acc = w.add_vec3_f32_anim(s_j) + samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": bone_node_indices[j], "path": "scale"}, + }) + + if camera_translation != "off": + cam_t = np.stack([ + unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32)) + for t in frame_indices + ], axis=0) + if camera_translation == "centered" and cam_t.shape[0] > 0: + cam_t = cam_t - cam_t[0:1] + if (np.ptp(cam_t, axis=0) < 1e-6).all(): + person_root["translation"] = cam_t[0].tolist() + else: + acc = w.add_vec3_f32_anim(cam_t) + samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": person_root_idx, "path": "translation"}, + }) + + # Body-mesh-only: bone-vis primitives have no morph targets. + if expr_morph_accs and body_mesh_node_idx is not None: + expr_per_frame = np.stack([ + np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32) + for t in frame_indices + ], axis=0).astype(np.float32) + weights_acc_anim = w.add_scalar_f32_flat(expr_per_frame, count=n_frames * NEXPR) + samplers.append({"input": time_acc, "output": weights_acc_anim, "interpolation": "LINEAR"}) + channels.append({ + "sampler": len(samplers) - 1, + "target": {"node": body_mesh_node_idx, "path": "weights"}, + }) + + animations.append({ + "name": f"track{track_i:02d}", + "samplers": samplers, "channels": channels, + }) + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI-SAM3DBody"}, + "scene": 0, + "scenes": [{"nodes": scene_root_indices}], + "nodes": nodes, + "meshes": meshes, + "skins": skins, + } + if materials: + gltf["materials"] = materials + if animations: + gltf["animations"] = animations + + return w.to_bytes(gltf) diff --git a/comfy_extras/sam3d_body/export/openpose_2d.py b/comfy_extras/sam3d_body/export/openpose_2d.py new file mode 100644 index 000000000..084ffa8f3 --- /dev/null +++ b/comfy_extras/sam3d_body/export/openpose_2d.py @@ -0,0 +1,233 @@ +"""2D OpenPose-style skeleton rendering for SAM 3D Body pose_data. + +Body / hand drawing is delegated to `KeypointDraw.draw_wholebody_keypoints` +(shared with SDPose). SAM3D-specific: MHR70 -> DWPose-134 keypoint packing, +plus optional rig-projected face landmarks when `pred_face_keypoints_2d` +isn't present (and arbitrary-count face dots, since sapiens-238 doesn't fit +the DWPose face slot). + +Output: (H, W, 3) fp32 torch.Tensor in [0, 1]. +""" + +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +from PIL import Image + +from comfy_extras.pose.keypoint_draw import KeypointDraw + +from .glb_shared import ( + OPENPOSE18_TO_MHR70, + OPENPOSE_HAND21_TO_MHR70_L, + OPENPOSE_HAND21_TO_MHR70_R, + OPENPOSE_HAND_COLORS_21, + select_face_landmark_vert_ids, +) + + +_KD = KeypointDraw() +# OpenPose hand palette as a (21, 3) int array (0..255) for KeypointDraw. +_HAND_DOT_PALETTE_OPENPOSE = (OPENPOSE_HAND_COLORS_21 * 255.0).astype(int) + + +def _project_face_landmarks_2d( + person: Dict[str, Any], face_vert_ids: np.ndarray, H: int, W: int, +) -> Optional[np.ndarray]: + """Project `pred_vertices[face_vert_ids]` to 2D using each person's + pred_cam_t + focal_length. Same projection used by `_replay_mhr_with_overrides`.""" + verts = person.get("pred_vertices") + cam_t = person.get("pred_cam_t") + focal = person.get("focal_length") + if verts is None or cam_t is None or focal is None: + return None + verts = np.asarray(verts, dtype=np.float32) + cam_t = np.asarray(cam_t, dtype=np.float32).reshape(3) + f = float(np.asarray(focal, dtype=np.float32).reshape(-1)[0]) + pts3 = verts[face_vert_ids] + cam_t[None, :] + z = np.maximum(pts3[:, 2:3], 1e-6) + xy = pts3[:, :2] * f + xy = xy + np.array([W * 0.5, H * 0.5], dtype=np.float32)[None, :] * z + return (xy / z).astype(np.float32) + + +def _pack_dwpose_134( + person: Dict[str, Any], *, include_body: bool, include_hands: bool, +) -> Tuple[np.ndarray, np.ndarray]: + """Pack a SAM3D person dict into (kp, scores): (134, 2) DWPose-layout + coords + (134,) confidence. Face slot (24-91) is left zeroed; face dots + are drawn separately so SAM3D's 238-sapiens / rig-fallback counts work. + Non-finite or out-of-band entries get score=0 and are filtered downstream.""" + kp = np.zeros((134, 2), dtype=np.float32) + scores = np.zeros(134, dtype=np.float32) + + kp2d_full = person.get("pred_keypoints_2d") + if kp2d_full is None: + return kp, scores + kp2d = np.asarray(kp2d_full, dtype=np.float32) + if kp2d.ndim != 2 or kp2d.shape[1] != 2 or kp2d.shape[0] < 70: + return kp, scores + + if include_body: + body_xy = kp2d[OPENPOSE18_TO_MHR70] + finite = np.isfinite(body_xy).all(axis=1) + kp[:18][finite] = body_xy[finite] + scores[:18][finite] = 1.0 + + if include_hands: + for slot_start, mhr_idx in ((92, OPENPOSE_HAND21_TO_MHR70_R), + (113, OPENPOSE_HAND21_TO_MHR70_L)): + hand_xy = kp2d[mhr_idx] + finite = np.isfinite(hand_xy).all(axis=1) + kp[slot_start:slot_start + 21][finite] = hand_xy[finite] + scores[slot_start:slot_start + 21][finite] = 1.0 + + return kp, scores + + +def _draw_face_dots( + canvas: np.ndarray, face_xy: np.ndarray, marker_radius_px: int, +) -> None: + """White face dots, variable count (238 sapiens / ~30 rig-projected).""" + H, W = canvas.shape[:2] + pad = int(marker_radius_px) + white = (255, 255, 255) + for i in range(face_xy.shape[0]): + x_, y_ = float(face_xy[i, 0]), float(face_xy[i, 1]) + if not (np.isfinite(x_) and np.isfinite(y_)): + continue + x, y = int(round(x_)), int(round(y_)) + if x + pad < 0 or x - pad >= W or y + pad < 0 or y - pad >= H: + continue + _KD.draw.circle(canvas, (x, y), int(marker_radius_px), white, thickness=-1) + + +def render_pose_data_openpose( + pose_data: Dict[str, Any], + *, + frame_idx: int, + W: int, + H: int, + background: Optional[torch.Tensor] = None, + composite: str = "over", + marker_radius_px: int = 4, + stick_width_px: int = 4, + limb_alpha: float = 0.6, + include_body: bool = True, + include_hands: bool = False, + face_style: str = "disabled", + hand_color_style: str = "dwpose", + hand_marker_radius_px: int = 0, + hand_stick_width_px: int = 0, + face_marker_radius_px: int = 3, + person_brightness_falloff: float = 0.0, +) -> torch.Tensor: + """Render a 2D OpenPose-style skeleton onto an (H, W, 3) canvas. + + `composite='over'` paints over `background` (else black canvas). + `hand_marker_radius_px` / `hand_stick_width_px`: 0 = auto = 0.7x / 0.5x + of the body sizes. + `face_style`: 'disabled' = no face dots, 'full' = all face landmarks + (prefers sapiens-238 if present, else rig-fallback ~30), 'eyes_mouth' = + rig-fallback subset (~12 dots: eyes + mouth only). The subset only has a + documented layout for the rig fallback so eyes_mouth always uses it, + regardless of whether sapiens-238 is available. + `person_brightness_falloff` mixes each person's drawn pixels toward white + by `1 - falloff^k` (track 0 stays vivid). Applied post-draw so per-limb + alpha blending against the existing canvas remains correct. + """ + persons = pose_data["frames"][frame_idx] + + if composite == "over" and background is not None: + bg = background.cpu().numpy() + canvas = (np.clip(bg, 0.0, 1.0) * 255.0).astype(np.uint8) + if canvas.shape[:2] != (H, W): + canvas = np.array(Image.fromarray(canvas).resize((W, H), Image.LANCZOS)) + else: + canvas = np.zeros((H, W, 3), dtype=np.uint8) + # In-place draw needs a contiguous writable buffer. + canvas = np.ascontiguousarray(canvas) + + if int(hand_marker_radius_px) <= 0: + hand_marker_radius_px = max(1, int(round(marker_radius_px * 0.7))) + if int(hand_stick_width_px) <= 0: + hand_stick_width_px = max(1, int(round(stick_width_px * 0.5))) + + # Eyes+mouth indices into the rig fallback (FACE_LANDMARK_TARGETS): 6..13 + # = both eyes, 19..22 = outer-lip ring. Brows/nose/chin/jaw are dropped. + _EYES_MOUTH_IDX = np.array([6, 7, 8, 9, 10, 11, 12, 13, 19, 20, 21, 22], dtype=np.int64) + + include_face = face_style != "disabled" + use_rig_only = face_style == "eyes_mouth" + + # Real 238 sapiens face KPs take priority for 'full'; 'eyes_mouth' always + # falls through to the rig path since sapiens has no documented subset. + face_vert_ids: Optional[np.ndarray] = None + if include_face: + any_real = (not use_rig_only) and any( + p.get("pred_face_keypoints_2d") is not None for p in persons + ) + if not any_real: + cc = pose_data.get("canonical_colors") or {} + positions = cc.get("positions") + if positions is not None: + try: + face_vert_ids = select_face_landmark_vert_ids( + np.asarray(positions), face_mask=cc.get("face_mask"), + ) + if use_rig_only: + face_vert_ids = face_vert_ids[_EYES_MOUTH_IDX] + except Exception as e: + print(f"[SAM3DBody] face landmarks disabled - {e}") + face_vert_ids = None + + hand_dot_color = ( + _HAND_DOT_PALETTE_OPENPOSE if hand_color_style == "openpose" + else (0, 0, 255) + ) + + falloff = max(0.0, min(1.0, float(person_brightness_falloff))) + + for k, person in enumerate(persons): + pastel = 0.0 if k == 0 else (1.0 - falloff ** k) + # Snapshot before this person's strokes so we can identify the pixels + # they touched and blend just those toward white. Drawing happens + # against the live canvas first so limb_alpha blends correctly. + pre = canvas.copy() if pastel > 0 else None + + kp134, scores134 = _pack_dwpose_134( + person, include_body=include_body, include_hands=include_hands, + ) + _KD.draw_wholebody_keypoints( + canvas, kp134, scores=scores134, threshold=0.5, + draw_body=include_body, draw_feet=False, + draw_face=False, # SAM3D draws face dots separately (variable count) + draw_hands=include_hands, + stick_width=stick_width_px, + marker_radius=marker_radius_px, + hand_stick_width=hand_stick_width_px, + hand_marker_radius=hand_marker_radius_px, + limb_alpha=limb_alpha, + hand_dot_color=hand_dot_color, + ) + + if include_face: + face_xy = None + real_face = person.get("pred_face_keypoints_2d") + if real_face is not None: + arr = np.asarray(real_face, dtype=np.float32) + if arr.ndim == 2 and arr.shape[1] == 2: + face_xy = arr + elif face_vert_ids is not None: + face_xy = _project_face_landmarks_2d(person, face_vert_ids, H, W) + if face_xy is not None: + _draw_face_dots(canvas, face_xy, face_marker_radius_px) + + if pre is not None: + changed = (canvas != pre).any(axis=-1) + if changed.any(): + touched = canvas[changed].astype(np.float32) + blended = touched * (1.0 - pastel) + 255.0 * pastel + canvas[changed] = np.clip(blended, 0.0, 255.0).astype(np.uint8) + + return torch.from_numpy(canvas.astype(np.float32) / 255.0) diff --git a/comfy_extras/sam3d_body/face_expression.py b/comfy_extras/sam3d_body/face_expression.py new file mode 100644 index 000000000..2cea36325 --- /dev/null +++ b/comfy_extras/sam3d_body/face_expression.py @@ -0,0 +1,516 @@ +"""Face expression for SAM 3D Body. + +Pipeline: comfy_extras.mediapipe.face_landmarker → 52 ARKit blendshapes → +72-dim MHR expr_params (mapping inlined below). +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch + +import comfy.model_management + + +# Bypass deadzone — jaw signals are clean (open or not) +_NOISE_FREE_BLENDSHAPES = {"jawOpen", "jawForward", "jawLeft", "jawRight"} + +# Per-region gain — MP magnitudes vary by family (jaw up to 1.0, eye/brow +# rarely past 0.3), so a single global gain over/underdrives. +_REGION_PREFIXES = { + "mouth": ("jaw", "mouth"), + "eye": ("eye",), + "brow": ("brow", "cheek", "nose"), # cheek/nose read as upper-face +} + + +def _region_of(arkit_name: str) -> str: + for region, prefixes in _REGION_PREFIXES.items(): + for p in prefixes: + if arkit_name.startswith(p): + return region + return "other" + + +# MHR axis → ARKit driver(s). Each axis collects 1-3 (name, weight) entries; +# the consumer takes max() across them so primary + aux contributions don't +# stack. MHR's 72 expression axes ship as anonymous `shape_c_N` channels in +# the upstream FBX (no semantic names), so this table is hand-derived by +# visual inspection of which axis each ARKit shape drives. Axes 2/3 and +# 12/13 are filled by aux routes only. ARKit shapes with no MHR analog are simply absent. +_AXIS_TO_ARKIT: Dict[int, List[Tuple[str, float]]] = { + 0: [("browDownLeft", 1.0)], + 1: [("browDownRight", 1.0)], + 2: [("cheekPuff", 1.0)], + 3: [("cheekPuff", 1.0)], + 4: [("cheekSquintLeft", 1.0)], + 5: [("cheekSquintRight", 1.0)], + 6: [("mouthStretchLeft", 1.0)], + 7: [("mouthStretchRight", 1.0)], + 8: [("mouthShrugLower", 1.0)], + 9: [("mouthShrugUpper", 1.0)], + 10: [("mouthDimpleLeft", 1.0)], + 11: [("mouthDimpleRight", 1.0)], + 12: [("eyeLookDownLeft", 0.3)], + 13: [("eyeLookDownRight", 0.3)], + 14: [("eyeBlinkLeft", 1.0)], + 15: [("eyeBlinkRight", 1.0)], + 16: [("eyeLookOutLeft", 1.0)], + 17: [("eyeLookInRight", 1.0)], + 18: [("eyeLookInLeft", 1.0)], + 19: [("eyeLookOutRight", 1.0)], + 22: [("eyeLookUpLeft", 1.0), ("browInnerUp", 0.5)], + 23: [("eyeLookUpRight", 1.0), ("browInnerUp", 0.5)], + 24: [("jawOpen", 1.0), ("mouthLowerDownLeft", 0.3), ("mouthLowerDownRight", 0.3)], + 25: [("jawLeft", 1.0)], + 26: [("jawRight", 1.0)], + 27: [("jawForward", 1.0)], + 28: [("eyeSquintLeft", 1.0)], + 29: [("eyeSquintRight", 1.0)], + 32: [("mouthSmileLeft", 1.0)], + 33: [("mouthSmileRight", 1.0)], + 40: [("mouthLeft", 1.0)], + 41: [("mouthRight", 1.0)], + 42: [("mouthFrownLeft", 1.0)], + 43: [("mouthFrownRight", 1.0)], + 54: [("mouthLowerDownLeft", 1.0)], + 55: [("mouthLowerDownRight", 1.0)], + 60: [("noseSneerLeft", 1.0)], + 61: [("noseSneerRight", 1.0)], + 66: [("browOuterUpLeft", 1.0)], + 67: [("browOuterUpRight", 1.0)], + 68: [("eyeWideLeft", 1.0)], + 69: [("eyeWideRight", 1.0)], + 70: [("mouthUpperUpLeft", 1.0)], + 71: [("mouthUpperUpRight", 1.0)], +} + + +def _deadzone(x: float, threshold: float) -> float: + """Zero below threshold, linearly remap (threshold..1] → (0..1] so + amplification doesn't promote MP's per-blendshape noise floor.""" + if threshold <= 0.0: + return x + if x <= threshold: + return 0.0 + return (x - threshold) / (1.0 - threshold) + + +def arkit_to_expr_params( + blendshape_coefs: Dict[str, float], + strength: float = 1.0, + mouth_strength: float = 1.0, + eye_strength: float = 1.0, + brow_strength: float = 1.0, + input_threshold: float = 0.0, + n_axes: int = 72, +) -> np.ndarray: + """Map MediaPipe's 52 ARKit blendshapes to MHR's 72 expr_params axes. + Multiple ARKit names per axis combine via max() so primary + aux routes + don't double up.""" + expr = np.zeros(n_axes, dtype=np.float32) + region_scale = { + "mouth": float(mouth_strength), "eye": float(eye_strength), + "brow": float(brow_strength), "other": 1.0, + } + thr = float(input_threshold) + for axis, routes in _AXIS_TO_ARKIT.items(): + best = 0.0 + for name, weight in routes: + raw = float(blendshape_coefs.get(name, 0.0)) + name_thr = 0.0 if name in _NOISE_FREE_BLENDSHAPES else thr + raw = _deadzone(raw, name_thr) + c = raw * region_scale[_region_of(name)] * float(weight) + if c > best: + best = c + expr[axis] = best * strength + return expr + + +def subtract_per_clip_baseline( + per_frame_coefs: List[Optional[Dict[str, float]]], + percentile: float = 5.0, +) -> List[Optional[Dict[str, float]]]: + """Subtract per-blendshape p`percentile` baseline, clamp at 0. Adapts to + per-subject MP bias (e.g. resting browOuterUp ~0.15 → permanent surprise + under brow_strength=2.0) that a global deadzone can't catch.""" + if percentile <= 0.0: + return list(per_frame_coefs) + + names: set = set() + for c in per_frame_coefs: + if c is not None: + names.update(c.keys()) + + baselines: Dict[str, float] = {} + for n in names: + vals = [c[n] for c in per_frame_coefs if c is not None and n in c] + if vals: + baselines[n] = float(np.percentile(vals, percentile)) + + return [ + None if c is None + else {n: max(0.0, float(v) - baselines.get(n, 0.0)) for n, v in c.items()} + for c in per_frame_coefs + ] + + +def smooth_blendshape_series( + per_frame_coefs: List[Optional[Dict[str, float]]], + window: int = 7, + sigma: Optional[float] = None, +) -> List[Optional[Dict[str, float]]]: + """Gaussian-smooth each coefficient across time. MP per-frame output swings + 30-70% on static faces; smoothing pre-mapping cleans better than smoothing + mesh verts. None frames pass through unchanged.""" + if window <= 1: + return list(per_frame_coefs) + if window % 2 == 0: + window += 1 + if sigma is None: + sigma = max(1.0, window / 5.0) + + x = np.arange(window) - (window - 1) / 2.0 + k = np.exp(-(x ** 2) / (2 * sigma ** 2)) + k = k / k.sum() + + names: set = set() + for c in per_frame_coefs: + if c is not None: + names.update(c.keys()) + if not names: + return list(per_frame_coefs) + + N = len(per_frame_coefs) + pad = window // 2 + out: List[Optional[Dict[str, float]]] = [None] * N + + for name in names: + series = np.zeros(N, dtype=np.float32) + mask = np.zeros(N, dtype=bool) + for i, c in enumerate(per_frame_coefs): + if c is not None: + series[i] = float(c.get(name, 0.0)) + mask[i] = True + if not mask.any(): + continue + if not mask.all(): + idx = np.arange(N) + series = np.interp(idx, idx[mask], series[mask]) + padded = np.concatenate( + [np.repeat(series[:1], pad), series, np.repeat(series[-1:], pad)] + ) + filt = np.zeros_like(series) + for i, w in enumerate(k): + filt += w * padded[i: i + N] + for i in range(N): + if per_frame_coefs[i] is None: + continue + if out[i] is None: + out[i] = {} + out[i][name] = float(filt[i]) + return out + + +def fill_detection_gaps( + per_frame_coefs: List[Optional[Dict[str, float]]], + method: str = "interpolate", + max_gap: int = 12, +) -> List[Optional[Dict[str, float]]]: + """Fill missing per-frame dicts so the signal doesn't slam to zero at + undetected frames. method: 'interpolate' | 'hold' | 'zeros'. + + `max_gap` applies to 'interpolate' and 'hold' — gaps longer than that stay + None (don't fake too far). 'zeros' ignores `max_gap` on purpose: the goal + there is to relax to neutral on every miss, no matter how long, otherwise + long undetected runs would inherit Predict's per-frame expression.""" + if method == "zeros": + names: set = set() + for c in per_frame_coefs: + if c is not None: + names.update(c.keys()) + zero = {n: 0.0 for n in names} + return [dict(zero) if c is None else c for c in per_frame_coefs] + + N = len(per_frame_coefs) + detected = [i for i, c in enumerate(per_frame_coefs) if c is not None] + if not detected: + return list(per_frame_coefs) + + out: List[Optional[Dict[str, float]]] = list(per_frame_coefs) + + for fi in range(N): + if out[fi] is not None: + continue + prev_i = next((k for k in range(fi - 1, -1, -1) if per_frame_coefs[k] is not None), None) + next_i = next((k for k in range(fi + 1, N) if per_frame_coefs[k] is not None), None) + if prev_i is None and next_i is None: + continue + max_dist = max( + (fi - prev_i) if prev_i is not None else 10**9, + (next_i - fi) if next_i is not None else 10**9, + ) + if max_dist > max_gap: + continue + + if method == "hold": + src = per_frame_coefs[prev_i] if prev_i is not None else per_frame_coefs[next_i] + out[fi] = dict(src) + elif method == "interpolate": + if prev_i is None: + out[fi] = dict(per_frame_coefs[next_i]) + elif next_i is None: + out[fi] = dict(per_frame_coefs[prev_i]) + else: + w = (fi - prev_i) / (next_i - prev_i) + a = per_frame_coefs[prev_i] + b = per_frame_coefs[next_i] + keys = set(a.keys()) | set(b.keys()) + out[fi] = {k: (1.0 - w) * a.get(k, 0.0) + w * b.get(k, 0.0) for k in keys} + return out + + +def detect_faces_in_crop( + inner, image_rgb_uint8: np.ndarray, crop_xyxy: np.ndarray, num_faces: int = 1, +) -> List[dict]: + """Run detection on a sub-region; remap bbox+landmarks back to full-image + coords. Helps small/distant faces that fall below BlazeFace's min size.""" + H, W = image_rgb_uint8.shape[:2] + x1, y1, x2, y2 = (int(round(float(v))) for v in crop_xyxy) + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(W, x2), min(H, y2) + if x2 - x1 < 16 or y2 - y1 < 16: + return [] + crop = np.ascontiguousarray(image_rgb_uint8[y1:y2, x1:x2]) + faces = inner.face_landmarker.detect_batch([crop], num_faces=num_faces)[0] + bbox_off = np.array([x1, y1, x1, y1], dtype=np.float32) + xy_off = np.array([x1, y1], dtype=np.float32) + for f in faces: + f["bbox_xyxy"] = f["bbox_xyxy"] + bbox_off + f["landmarks_xy"] = f["landmarks_xy"] + xy_off + return faces + + +# Crop helpers — feed MP a tight head region so it doesn't downsample the face +# to 192px for full-frame detection. + +def _expand_bbox(bbox_xyxy: np.ndarray, factor: float, W: int, H: int) -> np.ndarray: + x1, y1, x2, y2 = (float(v) for v in bbox_xyxy) + cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2) + hw, hh = 0.5 * (x2 - x1) * factor, 0.5 * (y2 - y1) * factor + return np.array([ + max(0.0, cx - hw), max(0.0, cy - hh), + min(float(W), cx + hw), min(float(H), cy + hh), + ], dtype=np.float32) + + +def head_region_crop( + person_bbox: np.ndarray, expand: float, W: int, H: int, head_h_frac: float = 0.4, +) -> np.ndarray: + """Crop upper `head_h_frac` of a body bbox — cropping the whole body wastes + BlazeFace's 128² input on body pixels.""" + x1, y1, x2, y2 = (float(v) for v in person_bbox) + body_h = y2 - y1 + if body_h <= 0 or x2 - x1 <= 0: + return np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32) + return _expand_bbox(np.array([x1, y1, x2, y1 + body_h * head_h_frac]), expand, W, H) + + +# mhr70 convention: first five kp are COCO-style face landmarks in pixel coords. +_FACE_KP_INDICES = (0, 1, 2, 3, 4) # nose, L-eye, R-eye, L-ear, R-ear + + +def head_crop_from_keypoints( + pred_keypoints_2d: np.ndarray, expand: float, W: int, H: int, +) -> Optional[np.ndarray]: + """Head crop from SAM3D nose/eyes/ears kp. HEAD_FIT pads forehead/chin + since these only span the central face. None if <2 kp in-frame.""" + if pred_keypoints_2d is None: + return None + kp = np.asarray(pred_keypoints_2d, dtype=np.float32) + if kp.ndim != 2 or kp.shape[0] <= max(_FACE_KP_INDICES): + return None + face = kp[list(_FACE_KP_INDICES), :2] + in_frame = (face[:, 0] > 0) & (face[:, 1] > 0) & (face[:, 0] < W) & (face[:, 1] < H) + valid = face[in_frame] + if len(valid) < 2: + return None + x1, x2 = float(valid[:, 0].min()), float(valid[:, 0].max()) + y1, y2 = float(valid[:, 1].min()), float(valid[:, 1].max()) + cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2) + span = max(x2 - x1, y2 - y1, 1.0) + half = 0.5 * span * 1.8 * float(expand) # 1.8 = pad forehead+chin + return np.array([ + max(0.0, cx - half), max(0.0, cy - half), + min(float(W), cx + half), min(float(H), cy + half), + ], dtype=np.float32) + + +# Face → person assignment when running full-frame detection. + +def _iou_xyxy(a: np.ndarray, b: np.ndarray) -> float: + ix1, iy1 = max(a[0], b[0]), max(a[1], b[1]) + ix2, iy2 = min(a[2], b[2]), min(a[3], b[3]) + iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1) + inter = iw * ih + if inter <= 0.0: + return 0.0 + aw, ah = max(0.0, a[2] - a[0]), max(0.0, a[3] - a[1]) + bw, bh = max(0.0, b[2] - b[0]), max(0.0, b[3] - b[1]) + union = aw * ah + bw * bh - inter + return float(inter / union) if union > 0.0 else 0.0 + + +def assign_faces_to_persons( + face_bboxes: List[np.ndarray], person_bboxes: List[np.ndarray], min_iou: float = 0.01, +) -> List[Optional[int]]: + if not face_bboxes or not person_bboxes: + return [None] * len(person_bboxes) + assigned: List[Optional[int]] = [None] * len(person_bboxes) + used: set = set() + # Larger persons first — bigger bbox correlates with detectable face. + order = sorted(range(len(person_bboxes)), + key=lambda p: -((person_bboxes[p][2] - person_bboxes[p][0]) + * (person_bboxes[p][3] - person_bboxes[p][1]))) + for pi in order: + best_iou = min_iou + best_fi = None + pb = person_bboxes[pi] + for fi, fb in enumerate(face_bboxes): + if fi in used: + continue + cx, cy = 0.5 * (fb[0] + fb[2]), 0.5 * (fb[1] + fb[3]) + inside = (pb[0] <= cx <= pb[2]) and (pb[1] <= cy <= pb[3]) + score = max(_iou_xyxy(fb, pb), 0.5 if inside else 0.0) + if score > best_iou: + best_iou = score + best_fi = fi + if best_fi is not None: + assigned[pi] = best_fi + used.add(best_fi) + return assigned + + +# Re-run MHR forward after writing expr_params back into pose_frames; updates +# pred_vertices / pred_keypoints_2d/3d / pred_joint_coords / pred_global_rots. + +def regenerate_mesh_from_params(inner, pose_frames: List[List[Dict[str, Any]]]) -> None: + """Re-run MHR forward and write verts/kp3d/kp2d/joint back in place. + Drives MHR via euler params directly because hand refinement zeroes + pred_pose_raw.""" + device = comfy.model_management.get_torch_device() + head = inner.head_pose + if head.mhr is None: + return + + B = len(pose_frames) + max_p = max((len(f) for f in pose_frames), default=0) + + for pid in range(max_p): + grots, bpps, hands, shapes, scales, exprs, cam_ts, fls = [], [], [], [], [], [], [], [] + present: List[bool] = [] + for fi in range(B): + if pid >= len(pose_frames[fi]): + present.append(False) + continue + p = pose_frames[fi][pid] + needed = ("global_rot", "body_pose_params", "hand_pose_params", + "shape_params", "scale_params", "expr_params", + "pred_cam_t", "focal_length") + if any(p.get(k) is None for k in needed): + present.append(False) + continue + grots.append(np.asarray(p["global_rot"], dtype=np.float32)) + bpps.append(np.asarray(p["body_pose_params"], dtype=np.float32)) + hands.append(np.asarray(p["hand_pose_params"], dtype=np.float32)) + shapes.append(np.asarray(p["shape_params"], dtype=np.float32)) + scales.append(np.asarray(p["scale_params"], dtype=np.float32)) + exprs.append(np.asarray(p["expr_params"], dtype=np.float32)) + cam_ts.append(np.asarray(p["pred_cam_t"], dtype=np.float32)) + fls.append(float(np.asarray(p["focal_length"]).reshape(-1)[0])) + present.append(True) + if not any(present): + continue + + global_rot_euler = torch.from_numpy(np.stack(grots)).to(device) + body_pose_euler = torch.from_numpy(np.stack(bpps)).to(device) + hand_t = torch.from_numpy(np.stack(hands)).to(device) + shape_t = torch.from_numpy(np.stack(shapes)).to(device) + scale_t = torch.from_numpy(np.stack(scales)).to(device) + expr_t = torch.from_numpy(np.stack(exprs)).to(device) + cam_t_t = torch.from_numpy(np.stack(cam_ts)).to(device) + f_t = torch.tensor(fls, device=device, dtype=torch.float32) + + verts, kp3d_full, joint_coords, _, joint_rotmats = head.mhr_forward( + global_trans=torch.zeros_like(global_rot_euler), + global_rot=global_rot_euler, + body_pose_params=body_pose_euler, + hand_pose_params=hand_t, + scale_params=scale_t, + shape_params=shape_t, + expr_params=expr_t, + return_keypoints=True, + return_joint_coords=True, + return_model_params=True, + return_joint_rotations=True, + ) + # y/z flip matches head_pose.forward (camera-y-down convention). + verts = verts.clone() + verts[..., [1, 2]] *= -1 + kp3d = kp3d_full[:, :70].clone() + kp3d[..., [1, 2]] *= -1 + # 238 sapiens face landmarks (70:308) — track retargeted expression + # so openpose face dots follow new mouth/eye/brow shape. + kp3d_face = kp3d_full[:, 70:].clone() + kp3d_face[..., [1, 2]] *= -1 + joint_coords = joint_coords.clone() + joint_coords[..., [1, 2]] *= -1 + + # Recover principal point from any raw frame for reprojection. + cx = cy = 0.0 + for fi in range(B): + if not present[fi]: + continue + raw = pose_frames[fi][pid] + kp2d_r = np.asarray(raw["pred_keypoints_2d"], dtype=np.float32) + kp3d_r = np.asarray(raw["pred_keypoints_3d"], dtype=np.float32) + ct_r = np.asarray(raw["pred_cam_t"], dtype=np.float32) + fl_r = float(np.asarray(raw["focal_length"]).reshape(-1)[0]) + x, y, z = kp3d_r[0] + ct_r + cx = float(kp2d_r[0, 0] - fl_r * x / max(z, 1e-6)) + cy = float(kp2d_r[0, 1] - fl_r * y / max(z, 1e-6)) + break + + def _project_kp(kp3d_local: torch.Tensor) -> torch.Tensor: + kp3d_cam = kp3d_local + cam_t_t.unsqueeze(1) + u = f_t[:, None] * kp3d_cam[..., 0] / kp3d_cam[..., 2].clamp(min=1e-6) + cx + v = f_t[:, None] * kp3d_cam[..., 1] / kp3d_cam[..., 2].clamp(min=1e-6) + cy + return torch.stack([u, v], dim=-1) + + kp2d = _project_kp(kp3d) + kp2d_face = _project_kp(kp3d_face) + + verts_np = verts.float().cpu().numpy() + kp3d_np = kp3d.float().cpu().numpy() + kp2d_np = kp2d.float().cpu().numpy() + kp3d_face_np = kp3d_face.float().cpu().numpy() + kp2d_face_np = kp2d_face.float().cpu().numpy() + jc_np = joint_coords.float().cpu().numpy() + jrot_np = joint_rotmats.float().cpu().numpy() + + fi_active = 0 + for fi in range(B): + if not present[fi]: + continue + pose_frames[fi][pid] = dict(pose_frames[fi][pid]) + p = pose_frames[fi][pid] + p["pred_vertices"] = verts_np[fi_active] + p["pred_keypoints_3d"] = kp3d_np[fi_active] + p["pred_keypoints_2d"] = kp2d_np[fi_active] + p["pred_face_keypoints_3d"] = kp3d_face_np[fi_active] + p["pred_face_keypoints_2d"] = kp2d_face_np[fi_active] + p["pred_joint_coords"] = jc_np[fi_active] + p["pred_global_rots"] = jrot_np[fi_active] + fi_active += 1 diff --git a/comfy_extras/sam3d_body/rasterizer.py b/comfy_extras/sam3d_body/rasterizer.py new file mode 100644 index 000000000..37d32f952 --- /dev/null +++ b/comfy_extras/sam3d_body/rasterizer.py @@ -0,0 +1,467 @@ +"""Pure-PyTorch rasterizer for SAM 3D Body meshes. + +Algorithm: forward triangle rasterizer with hard z-buffer. Per-face screen +bbox cull → faces sorted by bbox size and chunked under a fixed pixel +budget → inside-test via edge functions, barycentric interpolation, depth +test via `scatter_reduce_(amin)`. +""" + +from typing import Sequence + +import numpy as np +import torch + +import comfy.model_management +from .utils import jet_colormap + +_CANONICAL_PRESETS = {"rainbow", "rainbow_face_normal", "rainbow_face_semantic"} + +_rainbow_cache: dict = {} + +def rainbow_colors_from_canonical( + positions: np.ndarray, + tilt_x_deg: float = 0.0, + tilt_z_deg: float = 0.0, +) -> np.ndarray: + """Compute per-vertex jet-colormap RGB from canonical (T-pose, Y-up) vertices. + + Args: + positions: (N_v, 3) canonical vertex positions, Y-up (head at max Y). + tilt_x_deg: rotation of the jet axis around X (in degrees). Positive + biases the ramp toward +Z (front). + tilt_z_deg: rotation of the jet axis around Z (in degrees). Positive + biases the ramp toward +X (right, in body frame). + + Returns: + (N_v, 3) float32 RGB in [0, 1]. + """ + key = (id(positions), round(float(tilt_x_deg), 3), round(float(tilt_z_deg), 3)) + cached = _rainbow_cache.get(key) + if cached is not None: + return cached + + theta_x = np.deg2rad(tilt_x_deg) + theta_z = np.deg2rad(tilt_z_deg) + axis = np.array([ + np.sin(theta_z), + np.cos(theta_z) * np.cos(theta_x), + np.cos(theta_z) * np.sin(theta_x), + ], dtype=np.float32) + + s = positions @ axis + s = (s - s.min()) / max(float(s.max() - s.min()), 1e-8) + s = np.clip(s * 0.98, 0.0, 1.0).astype(np.float32) + colors = jet_colormap(s) + _rainbow_cache[key] = colors + if len(_rainbow_cache) > 32: + _rainbow_cache.pop(next(iter(_rainbow_cache))) + return colors + + +def _vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor: + """Area-weighted per-vertex normals; matches `_compute_vertex_normals`.""" + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + vn = torch.zeros_like(verts) + vn.index_add_(0, faces[:, 0], fn) + vn.index_add_(0, faces[:, 1], fn) + vn.index_add_(0, faces[:, 2], fn) + return vn / vn.norm(dim=-1, keepdim=True).clamp(min=1e-8) + + +def _build_vcolor(canonical_colors, shader_preset, tilt_x, tilt_z): + """Mirrors the canonical_colors -> per-vertex RGB pipeline in + `rasterizer.render_pose_data`. Returns a numpy float32 (V, 3) table.""" + positions = np.asarray(canonical_colors.get("positions"), dtype=np.float32) + vcolor = rainbow_colors_from_canonical(positions, tilt_x_deg=tilt_x, tilt_z_deg=tilt_z).copy() + + if shader_preset in ("rainbow_face_normal", "rainbow_face_semantic"): + face_mask = canonical_colors.get("face_mask") + if face_mask is not None and face_mask.any(): + if shader_preset == "rainbow_face_normal": + norm = np.asarray(canonical_colors["norm"], dtype=np.float32) + vcolor[face_mask] = norm[face_mask] + else: # rainbow_face_semantic + sem = np.asarray(canonical_colors["face_region_rgb"], dtype=np.float32) + assigned = sem.sum(axis=1) > 0 + vcolor[assigned] = sem[assigned] + return vcolor + + +def _rasterize_chunk( + fv_pix: torch.Tensor, # (Fc, 3, 2) — pixel coords (sub-pixel float) + fv_z: torch.Tensor, # (Fc, 3) — image-frame z (smaller=closer) + bb_min_x: torch.Tensor, bb_max_x: torch.Tensor, # (Fc,) clamped int bboxes + bb_min_y: torch.Tensor, bb_max_y: torch.Tensor, + max_sx: int, max_sy: int, + W: int, +): + """Rasterize a chunk of faces at pixel centers. Returns flat tensors of + inside fragments: (pixel_idx, depth, face_local, bary). + """ + device = fv_pix.device + if max_sx == 0 or max_sy == 0: + return None + + sx = bb_max_x - bb_min_x + sy = bb_max_y - bb_min_y + + px_off = torch.arange(max_sx, device=device) + py_off = torch.arange(max_sy, device=device) + + # Pixel-center sample positions, broadcast to (Fc, max_sy, max_sx). + P_x = (bb_min_x[:, None, None] + px_off[None, None, :]).float() + 0.5 + P_y = (bb_min_y[:, None, None] + py_off[None, :, None]).float() + 0.5 + + in_bb = (px_off[None, None, :] < sx[:, None, None]) & \ + (py_off[None, :, None] < sy[:, None, None]) + + Ax = fv_pix[:, 0, 0][:, None, None] + Ay = fv_pix[:, 0, 1][:, None, None] + Bx = fv_pix[:, 1, 0][:, None, None] + By = fv_pix[:, 1, 1][:, None, None] + Cx = fv_pix[:, 2, 0][:, None, None] + Cy = fv_pix[:, 2, 1][:, None, None] + + area2 = (Bx - Ax) * (Cy - Ay) - (By - Ay) * (Cx - Ax) # (Fc, 1, 1) + e_a = (Bx - P_x) * (Cy - P_y) - (By - P_y) * (Cx - P_x) + e_b = (Cx - P_x) * (Ay - P_y) - (Cy - P_y) * (Ax - P_x) + e_c = (Ax - P_x) * (By - P_y) - (Ay - P_y) * (Bx - P_x) + + # Same-sign-as-area2 inside test (no back-face culling — match either winding). + nondegen = area2.abs() > 1e-6 # threshold rejects near-degenerate triangles + inside = (e_a * area2 >= 0) & (e_b * area2 >= 0) & (e_c * area2 >= 0) + inside = inside & in_bb & nondegen + + if not inside.any(): + return None + + inv_a2 = torch.where(nondegen, 1.0 / area2, torch.zeros_like(area2)) + w_a = e_a * inv_a2 + w_b = e_b * inv_a2 + w_c = e_c * inv_a2 + + z_a = fv_z[:, 0, None, None] + z_b = fv_z[:, 1, None, None] + z_c = fv_z[:, 2, None, None] + z_grid = w_a * z_a + w_b * z_b + w_c * z_c + + fi, yi, xi = inside.nonzero(as_tuple=True) + px_pixel = bb_min_x[fi] + xi + py_pixel = bb_min_y[fi] + yi + pixel_idx = (py_pixel * W + px_pixel).long() + z_flat = z_grid[fi, yi, xi] + bary_flat = torch.stack([w_a[fi, yi, xi], w_b[fi, yi, xi], w_c[fi, yi, xi]], dim=-1) + return pixel_idx, z_flat, fi.long(), bary_flat + + +def _rasterize_person( + verts_world: torch.Tensor, faces: torch.Tensor, + focal: float, W: int, H: int, + z_buf: torch.Tensor, color_buf: torch.Tensor, mask_buf: torch.Tensor, + shade_fn, +): + # Project image-frame verts to pixel coords. Skip verts at/behind camera. + z_min_ok = 0.05 + valid_v = verts_world[:, 2] > z_min_ok + safe_z = verts_world[:, 2].clamp(min=z_min_ok) + px = 0.5 * W + focal * verts_world[:, 0] / safe_z + py = 0.5 * H + focal * verts_world[:, 1] / safe_z + + Fv_pix = torch.stack([px, py], dim=-1)[faces] # (F, 3, 2) + Fv_z = verts_world[faces][..., 2] # (F, 3) + Fv_valid = valid_v[faces].all(dim=-1) + + sx_face = Fv_pix[..., 0] + sy_face = Fv_pix[..., 1] + bb_min_x = sx_face.amin(dim=-1).floor().long().clamp(min=0, max=W) + bb_max_x = (sx_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=W) + bb_min_y = sy_face.amin(dim=-1).floor().long().clamp(min=0, max=H) + bb_max_y = (sy_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=H) + sx_all = bb_max_x - bb_min_x + sy_all = bb_max_y - bb_min_y + valid_face = Fv_valid & (sx_all > 0) & (sy_all > 0) + + keep = torch.where(valid_face)[0] + if keep.numel() == 0: + return + + # Sort kept faces by max bbox dimension so chunks stay similarly-sized. + bbsize = torch.maximum(sx_all, sy_all)[keep] + order = torch.argsort(bbsize) + keep = keep[order] + n = keep.numel() + sx_cpu = sx_all[keep].tolist() + sy_cpu = sy_all[keep].tolist() + bbsize_cpu = bbsize[order].tolist() + + PIXEL_BUDGET = 4_000_000 + MAX_CHUNK = 8192 + + i = 0 + while i < n: + e = min(i + MAX_CHUNK, n) + # Shrink chunk so worst-case per-face bbox stays within pixel budget. + while e > i + 1: + bb = bbsize_cpu[e - 1] + if (e - i) * bb * bb <= PIXEL_BUDGET: + break + e = max(i + 1, e - max(1, (e - i) // 4)) + chunk = keep[i:e] + max_sx = max(sx_cpu[i:e]) + max_sy = max(sy_cpu[i:e]) + i = e + + result = _rasterize_chunk( + Fv_pix[chunk], Fv_z[chunk], + bb_min_x[chunk], bb_max_x[chunk], + bb_min_y[chunk], bb_max_y[chunk], + max_sx, max_sy, W, + ) + if result is None: + continue + pixel_idx, z_chunk, face_local, bary = result + face_global = chunk[face_local] + + # Atomic depth test against z_buf. + old_at = z_buf[pixel_idx].clone() + z_buf.scatter_reduce_(0, pixel_idx, z_chunk, reduce='amin', include_self=True) + new_at = z_buf[pixel_idx] + is_min = (z_chunk == new_at) & (new_at < old_at) + if not is_min.any(): + continue + + # Multiple fragments can land on the same pixel and share the new min; + # stable-sort by pixel and keep the first of each run so shade_fn runs + # once per winning pixel. O(M) where M = surviving fragments + surv_pixel = pixel_idx[is_min] + surv_face = face_global[is_min] + surv_bary = bary[is_min] + sort_perm = torch.argsort(surv_pixel, stable=True) + sp = surv_pixel[sort_perm] + first = torch.ones_like(sp, dtype=torch.bool) + first[1:] = sp[1:] != sp[:-1] + selected = sort_perm[first] + + wp_idx = surv_pixel[selected] + wp_face = surv_face[selected] + wp_bary = surv_bary[selected] + + color_buf[wp_idx] = shade_fn(wp_face, wp_bary) + mask_buf[wp_idx] = True + + +def _make_shade_fn( + shader_preset, composite, + view_normals_v, view_pos_v, vcolor_v, faces, + base_color, light_dir, pastel_mix, +): + device = view_normals_v.device + base_color_t = torch.as_tensor(base_color, dtype=torch.float32, device=device) + light_dir_t = torch.as_tensor(light_dir, dtype=torch.float32, device=device) + # Light-vector constants — normalized once per render call. + l_unit = -light_dir_t + l_unit = l_unit / l_unit.norm().clamp(min=1e-8) + + if pastel_mix <= 0.0: + apply_pastel = lambda rgb: rgb + else: + pm = float(pastel_mix) + apply_pastel = lambda rgb: rgb * (1.0 - pm) + pm + + def gather_n(face_idx, bary): + n = (view_normals_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1) + return n / n.norm(dim=-1, keepdim=True).clamp(min=1e-8) + + def gather_pos(face_idx, bary): + return (view_pos_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1) + + def gather_color(face_idx, bary): + return (vcolor_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1) + + if composite == "silhouette": + return lambda fi, ba: torch.ones((fi.shape[0], 3), device=device) + + if shader_preset == "normals": + # View-space surface normal encoded as RGB (OpenGL Y+ convention). + # +X right → R; +Y up → G; +Z toward viewer → B. Each face shows mostly + # one channel, matching standard normal-map visualization. + def shade(face_idx, bary): + n = gather_n(face_idx, bary) + return apply_pastel(((n + 1.0) * 0.5).clamp(0.0, 1.0)) + return shade + + use_vcolor = (shader_preset in _CANONICAL_PRESETS) and (vcolor_v is not None) + + if not use_vcolor: + # default.frag: ambient + diffuse + rim + def shade(face_idx, bary): + n = gather_n(face_idx, bary) + v = -gather_pos(face_idx, bary) + v = v / v.norm(dim=-1, keepdim=True).clamp(min=1e-8) + ndotl = (n * l_unit).sum(dim=-1).clamp(min=0) + ndotv = (n * v).sum(dim=-1).clamp(min=0) + rim = (1.0 - ndotv).pow(3.0) + lit = 0.25 * base_color_t \ + + 0.75 * base_color_t * ndotl.unsqueeze(-1) \ + + 0.35 * rim.unsqueeze(-1) + return apply_pastel(lit) + return shade + + if shader_preset == "rainbow": + def shade(face_idx, bary): + base = gather_color(face_idx, bary) + n = gather_n(face_idx, bary) + ndotl = (n * l_unit).sum(dim=-1).clamp(min=0) + return apply_pastel(base * (0.65 + 0.35 * ndotl).unsqueeze(-1)) + return shade + + # rainbow_face_* → rainbow_lit.frag. All light-direction & half-vector + # constants depend only on the (constant) light_dir, so precompute them. + key_l = l_unit + fill_l = torch.stack([-key_l[0], key_l[1].abs(), -key_l[2]]) + view_dir = torch.tensor([0.0, 0.0, 1.0], device=device) + h = key_l + view_dir + h = h / h.norm().clamp(min=1e-8) + + def shade(face_idx, bary): + base = gather_color(face_idx, bary) + n = gather_n(face_idx, bary) + key_ndotl = (n * key_l).sum(dim=-1).clamp(min=0) + fill_ndotl = (n * fill_l).sum(dim=-1).clamp(min=0) + rim = (1.0 - n[..., 2].clamp(min=0)).pow(2.5) * 0.30 + shade_val = (0.45 + 0.45 * key_ndotl + 0.15 * fill_ndotl + rim * 0.5).clamp(min=0.0, max=1.25) + ndoth = (n * h).sum(dim=-1).clamp(min=0) + spec = ndoth.pow(48) * 0.12 + lit = base * shade_val.unsqueeze(-1) + spec.unsqueeze(-1) + return apply_pastel(lit) + return shade + + +def render_pose_data_torch( + pose_data: dict, + frame_idx: int, + W: int, + H: int, + background=None, # Optional[np.ndarray | torch.Tensor] (H, W, 3) fp32 [0, 1] + composite: str = "over", + opacity: float = 1.0, + shader_preset: str = "default", + base_color: Sequence[float] = (0.68, 0.71, 0.78), + light_dir: Sequence[float] = (0.4, -0.7, -0.6), + rainbow_tilt_x_deg: float = 0.0, + rainbow_tilt_z_deg: float = 0.0, + person_brightness_falloff: float = 0.6, +) -> torch.Tensor: + """Render one frame of persons from `pose_data` at resolution WxH. + + Returns an (H, W, 3) float32 torch.Tensor on the comfy compute device, + ready to be stacked into the node's IMAGE output without a CPU round-trip.""" + device = comfy.model_management.get_torch_device() + persons = pose_data["frames"][frame_idx] if frame_idx < len(pose_data["frames"]) else [] + if len(persons) == 0: + if composite == "over" and background is not None: + if isinstance(background, np.ndarray): + bg = torch.as_tensor(background, dtype=torch.float32, device=device) + else: + bg = background.to(device=device, dtype=torch.float32) if ( + background.device != device or background.dtype != torch.float32 + ) else background + return bg.clamp(0.0, 1.0) + return torch.zeros((H, W, 3), device=device, dtype=torch.float32) + + faces = torch.as_tensor(np.asarray(pose_data["faces"], dtype=np.int64), device=device) + + canonical_colors = pose_data.get("canonical_colors") + using_canonical = shader_preset in _CANONICAL_PRESETS + if using_canonical and canonical_colors is None: + shader_preset = "default" + using_canonical = False + + vcolor = None + if using_canonical: + vcolor_np = _build_vcolor(canonical_colors, shader_preset, + rainbow_tilt_x_deg, rainbow_tilt_z_deg) + vcolor = torch.as_tensor(vcolor_np, dtype=torch.float32, device=device) + + falloff = max(0.0, min(1.0, float(person_brightness_falloff))) + person_pastel = [0.0 if k == 0 else (1.0 - falloff ** k) for k in range(len(persons))] + + # Front-to-back draw order so nearer persons overdraw farther ones. + order = sorted(range(len(persons)), + key=lambda i: -float(np.asarray(persons[i]["pred_cam_t"]).reshape(-1)[2])) + + HW = H * W + z_buf = torch.full((HW,), float('inf'), device=device, dtype=torch.float32) + color_buf = torch.zeros((HW, 3), device=device, dtype=torch.float32) + mask_buf = torch.zeros(HW, device=device, dtype=torch.bool) + + for idx in order: + p = persons[idx] + verts_np = np.asarray(p["pred_vertices"], dtype=np.float32).reshape(-1, 3) + cam_t = np.asarray(p["pred_cam_t"], dtype=np.float32).reshape(3) + verts_world = torch.as_tensor(verts_np + cam_t[None, :], + device=device, dtype=torch.float32) + focal = float(np.asarray(p.get("focal_length", 5000.0)).reshape(-1)[0]) + + # Image-frame (+Y down, +Z forward) → view-space (+Y up, -Z forward) + # for shading, matching what the GL-style shader math expects. + view_pos_v = torch.stack( + [verts_world[:, 0], -verts_world[:, 1], -verts_world[:, 2]], dim=-1, + ) + normals_world = _vertex_normals(verts_world, faces) + view_normals_v = torch.stack( + [normals_world[:, 0], -normals_world[:, 1], -normals_world[:, 2]], dim=-1, + ) + + vcolor_p = vcolor if (vcolor is not None and vcolor.shape[0] == verts_world.shape[0]) else None + # Only canonical-vcolor shaders need vcolor; geometric shaders + # ('normals', 'depth') and the lit default work without it. + if shader_preset in _CANONICAL_PRESETS and vcolor_p is None: + effective_preset = "default" + else: + effective_preset = shader_preset + + shade_fn = _make_shade_fn( + effective_preset, composite, + view_normals_v, view_pos_v, vcolor_p, faces, + base_color, light_dir, person_pastel[idx], + ) + + _rasterize_person( + verts_world, faces, focal, W, H, + z_buf, color_buf, mask_buf, shade_fn, + ) + + # Stay on GPU through readback + composite. + if shader_preset == "depth": + # z_buf already holds linear image-frame z (smaller=closer; +inf where no mesh covers) + # Normalize within the rendered mesh's range: near=white, far=black, background=black + mask_2d = mask_buf.reshape(H, W) + z_2d = z_buf.reshape(H, W) + if mask_2d.any(): + zin = z_2d[mask_2d] + zmin = zin.min() + zr = (zin.max() - zmin).clamp(min=1e-6) + norm = torch.where(mask_2d, 1.0 - (z_2d - zmin) / zr, torch.zeros_like(z_2d)) + else: + norm = torch.zeros((H, W), device=device, dtype=torch.float32) + rendered = torch.stack([norm, norm, norm], dim=-1) + mask_f = mask_2d.float() + else: + rendered = color_buf.reshape(H, W, 3).clamp(0.0, 1.0) + mask_f = mask_buf.reshape(H, W).float() + + if composite == "over" and background is not None: + if isinstance(background, np.ndarray): + bg = torch.as_tensor(background, dtype=torch.float32, device=device) + else: + bg = background.to(device=device, dtype=torch.float32) + a = mask_f.unsqueeze(-1) + if opacity != 1.0: + a = a * float(opacity) + rendered = torch.lerp(bg, rendered, a) + + return rendered diff --git a/comfy_extras/sam3d_body/utils.py b/comfy_extras/sam3d_body/utils.py new file mode 100644 index 000000000..09e5f300f --- /dev/null +++ b/comfy_extras/sam3d_body/utils.py @@ -0,0 +1,397 @@ +import math +from typing import Any, Dict, List, Optional, Tuple +import torch +import torch.nn.functional as F +import numpy as np + +from comfy.ldm.sam3d_body.utils import prepare_batch +from comfy.ldm.sam3.tracker import unpack_masks +from comfy.ldm.sam3d_body.model.model import SAM3DBody + + +import comfy.model_management +from comfy_api.latest import io +import comfy.utils + +from tqdm import tqdm + +def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]: + """xyxy bounds of a binary mask, with sub-5px speckles filtered out.""" + m = mask[..., 0] if mask.dim() == 3 else mask + m_bool = m > 0 + if not m_bool.any(): + return None + t = m_bool.to(torch.float32)[None, None] + eroded = -F.max_pool2d(-t, kernel_size=5, stride=1, padding=2) + keep = eroded[0, 0] > 0.5 + if not keep.any(): + keep = m_bool + ys, xs = torch.where(keep) + return torch.stack([ + xs.min().float(), ys.min().float(), + (xs.max() + 1).float(), (ys.max() + 1).float(), + ]) + + +def inputs_from_sam3_track(track_data, B: int, H: int, W: int): + """Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image + resolution. Returns (per_frame_bboxes, per_frame_masks) or + (None, None) when the track is empty / frame count doesn't match""" + + packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None + if packed is None: + return None, None + unpacked = unpack_masks(packed) # (N, K, Hm, Wm) bool + N, K = unpacked.shape[:2] + if N != B or K == 0: + return None, None + Hm, Wm = unpacked.shape[2], unpacked.shape[3] + resized = F.interpolate( + unpacked.float().reshape(N * K, 1, Hm, Wm), + size=(H, W), mode="bilinear", align_corners=False, + ) + arr = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W).cpu() + + per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)] + full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32) + per_frame_bboxes = [] + for f in range(N): + derived = [] + for k in range(K): + b = _bbox_from_mask(arr[f, k]) + derived.append(b if b is not None else full_frame_bbox) + per_frame_bboxes.append(torch.stack(derived, dim=0)) + return per_frame_bboxes, per_frame_masks + + +# Soft budget for the batched Predict path +BATCHED_CROPS_PER_CHUNK = 64 + + +def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[torch.Tensor]: + """(1,3,3) intrinsic matrix from a vertical FOV in degrees. Matches MoGe2's + convention (vertical focal for both axes). Returns None for fov<=0 so the + caller falls back to prepare_batch's diagonal-focal default.""" + if fov_degrees <= 0: + return None + focal = height / (2.0 * math.tan(math.radians(fov_degrees) / 2.0)) + return torch.tensor( + [[[focal, 0.0, width / 2.0], + [0.0, focal, height / 2.0], + [0.0, 0.0, 1.0]]], + dtype=torch.float32, + ) + + +def cam_int_from_moge(moge_geometry, height: int, width: int) -> Optional[torch.Tensor]: + """(1,3,3) intrinsic matrix from a MoGe geometry payload. Uses MoGe's + vertical focal for both axes; forces principal point to image center + (overrides MoGe's predicted cx/cy to match prepare_batch's convention).""" + if moge_geometry is None: + return None + # MOGE_GEOMETRY is a dict with optional keys (see comfy_extras/nodes_moge.py). + K_norm = moge_geometry.get("intrinsics") if isinstance(moge_geometry, dict) else None + if K_norm is None: + return None + if K_norm.ndim == 3: + K_norm = K_norm[0] + # MoGe stores fy in height-units (multiply by H to get pixels); vfov = fy. + fy_norm = float(K_norm[1, 1].item()) + focal = fy_norm * height + return torch.tensor( + [[[focal, 0.0, width / 2.0], + [0.0, focal, height / 2.0], + [0.0, 0.0, 1.0]]], + dtype=torch.float32, + ) + + +def run_batched_single_chunk( + inner: SAM3DBody, + frames_rgb: List[torch.Tensor], + per_frame_boxes: List[torch.Tensor], + per_frame_masks: Optional[List[torch.Tensor]], + image_size: Tuple[int, int], + inference_type: str, + K: int, + cam_int: Optional[torch.Tensor] = None, +) -> List[List[Dict[str, Any]]]: + """Run a SINGLE chunk of frames through run_inference in one forward.""" + N = len(frames_rgb) + total = N * K + + # Reset stateful caches on the model + for attr in ("batch", "image_embeddings", "output"): + if hasattr(inner, attr): + setattr(inner, attr, None) + inner.prev_prompt = [] + + boxes_stacked = torch.stack( + [per_frame_boxes[f][k] for f in range(N) for k in range(K)], dim=0 + ) + img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)] + + if per_frame_masks is not None: + # Broadcast a single-mask bundle to per-bbox: when the user supplied one + # mask but multiple bboxes per frame, each bbox gets the same mask. + flat_masks = [] + for f in range(N): + mf = per_frame_masks[f] + if mf.shape[0] == 1 and K > 1: + mf = mf.repeat_interleave(K, dim=0) + flat_masks.extend([mf[k] for k in range(K)]) + masks_stacked = torch.stack(flat_masks, dim=0) + masks_score = torch.ones(total, dtype=torch.float32) + else: + masks_stacked = None + masks_score = None + + batch = prepare_batch( + img_per_crop, boxes_stacked, + input_size=image_size, + masks=masks_stacked, masks_score=masks_score, cam_int=cam_int, + ) + device = comfy.model_management.get_torch_device() + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + inner._initialize_batch(batch) + + outputs = inner.run_inference( + img_per_crop, + batch, + inference_type=inference_type, + thresh_wrist_angle=1.4, + ) + + if inference_type == "full": + pose_output, batch_lhand, batch_rhand, _, _ = outputs + else: + pose_output = outputs + batch_lhand = batch_rhand = None + + out = {k: v.cpu().numpy() for k, v in pose_output["mhr"].items() + if v is not None and k != "faces"} + + # Snapshot batch['bbox'] to CPU before we release `batch` references + batch_bbox_cpu = batch["bbox"][0].cpu().numpy() + lhand_bboxes = rhand_bboxes = None + if inference_type == "full" and batch_lhand is not None and batch_rhand is not None: + lhand_bboxes = [_bbox_from_center_scale(batch_lhand, i) for i in range(total)] + rhand_bboxes = [_bbox_from_center_scale(batch_rhand, i) for i in range(total)] + + del pose_output, batch, batch_lhand, batch_rhand, outputs + + frames_out: List[List[Dict[str, Any]]] = [] + for f in range(N): + persons: List[Dict[str, Any]] = [] + for k in range(K): + idx = f * K + k + p: Dict[str, Any] = { + "bbox": batch_bbox_cpu[idx], + "focal_length": out["focal_length"][idx], + "pred_keypoints_3d": out["pred_keypoints_3d"][idx], + "pred_keypoints_2d": out["pred_keypoints_2d"][idx], + "pred_vertices": out["pred_vertices"][idx], + "pred_cam_t": out["pred_cam_t"][idx], + "pred_pose_raw": out["pred_pose_raw"][idx], + "global_rot": out["global_rot"][idx], + "body_pose_params": out["body_pose"][idx], + "hand_pose_params": out["hand"][idx], + "scale_params": out["scale"][idx], + "shape_params": out["shape"][idx], + "expr_params": out["face"][idx], + "mask": (per_frame_masks[f][k] if per_frame_masks[f].shape[0] > 1 else per_frame_masks[f][0]) + if per_frame_masks is not None else None, + "pred_joint_coords": out["pred_joint_coords"][idx], + "pred_global_rots": out["joint_global_rots"][idx], + "mhr_model_params": out["mhr_model_params"][idx], + # 238 face landmarks from sapiens-308 (indices 70..308 of the pre-slice keypoint tensor). + "pred_face_keypoints_3d": out["pred_face_keypoints_3d"][idx] if "pred_face_keypoints_3d" in out else None, + "pred_face_keypoints_2d": out["pred_face_keypoints_2d"][idx] if "pred_face_keypoints_2d" in out else None, + } + if lhand_bboxes is not None: + p["lhand_bbox"] = lhand_bboxes[idx] + p["rhand_bbox"] = rhand_bboxes[idx] + persons.append(p) + frames_out.append(persons) + return frames_out + + +def run_batched_frames( + inner: SAM3DBody, + frames_rgb: List[torch.Tensor], + per_frame_boxes: List[torch.Tensor], + per_frame_masks: Optional[List[torch.Tensor]], + image_size: Tuple[int, int], + inference_type: str, + cam_int: Optional[torch.Tensor] = None, + pbar: Optional[comfy.utils.ProgressBar] = None, + crops_per_chunk: int = BATCHED_CROPS_PER_CHUNK, +) -> List[List[Dict[str, Any]]]: + """Run the clip through chunked batched run_inference calls. + + Supports K persons per frame (K must be the same across frames — padded + externally). Splits frames into chunks so chunk_frames * K <= budget; each + chunk is one body forward + optional hand forwards over its person-crops + """ + N = len(frames_rgb) + assert N > 0, "empty frame list" + K_set = {len(b) for b in per_frame_boxes} + assert len(K_set) == 1, f"batched path requires same bbox count per frame, got {K_set}" + K = K_set.pop() + assert K >= 1, "need at least one bbox per frame" + + chunk_frames = max(1, crops_per_chunk // K) + results: List[List[Dict[str, Any]]] = [] + with tqdm(total=N, desc="SAM3D body inference") as t: + for start in range(0, N, chunk_frames): + end = min(N, start + chunk_frames) + sub_frames = frames_rgb[start:end] + sub_boxes = per_frame_boxes[start:end] + sub_masks = None if per_frame_masks is None else per_frame_masks[start:end] + chunk_result = run_batched_single_chunk( + inner, sub_frames, sub_boxes, sub_masks, + image_size, inference_type, K, + cam_int=cam_int, + ) + results.extend(chunk_result) + t.update(end - start) + if pbar is not None: + pbar.update(end - start) + # Drop GPU caches so the next chunk starts from a clean allocator state + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return results + + +def _bbox_from_center_scale(batch, idx: int) -> np.ndarray: + cx = batch["bbox_center"].flatten(0, 1)[idx][0].item() + cy = batch["bbox_center"].flatten(0, 1)[idx][1].item() + sx = batch["bbox_scale"].flatten(0, 1)[idx][0].item() + sy = batch["bbox_scale"].flatten(0, 1)[idx][1].item() + return np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2], dtype=np.float32) + +# Wire types and small helpers shared across the SAM 3D Body node modules. + +def image_to_uint8(image: torch.Tensor) -> torch.Tensor: + """ComfyUI image tensor (any shape, float 0..1) → uint8 tensor in [0, 255] on CPU.""" + return (image * 255.0).clamp(0.0, 255.0).to(dtype=torch.uint8, device="cpu") + + +def compute_canonical_colors(model) -> Dict[str, np.ndarray]: + """Canonical rest-pose data for shader color lookups: positions (Nv,3), + norm (Nv,3 in [0,1]), face_mask, head_mask, and face_region_rgb + (per-region painted color from the .safetensors).""" + verts = model.head_pose.canonical_vertices().float().cpu().numpy() + faces = model.head_pose.faces.cpu().numpy() + + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + fn = np.cross(v1 - v0, v2 - v0).astype(np.float32) + vn = np.zeros_like(verts, dtype=np.float32) + np.add.at(vn, faces[:, 0], fn) + np.add.at(vn, faces[:, 1], fn) + np.add.at(vn, faces[:, 2], fn) + ln = np.linalg.norm(vn, axis=1, keepdims=True) + ln[ln < 1e-8] = 1.0 + vn = vn / ln + norm_map = ((vn + 1.0) * 0.5).astype(np.float32) + + face_mask = _compute_face_mask(model) + # Head: above jaw-neck (y>1.43) and narrower than shoulders (|x|<0.11). + # Ears reach |x|≈0.09; shoulders start at |x|≈0.20. + head_mask = (verts[:, 1] > 1.43) & (np.abs(verts[:, 0]) < 0.11) + + # Painted per-vertex face region RGB ships in the model .safetensors as + # `head_pose.face_region_rgb` and gets loaded by load_state_dict. + face_region_rgb = model.head_pose.face_region_rgb.detach().float().cpu().numpy() + + return { + "positions": verts.astype(np.float32), + "norm": norm_map, + "face_mask": face_mask, + "head_mask": head_mask, + "face_region_rgb": face_region_rgb, + } + + +def compute_hand_vert_mask(model, hand_radius_m: float = 0.15, weight_threshold: float = 0.5) -> np.ndarray: + """(Nv,) bool mask of hand-region verts. Picks joints within `hand_radius_m` + of the mhr70 hand keypoint clusters (indices 21..62), then sums sparse LBS + weights across them; verts above `weight_threshold` are hand verts.""" + head = model.head_pose + mhr = head.mhr + device = head.scale_mean.device + + zeros = lambda *s: torch.zeros(1, *s, device=device) + out = head.mhr_forward( + global_trans=zeros(3), + global_rot=zeros(3), + body_pose_params=zeros(130), + hand_pose_params=zeros(head.num_hand_comps * 2), + scale_params=zeros(head.num_scale_comps), + shape_params=zeros(head.num_shape_comps), + expr_params=zeros(head.num_face_comps), + return_keypoints=True, + return_joint_coords=True, + ) + # Output order with these flags: (verts, kp, jcoords). See mhr_head.mhr_forward. + _, kp, jcoords = out[0], out[1], out[2] + kp = kp[0, :70].cpu().numpy() + jcoords = jcoords[0].cpu().numpy() + + right_center = kp[21:42].mean(axis=0) + left_center = kp[42:63].mean(axis=0) + j_dist_r = np.linalg.norm(jcoords - right_center, axis=1) + j_dist_l = np.linalg.norm(jcoords - left_center, axis=1) + is_hand_joint = (j_dist_r < hand_radius_m) | (j_dist_l < hand_radius_m) + + lbs_w = mhr.lbs_skin_weights.cpu().numpy() + lbs_v = mhr.lbs_vert_indices.cpu().numpy() + lbs_j = mhr.lbs_skin_indices.cpu().numpy() + is_hand_joint_f = is_hand_joint.astype(np.float32) + + n_verts = mhr.NUM_VERTS + hand_mass = np.zeros(n_verts, dtype=np.float32) + np.add.at(hand_mass, lbs_v, lbs_w * is_hand_joint_f[lbs_j]) + return hand_mass >= weight_threshold + + +def _compute_face_mask(model, disp_threshold_m: float = 1e-4) -> np.ndarray: + """(Nv,) bool mask of verts that move with face expression. Sweeps each of + the 72 expression axes at coef=+1.0 and flags any vert that moves more + than `disp_threshold_m` for at least one axis.""" + head = model.head_pose + device = head.scale_mean.device + num_face = head.num_face_comps + + zeros = lambda *s: torch.zeros(1, *s, device=device) + neutral_kw = dict( + global_trans=zeros(3), + global_rot=zeros(3), + body_pose_params=zeros(130), + hand_pose_params=zeros(head.num_hand_comps * 2), + scale_params=zeros(head.num_scale_comps), + shape_params=zeros(head.num_shape_comps), + expr_params=zeros(num_face), + ) + v0 = head.mhr_forward(**neutral_kw).cpu().numpy()[0] # (Nv, 3) + + face_mask = np.zeros(v0.shape[0], dtype=bool) + for axis in range(num_face): + expr = zeros(num_face) + expr[0, axis] = 1.0 + kw = dict(neutral_kw) + kw["expr_params"] = expr + v = head.mhr_forward(**kw).cpu().numpy()[0] + face_mask |= (np.linalg.norm(v - v0, axis=1) > disp_threshold_m) + return face_mask + + +def jet_colormap(s: np.ndarray) -> np.ndarray: + """matplotlib jet, (N,) in [0,1] -> (N, 3) float32 RGB.""" + s = np.asarray(s, dtype=np.float32).clip(0.0, 1.0) + r = np.interp(s, [0.0, 0.35, 0.66, 0.89, 1.0], [0.0, 0.0, 1.0, 1.0, 0.5]) + g = np.interp(s, [0.0, 0.125, 0.375, 0.64, 0.91, 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0]) + b = np.interp(s, [0.0, 0.11, 0.34, 0.65, 1.0], [0.5, 1.0, 1.0, 0.0, 0.0]) + return np.stack([r, g, b], axis=-1).astype(np.float32) diff --git a/nodes.py b/nodes.py index 13e46ac8a..820eeef4c 100644 --- a/nodes.py +++ b/nodes.py @@ -2445,6 +2445,7 @@ async def init_builtin_extra_nodes(): "nodes_save_3d.py", "nodes_moge.py", "nodes_mediapipe.py", + "nodes_sam3d_body.py", ] import_failed = []