Initial sam3d body support

This commit is contained in:
kijai 2026-05-26 02:15:15 +03:00
parent 04879a8113
commit 1294041778
24 changed files with 10261 additions and 292 deletions

View File

@ -0,0 +1,377 @@
from typing import Optional
import torch
import torch.nn as nn
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask
from ..model.transformer import MLP
class MHRHead(nn.Module):
def __init__(
self,
input_dim: int,
mhr_rig,
mlp_depth: int = 1,
extra_joint_regressor: str = "",
mlp_channel_div_factor: int = 8,
enable_hand_model=False,
device=None,
dtype=None,
operations=None,
):
super().__init__()
# Store the shared MHRRig as a non-registered Python attribute
object.__setattr__(self, "mhr", mhr_rig)
self.num_shape_comps = 45
self.num_scale_comps = 28
self.num_hand_comps = 54
self.num_face_comps = 72
self.enable_hand_model = enable_hand_model
self.body_cont_dim = 260
self.npose = (
6 # Global Rotation
+ self.body_cont_dim # then body
+ self.num_shape_comps
+ self.num_scale_comps
+ self.num_hand_comps * 2
+ self.num_face_comps
)
self.proj = MLP(
input_dim=input_dim,
hidden_dim=input_dim // mlp_channel_div_factor,
output_dim=self.npose,
num_layers=mlp_depth,
device=device,
dtype=dtype,
operations=operations,
)
# MHR Parameters
self.num_hand_scale_comps = self.num_scale_comps - 18
self.num_hand_pose_comps = self.num_hand_comps
# Buffers populated by load_state_dict from the safetensors
def _p(*shape, dtype=torch.float32):
return nn.Parameter(torch.empty(*shape, dtype=dtype), requires_grad=False)
self.joint_rotation = _p(127, 3, 3)
self.scale_mean = _p(68)
self.scale_comps = _p(28, 68)
self.faces = _p(36874, 3, dtype=torch.int64)
self.hand_pose_mean = _p(54)
self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
self.hand_joint_idxs_left = _p(27, dtype=torch.int64)
self.hand_joint_idxs_right = _p(27, dtype=torch.int64)
self.keypoint_mapping = _p(308, 18439 + 127)
# Some special buffers for the hand-version
self.right_wrist_coords = _p(3)
self.root_coords = _p(3)
self.local_to_world_wrist = _p(3, 3)
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
# Optional — loaded from the .safetensors if present, otherwise the
# render path falls back to a coarse geometric approximation.
self.register_buffer(
"face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32),
)
def canonical_vertices(self, device=None):
"""Return the T-pose vertices for the mean shape (scaled to meters).
Runs MHR with zero pose / shape / scale / expression so the returned
mesh is the canonical rest pose fixed per-model
"""
dev = device or self.scale_mean.device
dt = self.scale_mean.dtype
B = 1
global_trans = torch.zeros(B, 3, device=dev, dtype=dt)
global_rot = torch.zeros(B, 3, device=dev, dtype=dt)
body_pose = torch.zeros(B, 130, device=dev, dtype=dt)
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt)
scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt)
shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt)
expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt)
verts = self.mhr_forward(
global_trans=global_trans,
global_rot=global_rot,
body_pose_params=body_pose,
hand_pose_params=hand_pose,
scale_params=scale,
shape_params=shape,
expr_params=expr,
) # single-tensor shape (1, N_v, 3) in meters
return verts[0]
def get_zero_pose_init(self, factor=1.0):
# Initialize pose token with zero-initialized learnable params
# Note: bias/initial value should be zero-pose in cont, not all-zeros
weights = torch.zeros(1, self.npose)
weights[:, : 6 + self.body_cont_dim] = torch.cat(
[
torch.FloatTensor([1, 0, 0, 0, 1, 0]),
compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
* factor,
],
dim=0,
)
return weights
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
assert full_pose_params.shape[1] == 136
# This drops in the hand poses from hand_pose_params (PCA 6D) into full_pose_params.
# Split into left and right hands
left_hand_params, right_hand_params = torch.split(
hand_pose_params,
[self.num_hand_pose_comps, self.num_hand_pose_comps],
dim=1,
)
# Change from cont to model params
left_hand_params_model_params = compact_cont_to_model_params_hand(
self.hand_pose_mean
+ torch.einsum("da,ab->db", left_hand_params, self.hand_pose_comps)
)
right_hand_params_model_params = compact_cont_to_model_params_hand(
self.hand_pose_mean
+ torch.einsum("da,ab->db", right_hand_params, self.hand_pose_comps)
)
# Drop it in
full_pose_params[:, self.hand_joint_idxs_left] = left_hand_params_model_params
full_pose_params[:, self.hand_joint_idxs_right] = right_hand_params_model_params
return full_pose_params # B x 207
def mhr_forward(
self,
global_trans,
global_rot,
body_pose_params,
hand_pose_params,
scale_params,
shape_params,
expr_params=None,
return_keypoints=False,
do_pcblend=True,
return_joint_coords=False,
return_model_params=False,
return_joint_rotations=False,
scale_offsets=None,
vertex_offsets=None,
):
# Align everything to the static buffers
dt = self.scale_mean.dtype
global_trans = global_trans.to(dt)
global_rot = global_rot.to(dt)
body_pose_params = body_pose_params.to(dt)
if hand_pose_params is not None:
hand_pose_params = hand_pose_params.to(dt)
scale_params = scale_params.to(dt)
shape_params = shape_params.to(dt)
if expr_params is not None:
expr_params = expr_params.to(dt)
if self.enable_hand_model:
# Transfer wrist-centric predictions to the body.
global_rot_ori = global_rot.clone()
global_trans_ori = global_trans.clone()
global_rot = rotmat_to_euler(
"xyz",
euler_to_rotmat("xyz", global_rot_ori) @ self.local_to_world_wrist,
)
global_trans = (
-(
euler_to_rotmat("xyz", global_rot)
@ (self.right_wrist_coords - self.root_coords)
+ self.root_coords
)
+ global_trans_ori
)
body_pose_params = body_pose_params[..., :130]
# Convert from scale and shape params to actual scales and vertices
# Add singleton batches in case...
if len(scale_params.shape) == 1:
scale_params = scale_params[None]
if len(shape_params.shape) == 1:
shape_params = shape_params[None]
# Convert scale...
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
if scale_offsets is not None:
scales = scales + scale_offsets
# Now, figure out the pose.
## 10 here is because it's more stable to optimize global translation in meters.
full_pose_params = torch.cat(
[global_trans * 10, global_rot, body_pose_params], dim=1
) # B x 127
## Put in hands
if hand_pose_params is not None:
full_pose_params = self.replace_hands_in_pose(
full_pose_params, hand_pose_params
)
model_params = torch.cat([full_pose_params, scales], dim=1)
if self.enable_hand_model:
# Zero out non-hand parameters
model_params[:, self.nonhand_param_idxs] = 0
curr_skinned_verts, curr_skel_state = self.mhr(
shape_params, model_params, expr_params
)
curr_joint_coords, curr_joint_quats, _ = torch.split(
curr_skel_state, [3, 4, 1], dim=2
)
curr_skinned_verts = curr_skinned_verts / 100
curr_joint_coords = curr_joint_coords / 100
curr_joint_rots = unitquat_to_rotmat(curr_joint_quats)
# Prepare returns
to_return = [curr_skinned_verts]
if return_keypoints:
# Get sapiens 308 keypoints
model_vert_joints = torch.cat(
[curr_skinned_verts, curr_joint_coords], dim=1
) # B x (num_verts + 127) x 3
kp_map = self.keypoint_mapping.to(model_vert_joints.dtype)
model_keypoints_pred = (
(kp_map @ model_vert_joints.permute(1, 0, 2).flatten(1, 2))
.reshape(-1, model_vert_joints.shape[0], 3)
.permute(1, 0, 2)
)
if self.enable_hand_model:
# Zero out everything except for the right hand
model_keypoints_pred[:, :21] = 0
model_keypoints_pred[:, 42:] = 0
to_return = to_return + [model_keypoints_pred]
if return_joint_coords:
to_return = to_return + [curr_joint_coords]
if return_model_params:
to_return = to_return + [model_params]
if return_joint_rotations:
to_return = to_return + [curr_joint_rots]
if isinstance(to_return, list) and len(to_return) == 1:
return to_return[0]
else:
return tuple(to_return)
def forward(
self,
x: torch.Tensor,
init_estimate: Optional[torch.Tensor] = None,
do_pcblend=True,
slim_keypoints=False,
intermediate: bool = False,
):
"""
Args:
x: pose token with shape [B, C], usually C=DECODER.DIM
init_estimate: [B, self.npose]
intermediate: when True, the caller only needs the keypoints/pose
outputs needed by the per-layer keypoint-token update path
vertex output is suppressed so `camera_project` skips the
18439-vertex perspective projection on intermediate decoder
layers. The final layer must call with intermediate=False.
"""
batch_size = x.shape[0]
pred = self.proj(x)
if init_estimate is not None:
pred = pred + init_estimate
# From pred, we want to pull out individual predictions.
## First, get globals
### Global rotation is first 6.
count = 6
global_rot_6d = pred[:, :count]
global_rot_rotmat = rot6d_to_rotmat(global_rot_6d) # B x 3 x 3
global_rot_euler = rotmat_to_euler("ZYX", global_rot_rotmat) # B x 3
global_trans = torch.zeros_like(global_rot_euler)
## Next, get body pose.
### Hold onto raw, continuous version for iterative correction.
pred_pose_cont = pred[:, count : count + self.body_cont_dim]
count += self.body_cont_dim
### Convert to eulers (and trans)
pred_pose_euler = compact_cont_to_model_params_body(pred_pose_cont)
### Zero-out hands
pred_pose_euler[:, mhr_param_hand_mask] = 0
### Zero-out jaw
pred_pose_euler[:, -3:] = 0
## Get remaining parameters
pred_shape = pred[:, count : count + self.num_shape_comps]
count += self.num_shape_comps
pred_scale = pred[:, count : count + self.num_scale_comps]
count += self.num_scale_comps
pred_hand = pred[:, count : count + self.num_hand_comps * 2]
count += self.num_hand_comps * 2
pred_face = pred[:, count : count + self.num_face_comps] * 0
count += self.num_face_comps
# Run everything through mhr
output = self.mhr_forward(
global_trans=global_trans,
global_rot=global_rot_euler,
body_pose_params=pred_pose_euler,
hand_pose_params=pred_hand,
scale_params=pred_scale,
shape_params=pred_shape,
expr_params=pred_face,
do_pcblend=do_pcblend,
return_keypoints=True,
return_joint_coords=True,
return_model_params=True,
return_joint_rotations=True,
)
# Some existing code to get joints and fix camera system
verts, j3d, jcoords, mhr_model_params, joint_global_rots = output
j3d = j3d[:, :70] # 308 --> 70 keypoints
# Intermediate decoder layers only consume pred_keypoints_3d via the
# keypoint-token update path; suppress verts so camera_project skips
# the 18439-vertex perspective projection.
if intermediate:
verts = None
if verts is not None:
verts[..., [1, 2]] *= -1 # Camera system difference
j3d[..., [1, 2]] *= -1 # Camera system difference
if jcoords is not None:
jcoords[..., [1, 2]] *= -1
# Head-MLP outputs are promoted to fp32 here so the external
# pose_output["mhr"] contract has a stable dtype regardless of what
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are
# already fp32 from MHR's math layers; the cast on them is a no-op.
output = {
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
"pred_pose_rotmat": None,
"global_rot": global_rot_euler.float(),
"body_pose": pred_pose_euler.float(),
"shape": pred_shape.float(),
"scale": pred_scale.float(),
"hand": pred_hand.float(),
"face": pred_face.float(),
"pred_keypoints_3d": j3d.reshape(batch_size, -1, 3),
"pred_vertices": verts.reshape(batch_size, -1, 3) if verts is not None else None,
"pred_joint_coords": jcoords.reshape(batch_size, -1, 3) if jcoords is not None else None,
"faces": self.faces.cpu().numpy(),
"joint_global_rots": joint_global_rots,
"mhr_model_params": mhr_model_params,
}
return output

View File

@ -0,0 +1,235 @@
# Adapted from facebookresearch/MHR (Apache 2.0):
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
# verbatim from the TorchScript source bundled in the upstream mhr_model.pt
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.sparse.check_sparse_tensor_invariants.disable() # silence the "implicitly disabled" UserWarning.
from .mhr_utils import batch6DFromXYZ
_LN2 = 0.6931471824645996
def _euler_xyz_to_quat(angles):
"""(roll, pitch, yaw) -> quaternion (x, y, z, w). Matches pymomentum.quaternion.euler_xyz_to_quaternion."""
roll, pitch, yaw = angles.unbind(-1)
cy, sy = torch.cos(yaw * 0.5), torch.sin(yaw * 0.5)
cp, sp = torch.cos(pitch * 0.5), torch.sin(pitch * 0.5)
cr, sr = torch.cos(roll * 0.5), torch.sin(roll * 0.5)
x = sr * cp * cy - cr * sp * sy
y = cr * sp * cy + sr * cp * sy
z = cr * cp * sy - sr * sp * cy
w = cr * cp * cy + sr * sp * sy
return torch.stack([x, y, z, w], dim=-1)
def _quat_multiply(q1, q2):
x1, y1, z1, w1 = q1.unbind(-1)
x2, y2, z2, w2 = q2.unbind(-1)
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
return torch.stack([x, y, z, w], dim=-1)
def _quat_rotate(q, v):
"""Rotate v by unit quaternion q (xyzw). v + 2 * (axis x v * w + axis x (axis x v))."""
axis = q[..., :3]
r = q[..., 3:4]
av = torch.cross(axis, v, dim=-1)
aav = torch.cross(axis, av, dim=-1)
return v + 2.0 * (av * r + aav)
def _skel_multiply(s1, s2):
"""Compose two skel states (..., 8). Returns parent ∘ child.
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
before composition. With many FK levels the previously-normalized quats
drift in ULPs; the JIT renormalizes defensively, so we do too to stay
bit-close to its outputs.
"""
t1, sc1 = s1[..., :3], s1[..., 7:8]
t2, sc2 = s2[..., :3], s2[..., 7:8]
q1 = F.normalize(s1[..., 3:7], p=2, dim=-1, eps=1e-12)
q2 = F.normalize(s2[..., 3:7], p=2, dim=-1, eps=1e-12)
t_res = t1 + sc1 * _quat_rotate(q1, t2)
q_res = _quat_multiply(q1, q2)
s_res = sc1 * sc2
return torch.cat([t_res, q_res, s_res], dim=-1)
def _skel_transform_points(skel_state, points):
"""Apply skel_state (..., 8) to points (..., 3): t + q * (s * points).
Assumes the quaternion in skel_state is already unit-norm. Callers that
can't guarantee that should normalize first.
"""
t = skel_state[..., :3]
q = skel_state[..., 3:7]
s = skel_state[..., 7:8]
return t + _quat_rotate(q, s * points)
def _global_skel_state_from_local(local, pmi_levels):
"""FK walk in fp64 (matches the JIT's use_double_precision=True path).
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
one per BFS level. Avoids per-call torch.split + tolist() sync.
"""
orig_dtype = local.dtype
g = local.to(torch.float64).clone()
for source, target in pmi_levels:
parent = g.index_select(-2, target)
child = g.index_select(-2, source)
g.index_copy_(-2, source, _skel_multiply(parent, child))
return g.to(orig_dtype)
class MHRRig(nn.Module):
"""Plain-PyTorch reimplementation of Meta's MHR rig.
All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's
use_double_precision=True backend) regardless of the host model's dtype.
"""
NUM_VERTS = 18439
NUM_JOINTS = 127
NUM_LBS_TRIPLETS = 51337
NUM_IDENTITY = 45
NUM_EXPR = 72
PARAM_TRANSFORM_IN = 249 # = model_parameters(204) + identity_coeffs(45)
PARAM_TRANSFORM_OUT = 889 # = NUM_JOINTS * 7
POSE_CORR_IN = 750 # = (NUM_JOINTS - 2) * 6
POSE_CORR_HIDDEN = 3000
POSE_CORR_SPARSE_NNZ = 53136
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
del dtype, operations
f32 = torch.float32
# All buffers are populated by load_state_dict from the `mhr.*` keys
def _p(*shape, dtype=f32):
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
def _b(name, *shape, dtype):
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
self.base_shape = _p(self.NUM_VERTS, 3)
self.identity_basis = _p(self.NUM_IDENTITY, self.NUM_VERTS, 3)
self.expr_basis = _p(self.NUM_EXPR, self.NUM_VERTS, 3)
self.param_transform = _p(self.PARAM_TRANSFORM_OUT, self.PARAM_TRANSFORM_IN)
self.skel_joint_translation_offsets = _p(self.NUM_JOINTS, 3)
self.skel_joint_prerotations = _p(self.NUM_JOINTS, 4)
_b("skel_joint_parents", self.NUM_JOINTS, dtype=torch.int32)
_b("skel_pmi", 2, 266, dtype=torch.int64)
_b("skel_pmi_buffer_sizes", 4, dtype=torch.int64)
self.lbs_inverse_bind_pose = _p(self.NUM_JOINTS, 8)
self.lbs_skin_weights = _p(self.NUM_LBS_TRIPLETS)
_b("lbs_skin_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int32)
_b("lbs_vert_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int64)
_b("pose_corr_sparse_indices", 2, self.POSE_CORR_SPARSE_NNZ, dtype=torch.int64)
self.pose_corr_sparse_weight = _p(self.POSE_CORR_SPARSE_NNZ)
self.register_buffer("pose_corr_sparse_shape", torch.tensor([self.POSE_CORR_HIDDEN, self.POSE_CORR_IN], dtype=torch.int64, device=device))
self.pose_corr_weight = _p(self.NUM_VERTS * 3, self.POSE_CORR_HIDDEN)
self.pose_corr_bias = None
self._pose_corr_sparse_cache = None
self._pmi_levels_cache = None
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
f32 = self.base_shape.dtype
identity_coeffs = identity_coeffs.to(f32)
model_parameters = model_parameters.to(f32)
expr_coeffs = expr_coeffs.to(f32)
B = identity_coeffs.shape[0]
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)
cat_in = torch.cat([model_parameters, torch.zeros_like(identity_coeffs)], dim=1)
joint_parameters = torch.einsum("dn,bn->bd", self.param_transform, cat_in)
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
local_t = jp[..., :3] + self.skel_joint_translation_offsets.unsqueeze(0)
local_q = _euler_xyz_to_quat(jp[..., 3:6])
local_q = _quat_multiply(self.skel_joint_prerotations.unsqueeze(0), local_q)
local_s = torch.exp(jp[..., 6:7] * _LN2)
local_state = torch.cat([local_t, local_q, local_s], dim=-1)
skel_state = _global_skel_state_from_local(local_state, self._pmi_levels())
face_expr = torch.einsum("nvd,bn->bvd", self.expr_basis, expr_coeffs)
unposed = identity_rest + face_expr
if apply_correctives:
unposed = unposed + self._pose_correctives(joint_parameters)
verts = self._skin(skel_state, unposed)
return verts, skel_state
def _pose_correctives(self, joint_parameters):
B = joint_parameters.shape[0]
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
# Joints [2:] only — root and one more skipped. Take Euler XYZ (cols 3:6).
feat = batch6DFromXYZ(jp[:, 2:, 3:6], return_9D=False) # (B, 125, 6)
feat[..., 0] -= 1.0
feat[..., 4] -= 1.0
feat = feat.flatten(1, 2) # (B, 750)
h = (self._sparse_w() @ feat.T).T # (B, 3000)
h = F.relu(h)
out = F.linear(h, self.pose_corr_weight, self.pose_corr_bias) # (B, 55317)
return out.view(B, self.NUM_VERTS, 3)
def _pmi_levels(self):
cached = self._pmi_levels_cache
pmi = self.skel_pmi
if cached is not None and cached[0][0].device == pmi.device:
return cached
sizes = self.skel_pmi_buffer_sizes.tolist()
parts = torch.split(pmi, sizes, dim=1)
levels = [(p[0], p[1]) for p in parts]
self._pmi_levels_cache = levels
return levels
def _sparse_w(self):
cached = self._pose_corr_sparse_cache
w = self.pose_corr_sparse_weight
if cached is not None and cached.device == w.device and cached.dtype == w.dtype:
return cached
sparse = torch.sparse_coo_tensor(
self.pose_corr_sparse_indices,
w,
tuple(self.pose_corr_sparse_shape.tolist()),
check_invariants=False,
).coalesce()
self._pose_corr_sparse_cache = sparse
return sparse
def _skin(self, skel_state, rest_verts):
B = skel_state.shape[0]
ibp = self.lbs_inverse_bind_pose.unsqueeze(0).expand(B, self.NUM_JOINTS, 8)
joint_xform = _skel_multiply(skel_state, ibp)
norm_q = F.normalize(joint_xform[..., 3:7], p=2, dim=-1, eps=1e-12)
joint_xform = torch.cat([joint_xform[..., :3], norm_q, joint_xform[..., 7:8]], dim=-1)
sk_idx = self.lbs_skin_indices.long()
v_idx = self.lbs_vert_indices
w = self.lbs_skin_weights
per_triplet_xform = joint_xform.index_select(-2, sk_idx) # (B, 51337, 8)
per_triplet_rest = rest_verts.index_select(-2, v_idx) # (B, 51337, 3)
contrib = _skel_transform_points(per_triplet_xform, per_triplet_rest) * w.unsqueeze(0).unsqueeze(-1)
out = torch.zeros(B, self.NUM_VERTS, 3, dtype=rest_verts.dtype, device=rest_verts.device)
out.index_add_(-2, v_idx, contrib)
return out

View File

@ -0,0 +1,413 @@
# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity
# representation from Zhou et al., "On the Continuity of Rotation
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
# implementations from papagina/RotationContinuity:
# https://github.com/papagina/RotationContinuity/blob/758b0ce5/shapenet/code/tools.py
# The compact_cont_to_model_params_* functions are MHR-rig-specific glue.
import torch
import torch.nn.functional as F
def rotation_angle_difference(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Compute the angle difference (magnitude) between two batches of SO(3) rotation matrices.
Args:
A: Tensor of shape (*, 3, 3), batch of rotation matrices.
B: Tensor of shape (*, 3, 3), batch of rotation matrices.
Returns:
Tensor of shape (*,), angle differences in radians.
"""
# Compute relative rotation matrix
R_rel = torch.matmul(A, B.transpose(-2, -1)) # (B, 3, 3)
# Compute trace of relative rotation
trace = R_rel[..., 0, 0] + R_rel[..., 1, 1] + R_rel[..., 2, 2] # (B,)
# Compute angle using the trace formula
cos_theta = (trace - 1) / 2
# Clamp for numerical stability
cos_theta_clamped = torch.clamp(cos_theta, -1.0, 1.0)
# Compute angle difference
angle = torch.acos(cos_theta_clamped)
return angle
def fix_wrist_euler(
wrist_xzy, limits_x=(-2.2, 1.0), limits_z=(-2.2, 1.5), limits_y=(-1.2, 1.5)
):
"""
wrist_xzy: B x 2 x 3 (X, Z, Y angles)
Returns: Fixed angles within joint limits
"""
x, z, y = wrist_xzy[..., 0], wrist_xzy[..., 1], wrist_xzy[..., 2]
x_alt = torch.atan2(torch.sin(x + torch.pi), torch.cos(x + torch.pi))
z_alt = torch.atan2(torch.sin(-(z + torch.pi)), torch.cos(-(z + torch.pi)))
y_alt = torch.atan2(torch.sin(y + torch.pi), torch.cos(y + torch.pi))
# Calculate L2 violation distance
def calc_violation(val, limits):
below = torch.clamp(limits[0] - val, min=0.0)
above = torch.clamp(val - limits[1], min=0.0)
return below**2 + above**2
violation_orig = (
calc_violation(x, limits_x)
+ calc_violation(z, limits_z)
+ calc_violation(y, limits_y)
)
violation_alt = (
calc_violation(x_alt, limits_x)
+ calc_violation(z_alt, limits_z)
+ calc_violation(y_alt, limits_y)
)
# Use alternative where it has lower L2 violation
use_alt = violation_alt < violation_orig
# Stack alternative and apply mask
wrist_xzy_alt = torch.stack([x_alt, z_alt, y_alt], dim=-1)
result = torch.where(use_alt.unsqueeze(-1), wrist_xzy_alt, wrist_xzy)
return result
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py
def batch6DFromXYZ(r, return_9D=False):
"""
Generate a matrix representing a rotation defined by a XYZ-Euler
rotation.
Args:
r: ... x 3 rotation vectors
Returns:
... x 6
"""
rc = torch.cos(r)
rs = torch.sin(r)
cx = rc[..., 0]
cy = rc[..., 1]
cz = rc[..., 2]
sx = rs[..., 0]
sy = rs[..., 1]
sz = rs[..., 2]
result = torch.empty(list(r.shape[:-1]) + [3, 3], dtype=r.dtype).to(r.device)
result[..., 0, 0] = cy * cz
result[..., 0, 1] = -cx * sz + sx * sy * cz
result[..., 0, 2] = sx * sz + cx * sy * cz
result[..., 1, 0] = cy * sz
result[..., 1, 1] = cx * cz + sx * sy * sz
result[..., 1, 2] = -sx * cz + cx * sy * sz
result[..., 2, 0] = -sy
result[..., 2, 1] = sx * cy
result[..., 2, 2] = cx * cy
if not return_9D:
return torch.cat([result[..., :, 0], result[..., :, 1]], dim=-1)
else:
return result
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L82
def batchXYZfrom6D(poses):
# Args: poses: ... x 6, where "6" is the combined first and second columns
# First, get the rotaiton matrix
x_raw = poses[..., :3]
y_raw = poses[..., 3:]
x = F.normalize(x_raw, dim=-1)
z = torch.cross(x, y_raw, dim=-1)
z = F.normalize(z, dim=-1)
y = torch.cross(z, x, dim=-1)
matrix = torch.stack([x, y, z], dim=-1) # ... x 3 x 3
# Now get it into euler
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L412
sy = torch.sqrt(
matrix[..., 0, 0] * matrix[..., 0, 0] + matrix[..., 1, 0] * matrix[..., 1, 0]
)
singular = sy < 1e-6
singular = singular.float()
x = torch.atan2(matrix[..., 2, 1], matrix[..., 2, 2])
y = torch.atan2(-matrix[..., 2, 0], sy)
z = torch.atan2(matrix[..., 1, 0], matrix[..., 0, 0])
xs = torch.atan2(-matrix[..., 1, 2], matrix[..., 1, 1])
ys = torch.atan2(-matrix[..., 2, 0], sy)
zs = matrix[..., 1, 0] * 0
out_euler = torch.zeros_like(matrix[..., 0])
out_euler[..., 0] = x * (1 - singular) + xs * singular
out_euler[..., 1] = y * (1 - singular) + ys * singular
out_euler[..., 2] = z * (1 - singular) + zs * singular
return out_euler
_HAND_DOFS = [3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1]
_HAND_MASK_CACHE: dict = {} # device -> dict of masks
def _hand_masks(device):
m = _HAND_MASK_CACHE.get(device)
if m is not None:
return m
mask_cont_threedofs = torch.cat(
[torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
).to(device)
mask_cont_onedofs = torch.cat(
[torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
).to(device)
mask_model_params_threedofs = torch.cat(
[torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
).to(device)
mask_model_params_onedofs = torch.cat(
[torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
).to(device)
m = dict(
mask_cont_threedofs=mask_cont_threedofs,
mask_cont_onedofs=mask_cont_onedofs,
mask_model_params_threedofs=mask_model_params_threedofs,
mask_model_params_onedofs=mask_model_params_onedofs,
)
_HAND_MASK_CACHE[device] = m
return m
def compact_cont_to_model_params_hand(hand_cont):
# These are ordered by joint, not model params ^^
assert hand_cont.shape[-1] == 54
m = _hand_masks(hand_cont.device)
mask_cont_threedofs = m["mask_cont_threedofs"]
mask_cont_onedofs = m["mask_cont_onedofs"]
mask_model_params_threedofs = m["mask_model_params_threedofs"]
mask_model_params_onedofs = m["mask_model_params_onedofs"]
# Convert hand_cont to eulers
## First for 3DoFs
hand_cont_threedofs = hand_cont[..., mask_cont_threedofs].unflatten(-1, (-1, 6))
hand_model_params_threedofs = batchXYZfrom6D(hand_cont_threedofs).flatten(-2, -1)
## Next for 1DoFs
hand_cont_onedofs = hand_cont[..., mask_cont_onedofs].unflatten(
-1, (-1, 2)
) # (sincos)
hand_model_params_onedofs = torch.atan2(
hand_cont_onedofs[..., -2], hand_cont_onedofs[..., -1]
)
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
hand_model_params = torch.zeros(*hand_cont.shape[:-1], 27).to(hand_cont)
hand_model_params[..., mask_model_params_threedofs] = hand_model_params_threedofs
hand_model_params[..., mask_model_params_onedofs] = hand_model_params_onedofs
return hand_model_params
def compact_model_params_to_cont_hand(hand_model_params):
# These are ordered by joint, not model params ^^
assert hand_model_params.shape[-1] == 27
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
assert sum(hand_dofs_in_order) == 27
# Mask of 3DoFs into hand_cont
mask_cont_threedofs = torch.cat(
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
)
# Mask of 1DoFs (including 2DoF) into hand_cont
mask_cont_onedofs = torch.cat(
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
)
# Mask of 3DoFs into hand_model_params
mask_model_params_threedofs = torch.cat(
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
)
# Mask of 1DoFs (including 2DoF) into hand_model_params
mask_model_params_onedofs = torch.cat(
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
)
# Convert eulers to hand_cont hand_cont
## First for 3DoFs
hand_model_params_threedofs = hand_model_params[
..., mask_model_params_threedofs
].unflatten(-1, (-1, 3))
hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
## Next for 1DoFs
hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
hand_cont_onedofs = torch.stack(
[hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
).flatten(-2, -1)
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
return hand_cont
def batch9Dfrom6D(poses):
# Args: poses: ... x 6, where "6" is the combined first and second columns
# First, get the rotaiton matrix
x_raw = poses[..., :3]
y_raw = poses[..., 3:]
x = F.normalize(x_raw, dim=-1)
z = torch.cross(x, y_raw, dim=-1)
z = F.normalize(z, dim=-1)
y = torch.cross(z, x, dim=-1)
matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
return matrix
def batch4Dfrom2D(poses):
# Args: poses: ... x 2, where "2" is sincos
poses_norm = F.normalize(poses, dim=-1)
poses_4d = torch.stack(
[
poses_norm[..., 1],
poses_norm[..., 0],
-poses_norm[..., 0],
poses_norm[..., 1],
],
dim=-1,
) # Flattened SO2.
return poses_4d # .... x 4
def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
# fmt: on
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_cont.shape[-1] == (
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
)
# Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
]
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
# Convert conts to model params
## First for 3dofs
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
## Next for 1dofs
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
if inflate_trans:
assert (
False
), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
else:
## Nothing to do for trans
body_rotmat_trans = body_cont_trans
# Put them together
body_rotmat_params = torch.cat(
[body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
)
return body_rotmat_params
_BODY_IDX_CACHE: dict = {}
def _body_idxs(device):
cached = _BODY_IDX_CACHE.get(device)
if cached is not None:
return cached
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]).to(device)
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]).to(device)
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]).to(device)
# fmt: on
cached = (
all_param_3dof_rot_idxs,
all_param_1dof_rot_idxs,
all_param_1dof_trans_idxs,
all_param_3dof_rot_idxs.flatten(),
)
_BODY_IDX_CACHE[device] = cached
return cached
def compact_cont_to_model_params_body(body_pose_cont):
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
# Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
# Convert conts to model params
## First for 3dofs
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
body_params_3dofs = batchXYZfrom6D(body_cont_3dofs).flatten(-2, -1)
## Next for 1dofs
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
body_params_1dofs = torch.atan2(body_cont_1dofs[..., -2], body_cont_1dofs[..., -1])
## Nothing to do for trans
body_params_trans = body_cont_trans
# Put them together
body_pose_params = torch.zeros(*body_pose_cont.shape[:-1], 133, dtype=body_pose_cont.dtype, device=body_pose_cont.device)
body_pose_params[..., idxs_3dof_flat] = body_params_3dofs
body_pose_params[..., all_param_1dof_rot_idxs] = body_params_1dofs
body_pose_params[..., all_param_1dof_trans_idxs] = body_params_trans
return body_pose_params
def compact_model_params_to_cont_body(body_pose_params):
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
# fmt: on
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
num_1dof_trans = len(all_param_1dof_trans_idxs)
assert body_pose_params.shape[-1] == (
num_3dof_angles + num_1dof_angles + num_1dof_trans
)
# Take out params
body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
# params to cont
body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
-2, -1
)
body_cont_1dofs = torch.stack(
[body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
).flatten(-2, -1)
body_cont_trans = body_params_trans
# Put them together
body_pose_cont = torch.cat(
[body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
)
return body_pose_cont
# fmt: off
mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
mhr_param_hand_mask = torch.zeros(133).bool()
mhr_param_hand_mask[mhr_param_hand_idxs] = True
mhr_cont_hand_mask = torch.zeros(260).bool()
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
# fmt: on

View File

@ -0,0 +1,155 @@
import math
import einops
import torch
import torch.nn.functional as F
from comfy.ldm.cascade.common import LayerNorm2d_op
from torch import nn
from typing import List, Optional, Tuple, Union
from ..utils import perspective_projection
from .transformer import MLP
class CameraEncoder(nn.Module):
def __init__(self, embed_dim: int, patch_size: int = 14, device=None, dtype=None, operations=None):
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.camera = FourierPositionEncoding(n=3, num_bands=16, max_resolution=64)
self.conv = operations.Conv2d(embed_dim + 99, embed_dim, kernel_size=1, bias=False, device=device, dtype=dtype)
self.norm = LayerNorm2d_op(operations)(embed_dim, device=device, dtype=dtype)
def forward(self, img_embeddings: torch.Tensor, rays: torch.Tensor):
B, D, _h, _w = img_embeddings.shape
scale = 1 / self.patch_size
rays = F.interpolate(rays, scale_factor=(scale, scale), mode="bilinear", align_corners=False, antialias=True)
rays = rays.permute(0, 2, 3, 1).contiguous() # [b, h, w, 2]
rays = torch.cat([rays, torch.ones_like(rays[..., :1])], dim=-1)
rays_embeddings = self.camera(pos=rays.reshape(B, -1, 3)) # (bs, N, 99): rays fourier embedding
rays_embeddings = einops.rearrange(rays_embeddings, "b (h w) c -> b c h w", h=_h, w=_w).contiguous()
z = torch.cat([img_embeddings, rays_embeddings], dim=1)
return self.norm(self.conv(z))
class FourierPositionEncoding(nn.Module):
"""Sin/cos Fourier features for ray positions"""
def __init__(self, n: int, num_bands: int, max_resolution: int):
super().__init__()
self.num_bands = num_bands
self.max_resolution = [max_resolution] * n
@property
def channels(self):
num_dims = len(self.max_resolution)
encoding_size = self.num_bands * num_dims
encoding_size *= 2 # sin-cos
encoding_size += num_dims # concat
return encoding_size
def forward(self, pos: torch.Tensor):
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
return fourier_pos_enc
def _generate_fourier_features(pos: torch.Tensor, num_bands: int, max_resolution: List[int], min_freq: float = 1.0):
b, n = pos.shape[:2]
freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=pos.device, dtype=pos.dtype) for res in max_resolution], dim=0)
per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0)
per_pos_features = per_pos_features.reshape(b, n, -1)
# Sin-Cos
per_pos_features = torch.cat([torch.sin(math.pi * per_pos_features), torch.cos(math.pi * per_pos_features)], dim=-1)
# Concat with initial pos
per_pos_features = torch.cat([pos, per_pos_features], dim=-1)
return per_pos_features
class PerspectiveHead(nn.Module):
"""
Predict camera translation (s, tx, ty) and perform full-perspective 2D reprojection (CLIFF/CameraHMR setup).
"""
def __init__(self, input_dim: int, img_size: Union[int, Tuple[int, int]], # model input size (W, H)
mlp_depth: int = 1, mlp_channel_div_factor: int = 8, default_scale_factor: float = 1.0,
device=None, dtype=None, operations=None
):
super().__init__()
# Metadata to compute 3D skeleton and 2D reprojection
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.ncam = 3 # (s, tx, ty)
self.default_scale_factor = default_scale_factor
self.proj = MLP(
input_dim=input_dim,
hidden_dim=input_dim // mlp_channel_div_factor,
output_dim=self.ncam,
num_layers=mlp_depth,
device=device,
dtype=dtype,
operations=operations,
)
def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None):
"""
Args:
x: pose token with shape [B, C], usually C=DECODER.DIM
init_estimate: [B, self.ncam]
"""
pred_cam = self.proj(x)
if init_estimate is not None:
pred_cam = pred_cam + init_estimate
return pred_cam
def perspective_projection(
self,
points_3d: torch.Tensor,
pred_cam: torch.Tensor,
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
bbox_size: torch.Tensor, # [N,], in original image space
img_size: torch.Tensor,
cam_int: torch.Tensor, # [B, 3, 3]
use_intrin_center: bool = False,
):
batch_size = points_3d.shape[0]
pred_cam = pred_cam.clone()
pred_cam[..., [0, 2]] *= -1 # Camera system difference
# Compute camera translation: (scale, x, y) --> (x, y, depth)
# depth ~= f / s, Note that f is in the NDC space
s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2]
bs = bbox_size * s * self.default_scale_factor + 1e-8
focal_length = cam_int[:, 0, 0]
tz = 2 * focal_length / bs
if not use_intrin_center:
cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs
cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
else:
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
# Compute camera translation
j3d_cam = points_3d + pred_cam_t.unsqueeze(1)
# Projection to the image plane, note that the projection output is in original image space now.
j2d = perspective_projection(j3d_cam, cam_int)
return {
"pred_keypoints_2d": j2d.reshape(batch_size, -1, 2),
"pred_keypoints_2d_depth": j3d_cam.reshape(batch_size, -1, 3)[:, :, 2],
"pred_cam_t": pred_cam_t, "focal_length": focal_length,
}

View File

@ -0,0 +1,250 @@
# DINOv3 ViT-H+ backbone for SAM 3D Body.
#
# Single-file consolidation of the inference path. SAM 3D Body only ships a
# `dinov3_vith16plus` checkpoint, so the architecture is hardcoded rather
# than reconstructed from Hydra-flavoured configs.
#
# Adapted from facebookresearch/dinov3 (DINOv3 License Agreement). Trimmed
# to what's actually exercised at inference: no multi-crop training path,
# no DINOHead, no causal blocks, no rmsnorm/Mlp variants, no rope shift /
# jitter / rescale (training-time augmentations).
#TODO: Unify with TRELLIS2
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
from torch import Tensor, nn
# DINOv3 ViT-H+ architecture constants.
EMBED_DIM = 1280
DEPTH = 32
NUM_HEADS = 20
FFN_RATIO = 6.0
PATCH_SIZE = 16
LAYERSCALE_INIT = 1.0e-5
N_STORAGE_TOKENS = 4
LAYERNORM_EPS = 1e-5 # "layernormbf16" preset uses 1e-5
ROPE_BASE = 100.0
# RoPE (axial sin/cos, no learnable weights)
def _rotate_half(x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def _apply_rope(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
return x * cos + _rotate_half(x) * sin
class RopePositionEmbedding(nn.Module):
"""Axial RoPE for 2D patch grids; periods buffer is deterministic."""
def __init__(self, embed_dim: int, num_heads: int, dtype=torch.float32, device=None):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
D_head = embed_dim // num_heads
# Periods are persistent so they round-trip through state_dict, but the
# values are deterministic from D_head/base; load_state_dict will
# overwrite this with the saved buffer either way.
periods = ROPE_BASE ** (
2 * torch.arange(D_head // 4, dtype=dtype, device=device) / (D_head // 2)
)
self.register_buffer("periods", periods, persistent=True)
self._dtype = dtype
def forward(self, H: int, W: int) -> Tuple[Tensor, Tensor]:
device, dtype = self.periods.device, self._dtype
coords_h = torch.arange(0.5, H, device=device, dtype=dtype) / H
coords_w = torch.arange(0.5, W, device=device, dtype=dtype) / W
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = 2.0 * coords.flatten(0, 1) - 1.0 # [HW, 2] in [-1, +1]
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
angles = angles.flatten(1, 2).tile(2) # [HW, D_head]
return torch.sin(angles), torch.cos(angles)
def _apply_rope_to_qk(q: Tensor, k: Tensor, rope: Tuple[Tensor, Tensor]):
"""Apply RoPE only to the patch-token slice (skip CLS + storage tokens)."""
sin, cos = rope
rope_dtype = sin.dtype
q_dtype, k_dtype = q.dtype, k.dtype
q = q.to(rope_dtype)
k = k.to(rope_dtype)
prefix = q.shape[-2] - sin.shape[-2]
q_pre, q_rope = q[..., :prefix, :], q[..., prefix:, :]
k_pre, k_rope = k[..., :prefix, :], k[..., prefix:, :]
q = torch.cat([q_pre, _apply_rope(q_rope, sin, cos)], dim=-2)
k = torch.cat([k_pre, _apply_rope(k_rope, sin, cos)], dim=-2)
return q.to(q_dtype), k.to(k_dtype)
# Layers
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values: float, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(
torch.full((dim,), init_values, device=device, dtype=dtype)
)
def forward(self, x: Tensor) -> Tensor:
return x * self.gamma
class SwiGLUFFN(nn.Module):
"""w3(silu(w1(x)) * w2(x))."""
def __init__(self, in_features: int, hidden_features: int, align_to: int = 8,
device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
d = int(hidden_features * 2 / 3)
h = d + (-d % align_to)
self.w1 = ops.Linear(in_features, h, bias=True, device=device, dtype=dtype)
self.w2 = ops.Linear(in_features, h, bias=True, device=device, dtype=dtype)
self.w3 = ops.Linear(h, in_features, bias=True, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
self.num_heads = num_heads
# DINOv3's `mask_k_bias` zeroes the K third of qkv.bias. The mask is
# deterministic from out_features, so the loader applies it in-place
# once after load_state_dict (see `apply_dinov3_qkv_bias_mask`) and the
# forward stays a plain F.linear.
self.qkv = ops.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype)
self.proj = ops.Linear(dim, dim, bias=True, device=device, dtype=dtype)
def forward(self, x: Tensor, rope: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = qkv.unbind(2)
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
if rope is not None:
q, k = _apply_rope_to_qk(q, k, rope)
# low_precision_attention=False forces attention_sage (when enabled
# globally in comfy) to fall back to pytorch SDPA. SAM 3D Body's
# regression heads (camera projection, MHR rig math) are sensitive
# to attention output precision; sage's int8/fp8 path drifts the
# keypoints and mesh visibly.
x = optimized_attention(
q, k, v, self.num_heads, skip_reshape=True,
low_precision_attention=False,
)
return self.proj(x)
class Block(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_ratio: float,
device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
self.norm1 = ops.LayerNorm(dim, eps=LAYERNORM_EPS, device=device, dtype=dtype)
self.attn = SelfAttention(dim, num_heads, device=device, dtype=dtype, operations=operations)
self.ls1 = LayerScale(dim, LAYERSCALE_INIT, device=device, dtype=dtype)
self.norm2 = ops.LayerNorm(dim, eps=LAYERNORM_EPS, device=device, dtype=dtype)
self.mlp = SwiGLUFFN(dim, int(dim * ffn_ratio), device=device, dtype=dtype, operations=operations)
self.ls2 = LayerScale(dim, LAYERSCALE_INIT, device=device, dtype=dtype)
def forward(self, x: Tensor, rope=None) -> Tensor:
x = x + self.ls1(self.attn(self.norm1(x), rope=rope))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
def __init__(self, in_chans=3, embed_dim=EMBED_DIM, patch_size=PATCH_SIZE,
device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
self.proj = ops.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size,
device=device, dtype=dtype,
)
# Encoder + wrapper
class _DinoEncoder(nn.Module):
"""Inner ViT module. Held under `Dinov3Backbone.encoder` so state_dict
keys (`backbone.encoder.*`) match the upstream layout."""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
self.patch_size = PATCH_SIZE
self.embed_dim = EMBED_DIM
self.patch_embed = PatchEmbed(
embed_dim=EMBED_DIM, patch_size=PATCH_SIZE,
device=device, dtype=dtype, operations=operations,
)
self.cls_token = nn.Parameter(torch.empty(1, 1, EMBED_DIM, device=device, dtype=dtype))
self.storage_tokens = nn.Parameter(
torch.empty(1, N_STORAGE_TOKENS, EMBED_DIM, device=device, dtype=dtype)
)
# The released config sets pos_embed_rope_dtype="fp32"; periods stays
# in fp32 regardless of the backbone weight dtype.
self.rope_embed = RopePositionEmbedding(EMBED_DIM, NUM_HEADS, dtype=torch.float32, device=device)
self.blocks = nn.ModuleList([
Block(EMBED_DIM, NUM_HEADS, FFN_RATIO, device=device, dtype=dtype, operations=operations)
for _ in range(DEPTH)
])
self.norm = ops.LayerNorm(EMBED_DIM, eps=LAYERNORM_EPS, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
x = self.patch_embed.proj(x) # (B, embed_dim, H, W)
B, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # (B, H*W, embed_dim)
# Prepend CLS + storage tokens.
x = torch.cat([
self.cls_token.expand(B, -1, -1),
self.storage_tokens.expand(B, -1, -1),
x,
], dim=1)
rope = self.rope_embed(H=H, W=W)
for blk in self.blocks:
x = blk(x, rope)
x = self.norm(x)
# Drop CLS + storage tokens; reshape patch grid to (B, C, H, W).
x = x[:, 1 + N_STORAGE_TOKENS :]
return x.reshape(B, H, W, EMBED_DIM).permute(0, 3, 1, 2).contiguous()
class Dinov3Backbone(nn.Module):
"""Public backbone interface used by SAM3DBody."""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
self.encoder = _DinoEncoder(device=device, dtype=dtype, operations=operations)
self.patch_size = PATCH_SIZE
self.embed_dim = self.embed_dims = EMBED_DIM
def forward(self, x: Tensor) -> Tensor:
return self.encoder(x)
def apply_dinov3_qkv_bias_mask(backbone: "Dinov3Backbone") -> None:
"""Zero the K third of every block's qkv.bias in-place.
Implements DINOv3's `mask_k_bias` once at load time so the per-block forward
stays a plain F.linear instead of cloning + slicing the bias every call.
"""
for blk in backbone.encoder.blocks:
qkv = blk.attn.qkv
if qkv.bias is not None:
o = qkv.out_features
qkv.bias.data[o // 3 : 2 * o // 3] = 0

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,272 @@
"""SAM 3D Body prompt pipeline: encode (keypoint, mask) prompts and run them
through a cross-attention transformer decoder over (token, image) pairs.
Both adapted from the SAM-style prompt path (Meta, Apache 2.0):
https://github.com/facebookresearch/segment-anything
"""
from typing import Optional, Tuple
import torch
import torch.nn as nn
from comfy.ldm.cascade.common import LayerNorm2d_op
from comfy.ldm.sam3.sam import PositionEmbeddingRandom
from .transformer import TransformerDecoderLayer
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
num_body_joints: int,
device=None,
dtype=None,
operations=None,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
"""
super().__init__()
ops = operations if operations is not None else nn
self.embed_dim = embed_dim
self.num_body_joints = num_body_joints
# Keypoint prompts
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.point_embeddings = nn.ModuleList(
[ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
)
self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
LN2d = LayerNorm2d_op(ops)
mask_in_chans = 256
self.mask_downscaling = nn.Sequential(
ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 64, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 16, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 4, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans, device=device, dtype=dtype),
nn.GELU(),
ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
)
# Trained values for the gating conv and no_mask_embed are loaded from the state dict
self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
"""Positional encoding over the image-embedding grid; (1, C, H, W)."""
return self.pe_layer(size)
def _embed_keypoints(self, points: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Embeds point prompts.
Assuming points have been normalized to [0, 1].
Output shape [B, N, C], mask shape [B, N]
"""
assert points.min() >= 0 and points.max() <= 1
# PE compute in fp32 for precision (sin/cos of large coords), then cast back to the embedding weight dtype
weight_dtype = self.invalid_point_embed.weight.dtype
point_embedding = self.pe_layer._encode(points.to(torch.float)).to(weight_dtype)
point_embedding[labels == -2] = 0.0 # invalid points
point_embedding[labels == -2] += self.invalid_point_embed.weight.to(point_embedding)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding)
for i in range(self.num_body_joints):
point_embedding[labels == i] += self.point_embeddings[i].weight.to(point_embedding)
point_mask = labels > -2
return point_embedding, point_mask
def _get_batch_size(self, keypoints: Optional[torch.Tensor], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor]) -> int:
if keypoints is not None:
return keypoints.shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def forward(
self,
keypoints: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
keypoints (torchTensor or none): point coordinates and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(keypoints, boxes, masks)
# Anchor device on the input prompts so we don't pull the offloaded
# CPU embedding device under dynamic loading.
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
device = ref.device if ref is not None else self.point_embeddings[0].weight.device
weight_dtype = self.invalid_point_embed.weight.dtype
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=device, dtype=weight_dtype)
sparse_masks = torch.empty((bs, 0), device=device)
if keypoints is not None:
coords = keypoints[:, :, :2]
labels = keypoints[:, :, -1]
point_embeddings, point_mask = self._embed_keypoints(coords, labels)
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
sparse_masks = torch.cat([sparse_masks, point_mask], dim=1)
return sparse_embeddings, sparse_masks
def get_mask_embeddings(
self,
masks: Optional[torch.Tensor] = None,
bs: int = 1,
size: Tuple[int, int] = (16, 16), # [H, W]
) -> torch.Tensor:
"""Embeds mask inputs."""
# masks is always on the active device when present; fall back to the
# downscaling Conv's weight device when it isn't (rare callers).
ref = masks if masks is not None else next(self.mask_downscaling.parameters())
no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand(
bs, -1, size[0], size[1]
)
if masks is not None:
mask_embeddings = self.mask_downscaling(masks)
else:
mask_embeddings = no_mask_embeddings
return mask_embeddings, no_mask_embeddings
class PromptableDecoder(nn.Module):
"""Cross-attention transformer decoder over (token, image) pairs."""
def __init__(
self,
dims: int,
context_dims: int,
depth: int,
num_heads: int = 8,
head_dims: int = 64,
mlp_dims: int = 1024,
repeat_pe: bool = False,
do_interm_preds: bool = False,
keypoint_token_update: bool = False,
device=None,
dtype=None,
operations=None,
):
super().__init__()
ops = operations if operations is not None else nn
self.layers = nn.ModuleList(
TransformerDecoderLayer(
token_dims=dims,
context_dims=context_dims,
num_heads=num_heads,
head_dims=head_dims,
mlp_dims=mlp_dims,
repeat_pe=repeat_pe,
skip_first_pe=(i == 0),
device=device,
dtype=dtype,
operations=operations,
)
for i in range(depth)
)
self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
self.do_interm_preds = do_interm_preds
self.keypoint_token_update = keypoint_token_update
def forward(
self,
token_embedding: torch.Tensor,
image_embedding: torch.Tensor,
token_augment: Optional[torch.Tensor] = None,
image_augment: Optional[torch.Tensor] = None,
token_mask: Optional[torch.Tensor] = None,
token_to_pose_output_fn=None,
keypoint_token_update_fn=None,
hand_embeddings=None,
hand_augment=None,
):
"""
Args:
token_embedding: [B, N, C]
image_embedding: [B, C, H, W] -- flattened to [B, HW, C] inline
"""
# Channels-last for the transformer.
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
if image_augment is not None:
image_augment = image_augment.flatten(2).permute(0, 2, 1)
if hand_embeddings is not None:
hand_embeddings = hand_embeddings.flatten(2).permute(0, 2, 1)
hand_augment = hand_augment.flatten(2).permute(0, 2, 1)
if len(hand_augment) == 1:
# inflate batch dimension
assert len(hand_augment.shape) == 3
hand_augment = hand_augment.repeat(len(hand_embeddings), 1, 1)
all_pose_outputs = [] if self.do_interm_preds else None
if self.do_interm_preds:
assert token_to_pose_output_fn is not None
layer_idx = 0
for layer_idx, layer in enumerate(self.layers):
if hand_embeddings is None:
token_embedding, image_embedding = layer(
token_embedding, image_embedding,
token_augment, image_augment, token_mask,
)
else:
token_embedding, image_embedding = layer(
token_embedding,
torch.cat([image_embedding, hand_embeddings], dim=1),
token_augment,
torch.cat([image_augment, hand_augment], dim=1),
token_mask,
)
image_embedding = image_embedding[:, : image_augment.shape[1]]
if self.do_interm_preds and layer_idx < len(self.layers) - 1:
curr = token_to_pose_output_fn(
self.norm_final(token_embedding),
prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None,
layer_idx=layer_idx,
)
all_pose_outputs.append(curr)
if self.keypoint_token_update:
assert keypoint_token_update_fn is not None
token_embedding, token_augment, _, _ = keypoint_token_update_fn(
token_embedding, token_augment, curr, layer_idx,
)
out = self.norm_final(token_embedding)
if self.do_interm_preds:
curr = token_to_pose_output_fn(
out,
prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None,
layer_idx=layer_idx,
)
all_pose_outputs.append(curr)
return out, all_pose_outputs
return out

View File

@ -0,0 +1,104 @@
from typing import Optional
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act_layer=nn.ReLU, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList(
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
for i in range(num_layers)
)
self.act = act_layer()
def forward(self, x):
for i, layer in enumerate(self.layers):
x = self.act(layer(x)) if i < len(self.layers) - 1 else layer(x)
return x
class Attention(nn.Module):
def __init__(self, embed_dims, num_heads, query_dims=None, key_dims=None, value_dims=None, qkv_bias=True, proj_bias=True,
device=None, dtype=None, operations=None):
super().__init__()
self.query_dims = query_dims or embed_dims
self.key_dims = key_dims or embed_dims
self.value_dims = value_dims or embed_dims
self.embed_dims = embed_dims
self.num_heads = num_heads
self.head_dims = embed_dims // num_heads
lin = lambda i, o, b: operations.Linear(i, o, bias=b, device=device, dtype=dtype)
self.q_proj = lin(self.query_dims, embed_dims, qkv_bias)
self.k_proj = lin(self.key_dims, embed_dims, qkv_bias)
self.v_proj = lin(self.value_dims, embed_dims, qkv_bias)
self.proj = lin(embed_dims, self.query_dims, proj_bias)
def _split(self, x: torch.Tensor) -> torch.Tensor:
b, n, _ = x.shape
return x.reshape(b, n, self.num_heads, self.head_dims).transpose(1, 2)
def forward(self, q, k, v, attn_mask: Optional[torch.Tensor] = None):
q, k, v = self._split(self.q_proj(q)), self._split(self.k_proj(k)), self._split(self.v_proj(v))
x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True, low_precision_attention=False)
return self.proj(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, token_dims, context_dims, num_heads=8, head_dims=64, mlp_dims=1024,
repeat_pe=False, skip_first_pe=False, device=None, dtype=None, operations=None):
super().__init__()
self.repeat_pe = repeat_pe
self.skip_first_pe = skip_first_pe
ln = lambda d: operations.LayerNorm(d, eps=1e-6, device=device, dtype=dtype)
attn_dim = num_heads * head_dims
attn_kwargs = dict(embed_dims=attn_dim, num_heads=num_heads, device=device, dtype=dtype, operations=operations)
if repeat_pe:
self.ln_pe_1, self.ln_pe_2 = ln(token_dims), ln(context_dims)
self.ln1 = ln(token_dims)
self.self_attn = Attention(query_dims=token_dims, key_dims=token_dims, value_dims=token_dims, **attn_kwargs)
self.ln2_1, self.ln2_2 = ln(token_dims), ln(context_dims)
self.cross_attn = Attention(query_dims=token_dims, key_dims=context_dims, value_dims=context_dims, **attn_kwargs)
self.ln3 = ln(token_dims)
self.ffn = MLP(token_dims, mlp_dims, token_dims, num_layers=2, act_layer=nn.GELU, device=device, dtype=dtype, operations=operations)
def forward(self, x, context, x_pe=None, context_pe=None, x_mask=None):
"""x: [B, N_tokens, C], context: [B, N_ctx, C], x_mask: [B, N_tokens] or None."""
# LaPE-style PE re-norm per layer.
if self.repeat_pe and context_pe is not None:
x_pe = self.ln_pe_1(x_pe)
context_pe = self.ln_pe_2(context_pe)
# Self-attn over tokens.
if self.repeat_pe and not self.skip_first_pe and x_pe is not None:
q = k = self.ln1(x) + x_pe
v = self.ln1(x)
else:
q = k = v = self.ln1(x)
attn_mask = None
if x_mask is not None:
attn_mask = x_mask[:, :, None] @ x_mask[:, None, :]
attn_mask.diagonal(dim1=1, dim2=2).fill_(1) # avoid all-invalid rows -> nan
attn_mask = attn_mask > 0
x = x + self.self_attn(q, k, v, attn_mask=attn_mask)
# Cross-attn: tokens attend to image context.
if self.repeat_pe and context_pe is not None:
q = self.ln2_1(x) + x_pe
k = self.ln2_2(context) + context_pe
v = self.ln2_2(context)
else:
q = self.ln2_1(x)
k = v = self.ln2_2(context)
x = x + self.cross_attn(q, k, v)
x = x + self.ffn(self.ln3(x))
return x, context

View File

@ -0,0 +1,341 @@
# The bbox/affine math (xyxy<->cs, get_warp_matrices) is the standard
# top-down pose-estimation crop pipeline from MMPose (Apache 2.0):
# https://github.com/open-mmlab/mmpose — same algorithm as UDP (CVPR 2020).
from typing import Dict, Tuple
import torch
import torch.nn.functional as F
# Bbox + affine math
# All `output_size` / image-shape tuples in this block are (H, W) to match
# the torch.Size convention used everywhere else in the codebase.
def bbox_xyxy2cs(bbox, padding: float) -> Tuple[torch.Tensor, torch.Tensor]:
"""xyxy bbox -> (center, scale) with optional padding multiplier."""
bbox = torch.as_tensor(bbox, dtype=torch.float32)
dim = bbox.dim()
if dim == 1:
bbox = bbox.unsqueeze(0)
x1, y1, x2, y2 = bbox[:, 0:1], bbox[:, 1:2], bbox[:, 2:3], bbox[:, 3:4]
center = torch.cat([x1 + x2, y1 + y2], dim=1) * 0.5
scale = torch.cat([x2 - x1, y2 - y1], dim=1) * padding
if dim == 1:
return center[0], scale[0]
return center, scale
def fix_aspect_ratio(bbox_scale, aspect_ratio: float) -> torch.Tensor:
"""Pad whichever side is too narrow to hit `aspect_ratio` (w/h)."""
bbox_scale = torch.as_tensor(bbox_scale, dtype=torch.float32)
dim = bbox_scale.dim()
if dim == 1:
bbox_scale = bbox_scale.unsqueeze(0)
w, h = bbox_scale[:, 0:1], bbox_scale[:, 1:2]
out = torch.where(
w > h * aspect_ratio,
torch.cat([w, w / aspect_ratio], dim=1),
torch.cat([h * aspect_ratio, h], dim=1),
)
return out[0] if dim == 1 else out
def get_warp_matrices(centers, scales, output_size: Tuple[int, int]) -> torch.Tensor:
"""Batched 2x3 affine matrices mapping each (center, scale) bbox region to
the output box. `output_size` is (H_out, W_out). With rot=0 the MMPose
3-point fit reduces to a closed-form isotropic scale + translate.
"""
centers = torch.as_tensor(centers, dtype=torch.float32)
scales = torch.as_tensor(scales, dtype=torch.float32)
if centers.dim() == 1:
centers = centers.unsqueeze(0)
scales = scales.unsqueeze(0)
n = centers.shape[0]
src_w = scales[:, 0]
dst_h = float(output_size[0])
dst_w = float(output_size[1])
# With rot=0 the warp is just scale + translate (uniform x/y scale based
# on src_w/dst_w). The closed form drops out of MMPose's 3-point solve.
s = dst_w / src_w # (N,)
mats = torch.zeros((n, 2, 3), dtype=torch.float32)
mats[:, 0, 0] = s
mats[:, 1, 1] = s
mats[:, 0, 2] = dst_w * 0.5 - s * centers[:, 0]
mats[:, 1, 2] = dst_h * 0.5 - s * centers[:, 1]
return mats # (N, 2, 3)
def warp_affine_batched(
src_t: torch.Tensor, # (N, C, H_src, W_src) float
mats: torch.Tensor, # (N, 2, 3) float
output_size: Tuple[int, int] # (H_out, W_out)
) -> torch.Tensor:
"""Apply N forward (src->dst) 2x3 affine warps to N source images in one
grid_sample call. Kept generic over arbitrary affines (not specialized to
the scale+translate produced by `get_warp_matrices`) so callers can pass
rotated/sheared affines; the per-crop 3x3 invert is O(N) of trivial work."""
H_out, W_out = int(output_size[0]), int(output_size[1])
N, _, H_src, W_src = src_t.shape
device = src_t.device
# Invert each forward affine; grid_sample needs dst->src.
mats_t = mats.to(device=device, dtype=torch.float32)
bottom = torch.tensor([0.0, 0.0, 1.0], device=device).expand(N, 1, 3)
mats_3 = torch.cat([mats_t, bottom], dim=1) # (N, 3, 3)
mats_inv = torch.linalg.inv(mats_3)[:, :2, :] # (N, 2, 3)
# Output pixel-center grid (i+0.5, j+0.5).
ys, xs = torch.meshgrid(
torch.arange(H_out, dtype=torch.float32, device=device) + 0.5,
torch.arange(W_out, dtype=torch.float32, device=device) + 0.5,
indexing="ij",
)
homo = torch.stack([xs, ys, torch.ones_like(xs)], dim=-1) # (H_out, W_out, 3)
src_pos = torch.einsum("nkl,ijl->nijk", mats_inv, homo) # (N, H_out, W_out, 2)
# Normalize to [-1, 1] grid_sample coords (align_corners=False).
src_pos[..., 0] = src_pos[..., 0] / W_src * 2 - 1
src_pos[..., 1] = src_pos[..., 1] / H_src * 2 - 1
return F.grid_sample(src_t, src_pos, mode="bilinear", padding_mode="zeros", align_corners=False)
# Batch construction (one prediction over N person crops from a single image)
def prepare_batch(
img, # (H, W, 3) uint8 torch tensor or list of such tensors
boxes, # (N, 4) xyxy (numpy or torch)
input_size: Tuple[int, int], # (W, H) of the model crop
bbox_padding: float = 1.25, # xyxy->cs padding multiplier (1.25 body, 0.9 hand)
aspect_ratio: float = 0.75, # w/h of the crop (0.75 matches HMR2/Sapiens)
masks=None, # optional per-person masks
masks_score=None, # optional per-person mask scores
cam_int=None, # optional camera intrinsics
) -> Dict:
"""Build the batch dict the SAM3DBody forward expects, doing the N crops in one batched `grid_sample` call."""
is_multi_image = isinstance(img, list)
if is_multi_image:
assert len(img) == boxes.shape[0]
height, width = img[0].shape[:2]
else:
height, width = img.shape[:2]
n = int(boxes.shape[0])
assert n > 0, "prepare_batch needs at least one box"
W_out, H_out = int(input_size[0]), int(input_size[1])
# Per-box bbox math (cheap, vectorized, CPU).
centers, scales = bbox_xyxy2cs(boxes, padding=bbox_padding)
# Two passes: first hits the upstream bbox aspect (e.g. 0.75 HMR2/Sapiens
# convention), second pads further if the model crop's W_out/H_out differs
# from that. When they match (common case) the second call is a no-op.
scales = fix_aspect_ratio(scales, aspect_ratio)
scales = fix_aspect_ratio(scales, W_out / H_out)
mats = get_warp_matrices(centers, scales, (H_out, W_out)) # (N, 2, 3)
# Stack source images into a contiguous (N, 3, H, W) tensor on CPU.
if is_multi_image:
src_t = torch.stack(list(img), dim=0)
else:
src_t = img.unsqueeze(0).expand(n, -1, -1, -1)
src_t = src_t.permute(0, 3, 1, 2).contiguous().float() # (N, 3, H, W) in [0, 255]
warped_t = warp_affine_batched(src_t, mats, (H_out, W_out)) # (N, 3, H_out, W_out)
# Float warp -> floor (matches the legacy uint8 round-trip) -> /255.
img_t = torch.floor(warped_t).clamp_(0.0, 255.0) / 255.0 # (N, 3, H_out, W_out) in [0, 1]
# Masks: zero-init when missing, otherwise stack and warp through the same matrices.
boxes_t = torch.as_tensor(boxes, dtype=torch.float32)
if masks is None:
mask_t = torch.zeros((n, H_out, W_out), dtype=torch.float32)
mask_score_t = torch.zeros((n,), dtype=torch.float32)
else:
# masks is an array of N items, each (H, W) or (H, W, 1).
masks_t = torch.stack([torch.as_tensor(masks[i]) for i in range(n)], dim=0)
if masks_t.dim() == 4 and masks_t.shape[-1] == 1:
masks_t = masks_t[..., 0]
masks_src_t = masks_t.float().unsqueeze(1) # (N, 1, H, W) in [0, 255]
warped_masks = warp_affine_batched(masks_src_t, mats, (H_out, W_out))
mask_t = torch.floor(warped_masks.squeeze(1)).clamp_(0.0, 255.0)
if masks_score is not None:
mask_score_t = torch.as_tensor([masks_score[i] for i in range(n)], dtype=torch.float32)
else:
mask_score_t = torch.ones((n,), dtype=torch.float32)
img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous()
ori_img_size_t = torch.tensor([width, height], dtype=torch.float32).expand(n, 2).contiguous()
batch = {
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
"img_size": img_size_t.unsqueeze(0), # (1, N, 2)
"ori_img_size": ori_img_size_t.unsqueeze(0),# (1, N, 2)
"bbox_center": centers.unsqueeze(0), # (1, N, 2)
"bbox_scale": scales.unsqueeze(0), # (1, N, 2)
"bbox": boxes_t.unsqueeze(0), # (1, N, 4)
"affine_trans": mats.unsqueeze(0), # (1, N, 2, 3)
"mask": mask_t.unsqueeze(0).unsqueeze(2), # (1, N, 1, H_out, W_out)
"mask_score": mask_score_t.unsqueeze(0), # (1, N)
"person_valid": torch.ones((1, n), dtype=torch.float32),
}
if cam_int is not None:
batch["cam_int"] = cam_int.to(batch["img"])
else:
# Default intrinsics: focal = sqrt(W^2 + H^2), principal point = image center.
f = (height ** 2 + width ** 2) ** 0.5
batch["cam_int"] = torch.tensor(
[[[f, 0, width / 2.0], [0, f, height / 2.0], [0, 0, 1]]],
).to(batch["img"])
return batch
# Geometry utils
def rot6d_to_rotmat(
x: torch.Tensor # (B, 6) batch of 6-D rotation representations.
) -> torch.Tensor: # (B, 3, 3) rotation matrices.
"""6D continuous rotation rep (Zhou et al., CVPR 2019) -> 3x3 rotation matrix."""
x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
a1, a2 = x[:, :, 0], x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
b3 = torch.linalg.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
def perspective_projection(
x: torch.Tensor, # (B, N, 3) 3D points in camera coords.
K: torch.Tensor # (B, 3, 3) camera intrinsics.
) -> torch.Tensor: # (B, N, 2) 2D image-plane projections.
"""Project 3D points (already in camera frame) through intrinsics K."""
y = x / x[:, :, -1].unsqueeze(-1) # perspective divide
y = torch.einsum("bij,bkj->bki", K, y) # apply intrinsics
return y[:, :, :2]
# Rotation conversions, behavior mirrors the roma library (https://github.com/naver/roma)
def _axis_rotmat(axis: str, angle: torch.Tensor) -> torch.Tensor:
"""Rotation matrices around a single coordinate axis. Shape (..., 3, 3)."""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
flat = (one, zero, zero,
zero, cos, -sin,
zero, sin, cos)
elif axis == "Y":
flat = (cos, zero, sin,
zero, one, zero,
-sin, zero, cos)
elif axis == "Z":
flat = (cos, -sin, zero,
sin, cos, zero,
zero, zero, one)
else:
raise ValueError(f"Invalid axis {axis!r}; expected X/Y/Z.")
return torch.stack(flat, dim=-1).reshape(angle.shape + (3, 3))
def euler_to_rotmat(convention: str, angles: torch.Tensor) -> torch.Tensor:
"""Euler angles -> rotation matrix, matching roma's case-keyed convention."""
axes = convention.upper()
R0 = _axis_rotmat(axes[0], angles[..., 0])
R1 = _axis_rotmat(axes[1], angles[..., 1])
R2 = _axis_rotmat(axes[2], angles[..., 2])
if convention.islower():
return R2 @ R1 @ R0
return R0 @ R1 @ R2
def _index_from_letter(letter: str) -> int:
return {"X": 0, "Y": 1, "Z": 2}[letter]
def _angle_from_tan(
axis: str,
other_axis: str,
data: torch.Tensor,
horizontal: bool,
tait_bryan: bool,
) -> torch.Tensor:
"""Extract an outer Euler angle from a row/column of a rotation matrix.
Adapted from PyTorch3D's matrix_to_euler_angles helper.
"""
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ("XY", "YZ", "ZX")
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
def _matrix_to_euler_intrinsic(matrix: torch.Tensor, convention: str) -> torch.Tensor:
"""Decompose a rotation matrix into intrinsic Euler angles (uppercase abc).
Adapted from PyTorch3D's matrix_to_euler_angles.
"""
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
sign = -1.0 if (i0 - i2) in (-1, 2) else 1.0
central = torch.asin(matrix[..., i0, i2] * sign)
else:
central = torch.acos(matrix[..., i0, i0])
out = (
_angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan),
central,
_angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan),
)
return torch.stack(out, dim=-1)
def rotmat_to_euler(convention: str, matrix: torch.Tensor) -> torch.Tensor:
"""Rotation matrix -> Euler angles, inverse of :func:`euler_to_rotmat`.
PyTorch3D's matrix_to_euler_angles uses the convention R = R_a R_b R_c for
convention "abc"; that matches roma's UPPERCASE ordering directly. For
roma's lowercase, the matrix is reversed (R_c R_b R_a), so we decompose
with the reversed convention and flip the angles back to axis order.
"""
if matrix.shape[-2:] != (3, 3):
raise ValueError(f"Expected (..., 3, 3) rotation matrix, got {tuple(matrix.shape)}.")
if convention.isupper():
return _matrix_to_euler_intrinsic(matrix, convention)
decomposed = _matrix_to_euler_intrinsic(matrix, convention.upper()[::-1])
return decomposed.flip(-1)
def unitquat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
"""Unit quaternion (x, y, z, w) -> rotation matrix.
Matches roma.unitquat_to_rotmat (scalar-last). The quaternion is assumed to be normalized.
Args:
quat: (..., 4) unit quaternion.
Returns:
(..., 3, 3) rotation matrix.
"""
x, y, z, w = quat.unbind(dim=-1)
tx, ty, tz = 2 * x, 2 * y, 2 * z
twx, twy, twz = tx * w, ty * w, tz * w
txx, txy, txz = tx * x, ty * x, tz * x
tyy, tyz, tzz = ty * y, tz * y, tz * z
one = torch.ones_like(w)
flat = (
one - (tyy + tzz), txy - twz, txz + twy,
txy + twz, one - (txx + tzz), tyz - twx,
txz - twy, tyz + twx, one - (txx + tyy),
)
return torch.stack(flat, dim=-1).reshape(quat.shape[:-1] + (3, 3))

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import math
from functools import lru_cache
from typing import List, Tuple
from typing import List, Optional, Tuple
import numpy as np
import torch
@ -558,31 +558,47 @@ def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -
class FaceLandmarker(nn.Module):
"""BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short'
(128², 2m) or 'full' (192² FPN, 5m). State dict uses inner-module prefixes
`detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel
(128², 2m), 'full' (192² FPN, 5m), or 'both' which holds both detectors
alongside shared mesh/blendshapes so detect_batch can run them and per-frame
keep whichever found more faces (tie short). State dict uses inner-module
prefixes `detector.*` (short or single-variant), `detector_full.*` (only
under 'both'), `mesh.*`, `blendshapes.*`; the outer FaceLandmarkerModel
wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading.
"""
def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"):
super().__init__()
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant)
self.detector_variant = detector_variant
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
if detector_variant == "both":
self.detector = BlazeFace(device=device, dtype=dtype, operations=operations)
self.detector_full = BlazeFaceFullRange(device=device, dtype=dtype, operations=operations)
else:
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}[detector_variant]
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations)
self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations)
self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False)
def _detector(self, variant: str) -> nn.Module:
if self.detector_variant == "both":
return self.detector_full if variant == "full" else self.detector
return self.detector
def run_detector_batch(self, images_rgb_uint8: List[np.ndarray],
score_thresh: float = _BF_MIN_SCORE,
iou_thresh: float = 0.5):
iou_thresh: float = 0.5,
variant: Optional[str] = None):
"""Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded)
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords."""
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords.
`variant` overrides per-call (required on 'both'-mode instances)."""
if not images_rgb_uint8:
return [], [], [], []
device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype
if variant is None:
variant = "short" if self.detector_variant == "both" else self.detector_variant
detector = self._detector(variant)
device, dtype = detector.stem.weight.device, detector.stem.weight.dtype
det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range)
if self.detector_variant == "full"
if variant == "full"
else (_BF_INPUT_SIZE, _decode_blazeface))
# Same-size frames: stack once and transfer once. Variable size falls back
@ -598,7 +614,7 @@ class FaceLandmarker(nn.Module):
det_crops = [w[0] for w in warps]
sub_rects = [(w[1], w[2], w[3]) for w in warps]
regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0))
regs_b, cls_b = detector(torch.stack(det_crops, dim=0))
regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy()
per_frame = []
for b in range(len(images_rgb_uint8)):
@ -607,15 +623,22 @@ class FaceLandmarker(nn.Module):
return img_raws, sub_rects, sizes, per_frame
def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1,
score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]:
score_thresh: float = _BF_MIN_SCORE,
variant: Optional[str] = None) -> List[List[dict]]:
"""Full pipeline batched across `images_rgb_uint8`. Returns one face-dict
list per image (empty if nothing detected). Face dict:
bbox_xyxy (4,) image pixels, blendshapes {52} [0,1],
landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in
192-canonical (pre-transformation) units, presence float (raw logit).
On 'both'-mode instances `variant=None` runs both detectors and keeps
whichever found more faces per frame (tie short).
"""
if variant is None and self.detector_variant == "both":
s = self.detect_batch(images_rgb_uint8, num_faces=num_faces, score_thresh=score_thresh, variant="short")
f = self.detect_batch(images_rgb_uint8, num_faces=num_faces, score_thresh=score_thresh, variant="full")
return [s[b] if len(s[b]) >= len(f[b]) else f[b] for b in range(len(images_rgb_uint8))]
img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch(
images_rgb_uint8, score_thresh=score_thresh,
images_rgb_uint8, score_thresh=score_thresh, variant=variant,
)
# tensor-normalized → image-normalized [0,1] for _detection_to_face_rect.
for b, decoded in enumerate(per_frame_dets):

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,6 @@
"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB node."""
"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB
node, plus pose-data exporters (BuildPoseGLB / SavePoseBVH) that accept either
SAM3DBody Predict's MHR pose data or external-rig pose data from Kimodo."""
import json
import logging
@ -15,6 +17,15 @@ import folder_paths
from comfy.cli_args import args
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_extras.sam3d_body.export.bvh import build_bvh
from comfy_extras.sam3d_body.export.glb_openpose import build_glb_openpose
from comfy_extras.sam3d_body.export.glb_skeletal import build_glb_skeletal
MHRPoseData = IO.Custom("MHR_POSE_DATA")
KimodoPoseData = IO.Custom("KIMODO_POSE_DATA")
SAM3DBodyModel = IO.Custom("SAM3D_BODY_MODEL")
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None):
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
@ -386,10 +397,437 @@ class SaveGLB(IO.ComfyNode):
return IO.NodeOutput(ui={"3d": results})
def rainbow_tilt_inputs():
"""Shared rainbow-shader tilt inputs (used by Render and ToGLB schemas)."""
return [
IO.Float.Input(
"rainbow_tilt_z", default=-35.0, min=-90.0, max=90.0, step=0.5,
tooltip="Rotate rainbow jet axis around Z (forward). Differentiates left/right.",
),
IO.Float.Input(
"rainbow_tilt_x", default=0.0, min=-90.0, max=90.0, step=0.5,
tooltip="Rotate rainbow jet axis around X (right). Differentiates front/back.",
),
]
class BuildPoseGLB(IO.ComfyNode):
"""Convert pose_data to an in-memory animated GLB"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BuildPoseGLB",
display_name="Build Pose GLB",
description="Convert pose data to an animated GLB",
category="3d",
inputs=[
IO.MultiType.Input(
"pose_data", types=[MHRPoseData, KimodoPoseData],
tooltip=(
"MHR pose data from SAM3DBody_Predict, or external-rig "
"pose data from Kimodo (`_skeleton_override`-augmented)."
),
),
SAM3DBodyModel.Input("sam3d_body_model", optional=True),
IO.DynamicCombo.Input(
"mesh_style",
options=[
IO.DynamicCombo.Option("body_mesh", [
IO.Int.Input(
"bone_smooth_window",
default=0, min=0, max=51, step=2,
tooltip=(
"Gaussian window on per-bone rotation keyframes. 0 = off. "
"7-15 helps cartwheels/spins where upstream Smooth misses spikes."
),
),
IO.DynamicCombo.Input(
"bone_vis",
options=[
IO.DynamicCombo.Option("off", []),
IO.DynamicCombo.Option("octahedrons", [
IO.Float.Input(
"bone_vis_radius_m",
default=0.02, min=0.005, max=0.5, step=0.005,
tooltip="Radius in m (sphere radius / octahedron half-width).",
),
IO.Combo.Input(
"bone_vis_color",
options=["white", "rainbow_y"],
default="rainbow_y",
tooltip=(
"Per-bone vertex colors (unlit material). "
"'white' = none, 'rainbow_y' = head→toe jet."
),
),
]),
IO.DynamicCombo.Option("sticks", [
IO.Combo.Input(
"bone_vis_color",
options=["white", "rainbow_y"],
default="rainbow_y",
tooltip="Per-bone vertex colors (see octahedrons).",
),
]),
],
tooltip=(
"Bone vis shape, rigidly skinned to each joint. "
"'octahedrons' = Blender-style directional bones (joint → "
"primary child); 'sticks' = thin lines."
),
),
IO.DynamicCombo.Input(
"shader",
options=[
IO.DynamicCombo.Option("default", []),
IO.DynamicCombo.Option("rainbow", [
*rainbow_tilt_inputs(),
IO.Float.Input(
"person_palette_falloff",
default=0.6, min=0.1, max=1.0, step=0.05,
tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.",
),
]),
IO.DynamicCombo.Option("rainbow_face_normal", [
*rainbow_tilt_inputs(),
IO.Float.Input(
"person_palette_falloff",
default=0.6, min=0.1, max=1.0, step=0.05,
tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.",
),
]),
IO.DynamicCombo.Option("rainbow_face_semantic", [
*rainbow_tilt_inputs(),
IO.Float.Input(
"person_palette_falloff",
default=0.6, min=0.1, max=1.0, step=0.05,
tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.",
),
]),
],
tooltip=(
"Bake per-vertex colors matching the Render node's shaders "
"(COLOR_0 + KHR_materials_unlit). 'default' = no colors."
),
),
]),
IO.DynamicCombo.Option("bones_only", [
IO.Int.Input(
"bone_smooth_window",
default=0, min=0, max=51, step=2,
tooltip=(
"Gaussian window on per-bone rotation keyframes. 0 = off. "
"7-15 helps cartwheels/spins where upstream Smooth misses spikes."
),
),
IO.DynamicCombo.Input(
"bone_vis",
options=[
IO.DynamicCombo.Option("octahedrons", [
IO.Float.Input(
"bone_vis_radius_m",
default=0.02, min=0.005, max=0.5, step=0.005,
tooltip="Radius in m (sphere radius / octahedron half-width).",
),
IO.Combo.Input(
"bone_vis_color",
options=["white", "rainbow_y"],
default="rainbow_y",
tooltip=(
"Per-bone vertex colors (unlit material). "
"'white' = none, 'rainbow_y' = head→toe jet."
),
),
]),
IO.DynamicCombo.Option("sticks", [
IO.Combo.Input(
"bone_vis_color",
options=["white", "rainbow_y"],
default="rainbow_y",
tooltip="Per-bone vertex colors (see octahedrons).",
),
]),
],
tooltip=(
"Bone vis shape, rigidly skinned to each joint. "
"'octahedrons' = Blender-style directional bones (joint → "
"primary child); 'sticks' = thin lines."
),
),
]),
IO.DynamicCombo.Option("openpose", [
IO.Float.Input(
"marker_radius_m", default=0.010, min=0.005, max=0.1, step=0.001,
tooltip="Sphere radius in m.",
),
IO.Float.Input(
"stick_radius_m", default=0.008, min=0.002, max=0.05, step=0.001,
tooltip="Limb half-width in m. Auto-clamped to bone_length x 0.1.",
),
IO.Boolean.Input(
"include_hands", default=False,
tooltip=(
"Append 21+21 OpenPose hands (wrist + 5 fingers x 4 joints, "
"base→tip) sourced from pred_keypoints_3d."
),
),
IO.Float.Input(
"hand_marker_radius_m", default=0.005, min=0.001, max=0.1, step=0.001,
tooltip="Hand sphere radius in m.",
),
IO.Float.Input(
"hand_stick_radius_m", default=0.003, min=0.001, max=0.05, step=0.001,
tooltip="Hand limb half-width in m.",
),
IO.Combo.Input(
"face_source",
options=["off", "rig"],
default="off",
tooltip=(
"'rig' adds ~30 face-contour landmarks sampled from pred_vertices "
"at fixed head-mesh vertex IDs (brow/eyes/nose/mouth/jaw); needs "
"canonical_colors on pose_data."
),
),
IO.Float.Input(
"face_marker_radius_m", default=0.0, min=0.0, max=0.05, step=0.0005,
tooltip="Face dot radius. 0 = auto = 0.3 x marker_radius_m.",
),
]),
IO.DynamicCombo.Option("scail", [
IO.Float.Input(
"stick_radius_m", default=0.022, min=0.002, max=0.1, step=0.001,
tooltip=(
"Cylinder radius in m. Bones are open cylinders at constant "
"radius; joint spheres (auto-sized to match) cap the open ends. "
"SCAIL reference = 0.0215 m."
),
),
IO.Float.Input(
"marker_radius_m", default=0.0, min=0.0, max=0.1, step=0.001,
tooltip="Joint sphere radius. 0 = auto = stick_radius_m (flush cap).",
),
IO.Float.Input(
"material_roughness", default=0.3, min=0.0, max=1.0, step=0.05,
tooltip="PBR roughness. SCAIL ref = 0.3. 1 = matte; 0 = chrome.",
),
IO.Boolean.Input(
"include_hands", default=False,
tooltip="Append 21+21 hand keypoints + capsule sticks per track.",
),
IO.Float.Input(
"hand_marker_radius_m", default=0.005, min=0.001, max=0.05, step=0.001,
tooltip="Hand sphere radius in m.",
),
IO.Float.Input(
"hand_stick_radius_m", default=0.003, min=0.001, max=0.05, step=0.001,
tooltip="Hand cylinder radius in m.",
),
]),
],
tooltip=(
"'body_mesh' = real Armature (127 bones, skinning, TRS "
"keyframes, 72 face morphs; needs model). "
"'bones_only' = bone-shape primitives at each joint (preview armature). "
"'openpose' = OpenPose-18 3D skeleton from keypoints "
"(no model needed). 'scail' = SCAIL 3D capsule rig (open "
"cylinders capped flush by joint spheres)."
),
),
IO.Float.Input(
"fps", default=24.0, min=1.0, max=240.0, step=1.0,
tooltip="Animation frame rate.",
),
IO.Combo.Input(
"camera_translation",
options=["off", "centered", "absolute"],
default="off",
tooltip=(
"Bake pred_cam_t into per-track root translation. "
"'off' = origin; 'centered' = delta from frame 0; "
"'absolute' = raw (Z is camera depth — usually meters away)."
),
),
IO.Int.Input(
"track_index", default=-1, min=-1, max=15,
tooltip="-1 = all tracks; ≥0 = single track.",
),
],
outputs=[IO.File3DGLB.Output("glb")],
)
@classmethod
def execute(cls, pose_data, mesh_style, sam3d_body_model=None, fps=24.0, camera_translation="off", track_index=-1) -> IO.NodeOutput:
mesh_style = mesh_style or {"mesh_style": "body_mesh"}
mode_key = mesh_style["mesh_style"]
# `shader` is nested in body_mesh; absent for bones_only.
shader_dict = mesh_style.get("shader") or {}
shader_key = shader_dict.get("shader", "default")
common = dict(
fps=float(fps),
camera_translation=str(camera_translation),
track_index=int(track_index),
shader=str(shader_key),
rainbow_tilt_x_deg=float(shader_dict.get("rainbow_tilt_x", 0.0)),
rainbow_tilt_z_deg=float(shader_dict.get("rainbow_tilt_z", 0.0)),
person_palette_falloff=float(shader_dict.get("person_palette_falloff", 0.6)),
)
if mode_key in ("body_mesh", "bones_only"):
# External rigs (e.g. ComfyUI-Kimodo) supply pose_data["_skeleton_override"]
# so the GLB writer reads rig/bind/skin from there instead of MHR.
has_external_rig = isinstance(pose_data, dict) and ("_skeleton_override" in pose_data)
if sam3d_body_model is None and not has_external_rig:
raise ValueError(
f"BuildPoseGLB: '{mode_key}' mode needs the `sam3d_body_model` input OR a "
"`_skeleton_override` dict in pose_data. Connect the SAM3DBody model "
"or feed pose_data from a node that supplies the override (e.g. KimodoSample)."
)
default_shape = "off" if mode_key == "body_mesh" else "octahedrons"
bone_vis_dict = mesh_style.get("bone_vis", {"bone_vis": default_shape})
bone_vis = str(bone_vis_dict.get("bone_vis", default_shape))
bone_vis_radius_m = float(bone_vis_dict.get("bone_vis_radius_m", 0.04))
bone_vis_color = str(bone_vis_dict.get("bone_vis_color", "white"))
glb_bytes = build_glb_skeletal(
pose_data, sam3d_body_model,
bone_smooth_window=int(mesh_style.get("bone_smooth_window", 0)),
bone_vis=bone_vis,
bone_vis_radius_m=bone_vis_radius_m,
bone_vis_color=bone_vis_color,
include_body_mesh=(mode_key == "body_mesh"),
**common,
)
elif mode_key == "openpose":
# Rig-independent: sourced from pred_keypoints_3d. face_source='rig'
# additionally reads canonical_colors for head-mesh vertex IDs.
glb_bytes = build_glb_openpose(
pose_data,
fps=float(fps),
camera_translation=str(camera_translation),
track_index=int(track_index),
marker_radius_m=float(mesh_style.get("marker_radius_m", 0.025)),
stick_radius_m=float(mesh_style.get("stick_radius_m", 0.008)),
include_hands=bool(mesh_style.get("include_hands", False)),
hand_marker_radius_m=float(mesh_style.get("hand_marker_radius_m", 0.005)),
hand_stick_radius_m=float(mesh_style.get("hand_stick_radius_m", 0.003)),
face_source=str(mesh_style.get("face_source", "off")),
face_marker_radius_m=float(mesh_style.get("face_marker_radius_m", 0.0)),
palette="openpose",
shape="ellipsoid",
)
elif mode_key == "scail":
# SCAIL rig: open cylinders capped flush by joint spheres (sphere
# radius defaults to cylinder radius for a seamless silhouette).
cap_stick_radius = float(mesh_style.get("stick_radius_m", 0.022))
cap_marker_radius = float(mesh_style.get("marker_radius_m", 0.0))
if cap_marker_radius <= 0.0:
cap_marker_radius = cap_stick_radius
glb_bytes = build_glb_openpose(
pose_data,
fps=float(fps),
camera_translation=str(camera_translation),
track_index=int(track_index),
marker_radius_m=cap_marker_radius,
stick_radius_m=cap_stick_radius,
include_hands=bool(mesh_style.get("include_hands", False)),
hand_marker_radius_m=float(mesh_style.get("hand_marker_radius_m", 0.005)),
hand_stick_radius_m=float(mesh_style.get("hand_stick_radius_m", 0.003)),
face_source="off",
palette="scail",
shape="capsule",
smooth_shade=True,
# SCAIL material: slightly glossy (0.3) + double-sided so the
# inside of the open cylinders shades sensibly at grazing angles.
material_roughness=float(mesh_style.get("material_roughness", 0.3)),
material_double_sided=True,
)
else:
raise ValueError(f"BuildPoseGLB: unknown mesh_style {mode_key!r}")
return IO.NodeOutput(Types.File3D(BytesIO(glb_bytes), file_format="glb"))
class SavePoseBVH(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SavePoseBVH",
description="Save pose data as BVH mocap file",
display_name="Save Pose BVH",
category="3d",
is_output_node=True,
inputs=[
IO.MultiType.Input(
"pose_data", types=[MHRPoseData, KimodoPoseData],
tooltip=(
"MHR pose data from SAM3DBody_Predict, or external-rig "
"pose data from Kimodo."
),
),
SAM3DBodyModel.Input("sam3d_body_model"),
IO.String.Input("filename_prefix", default="3d/ComfyUI"),
IO.Float.Input(
"fps", default=24.0, min=1.0, max=240.0, step=1.0,
tooltip="Animation frame rate (BVH `Frame Time`).",
),
IO.Combo.Input(
"camera_translation",
options=["off", "centered", "absolute"],
default="off",
tooltip=(
"Bake pred_cam_t into the root's position channels. "
"'off' = bind position; 'centered' = delta from frame 0; "
"'absolute' = raw (Z is camera depth — usually meters away)."
),
),
IO.Combo.Input(
"units",
options=["cm", "m"],
default="cm",
tooltip="BVH OFFSET/position units. 'cm' is the mocap standard.",
),
IO.Int.Input(
"track_index", default=0, min=0, max=15,
tooltip="Track to export. BVH carries one skeleton; export multi-person clips one at a time.",
),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
outputs=[],
)
@classmethod
def execute(cls, pose_data, sam3d_body_model, filename_prefix="3d/ComfyUI",
fps=24.0, camera_translation="off", units="cm",
track_index=0) -> IO.NodeOutput:
bvh_bytes = build_bvh(
pose_data, sam3d_body_model,
fps=float(fps),
camera_translation=str(camera_translation),
track_index=int(track_index),
units=str(units),
)
full_output_folder, filename, counter, subfolder, _ = \
folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory(),
)
f = f"{filename}_{counter:05}_.bvh"
out_path = os.path.join(full_output_folder, f)
with open(out_path, "wb") as fh:
fh.write(bvh_bytes)
return IO.NodeOutput(ui={"3d": [{
"filename": f,
"subfolder": subfolder,
"type": "output",
}]})
class Save3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [SaveGLB]
return [SaveGLB, BuildPoseGLB, SavePoseBVH]
async def comfy_entrypoint() -> Save3DExtension:

View File

@ -2,11 +2,10 @@ import torch
import comfy.utils
import comfy.model_management
import numpy as np
import math
import colorsys
from tqdm import tqdm
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.pose.keypoint_draw import KeypointDraw
from comfy_extras.nodes_lotus import LotusConditioning
@ -73,281 +72,6 @@ def _to_openpose_frames(all_keypoints, all_scores, height, width):
return frames
class KeypointDraw:
"""
Pose keypoint drawing class that supports both numpy and cv2 backends.
"""
def __init__(self):
try:
import cv2
self.draw = cv2
except ImportError:
self.draw = self
# Hand connections (same for both hands)
self.hand_edges = [
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
[0, 5], [5, 6], [6, 7], [7, 8], # index
[0, 9], [9, 10], [10, 11], [11, 12], # middle
[0, 13], [13, 14], [14, 15], [15, 16], # ring
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
]
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed)
self.body_limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
[1, 16], [16, 18]
]
# Colors matching DWPose
self.colors = [
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
]
@staticmethod
def circle(canvas_np, center, radius, color, **kwargs):
"""Draw a filled circle using NumPy vectorized operations."""
cx, cy = center
h, w = canvas_np.shape[:2]
radius_int = int(np.ceil(radius))
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
if y_max <= y_min or x_max <= x_min:
return
y, x = np.ogrid[y_min:y_max, x_min:x_max]
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
"""Draw line using Bresenham's algorithm with NumPy operations."""
x0, y0, x1, y1 = *pt1, *pt2
h, w = canvas_np.shape[:2]
dx, dy = abs(x1 - x0), abs(y1 - y0)
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
err, x, y, line_points = dx - dy, x0, y0, []
while True:
line_points.append((x, y))
if x == x1 and y == y1:
break
e2 = 2 * err
if e2 > -dy:
err, x = err - dy, x + sx
if e2 < dx:
err, y = err + dx, y + sy
if thickness > 1:
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
for px, py in line_points:
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
if y_max > y_min and x_max > x_min:
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
else:
line_points = np.array(line_points)
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
if (valid_points := line_points[valid]).size:
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
@staticmethod
def fillConvexPoly(canvas_np, pts, color, **kwargs):
"""Fill polygon using vectorized scanline algorithm."""
if len(pts) < 3:
return
pts = np.array(pts, dtype=np.int32)
h, w = canvas_np.shape[:2]
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
if y_max <= y_min or x_max <= x_min:
return
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
for i in range(len(pts)):
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
y1, y2 = p1[1], p2[1]
if y1 == y2:
continue
if y1 > y2:
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
continue
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
"""Python implementation of cv2.ellipse2Poly."""
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
angle = angle % 360
if arc_start > arc_end:
arc_start, arc_end = arc_end, arc_start
while arc_start < 0:
arc_start, arc_end = arc_start + 360, arc_end + 360
while arc_end > 360:
arc_end, arc_start = arc_end - 360, arc_start - 360
if arc_end - arc_start > 360:
arc_start, arc_end = 0, 360
angle_rad = math.radians(angle)
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
pts = []
for i in range(arc_start, arc_end + delta, delta):
theta_rad = math.radians(min(i, arc_end))
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
unique_pts, prev_pt = [], (float('inf'), float('inf'))
for pt in pts:
if (pt_tuple := tuple(pt)) != prev_pt:
unique_pts.append(pt)
prev_pt = pt_tuple
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3):
"""
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
Expected keypoint format (after neck insertion and remapping):
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
- Foot: 18-23 (6 keypoints)
- Face: 24-91 (68 landmarks)
- Right hand: 92-112 (21 keypoints)
- Left hand: 113-133 (21 keypoints)
Args:
canvas: The canvas to draw on (numpy array)
keypoints: Array of keypoint coordinates
scores: Optional confidence scores for each keypoint
threshold: Minimum confidence threshold for drawing keypoints
Returns:
canvas: The canvas with keypoints drawn
"""
H, W, C = canvas.shape
# Draw body limbs
if draw_body and len(keypoints) >= 18:
for i, limb in enumerate(self.body_limbSeq):
# Convert from 1-indexed to 0-indexed
idx1, idx2 = limb[0] - 1, limb[1] - 1
if idx1 >= 18 or idx2 >= 18:
continue
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
Y = [keypoints[idx1][0], keypoints[idx2][0]]
X = [keypoints[idx1][1], keypoints[idx2][1]]
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
if length < 1:
continue
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)])
# Draw body keypoints
if draw_body and len(keypoints) >= 18:
for i in range(18):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
# Draw foot keypoints (18-23, 6 keypoints)
if draw_feet and len(keypoints) >= 24:
for i in range(18, 24):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
# Draw right hand (92-112)
if draw_hands and len(keypoints) >= 113:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 92 + edge[0], 92 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
# Draw right hand keypoints
for i in range(92, 113):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
# Draw left hand (113-133)
if draw_hands and len(keypoints) >= 134:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 113 + edge[0], 113 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
# Draw left hand keypoints
for i in range(113, 134):
if scores is not None and i < len(scores) and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
# Draw face keypoints (24-91) - white dots only, no lines
if draw_face and len(keypoints) >= 92:
eps = 0.01
for i in range(24, 92):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
return canvas
class SDPoseDrawKeypoints(io.ComfyNode):
@classmethod
def define_schema(cls):

View File

@ -0,0 +1,348 @@
"""Pose keypoint drawing primitives shared across pose nodes.
`KeypointDraw` exposes a cv2-or-numpy backend so callers can use one drawing
API regardless of whether OpenCV is installed:
kd = KeypointDraw()
kd.draw.circle(canvas, (x, y), radius, color, thickness=-1)
kd.draw.line(canvas, p1, p2, color, thickness=4)
kd.draw.fillConvexPoly(canvas, polygon, color)
kd.draw.ellipse2Poly(center, axes, angle, 0, 360, 1)
It also carries DWPose's body/hand topology + color tables, used by:
- comfy_extras.nodes_sdpose (SDPose pose drawing)
- comfy_extras.pose.export.openpose_2d (SAM 3D Body 2D pose viz)
- comfy_extras.pose.export.glb_shared (SAM 3D Body GLB tables)
"""
import colorsys
import math
import numpy as np
class KeypointDraw:
"""
Pose keypoint drawing class that supports both numpy and cv2 backends.
"""
def __init__(self):
try:
import cv2
self.draw = cv2
except ImportError:
self.draw = self
# Hand connections (same for both hands)
self.hand_edges = [
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
[0, 5], [5, 6], [6, 7], [7, 8], # index
[0, 9], [9, 10], [10, 11], [11, 12], # middle
[0, 13], [13, 14], [14, 15], [15, 16], # ring
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
]
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed)
self.body_limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
[1, 16], [16, 18]
]
# Colors matching DWPose
self.colors = [
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
]
@staticmethod
def circle(canvas_np, center, radius, color, **kwargs):
"""Draw a filled circle using NumPy vectorized operations."""
cx, cy = center
h, w = canvas_np.shape[:2]
radius_int = int(np.ceil(radius))
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
if y_max <= y_min or x_max <= x_min:
return
y, x = np.ogrid[y_min:y_max, x_min:x_max]
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
"""Draw line using Bresenham's algorithm with NumPy operations."""
x0, y0, x1, y1 = *pt1, *pt2
h, w = canvas_np.shape[:2]
dx, dy = abs(x1 - x0), abs(y1 - y0)
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
err, x, y, line_points = dx - dy, x0, y0, []
while True:
line_points.append((x, y))
if x == x1 and y == y1:
break
e2 = 2 * err
if e2 > -dy:
err, x = err - dy, x + sx
if e2 < dx:
err, y = err + dx, y + sy
if thickness > 1:
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
for px, py in line_points:
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
if y_max > y_min and x_max > x_min:
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
else:
line_points = np.array(line_points)
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
if (valid_points := line_points[valid]).size:
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
@staticmethod
def fillConvexPoly(canvas_np, pts, color, **kwargs):
"""Fill polygon using vectorized scanline algorithm."""
if len(pts) < 3:
return
pts = np.array(pts, dtype=np.int32)
h, w = canvas_np.shape[:2]
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
if y_max <= y_min or x_max <= x_min:
return
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
for i in range(len(pts)):
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
y1, y2 = p1[1], p2[1]
if y1 == y2:
continue
if y1 > y2:
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
continue
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
"""Python implementation of cv2.ellipse2Poly."""
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
angle = angle % 360
if arc_start > arc_end:
arc_start, arc_end = arc_end, arc_start
while arc_start < 0:
arc_start, arc_end = arc_start + 360, arc_end + 360
while arc_end > 360:
arc_end, arc_start = arc_end - 360, arc_start - 360
if arc_end - arc_start > 360:
arc_start, arc_end = 0, 360
angle_rad = math.radians(angle)
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
pts = []
for i in range(arc_start, arc_end + delta, delta):
theta_rad = math.radians(min(i, arc_end))
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
unique_pts, prev_pt = [], (float('inf'), float('inf'))
for pt in pts:
if (pt_tuple := tuple(pt)) != prev_pt:
unique_pts.append(pt)
prev_pt = pt_tuple
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
draw_body=True, draw_feet=True, draw_face=True, draw_hands=True,
stick_width=4, face_point_size=3,
marker_radius=4, hand_stick_width=2, hand_marker_radius=4,
limb_alpha=1.0, hand_dot_color=(0, 0, 255)):
"""
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
Expected keypoint format (after neck insertion and remapping):
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
- Foot: 18-23 (6 keypoints)
- Face: 24-91 (68 landmarks)
- Right hand: 92-112 (21 keypoints)
- Left hand: 113-133 (21 keypoints)
Args:
canvas: The canvas to draw on (numpy array)
keypoints: Array of keypoint coordinates
scores: Optional confidence scores for each keypoint
threshold: Minimum confidence threshold for drawing keypoints
stick_width: Body limb half-width (passed to ellipse2Poly).
face_point_size: Radius of the white face dots.
marker_radius: Radius of body/foot dots. Defaults to 4 (DWPose).
hand_stick_width: Thickness of hand limb lines. Defaults to 2.
hand_marker_radius: Radius of hand dots. Defaults to 4.
limb_alpha: Body-limb alpha blend (0..1). 1.0 = opaque fill (default),
<1.0 enables per-limb bbox-clipped alpha overlay (DWPose semantics
where overlapping limbs darken).
hand_dot_color: Either an (R, G, B) tuple/list of ints for solid-color
hand dots (default (0, 0, 255), DWPose blue), or a (21, 3) array
for per-keypoint hand-dot colors (OpenPose-style rainbow palette).
Returns:
canvas: The canvas with keypoints drawn
"""
H, W, C = canvas.shape
# Normalize hand_dot_color to a (21, 3) int array.
hdc_arr = np.asarray(hand_dot_color, dtype=int)
if hdc_arr.ndim == 1:
hdc_arr = np.tile(hdc_arr.reshape(1, 3), (21, 1))
hand_dot_tuples = [tuple(int(c) for c in hdc_arr[i]) for i in range(21)]
do_alpha = float(limb_alpha) < 1.0
# Draw body limbs
if draw_body and len(keypoints) >= 18:
for i, limb in enumerate(self.body_limbSeq):
# Convert from 1-indexed to 0-indexed
idx1, idx2 = limb[0] - 1, limb[1] - 1
if idx1 >= 18 or idx2 >= 18:
continue
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
Y = [keypoints[idx1][0], keypoints[idx2][0]]
X = [keypoints[idx1][1], keypoints[idx2][1]]
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
if length < 1:
continue
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
color = self.colors[i % len(self.colors)]
if do_alpha:
_fill_poly_alpha(canvas, polygon, color, limb_alpha, self.draw)
else:
self.draw.fillConvexPoly(canvas, polygon, color)
# Draw body keypoints
if draw_body and len(keypoints) >= 18:
for i in range(18):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), marker_radius, self.colors[i % len(self.colors)], thickness=-1)
# Draw foot keypoints (18-23, 6 keypoints)
if draw_feet and len(keypoints) >= 24:
for i in range(18, 24):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), marker_radius, self.colors[i % len(self.colors)], thickness=-1)
# Draw right hand (92-112)
if draw_hands and len(keypoints) >= 113:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 92 + edge[0], 92 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=hand_stick_width)
# Draw right hand keypoints
for i in range(92, 113):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), hand_marker_radius, hand_dot_tuples[i - 92], thickness=-1)
# Draw left hand (113-133)
if draw_hands and len(keypoints) >= 134:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 113 + edge[0], 113 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=hand_stick_width)
# Draw left hand keypoints
for i in range(113, 134):
if scores is not None and i < len(scores) and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), hand_marker_radius, hand_dot_tuples[i - 113], thickness=-1)
# Draw face keypoints (24-91) - white dots only, no lines
if draw_face and len(keypoints) >= 92:
eps = 0.01
for i in range(24, 92):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
return canvas
def _fill_poly_alpha(canvas, polygon, color, alpha, draw_backend):
"""Bbox-clipped alpha-blended fillConvexPoly. `canvas` is mutated in-place.
DWPose semantics: each limb blends with `alpha` independently so overlapping
limbs darken further. Operates on the polygon's bbox to avoid copying the
whole canvas per limb.
"""
H, W = canvas.shape[:2]
poly_arr = np.asarray(polygon, dtype=np.int32)
x0 = max(0, int(poly_arr[:, 0].min()))
xN = min(W, int(poly_arr[:, 0].max()) + 1)
y0 = max(0, int(poly_arr[:, 1].min()))
yN = min(H, int(poly_arr[:, 1].max()) + 1)
if xN <= x0 or yN <= y0:
return
local_poly = poly_arr - np.array([x0, y0], dtype=poly_arr.dtype)
roi = canvas[y0:yN, x0:xN].copy()
draw_backend.fillConvexPoly(roi, local_poly, color)
a = float(alpha)
canvas[y0:yN, x0:xN] = np.clip(
roi.astype(np.float32) * a + canvas[y0:yN, x0:xN].astype(np.float32) * (1.0 - a),
0, 255,
).astype(np.uint8)

View File

@ -0,0 +1,207 @@
"""BVH export for SAM 3D Body pose_data.
BVH stores explicit bone OFFSETs per joint, so any standard importer
(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations
directly no heuristic guessing as needed for glTF. We skip the rig's joint 0
(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos +
ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are
intrinsic Z-X-Y Euler degrees.
"""
from __future__ import annotations
import io
from typing import Any, Dict, List
import numpy as np
from .glb_shared import (
bind_skel_state,
bone_locals_from_globals,
collect_tracks,
extract_rig_static,
global_skel_state_from_pose_data,
quat_sign_fix_per_joint,
unflip,
)
def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray:
"""xyzw quat → intrinsic Z-X-Y Euler degrees, returned as (..., 3) in
(z, x, y) order to match BVH's `CHANNELS Zrotation Xrotation Yrotation`."""
q = np.asarray(quat, dtype=np.float64)
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
# ZXY decomposition R = Rz(c)·Rx(a)·Ry(b):
# M[2][1] = 2(yz + xw) → sin(a)
# M[0][1] = 2(xy - zw) → -cos(a) sin(c)
# M[1][1] = 1 - 2(x² + z²) → cos(a) cos(c)
# M[2][0] = 2(xz - yw) → -cos(a) sin(b)
# M[2][2] = 1 - 2(x² + y²) → cos(a) cos(b)
M21 = np.clip(2.0 * (y * z + x * w), -1.0, 1.0)
M01 = 2.0 * (x * y - z * w)
M11 = 1.0 - 2.0 * (x * x + z * z)
M20 = 2.0 * (x * z - y * w)
M22 = 1.0 - 2.0 * (x * x + y * y)
a = np.arcsin(M21)
c = np.arctan2(-M01, M11)
b = np.arctan2(-M20, M22)
out = np.stack([np.rad2deg(c), np.rad2deg(a), np.rad2deg(b)], axis=-1)
return out.astype(np.float32)
def _find_bvh_root(parents: np.ndarray) -> int:
"""First child of the rig's world anchor so the static origin→body stick
bone gets left out. Falls back to the first root joint."""
NJ = parents.shape[0]
world_anchors = [j for j in range(NJ)
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
if not world_anchors:
return 0
children: List[List[int]] = [[] for _ in range(NJ)]
for j in range(NJ):
p = int(parents[j])
if 0 <= p < NJ and p != j:
children[p].append(j)
wa = world_anchors[0]
if children[wa]:
return children[wa][0]
return wa
def _build_children_map(parents: np.ndarray) -> List[List[int]]:
NJ = parents.shape[0]
out: List[List[int]] = [[] for _ in range(NJ)]
for j in range(NJ):
p = int(parents[j])
if 0 <= p < NJ and p != j:
out[p].append(j)
return out
def build_bvh(
pose_data: Dict[str, Any],
model: Any,
*,
fps: float = 24.0,
camera_translation: str = "off",
track_index: int = -1,
units: str = "cm",
) -> bytes:
"""Build a BVH file from pose_data. Returns UTF-8 encoded text bytes.
`units` is "cm" (default, standard mocap convention) or "m". Affects the
OFFSET and root-position values; rotations are independent of units.
"""
if units not in ("cm", "m"):
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
unit_scale = 100.0 if units == "cm" else 1.0
rig_static = extract_rig_static(model)
NJ = int(rig_static["num_joints"])
parents = rig_static["parents"]
frames = pose_data["frames"]
tracks = collect_tracks(pose_data, track_index)
if not tracks:
raise ValueError("build_bvh: no valid tracks in pose_data")
person_k, frame_indices = tracks[0]
n_frames = len(frame_indices)
if n_frames == 0:
raise ValueError("build_bvh: track has zero frames")
body_root = _find_bvh_root(parents)
children_map = _build_children_map(parents)
# Bone OFFSETs come from MHR's translation_offsets (joint position
# relative to parent in parent's local-bind frame). For the BVH root,
# we use its bind world position so the skeleton sits at the right
# spot when imported.
bind_global = bind_skel_state(model) # (NJ, 8) cm
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
offset_m = rig_static["joint_translation_offsets"].astype(np.float64) * 0.01
# DFS order rooted at body_root — matches per-frame channel order.
bvh_order: List[int] = []
def _visit(j: int) -> None:
bvh_order.append(j)
for c in children_map[j]:
_visit(c)
_visit(body_root)
# Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative)
# rather than re-running rig.forward, then derive locals with body_root
# treated as the hierarchy root in BVH-space.
rig_global_m = global_skel_state_from_pose_data(
pose_data, frame_indices, person_k, NJ,
)
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
bvh_parents = parents.copy()
bvh_parents[body_root] = -1
bone_local = bone_locals_from_globals(rig_global_m, bvh_parents)
# Second pass catches sign discontinuities from the parent-inverse composition.
bone_local[..., 3:7] = quat_sign_fix_per_joint(bone_local[..., 3:7])
eulers_deg = _quat_to_zxy_euler_deg(bone_local[..., 3:7])
if camera_translation in ("absolute", "centered"):
cam_t = np.stack([
unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32))
for t in frame_indices
], axis=0).astype(np.float64)
if camera_translation == "centered":
cam_t = cam_t - cam_t[0:1]
root_pos_m = bind_pos_m[body_root][None, :] + cam_t
else:
root_pos_m = np.tile(bind_pos_m[body_root], (n_frames, 1))
lines: List[str] = ["HIERARCHY"]
def _emit_joint(j: int, depth: int, is_root: bool) -> None:
ind = " " * depth
keyword = "ROOT" if is_root else "JOINT"
name = "Hips" if is_root else f"joint_{j:03d}"
lines.append(f"{ind}{keyword} {name}")
lines.append(ind + "{")
o = (bind_pos_m[j] if is_root else offset_m[j]) * unit_scale
lines.append(f"{ind} OFFSET {o[0]:.6f} {o[1]:.6f} {o[2]:.6f}")
if is_root:
lines.append(ind + " CHANNELS 6 Xposition Yposition Zposition "
"Zrotation Xrotation Yrotation")
else:
lines.append(ind + " CHANNELS 3 Zrotation Xrotation Yrotation")
kids = children_map[j]
if kids:
for c in kids:
_emit_joint(c, depth + 1, is_root=False)
else:
# End Site (standard BVH spec) gives leaf bones a drawable length.
lines.append(ind + " End Site")
lines.append(ind + " {")
tip = (offset_m[j] * unit_scale) * 0.3
tip_norm = float(np.linalg.norm(tip))
if tip_norm < 0.5 * unit_scale * 0.01: # < 0.5 mm → fall back
tip = np.array([0.0, 0.05 * unit_scale, 0.0], dtype=np.float64)
lines.append(f"{ind} OFFSET {tip[0]:.6f} {tip[1]:.6f} {tip[2]:.6f}")
lines.append(ind + " }")
lines.append(ind + "}")
_emit_joint(body_root, 0, is_root=True)
lines.append("MOTION")
lines.append(f"Frames: {n_frames}")
lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
# Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per
# frame, columns in `bvh_order` order. Vectorized — savetxt's C-side
# formatting beats Python f-strings by ~10× on long clips.
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
motion = np.concatenate([
root_pos_m * unit_scale, # (N, 3)
eulers_deg[:, body_root].astype(np.float64), # (N, 3)
eulers_deg[:, non_root_idx, :].reshape(n_frames, -1), # (N, 3*(NJ-1))
], axis=1)
buf = io.StringIO()
np.savetxt(buf, motion, fmt="%.6f")
lines.append(buf.getvalue().rstrip("\n"))
return ("\n".join(lines) + "\n").encode("utf-8")

View File

@ -0,0 +1,403 @@
"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent
torch ray-marching SDF renderer adapted to SAM3DBody pose_data.
Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps),
projected through the per-person camera (`pred_cam_t` + `focal_length` +
image_size) so closer limbs appear thicker/brighter the SCAIL-Pose
visual style. Self-contained: no dependency on the SCAIL-Pose package.
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
"""
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import comfy.model_management
from .glb_shared import (
OPENPOSE_18_PAIRS,
OPENPOSE18_TO_MHR70,
OPENPOSE_RAINBOW_18,
SCAIL_LIMB_COLORS_17,
OPENPOSE_HAND_PAIRS,
OPENPOSE_HAND21_TO_MHR70_R,
OPENPOSE_HAND21_TO_MHR70_L,
OPENPOSE_HAND_COLORS_21,
)
def _limb_palette_rgb01(palette: str) -> np.ndarray:
"""17 per-limb RGB colors in [0,1] for the OpenPose-18 body limbs."""
if palette == "scail":
return SCAIL_LIMB_COLORS_17.astype(np.float32)
return OPENPOSE_RAINBOW_18[: len(OPENPOSE_18_PAIRS)].astype(np.float32)
def _build_specs_from_pose(
persons: List[Dict[str, Any]],
*,
include_hands: bool,
palette: str,
person_brightness_falloff: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Flatten body + optional hand limbs for one frame into
(starts, ends, colors_rgba) in camera coords (Y-down, +Z forward).
Drops endpoints that are non-finite or behind the camera.
`person_brightness_falloff` mixes each per-person limb color toward white
by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the
mesh rasterizer and GLB exporters."""
starts: List[np.ndarray] = []
ends: List[np.ndarray] = []
colors: List[np.ndarray] = []
body_limb_colors = _limb_palette_rgb01(palette)
hand_limb_colors = OPENPOSE_HAND_COLORS_21.astype(np.float32)
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
for k, person in enumerate(persons):
kp2d_full = person.get("pred_keypoints_3d")
cam_t = person.get("pred_cam_t")
if kp2d_full is None or cam_t is None:
continue
kp = np.asarray(kp2d_full, dtype=np.float32)
if kp.ndim != 2 or kp.shape[1] != 3 or kp.shape[0] < 70:
continue
cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3)
# pred_keypoints_3d is camera frame (Y-down post-flip); add cam_t to
# place the subject in front of the camera.
kp_cam = kp + cam_t_np[None, :]
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
def _tint(rgb: np.ndarray) -> np.ndarray:
if pastel <= 0:
return rgb
return rgb * (1.0 - pastel) + pastel
# SCAIL drops the 4 face bones (13..16: nose↔eyes, eyes→ears) — in the
# reference, NLF leaves those COCO slots at zero so its `sum==0` skip
# silently culls them. The grey neck limb (12) blends spine direction
# (mid-hip → neck, stable) with the neck→nose direction at 60/40 so
# the stub tracks head pose lightly without flapping around like full
# nose direction does.
body_limb_count = 13 if palette == "scail" else len(OPENPOSE_18_PAIRS)
body_kp = kp_cam[OPENPOSE18_TO_MHR70]
spine_dir = None
if palette == "scail":
mid_hip = 0.5 * (body_kp[8] + body_kp[11]) # 8=RHip, 11=LHip
sd = body_kp[1] - mid_hip # 1=Neck
sd_len = float(np.linalg.norm(sd))
if np.all(np.isfinite(sd)) and sd_len > 1e-6:
spine_dir = sd / sd_len
for limb_i, (a, b) in enumerate(OPENPOSE_18_PAIRS[:body_limb_count]):
sa, sb = body_kp[a], body_kp[b]
if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))):
continue
if sa[2] <= 0 or sb[2] <= 0:
continue
if palette == "scail" and limb_i == 12:
nose_vec = sb - sa
nose_len = float(np.linalg.norm(nose_vec))
if nose_len > 1e-6 and spine_dir is not None:
nose_dir = nose_vec / nose_len
mixed = 0.6 * spine_dir + 0.4 * nose_dir
mixed = mixed / max(float(np.linalg.norm(mixed)), 1e-6)
sb = sa + mixed * (nose_len * 0.5)
elif nose_len > 1e-6:
sb = sa + nose_vec * 0.5
elif spine_dir is not None:
sb = sa + spine_dir * (sd_len * 0.3)
starts.append(sa)
ends.append(sb)
color_rgb = _tint(body_limb_colors[limb_i])
colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0],
dtype=np.float32))
if include_hands:
r_kp = kp_cam[OPENPOSE_HAND21_TO_MHR70_R]
l_kp = kp_cam[OPENPOSE_HAND21_TO_MHR70_L]
for limb_i, (a, b) in enumerate(OPENPOSE_HAND_PAIRS):
for hand_kp in (r_kp, l_kp):
sa, sb = hand_kp[a], hand_kp[b]
if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))):
continue
if sa[2] <= 0 or sb[2] <= 0:
continue
starts.append(sa)
ends.append(sb)
color_rgb = _tint(hand_limb_colors[(a + b) % len(hand_limb_colors)])
colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0],
dtype=np.float32))
if not starts:
return (np.zeros((0, 3), dtype=np.float32),
np.zeros((0, 3), dtype=np.float32),
np.zeros((0, 4), dtype=np.float32))
return (np.stack(starts).astype(np.float32),
np.stack(ends).astype(np.float32),
np.stack(colors).astype(np.float32))
def _ray_capsule_t(
ray_dirs: torch.Tensor, # (K, 3) unit rays from camera origin
starts: torch.Tensor, # (M, 3)
ends: torch.Tensor, # (M, 3)
ba_norm: torch.Tensor, # (M, 3) unit axis (A → B)
ba_len: torch.Tensor, # (M,) segment length
radius: float,
) -> torch.Tensor:
"""Closed-form ray-capsule intersection. Returns (K, M) tensor of ray
parameters t to the nearest valid hit per capsule, +inf where the ray
misses. A capsule is the union of (cylinder body, hemisphere at A,
hemisphere at B); each component is a quadratic root-find."""
INF = float("inf")
r_sq = float(radius) * float(radius)
# Cached dot products.
dn = ray_dirs @ ba_norm.transpose(0, 1) # (K, M) — d·n
dA = ray_dirs @ starts.transpose(0, 1) # (K, M) — d·A
dB = ray_dirs @ ends.transpose(0, 1) # (K, M) — d·B
An = (starts * ba_norm).sum(-1) # (M,) — A·n
A_sq = (starts * starts).sum(-1) # (M,) — |A|²
B_sq = (ends * ends).sum(-1) # (M,) — |B|²
# Cylinder body: project onto plane ⊥ n and solve |P_⊥(t)|² = r².
a_c = 1.0 - dn * dn # (K, M)
b_c = -2.0 * (dA - dn * An) # (K, M)
c_c = A_sq - An * An - r_sq # (M,)
disc_c = b_c * b_c - 4.0 * a_c * c_c
safe_a = a_c.clamp(min=1e-9)
sqrt_c = torch.sqrt(disc_c.clamp(min=0.0))
t_cyl = (-b_c - sqrt_c) / (2.0 * safe_a)
s_cyl = t_cyl * dn - An # axial projection from A
cyl_ok = (disc_c >= 0) & (a_c > 1e-7) & (t_cyl > 1e-6) & \
(s_cyl >= 0.0) & (s_cyl <= ba_len)
t_cyl = torch.where(cyl_ok, t_cyl, torch.full_like(t_cyl, INF))
# Sphere at A, restricted to the hemisphere with axial projection ≤ 0.
disc_a = dA * dA - (A_sq - r_sq)
sqrt_a = torch.sqrt(disc_a.clamp(min=0.0))
t_sa = dA - sqrt_a
s_a = t_sa * dn - An
a_ok = (disc_a >= 0) & (t_sa > 1e-6) & (s_a <= 0.0)
t_sa = torch.where(a_ok, t_sa, torch.full_like(t_sa, INF))
# Sphere at B, restricted to the hemisphere with axial projection ≥ ba_len.
disc_b = dB * dB - (B_sq - r_sq)
sqrt_b = torch.sqrt(disc_b.clamp(min=0.0))
t_sb = dB - sqrt_b
s_b = t_sb * dn - An
b_ok = (disc_b >= 0) & (t_sb > 1e-6) & (s_b >= ba_len)
t_sb = torch.where(b_ok, t_sb, torch.full_like(t_sb, INF))
return torch.minimum(torch.minimum(t_cyl, t_sa), t_sb)
def _render_capsules_torch(
starts: torch.Tensor,
ends: torch.Tensor,
colors: torch.Tensor,
H: int, W: int,
fx: float, fy: float, cx: float, cy: float,
radius: float,
background_rgb: Optional[torch.Tensor],
device: torch.device,
) -> torch.Tensor:
"""Analytic ray-capsule renderer for a union of capsules. Camera at
origin looking down +Z; pixels in y-down screen coords."""
M = int(starts.shape[0])
if M == 0:
if background_rgb is not None:
return background_rgb.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)
return torch.zeros(H, W, 3, dtype=torch.float32, device=device)
yy, xx = torch.meshgrid(
torch.arange(H, device=device, dtype=torch.float32),
torch.arange(W, device=device, dtype=torch.float32),
indexing="ij",
)
u = (xx - cx) / fx
v = (yy - cy) / fy
z = torch.ones_like(u)
ray_dirs = torch.stack([u, v, z], dim=-1)
ray_dirs = ray_dirs / torch.linalg.norm(ray_dirs, dim=-1, keepdim=True)
flat_dirs = ray_dirs.view(-1, 3)
N = flat_dirs.shape[0]
ba = ends - starts
ba_len = torch.linalg.norm(ba, dim=1).clamp(min=1e-6)
ba_norm = ba / ba_len.unsqueeze(1)
z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item()))
z_near = max(0.05, z_min - radius)
# Union of per-capsule screen-space bboxes. Pixels outside this mask
# provably can't hit any capsule, so the analytic intersection only runs
# on the relevant subset of the canvas (~5-15% at 1080p for typical poses).
sz = starts[:, 2].clamp(min=z_near)
ez = ends[:, 2].clamp(min=z_near)
sx_p = starts[:, 0] * fx / sz + cx
sy_p = starts[:, 1] * fy / sz + cy
ex_p = ends[:, 0] * fx / ez + cx
ey_p = ends[:, 1] * fy / ez + cy
# Projected radius using the closer endpoint — conservative bbox.
r_pix = radius * fx / torch.minimum(sz, ez)
pad = 2.0
xmin_t = (torch.minimum(sx_p, ex_p) - r_pix - pad).floor().long().clamp(min=0, max=W)
xmax_t = (torch.maximum(sx_p, ex_p) + r_pix + pad).ceil().long().clamp(min=0, max=W)
ymin_t = (torch.minimum(sy_p, ey_p) - r_pix - pad).floor().long().clamp(min=0, max=H)
ymax_t = (torch.maximum(sy_p, ey_p) + r_pix + pad).ceil().long().clamp(min=0, max=H)
# One stack→tolist sync amortizes the GPU→CPU read over all M bboxes.
bboxes_cpu = torch.stack([xmin_t, ymin_t, xmax_t, ymax_t], dim=1).tolist()
coarse_mask = torch.zeros(H, W, dtype=torch.bool, device=device)
for xmin_i, ymin_i, xmax_i, ymax_i in bboxes_cpu:
if xmax_i > xmin_i and ymax_i > ymin_i:
coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True
# Analytic ray-capsule intersection. One pass over the masked pixels —
# the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel
# plus 6 SDF evaluations per hit pixel for finite-difference normals.
INF = float("inf")
flat_t = torch.full((N,), INF, device=device, dtype=torch.float32)
flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long)
active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1)
if active_idx.numel() > 0:
# Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory
# manageable when both K (image pixels) and M (capsules) are large.
chunk_max = max(1, int(4_000_000 / max(M, 1)))
for i0 in range(0, active_idx.numel(), chunk_max):
sub = active_idx[i0 : i0 + chunk_max]
t_KM = _ray_capsule_t(
flat_dirs[sub], starts, ends, ba_norm, ba_len, radius,
)
t_min, m_idx = t_KM.min(dim=1)
hit = t_min < INF
if hit.any():
winners = sub[hit]
flat_t[winners] = t_min[hit]
flat_m_idx[winners] = m_idx[hit]
# Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade.
out = torch.zeros((N, 3), dtype=torch.float32, device=device)
if background_rgb is not None:
out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone()
hit_idx = torch.nonzero(flat_m_idx >= 0, as_tuple=False).squeeze(1)
if hit_idx.numel() > 0:
rd = flat_dirs[hit_idx]
t_h = flat_t[hit_idx]
m_h = flat_m_idx[hit_idx]
p_hit = rd * t_h.unsqueeze(-1)
A_h = starts[m_h]
n_h = ba_norm[m_h]
L_h = ba_len[m_h]
proj = ((p_hit - A_h) * n_h).sum(-1).clamp(min=0.0)
proj = torch.minimum(proj, L_h)
C_h = A_h + proj.unsqueeze(-1) * n_h
normals = p_hit - C_h
normals = normals / normals.norm(dim=-1, keepdim=True).clamp(min=1e-8)
col = colors[m_h, :3]
# SCAIL shading (render_torch.py:290-331). Light from camera (+Z toward
# subject); diffuse term `N·-L` simplifies to `-N.z`. Specular uses the
# proper Blinn-Phong half-vector `(view + (-L))` — using `diff` as a
# shortcut would lock the highlight to image center.
diff = torch.clamp(-(normals[:, 2]), min=0.0)
diffuse = 0.45 + 0.55 * diff
view_dir = -rd
half_dir = view_dir.clone()
half_dir[:, 2] -= 1.0
half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8)
spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32)
# SCAIL's reference depth fade uses mm-scale constants (`z_max + 6000`)
# that translate to almost no fade in our meter units — `depth_factor`
# stays ~0.85-1.0. Matching that with a mild ramp.
z_vals = p_hit[:, 2]
z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item())
if z_hi - z_lo > 1e-6:
fade = 1.0 - (z_vals - z_lo) / (z_hi - z_lo)
depth_factor = 0.85 + 0.15 * fade
else:
depth_factor = torch.ones_like(z_vals)
base = col * diffuse.unsqueeze(-1) * depth_factor.unsqueeze(-1)
highlight = (0.5 * spec * depth_factor).unsqueeze(-1)
out[hit_idx] = base + highlight
return out.view(H, W, 3).clamp(0.0, 1.0)
def render_pose_data_capsules(
pose_data: Dict[str, Any],
*,
frame_idx: int,
W: int,
H: int,
background: Optional[torch.Tensor] = None,
composite: str = "over",
radius_m: float = 0.025,
include_hands: bool = False,
palette: str = "scail",
person_brightness_falloff: float = 0.0,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Render a frame's pose_data as 3D capsules projected through the per-
person camera. Returns (H, W, 3) fp32 in [0, 1].
`composite='over'` paints over `background` (black if None);
`composite='mesh_only'` always uses a black canvas.
`radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`).
Camera fx/fy come from each person's `focal_length` (pixels); cx/cy = center.
"""
persons = pose_data["frames"][frame_idx]
if device is None:
device = comfy.model_management.get_torch_device()
# SAM3DBody shares one camera across the clip — pick from the first valid person.
fx = fy = float(min(H, W))
for person in persons:
f = person.get("focal_length")
if f is None:
continue
fx = fy = float(np.asarray(f, dtype=np.float32).reshape(-1)[0])
break
cx, cy = W * 0.5, H * 0.5
starts_np, ends_np, colors_np = _build_specs_from_pose(
persons, include_hands=include_hands, palette=palette,
person_brightness_falloff=person_brightness_falloff,
)
bg_t: Optional[torch.Tensor] = None
if composite == "over" and background is not None:
bg_t = background.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)
if bg_t.shape[:2] != (H, W):
bg_t = bg_t.permute(2, 0, 1).unsqueeze(0)
bg_t = torch.nn.functional.interpolate(
bg_t, size=(H, W), mode="bilinear", align_corners=False,
)
bg_t = bg_t.squeeze(0).permute(1, 2, 0).contiguous()
if starts_np.shape[0] == 0:
if bg_t is not None:
return bg_t
return torch.zeros(H, W, 3, dtype=torch.float32, device=device)
starts_t = torch.from_numpy(starts_np).to(device=device, dtype=torch.float32)
ends_t = torch.from_numpy(ends_np).to(device=device, dtype=torch.float32)
colors_t = torch.from_numpy(colors_np).to(device=device, dtype=torch.float32)
return _render_capsules_torch(
starts_t, ends_t, colors_t,
H=H, W=W, fx=fx, fy=fy, cx=cx, cy=cy,
radius=float(radius_m),
background_rgb=bg_t,
device=device,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,578 @@
"""GLB export — skeletal (real armature) mode.
Rebuilds an Armature with the MHR 127-bone rig:
- per-frame local TRS comes from re-running param_transform on the saved
`mhr_model_params`;
- rest verts come from a zero-pose forward with each person's `shape_params`;
- sparse triplet skinning is compacted to glTF's max-4-influences-per-vertex form;
- facial expression is re-exposed as 72 morph targets driven by `expr_params`
so face animation survives plain glTF skinning.
Optional bone visualization (octahedrons / sticks) is rigidly
skinned alongside the body mesh used to preview the armature in glTF
viewers that don't draw bones.
Shared GLB infra (writer, math, rig static extraction, shaders, normals)
stays in `glb_shared.py`; only this mode's geometry + assembly live here.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from .glb_shared import (
GLBWriter,
bake_vertex_colors,
bind_skel_state,
bone_locals_from_globals,
collect_tracks,
compact_skin_to_n,
compute_normals,
compute_pastel_mix,
extract_rig_static,
flat_shade_mesh,
gaussian_smooth_quats,
global_skel_state_from_pose_data,
global_skel_state_per_frame,
ibp_from_bind_global,
make_lit_material,
quat_sign_fix_per_joint,
rotation_align,
unflip,
zero_pose_rest_verts,
)
from comfy_extras.sam3d_body.utils import jet_colormap
def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]:
"""Per-bone RGB color (NJ, 3) float32 in [0, 1]. Returns None for 'white'
(no per-bone color bone-vis mesh uses default unlit material)."""
if scheme == "rainbow_y":
y = bind_pos_m[:, 1].astype(np.float32)
y_min, y_max = float(y.min()), float(y.max())
s = np.clip((y - y_min) / max(y_max - y_min, 1e-6), 0.0, 1.0)
return jet_colormap(s)
return None
def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
"""Canonical Blender-style bone octahedron. Head at origin, tail at +Y,
unit length, ridge at 1/10 height. 6 verts, 8 triangles. Faces wound
so cross(v1-v0, v2-v0) points OUTWARD from the bone axis."""
v = np.array([
[0.0, 0.0, 0.0], # 0: head
[0.0, 1.0, 0.0], # 1: tail
[1.0, 0.1, 0.0], # 2: +X ridge (pre-scale; X/Z scale by half_width)
[-1.0, 0.1, 0.0], # 3: -X ridge
[0.0, 0.1, 1.0], # 4: +Z ridge
[0.0, 0.1, -1.0], # 5: -Z ridge
], dtype=np.float32)
f = np.array([
# head pyramid: outward = away from bone axis, slightly -Y
[0, 2, 4], [0, 5, 2], [0, 3, 5], [0, 4, 3],
# tail pyramid: outward = away from bone axis, slightly +Y
[1, 4, 2], [1, 3, 4], [1, 5, 3], [1, 2, 5],
], dtype=np.uint32)
return v, f
def _bone_edges(
joint_pos_m: np.ndarray, parents: np.ndarray,
) -> List[Tuple[int, int, np.ndarray, np.ndarray]]:
"""Return one (parent_idx, child_idx, head_pos, tail_pos) tuple per
parentchild edge in the hierarchy, skipping edges whose PARENT is a
root joint (those typically anchor the skeleton at world origin and
just look like a stray stick from origin to the body). Zero-length
edges are skipped too."""
NJ = joint_pos_m.shape[0]
out: List[Tuple[int, int, np.ndarray, np.ndarray]] = []
for c in range(NJ):
p = int(parents[c])
if not (0 <= p < NJ and p != c):
continue
# Skip if parent itself is a root — that bone is a world-anchor stick.
gp = int(parents[p])
if not (0 <= gp < NJ and gp != p):
continue
head = joint_pos_m[p].astype(np.float32)
tail = joint_pos_m[c].astype(np.float32)
if float(np.linalg.norm(tail - head)) < 1e-6:
continue
out.append((p, c, head, tail))
return out
def _build_bone_octahedrons_mesh(
bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""One Blender-style octahedron per parent→child edge. Returns
(verts, normals, faces, joints, weights, child_idx_per_vert);
child_idx feeds per-bone color lookup at the call site."""
base_v, base_f = _octahedron_unit()
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
out_v: List[List[float]] = []
out_n: List[List[float]] = []
out_f: List[List[int]] = []
out_j: List[List[int]] = []
out_w: List[List[float]] = []
child_per_vert: List[int] = []
# Width scales with length so short bones (fingers, face) don't look chunky
# next to long ones (limbs, spine). `half_width_m` caps long bones.
WIDTH_RATIO = 0.1
MIN_WIDTH = 0.001
for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents):
direction = tail - head
length = float(np.linalg.norm(direction))
if length < 1e-6:
continue
unit_dir = direction / length
R = rotation_align(canonical, unit_dir)
half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m))
scale = np.array([half_width_eff, length, half_width_eff], dtype=np.float32)
v_local = base_v * scale
v_world = v_local @ R.T + head
# head pole outward = -Y, tail pole +Y, ridges outward in XZ.
n_local = np.zeros_like(base_v)
n_local[0] = [0.0, -1.0, 0.0]
n_local[1] = [0.0, 1.0, 0.0]
for k in range(2, 6):
n = base_v[k].copy()
n[1] = 0.0
n_norm = float(np.linalg.norm(n))
if n_norm > 0:
n_local[k] = n / n_norm
n_world = n_local @ R.T
v_off = len(out_v)
out_v.extend(v_world.tolist())
out_n.extend(n_world.tolist())
for face in base_f:
out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off])
# Dual skin head→parent, tail→child, ridges blend by canonical Y so the
# bone stretches between joints instead of going rigid with one.
for k in range(base_v.shape[0]):
y_canon = float(base_v[k, 1])
w_parent = max(0.0, 1.0 - y_canon)
w_child = max(0.0, y_canon)
wsum = w_parent + w_child
if wsum > 0:
w_parent /= wsum
w_child /= wsum
out_j.append([int(parent_idx), int(child_idx), 0, 0])
out_w.append([w_parent, w_child, 0.0, 0.0])
child_per_vert.append(int(child_idx))
if not out_v:
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32),
np.zeros((0, 3), dtype=np.uint32), np.zeros((0, 4), dtype=np.uint16),
np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=np.int64))
return (np.asarray(out_v, dtype=np.float32),
np.asarray(out_n, dtype=np.float32),
np.asarray(out_f, dtype=np.uint32),
np.asarray(out_j, dtype=np.uint16),
np.asarray(out_w, dtype=np.float32),
np.asarray(child_per_vert, dtype=np.int64))
def build_glb_skeletal(
pose_data: Dict[str, Any],
model: Any = None,
*,
fps: float = 24.0,
camera_translation: str = "off",
track_index: int = -1,
include_face_morphs: bool = True,
shader: str = "default",
rainbow_tilt_x_deg: float = 0.0,
rainbow_tilt_z_deg: float = 0.0,
person_palette_falloff: float = 0.6,
bone_smooth_window: int = 0,
use_stored_global_rots: bool = True,
bone_vis: str = "off",
bone_vis_radius_m: float = 0.04,
bone_vis_color: str = "white",
include_body_mesh: bool = True,
) -> bytes:
"""Build pose_data as a real Armature GLB blob with per-bone TRS keyframes.
For MHR (default) facial expression is exposed as 72 morph targets driven
by expr_params per frame when include_face_morphs=True.
External skeletons (e.g. ComfyUI-Kimodo) can supply a
``pose_data["_skeleton_override"]`` dict to bypass the MHR rig extraction
entirely. When present, ``model`` may be None and the rig data, bind pose,
skin weights, and rest verts come from the override. Per-frame skeletal
state still reads ``pred_global_rots`` / ``pred_joint_coords`` from each
person dict (kimodo populates these from its own FK output). See
``glb.shared._get_skeleton_override`` for the override schema.
"""
frames = pose_data["frames"]
# Only `pred_cam_t` is camera-y-down; mhr_model_params, lbs_*, expr_basis,
# faces are all rig-native (Y-up).
faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32)
tracks = collect_tracks(pose_data, track_index)
if not tracks:
raise ValueError("build_glb_skeletal: no valid tracks in pose_data")
rig_static = extract_rig_static(model, pose_data)
NJ = rig_static["num_joints"]
NV = rig_static["num_verts"]
NEXPR = rig_static["num_expr"]
parents = rig_static["parents"]
is_external = bool(rig_static.get("_external", False))
if is_external:
# External rigs have no PCA pose params to re-run; only stored globals
# are available, and kimodo stores joint coords already Y-up.
use_stored_global_rots = True
joint_coords_y_down = not is_external
# Compact sparse skinning to 8 influences per vertex into glTF's two
# JOINTS_*/WEIGHTS_* sets. MHR averages ~2.8 influences/vert but some
# shoulder/hip verts have 5-8 where multiple joints cancel — keeping only
# 4 there leaks per-bone rotation noise into the rendered mesh.
if is_external:
joints_8 = rig_static["lbs_compact_joints"]
weights_8 = rig_static["lbs_compact_weights"]
actual_max_inf = rig_static["lbs_compact_max_inf"]
else:
joints_8, weights_8, actual_max_inf = compact_skin_to_n(
rig_static["lbs_skin_indices"], rig_static["lbs_vert_indices"],
rig_static["lbs_skin_weights"], NV, max_inf=8,
)
joints_set0 = np.ascontiguousarray(joints_8[:, :4])
weights_set0 = np.ascontiguousarray(weights_8[:, :4])
use_set1 = actual_max_inf > 4
joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None
weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None
# Derive bone locals from the rig's bind globals rather than recomputing
# FK ourselves, so any mismatch between `parents` and the rig's actual FK
# is absorbed into the local TRS instead of producing wrong globals.
bind_global_cm = bind_skel_state(model, pose_data)
bind_global_m = bind_global_cm.copy().astype(np.float32)
bind_global_m[:, :3] *= 0.01
bind_local = bone_locals_from_globals(bind_global_m[None], rig_static["parents"])[0]
# IBP = inverse of bind global. With bone defaults set to bind_local and
# FK composed via `parents`, skin_matrix at rest = identity.
ibp_mat4 = ibp_from_bind_global(bind_global_m)
w = GLBWriter()
nodes: List[dict] = []
meshes: List[dict] = []
skins: List[dict] = []
materials: List[dict] = []
animations: List[dict] = []
scene_root_indices: List[int] = []
canonical_colors = pose_data.get("canonical_colors")
indices_acc = w.add_indices_u32(faces_native)
joints0_acc = w.add_joints_u16(joints_set0)
weights0_acc = w.add_weights_f32(weights_set0)
joints1_acc = w.add_joints_u16(joints_set1) if use_set1 else None
weights1_acc = w.add_weights_f32(weights_set1) if use_set1 else None
ibm_acc = w.add_mat4_f32(ibp_mat4)
expr_morph_accs: List[int] = []
if include_face_morphs and NEXPR > 0:
eb = rig_static["expr_basis"].astype(np.float32) * 0.01
for e in range(NEXPR):
expr_morph_accs.append(w.add_vec3_f32_no_minmax(eb[e]))
for track_i, (person_k, frame_indices) in enumerate(tracks):
person_root = {"name": f"track{track_i:02d}", "children": []}
nodes.append(person_root)
person_root_idx = len(nodes) - 1
scene_root_indices.append(person_root_idx)
bone_node_indices: List[int] = []
for j in range(NJ):
bone = {
"name": f"bone_{j:03d}",
"translation": bind_local[j, :3].tolist(),
"rotation": bind_local[j, 3:7].tolist(),
"scale": [float(bind_local[j, 7])] * 3,
}
nodes.append(bone)
bone_node_indices.append(len(nodes) - 1)
bone_children: List[List[int]] = [[] for _ in range(NJ)]
bone_root_indices: List[int] = []
for j in range(NJ):
p = int(parents[j])
if 0 <= p < NJ and p != j:
bone_children[p].append(bone_node_indices[j])
else:
bone_root_indices.append(bone_node_indices[j])
for j in range(NJ):
if bone_children[j]:
nodes[bone_node_indices[j]]["children"] = bone_children[j]
person_root["children"].extend(bone_root_indices)
skin = {
"joints": bone_node_indices,
"inverseBindMatrices": ibm_acc,
"skeleton": bone_root_indices[0] if bone_root_indices else bone_node_indices[0],
}
skins.append(skin)
skin_idx = len(skins) - 1
include_body = bool(include_body_mesh)
include_bones = bone_vis in ("octahedrons", "sticks")
body_mesh_node_idx: Optional[int] = None
if include_body:
# External rigs have no PCA shape — `zero_pose_rest_verts` short-
# circuits to `pose_data["_skeleton_override"]["rest_verts_m"]`,
# so zeroed shape_params is safe there.
if is_external:
shape_params_arr = np.zeros(0, dtype=np.float32)
else:
shape_params_arr = np.asarray(
frames[frame_indices[0]][person_k]["shape_params"], dtype=np.float32,
)
rest_v = zero_pose_rest_verts(model, shape_params_arr, pose_data=pose_data)
normals = compute_normals(rest_v, faces_native)
positions_acc = w.add_vec3_f32(rest_v)
normals_acc = w.add_vec3_f32(normals)
pastel_mix = compute_pastel_mix(track_i, person_palette_falloff)
vcolor = bake_vertex_colors(
canonical_colors, shader,
rainbow_tilt_x_deg, rainbow_tilt_z_deg, pastel_mix,
)
color_acc = w.add_vec3_f32(vcolor) if vcolor is not None else None
attributes = {
"POSITION": positions_acc, "NORMAL": normals_acc,
"JOINTS_0": joints0_acc, "WEIGHTS_0": weights0_acc,
}
if joints1_acc is not None:
attributes["JOINTS_1"] = joints1_acc
attributes["WEIGHTS_1"] = weights1_acc
if color_acc is not None:
attributes["COLOR_0"] = color_acc
primitive = {
"attributes": attributes,
"indices": indices_acc,
"mode": 4,
}
if color_acc is not None:
materials.append(make_lit_material())
primitive["material"] = len(materials) - 1
if expr_morph_accs:
primitive["targets"] = [{"POSITION": a} for a in expr_morph_accs]
mesh = {"primitives": [primitive]}
if expr_morph_accs:
mesh["weights"] = [0.0] * len(expr_morph_accs)
meshes.append(mesh)
mesh_idx = len(meshes) - 1
mesh_node = {
"name": f"track{track_i:02d}_mesh", "mesh": mesh_idx, "skin": skin_idx,
}
nodes.append(mesh_node)
body_mesh_node_idx = len(nodes) - 1
person_root["children"].append(body_mesh_node_idx)
if include_bones:
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
# Indexes `bone_palette`: octahedrons/sticks use the bone's child
# joint so every bone has its own color regardless of skin target.
# 'sticks' = thin octahedrons. glTF LINES skinning is unreliable
# (Three.js' GLTFLoader doesn't animate skinned line primitives),
# so we render triangle tubes instead.
color_idx_per_vert: Optional[np.ndarray] = None
hw = float(bone_vis_radius_m) if bone_vis == "octahedrons" else 0.0035
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
bind_global_m[:, :3], rig_static["parents"], half_width_m=hw,
)
if bv_v.shape[0] > 0:
F = bv_f.shape[0]
expanded_child = np.empty((F * 3,), dtype=np.int64)
for k in range(3):
expanded_child[k::3] = child_per_vert[bv_f[:, k]]
bv_v, bv_n, bv_f, bv_j, bv_w = flat_shade_mesh(bv_v, bv_f, bv_j, bv_w)
color_idx_per_vert = expanded_child
primitive_mode = 4
bv_idx_flat = bv_f.reshape(-1)
if bv_v.shape[0] > 0:
bv_pos_acc = w.add_vec3_f32(bv_v)
bv_idx_acc = w.add_indices_u32(bv_idx_flat)
bv_j_acc = w.add_joints_u16(bv_j)
bv_w_acc = w.add_weights_f32(bv_w)
bv_attrs = {
"POSITION": bv_pos_acc,
"JOINTS_0": bv_j_acc, "WEIGHTS_0": bv_w_acc,
}
if bv_n is not None:
bv_attrs["NORMAL"] = w.add_vec3_f32(bv_n)
if bone_palette is not None and color_idx_per_vert is not None:
bv_color = bone_palette[color_idx_per_vert].astype(np.float32)
bv_attrs["COLOR_0"] = w.add_vec3_f32(bv_color)
bv_primitive = {
"attributes": bv_attrs,
"indices": bv_idx_acc,
"mode": primitive_mode,
}
if bone_palette is not None:
materials.append(make_lit_material())
bv_primitive["material"] = len(materials) - 1
bv_mesh = {"primitives": [bv_primitive]}
meshes.append(bv_mesh)
bv_mesh_node = {
"name": f"track{track_i:02d}_bones",
"mesh": len(meshes) - 1,
"skin": skin_idx,
}
nodes.append(bv_mesh_node)
person_root["children"].append(len(nodes) - 1)
# Per-frame GLOBAL skel state → bone locals via parent-inverse.
# Default uses the rig's stored output; the fallback re-runs FK.
if use_stored_global_rots:
rig_global_m = global_skel_state_from_pose_data(
pose_data, frame_indices, person_k, NJ,
joint_coords_y_down=joint_coords_y_down,
)
else:
mp_per_frame = np.stack([
np.asarray(frames[t][person_k]["mhr_model_params"], dtype=np.float32)
for t in frame_indices
], axis=0)
rig_global_cm = global_skel_state_per_frame(model, mp_per_frame)
rig_global_m = rig_global_cm.copy().astype(np.float32)
rig_global_m[..., :3] *= 0.01
# Sign-fix on the GLOBAL quats BEFORE deriving locals. The rig's
# Euler-XYZ parametrization wraps at ±180° for spinning joints; if we
# only fix locals, the parent's flip propagates into the child's
# local translation (t_local inherits parent sign via q_parent_inv)
# and produces visible "axis resets" mid-animation.
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
bone_local_anim = bone_locals_from_globals(rig_global_m, rig_static["parents"])
local_t = bone_local_anim[..., :3].astype(np.float32)
local_q = bone_local_anim[..., 3:7].astype(np.float32)
local_s = bone_local_anim[..., 7].astype(np.float32)
# Second pass on locals catches residual drift from the parent-inverse.
local_q = quat_sign_fix_per_joint(local_q)
# Hemisphere-align frame 0 with the bind quat so pause/play takes the
# short path; then re-propagate.
bind_q = bind_local[:, 3:7].astype(np.float32)
if local_q.shape[0] > 0:
d0 = (bind_q * local_q[0]).sum(axis=-1)
sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None]
local_q[0] = local_q[0] * sign0
local_q = quat_sign_fix_per_joint(local_q)
# Optional smoothing for multi-frame rig spikes (e.g. q.w discontinuity
# at handstand) that the upstream Smooth node may not catch.
if bone_smooth_window and bone_smooth_window > 1:
local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window))
# fp64 renormalize → fp32 keyframes. Viewers' nlerp amplifies non-unit
# drift into visible flips otherwise.
lq64 = local_q.astype(np.float64)
lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12)
local_q = lq64.astype(np.float32)
n_frames = len(frame_indices)
times = np.asarray(frame_indices, dtype=np.float32) / float(fps)
time_acc = w.add_scalar_f32(times)
samplers: List[dict] = []
channels: List[dict] = []
for j in range(NJ):
t_j = local_t[:, j, :]
q_j = local_q[:, j, :]
s_j = np.broadcast_to(local_s[:, j:j+1], (n_frames, 3)).astype(np.float32)
t_const = (np.ptp(t_j, axis=0) < 1e-6).all()
q_const = (np.ptp(q_j, axis=0) < 1e-6).all()
s_const = (np.ptp(s_j, axis=0) < 1e-6).all()
if t_const:
nodes[bone_node_indices[j]]["translation"] = t_j[0].tolist()
else:
acc = w.add_vec3_f32_anim(t_j)
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": bone_node_indices[j], "path": "translation"},
})
if q_const:
nodes[bone_node_indices[j]]["rotation"] = q_j[0].tolist()
else:
acc = w.add_vec4_f32(q_j)
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": bone_node_indices[j], "path": "rotation"},
})
if s_const:
nodes[bone_node_indices[j]]["scale"] = s_j[0].tolist()
else:
acc = w.add_vec3_f32_anim(s_j)
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": bone_node_indices[j], "path": "scale"},
})
if camera_translation != "off":
cam_t = np.stack([
unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32))
for t in frame_indices
], axis=0)
if camera_translation == "centered" and cam_t.shape[0] > 0:
cam_t = cam_t - cam_t[0:1]
if (np.ptp(cam_t, axis=0) < 1e-6).all():
person_root["translation"] = cam_t[0].tolist()
else:
acc = w.add_vec3_f32_anim(cam_t)
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": person_root_idx, "path": "translation"},
})
# Body-mesh-only: bone-vis primitives have no morph targets.
if expr_morph_accs and body_mesh_node_idx is not None:
expr_per_frame = np.stack([
np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32)
for t in frame_indices
], axis=0).astype(np.float32)
weights_acc_anim = w.add_scalar_f32_flat(expr_per_frame, count=n_frames * NEXPR)
samplers.append({"input": time_acc, "output": weights_acc_anim, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": body_mesh_node_idx, "path": "weights"},
})
animations.append({
"name": f"track{track_i:02d}",
"samplers": samplers, "channels": channels,
})
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI-SAM3DBody"},
"scene": 0,
"scenes": [{"nodes": scene_root_indices}],
"nodes": nodes,
"meshes": meshes,
"skins": skins,
}
if materials:
gltf["materials"] = materials
if animations:
gltf["animations"] = animations
return w.to_bytes(gltf)

View File

@ -0,0 +1,233 @@
"""2D OpenPose-style skeleton rendering for SAM 3D Body pose_data.
Body / hand drawing is delegated to `KeypointDraw.draw_wholebody_keypoints`
(shared with SDPose). SAM3D-specific: MHR70 -> DWPose-134 keypoint packing,
plus optional rig-projected face landmarks when `pred_face_keypoints_2d`
isn't present (and arbitrary-count face dots, since sapiens-238 doesn't fit
the DWPose face slot).
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
"""
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from PIL import Image
from comfy_extras.pose.keypoint_draw import KeypointDraw
from .glb_shared import (
OPENPOSE18_TO_MHR70,
OPENPOSE_HAND21_TO_MHR70_L,
OPENPOSE_HAND21_TO_MHR70_R,
OPENPOSE_HAND_COLORS_21,
select_face_landmark_vert_ids,
)
_KD = KeypointDraw()
# OpenPose hand palette as a (21, 3) int array (0..255) for KeypointDraw.
_HAND_DOT_PALETTE_OPENPOSE = (OPENPOSE_HAND_COLORS_21 * 255.0).astype(int)
def _project_face_landmarks_2d(
person: Dict[str, Any], face_vert_ids: np.ndarray, H: int, W: int,
) -> Optional[np.ndarray]:
"""Project `pred_vertices[face_vert_ids]` to 2D using each person's
pred_cam_t + focal_length. Same projection used by `_replay_mhr_with_overrides`."""
verts = person.get("pred_vertices")
cam_t = person.get("pred_cam_t")
focal = person.get("focal_length")
if verts is None or cam_t is None or focal is None:
return None
verts = np.asarray(verts, dtype=np.float32)
cam_t = np.asarray(cam_t, dtype=np.float32).reshape(3)
f = float(np.asarray(focal, dtype=np.float32).reshape(-1)[0])
pts3 = verts[face_vert_ids] + cam_t[None, :]
z = np.maximum(pts3[:, 2:3], 1e-6)
xy = pts3[:, :2] * f
xy = xy + np.array([W * 0.5, H * 0.5], dtype=np.float32)[None, :] * z
return (xy / z).astype(np.float32)
def _pack_dwpose_134(
person: Dict[str, Any], *, include_body: bool, include_hands: bool,
) -> Tuple[np.ndarray, np.ndarray]:
"""Pack a SAM3D person dict into (kp, scores): (134, 2) DWPose-layout
coords + (134,) confidence. Face slot (24-91) is left zeroed; face dots
are drawn separately so SAM3D's 238-sapiens / rig-fallback counts work.
Non-finite or out-of-band entries get score=0 and are filtered downstream."""
kp = np.zeros((134, 2), dtype=np.float32)
scores = np.zeros(134, dtype=np.float32)
kp2d_full = person.get("pred_keypoints_2d")
if kp2d_full is None:
return kp, scores
kp2d = np.asarray(kp2d_full, dtype=np.float32)
if kp2d.ndim != 2 or kp2d.shape[1] != 2 or kp2d.shape[0] < 70:
return kp, scores
if include_body:
body_xy = kp2d[OPENPOSE18_TO_MHR70]
finite = np.isfinite(body_xy).all(axis=1)
kp[:18][finite] = body_xy[finite]
scores[:18][finite] = 1.0
if include_hands:
for slot_start, mhr_idx in ((92, OPENPOSE_HAND21_TO_MHR70_R),
(113, OPENPOSE_HAND21_TO_MHR70_L)):
hand_xy = kp2d[mhr_idx]
finite = np.isfinite(hand_xy).all(axis=1)
kp[slot_start:slot_start + 21][finite] = hand_xy[finite]
scores[slot_start:slot_start + 21][finite] = 1.0
return kp, scores
def _draw_face_dots(
canvas: np.ndarray, face_xy: np.ndarray, marker_radius_px: int,
) -> None:
"""White face dots, variable count (238 sapiens / ~30 rig-projected)."""
H, W = canvas.shape[:2]
pad = int(marker_radius_px)
white = (255, 255, 255)
for i in range(face_xy.shape[0]):
x_, y_ = float(face_xy[i, 0]), float(face_xy[i, 1])
if not (np.isfinite(x_) and np.isfinite(y_)):
continue
x, y = int(round(x_)), int(round(y_))
if x + pad < 0 or x - pad >= W or y + pad < 0 or y - pad >= H:
continue
_KD.draw.circle(canvas, (x, y), int(marker_radius_px), white, thickness=-1)
def render_pose_data_openpose(
pose_data: Dict[str, Any],
*,
frame_idx: int,
W: int,
H: int,
background: Optional[torch.Tensor] = None,
composite: str = "over",
marker_radius_px: int = 4,
stick_width_px: int = 4,
limb_alpha: float = 0.6,
include_body: bool = True,
include_hands: bool = False,
face_style: str = "disabled",
hand_color_style: str = "dwpose",
hand_marker_radius_px: int = 0,
hand_stick_width_px: int = 0,
face_marker_radius_px: int = 3,
person_brightness_falloff: float = 0.0,
) -> torch.Tensor:
"""Render a 2D OpenPose-style skeleton onto an (H, W, 3) canvas.
`composite='over'` paints over `background` (else black canvas).
`hand_marker_radius_px` / `hand_stick_width_px`: 0 = auto = 0.7x / 0.5x
of the body sizes.
`face_style`: 'disabled' = no face dots, 'full' = all face landmarks
(prefers sapiens-238 if present, else rig-fallback ~30), 'eyes_mouth' =
rig-fallback subset (~12 dots: eyes + mouth only). The subset only has a
documented layout for the rig fallback so eyes_mouth always uses it,
regardless of whether sapiens-238 is available.
`person_brightness_falloff` mixes each person's drawn pixels toward white
by `1 - falloff^k` (track 0 stays vivid). Applied post-draw so per-limb
alpha blending against the existing canvas remains correct.
"""
persons = pose_data["frames"][frame_idx]
if composite == "over" and background is not None:
bg = background.cpu().numpy()
canvas = (np.clip(bg, 0.0, 1.0) * 255.0).astype(np.uint8)
if canvas.shape[:2] != (H, W):
canvas = np.array(Image.fromarray(canvas).resize((W, H), Image.LANCZOS))
else:
canvas = np.zeros((H, W, 3), dtype=np.uint8)
# In-place draw needs a contiguous writable buffer.
canvas = np.ascontiguousarray(canvas)
if int(hand_marker_radius_px) <= 0:
hand_marker_radius_px = max(1, int(round(marker_radius_px * 0.7)))
if int(hand_stick_width_px) <= 0:
hand_stick_width_px = max(1, int(round(stick_width_px * 0.5)))
# Eyes+mouth indices into the rig fallback (FACE_LANDMARK_TARGETS): 6..13
# = both eyes, 19..22 = outer-lip ring. Brows/nose/chin/jaw are dropped.
_EYES_MOUTH_IDX = np.array([6, 7, 8, 9, 10, 11, 12, 13, 19, 20, 21, 22], dtype=np.int64)
include_face = face_style != "disabled"
use_rig_only = face_style == "eyes_mouth"
# Real 238 sapiens face KPs take priority for 'full'; 'eyes_mouth' always
# falls through to the rig path since sapiens has no documented subset.
face_vert_ids: Optional[np.ndarray] = None
if include_face:
any_real = (not use_rig_only) and any(
p.get("pred_face_keypoints_2d") is not None for p in persons
)
if not any_real:
cc = pose_data.get("canonical_colors") or {}
positions = cc.get("positions")
if positions is not None:
try:
face_vert_ids = select_face_landmark_vert_ids(
np.asarray(positions), face_mask=cc.get("face_mask"),
)
if use_rig_only:
face_vert_ids = face_vert_ids[_EYES_MOUTH_IDX]
except Exception as e:
print(f"[SAM3DBody] face landmarks disabled - {e}")
face_vert_ids = None
hand_dot_color = (
_HAND_DOT_PALETTE_OPENPOSE if hand_color_style == "openpose"
else (0, 0, 255)
)
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
for k, person in enumerate(persons):
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
# Snapshot before this person's strokes so we can identify the pixels
# they touched and blend just those toward white. Drawing happens
# against the live canvas first so limb_alpha blends correctly.
pre = canvas.copy() if pastel > 0 else None
kp134, scores134 = _pack_dwpose_134(
person, include_body=include_body, include_hands=include_hands,
)
_KD.draw_wholebody_keypoints(
canvas, kp134, scores=scores134, threshold=0.5,
draw_body=include_body, draw_feet=False,
draw_face=False, # SAM3D draws face dots separately (variable count)
draw_hands=include_hands,
stick_width=stick_width_px,
marker_radius=marker_radius_px,
hand_stick_width=hand_stick_width_px,
hand_marker_radius=hand_marker_radius_px,
limb_alpha=limb_alpha,
hand_dot_color=hand_dot_color,
)
if include_face:
face_xy = None
real_face = person.get("pred_face_keypoints_2d")
if real_face is not None:
arr = np.asarray(real_face, dtype=np.float32)
if arr.ndim == 2 and arr.shape[1] == 2:
face_xy = arr
elif face_vert_ids is not None:
face_xy = _project_face_landmarks_2d(person, face_vert_ids, H, W)
if face_xy is not None:
_draw_face_dots(canvas, face_xy, face_marker_radius_px)
if pre is not None:
changed = (canvas != pre).any(axis=-1)
if changed.any():
touched = canvas[changed].astype(np.float32)
blended = touched * (1.0 - pastel) + 255.0 * pastel
canvas[changed] = np.clip(blended, 0.0, 255.0).astype(np.uint8)
return torch.from_numpy(canvas.astype(np.float32) / 255.0)

View File

@ -0,0 +1,516 @@
"""Face expression for SAM 3D Body.
Pipeline: comfy_extras.mediapipe.face_landmarker 52 ARKit blendshapes
72-dim MHR expr_params (mapping inlined below).
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import comfy.model_management
# Bypass deadzone — jaw signals are clean (open or not)
_NOISE_FREE_BLENDSHAPES = {"jawOpen", "jawForward", "jawLeft", "jawRight"}
# Per-region gain — MP magnitudes vary by family (jaw up to 1.0, eye/brow
# rarely past 0.3), so a single global gain over/underdrives.
_REGION_PREFIXES = {
"mouth": ("jaw", "mouth"),
"eye": ("eye",),
"brow": ("brow", "cheek", "nose"), # cheek/nose read as upper-face
}
def _region_of(arkit_name: str) -> str:
for region, prefixes in _REGION_PREFIXES.items():
for p in prefixes:
if arkit_name.startswith(p):
return region
return "other"
# MHR axis → ARKit driver(s). Each axis collects 1-3 (name, weight) entries;
# the consumer takes max() across them so primary + aux contributions don't
# stack. MHR's 72 expression axes ship as anonymous `shape_c_N` channels in
# the upstream FBX (no semantic names), so this table is hand-derived by
# visual inspection of which axis each ARKit shape drives. Axes 2/3 and
# 12/13 are filled by aux routes only. ARKit shapes with no MHR analog are simply absent.
_AXIS_TO_ARKIT: Dict[int, List[Tuple[str, float]]] = {
0: [("browDownLeft", 1.0)],
1: [("browDownRight", 1.0)],
2: [("cheekPuff", 1.0)],
3: [("cheekPuff", 1.0)],
4: [("cheekSquintLeft", 1.0)],
5: [("cheekSquintRight", 1.0)],
6: [("mouthStretchLeft", 1.0)],
7: [("mouthStretchRight", 1.0)],
8: [("mouthShrugLower", 1.0)],
9: [("mouthShrugUpper", 1.0)],
10: [("mouthDimpleLeft", 1.0)],
11: [("mouthDimpleRight", 1.0)],
12: [("eyeLookDownLeft", 0.3)],
13: [("eyeLookDownRight", 0.3)],
14: [("eyeBlinkLeft", 1.0)],
15: [("eyeBlinkRight", 1.0)],
16: [("eyeLookOutLeft", 1.0)],
17: [("eyeLookInRight", 1.0)],
18: [("eyeLookInLeft", 1.0)],
19: [("eyeLookOutRight", 1.0)],
22: [("eyeLookUpLeft", 1.0), ("browInnerUp", 0.5)],
23: [("eyeLookUpRight", 1.0), ("browInnerUp", 0.5)],
24: [("jawOpen", 1.0), ("mouthLowerDownLeft", 0.3), ("mouthLowerDownRight", 0.3)],
25: [("jawLeft", 1.0)],
26: [("jawRight", 1.0)],
27: [("jawForward", 1.0)],
28: [("eyeSquintLeft", 1.0)],
29: [("eyeSquintRight", 1.0)],
32: [("mouthSmileLeft", 1.0)],
33: [("mouthSmileRight", 1.0)],
40: [("mouthLeft", 1.0)],
41: [("mouthRight", 1.0)],
42: [("mouthFrownLeft", 1.0)],
43: [("mouthFrownRight", 1.0)],
54: [("mouthLowerDownLeft", 1.0)],
55: [("mouthLowerDownRight", 1.0)],
60: [("noseSneerLeft", 1.0)],
61: [("noseSneerRight", 1.0)],
66: [("browOuterUpLeft", 1.0)],
67: [("browOuterUpRight", 1.0)],
68: [("eyeWideLeft", 1.0)],
69: [("eyeWideRight", 1.0)],
70: [("mouthUpperUpLeft", 1.0)],
71: [("mouthUpperUpRight", 1.0)],
}
def _deadzone(x: float, threshold: float) -> float:
"""Zero below threshold, linearly remap (threshold..1] → (0..1] so
amplification doesn't promote MP's per-blendshape noise floor."""
if threshold <= 0.0:
return x
if x <= threshold:
return 0.0
return (x - threshold) / (1.0 - threshold)
def arkit_to_expr_params(
blendshape_coefs: Dict[str, float],
strength: float = 1.0,
mouth_strength: float = 1.0,
eye_strength: float = 1.0,
brow_strength: float = 1.0,
input_threshold: float = 0.0,
n_axes: int = 72,
) -> np.ndarray:
"""Map MediaPipe's 52 ARKit blendshapes to MHR's 72 expr_params axes.
Multiple ARKit names per axis combine via max() so primary + aux routes
don't double up."""
expr = np.zeros(n_axes, dtype=np.float32)
region_scale = {
"mouth": float(mouth_strength), "eye": float(eye_strength),
"brow": float(brow_strength), "other": 1.0,
}
thr = float(input_threshold)
for axis, routes in _AXIS_TO_ARKIT.items():
best = 0.0
for name, weight in routes:
raw = float(blendshape_coefs.get(name, 0.0))
name_thr = 0.0 if name in _NOISE_FREE_BLENDSHAPES else thr
raw = _deadzone(raw, name_thr)
c = raw * region_scale[_region_of(name)] * float(weight)
if c > best:
best = c
expr[axis] = best * strength
return expr
def subtract_per_clip_baseline(
per_frame_coefs: List[Optional[Dict[str, float]]],
percentile: float = 5.0,
) -> List[Optional[Dict[str, float]]]:
"""Subtract per-blendshape p`percentile` baseline, clamp at 0. Adapts to
per-subject MP bias (e.g. resting browOuterUp ~0.15 permanent surprise
under brow_strength=2.0) that a global deadzone can't catch."""
if percentile <= 0.0:
return list(per_frame_coefs)
names: set = set()
for c in per_frame_coefs:
if c is not None:
names.update(c.keys())
baselines: Dict[str, float] = {}
for n in names:
vals = [c[n] for c in per_frame_coefs if c is not None and n in c]
if vals:
baselines[n] = float(np.percentile(vals, percentile))
return [
None if c is None
else {n: max(0.0, float(v) - baselines.get(n, 0.0)) for n, v in c.items()}
for c in per_frame_coefs
]
def smooth_blendshape_series(
per_frame_coefs: List[Optional[Dict[str, float]]],
window: int = 7,
sigma: Optional[float] = None,
) -> List[Optional[Dict[str, float]]]:
"""Gaussian-smooth each coefficient across time. MP per-frame output swings
30-70% on static faces; smoothing pre-mapping cleans better than smoothing
mesh verts. None frames pass through unchanged."""
if window <= 1:
return list(per_frame_coefs)
if window % 2 == 0:
window += 1
if sigma is None:
sigma = max(1.0, window / 5.0)
x = np.arange(window) - (window - 1) / 2.0
k = np.exp(-(x ** 2) / (2 * sigma ** 2))
k = k / k.sum()
names: set = set()
for c in per_frame_coefs:
if c is not None:
names.update(c.keys())
if not names:
return list(per_frame_coefs)
N = len(per_frame_coefs)
pad = window // 2
out: List[Optional[Dict[str, float]]] = [None] * N
for name in names:
series = np.zeros(N, dtype=np.float32)
mask = np.zeros(N, dtype=bool)
for i, c in enumerate(per_frame_coefs):
if c is not None:
series[i] = float(c.get(name, 0.0))
mask[i] = True
if not mask.any():
continue
if not mask.all():
idx = np.arange(N)
series = np.interp(idx, idx[mask], series[mask])
padded = np.concatenate(
[np.repeat(series[:1], pad), series, np.repeat(series[-1:], pad)]
)
filt = np.zeros_like(series)
for i, w in enumerate(k):
filt += w * padded[i: i + N]
for i in range(N):
if per_frame_coefs[i] is None:
continue
if out[i] is None:
out[i] = {}
out[i][name] = float(filt[i])
return out
def fill_detection_gaps(
per_frame_coefs: List[Optional[Dict[str, float]]],
method: str = "interpolate",
max_gap: int = 12,
) -> List[Optional[Dict[str, float]]]:
"""Fill missing per-frame dicts so the signal doesn't slam to zero at
undetected frames. method: 'interpolate' | 'hold' | 'zeros'.
`max_gap` applies to 'interpolate' and 'hold' gaps longer than that stay
None (don't fake too far). 'zeros' ignores `max_gap` on purpose: the goal
there is to relax to neutral on every miss, no matter how long, otherwise
long undetected runs would inherit Predict's per-frame expression."""
if method == "zeros":
names: set = set()
for c in per_frame_coefs:
if c is not None:
names.update(c.keys())
zero = {n: 0.0 for n in names}
return [dict(zero) if c is None else c for c in per_frame_coefs]
N = len(per_frame_coefs)
detected = [i for i, c in enumerate(per_frame_coefs) if c is not None]
if not detected:
return list(per_frame_coefs)
out: List[Optional[Dict[str, float]]] = list(per_frame_coefs)
for fi in range(N):
if out[fi] is not None:
continue
prev_i = next((k for k in range(fi - 1, -1, -1) if per_frame_coefs[k] is not None), None)
next_i = next((k for k in range(fi + 1, N) if per_frame_coefs[k] is not None), None)
if prev_i is None and next_i is None:
continue
max_dist = max(
(fi - prev_i) if prev_i is not None else 10**9,
(next_i - fi) if next_i is not None else 10**9,
)
if max_dist > max_gap:
continue
if method == "hold":
src = per_frame_coefs[prev_i] if prev_i is not None else per_frame_coefs[next_i]
out[fi] = dict(src)
elif method == "interpolate":
if prev_i is None:
out[fi] = dict(per_frame_coefs[next_i])
elif next_i is None:
out[fi] = dict(per_frame_coefs[prev_i])
else:
w = (fi - prev_i) / (next_i - prev_i)
a = per_frame_coefs[prev_i]
b = per_frame_coefs[next_i]
keys = set(a.keys()) | set(b.keys())
out[fi] = {k: (1.0 - w) * a.get(k, 0.0) + w * b.get(k, 0.0) for k in keys}
return out
def detect_faces_in_crop(
inner, image_rgb_uint8: np.ndarray, crop_xyxy: np.ndarray, num_faces: int = 1,
) -> List[dict]:
"""Run detection on a sub-region; remap bbox+landmarks back to full-image
coords. Helps small/distant faces that fall below BlazeFace's min size."""
H, W = image_rgb_uint8.shape[:2]
x1, y1, x2, y2 = (int(round(float(v))) for v in crop_xyxy)
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(W, x2), min(H, y2)
if x2 - x1 < 16 or y2 - y1 < 16:
return []
crop = np.ascontiguousarray(image_rgb_uint8[y1:y2, x1:x2])
faces = inner.face_landmarker.detect_batch([crop], num_faces=num_faces)[0]
bbox_off = np.array([x1, y1, x1, y1], dtype=np.float32)
xy_off = np.array([x1, y1], dtype=np.float32)
for f in faces:
f["bbox_xyxy"] = f["bbox_xyxy"] + bbox_off
f["landmarks_xy"] = f["landmarks_xy"] + xy_off
return faces
# Crop helpers — feed MP a tight head region so it doesn't downsample the face
# to 192px for full-frame detection.
def _expand_bbox(bbox_xyxy: np.ndarray, factor: float, W: int, H: int) -> np.ndarray:
x1, y1, x2, y2 = (float(v) for v in bbox_xyxy)
cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
hw, hh = 0.5 * (x2 - x1) * factor, 0.5 * (y2 - y1) * factor
return np.array([
max(0.0, cx - hw), max(0.0, cy - hh),
min(float(W), cx + hw), min(float(H), cy + hh),
], dtype=np.float32)
def head_region_crop(
person_bbox: np.ndarray, expand: float, W: int, H: int, head_h_frac: float = 0.4,
) -> np.ndarray:
"""Crop upper `head_h_frac` of a body bbox — cropping the whole body wastes
BlazeFace's 128² input on body pixels."""
x1, y1, x2, y2 = (float(v) for v in person_bbox)
body_h = y2 - y1
if body_h <= 0 or x2 - x1 <= 0:
return np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
return _expand_bbox(np.array([x1, y1, x2, y1 + body_h * head_h_frac]), expand, W, H)
# mhr70 convention: first five kp are COCO-style face landmarks in pixel coords.
_FACE_KP_INDICES = (0, 1, 2, 3, 4) # nose, L-eye, R-eye, L-ear, R-ear
def head_crop_from_keypoints(
pred_keypoints_2d: np.ndarray, expand: float, W: int, H: int,
) -> Optional[np.ndarray]:
"""Head crop from SAM3D nose/eyes/ears kp. HEAD_FIT pads forehead/chin
since these only span the central face. None if <2 kp in-frame."""
if pred_keypoints_2d is None:
return None
kp = np.asarray(pred_keypoints_2d, dtype=np.float32)
if kp.ndim != 2 or kp.shape[0] <= max(_FACE_KP_INDICES):
return None
face = kp[list(_FACE_KP_INDICES), :2]
in_frame = (face[:, 0] > 0) & (face[:, 1] > 0) & (face[:, 0] < W) & (face[:, 1] < H)
valid = face[in_frame]
if len(valid) < 2:
return None
x1, x2 = float(valid[:, 0].min()), float(valid[:, 0].max())
y1, y2 = float(valid[:, 1].min()), float(valid[:, 1].max())
cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
span = max(x2 - x1, y2 - y1, 1.0)
half = 0.5 * span * 1.8 * float(expand) # 1.8 = pad forehead+chin
return np.array([
max(0.0, cx - half), max(0.0, cy - half),
min(float(W), cx + half), min(float(H), cy + half),
], dtype=np.float32)
# Face → person assignment when running full-frame detection.
def _iou_xyxy(a: np.ndarray, b: np.ndarray) -> float:
ix1, iy1 = max(a[0], b[0]), max(a[1], b[1])
ix2, iy2 = min(a[2], b[2]), min(a[3], b[3])
iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
inter = iw * ih
if inter <= 0.0:
return 0.0
aw, ah = max(0.0, a[2] - a[0]), max(0.0, a[3] - a[1])
bw, bh = max(0.0, b[2] - b[0]), max(0.0, b[3] - b[1])
union = aw * ah + bw * bh - inter
return float(inter / union) if union > 0.0 else 0.0
def assign_faces_to_persons(
face_bboxes: List[np.ndarray], person_bboxes: List[np.ndarray], min_iou: float = 0.01,
) -> List[Optional[int]]:
if not face_bboxes or not person_bboxes:
return [None] * len(person_bboxes)
assigned: List[Optional[int]] = [None] * len(person_bboxes)
used: set = set()
# Larger persons first — bigger bbox correlates with detectable face.
order = sorted(range(len(person_bboxes)),
key=lambda p: -((person_bboxes[p][2] - person_bboxes[p][0])
* (person_bboxes[p][3] - person_bboxes[p][1])))
for pi in order:
best_iou = min_iou
best_fi = None
pb = person_bboxes[pi]
for fi, fb in enumerate(face_bboxes):
if fi in used:
continue
cx, cy = 0.5 * (fb[0] + fb[2]), 0.5 * (fb[1] + fb[3])
inside = (pb[0] <= cx <= pb[2]) and (pb[1] <= cy <= pb[3])
score = max(_iou_xyxy(fb, pb), 0.5 if inside else 0.0)
if score > best_iou:
best_iou = score
best_fi = fi
if best_fi is not None:
assigned[pi] = best_fi
used.add(best_fi)
return assigned
# Re-run MHR forward after writing expr_params back into pose_frames; updates
# pred_vertices / pred_keypoints_2d/3d / pred_joint_coords / pred_global_rots.
def regenerate_mesh_from_params(inner, pose_frames: List[List[Dict[str, Any]]]) -> None:
"""Re-run MHR forward and write verts/kp3d/kp2d/joint back in place.
Drives MHR via euler params directly because hand refinement zeroes
pred_pose_raw."""
device = comfy.model_management.get_torch_device()
head = inner.head_pose
if head.mhr is None:
return
B = len(pose_frames)
max_p = max((len(f) for f in pose_frames), default=0)
for pid in range(max_p):
grots, bpps, hands, shapes, scales, exprs, cam_ts, fls = [], [], [], [], [], [], [], []
present: List[bool] = []
for fi in range(B):
if pid >= len(pose_frames[fi]):
present.append(False)
continue
p = pose_frames[fi][pid]
needed = ("global_rot", "body_pose_params", "hand_pose_params",
"shape_params", "scale_params", "expr_params",
"pred_cam_t", "focal_length")
if any(p.get(k) is None for k in needed):
present.append(False)
continue
grots.append(np.asarray(p["global_rot"], dtype=np.float32))
bpps.append(np.asarray(p["body_pose_params"], dtype=np.float32))
hands.append(np.asarray(p["hand_pose_params"], dtype=np.float32))
shapes.append(np.asarray(p["shape_params"], dtype=np.float32))
scales.append(np.asarray(p["scale_params"], dtype=np.float32))
exprs.append(np.asarray(p["expr_params"], dtype=np.float32))
cam_ts.append(np.asarray(p["pred_cam_t"], dtype=np.float32))
fls.append(float(np.asarray(p["focal_length"]).reshape(-1)[0]))
present.append(True)
if not any(present):
continue
global_rot_euler = torch.from_numpy(np.stack(grots)).to(device)
body_pose_euler = torch.from_numpy(np.stack(bpps)).to(device)
hand_t = torch.from_numpy(np.stack(hands)).to(device)
shape_t = torch.from_numpy(np.stack(shapes)).to(device)
scale_t = torch.from_numpy(np.stack(scales)).to(device)
expr_t = torch.from_numpy(np.stack(exprs)).to(device)
cam_t_t = torch.from_numpy(np.stack(cam_ts)).to(device)
f_t = torch.tensor(fls, device=device, dtype=torch.float32)
verts, kp3d_full, joint_coords, _, joint_rotmats = head.mhr_forward(
global_trans=torch.zeros_like(global_rot_euler),
global_rot=global_rot_euler,
body_pose_params=body_pose_euler,
hand_pose_params=hand_t,
scale_params=scale_t,
shape_params=shape_t,
expr_params=expr_t,
return_keypoints=True,
return_joint_coords=True,
return_model_params=True,
return_joint_rotations=True,
)
# y/z flip matches head_pose.forward (camera-y-down convention).
verts = verts.clone()
verts[..., [1, 2]] *= -1
kp3d = kp3d_full[:, :70].clone()
kp3d[..., [1, 2]] *= -1
# 238 sapiens face landmarks (70:308) — track retargeted expression
# so openpose face dots follow new mouth/eye/brow shape.
kp3d_face = kp3d_full[:, 70:].clone()
kp3d_face[..., [1, 2]] *= -1
joint_coords = joint_coords.clone()
joint_coords[..., [1, 2]] *= -1
# Recover principal point from any raw frame for reprojection.
cx = cy = 0.0
for fi in range(B):
if not present[fi]:
continue
raw = pose_frames[fi][pid]
kp2d_r = np.asarray(raw["pred_keypoints_2d"], dtype=np.float32)
kp3d_r = np.asarray(raw["pred_keypoints_3d"], dtype=np.float32)
ct_r = np.asarray(raw["pred_cam_t"], dtype=np.float32)
fl_r = float(np.asarray(raw["focal_length"]).reshape(-1)[0])
x, y, z = kp3d_r[0] + ct_r
cx = float(kp2d_r[0, 0] - fl_r * x / max(z, 1e-6))
cy = float(kp2d_r[0, 1] - fl_r * y / max(z, 1e-6))
break
def _project_kp(kp3d_local: torch.Tensor) -> torch.Tensor:
kp3d_cam = kp3d_local + cam_t_t.unsqueeze(1)
u = f_t[:, None] * kp3d_cam[..., 0] / kp3d_cam[..., 2].clamp(min=1e-6) + cx
v = f_t[:, None] * kp3d_cam[..., 1] / kp3d_cam[..., 2].clamp(min=1e-6) + cy
return torch.stack([u, v], dim=-1)
kp2d = _project_kp(kp3d)
kp2d_face = _project_kp(kp3d_face)
verts_np = verts.float().cpu().numpy()
kp3d_np = kp3d.float().cpu().numpy()
kp2d_np = kp2d.float().cpu().numpy()
kp3d_face_np = kp3d_face.float().cpu().numpy()
kp2d_face_np = kp2d_face.float().cpu().numpy()
jc_np = joint_coords.float().cpu().numpy()
jrot_np = joint_rotmats.float().cpu().numpy()
fi_active = 0
for fi in range(B):
if not present[fi]:
continue
pose_frames[fi][pid] = dict(pose_frames[fi][pid])
p = pose_frames[fi][pid]
p["pred_vertices"] = verts_np[fi_active]
p["pred_keypoints_3d"] = kp3d_np[fi_active]
p["pred_keypoints_2d"] = kp2d_np[fi_active]
p["pred_face_keypoints_3d"] = kp3d_face_np[fi_active]
p["pred_face_keypoints_2d"] = kp2d_face_np[fi_active]
p["pred_joint_coords"] = jc_np[fi_active]
p["pred_global_rots"] = jrot_np[fi_active]
fi_active += 1

View File

@ -0,0 +1,467 @@
"""Pure-PyTorch rasterizer for SAM 3D Body meshes.
Algorithm: forward triangle rasterizer with hard z-buffer. Per-face screen
bbox cull faces sorted by bbox size and chunked under a fixed pixel
budget inside-test via edge functions, barycentric interpolation, depth
test via `scatter_reduce_(amin)`.
"""
from typing import Sequence
import numpy as np
import torch
import comfy.model_management
from .utils import jet_colormap
_CANONICAL_PRESETS = {"rainbow", "rainbow_face_normal", "rainbow_face_semantic"}
_rainbow_cache: dict = {}
def rainbow_colors_from_canonical(
positions: np.ndarray,
tilt_x_deg: float = 0.0,
tilt_z_deg: float = 0.0,
) -> np.ndarray:
"""Compute per-vertex jet-colormap RGB from canonical (T-pose, Y-up) vertices.
Args:
positions: (N_v, 3) canonical vertex positions, Y-up (head at max Y).
tilt_x_deg: rotation of the jet axis around X (in degrees). Positive
biases the ramp toward +Z (front).
tilt_z_deg: rotation of the jet axis around Z (in degrees). Positive
biases the ramp toward +X (right, in body frame).
Returns:
(N_v, 3) float32 RGB in [0, 1].
"""
key = (id(positions), round(float(tilt_x_deg), 3), round(float(tilt_z_deg), 3))
cached = _rainbow_cache.get(key)
if cached is not None:
return cached
theta_x = np.deg2rad(tilt_x_deg)
theta_z = np.deg2rad(tilt_z_deg)
axis = np.array([
np.sin(theta_z),
np.cos(theta_z) * np.cos(theta_x),
np.cos(theta_z) * np.sin(theta_x),
], dtype=np.float32)
s = positions @ axis
s = (s - s.min()) / max(float(s.max() - s.min()), 1e-8)
s = np.clip(s * 0.98, 0.0, 1.0).astype(np.float32)
colors = jet_colormap(s)
_rainbow_cache[key] = colors
if len(_rainbow_cache) > 32:
_rainbow_cache.pop(next(iter(_rainbow_cache)))
return colors
def _vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
"""Area-weighted per-vertex normals; matches `_compute_vertex_normals`."""
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
vn = torch.zeros_like(verts)
vn.index_add_(0, faces[:, 0], fn)
vn.index_add_(0, faces[:, 1], fn)
vn.index_add_(0, faces[:, 2], fn)
return vn / vn.norm(dim=-1, keepdim=True).clamp(min=1e-8)
def _build_vcolor(canonical_colors, shader_preset, tilt_x, tilt_z):
"""Mirrors the canonical_colors -> per-vertex RGB pipeline in
`rasterizer.render_pose_data`. Returns a numpy float32 (V, 3) table."""
positions = np.asarray(canonical_colors.get("positions"), dtype=np.float32)
vcolor = rainbow_colors_from_canonical(positions, tilt_x_deg=tilt_x, tilt_z_deg=tilt_z).copy()
if shader_preset in ("rainbow_face_normal", "rainbow_face_semantic"):
face_mask = canonical_colors.get("face_mask")
if face_mask is not None and face_mask.any():
if shader_preset == "rainbow_face_normal":
norm = np.asarray(canonical_colors["norm"], dtype=np.float32)
vcolor[face_mask] = norm[face_mask]
else: # rainbow_face_semantic
sem = np.asarray(canonical_colors["face_region_rgb"], dtype=np.float32)
assigned = sem.sum(axis=1) > 0
vcolor[assigned] = sem[assigned]
return vcolor
def _rasterize_chunk(
fv_pix: torch.Tensor, # (Fc, 3, 2) — pixel coords (sub-pixel float)
fv_z: torch.Tensor, # (Fc, 3) — image-frame z (smaller=closer)
bb_min_x: torch.Tensor, bb_max_x: torch.Tensor, # (Fc,) clamped int bboxes
bb_min_y: torch.Tensor, bb_max_y: torch.Tensor,
max_sx: int, max_sy: int,
W: int,
):
"""Rasterize a chunk of faces at pixel centers. Returns flat tensors of
inside fragments: (pixel_idx, depth, face_local, bary).
"""
device = fv_pix.device
if max_sx == 0 or max_sy == 0:
return None
sx = bb_max_x - bb_min_x
sy = bb_max_y - bb_min_y
px_off = torch.arange(max_sx, device=device)
py_off = torch.arange(max_sy, device=device)
# Pixel-center sample positions, broadcast to (Fc, max_sy, max_sx).
P_x = (bb_min_x[:, None, None] + px_off[None, None, :]).float() + 0.5
P_y = (bb_min_y[:, None, None] + py_off[None, :, None]).float() + 0.5
in_bb = (px_off[None, None, :] < sx[:, None, None]) & \
(py_off[None, :, None] < sy[:, None, None])
Ax = fv_pix[:, 0, 0][:, None, None]
Ay = fv_pix[:, 0, 1][:, None, None]
Bx = fv_pix[:, 1, 0][:, None, None]
By = fv_pix[:, 1, 1][:, None, None]
Cx = fv_pix[:, 2, 0][:, None, None]
Cy = fv_pix[:, 2, 1][:, None, None]
area2 = (Bx - Ax) * (Cy - Ay) - (By - Ay) * (Cx - Ax) # (Fc, 1, 1)
e_a = (Bx - P_x) * (Cy - P_y) - (By - P_y) * (Cx - P_x)
e_b = (Cx - P_x) * (Ay - P_y) - (Cy - P_y) * (Ax - P_x)
e_c = (Ax - P_x) * (By - P_y) - (Ay - P_y) * (Bx - P_x)
# Same-sign-as-area2 inside test (no back-face culling — match either winding).
nondegen = area2.abs() > 1e-6 # threshold rejects near-degenerate triangles
inside = (e_a * area2 >= 0) & (e_b * area2 >= 0) & (e_c * area2 >= 0)
inside = inside & in_bb & nondegen
if not inside.any():
return None
inv_a2 = torch.where(nondegen, 1.0 / area2, torch.zeros_like(area2))
w_a = e_a * inv_a2
w_b = e_b * inv_a2
w_c = e_c * inv_a2
z_a = fv_z[:, 0, None, None]
z_b = fv_z[:, 1, None, None]
z_c = fv_z[:, 2, None, None]
z_grid = w_a * z_a + w_b * z_b + w_c * z_c
fi, yi, xi = inside.nonzero(as_tuple=True)
px_pixel = bb_min_x[fi] + xi
py_pixel = bb_min_y[fi] + yi
pixel_idx = (py_pixel * W + px_pixel).long()
z_flat = z_grid[fi, yi, xi]
bary_flat = torch.stack([w_a[fi, yi, xi], w_b[fi, yi, xi], w_c[fi, yi, xi]], dim=-1)
return pixel_idx, z_flat, fi.long(), bary_flat
def _rasterize_person(
verts_world: torch.Tensor, faces: torch.Tensor,
focal: float, W: int, H: int,
z_buf: torch.Tensor, color_buf: torch.Tensor, mask_buf: torch.Tensor,
shade_fn,
):
# Project image-frame verts to pixel coords. Skip verts at/behind camera.
z_min_ok = 0.05
valid_v = verts_world[:, 2] > z_min_ok
safe_z = verts_world[:, 2].clamp(min=z_min_ok)
px = 0.5 * W + focal * verts_world[:, 0] / safe_z
py = 0.5 * H + focal * verts_world[:, 1] / safe_z
Fv_pix = torch.stack([px, py], dim=-1)[faces] # (F, 3, 2)
Fv_z = verts_world[faces][..., 2] # (F, 3)
Fv_valid = valid_v[faces].all(dim=-1)
sx_face = Fv_pix[..., 0]
sy_face = Fv_pix[..., 1]
bb_min_x = sx_face.amin(dim=-1).floor().long().clamp(min=0, max=W)
bb_max_x = (sx_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=W)
bb_min_y = sy_face.amin(dim=-1).floor().long().clamp(min=0, max=H)
bb_max_y = (sy_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=H)
sx_all = bb_max_x - bb_min_x
sy_all = bb_max_y - bb_min_y
valid_face = Fv_valid & (sx_all > 0) & (sy_all > 0)
keep = torch.where(valid_face)[0]
if keep.numel() == 0:
return
# Sort kept faces by max bbox dimension so chunks stay similarly-sized.
bbsize = torch.maximum(sx_all, sy_all)[keep]
order = torch.argsort(bbsize)
keep = keep[order]
n = keep.numel()
sx_cpu = sx_all[keep].tolist()
sy_cpu = sy_all[keep].tolist()
bbsize_cpu = bbsize[order].tolist()
PIXEL_BUDGET = 4_000_000
MAX_CHUNK = 8192
i = 0
while i < n:
e = min(i + MAX_CHUNK, n)
# Shrink chunk so worst-case per-face bbox stays within pixel budget.
while e > i + 1:
bb = bbsize_cpu[e - 1]
if (e - i) * bb * bb <= PIXEL_BUDGET:
break
e = max(i + 1, e - max(1, (e - i) // 4))
chunk = keep[i:e]
max_sx = max(sx_cpu[i:e])
max_sy = max(sy_cpu[i:e])
i = e
result = _rasterize_chunk(
Fv_pix[chunk], Fv_z[chunk],
bb_min_x[chunk], bb_max_x[chunk],
bb_min_y[chunk], bb_max_y[chunk],
max_sx, max_sy, W,
)
if result is None:
continue
pixel_idx, z_chunk, face_local, bary = result
face_global = chunk[face_local]
# Atomic depth test against z_buf.
old_at = z_buf[pixel_idx].clone()
z_buf.scatter_reduce_(0, pixel_idx, z_chunk, reduce='amin', include_self=True)
new_at = z_buf[pixel_idx]
is_min = (z_chunk == new_at) & (new_at < old_at)
if not is_min.any():
continue
# Multiple fragments can land on the same pixel and share the new min;
# stable-sort by pixel and keep the first of each run so shade_fn runs
# once per winning pixel. O(M) where M = surviving fragments
surv_pixel = pixel_idx[is_min]
surv_face = face_global[is_min]
surv_bary = bary[is_min]
sort_perm = torch.argsort(surv_pixel, stable=True)
sp = surv_pixel[sort_perm]
first = torch.ones_like(sp, dtype=torch.bool)
first[1:] = sp[1:] != sp[:-1]
selected = sort_perm[first]
wp_idx = surv_pixel[selected]
wp_face = surv_face[selected]
wp_bary = surv_bary[selected]
color_buf[wp_idx] = shade_fn(wp_face, wp_bary)
mask_buf[wp_idx] = True
def _make_shade_fn(
shader_preset, composite,
view_normals_v, view_pos_v, vcolor_v, faces,
base_color, light_dir, pastel_mix,
):
device = view_normals_v.device
base_color_t = torch.as_tensor(base_color, dtype=torch.float32, device=device)
light_dir_t = torch.as_tensor(light_dir, dtype=torch.float32, device=device)
# Light-vector constants — normalized once per render call.
l_unit = -light_dir_t
l_unit = l_unit / l_unit.norm().clamp(min=1e-8)
if pastel_mix <= 0.0:
apply_pastel = lambda rgb: rgb
else:
pm = float(pastel_mix)
apply_pastel = lambda rgb: rgb * (1.0 - pm) + pm
def gather_n(face_idx, bary):
n = (view_normals_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
return n / n.norm(dim=-1, keepdim=True).clamp(min=1e-8)
def gather_pos(face_idx, bary):
return (view_pos_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
def gather_color(face_idx, bary):
return (vcolor_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
if composite == "silhouette":
return lambda fi, ba: torch.ones((fi.shape[0], 3), device=device)
if shader_preset == "normals":
# View-space surface normal encoded as RGB (OpenGL Y+ convention).
# +X right → R; +Y up → G; +Z toward viewer → B. Each face shows mostly
# one channel, matching standard normal-map visualization.
def shade(face_idx, bary):
n = gather_n(face_idx, bary)
return apply_pastel(((n + 1.0) * 0.5).clamp(0.0, 1.0))
return shade
use_vcolor = (shader_preset in _CANONICAL_PRESETS) and (vcolor_v is not None)
if not use_vcolor:
# default.frag: ambient + diffuse + rim
def shade(face_idx, bary):
n = gather_n(face_idx, bary)
v = -gather_pos(face_idx, bary)
v = v / v.norm(dim=-1, keepdim=True).clamp(min=1e-8)
ndotl = (n * l_unit).sum(dim=-1).clamp(min=0)
ndotv = (n * v).sum(dim=-1).clamp(min=0)
rim = (1.0 - ndotv).pow(3.0)
lit = 0.25 * base_color_t \
+ 0.75 * base_color_t * ndotl.unsqueeze(-1) \
+ 0.35 * rim.unsqueeze(-1)
return apply_pastel(lit)
return shade
if shader_preset == "rainbow":
def shade(face_idx, bary):
base = gather_color(face_idx, bary)
n = gather_n(face_idx, bary)
ndotl = (n * l_unit).sum(dim=-1).clamp(min=0)
return apply_pastel(base * (0.65 + 0.35 * ndotl).unsqueeze(-1))
return shade
# rainbow_face_* → rainbow_lit.frag. All light-direction & half-vector
# constants depend only on the (constant) light_dir, so precompute them.
key_l = l_unit
fill_l = torch.stack([-key_l[0], key_l[1].abs(), -key_l[2]])
view_dir = torch.tensor([0.0, 0.0, 1.0], device=device)
h = key_l + view_dir
h = h / h.norm().clamp(min=1e-8)
def shade(face_idx, bary):
base = gather_color(face_idx, bary)
n = gather_n(face_idx, bary)
key_ndotl = (n * key_l).sum(dim=-1).clamp(min=0)
fill_ndotl = (n * fill_l).sum(dim=-1).clamp(min=0)
rim = (1.0 - n[..., 2].clamp(min=0)).pow(2.5) * 0.30
shade_val = (0.45 + 0.45 * key_ndotl + 0.15 * fill_ndotl + rim * 0.5).clamp(min=0.0, max=1.25)
ndoth = (n * h).sum(dim=-1).clamp(min=0)
spec = ndoth.pow(48) * 0.12
lit = base * shade_val.unsqueeze(-1) + spec.unsqueeze(-1)
return apply_pastel(lit)
return shade
def render_pose_data_torch(
pose_data: dict,
frame_idx: int,
W: int,
H: int,
background=None, # Optional[np.ndarray | torch.Tensor] (H, W, 3) fp32 [0, 1]
composite: str = "over",
opacity: float = 1.0,
shader_preset: str = "default",
base_color: Sequence[float] = (0.68, 0.71, 0.78),
light_dir: Sequence[float] = (0.4, -0.7, -0.6),
rainbow_tilt_x_deg: float = 0.0,
rainbow_tilt_z_deg: float = 0.0,
person_brightness_falloff: float = 0.6,
) -> torch.Tensor:
"""Render one frame of persons from `pose_data` at resolution WxH.
Returns an (H, W, 3) float32 torch.Tensor on the comfy compute device,
ready to be stacked into the node's IMAGE output without a CPU round-trip."""
device = comfy.model_management.get_torch_device()
persons = pose_data["frames"][frame_idx] if frame_idx < len(pose_data["frames"]) else []
if len(persons) == 0:
if composite == "over" and background is not None:
if isinstance(background, np.ndarray):
bg = torch.as_tensor(background, dtype=torch.float32, device=device)
else:
bg = background.to(device=device, dtype=torch.float32) if (
background.device != device or background.dtype != torch.float32
) else background
return bg.clamp(0.0, 1.0)
return torch.zeros((H, W, 3), device=device, dtype=torch.float32)
faces = torch.as_tensor(np.asarray(pose_data["faces"], dtype=np.int64), device=device)
canonical_colors = pose_data.get("canonical_colors")
using_canonical = shader_preset in _CANONICAL_PRESETS
if using_canonical and canonical_colors is None:
shader_preset = "default"
using_canonical = False
vcolor = None
if using_canonical:
vcolor_np = _build_vcolor(canonical_colors, shader_preset,
rainbow_tilt_x_deg, rainbow_tilt_z_deg)
vcolor = torch.as_tensor(vcolor_np, dtype=torch.float32, device=device)
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
person_pastel = [0.0 if k == 0 else (1.0 - falloff ** k) for k in range(len(persons))]
# Front-to-back draw order so nearer persons overdraw farther ones.
order = sorted(range(len(persons)),
key=lambda i: -float(np.asarray(persons[i]["pred_cam_t"]).reshape(-1)[2]))
HW = H * W
z_buf = torch.full((HW,), float('inf'), device=device, dtype=torch.float32)
color_buf = torch.zeros((HW, 3), device=device, dtype=torch.float32)
mask_buf = torch.zeros(HW, device=device, dtype=torch.bool)
for idx in order:
p = persons[idx]
verts_np = np.asarray(p["pred_vertices"], dtype=np.float32).reshape(-1, 3)
cam_t = np.asarray(p["pred_cam_t"], dtype=np.float32).reshape(3)
verts_world = torch.as_tensor(verts_np + cam_t[None, :],
device=device, dtype=torch.float32)
focal = float(np.asarray(p.get("focal_length", 5000.0)).reshape(-1)[0])
# Image-frame (+Y down, +Z forward) → view-space (+Y up, -Z forward)
# for shading, matching what the GL-style shader math expects.
view_pos_v = torch.stack(
[verts_world[:, 0], -verts_world[:, 1], -verts_world[:, 2]], dim=-1,
)
normals_world = _vertex_normals(verts_world, faces)
view_normals_v = torch.stack(
[normals_world[:, 0], -normals_world[:, 1], -normals_world[:, 2]], dim=-1,
)
vcolor_p = vcolor if (vcolor is not None and vcolor.shape[0] == verts_world.shape[0]) else None
# Only canonical-vcolor shaders need vcolor; geometric shaders
# ('normals', 'depth') and the lit default work without it.
if shader_preset in _CANONICAL_PRESETS and vcolor_p is None:
effective_preset = "default"
else:
effective_preset = shader_preset
shade_fn = _make_shade_fn(
effective_preset, composite,
view_normals_v, view_pos_v, vcolor_p, faces,
base_color, light_dir, person_pastel[idx],
)
_rasterize_person(
verts_world, faces, focal, W, H,
z_buf, color_buf, mask_buf, shade_fn,
)
# Stay on GPU through readback + composite.
if shader_preset == "depth":
# z_buf already holds linear image-frame z (smaller=closer; +inf where no mesh covers)
# Normalize within the rendered mesh's range: near=white, far=black, background=black
mask_2d = mask_buf.reshape(H, W)
z_2d = z_buf.reshape(H, W)
if mask_2d.any():
zin = z_2d[mask_2d]
zmin = zin.min()
zr = (zin.max() - zmin).clamp(min=1e-6)
norm = torch.where(mask_2d, 1.0 - (z_2d - zmin) / zr, torch.zeros_like(z_2d))
else:
norm = torch.zeros((H, W), device=device, dtype=torch.float32)
rendered = torch.stack([norm, norm, norm], dim=-1)
mask_f = mask_2d.float()
else:
rendered = color_buf.reshape(H, W, 3).clamp(0.0, 1.0)
mask_f = mask_buf.reshape(H, W).float()
if composite == "over" and background is not None:
if isinstance(background, np.ndarray):
bg = torch.as_tensor(background, dtype=torch.float32, device=device)
else:
bg = background.to(device=device, dtype=torch.float32)
a = mask_f.unsqueeze(-1)
if opacity != 1.0:
a = a * float(opacity)
rendered = torch.lerp(bg, rendered, a)
return rendered

View File

@ -0,0 +1,397 @@
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import numpy as np
from comfy.ldm.sam3d_body.utils import prepare_batch
from comfy.ldm.sam3.tracker import unpack_masks
from comfy.ldm.sam3d_body.model.model import SAM3DBody
import comfy.model_management
from comfy_api.latest import io
import comfy.utils
from tqdm import tqdm
def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]:
"""xyxy bounds of a binary mask, with sub-5px speckles filtered out."""
m = mask[..., 0] if mask.dim() == 3 else mask
m_bool = m > 0
if not m_bool.any():
return None
t = m_bool.to(torch.float32)[None, None]
eroded = -F.max_pool2d(-t, kernel_size=5, stride=1, padding=2)
keep = eroded[0, 0] > 0.5
if not keep.any():
keep = m_bool
ys, xs = torch.where(keep)
return torch.stack([
xs.min().float(), ys.min().float(),
(xs.max() + 1).float(), (ys.max() + 1).float(),
])
def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
resolution. Returns (per_frame_bboxes, per_frame_masks) or
(None, None) when the track is empty / frame count doesn't match"""
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
if packed is None:
return None, None
unpacked = unpack_masks(packed) # (N, K, Hm, Wm) bool
N, K = unpacked.shape[:2]
if N != B or K == 0:
return None, None
Hm, Wm = unpacked.shape[2], unpacked.shape[3]
resized = F.interpolate(
unpacked.float().reshape(N * K, 1, Hm, Wm),
size=(H, W), mode="bilinear", align_corners=False,
)
arr = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W).cpu()
per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)]
full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32)
per_frame_bboxes = []
for f in range(N):
derived = []
for k in range(K):
b = _bbox_from_mask(arr[f, k])
derived.append(b if b is not None else full_frame_bbox)
per_frame_bboxes.append(torch.stack(derived, dim=0))
return per_frame_bboxes, per_frame_masks
# Soft budget for the batched Predict path
BATCHED_CROPS_PER_CHUNK = 64
def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[torch.Tensor]:
"""(1,3,3) intrinsic matrix from a vertical FOV in degrees. Matches MoGe2's
convention (vertical focal for both axes). Returns None for fov<=0 so the
caller falls back to prepare_batch's diagonal-focal default."""
if fov_degrees <= 0:
return None
focal = height / (2.0 * math.tan(math.radians(fov_degrees) / 2.0))
return torch.tensor(
[[[focal, 0.0, width / 2.0],
[0.0, focal, height / 2.0],
[0.0, 0.0, 1.0]]],
dtype=torch.float32,
)
def cam_int_from_moge(moge_geometry, height: int, width: int) -> Optional[torch.Tensor]:
"""(1,3,3) intrinsic matrix from a MoGe geometry payload. Uses MoGe's
vertical focal for both axes; forces principal point to image center
(overrides MoGe's predicted cx/cy to match prepare_batch's convention)."""
if moge_geometry is None:
return None
# MOGE_GEOMETRY is a dict with optional keys (see comfy_extras/nodes_moge.py).
K_norm = moge_geometry.get("intrinsics") if isinstance(moge_geometry, dict) else None
if K_norm is None:
return None
if K_norm.ndim == 3:
K_norm = K_norm[0]
# MoGe stores fy in height-units (multiply by H to get pixels); vfov = fy.
fy_norm = float(K_norm[1, 1].item())
focal = fy_norm * height
return torch.tensor(
[[[focal, 0.0, width / 2.0],
[0.0, focal, height / 2.0],
[0.0, 0.0, 1.0]]],
dtype=torch.float32,
)
def run_batched_single_chunk(
inner: SAM3DBody,
frames_rgb: List[torch.Tensor],
per_frame_boxes: List[torch.Tensor],
per_frame_masks: Optional[List[torch.Tensor]],
image_size: Tuple[int, int],
inference_type: str,
K: int,
cam_int: Optional[torch.Tensor] = None,
) -> List[List[Dict[str, Any]]]:
"""Run a SINGLE chunk of frames through run_inference in one forward."""
N = len(frames_rgb)
total = N * K
# Reset stateful caches on the model
for attr in ("batch", "image_embeddings", "output"):
if hasattr(inner, attr):
setattr(inner, attr, None)
inner.prev_prompt = []
boxes_stacked = torch.stack(
[per_frame_boxes[f][k] for f in range(N) for k in range(K)], dim=0
)
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
if per_frame_masks is not None:
# Broadcast a single-mask bundle to per-bbox: when the user supplied one
# mask but multiple bboxes per frame, each bbox gets the same mask.
flat_masks = []
for f in range(N):
mf = per_frame_masks[f]
if mf.shape[0] == 1 and K > 1:
mf = mf.repeat_interleave(K, dim=0)
flat_masks.extend([mf[k] for k in range(K)])
masks_stacked = torch.stack(flat_masks, dim=0)
masks_score = torch.ones(total, dtype=torch.float32)
else:
masks_stacked = None
masks_score = None
batch = prepare_batch(
img_per_crop, boxes_stacked,
input_size=image_size,
masks=masks_stacked, masks_score=masks_score, cam_int=cam_int,
)
device = comfy.model_management.get_torch_device()
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
inner._initialize_batch(batch)
outputs = inner.run_inference(
img_per_crop,
batch,
inference_type=inference_type,
thresh_wrist_angle=1.4,
)
if inference_type == "full":
pose_output, batch_lhand, batch_rhand, _, _ = outputs
else:
pose_output = outputs
batch_lhand = batch_rhand = None
out = {k: v.cpu().numpy() for k, v in pose_output["mhr"].items()
if v is not None and k != "faces"}
# Snapshot batch['bbox'] to CPU before we release `batch` references
batch_bbox_cpu = batch["bbox"][0].cpu().numpy()
lhand_bboxes = rhand_bboxes = None
if inference_type == "full" and batch_lhand is not None and batch_rhand is not None:
lhand_bboxes = [_bbox_from_center_scale(batch_lhand, i) for i in range(total)]
rhand_bboxes = [_bbox_from_center_scale(batch_rhand, i) for i in range(total)]
del pose_output, batch, batch_lhand, batch_rhand, outputs
frames_out: List[List[Dict[str, Any]]] = []
for f in range(N):
persons: List[Dict[str, Any]] = []
for k in range(K):
idx = f * K + k
p: Dict[str, Any] = {
"bbox": batch_bbox_cpu[idx],
"focal_length": out["focal_length"][idx],
"pred_keypoints_3d": out["pred_keypoints_3d"][idx],
"pred_keypoints_2d": out["pred_keypoints_2d"][idx],
"pred_vertices": out["pred_vertices"][idx],
"pred_cam_t": out["pred_cam_t"][idx],
"pred_pose_raw": out["pred_pose_raw"][idx],
"global_rot": out["global_rot"][idx],
"body_pose_params": out["body_pose"][idx],
"hand_pose_params": out["hand"][idx],
"scale_params": out["scale"][idx],
"shape_params": out["shape"][idx],
"expr_params": out["face"][idx],
"mask": (per_frame_masks[f][k] if per_frame_masks[f].shape[0] > 1 else per_frame_masks[f][0])
if per_frame_masks is not None else None,
"pred_joint_coords": out["pred_joint_coords"][idx],
"pred_global_rots": out["joint_global_rots"][idx],
"mhr_model_params": out["mhr_model_params"][idx],
# 238 face landmarks from sapiens-308 (indices 70..308 of the pre-slice keypoint tensor).
"pred_face_keypoints_3d": out["pred_face_keypoints_3d"][idx] if "pred_face_keypoints_3d" in out else None,
"pred_face_keypoints_2d": out["pred_face_keypoints_2d"][idx] if "pred_face_keypoints_2d" in out else None,
}
if lhand_bboxes is not None:
p["lhand_bbox"] = lhand_bboxes[idx]
p["rhand_bbox"] = rhand_bboxes[idx]
persons.append(p)
frames_out.append(persons)
return frames_out
def run_batched_frames(
inner: SAM3DBody,
frames_rgb: List[torch.Tensor],
per_frame_boxes: List[torch.Tensor],
per_frame_masks: Optional[List[torch.Tensor]],
image_size: Tuple[int, int],
inference_type: str,
cam_int: Optional[torch.Tensor] = None,
pbar: Optional[comfy.utils.ProgressBar] = None,
crops_per_chunk: int = BATCHED_CROPS_PER_CHUNK,
) -> List[List[Dict[str, Any]]]:
"""Run the clip through chunked batched run_inference calls.
Supports K persons per frame (K must be the same across frames padded
externally). Splits frames into chunks so chunk_frames * K <= budget; each
chunk is one body forward + optional hand forwards over its person-crops
"""
N = len(frames_rgb)
assert N > 0, "empty frame list"
K_set = {len(b) for b in per_frame_boxes}
assert len(K_set) == 1, f"batched path requires same bbox count per frame, got {K_set}"
K = K_set.pop()
assert K >= 1, "need at least one bbox per frame"
chunk_frames = max(1, crops_per_chunk // K)
results: List[List[Dict[str, Any]]] = []
with tqdm(total=N, desc="SAM3D body inference") as t:
for start in range(0, N, chunk_frames):
end = min(N, start + chunk_frames)
sub_frames = frames_rgb[start:end]
sub_boxes = per_frame_boxes[start:end]
sub_masks = None if per_frame_masks is None else per_frame_masks[start:end]
chunk_result = run_batched_single_chunk(
inner, sub_frames, sub_boxes, sub_masks,
image_size, inference_type, K,
cam_int=cam_int,
)
results.extend(chunk_result)
t.update(end - start)
if pbar is not None:
pbar.update(end - start)
# Drop GPU caches so the next chunk starts from a clean allocator state
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def _bbox_from_center_scale(batch, idx: int) -> np.ndarray:
cx = batch["bbox_center"].flatten(0, 1)[idx][0].item()
cy = batch["bbox_center"].flatten(0, 1)[idx][1].item()
sx = batch["bbox_scale"].flatten(0, 1)[idx][0].item()
sy = batch["bbox_scale"].flatten(0, 1)[idx][1].item()
return np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2], dtype=np.float32)
# Wire types and small helpers shared across the SAM 3D Body node modules.
def image_to_uint8(image: torch.Tensor) -> torch.Tensor:
"""ComfyUI image tensor (any shape, float 0..1) → uint8 tensor in [0, 255] on CPU."""
return (image * 255.0).clamp(0.0, 255.0).to(dtype=torch.uint8, device="cpu")
def compute_canonical_colors(model) -> Dict[str, np.ndarray]:
"""Canonical rest-pose data for shader color lookups: positions (Nv,3),
norm (Nv,3 in [0,1]), face_mask, head_mask, and face_region_rgb
(per-region painted color from the .safetensors)."""
verts = model.head_pose.canonical_vertices().float().cpu().numpy()
faces = model.head_pose.faces.cpu().numpy()
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
fn = np.cross(v1 - v0, v2 - v0).astype(np.float32)
vn = np.zeros_like(verts, dtype=np.float32)
np.add.at(vn, faces[:, 0], fn)
np.add.at(vn, faces[:, 1], fn)
np.add.at(vn, faces[:, 2], fn)
ln = np.linalg.norm(vn, axis=1, keepdims=True)
ln[ln < 1e-8] = 1.0
vn = vn / ln
norm_map = ((vn + 1.0) * 0.5).astype(np.float32)
face_mask = _compute_face_mask(model)
# Head: above jaw-neck (y>1.43) and narrower than shoulders (|x|<0.11).
# Ears reach |x|≈0.09; shoulders start at |x|≈0.20.
head_mask = (verts[:, 1] > 1.43) & (np.abs(verts[:, 0]) < 0.11)
# Painted per-vertex face region RGB ships in the model .safetensors as
# `head_pose.face_region_rgb` and gets loaded by load_state_dict.
face_region_rgb = model.head_pose.face_region_rgb.detach().float().cpu().numpy()
return {
"positions": verts.astype(np.float32),
"norm": norm_map,
"face_mask": face_mask,
"head_mask": head_mask,
"face_region_rgb": face_region_rgb,
}
def compute_hand_vert_mask(model, hand_radius_m: float = 0.15, weight_threshold: float = 0.5) -> np.ndarray:
"""(Nv,) bool mask of hand-region verts. Picks joints within `hand_radius_m`
of the mhr70 hand keypoint clusters (indices 21..62), then sums sparse LBS
weights across them; verts above `weight_threshold` are hand verts."""
head = model.head_pose
mhr = head.mhr
device = head.scale_mean.device
zeros = lambda *s: torch.zeros(1, *s, device=device)
out = head.mhr_forward(
global_trans=zeros(3),
global_rot=zeros(3),
body_pose_params=zeros(130),
hand_pose_params=zeros(head.num_hand_comps * 2),
scale_params=zeros(head.num_scale_comps),
shape_params=zeros(head.num_shape_comps),
expr_params=zeros(head.num_face_comps),
return_keypoints=True,
return_joint_coords=True,
)
# Output order with these flags: (verts, kp, jcoords). See mhr_head.mhr_forward.
_, kp, jcoords = out[0], out[1], out[2]
kp = kp[0, :70].cpu().numpy()
jcoords = jcoords[0].cpu().numpy()
right_center = kp[21:42].mean(axis=0)
left_center = kp[42:63].mean(axis=0)
j_dist_r = np.linalg.norm(jcoords - right_center, axis=1)
j_dist_l = np.linalg.norm(jcoords - left_center, axis=1)
is_hand_joint = (j_dist_r < hand_radius_m) | (j_dist_l < hand_radius_m)
lbs_w = mhr.lbs_skin_weights.cpu().numpy()
lbs_v = mhr.lbs_vert_indices.cpu().numpy()
lbs_j = mhr.lbs_skin_indices.cpu().numpy()
is_hand_joint_f = is_hand_joint.astype(np.float32)
n_verts = mhr.NUM_VERTS
hand_mass = np.zeros(n_verts, dtype=np.float32)
np.add.at(hand_mass, lbs_v, lbs_w * is_hand_joint_f[lbs_j])
return hand_mass >= weight_threshold
def _compute_face_mask(model, disp_threshold_m: float = 1e-4) -> np.ndarray:
"""(Nv,) bool mask of verts that move with face expression. Sweeps each of
the 72 expression axes at coef=+1.0 and flags any vert that moves more
than `disp_threshold_m` for at least one axis."""
head = model.head_pose
device = head.scale_mean.device
num_face = head.num_face_comps
zeros = lambda *s: torch.zeros(1, *s, device=device)
neutral_kw = dict(
global_trans=zeros(3),
global_rot=zeros(3),
body_pose_params=zeros(130),
hand_pose_params=zeros(head.num_hand_comps * 2),
scale_params=zeros(head.num_scale_comps),
shape_params=zeros(head.num_shape_comps),
expr_params=zeros(num_face),
)
v0 = head.mhr_forward(**neutral_kw).cpu().numpy()[0] # (Nv, 3)
face_mask = np.zeros(v0.shape[0], dtype=bool)
for axis in range(num_face):
expr = zeros(num_face)
expr[0, axis] = 1.0
kw = dict(neutral_kw)
kw["expr_params"] = expr
v = head.mhr_forward(**kw).cpu().numpy()[0]
face_mask |= (np.linalg.norm(v - v0, axis=1) > disp_threshold_m)
return face_mask
def jet_colormap(s: np.ndarray) -> np.ndarray:
"""matplotlib jet, (N,) in [0,1] -> (N, 3) float32 RGB."""
s = np.asarray(s, dtype=np.float32).clip(0.0, 1.0)
r = np.interp(s, [0.0, 0.35, 0.66, 0.89, 1.0], [0.0, 0.0, 1.0, 1.0, 0.5])
g = np.interp(s, [0.0, 0.125, 0.375, 0.64, 0.91, 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0])
b = np.interp(s, [0.0, 0.11, 0.34, 0.65, 1.0], [0.5, 1.0, 1.0, 0.0, 0.0])
return np.stack([r, g, b], axis=-1).astype(np.float32)

View File

@ -2445,6 +2445,7 @@ async def init_builtin_extra_nodes():
"nodes_save_3d.py",
"nodes_moge.py",
"nodes_mediapipe.py",
"nodes_sam3d_body.py",
]
import_failed = []