This commit is contained in:
Jukka Seppänen 2026-07-02 17:56:30 -07:00 committed by GitHub
commit 26d2fc8217
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 9929 additions and 321 deletions

View File

@ -156,10 +156,12 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
class DINOv3ViTEmbeddings(nn.Module):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations, use_mask_token=True):
super().__init__()
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
# mask_token is a pre-training param, omit it when the checkpoint does not ship it so strict loading stays clean
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) if use_mask_token else None
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
self.patch_embeddings = operations.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
@ -212,7 +214,7 @@ class DINOv3ViTLayer(nn.Module):
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
def __init__(self, config, dtype, device, operations, use_mask_token=True):
super().__init__()
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
@ -228,7 +230,7 @@ class DINOv3ViTModel(nn.Module):
self.embeddings = DINOv3ViTEmbeddings(
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
dtype=dtype, device=device, operations=operations
dtype=dtype, device=device, operations=operations, use_mask_token=use_mask_token
)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
@ -240,6 +242,10 @@ class DINOv3ViTModel(nn.Module):
for _ in range(num_hidden_layers)])
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
self.patch_size = patch_size
self.embed_dim = self.embed_dims = hidden_size
self.num_prefix_tokens = 1 + num_register_tokens # cls + register
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@ -257,3 +263,12 @@ class DINOv3ViTModel(nn.Module):
sequence_output = norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return sequence_output, None, pooled_output, None
def forward_features(self, pixel_values, **kwargs):
"""Dense (B, C, H, W) patch-feature grid, CLS + register tokens dropped."""
sequence_output = self.forward(pixel_values, **kwargs)[0]
b = pixel_values.shape[0]
h = pixel_values.shape[-2] // self.patch_size
w = pixel_values.shape[-1] // self.patch_size
patches = sequence_output[:, self.num_prefix_tokens:, :]
return patches.reshape(b, h, w, self.embed_dim).permute(0, 3, 1, 2).contiguous()

View File

@ -9,8 +9,7 @@ from torchvision.ops import roi_align
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker
from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__
from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine
from comfy.ldm.sam3.sam import SAM3VisionBackbone, MLP, PositionEmbeddingSine
TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker}
from comfy.ops import cast_to_input

View File

@ -0,0 +1,333 @@
from typing import Optional
import torch
import torch.nn as nn
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, mhr_param_hand_mask
from ..model.transformer import MLP
class MHRHead(nn.Module):
def __init__(self, input_dim: int, mhr_rig, mlp_depth: int = 1, mlp_channel_div_factor: int = 8, enable_hand_model=False,
device=None, dtype=None, operations=None):
super().__init__()
# Store the shared MHRRig as a non-registered Python attribute
object.__setattr__(self, "mhr", mhr_rig)
self.num_shape_comps = 45
self.num_scale_comps = 28
self.num_hand_comps = 54
self.num_face_comps = 72
self.enable_hand_model = enable_hand_model
self.body_cont_dim = 260
self.npose = (
6 # Global Rotation
+ self.body_cont_dim # then body
+ self.num_shape_comps
+ self.num_scale_comps
+ self.num_hand_comps * 2
+ self.num_face_comps
)
self.proj = MLP(
input_dim=input_dim,
hidden_dim=input_dim // mlp_channel_div_factor,
output_dim=self.npose,
num_layers=mlp_depth,
device=device, dtype=dtype, operations=operations,
)
# MHR Parameters
self.num_hand_scale_comps = self.num_scale_comps - 18
self.num_hand_pose_comps = self.num_hand_comps
# Buffers populated by load_state_dict from the safetensors
def _p(*shape, dtype=torch.float32):
return nn.Parameter(torch.empty(*shape, dtype=dtype), requires_grad=False)
self.joint_rotation = _p(127, 3, 3)
self.scale_mean = _p(68)
self.scale_comps = _p(28, 68)
self.register_buffer("faces", torch.empty(36874, 3, dtype=torch.int64))
self.hand_pose_mean = _p(54)
self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
self.register_buffer("hand_joint_idxs_left", torch.empty(27, dtype=torch.int64))
self.register_buffer("hand_joint_idxs_right", torch.empty(27, dtype=torch.int64))
self.keypoint_mapping = _p(308, 18439 + 127)
# Some special buffers for the hand-version
self.right_wrist_coords = _p(3)
self.root_coords = _p(3)
self.local_to_world_wrist = _p(3, 3)
self.register_buffer("nonhand_param_idxs", torch.empty(145, dtype=torch.int64))
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32))
def canonical_vertices(self):
"""Return the T-pose vertices for the mean shape (scaled to meters).
Runs MHR with zero pose / shape / scale / expression so the returned
mesh is the canonical rest pose fixed per-model
"""
device = self.scale_mean.device
dtype = self.scale_mean.dtype
B = 1
global_trans = torch.zeros(B, 3, device=device, dtype=dtype)
global_rot = torch.zeros(B, 3, device=device, dtype=dtype)
body_pose = torch.zeros(B, 130, device=device, dtype=dtype)
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=device, dtype=dtype)
scale = torch.zeros(B, self.num_scale_comps, device=device, dtype=dtype)
shape = torch.zeros(B, self.num_shape_comps, device=device, dtype=dtype)
expr = torch.zeros(B, self.num_face_comps, device=device, dtype=dtype)
verts = self.mhr_forward(
global_trans=global_trans,
global_rot=global_rot,
body_pose_params=body_pose,
hand_pose_params=hand_pose,
scale_params=scale,
shape_params=shape,
expr_params=expr,
) # single-tensor shape (1, N_v, 3) in meters
return verts[0]
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
assert full_pose_params.shape[1] == 136
# This drops in the hand poses from hand_pose_params (PCA 6D) into full_pose_params.
# Split into left and right hands
left_hand_params, right_hand_params = torch.split(
hand_pose_params,
[self.num_hand_pose_comps, self.num_hand_pose_comps],
dim=1,
)
# Change from cont to model params
left_hand_params_model_params = compact_cont_to_model_params_hand(
self.hand_pose_mean
+ torch.einsum("da,ab->db", left_hand_params, self.hand_pose_comps)
)
right_hand_params_model_params = compact_cont_to_model_params_hand(
self.hand_pose_mean
+ torch.einsum("da,ab->db", right_hand_params, self.hand_pose_comps)
)
# Drop it in
full_pose_params[:, self.hand_joint_idxs_left] = left_hand_params_model_params
full_pose_params[:, self.hand_joint_idxs_right] = right_hand_params_model_params
return full_pose_params # B x 207
def mhr_forward(
self,
global_trans,
global_rot,
body_pose_params,
hand_pose_params,
scale_params,
shape_params,
expr_params=None,
return_keypoints=False,
return_joint_coords=False,
return_model_params=False,
return_joint_rotations=False,
):
# Align everything to the static buffers
dt = self.scale_mean.dtype
global_trans = global_trans.to(dt)
global_rot = global_rot.to(dt)
body_pose_params = body_pose_params.to(dt)
if hand_pose_params is not None:
hand_pose_params = hand_pose_params.to(dt)
scale_params = scale_params.to(dt)
shape_params = shape_params.to(dt)
if expr_params is not None:
expr_params = expr_params.to(dt)
if self.enable_hand_model:
# Transfer wrist-centric predictions to the body.
global_rot_ori = global_rot.clone()
global_trans_ori = global_trans.clone()
global_rot = rotmat_to_euler(
"xyz",
euler_to_rotmat("xyz", global_rot_ori) @ self.local_to_world_wrist,
)
global_trans = (
-(
euler_to_rotmat("xyz", global_rot)
@ (self.right_wrist_coords - self.root_coords)
+ self.root_coords
)
+ global_trans_ori
)
body_pose_params = body_pose_params[..., :130]
# Convert from scale and shape params to actual scales and vertices
# Add singleton batches in case...
if len(scale_params.shape) == 1:
scale_params = scale_params[None]
if len(shape_params.shape) == 1:
shape_params = shape_params[None]
# Convert scale...
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
# Now, figure out the pose.
## 10 here is because it's more stable to optimize global translation in meters.
full_pose_params = torch.cat([global_trans * 10, global_rot, body_pose_params], dim=1) # B x 127
## Put in hands
if hand_pose_params is not None:
full_pose_params = self.replace_hands_in_pose(
full_pose_params, hand_pose_params
)
model_params = torch.cat([full_pose_params, scales], dim=1)
if self.enable_hand_model:
# Zero out non-hand parameters
model_params[:, self.nonhand_param_idxs] = 0
curr_skinned_verts, curr_skel_state = self.mhr(
shape_params, model_params, expr_params
)
curr_joint_coords, curr_joint_quats, _ = torch.split(
curr_skel_state, [3, 4, 1], dim=2
)
curr_skinned_verts = curr_skinned_verts / 100
curr_joint_coords = curr_joint_coords / 100
curr_joint_rots = unitquat_to_rotmat(curr_joint_quats)
# Prepare returns
to_return = [curr_skinned_verts]
if return_keypoints:
# Get sapiens 308 keypoints
model_vert_joints = torch.cat(
[curr_skinned_verts, curr_joint_coords], dim=1
) # B x (num_verts + 127) x 3
kp_map = self.keypoint_mapping.to(model_vert_joints.dtype)
model_keypoints_pred = (
(kp_map @ model_vert_joints.permute(1, 0, 2).flatten(1, 2))
.reshape(-1, model_vert_joints.shape[0], 3)
.permute(1, 0, 2)
)
if self.enable_hand_model:
# Zero out everything except for the right hand
model_keypoints_pred[:, :21] = 0
model_keypoints_pred[:, 42:] = 0
to_return = to_return + [model_keypoints_pred]
if return_joint_coords:
to_return = to_return + [curr_joint_coords]
if return_model_params:
to_return = to_return + [model_params]
if return_joint_rotations:
to_return = to_return + [curr_joint_rots]
if isinstance(to_return, list) and len(to_return) == 1:
return to_return[0]
else:
return tuple(to_return)
def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None, intermediate: bool = False):
"""
Args:
x: pose token with shape [B, C], usually C=DECODER.DIM
init_estimate: [B, self.npose]
intermediate: when True, the caller only needs the keypoints/pose
outputs needed by the per-layer keypoint-token update path
vertex output is suppressed so `camera_project` skips the
18439-vertex perspective projection on intermediate decoder
layers. The final layer must call with intermediate=False.
"""
batch_size = x.shape[0]
pred = self.proj(x)
if init_estimate is not None:
pred = pred + init_estimate
# From pred, we want to pull out individual predictions.
## First, get globals
### Global rotation is first 6.
count = 6
global_rot_6d = pred[:, :count]
global_rot_rotmat = rot6d_to_rotmat(global_rot_6d) # B x 3 x 3
global_rot_euler = rotmat_to_euler("ZYX", global_rot_rotmat) # B x 3
global_trans = torch.zeros_like(global_rot_euler)
## Next, get body pose.
### Hold onto raw, continuous version for iterative correction.
pred_pose_cont = pred[:, count : count + self.body_cont_dim]
count += self.body_cont_dim
### Convert to eulers (and trans)
pred_pose_euler = compact_cont_to_model_params_body(pred_pose_cont)
### Zero-out hands
pred_pose_euler[:, mhr_param_hand_mask] = 0
### Zero-out jaw
pred_pose_euler[:, -3:] = 0
## Get remaining parameters
pred_shape = pred[:, count : count + self.num_shape_comps]
count += self.num_shape_comps
pred_scale = pred[:, count : count + self.num_scale_comps]
count += self.num_scale_comps
pred_hand = pred[:, count : count + self.num_hand_comps * 2]
count += self.num_hand_comps * 2
pred_face = pred[:, count : count + self.num_face_comps] * 0
count += self.num_face_comps
# Run everything through mhr
output = self.mhr_forward(
global_trans=global_trans,
global_rot=global_rot_euler,
body_pose_params=pred_pose_euler,
hand_pose_params=pred_hand,
scale_params=pred_scale,
shape_params=pred_shape,
expr_params=pred_face,
return_keypoints=True,
return_joint_coords=True,
return_model_params=True,
return_joint_rotations=True,
)
# Some existing code to get joints and fix camera system
verts, j3d, jcoords, mhr_model_params, joint_global_rots = output
j3d = j3d[:, :70] # 308 --> 70 keypoints
# Intermediate decoder layers only consume pred_keypoints_3d via the
# keypoint-token update path; suppress verts so camera_project skips
# the 18439-vertex perspective projection.
if intermediate:
verts = None
if verts is not None:
verts[..., [1, 2]] *= -1 # Camera system difference
j3d[..., [1, 2]] *= -1 # Camera system difference
if jcoords is not None:
jcoords[..., [1, 2]] *= -1
# Head-MLP outputs are promoted to fp32 here so the external
# pose_output["mhr"] contract has a stable dtype regardless of what
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are
# already fp32 from MHR's math layers.
output = {
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
"pred_pose_rotmat": None,
"global_rot": global_rot_euler.float(),
"body_pose": pred_pose_euler.float(),
"shape": pred_shape.float(),
"scale": pred_scale.float(),
"hand": pred_hand.float(),
"face": pred_face.float(),
"pred_keypoints_3d": j3d.reshape(batch_size, -1, 3),
"pred_vertices": verts.reshape(batch_size, -1, 3) if verts is not None else None,
"pred_joint_coords": jcoords.reshape(batch_size, -1, 3) if jcoords is not None else None,
"faces": self.faces.cpu().numpy(),
"joint_global_rots": joint_global_rots,
"mhr_model_params": mhr_model_params,
}
return output

View File

@ -0,0 +1,233 @@
# Adapted from facebookresearch/MHR (Apache 2.0):
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
# verbatim from the upstream mhr_model.pt
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.sparse.check_sparse_tensor_invariants.disable() # silence the "implicitly disabled" UserWarning.
from .mhr_utils import batch6DFromXYZ
_LN2 = 0.6931471824645996
def _euler_xyz_to_quat(angles):
"""(roll, pitch, yaw) -> quaternion (x, y, z, w). Matches pymomentum.quaternion.euler_xyz_to_quaternion."""
roll, pitch, yaw = angles.unbind(-1)
cy, sy = torch.cos(yaw * 0.5), torch.sin(yaw * 0.5)
cp, sp = torch.cos(pitch * 0.5), torch.sin(pitch * 0.5)
cr, sr = torch.cos(roll * 0.5), torch.sin(roll * 0.5)
x = sr * cp * cy - cr * sp * sy
y = cr * sp * cy + sr * cp * sy
z = cr * cp * sy - sr * sp * cy
w = cr * cp * cy + sr * sp * sy
return torch.stack([x, y, z, w], dim=-1)
def _quat_multiply(q1, q2):
x1, y1, z1, w1 = q1.unbind(-1)
x2, y2, z2, w2 = q2.unbind(-1)
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
return torch.stack([x, y, z, w], dim=-1)
def _quat_rotate(q, v):
"""Rotate v by unit quaternion q (xyzw). v + 2 * (axis x v * w + axis x (axis x v))."""
axis = q[..., :3]
r = q[..., 3:4]
av = torch.cross(axis, v, dim=-1)
aav = torch.cross(axis, av, dim=-1)
return v + 2.0 * (av * r + aav)
def _skel_multiply(s1, s2):
"""Compose two skel states (..., 8). Returns parent ∘ child.
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
before composition. With many FK levels the previously-normalized quats
drift in ULPs; upstream renormalizes defensively, so we do too to stay
bit-close to its outputs.
"""
t1, sc1 = s1[..., :3], s1[..., 7:8]
t2, sc2 = s2[..., :3], s2[..., 7:8]
q1 = F.normalize(s1[..., 3:7], p=2, dim=-1, eps=1e-12)
q2 = F.normalize(s2[..., 3:7], p=2, dim=-1, eps=1e-12)
t_res = t1 + sc1 * _quat_rotate(q1, t2)
q_res = _quat_multiply(q1, q2)
s_res = sc1 * sc2
return torch.cat([t_res, q_res, s_res], dim=-1)
def _skel_transform_points(skel_state, points):
"""Apply skel_state (..., 8) to points (..., 3): t + q * (s * points).
Assumes the quaternion in skel_state is already unit-norm. Callers that
can't guarantee that should normalize first.
"""
t = skel_state[..., :3]
q = skel_state[..., 3:7]
s = skel_state[..., 7:8]
return t + _quat_rotate(q, s * points)
def _global_skel_state_from_local(local, pmi_levels):
"""FK walk in fp64 (matches upstream's use_double_precision=True path).
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
one per BFS level. Avoids per-call torch.split + tolist() sync.
"""
orig_dtype = local.dtype
g = local.to(torch.float64).clone()
for source, target in pmi_levels:
parent = g.index_select(-2, target)
child = g.index_select(-2, source)
g.index_copy_(-2, source, _skel_multiply(parent, child))
return g.to(orig_dtype)
class MHRRig(nn.Module):
"""Plain-PyTorch reimplementation of Meta's MHR rig.
All math runs in fp32 (FK upcast to fp64 internally, matching upstream's
use_double_precision=True backend) regardless of the host model's dtype.
"""
NUM_VERTS = 18439
NUM_JOINTS = 127
NUM_LBS_TRIPLETS = 51337
NUM_IDENTITY = 45
NUM_EXPR = 72
PARAM_TRANSFORM_IN = 249 # = model_parameters(204) + identity_coeffs(45)
PARAM_TRANSFORM_OUT = 889 # = NUM_JOINTS * 7
POSE_CORR_IN = 750 # = (NUM_JOINTS - 2) * 6
POSE_CORR_HIDDEN = 3000
POSE_CORR_SPARSE_NNZ = 53136
def __init__(self, device=None):
super().__init__()
# All buffers are populated by load_state_dict from the `mhr.*` keys
def _p(*shape, dtype=torch.float32):
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
def _b(name, *shape, dtype):
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
self.base_shape = _p(self.NUM_VERTS, 3)
self.identity_basis = _p(self.NUM_IDENTITY, self.NUM_VERTS, 3)
self.expr_basis = _p(self.NUM_EXPR, self.NUM_VERTS, 3)
self.param_transform = _p(self.PARAM_TRANSFORM_OUT, self.PARAM_TRANSFORM_IN)
self.skel_joint_translation_offsets = _p(self.NUM_JOINTS, 3)
self.skel_joint_prerotations = _p(self.NUM_JOINTS, 4)
_b("skel_joint_parents", self.NUM_JOINTS, dtype=torch.int32)
_b("skel_pmi", 2, 266, dtype=torch.int64)
_b("skel_pmi_buffer_sizes", 4, dtype=torch.int64)
self.lbs_inverse_bind_pose = _p(self.NUM_JOINTS, 8)
self.lbs_skin_weights = _p(self.NUM_LBS_TRIPLETS)
_b("lbs_skin_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int32)
_b("lbs_vert_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int64)
_b("pose_corr_sparse_indices", 2, self.POSE_CORR_SPARSE_NNZ, dtype=torch.int64)
self.pose_corr_sparse_weight = _p(self.POSE_CORR_SPARSE_NNZ)
self.register_buffer("pose_corr_sparse_shape", torch.tensor([self.POSE_CORR_HIDDEN, self.POSE_CORR_IN], dtype=torch.int64, device=device))
self.pose_corr_weight = _p(self.NUM_VERTS * 3, self.POSE_CORR_HIDDEN)
self.pose_corr_bias = None
self._pose_corr_sparse_cache = None
self._pmi_levels_cache = None
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
dtype = self.base_shape.dtype
identity_coeffs = identity_coeffs.to(dtype)
model_parameters = model_parameters.to(dtype)
expr_coeffs = expr_coeffs.to(dtype)
B = identity_coeffs.shape[0]
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)
cat_in = torch.cat([model_parameters, torch.zeros_like(identity_coeffs)], dim=1)
joint_parameters = torch.einsum("dn,bn->bd", self.param_transform, cat_in)
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
local_t = jp[..., :3] + self.skel_joint_translation_offsets.unsqueeze(0)
local_q = _euler_xyz_to_quat(jp[..., 3:6])
local_q = _quat_multiply(self.skel_joint_prerotations.unsqueeze(0), local_q)
local_s = torch.exp(jp[..., 6:7] * _LN2)
local_state = torch.cat([local_t, local_q, local_s], dim=-1)
skel_state = _global_skel_state_from_local(local_state, self._pmi_levels())
face_expr = torch.einsum("nvd,bn->bvd", self.expr_basis, expr_coeffs)
unposed = identity_rest + face_expr
if apply_correctives:
unposed = unposed + self._pose_correctives(joint_parameters)
verts = self._skin(skel_state, unposed)
return verts, skel_state
def _pose_correctives(self, joint_parameters):
B = joint_parameters.shape[0]
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
# Joints [2:] only — root and one more skipped. Take Euler XYZ (cols 3:6).
feat = batch6DFromXYZ(jp[:, 2:, 3:6], return_9D=False) # (B, 125, 6)
feat[..., 0] -= 1.0
feat[..., 4] -= 1.0
feat = feat.flatten(1, 2) # (B, 750)
h = (self._sparse_w() @ feat.T).T # (B, 3000)
h = F.relu(h)
out = F.linear(h, self.pose_corr_weight, self.pose_corr_bias) # (B, 55317)
return out.view(B, self.NUM_VERTS, 3)
def _pmi_levels(self):
cached = self._pmi_levels_cache
pmi = self.skel_pmi
if cached is not None and cached[0][0].device == pmi.device:
return cached
sizes = self.skel_pmi_buffer_sizes.tolist()
parts = torch.split(pmi, sizes, dim=1)
levels = [(p[0], p[1]) for p in parts]
self._pmi_levels_cache = levels
return levels
def _sparse_w(self):
cached = self._pose_corr_sparse_cache
w = self.pose_corr_sparse_weight
if cached is not None and cached.device == w.device and cached.dtype == w.dtype:
return cached
sparse = torch.sparse_coo_tensor(
self.pose_corr_sparse_indices,
w,
tuple(self.pose_corr_sparse_shape.tolist()),
check_invariants=False,
).coalesce()
self._pose_corr_sparse_cache = sparse
return sparse
def _skin(self, skel_state, rest_verts):
B = skel_state.shape[0]
ibp = self.lbs_inverse_bind_pose.unsqueeze(0).expand(B, self.NUM_JOINTS, 8)
joint_xform = _skel_multiply(skel_state, ibp)
norm_q = F.normalize(joint_xform[..., 3:7], p=2, dim=-1, eps=1e-12)
joint_xform = torch.cat([joint_xform[..., :3], norm_q, joint_xform[..., 7:8]], dim=-1)
sk_idx = self.lbs_skin_indices.long()
v_idx = self.lbs_vert_indices
w = self.lbs_skin_weights
per_triplet_xform = joint_xform.index_select(-2, sk_idx) # (B, 51337, 8)
per_triplet_rest = rest_verts.index_select(-2, v_idx) # (B, 51337, 3)
contrib = _skel_transform_points(per_triplet_xform, per_triplet_rest) * w.unsqueeze(0).unsqueeze(-1)
out = torch.zeros(B, self.NUM_VERTS, 3, dtype=rest_verts.dtype, device=rest_verts.device)
out.index_add_(-2, v_idx, contrib)
return out

View File

@ -0,0 +1,256 @@
# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
# (batch6DFromXYZ, batchXYZfrom6D) are the continuity
# representation from Zhou et al., "On the Continuity of Rotation
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
# implementations from papagina/RotationContinuity:
# https://github.com/papagina/RotationContinuity/blob/758b0ce5/shapenet/code/tools.py
# The compact_cont_to_model_params_* functions are MHR-rig-specific glue.
import torch
import torch.nn.functional as F
def rotation_angle_difference(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Compute the angle difference (magnitude) between two batches of SO(3) rotation matrices.
Args:
A: Tensor of shape (*, 3, 3), batch of rotation matrices.
B: Tensor of shape (*, 3, 3), batch of rotation matrices.
Returns:
Tensor of shape (*,), angle differences in radians.
"""
# Compute relative rotation matrix
R_rel = torch.matmul(A, B.transpose(-2, -1)) # (B, 3, 3)
# Compute trace of relative rotation
trace = R_rel[..., 0, 0] + R_rel[..., 1, 1] + R_rel[..., 2, 2] # (B,)
# Compute angle using the trace formula
cos_theta = (trace - 1) / 2
# Clamp for numerical stability
cos_theta_clamped = torch.clamp(cos_theta, -1.0, 1.0)
# Compute angle difference
angle = torch.acos(cos_theta_clamped)
return angle
def fix_wrist_euler(
wrist_xzy, limits_x=(-2.2, 1.0), limits_z=(-2.2, 1.5), limits_y=(-1.2, 1.5)
):
"""
wrist_xzy: B x 2 x 3 (X, Z, Y angles)
Returns: Fixed angles within joint limits
"""
x, z, y = wrist_xzy[..., 0], wrist_xzy[..., 1], wrist_xzy[..., 2]
x_alt = torch.atan2(torch.sin(x + torch.pi), torch.cos(x + torch.pi))
z_alt = torch.atan2(torch.sin(-(z + torch.pi)), torch.cos(-(z + torch.pi)))
y_alt = torch.atan2(torch.sin(y + torch.pi), torch.cos(y + torch.pi))
# Calculate L2 violation distance
def calc_violation(val, limits):
below = torch.clamp(limits[0] - val, min=0.0)
above = torch.clamp(val - limits[1], min=0.0)
return below**2 + above**2
violation_orig = (
calc_violation(x, limits_x)
+ calc_violation(z, limits_z)
+ calc_violation(y, limits_y)
)
violation_alt = (
calc_violation(x_alt, limits_x)
+ calc_violation(z_alt, limits_z)
+ calc_violation(y_alt, limits_y)
)
# Use alternative where it has lower L2 violation
use_alt = violation_alt < violation_orig
# Stack alternative and apply mask
wrist_xzy_alt = torch.stack([x_alt, z_alt, y_alt], dim=-1)
result = torch.where(use_alt.unsqueeze(-1), wrist_xzy_alt, wrist_xzy)
return result
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py
def batch6DFromXYZ(r, return_9D=False):
"""
Generate a matrix representing a rotation defined by a XYZ-Euler
rotation.
Args:
r: ... x 3 rotation vectors
Returns:
... x 6
"""
rc = torch.cos(r)
rs = torch.sin(r)
cx = rc[..., 0]
cy = rc[..., 1]
cz = rc[..., 2]
sx = rs[..., 0]
sy = rs[..., 1]
sz = rs[..., 2]
result = torch.empty(list(r.shape[:-1]) + [3, 3], dtype=r.dtype, device=r.device)
result[..., 0, 0] = cy * cz
result[..., 0, 1] = -cx * sz + sx * sy * cz
result[..., 0, 2] = sx * sz + cx * sy * cz
result[..., 1, 0] = cy * sz
result[..., 1, 1] = cx * cz + sx * sy * sz
result[..., 1, 2] = -sx * cz + cx * sy * sz
result[..., 2, 0] = -sy
result[..., 2, 1] = sx * cy
result[..., 2, 2] = cx * cy
if not return_9D:
return torch.cat([result[..., :, 0], result[..., :, 1]], dim=-1)
else:
return result
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L82
def batchXYZfrom6D(poses):
# Args: poses: ... x 6, where "6" is the combined first and second columns
# First, get the rotaiton matrix
x_raw = poses[..., :3]
y_raw = poses[..., 3:]
x = F.normalize(x_raw, dim=-1)
z = torch.cross(x, y_raw, dim=-1)
z = F.normalize(z, dim=-1)
y = torch.cross(z, x, dim=-1)
matrix = torch.stack([x, y, z], dim=-1) # ... x 3 x 3
# Now get it into euler
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L412
sy = torch.sqrt(
matrix[..., 0, 0] * matrix[..., 0, 0] + matrix[..., 1, 0] * matrix[..., 1, 0]
)
singular = sy < 1e-6
singular = singular.float()
x = torch.atan2(matrix[..., 2, 1], matrix[..., 2, 2])
y = torch.atan2(-matrix[..., 2, 0], sy)
z = torch.atan2(matrix[..., 1, 0], matrix[..., 0, 0])
xs = torch.atan2(-matrix[..., 1, 2], matrix[..., 1, 1])
ys = torch.atan2(-matrix[..., 2, 0], sy)
zs = matrix[..., 1, 0] * 0
out_euler = torch.zeros_like(matrix[..., 0])
out_euler[..., 0] = x * (1 - singular) + xs * singular
out_euler[..., 1] = y * (1 - singular) + ys * singular
out_euler[..., 2] = z * (1 - singular) + zs * singular
return out_euler
_HAND_DOFS = [3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1]
_HAND_MASK_CACHE: dict = {} # device -> dict of masks
def _hand_masks(device):
m = _HAND_MASK_CACHE.get(device)
if m is not None:
return m
mask_cont_threedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
mask_cont_onedofs = torch.cat([torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
mask_model_params_threedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]).to(device)
mask_model_params_onedofs = torch.cat([torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]).to(device)
m = dict(
mask_cont_threedofs=mask_cont_threedofs,
mask_cont_onedofs=mask_cont_onedofs,
mask_model_params_threedofs=mask_model_params_threedofs,
mask_model_params_onedofs=mask_model_params_onedofs,
)
_HAND_MASK_CACHE[device] = m
return m
def compact_cont_to_model_params_hand(hand_cont):
# These are ordered by joint, not model params ^^
m = _hand_masks(hand_cont.device)
mask_cont_threedofs = m["mask_cont_threedofs"]
mask_cont_onedofs = m["mask_cont_onedofs"]
mask_model_params_threedofs = m["mask_model_params_threedofs"]
mask_model_params_onedofs = m["mask_model_params_onedofs"]
# Convert hand_cont to eulers
## First for 3DoFs
hand_cont_threedofs = hand_cont[..., mask_cont_threedofs].unflatten(-1, (-1, 6))
hand_model_params_threedofs = batchXYZfrom6D(hand_cont_threedofs).flatten(-2, -1)
## Next for 1DoFs
hand_cont_onedofs = hand_cont[..., mask_cont_onedofs].unflatten(
-1, (-1, 2)
) # (sincos)
hand_model_params_onedofs = torch.atan2(
hand_cont_onedofs[..., -2], hand_cont_onedofs[..., -1]
)
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
hand_model_params = torch.zeros(*hand_cont.shape[:-1], 27).to(hand_cont)
hand_model_params[..., mask_model_params_threedofs] = hand_model_params_threedofs
hand_model_params[..., mask_model_params_onedofs] = hand_model_params_onedofs
return hand_model_params
_BODY_IDX_CACHE: dict = {}
def _body_idxs(device):
cached = _BODY_IDX_CACHE.get(device)
if cached is not None:
return cached
# fmt: off
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]).to(device)
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]).to(device)
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]).to(device)
# fmt: on
cached = (
all_param_3dof_rot_idxs,
all_param_1dof_rot_idxs,
all_param_1dof_trans_idxs,
all_param_3dof_rot_idxs.flatten(),
)
_BODY_IDX_CACHE[device] = cached
return cached
def compact_cont_to_model_params_body(body_pose_cont):
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
num_1dof_angles = len(all_param_1dof_rot_idxs)
# Get subsets
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
# Convert conts to model params
## First for 3dofs
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
body_params_3dofs = batchXYZfrom6D(body_cont_3dofs).flatten(-2, -1)
## Next for 1dofs
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
body_params_1dofs = torch.atan2(body_cont_1dofs[..., -2], body_cont_1dofs[..., -1])
## Nothing to do for trans
body_params_trans = body_cont_trans
# Put them together
body_pose_params = torch.zeros(*body_pose_cont.shape[:-1], 133, dtype=body_pose_cont.dtype, device=body_pose_cont.device)
body_pose_params[..., idxs_3dof_flat] = body_params_3dofs
body_pose_params[..., all_param_1dof_rot_idxs] = body_params_1dofs
body_pose_params[..., all_param_1dof_trans_idxs] = body_params_trans
return body_pose_params
# Hand indices into the 133-dim param and 260-dim cont body-pose vectors.
mhr_param_hand_idxs = list(range(62, 116))
mhr_cont_hand_idxs = list(range(72, 132)) + list(range(190, 238))
mhr_param_hand_mask = torch.zeros(133).bool()
mhr_param_hand_mask[mhr_param_hand_idxs] = True
mhr_cont_hand_mask = torch.zeros(260).bool()
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True

View File

@ -0,0 +1,140 @@
import math
import einops
import torch
import torch.nn.functional as F
from comfy.ldm.cascade.common import LayerNorm2d_op
from torch import nn
from typing import List, Optional, Tuple, Union
from ..utils import perspective_projection
from .transformer import MLP
class CameraEncoder(nn.Module):
def __init__(self, embed_dim: int, patch_size: int = 14, device=None, dtype=None, operations=None):
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.camera = FourierPositionEncoding(n=3, num_bands=16, max_resolution=64)
self.conv = operations.Conv2d(embed_dim + 99, embed_dim, kernel_size=1, bias=False, device=device, dtype=dtype)
self.norm = LayerNorm2d_op(operations)(embed_dim, device=device, dtype=dtype)
def forward(self, img_embeddings: torch.Tensor, rays: torch.Tensor):
B, D, _h, _w = img_embeddings.shape
scale = 1 / self.patch_size
rays = F.interpolate(rays, scale_factor=(scale, scale), mode="bilinear", align_corners=False, antialias=True)
rays = rays.permute(0, 2, 3, 1).contiguous() # [b, h, w, 2]
rays = torch.cat([rays, torch.ones_like(rays[..., :1])], dim=-1)
rays_embeddings = self.camera(pos=rays.reshape(B, -1, 3)) # (bs, N, 99): rays fourier embedding
rays_embeddings = einops.rearrange(rays_embeddings, "b (h w) c -> b c h w", h=_h, w=_w).contiguous()
z = torch.cat([img_embeddings, rays_embeddings], dim=1)
return self.norm(self.conv(z))
class FourierPositionEncoding(nn.Module):
"""Sin/cos Fourier features for ray positions"""
def __init__(self, n: int, num_bands: int, max_resolution: int):
super().__init__()
self.num_bands = num_bands
self.max_resolution = [max_resolution] * n
def forward(self, pos: torch.Tensor):
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
return fourier_pos_enc
def _generate_fourier_features(pos: torch.Tensor, num_bands: int, max_resolution: List[int], min_freq: float = 1.0):
b, n = pos.shape[:2]
freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=pos.device, dtype=pos.dtype) for res in max_resolution], dim=0)
per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0)
per_pos_features = per_pos_features.reshape(b, n, -1)
# Sin-Cos
per_pos_features = torch.cat([torch.sin(math.pi * per_pos_features), torch.cos(math.pi * per_pos_features)], dim=-1)
# Concat with initial pos
per_pos_features = torch.cat([pos, per_pos_features], dim=-1)
return per_pos_features
class PerspectiveHead(nn.Module):
"""
Predict camera translation (s, tx, ty) and perform full-perspective 2D reprojection (CLIFF/CameraHMR setup).
"""
def __init__(self, input_dim: int, img_size: Union[int, Tuple[int, int]], # model input size (W, H)
mlp_depth: int = 1, mlp_channel_div_factor: int = 8, default_scale_factor: float = 1.0,
device=None, dtype=None, operations=None
):
super().__init__()
# Metadata to compute 3D skeleton and 2D reprojection
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.ncam = 3 # (s, tx, ty)
self.default_scale_factor = default_scale_factor
self.proj = MLP(
input_dim=input_dim,
hidden_dim=input_dim // mlp_channel_div_factor,
output_dim=self.ncam,
num_layers=mlp_depth,
device=device,
dtype=dtype,
operations=operations,
)
def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None):
"""
Args:
x: pose token with shape [B, C], usually C=DECODER.DIM
init_estimate: [B, self.ncam]
"""
pred_cam = self.proj(x)
if init_estimate is not None:
pred_cam = pred_cam + init_estimate
return pred_cam
def perspective_projection(
self,
points_3d: torch.Tensor,
pred_cam: torch.Tensor,
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
bbox_size: torch.Tensor, # [N,], in original image space
cam_int: torch.Tensor, # [B, 3, 3]
):
batch_size = points_3d.shape[0]
pred_cam = pred_cam.clone()
pred_cam[..., [0, 2]] *= -1 # Camera system difference
# Compute camera translation: (scale, x, y) --> (x, y, depth)
# depth ~= f / s, Note that f is in the NDC space
s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2]
bs = bbox_size * s * self.default_scale_factor + 1e-8
focal_length = cam_int[:, 0, 0]
tz = 2 * focal_length / bs
cx = 2 * (bbox_center[:, 0] - cam_int[:, 0, 2]) / bs
cy = 2 * (bbox_center[:, 1] - cam_int[:, 1, 2]) / bs
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
# Compute camera translation
j3d_cam = points_3d + pred_cam_t.unsqueeze(1)
# Projection to the image plane, note that the projection output is in original image space now.
j2d = perspective_projection(j3d_cam, cam_int)
return {
"pred_keypoints_2d": j2d.reshape(batch_size, -1, 2),
"pred_keypoints_2d_depth": j3d_cam.reshape(batch_size, -1, 3)[:, :, 2],
"pred_cam_t": pred_cam_t, "focal_length": focal_length,
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,254 @@
"""SAM 3D Body prompt pipeline: encode (keypoint, mask) prompts and run them
through a cross-attention transformer decoder over (token, image) pairs.
Both adapted from the SAM-style prompt path (Meta, Apache 2.0):
https://github.com/facebookresearch/segment-anything
"""
from typing import Optional, Tuple
import torch
import torch.nn as nn
from comfy.ldm.cascade.common import LayerNorm2d_op
from comfy.ldm.sam3.sam import PositionEmbeddingRandom
from .transformer import TransformerDecoderLayer
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
num_body_joints: int,
device=None,
dtype=None,
operations=None,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
"""
super().__init__()
self.embed_dim = embed_dim
self.num_body_joints = num_body_joints
# Keypoint prompts
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.point_embeddings = nn.ModuleList(
[operations.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
)
self.not_a_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
self.invalid_point_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
LN2d = LayerNorm2d_op(operations)
mask_in_chans = 256
self.mask_downscaling = nn.Sequential(
operations.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 64, device=device, dtype=dtype),
nn.GELU(),
operations.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 16, device=device, dtype=dtype),
nn.GELU(),
operations.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans // 4, device=device, dtype=dtype),
nn.GELU(),
operations.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
LN2d(mask_in_chans, device=device, dtype=dtype),
nn.GELU(),
operations.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
)
# Trained values for the gating conv and no_mask_embed are loaded from the state dict
self.no_mask_embed = operations.Embedding(1, embed_dim, device=device, dtype=dtype)
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
"""Positional encoding over the image-embedding grid; (1, C, H, W)."""
return self.pe_layer(size)
def _embed_keypoints(self, points: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Embeds point prompts.
Assuming points have been normalized to [0, 1].
Output shape [B, N, C], mask shape [B, N]
"""
assert points.min() >= 0 and points.max() <= 1
# PE compute in fp32 for precision (sin/cos of large coords), then cast back to the embedding weight dtype
weight_dtype = self.invalid_point_embed.weight.dtype
point_embedding = self.pe_layer._encode(points.to(torch.float)).to(weight_dtype)
point_embedding[labels == -2] = 0.0 # invalid points
point_embedding[labels == -2] += self.invalid_point_embed.weight.to(point_embedding)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding)
for i in range(self.num_body_joints):
point_embedding[labels == i] += self.point_embeddings[i].weight.to(point_embedding)
point_mask = labels > -2
return point_embedding, point_mask
def _get_batch_size(self, keypoints: Optional[torch.Tensor], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor]) -> int:
if keypoints is not None:
return keypoints.shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def forward(
self,
keypoints: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
keypoints (torchTensor or none): point coordinates and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(keypoints, boxes, masks)
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
device = ref.device if ref is not None else self.point_embeddings[0].weight.device
weight_dtype = self.invalid_point_embed.weight.dtype
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=device, dtype=weight_dtype)
sparse_masks = torch.empty((bs, 0), device=device)
if keypoints is not None:
coords = keypoints[:, :, :2]
labels = keypoints[:, :, -1]
point_embeddings, point_mask = self._embed_keypoints(coords, labels)
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
sparse_masks = torch.cat([sparse_masks, point_mask], dim=1)
return sparse_embeddings, sparse_masks
def get_mask_embeddings(self, masks: torch.Tensor, bs: int = 1, size: Tuple[int, int] = (16, 16)) -> torch.Tensor:
"""Embeds mask inputs. Caller casts both outputs to its working dtype."""
no_mask_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, size[0], size[1])
mask_embeddings = self.mask_downscaling(masks)
return mask_embeddings, no_mask_embeddings
class PromptableDecoder(nn.Module):
"""Cross-attention transformer decoder over (token, image) pairs."""
def __init__(
self,
dims: int,
context_dims: int,
depth: int,
num_heads: int = 8,
head_dims: int = 64,
mlp_dims: int = 1024,
repeat_pe: bool = False,
do_interm_preds: bool = False,
keypoint_token_update: bool = False,
device=None, dtype=None, operations=None,
):
super().__init__()
self.layers = nn.ModuleList(
TransformerDecoderLayer(
token_dims=dims,
context_dims=context_dims,
num_heads=num_heads,
head_dims=head_dims,
mlp_dims=mlp_dims,
repeat_pe=repeat_pe,
skip_first_pe=(i == 0),
device=device,
dtype=dtype,
operations=operations,
)
for i in range(depth)
)
self.norm_final = operations.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
self.do_interm_preds = do_interm_preds
self.keypoint_token_update = keypoint_token_update
def forward(
self,
token_embedding: torch.Tensor,
image_embedding: torch.Tensor,
token_augment: Optional[torch.Tensor] = None,
image_augment: Optional[torch.Tensor] = None,
token_mask: Optional[torch.Tensor] = None,
token_to_pose_output_fn=None,
keypoint_token_update_fn=None,
hand_embeddings=None,
hand_augment=None,
):
"""
Args:
token_embedding: [B, N, C]
image_embedding: [B, C, H, W] -- flattened to [B, HW, C] inline
"""
# Channels-last for the transformer.
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
if image_augment is not None:
image_augment = image_augment.flatten(2).permute(0, 2, 1)
if hand_embeddings is not None:
hand_embeddings = hand_embeddings.flatten(2).permute(0, 2, 1)
hand_augment = hand_augment.flatten(2).permute(0, 2, 1)
if len(hand_augment) == 1:
# inflate batch dimension
assert len(hand_augment.shape) == 3
hand_augment = hand_augment.repeat(len(hand_embeddings), 1, 1)
all_pose_outputs = [] if self.do_interm_preds else None
if self.do_interm_preds:
assert token_to_pose_output_fn is not None
layer_idx = 0
for layer_idx, layer in enumerate(self.layers):
if hand_embeddings is None:
token_embedding, image_embedding = layer(
token_embedding, image_embedding,
token_augment, image_augment, token_mask,
)
else:
token_embedding, image_embedding = layer(
token_embedding,
torch.cat([image_embedding, hand_embeddings], dim=1),
token_augment,
torch.cat([image_augment, hand_augment], dim=1),
token_mask,
)
image_embedding = image_embedding[:, : image_augment.shape[1]]
if self.do_interm_preds and layer_idx < len(self.layers) - 1:
curr = token_to_pose_output_fn(
self.norm_final(token_embedding),
prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None,
layer_idx=layer_idx,
)
all_pose_outputs.append(curr)
if self.keypoint_token_update:
assert keypoint_token_update_fn is not None
token_embedding, token_augment, _, _ = keypoint_token_update_fn(
token_embedding, token_augment, curr, layer_idx,
)
out = self.norm_final(token_embedding)
if self.do_interm_preds:
curr = token_to_pose_output_fn(
out,
prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None,
layer_idx=layer_idx,
)
all_pose_outputs.append(curr)
return out, all_pose_outputs
return out

View File

@ -0,0 +1,104 @@
from typing import Optional
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act_layer=nn.ReLU, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList(
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
for i in range(num_layers)
)
self.act = act_layer()
def forward(self, x):
for i, layer in enumerate(self.layers):
x = self.act(layer(x)) if i < len(self.layers) - 1 else layer(x)
return x
class Attention(nn.Module):
def __init__(self, embed_dims, num_heads, query_dims=None, key_dims=None, value_dims=None, qkv_bias=True, proj_bias=True,
device=None, dtype=None, operations=None):
super().__init__()
self.query_dims = query_dims or embed_dims
self.key_dims = key_dims or embed_dims
self.value_dims = value_dims or embed_dims
self.embed_dims = embed_dims
self.num_heads = num_heads
self.head_dims = embed_dims // num_heads
lin = lambda i, o, b: operations.Linear(i, o, bias=b, device=device, dtype=dtype)
self.q_proj = lin(self.query_dims, embed_dims, qkv_bias)
self.k_proj = lin(self.key_dims, embed_dims, qkv_bias)
self.v_proj = lin(self.value_dims, embed_dims, qkv_bias)
self.proj = lin(embed_dims, self.query_dims, proj_bias)
def _split(self, x: torch.Tensor) -> torch.Tensor:
b, n, _ = x.shape
return x.reshape(b, n, self.num_heads, self.head_dims).transpose(1, 2)
def forward(self, q, k, v, attn_mask: Optional[torch.Tensor] = None):
q, k, v = self._split(self.q_proj(q)), self._split(self.k_proj(k)), self._split(self.v_proj(v))
x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True, low_precision_attention=False)
return self.proj(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, token_dims, context_dims, num_heads=8, head_dims=64, mlp_dims=1024,
repeat_pe=False, skip_first_pe=False, device=None, dtype=None, operations=None):
super().__init__()
self.repeat_pe = repeat_pe
self.skip_first_pe = skip_first_pe
ln = lambda d: operations.LayerNorm(d, eps=1e-6, device=device, dtype=dtype)
attn_dim = num_heads * head_dims
attn_kwargs = dict(embed_dims=attn_dim, num_heads=num_heads, device=device, dtype=dtype, operations=operations)
if repeat_pe:
self.ln_pe_1, self.ln_pe_2 = ln(token_dims), ln(context_dims)
self.ln1 = ln(token_dims)
self.self_attn = Attention(query_dims=token_dims, key_dims=token_dims, value_dims=token_dims, **attn_kwargs)
self.ln2_1, self.ln2_2 = ln(token_dims), ln(context_dims)
self.cross_attn = Attention(query_dims=token_dims, key_dims=context_dims, value_dims=context_dims, **attn_kwargs)
self.ln3 = ln(token_dims)
self.ffn = MLP(token_dims, mlp_dims, token_dims, num_layers=2, act_layer=nn.GELU, device=device, dtype=dtype, operations=operations)
def forward(self, x, context, x_pe=None, context_pe=None, x_mask=None):
"""x: [B, N_tokens, C], context: [B, N_ctx, C], x_mask: [B, N_tokens] or None."""
# LaPE-style PE re-norm per layer.
if self.repeat_pe and context_pe is not None:
x_pe = self.ln_pe_1(x_pe)
context_pe = self.ln_pe_2(context_pe)
# Self-attn over tokens.
if self.repeat_pe and not self.skip_first_pe and x_pe is not None:
q = k = self.ln1(x) + x_pe
v = self.ln1(x)
else:
q = k = v = self.ln1(x)
attn_mask = None
if x_mask is not None:
attn_mask = x_mask[:, :, None] @ x_mask[:, None, :]
attn_mask.diagonal(dim1=1, dim2=2).fill_(1) # avoid all-invalid rows -> nan
attn_mask = attn_mask > 0
x = x + self.self_attn(q, k, v, attn_mask=attn_mask)
# Cross-attn: tokens attend to image context.
if self.repeat_pe and context_pe is not None:
q = self.ln2_1(x) + x_pe
k = self.ln2_2(context) + context_pe
v = self.ln2_2(context)
else:
q = self.ln2_1(x)
k = v = self.ln2_2(context)
x = x + self.cross_attn(q, k, v)
x = x + self.ffn(self.ln3(x))
return x, context

View File

@ -0,0 +1,339 @@
# The bbox/affine math (xyxy<->cs, get_warp_matrices) is the standard
# top-down pose-estimation crop pipeline from MMPose (Apache 2.0):
# https://github.com/open-mmlab/mmpose — same algorithm as UDP (CVPR 2020).
from typing import Dict, Tuple
import torch
import torch.nn.functional as F
# Bbox + affine math
# All `output_size` / image-shape tuples in this block are (H, W) to match
# the torch.Size convention used everywhere else in the codebase.
def bbox_xyxy2cs(bbox, padding: float) -> Tuple[torch.Tensor, torch.Tensor]:
"""xyxy bbox -> (center, scale) with optional padding multiplier."""
bbox = torch.as_tensor(bbox, dtype=torch.float32)
dim = bbox.dim()
if dim == 1:
bbox = bbox.unsqueeze(0)
x1, y1, x2, y2 = bbox[:, 0:1], bbox[:, 1:2], bbox[:, 2:3], bbox[:, 3:4]
center = torch.cat([x1 + x2, y1 + y2], dim=1) * 0.5
scale = torch.cat([x2 - x1, y2 - y1], dim=1) * padding
if dim == 1:
return center[0], scale[0]
return center, scale
def fix_aspect_ratio(bbox_scale, aspect_ratio: float) -> torch.Tensor:
"""Pad whichever side is too narrow to hit `aspect_ratio` (w/h)."""
bbox_scale = torch.as_tensor(bbox_scale, dtype=torch.float32)
dim = bbox_scale.dim()
if dim == 1:
bbox_scale = bbox_scale.unsqueeze(0)
w, h = bbox_scale[:, 0:1], bbox_scale[:, 1:2]
out = torch.where(
w > h * aspect_ratio,
torch.cat([w, w / aspect_ratio], dim=1),
torch.cat([h * aspect_ratio, h], dim=1),
)
return out[0] if dim == 1 else out
def get_warp_matrices(centers, scales, output_size: Tuple[int, int]) -> torch.Tensor:
"""Batched 2x3 affine matrices mapping each (center, scale) bbox region to
the output box. `output_size` is (H_out, W_out). With rot=0 the MMPose
3-point fit reduces to a closed-form isotropic scale + translate.
"""
centers = torch.as_tensor(centers, dtype=torch.float32)
scales = torch.as_tensor(scales, dtype=torch.float32)
if centers.dim() == 1:
centers = centers.unsqueeze(0)
scales = scales.unsqueeze(0)
n = centers.shape[0]
src_w = scales[:, 0]
dst_h = float(output_size[0])
dst_w = float(output_size[1])
# With rot=0 the warp is just scale + translate (uniform x/y scale based
# on src_w/dst_w). The closed form drops out of MMPose's 3-point solve.
s = dst_w / src_w # (N,)
mats = torch.zeros((n, 2, 3), dtype=torch.float32)
mats[:, 0, 0] = s
mats[:, 1, 1] = s
mats[:, 0, 2] = dst_w * 0.5 - s * centers[:, 0]
mats[:, 1, 2] = dst_h * 0.5 - s * centers[:, 1]
return mats # (N, 2, 3)
def warp_affine_batched(
src_t: torch.Tensor, # (N, C, H_src, W_src) float
mats: torch.Tensor, # (N, 2, 3) float
output_size: Tuple[int, int] # (H_out, W_out)
) -> torch.Tensor:
"""Apply N forward (src->dst) 2x3 affine warps to N source images in one
grid_sample call. Kept generic over arbitrary affines (not specialized to
the scale+translate produced by `get_warp_matrices`) so callers can pass
rotated/sheared affines; the per-crop 3x3 invert is O(N) of trivial work."""
H_out, W_out = int(output_size[0]), int(output_size[1])
N, _, H_src, W_src = src_t.shape
device = src_t.device
# Invert each forward affine; grid_sample needs dst->src.
mats_t = mats.to(device=device, dtype=torch.float32)
bottom = torch.tensor([0.0, 0.0, 1.0], device=device).expand(N, 1, 3)
mats_3 = torch.cat([mats_t, bottom], dim=1) # (N, 3, 3)
mats_inv = torch.linalg.inv(mats_3)[:, :2, :] # (N, 2, 3)
# Output pixel-center grid (i+0.5, j+0.5).
ys, xs = torch.meshgrid(
torch.arange(H_out, dtype=torch.float32, device=device) + 0.5,
torch.arange(W_out, dtype=torch.float32, device=device) + 0.5,
indexing="ij",
)
homo = torch.stack([xs, ys, torch.ones_like(xs)], dim=-1) # (H_out, W_out, 3)
src_pos = torch.einsum("nkl,ijl->nijk", mats_inv, homo) # (N, H_out, W_out, 2)
# Normalize to [-1, 1] grid_sample coords (align_corners=False).
src_pos[..., 0] = src_pos[..., 0] / W_src * 2 - 1
src_pos[..., 1] = src_pos[..., 1] / H_src * 2 - 1
return F.grid_sample(src_t, src_pos, mode="bilinear", padding_mode="zeros", align_corners=False)
# Batch construction (one prediction over N person crops from a single image)
def prepare_batch(
img, # (H, W, 3) uint8 torch tensor or list of such tensors
boxes, # (N, 4) xyxy (numpy or torch)
input_size: Tuple[int, int], # (W, H) of the model crop
bbox_padding: float = 1.25, # xyxy->cs padding multiplier (1.25 body, 0.9 hand)
aspect_ratio: float = 0.75, # w/h of the crop (0.75 matches HMR2/Sapiens)
masks=None, # optional per-person masks
masks_score=None, # optional per-person mask scores
cam_int=None, # optional camera intrinsics
) -> Dict:
"""Build the batch dict the SAM3DBody forward expects, doing the N crops in one batched `grid_sample` call."""
is_multi_image = isinstance(img, list)
if is_multi_image:
assert len(img) == boxes.shape[0]
height, width = img[0].shape[:2]
else:
height, width = img.shape[:2]
n = int(boxes.shape[0])
assert n > 0, "prepare_batch needs at least one box"
W_out, H_out = int(input_size[0]), int(input_size[1])
# Per-box bbox math (cheap, vectorized, CPU).
centers, scales = bbox_xyxy2cs(boxes, padding=bbox_padding)
# Two passes: first hits the upstream bbox aspect (e.g. 0.75 HMR2/Sapiens
# convention), second pads further if the model crop's W_out/H_out differs
# from that. When they match (common case) the second call is a no-op.
scales = fix_aspect_ratio(scales, aspect_ratio)
scales = fix_aspect_ratio(scales, W_out / H_out)
mats = get_warp_matrices(centers, scales, (H_out, W_out)) # (N, 2, 3)
# Stack source images into a contiguous (N, 3, H, W) tensor on CPU.
if is_multi_image:
src_t = torch.stack(list(img), dim=0)
else:
src_t = img.unsqueeze(0).expand(n, -1, -1, -1)
src_t = src_t.permute(0, 3, 1, 2).contiguous().float() # (N, 3, H, W) in [0, 255]
warped_t = warp_affine_batched(src_t, mats, (H_out, W_out)) # (N, 3, H_out, W_out)
# Float warp -> floor (matches the legacy uint8 round-trip) -> /255.
img_t = torch.floor(warped_t).clamp_(0.0, 255.0) / 255.0 # (N, 3, H_out, W_out) in [0, 1]
# Masks: zero-init when missing, otherwise stack and warp through the same matrices.
boxes_t = torch.as_tensor(boxes, dtype=torch.float32)
if masks is None:
mask_t = torch.zeros((n, H_out, W_out), dtype=torch.float32)
mask_score_t = torch.zeros((n,), dtype=torch.float32)
else:
# masks is an array of N items, each (H, W) or (H, W, 1).
masks_t = torch.stack([torch.as_tensor(masks[i]) for i in range(n)], dim=0)
if masks_t.dim() == 4 and masks_t.shape[-1] == 1:
masks_t = masks_t[..., 0]
masks_src_t = masks_t.float().unsqueeze(1) # (N, 1, H, W) in [0, 255]
warped_masks = warp_affine_batched(masks_src_t, mats, (H_out, W_out))
mask_t = torch.floor(warped_masks.squeeze(1)).clamp_(0.0, 255.0)
if masks_score is not None:
mask_score_t = torch.as_tensor([masks_score[i] for i in range(n)], dtype=torch.float32)
else:
mask_score_t = torch.ones((n,), dtype=torch.float32)
img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous()
batch = {
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
"img_size": img_size_t.unsqueeze(0), # (1, N, 2)
"bbox_center": centers.unsqueeze(0), # (1, N, 2)
"bbox_scale": scales.unsqueeze(0), # (1, N, 2)
"bbox": boxes_t.unsqueeze(0), # (1, N, 4)
"affine_trans": mats.unsqueeze(0), # (1, N, 2, 3)
"mask": mask_t.unsqueeze(0).unsqueeze(2), # (1, N, 1, H_out, W_out)
"mask_score": mask_score_t.unsqueeze(0), # (1, N)
"person_valid": torch.ones((1, n), dtype=torch.float32),
}
if cam_int is not None:
batch["cam_int"] = cam_int.to(batch["img"])
else:
# Default intrinsics: focal = sqrt(W^2 + H^2), principal point = image center.
f = (height ** 2 + width ** 2) ** 0.5
batch["cam_int"] = torch.tensor(
[[[f, 0, width / 2.0], [0, f, height / 2.0], [0, 0, 1]]],
).to(batch["img"])
return batch
# Geometry utils
def rot6d_to_rotmat(
x: torch.Tensor # (B, 6) batch of 6-D rotation representations.
) -> torch.Tensor: # (B, 3, 3) rotation matrices.
"""6D continuous rotation rep (Zhou et al., CVPR 2019) -> 3x3 rotation matrix."""
x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
a1, a2 = x[:, :, 0], x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
b3 = torch.linalg.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
def perspective_projection(
x: torch.Tensor, # (B, N, 3) 3D points in camera coords.
K: torch.Tensor # (B, 3, 3) camera intrinsics.
) -> torch.Tensor: # (B, N, 2) 2D image-plane projections.
"""Project 3D points (already in camera frame) through intrinsics K."""
y = x / x[:, :, -1].unsqueeze(-1) # perspective divide
y = torch.einsum("bij,bkj->bki", K, y) # apply intrinsics
return y[:, :, :2]
# Rotation conversions, behavior mirrors the roma library (https://github.com/naver/roma)
def _axis_rotmat(axis: str, angle: torch.Tensor) -> torch.Tensor:
"""Rotation matrices around a single coordinate axis. Shape (..., 3, 3)."""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
flat = (one, zero, zero,
zero, cos, -sin,
zero, sin, cos)
elif axis == "Y":
flat = (cos, zero, sin,
zero, one, zero,
-sin, zero, cos)
elif axis == "Z":
flat = (cos, -sin, zero,
sin, cos, zero,
zero, zero, one)
else:
raise ValueError(f"Invalid axis {axis!r}; expected X/Y/Z.")
return torch.stack(flat, dim=-1).reshape(angle.shape + (3, 3))
def euler_to_rotmat(convention: str, angles: torch.Tensor) -> torch.Tensor:
"""Euler angles -> rotation matrix, matching roma's case-keyed convention."""
axes = convention.upper()
R0 = _axis_rotmat(axes[0], angles[..., 0])
R1 = _axis_rotmat(axes[1], angles[..., 1])
R2 = _axis_rotmat(axes[2], angles[..., 2])
if convention.islower():
return R2 @ R1 @ R0
return R0 @ R1 @ R2
def _index_from_letter(letter: str) -> int:
return {"X": 0, "Y": 1, "Z": 2}[letter]
def _angle_from_tan(
axis: str,
other_axis: str,
data: torch.Tensor,
horizontal: bool,
tait_bryan: bool,
) -> torch.Tensor:
"""Extract an outer Euler angle from a row/column of a rotation matrix.
Adapted from PyTorch3D's matrix_to_euler_angles helper.
"""
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ("XY", "YZ", "ZX")
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
def _matrix_to_euler_intrinsic(matrix: torch.Tensor, convention: str) -> torch.Tensor:
"""Decompose a rotation matrix into intrinsic Euler angles (uppercase abc).
Adapted from PyTorch3D's matrix_to_euler_angles.
"""
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
sign = -1.0 if (i0 - i2) in (-1, 2) else 1.0
central = torch.asin(matrix[..., i0, i2] * sign)
else:
central = torch.acos(matrix[..., i0, i0])
out = (
_angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan),
central,
_angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan),
)
return torch.stack(out, dim=-1)
def rotmat_to_euler(convention: str, matrix: torch.Tensor) -> torch.Tensor:
"""Rotation matrix -> Euler angles, inverse of :func:`euler_to_rotmat`.
PyTorch3D's matrix_to_euler_angles uses the convention R = R_a R_b R_c for
convention "abc"; that matches roma's UPPERCASE ordering directly. For
roma's lowercase, the matrix is reversed (R_c R_b R_a), so we decompose
with the reversed convention and flip the angles back to axis order.
"""
if matrix.shape[-2:] != (3, 3):
raise ValueError(f"Expected (..., 3, 3) rotation matrix, got {tuple(matrix.shape)}.")
if convention.isupper():
return _matrix_to_euler_intrinsic(matrix, convention)
decomposed = _matrix_to_euler_intrinsic(matrix, convention.upper()[::-1])
return decomposed.flip(-1)
def unitquat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
"""Unit quaternion (x, y, z, w) -> rotation matrix.
Matches roma.unitquat_to_rotmat (scalar-last). The quaternion is assumed to be normalized.
Args:
quat: (..., 4) unit quaternion.
Returns:
(..., 3, 3) rotation matrix.
"""
x, y, z, w = quat.unbind(dim=-1)
tx, ty, tz = 2 * x, 2 * y, 2 * z
twx, twy, twz = tx * w, ty * w, tz * w
txx, txy, txz = tx * x, ty * x, tz * x
tyy, tyz, tzz = ty * y, tz * y, tz * z
one = torch.ones_like(w)
flat = (
one - (tyy + tzz), txy - twz, txz + twy,
txy + twz, one - (txx + tzz), tyz - twx,
txz - twy, tyz + twx, one - (txx + tyy),
)
return torch.stack(flat, dim=-1).reshape(quat.shape[:-1] + (3, 3))

View File

@ -713,6 +713,12 @@ class File3DFBX(ComfyTypeIO):
Type = File3D
@comfytype(io_type="FILE_3D_BVH")
class File3DBVH(ComfyTypeIO):
"""BVH format 3D file - skeletal motion capture animation (no geometry)."""
Type = File3D
@comfytype(io_type="FILE_3D_OBJ")
class File3DOBJ(ComfyTypeIO):
"""OBJ format 3D file - simple geometry format."""
@ -2384,6 +2390,7 @@ __all__ = [
"File3DGLB",
"File3DGLTF",
"File3DFBX",
"File3DBVH",
"File3DOBJ",
"File3DSTL",
"File3DUSDZ",

View File

@ -4,7 +4,7 @@ BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes."""
import math
from functools import lru_cache
from typing import List, Tuple
from typing import List, Optional, Tuple
import numpy as np
import torch
@ -556,32 +556,41 @@ def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -
class FaceLandmarker(nn.Module):
"""BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short'
(128², 2m) or 'full' (192² FPN, 5m). State dict uses inner-module prefixes
`detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel
wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading.
"""
"""BlazeFace → FaceMesh v2 → blendshapes."""
def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"):
super().__init__()
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant)
self.detector_variant = detector_variant
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
if detector_variant == "both":
self.detector = BlazeFace(device=device, dtype=dtype, operations=operations)
self.detector_full = BlazeFaceFullRange(device=device, dtype=dtype, operations=operations)
else:
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}[detector_variant]
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations)
self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations)
self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False)
def _detector(self, variant: str) -> nn.Module:
if self.detector_variant == "both":
return self.detector_full if variant == "full" else self.detector
return self.detector
def run_detector_batch(self, images_rgb_uint8: List[np.ndarray],
score_thresh: float = _BF_MIN_SCORE,
iou_thresh: float = 0.5):
iou_thresh: float = 0.5,
variant: Optional[str] = None):
"""Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded)
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords."""
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords.
`variant` overrides per-call."""
if not images_rgb_uint8:
return [], [], [], []
device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype
if variant is None:
variant = "short" if self.detector_variant == "both" else self.detector_variant
detector = self._detector(variant)
device, dtype = detector.stem.weight.device, detector.stem.weight.dtype
det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range)
if self.detector_variant == "full"
if variant == "full"
else (_BF_INPUT_SIZE, _decode_blazeface))
# Same-size frames: stack once and transfer once. Variable size falls back
@ -597,7 +606,7 @@ class FaceLandmarker(nn.Module):
det_crops = [w[0] for w in warps]
sub_rects = [(w[1], w[2], w[3]) for w in warps]
regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0))
regs_b, cls_b = detector(torch.stack(det_crops, dim=0))
regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy()
per_frame = []
for b in range(len(images_rgb_uint8)):
@ -606,15 +615,22 @@ class FaceLandmarker(nn.Module):
return img_raws, sub_rects, sizes, per_frame
def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1,
score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]:
score_thresh: float = _BF_MIN_SCORE,
variant: Optional[str] = None) -> List[List[dict]]:
"""Full pipeline batched across `images_rgb_uint8`. Returns one face-dict
list per image (empty if nothing detected). Face dict:
bbox_xyxy (4,) image pixels, blendshapes {52} [0,1],
landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in
192-canonical (pre-transformation) units, presence float (raw logit).
On 'both'-mode instances `variant=None` runs both detectors and keeps
whichever found more faces per frame (tie short).
"""
if variant is None and self.detector_variant == "both":
s = self.detect_batch(images_rgb_uint8, num_faces=num_faces, score_thresh=score_thresh, variant="short")
f = self.detect_batch(images_rgb_uint8, num_faces=num_faces, score_thresh=score_thresh, variant="full")
return [s[b] if len(s[b]) >= len(f[b]) else f[b] for b in range(len(images_rgb_uint8))]
img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch(
images_rgb_uint8, score_thresh=score_thresh,
images_rgb_uint8, score_thresh=score_thresh, variant=variant,
)
# tensor-normalized → image-normalized [0,1] for _detection_to_face_rect.
for b, decoded in enumerate(per_frame_dets):

View File

@ -1,6 +1,8 @@
"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
import math
import torch
import comfy.utils
@ -391,10 +393,57 @@ class MoGePointMapToMesh(io.ComfyNode):
return io.NodeOutput(mesh)
class MoGeGeometryToFOV(io.ComfyNode):
"""Extract horizontal/vertical FOV from MoGe intrinsics, e.g. fov_y to feed SAM3DBody_Predict."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGeGeometryToFOV",
search_aliases=["moge", "fov", "geometry", "intrinsics", "field of view"],
display_name="Get FoV from MoGe Geometry",
description="Derive the field of view and focal length from MoGe intrinsics.",
category="image/geometry estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),
io.Combo.Input("axis", options=["vertical", "horizontal", "diagonal"], default="vertical",
tooltip="'vertical' (fov_y), 'horizontal' (fov_x), or 'diagonal'."),
io.Combo.Input("unit", options=["degrees", "radians"], default="degrees",
tooltip="Output unit for the FOV."),
],
outputs=[
io.Float.Output(display_name="fov"),
io.Float.Output(display_name="focal_pixels"),
],
)
@classmethod
def execute(cls, moge_geometry, axis, unit) -> io.NodeOutput:
K = moge_geometry.get("intrinsics") if isinstance(moge_geometry, dict) else None
if K is None:
raise ValueError("moge_geometry has no intrinsics (panorama geometry has none).")
if K.ndim == 3:
K = K[0]
# MoGe normalizes fx by width and fy by height; with cx=cy=0.5 the half-extent
# in normalized units is 0.5, so fov = 2*atan(0.5 / f) per axis (hypot for diagonal).
hx = 0.5 / float(K[0, 0].item())
hy = 0.5 / float(K[1, 1].item())
half_tan = {"horizontal": hx, "vertical": hy, "diagonal": math.hypot(hx, hy)}[axis]
fov_radians = 2.0 * math.atan(half_tan)
fov = fov_radians if unit == "radians" else math.degrees(fov_radians)
# Pixels are square here, so fy*H == fx*W is the single lens focal in pixels.
src = next((moge_geometry[k] for k in ("image", "points", "depth") if k in moge_geometry), None)
if src is None:
raise ValueError("moge_geometry has no image/points/depth to read the pixel height from.")
H = int(src.shape[1])
focal_pixels = float(K[1, 1].item()) * H
return io.NodeOutput(fov, focal_pixels)
class MoGeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh]
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV]
async def comfy_entrypoint() -> MoGeExtension:

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,3 @@
"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB node."""
import json
import logging
import os
@ -15,6 +13,15 @@ import folder_paths
from comfy.cli_args import args
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_extras.sam3d_body.export.bvh import build_bvh
from comfy_extras.sam3d_body.export.glb_openpose import build_glb_openpose
from comfy_extras.sam3d_body.export.glb_skeletal import build_glb_skeletal
MHRPoseData = IO.Custom("MHR_POSE_DATA")
KimodoPoseData = IO.Custom("KIMODO_POSE_DATA")
SAM3DBodyModel = IO.Custom("SAM3D_BODY_MODEL")
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False):
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
@ -335,6 +342,7 @@ class SaveGLB(IO.ComfyNode):
IO.File3DGLTF,
IO.File3DOBJ,
IO.File3DFBX,
IO.File3DBVH,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DPLY,
@ -406,10 +414,391 @@ class SaveGLB(IO.ComfyNode):
return IO.NodeOutput(ui={"3d": results})
def rainbow_tilt_inputs():
"""Shared rainbow-shader tilt inputs (used by Render and ToGLB schemas)."""
return [
IO.Float.Input(
"rainbow_tilt_z", default=-35.0, min=-90.0, max=90.0, step=0.5,
tooltip="Rotate rainbow jet axis around Z (forward). Differentiates left/right.",
),
IO.Float.Input(
"rainbow_tilt_x", default=0.0, min=-90.0, max=90.0, step=0.5,
tooltip="Rotate rainbow jet axis around X (right). Differentiates front/back.",
),
]
def camera_translation_input():
"""Shared camera_translation combo (Create 3D Animation glb + bvh paths)."""
return IO.Combo.Input(
"camera_translation",
options=["off", "centered", "absolute"],
default="off",
tooltip=(
"Bake pred_cam_t into the root's translation "
"'off' = bind position "
"'centered' = delta from frame 0 "
"'absolute' = raw (Z is camera depth — usually meters away)."
),
)
class BuildPoseFile(IO.ComfyNode):
"""Build an animated GLB from pose data, or save it as a BVH mocap file."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BuildPoseFile",
display_name="Create 3D Animation File",
description="Build an animated GLB from pose data, or save a BVH mocap file.",
search_aliases=["pose animation", "mocap", "glb", "bvh", "build pose file", "save pose file"],
category="3d",
inputs=[
IO.MultiType.Input(
"pose_data", types=[MHRPoseData, KimodoPoseData],
tooltip=("3D pose data."),
),
SAM3DBodyModel.Input("sam3d_body_model", optional=True),
IO.DynamicCombo.Input(
"format",
options=[
IO.DynamicCombo.Option("glb", [
IO.DynamicCombo.Input(
"mesh_style",
options=[
IO.DynamicCombo.Option("body_mesh", [
IO.DynamicCombo.Input(
"bone_vis",
options=[
IO.DynamicCombo.Option("off", []),
IO.DynamicCombo.Option("octahedrons", [
IO.Float.Input(
"bone_vis_radius_m",
default=0.02, min=0.005, max=0.5, step=0.005, advanced=True,
tooltip="Radius in m (sphere radius / octahedron half-width).",
),
IO.Combo.Input(
"bone_vis_color",
options=["white", "rainbow_y"],
default="rainbow_y",
tooltip=(
"Per-bone vertex colors (unlit material). "
"'white' = none, 'rainbow_y' = head→toe jet."
),
),
]),
],
tooltip=("Bone vis shape, rigidly skinned to each joint. "),
),
IO.DynamicCombo.Input(
"shader",
options=[
IO.DynamicCombo.Option("default", []),
IO.DynamicCombo.Option("rainbow", [
*rainbow_tilt_inputs(),
IO.Float.Input(
"person_palette_falloff",
default=0.6, min=0.1, max=1.0, step=0.05,
tooltip="Per-person desaturation: each track gets (1 - falloff^k) pastel mix.",
),
]),
IO.DynamicCombo.Option("rainbow_face_normal", [
*rainbow_tilt_inputs(),
IO.Float.Input(
"person_palette_falloff",
default=0.6, min=0.1, max=1.0, step=0.05,
tooltip="Per-person desaturation: each track gets (1 - falloff^k) pastel mix.",
),
]),
IO.DynamicCombo.Option("rainbow_face_semantic", [
*rainbow_tilt_inputs(),
IO.Float.Input(
"person_palette_falloff",
default=0.6, min=0.1, max=1.0, step=0.05,
tooltip="Per-person desaturation: each track gets (1 - falloff^k) pastel mix.",
),
]),
],
tooltip=(
"Bake per-vertex colors matching the Render node's shaders "
"(COLOR_0 + KHR_materials_unlit). 'default' = no colors."
),
),
]),
IO.DynamicCombo.Option("bones_only", [
IO.DynamicCombo.Input(
"bone_vis",
options=[
IO.DynamicCombo.Option("octahedrons", [
IO.Float.Input(
"bone_vis_radius_m",
default=0.02, min=0.005, max=0.5, step=0.005, advanced=True,
tooltip="Radius in m (sphere radius / octahedron half-width).",
),
IO.Combo.Input(
"bone_vis_color",
options=["white", "rainbow_y"],
default="rainbow_y",
tooltip=(
"Per-bone vertex colors (unlit material). "
"'white' = none, 'rainbow_y' = head→toe jet."
),
),
]),
],
tooltip=(
"Bone vis shape, rigidly skinned to each joint. "
"'octahedrons' = Blender-style directional bones (joint → "
"primary child)."
),
),
]),
IO.DynamicCombo.Option("openpose", [
IO.Float.Input(
"marker_radius_m", default=0.010, min=0.005, max=0.1, step=0.001, advanced=True,
tooltip="Sphere radius in m.",
),
IO.Float.Input(
"stick_radius_m", default=0.008, min=0.002, max=0.05, step=0.001, advanced=True,
tooltip="Limb half-width in m. Auto-clamped to bone_length x 0.1.",
),
IO.Boolean.Input(
"include_hands", default=False,
tooltip=(
"Append 21+21 OpenPose hands (wrist + 5 fingers x 4 joints, "
"base→tip) sourced from pred_keypoints_3d."
),
),
IO.Float.Input(
"hand_marker_radius_m", default=0.005, min=0.001, max=0.1, step=0.001, advanced=True,
tooltip="Hand sphere radius in m.",
),
IO.Float.Input(
"hand_stick_radius_m", default=0.003, min=0.001, max=0.05, step=0.001, advanced=True,
tooltip="Hand limb half-width in m.",
),
IO.Combo.Input(
"face_style",
options=["disabled", "full", "eyes_mouth"],
default="disabled",
tooltip=(
"Face-contour landmarks sampled from pred_vertices at fixed "
"head-mesh vertex IDs (needs canonical_colors on pose_data). "
"'full' = all ~30 points; 'eyes_mouth' = eyes + outer lips only."
),
),
IO.Float.Input(
"face_marker_radius_m", default=0.0, min=0.0, max=0.05, step=0.0005, advanced=True,
tooltip="Face dot radius. 0 = auto = 0.3 x marker_radius_m.",
),
]),
IO.DynamicCombo.Option("scail", [
IO.Float.Input(
"stick_radius_m", default=0.022, min=0.002, max=0.1, step=0.001, advanced=True,
tooltip=(
"Cylinder radius in m. Bones are open cylinders at constant "
"radius; joint spheres (auto-sized to match) cap the open ends. "
"SCAIL reference = 0.0215 m."
),
),
IO.Float.Input(
"marker_radius_m", default=0.0, min=0.0, max=0.1, step=0.001, advanced=True,
tooltip="Joint sphere radius. 0 = auto = stick_radius_m (flush cap).",
),
IO.Float.Input(
"material_roughness", default=0.3, min=0.0, max=1.0, step=0.05, advanced=True,
tooltip="PBR roughness. SCAIL ref = 0.3. 1 = matte; 0 = chrome.",
),
IO.Boolean.Input(
"include_hands", default=False,
tooltip="Append 21+21 hand keypoints + capsule sticks per track.",
),
IO.Float.Input(
"hand_marker_radius_m", default=0.005, min=0.001, max=0.05, step=0.001, advanced=True,
tooltip="Hand sphere radius in m.",
),
IO.Float.Input(
"hand_stick_radius_m", default=0.003, min=0.001, max=0.05, step=0.001, advanced=True,
tooltip="Hand cylinder radius in m.",
),
IO.Combo.Input(
"face_style",
options=["disabled", "full", "eyes_mouth"],
default="disabled",
tooltip=(
"Face-contour landmarks sampled from pred_vertices (needs "
"canonical_colors on pose_data). 'full' = all ~30 points; "
"'eyes_mouth' = eyes + outer lips only."
),
),
]),
],
tooltip=(
"'body_mesh' = real Armature (127 bones, skinning, TRS keyframes, 72 face morphs; needs model). "
"'bones_only' = bone-shape primitives at each joint (preview armature). "
"'openpose' = OpenPose-18 3D skeleton from keypoints "
"'scail' = SCAIL 3D capsule rig (open cylinders capped flush by joint spheres)."
),
),
IO.Int.Input(
"bone_smooth_window",
default=0, min=0, max=51, step=2,
tooltip=(
"Gaussian smoothing window on per-bone rotation keyframes / keypoint "
"tracks. 0 = off. 7-15 calms spins/jitter where upstream Smooth misses spikes."
),
),
]),
IO.DynamicCombo.Option("bvh", [
IO.Combo.Input(
"units",
options=["cm", "m"],
default="cm",
tooltip="BVH OFFSET/position units. 'cm' is the mocap standard.",
),
]),
],
tooltip=(
"Output format, both fed to Save 3D Model to write to disk. "
"'glb' = animated GLB (mesh / bones / openpose / scail). "
"'bvh' = BVH mocap clip (one skeleton; needs the model)."
),
),
IO.Float.Input(
"fps", default=24.0, min=1.0, max=240.0, step=1.0,
tooltip="Animation frame rate.",
),
camera_translation_input(),
IO.Int.Input(
"track_index", default=-1, min=-1, max=15,
tooltip="-1 = all tracks; ≥0 = single track.",
),
],
outputs=[IO.File3DAny.Output("model_3d")],
)
@classmethod
def execute(cls, pose_data, format, sam3d_body_model=None, fps=24.0, camera_translation="off", track_index=-1) -> IO.NodeOutput:
format = format or {"format": "glb"}
fmt = format.get("format", "glb")
if fmt == "bvh":
# External rigs (e.g. Kimodo) supply pose_data["_skeleton_override"]
has_external_rig = isinstance(pose_data, dict) and ("_skeleton_override" in pose_data)
if sam3d_body_model is None and not has_external_rig:
raise ValueError(
"Create 3D Animation: 'bvh' format needs the `sam3d_body_model` input OR a "
"`_skeleton_override` dict in pose_data (e.g. from KimodoSample)."
)
# BVH carries one skeleton; -1 (all tracks) collapses to the first.
ti = int(track_index)
if ti < 0:
ti = 0
bvh_bytes = build_bvh(
pose_data, sam3d_body_model,
fps=float(fps),
camera_translation=str(camera_translation),
track_index=ti,
units=str(format.get("units", "cm")),
)
return IO.NodeOutput(Types.File3D(BytesIO(bvh_bytes), file_format="bvh"))
mesh_style = format.get("mesh_style") or {"mesh_style": "body_mesh"}
bone_smooth_window = int(format.get("bone_smooth_window", 0))
mode_key = mesh_style["mesh_style"]
# `shader` is nested in body_mesh; absent for bones_only.
shader_dict = mesh_style.get("shader") or {}
shader_key = shader_dict.get("shader", "default")
common = dict(
fps=float(fps),
camera_translation=str(camera_translation),
track_index=int(track_index),
shader=str(shader_key),
rainbow_tilt_x_deg=float(shader_dict.get("rainbow_tilt_x", 0.0)),
rainbow_tilt_z_deg=float(shader_dict.get("rainbow_tilt_z", 0.0)),
person_palette_falloff=float(shader_dict.get("person_palette_falloff", 0.6)),
)
if mode_key in ("body_mesh", "bones_only"):
# External rigs (e.g. ComfyUI-Kimodo) supply pose_data["_skeleton_override"]
# so the GLB writer reads rig/bind/skin from there instead of MHR.
has_external_rig = isinstance(pose_data, dict) and ("_skeleton_override" in pose_data)
if sam3d_body_model is None and not has_external_rig:
raise ValueError(
f"BuildPoseFile: '{mode_key}' mode needs the `sam3d_body_model` input OR a "
"`_skeleton_override` dict in pose_data. Connect the SAM3DBody model "
"or feed pose_data from a node that supplies the override (e.g. KimodoSample)."
)
default_shape = "off" if mode_key == "body_mesh" else "octahedrons"
bone_vis_dict = mesh_style.get("bone_vis", {"bone_vis": default_shape})
bone_vis = str(bone_vis_dict.get("bone_vis", default_shape))
bone_vis_radius_m = float(bone_vis_dict.get("bone_vis_radius_m", 0.04))
bone_vis_color = str(bone_vis_dict.get("bone_vis_color", "white"))
glb_bytes = build_glb_skeletal(
pose_data, sam3d_body_model,
bone_smooth_window=int(bone_smooth_window),
bone_vis=bone_vis,
bone_vis_radius_m=bone_vis_radius_m,
bone_vis_color=bone_vis_color,
include_body_mesh=(mode_key == "body_mesh"),
**common,
)
elif mode_key == "openpose":
# Rig-independent: sourced from pred_keypoints_3d. face_source='rig'
# additionally reads canonical_colors for head-mesh vertex IDs.
glb_bytes = build_glb_openpose(
pose_data,
fps=float(fps),
camera_translation=str(camera_translation),
track_index=int(track_index),
marker_radius_m=float(mesh_style.get("marker_radius_m", 0.025)),
stick_radius_m=float(mesh_style.get("stick_radius_m", 0.008)),
include_hands=bool(mesh_style.get("include_hands", False)),
hand_marker_radius_m=float(mesh_style.get("hand_marker_radius_m", 0.005)),
hand_stick_radius_m=float(mesh_style.get("hand_stick_radius_m", 0.003)),
face_style=str(mesh_style.get("face_style", "disabled")),
face_marker_radius_m=float(mesh_style.get("face_marker_radius_m", 0.0)),
palette="openpose",
shape="ellipsoid",
bone_smooth_window=int(bone_smooth_window),
)
elif mode_key == "scail":
# SCAIL rig: open cylinders capped flush by joint spheres (sphere
# radius defaults to cylinder radius for a seamless silhouette).
cap_stick_radius = float(mesh_style.get("stick_radius_m", 0.022))
cap_marker_radius = float(mesh_style.get("marker_radius_m", 0.0))
if cap_marker_radius <= 0.0:
cap_marker_radius = cap_stick_radius
glb_bytes = build_glb_openpose(
pose_data,
fps=float(fps),
camera_translation=str(camera_translation),
track_index=int(track_index),
marker_radius_m=cap_marker_radius,
stick_radius_m=cap_stick_radius,
include_hands=bool(mesh_style.get("include_hands", False)),
hand_marker_radius_m=float(mesh_style.get("hand_marker_radius_m", 0.005)),
hand_stick_radius_m=float(mesh_style.get("hand_stick_radius_m", 0.003)),
face_style=str(mesh_style.get("face_style", "disabled")),
palette="scail",
shape="capsule",
smooth_shade=True,
# SCAIL material: slightly glossy (0.3) + double-sided so the
# inside of the open cylinders shades sensibly at grazing angles.
material_roughness=float(mesh_style.get("material_roughness", 0.3)),
material_double_sided=True,
bone_smooth_window=int(bone_smooth_window),
)
else:
raise ValueError(f"BuildPoseGLB: unknown mesh_style {mode_key!r}")
return IO.NodeOutput(Types.File3D(BytesIO(glb_bytes), file_format="glb"))
class Save3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [SaveGLB]
return [SaveGLB, BuildPoseFile]
async def comfy_entrypoint() -> Save3DExtension:

View File

@ -2,11 +2,10 @@ import torch
import comfy.utils
import comfy.model_management
import numpy as np
import math
import colorsys
from tqdm import tqdm
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.pose.keypoint_draw import KeypointDraw
from comfy_extras.nodes_lotus import LotusConditioning
@ -73,299 +72,6 @@ def _to_openpose_frames(all_keypoints, all_scores, height, width):
return frames
class KeypointDraw:
"""
Pose keypoint drawing class that supports both numpy and cv2 backends.
"""
def __init__(self):
try:
import cv2
self.draw = cv2
except ImportError:
self.draw = self
# Hand connections (same for both hands)
self.hand_edges = [
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
[0, 5], [5, 6], [6, 7], [7, 8], # index
[0, 9], [9, 10], [10, 11], [11, 12], # middle
[0, 13], [13, 14], [14, 15], [15, 16], # ring
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
]
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed)
self.body_limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
[10, 11], [2, 12], [12, 13], [13, 14]
]
# Head connections (1-indexed, converted to 0-indexed)
self.head_edges = [
[2, 1], [1, 15], [15, 17], [1, 16], [16, 18]
]
# Colors matching DWPose
self.colors = [
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
]
@staticmethod
def circle(canvas_np, center, radius, color, **kwargs):
"""Draw a filled circle using NumPy vectorized operations."""
cx, cy = center
h, w = canvas_np.shape[:2]
radius_int = int(np.ceil(radius))
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
if y_max <= y_min or x_max <= x_min:
return
y, x = np.ogrid[y_min:y_max, x_min:x_max]
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
"""Draw line using Bresenham's algorithm with NumPy operations."""
x0, y0, x1, y1 = *pt1, *pt2
h, w = canvas_np.shape[:2]
dx, dy = abs(x1 - x0), abs(y1 - y0)
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
err, x, y, line_points = dx - dy, x0, y0, []
while True:
line_points.append((x, y))
if x == x1 and y == y1:
break
e2 = 2 * err
if e2 > -dy:
err, x = err - dy, x + sx
if e2 < dx:
err, y = err + dx, y + sy
if thickness > 1:
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
for px, py in line_points:
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
if y_max > y_min and x_max > x_min:
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
else:
line_points = np.array(line_points)
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
if (valid_points := line_points[valid]).size:
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
@staticmethod
def fillConvexPoly(canvas_np, pts, color, **kwargs):
"""Fill polygon using vectorized scanline algorithm."""
if len(pts) < 3:
return
pts = np.array(pts, dtype=np.int32)
h, w = canvas_np.shape[:2]
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
if y_max <= y_min or x_max <= x_min:
return
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
for i in range(len(pts)):
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
y1, y2 = p1[1], p2[1]
if y1 == y2:
continue
if y1 > y2:
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
continue
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
"""Python implementation of cv2.ellipse2Poly."""
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
angle = angle % 360
if arc_start > arc_end:
arc_start, arc_end = arc_end, arc_start
while arc_start < 0:
arc_start, arc_end = arc_start + 360, arc_end + 360
while arc_end > 360:
arc_end, arc_start = arc_end - 360, arc_start - 360
if arc_end - arc_start > 360:
arc_start, arc_end = 0, 360
angle_rad = math.radians(angle)
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
pts = []
for i in range(arc_start, arc_end + delta, delta):
theta_rad = math.radians(min(i, arc_end))
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
unique_pts, prev_pt = [], (float('inf'), float('inf'))
for pt in pts:
if (pt_tuple := tuple(pt)) != prev_pt:
unique_pts.append(pt)
prev_pt = pt_tuple
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
draw_body=True, draw_head=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3):
"""
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
Expected keypoint format (after neck insertion and remapping):
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
- Foot: 18-23 (6 keypoints)
- Face: 24-91 (68 landmarks)
- Right hand: 92-112 (21 keypoints)
- Left hand: 113-133 (21 keypoints)
Args:
canvas: The canvas to draw on (numpy array)
keypoints: Array of keypoint coordinates
scores: Optional confidence scores for each keypoint
threshold: Minimum confidence threshold for drawing keypoints
Returns:
canvas: The canvas with keypoints drawn
"""
H, W, C = canvas.shape
# Draw body limbs & head connections
if (draw_body or draw_head) and len(keypoints) >= 18:
colorIndexOffset = 0
edges = []
if draw_body:
edges += self.body_limbSeq
else:
colorIndexOffset += len(self.body_limbSeq)
if draw_head:
edges += self.head_edges
for i, limb in enumerate(edges):
# Convert from 1-indexed to 0-indexed
idx1, idx2 = limb[0] - 1, limb[1] - 1
if idx1 >= 18 or idx2 >= 18:
continue
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
Y = [keypoints[idx1][0], keypoints[idx2][0]]
X = [keypoints[idx1][1], keypoints[idx2][1]]
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
if length < 1:
continue
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
self.draw.fillConvexPoly(canvas, polygon, self.colors[(i + colorIndexOffset) % len(self.colors)])
# Draw body & head keypoints
if (draw_body or draw_head) and len(keypoints) >= 18:
head_keypoints = {0, 14, 15, 16, 17} # nose, eyes, ears
neck_point = 1
for i in range(18):
if not draw_head and i in head_keypoints:
continue
if not draw_body and i not in head_keypoints and i != neck_point:
continue
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
# Draw foot keypoints (18-23, 6 keypoints)
if draw_feet and len(keypoints) >= 24:
for i in range(18, 24):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
# Draw right hand (92-112)
if draw_hands and len(keypoints) >= 113:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 92 + edge[0], 92 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
# Draw right hand keypoints
for i in range(92, 113):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
# Draw left hand (113-133)
if draw_hands and len(keypoints) >= 134:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 113 + edge[0], 113 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
# Draw left hand keypoints
for i in range(113, 134):
if scores is not None and i < len(scores) and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
# Draw face keypoints (24-91) - white dots only, no lines
if draw_face and len(keypoints) >= 92:
eps = 0.01
for i in range(24, 92):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
return canvas
class SDPoseDrawKeypoints(io.ComfyNode):
@classmethod
def define_schema(cls):

View File

@ -0,0 +1,371 @@
"""Pose keypoint drawing primitives shared across pose nodes.
`KeypointDraw` exposes a cv2-or-numpy backend so callers can use one drawing
API regardless of whether OpenCV is installed:
kd = KeypointDraw()
kd.draw.circle(canvas, (x, y), radius, color, thickness=-1)
kd.draw.line(canvas, p1, p2, color, thickness=4)
kd.draw.fillConvexPoly(canvas, polygon, color)
kd.draw.ellipse2Poly(center, axes, angle, 0, 360, 1)
It also carries DWPose's body/hand topology + color tables, used by:
- comfy_extras.nodes_sdpose (SDPose pose drawing)
- comfy_extras.pose.export.openpose_2d (SAM 3D Body 2D pose viz)
- comfy_extras.pose.export.glb_shared (SAM 3D Body GLB tables)
"""
import colorsys
import math
import numpy as np
class KeypointDraw:
"""
Pose keypoint drawing class that supports both numpy and cv2 backends.
"""
def __init__(self):
try:
import cv2
self.draw = cv2
except ImportError:
self.draw = self
# Hand connections (same for both hands)
self.hand_edges = [
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
[0, 5], [5, 6], [6, 7], [7, 8], # index
[0, 9], [9, 10], [10, 11], [11, 12], # middle
[0, 13], [13, 14], [14, 15], [15, 16], # ring
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
]
# Head connections (1-indexed, converted to 0-indexed): nose-neck, eyes, ears
self.head_edges = [
[2, 1], [1, 15], [15, 17], [1, 16], [16, 18]
]
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed).
# body_limbSeq is the full 18-point skeleton (body + head_edges last); the head
# edges are kept as the trailing entries so callers can toggle them via draw_head.
self.body_limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
[10, 11], [2, 12], [12, 13], [13, 14],
] + self.head_edges
# Colors matching DWPose
self.colors = [
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
]
@staticmethod
def circle(canvas_np, center, radius, color, **kwargs):
"""Draw a filled circle using NumPy vectorized operations."""
cx, cy = center
h, w = canvas_np.shape[:2]
radius_int = int(np.ceil(radius))
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
if y_max <= y_min or x_max <= x_min:
return
y, x = np.ogrid[y_min:y_max, x_min:x_max]
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
"""Draw line using Bresenham's algorithm with NumPy operations."""
x0, y0, x1, y1 = *pt1, *pt2
h, w = canvas_np.shape[:2]
dx, dy = abs(x1 - x0), abs(y1 - y0)
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
err, x, y, line_points = dx - dy, x0, y0, []
while True:
line_points.append((x, y))
if x == x1 and y == y1:
break
e2 = 2 * err
if e2 > -dy:
err, x = err - dy, x + sx
if e2 < dx:
err, y = err + dx, y + sy
if thickness > 1:
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
for px, py in line_points:
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
if y_max > y_min and x_max > x_min:
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
else:
line_points = np.array(line_points)
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
if (valid_points := line_points[valid]).size:
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
@staticmethod
def fillConvexPoly(canvas_np, pts, color, **kwargs):
"""Fill polygon using vectorized scanline algorithm."""
if len(pts) < 3:
return
pts = np.array(pts, dtype=np.int32)
h, w = canvas_np.shape[:2]
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
if y_max <= y_min or x_max <= x_min:
return
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
for i in range(len(pts)):
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
y1, y2 = p1[1], p2[1]
if y1 == y2:
continue
if y1 > y2:
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
continue
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
canvas_np[y_min:y_max, x_min:x_max][mask] = color
@staticmethod
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
"""Python implementation of cv2.ellipse2Poly."""
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
angle = angle % 360
if arc_start > arc_end:
arc_start, arc_end = arc_end, arc_start
while arc_start < 0:
arc_start, arc_end = arc_start + 360, arc_end + 360
while arc_end > 360:
arc_end, arc_start = arc_end - 360, arc_start - 360
if arc_end - arc_start > 360:
arc_start, arc_end = 0, 360
angle_rad = math.radians(angle)
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
pts = []
for i in range(arc_start, arc_end + delta, delta):
theta_rad = math.radians(min(i, arc_end))
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
unique_pts, prev_pt = [], (float('inf'), float('inf'))
for pt in pts:
if (pt_tuple := tuple(pt)) != prev_pt:
unique_pts.append(pt)
prev_pt = pt_tuple
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
draw_body=True, draw_head=True, draw_feet=True, draw_face=True, draw_hands=True,
stick_width=4, face_point_size=3,
marker_radius=4, hand_stick_width=2, hand_marker_radius=4,
limb_alpha=1.0, hand_dot_color=(0, 0, 255)):
"""
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
Expected keypoint format (after neck insertion and remapping):
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
- Foot: 18-23 (6 keypoints)
- Face: 24-91 (68 landmarks)
- Right hand: 92-112 (21 keypoints)
- Left hand: 113-133 (21 keypoints)
Args:
canvas: The canvas to draw on (numpy array)
keypoints: Array of keypoint coordinates
scores: Optional confidence scores for each keypoint
threshold: Minimum confidence threshold for drawing keypoints
draw_head: Toggle head edges/keypoints (nose, eyes, ears) independently of draw_body.
stick_width: Body limb half-width (passed to ellipse2Poly).
face_point_size: Radius of the white face dots.
marker_radius: Radius of body/foot dots. Defaults to 4 (DWPose).
hand_stick_width: Thickness of hand limb lines. Defaults to 2.
hand_marker_radius: Radius of hand dots. Defaults to 4.
limb_alpha: Body-limb alpha blend (0..1). 1.0 = opaque fill (default),
<1.0 enables per-limb bbox-clipped alpha overlay (DWPose semantics
where overlapping limbs darken).
hand_dot_color: Either an (R, G, B) tuple/list of ints for solid-color
hand dots (default (0, 0, 255), DWPose blue), or a (21, 3) array
for per-keypoint hand-dot colors (OpenPose-style rainbow palette).
Returns:
canvas: The canvas with keypoints drawn
"""
H, W, C = canvas.shape
# Normalize hand_dot_color to a (21, 3) int array.
hdc_arr = np.asarray(hand_dot_color, dtype=int)
if hdc_arr.ndim == 1:
hdc_arr = np.tile(hdc_arr.reshape(1, 3), (21, 1))
hand_dot_tuples = [tuple(int(c) for c in hdc_arr[i]) for i in range(21)]
do_alpha = float(limb_alpha) < 1.0
# Draw body limbs & head connections. body_limbSeq holds the full skeleton
# with head edges trailing; draw_body / draw_head toggle each group while the
# color index stays aligned to the full sequence.
if (draw_body or draw_head) and len(keypoints) >= 18:
body_core = self.body_limbSeq[:len(self.body_limbSeq) - len(self.head_edges)]
edges, color_offset = [], 0
if draw_body:
edges += body_core
else:
color_offset += len(body_core)
if draw_head:
edges += self.head_edges
for i, limb in enumerate(edges):
# Convert from 1-indexed to 0-indexed
idx1, idx2 = limb[0] - 1, limb[1] - 1
if idx1 >= 18 or idx2 >= 18:
continue
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
Y = [keypoints[idx1][0], keypoints[idx2][0]]
X = [keypoints[idx1][1], keypoints[idx2][1]]
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
if length < 1:
continue
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
color = self.colors[(i + color_offset) % len(self.colors)]
if do_alpha:
_fill_poly_alpha(canvas, polygon, color, limb_alpha, self.draw)
else:
self.draw.fillConvexPoly(canvas, polygon, color)
# Draw body & head keypoints
if (draw_body or draw_head) and len(keypoints) >= 18:
head_keypoints = {0, 14, 15, 16, 17} # nose, eyes, ears
neck_point = 1
for i in range(18):
if not draw_head and i in head_keypoints:
continue
if not draw_body and i not in head_keypoints and i != neck_point:
continue
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), marker_radius, self.colors[i % len(self.colors)], thickness=-1)
# Draw foot keypoints (18-23, 6 keypoints)
if draw_feet and len(keypoints) >= 24:
for i in range(18, 24):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), marker_radius, self.colors[i % len(self.colors)], thickness=-1)
# Draw right hand (92-112)
if draw_hands and len(keypoints) >= 113:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 92 + edge[0], 92 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=hand_stick_width)
# Draw right hand keypoints
for i in range(92, 113):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), hand_marker_radius, hand_dot_tuples[i - 92], thickness=-1)
# Draw left hand (113-133)
if draw_hands and len(keypoints) >= 134:
eps = 0.01
for ie, edge in enumerate(self.hand_edges):
idx1, idx2 = 113 + edge[0], 113 + edge[1]
if scores is not None:
if scores[idx1] < threshold or scores[idx2] < threshold:
continue
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
# HSV to RGB conversion for rainbow colors
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
color = (int(r * 255), int(g * 255), int(b * 255))
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=hand_stick_width)
# Draw left hand keypoints
for i in range(113, 134):
if scores is not None and i < len(scores) and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), hand_marker_radius, hand_dot_tuples[i - 113], thickness=-1)
# Draw face keypoints (24-91) - white dots only, no lines
if draw_face and len(keypoints) >= 92:
eps = 0.01
for i in range(24, 92):
if scores is not None and scores[i] < threshold:
continue
x, y = int(keypoints[i][0]), int(keypoints[i][1])
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
return canvas
def _fill_poly_alpha(canvas, polygon, color, alpha, draw_backend):
"""Bbox-clipped alpha-blended fillConvexPoly. `canvas` is mutated in-place.
DWPose semantics: each limb blends with `alpha` independently so overlapping
limbs darken further. Operates on the polygon's bbox to avoid copying the
whole canvas per limb.
"""
H, W = canvas.shape[:2]
poly_arr = np.asarray(polygon, dtype=np.int32)
x0 = max(0, int(poly_arr[:, 0].min()))
xN = min(W, int(poly_arr[:, 0].max()) + 1)
y0 = max(0, int(poly_arr[:, 1].min()))
yN = min(H, int(poly_arr[:, 1].max()) + 1)
if xN <= x0 or yN <= y0:
return
local_poly = poly_arr - np.array([x0, y0], dtype=poly_arr.dtype)
roi = canvas[y0:yN, x0:xN].copy()
draw_backend.fillConvexPoly(roi, local_poly, color)
a = float(alpha)
canvas[y0:yN, x0:xN] = np.clip(
roi.astype(np.float32) * a + canvas[y0:yN, x0:xN].astype(np.float32) * (1.0 - a),
0, 255,
).astype(np.uint8)

View File

@ -0,0 +1,207 @@
"""BVH export for SAM 3D Body pose_data.
BVH stores explicit bone OFFSETs per joint, so standard importers reconstruct
anatomical bone orientations directly (unlike glTF). We skip the rig's joint 0
(static world anchor) and use joint 1 as the ROOT (6 channels: XYZ pos + ZXY
rot); other joints get 3 channels. Rotations are intrinsic Z-X-Y Euler degrees.
"""
from __future__ import annotations
import io
from typing import Any, Dict, List
import numpy as np
from .glb_shared import (
Rig,
bone_locals_from_globals,
collect_tracks,
global_skel_state_from_pose_data,
quat_sign_fix_per_joint,
unflip,
)
def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray:
"""xyzw quat → intrinsic Z-X-Y Euler degrees, returned as (..., 3) in
(z, x, y) order to match BVH's `CHANNELS Zrotation Xrotation Yrotation`."""
q = np.asarray(quat, dtype=np.float64)
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
# ZXY decomposition R = Rz(c)·Rx(a)·Ry(b):
# M[2][1] = 2(yz + xw) → sin(a)
# M[0][1] = 2(xy - zw) → -cos(a) sin(c)
# M[1][1] = 1 - 2(x² + z²) → cos(a) cos(c)
# M[2][0] = 2(xz - yw) → -cos(a) sin(b)
# M[2][2] = 1 - 2(x² + y²) → cos(a) cos(b)
M21 = np.clip(2.0 * (y * z + x * w), -1.0, 1.0)
M01 = 2.0 * (x * y - z * w)
M11 = 1.0 - 2.0 * (x * x + z * z)
M20 = 2.0 * (x * z - y * w)
M22 = 1.0 - 2.0 * (x * x + y * y)
a = np.arcsin(M21)
c = np.arctan2(-M01, M11)
b = np.arctan2(-M20, M22)
out = np.stack([np.rad2deg(c), np.rad2deg(a), np.rad2deg(b)], axis=-1)
return out.astype(np.float32)
def _find_bvh_root(parents: np.ndarray, is_external: bool = False) -> int:
"""First child of the rig's world anchor, dropping the origin→body stick.
Falls back to the first root joint. External rigs whose root is already the
articulated body root with multiple child chains keep the root descending
into one child would drop the sibling limbs."""
NJ = parents.shape[0]
world_anchors = [j for j in range(NJ)
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
if not world_anchors:
return 0
children: List[List[int]] = [[] for _ in range(NJ)]
for j in range(NJ):
p = int(parents[j])
if 0 <= p < NJ and p != j:
children[p].append(j)
wa = world_anchors[0]
if children[wa]:
if is_external and len(children[wa]) > 1:
return wa
return children[wa][0]
return wa
def _build_children_map(parents: np.ndarray) -> List[List[int]]:
NJ = parents.shape[0]
out: List[List[int]] = [[] for _ in range(NJ)]
for j in range(NJ):
p = int(parents[j])
if 0 <= p < NJ and p != j:
out[p].append(j)
return out
def build_bvh(
pose_data: Dict[str, Any],
model: Any = None,
*,
fps: float = 24.0,
camera_translation: str = "off",
track_index: int = -1,
units: str = "cm",
) -> bytes:
"""Build a BVH file from pose_data. Returns UTF-8 text bytes.
`model` may be None when pose_data carries a `_skeleton_override` (external
rigs); the rig hierarchy/offsets/bind come from the override. `units` is
"cm" (default) or "m" affects OFFSET/root-position, not rotations.
"""
if units not in ("cm", "m"):
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
unit_scale = 100.0 if units == "cm" else 1.0
rig = Rig.from_pose_data(pose_data, model)
is_external = not rig.can_rerun_fk
NJ = rig.num_joints
parents = rig.parents
frames = pose_data["frames"]
tracks = collect_tracks(pose_data, track_index)
if not tracks:
raise ValueError("build_bvh: no valid tracks in pose_data")
person_k, frame_indices = tracks[0]
n_frames = len(frame_indices)
if n_frames == 0:
raise ValueError("build_bvh: track has zero frames")
body_root = _find_bvh_root(parents, is_external)
children_map = _build_children_map(parents)
# Bone OFFSETs = translation_offsets (joint position relative to parent).
# The BVH root uses its bind world position so the skeleton imports in place.
bind_global = rig.bind_global_cm # (NJ, 8) cm
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
offset_m = rig.joint_offsets_cm.astype(np.float64) * 0.01
# DFS order rooted at body_root — matches per-frame channel order.
bvh_order: List[int] = []
def _visit(j: int) -> None:
bvh_order.append(j)
for c in children_map[j]:
_visit(c)
_visit(body_root)
# Stored pred_global_rots/pred_joint_coords (authoritative); derive locals
# with body_root as the BVH-space hierarchy root.
rig_global_m = global_skel_state_from_pose_data(
pose_data, frame_indices, person_k, NJ,
joint_coords_y_down=rig.per_frame_y_down,
)
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
bvh_parents = parents.copy()
bvh_parents[body_root] = -1
bone_local = bone_locals_from_globals(rig_global_m, bvh_parents)
# Second pass catches sign discontinuities from the parent-inverse composition.
bone_local[..., 3:7] = quat_sign_fix_per_joint(bone_local[..., 3:7])
eulers_deg = _quat_to_zxy_euler_deg(bone_local[..., 3:7])
if camera_translation in ("absolute", "centered"):
cam_t = np.stack([
unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32))
for t in frame_indices
], axis=0).astype(np.float64)
if camera_translation == "centered":
cam_t = cam_t - cam_t[0:1]
root_pos_m = bind_pos_m[body_root][None, :] + cam_t
else:
root_pos_m = np.tile(bind_pos_m[body_root], (n_frames, 1))
lines: List[str] = ["HIERARCHY"]
def _emit_joint(j: int, depth: int, is_root: bool) -> None:
ind = " " * depth
keyword = "ROOT" if is_root else "JOINT"
name = "Hips" if is_root else f"joint_{j:03d}"
lines.append(f"{ind}{keyword} {name}")
lines.append(ind + "{")
o = (bind_pos_m[j] if is_root else offset_m[j]) * unit_scale
lines.append(f"{ind} OFFSET {o[0]:.6f} {o[1]:.6f} {o[2]:.6f}")
if is_root:
lines.append(ind + " CHANNELS 6 Xposition Yposition Zposition "
"Zrotation Xrotation Yrotation")
else:
lines.append(ind + " CHANNELS 3 Zrotation Xrotation Yrotation")
kids = children_map[j]
if kids:
for c in kids:
_emit_joint(c, depth + 1, is_root=False)
else:
# End Site (standard BVH spec) gives leaf bones a drawable length.
lines.append(ind + " End Site")
lines.append(ind + " {")
tip = (offset_m[j] * unit_scale) * 0.3
tip_norm = float(np.linalg.norm(tip))
if tip_norm < 0.5 * unit_scale * 0.01: # < 0.5 mm → fall back
tip = np.array([0.0, 0.05 * unit_scale, 0.0], dtype=np.float64)
lines.append(f"{ind} OFFSET {tip[0]:.6f} {tip[1]:.6f} {tip[2]:.6f}")
lines.append(ind + " }")
lines.append(ind + "}")
_emit_joint(body_root, 0, is_root=True)
lines.append("MOTION")
lines.append(f"Frames: {n_frames}")
lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
# Channel matrix per frame: root pos (3) + root rot (3) + non-root rots
# (3 each), columns in `bvh_order`. savetxt is far faster than f-strings.
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
motion = np.concatenate([
root_pos_m * unit_scale, # (N, 3)
eulers_deg[:, body_root].astype(np.float64), # (N, 3)
eulers_deg[:, non_root_idx, :].reshape(n_frames, -1), # (N, 3*(NJ-1))
], axis=1)
buf = io.StringIO()
np.savetxt(buf, motion, fmt="%.6f")
lines.append(buf.getvalue().rstrip("\n"))
return ("\n".join(lines) + "\n").encode("utf-8")

View File

@ -0,0 +1,397 @@
"""3D capsule rendering for OpenPose-style skeletons (SCAIL-Pose look).
Each limb is a true 3D capsule (cylinder + hemispherical caps), projected
through the per-person camera (`pred_cam_t` + `focal_length` + image_size) so
closer limbs appear thicker/brighter. Self-contained analytic ray-capsule
renderer. Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
"""
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import comfy.model_management
from .glb_shared import (
OPENPOSE_18_PAIRS,
OPENPOSE_RAINBOW_18,
SCAIL_LIMB_COLORS_17,
OPENPOSE_HAND_PAIRS,
OPENPOSE_HAND_COLORS_21,
openpose_render_keypoints,
)
def _limb_palette_rgb01(palette: str) -> np.ndarray:
"""17 per-limb RGB colors in [0,1] for the OpenPose-18 body limbs."""
if palette == "scail":
return SCAIL_LIMB_COLORS_17.astype(np.float32)
return OPENPOSE_RAINBOW_18[: len(OPENPOSE_18_PAIRS)].astype(np.float32)
def _build_specs_from_pose(
persons: List[Dict[str, Any]],
pose_data: Dict[str, Any],
*,
include_hands: bool,
palette: str,
person_brightness_falloff: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Flatten body + optional hand limbs for one frame into (starts, ends,
colors_rgba, is_hand) in camera coords (Y-down, +Z forward). Drops non-finite
or behind-camera endpoints; `is_hand` lets the renderer draw hands thinner.
`person_brightness_falloff` mixes each per-person color toward white by
`1 - falloff^k` for track k (track 0 stays vivid)."""
starts: List[np.ndarray] = []
ends: List[np.ndarray] = []
colors: List[np.ndarray] = []
is_hand: List[bool] = []
body_limb_colors = _limb_palette_rgb01(palette)
hand_limb_colors = OPENPOSE_HAND_COLORS_21.astype(np.float32)
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
for k, person in enumerate(persons):
cam_t = person.get("pred_cam_t")
body_op = openpose_render_keypoints(person, pose_data, "body", dim=3)
if body_op is None or cam_t is None:
continue
cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3)
# op-keypoints are camera frame; add cam_t to place the subject in front.
body_kp = body_op + cam_t_np[None, :]
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
def _tint(rgb: np.ndarray) -> np.ndarray:
if pastel <= 0:
return rgb
return rgb * (1.0 - pastel) + pastel
# SCAIL skips face bones (13..16) and redirects limb 12 into a short
# head stub blending spine + neck→nose direction.
body_limb_count = 13 if palette == "scail" else len(OPENPOSE_18_PAIRS)
spine_dir = None
if palette == "scail":
mid_hip = 0.5 * (body_kp[8] + body_kp[11]) # 8=RHip, 11=LHip
sd = body_kp[1] - mid_hip # 1=Neck
sd_len = float(np.linalg.norm(sd))
if np.all(np.isfinite(sd)) and sd_len > 1e-6:
spine_dir = sd / sd_len
for limb_i, (a, b) in enumerate(OPENPOSE_18_PAIRS[:body_limb_count]):
sa, sb = body_kp[a], body_kp[b]
if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))):
continue
if sa[2] <= 0 or sb[2] <= 0:
continue
if palette == "scail" and limb_i == 12:
nose_vec = sb - sa
nose_len = float(np.linalg.norm(nose_vec))
if nose_len > 1e-6 and spine_dir is not None:
nose_dir = nose_vec / nose_len
mixed = 0.6 * spine_dir + 0.4 * nose_dir
mixed = mixed / max(float(np.linalg.norm(mixed)), 1e-6)
sb = sa + mixed * (nose_len * 0.5)
elif nose_len > 1e-6:
sb = sa + nose_vec * 0.5
elif spine_dir is not None:
sb = sa + spine_dir * (sd_len * 0.3)
starts.append(sa)
ends.append(sb)
is_hand.append(False)
color_rgb = _tint(body_limb_colors[limb_i])
colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0],
dtype=np.float32))
if include_hands:
hand_ops = [openpose_render_keypoints(person, pose_data, p, dim=3)
for p in ("hand_r", "hand_l")]
hand_kps = [h + cam_t_np[None, :] for h in hand_ops if h is not None]
for limb_i, (a, b) in enumerate(OPENPOSE_HAND_PAIRS):
for hand_kp in hand_kps:
sa, sb = hand_kp[a], hand_kp[b]
if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))):
continue
if sa[2] <= 0 or sb[2] <= 0:
continue
starts.append(sa)
ends.append(sb)
is_hand.append(True)
color_rgb = _tint(hand_limb_colors[(a + b) % len(hand_limb_colors)])
colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0],
dtype=np.float32))
if not starts:
return (np.zeros((0, 3), dtype=np.float32),
np.zeros((0, 3), dtype=np.float32),
np.zeros((0, 4), dtype=np.float32),
np.zeros((0,), dtype=bool))
return (np.stack(starts).astype(np.float32),
np.stack(ends).astype(np.float32),
np.stack(colors).astype(np.float32),
np.asarray(is_hand, dtype=bool))
def _ray_capsule_t(
ray_dirs: torch.Tensor, # (K, 3) unit rays from camera origin
starts: torch.Tensor, # (M, 3)
ends: torch.Tensor, # (M, 3)
ba_norm: torch.Tensor, # (M, 3) unit axis (A → B)
ba_len: torch.Tensor, # (M,) segment length
radius: torch.Tensor, # (M,) per-capsule radius
) -> torch.Tensor:
"""Closed-form ray-capsule intersection -> (K, M) ray params t to the nearest
valid hit per capsule, +inf on miss. Capsule = union of (cylinder, hemisphere
at A, hemisphere at B), each a quadratic root-find."""
INF = float("inf")
r_sq = radius * radius # (M,)
# Cached dot products.
dn = ray_dirs @ ba_norm.transpose(0, 1) # (K, M) — d·n
dA = ray_dirs @ starts.transpose(0, 1) # (K, M) — d·A
dB = ray_dirs @ ends.transpose(0, 1) # (K, M) — d·B
An = (starts * ba_norm).sum(-1) # (M,) — A·n
A_sq = (starts * starts).sum(-1) # (M,) — |A|²
B_sq = (ends * ends).sum(-1) # (M,) — |B|²
# Cylinder body: project onto plane ⊥ n and solve |P_⊥(t)|² = r².
a_c = 1.0 - dn * dn # (K, M)
b_c = -2.0 * (dA - dn * An) # (K, M)
c_c = A_sq - An * An - r_sq # (M,)
disc_c = b_c * b_c - 4.0 * a_c * c_c
safe_a = a_c.clamp(min=1e-9)
sqrt_c = torch.sqrt(disc_c.clamp(min=0.0))
t_cyl = (-b_c - sqrt_c) / (2.0 * safe_a)
s_cyl = t_cyl * dn - An # axial projection from A
cyl_ok = (disc_c >= 0) & (a_c > 1e-7) & (t_cyl > 1e-6) & \
(s_cyl >= 0.0) & (s_cyl <= ba_len)
t_cyl = torch.where(cyl_ok, t_cyl, torch.full_like(t_cyl, INF))
# Sphere at A, restricted to the hemisphere with axial projection ≤ 0.
disc_a = dA * dA - (A_sq - r_sq)
sqrt_a = torch.sqrt(disc_a.clamp(min=0.0))
t_sa = dA - sqrt_a
s_a = t_sa * dn - An
a_ok = (disc_a >= 0) & (t_sa > 1e-6) & (s_a <= 0.0)
t_sa = torch.where(a_ok, t_sa, torch.full_like(t_sa, INF))
# Sphere at B, restricted to the hemisphere with axial projection ≥ ba_len.
disc_b = dB * dB - (B_sq - r_sq)
sqrt_b = torch.sqrt(disc_b.clamp(min=0.0))
t_sb = dB - sqrt_b
s_b = t_sb * dn - An
b_ok = (disc_b >= 0) & (t_sb > 1e-6) & (s_b >= ba_len)
t_sb = torch.where(b_ok, t_sb, torch.full_like(t_sb, INF))
return torch.minimum(torch.minimum(t_cyl, t_sa), t_sb)
def _render_capsules_torch(
starts: torch.Tensor,
ends: torch.Tensor,
colors: torch.Tensor,
H: int, W: int,
fx: float, fy: float, cx: float, cy: float,
radius: torch.Tensor, # scalar or (M,) per-capsule radius
background_rgb: Optional[torch.Tensor],
device: torch.device,
flat_shade: bool = False,
) -> torch.Tensor:
"""Analytic ray-capsule renderer for a union of capsules. Camera at
origin looking down +Z; pixels in y-down screen coords."""
M = int(starts.shape[0])
if M == 0:
if background_rgb is not None:
return background_rgb.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)
return torch.zeros(H, W, 3, dtype=torch.float32, device=device)
yy, xx = torch.meshgrid(
torch.arange(H, device=device, dtype=torch.float32),
torch.arange(W, device=device, dtype=torch.float32),
indexing="ij",
)
u = (xx - cx) / fx
v = (yy - cy) / fy
z = torch.ones_like(u)
ray_dirs = torch.stack([u, v, z], dim=-1)
ray_dirs = ray_dirs / torch.linalg.norm(ray_dirs, dim=-1, keepdim=True)
flat_dirs = ray_dirs.view(-1, 3)
N = flat_dirs.shape[0]
radius = torch.as_tensor(radius, device=device, dtype=torch.float32)
if radius.ndim == 0:
radius = radius.expand(M)
ba = ends - starts
ba_len = torch.linalg.norm(ba, dim=1).clamp(min=1e-6)
ba_norm = ba / ba_len.unsqueeze(1)
z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item()))
z_near = max(0.05, z_min - float(radius.max().item()))
# Union of per-capsule screen-space bboxes — pixels outside can't hit any
# capsule, so intersection only runs on the relevant subset of the canvas.
sz = starts[:, 2].clamp(min=z_near)
ez = ends[:, 2].clamp(min=z_near)
sx_p = starts[:, 0] * fx / sz + cx
sy_p = starts[:, 1] * fy / sz + cy
ex_p = ends[:, 0] * fx / ez + cx
ey_p = ends[:, 1] * fy / ez + cy
# Projected radius using the closer endpoint — conservative bbox.
r_pix = radius * fx / torch.minimum(sz, ez)
pad = 2.0
xmin_t = (torch.minimum(sx_p, ex_p) - r_pix - pad).floor().long().clamp(min=0, max=W)
xmax_t = (torch.maximum(sx_p, ex_p) + r_pix + pad).ceil().long().clamp(min=0, max=W)
ymin_t = (torch.minimum(sy_p, ey_p) - r_pix - pad).floor().long().clamp(min=0, max=H)
ymax_t = (torch.maximum(sy_p, ey_p) + r_pix + pad).ceil().long().clamp(min=0, max=H)
# One stack→tolist sync amortizes the GPU→CPU read over all M bboxes.
bboxes_cpu = torch.stack([xmin_t, ymin_t, xmax_t, ymax_t], dim=1).tolist()
coarse_mask = torch.zeros(H, W, dtype=torch.bool, device=device)
for xmin_i, ymin_i, xmax_i, ymax_i in bboxes_cpu:
if xmax_i > xmin_i and ymax_i > ymin_i:
coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True
# Analytic ray-capsule intersection, one pass over the masked pixels.
INF = float("inf")
flat_t = torch.full((N,), INF, device=device, dtype=torch.float32)
flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long)
active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1)
if active_idx.numel() > 0:
# Cap per-chunk (K, M) tensors to ~4M elements to bound peak memory.
chunk_max = max(1, int(4_000_000 / max(M, 1)))
for i0 in range(0, active_idx.numel(), chunk_max):
sub = active_idx[i0 : i0 + chunk_max]
t_KM = _ray_capsule_t(
flat_dirs[sub], starts, ends, ba_norm, ba_len, radius,
)
t_min, m_idx = t_KM.min(dim=1)
hit = t_min < INF
if hit.any():
winners = sub[hit]
flat_t[winners] = t_min[hit]
flat_m_idx[winners] = m_idx[hit]
# Shade via analytic normal (P - closest point on segment).
out = torch.zeros((N, 3), dtype=torch.float32, device=device)
if background_rgb is not None:
out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone()
hit_idx = torch.nonzero(flat_m_idx >= 0, as_tuple=False).squeeze(1)
if hit_idx.numel() > 0:
rd = flat_dirs[hit_idx]
t_h = flat_t[hit_idx]
m_h = flat_m_idx[hit_idx]
p_hit = rd * t_h.unsqueeze(-1)
A_h = starts[m_h]
n_h = ba_norm[m_h]
L_h = ba_len[m_h]
proj = ((p_hit - A_h) * n_h).sum(-1).clamp(min=0.0)
proj = torch.minimum(proj, L_h)
C_h = A_h + proj.unsqueeze(-1) * n_h
normals = p_hit - C_h
normals = normals / normals.norm(dim=-1, keepdim=True).clamp(min=1e-8)
col = colors[m_h, :3]
if flat_shade:
# Solid per-limb color (OpenPose look) — no lighting/depth.
out[hit_idx] = col
return out.view(H, W, 3).clamp(0.0, 1.0)
# SCAIL Blinn-Phong, headlight along +Z.
diff = torch.clamp(-(normals[:, 2]), min=0.0)
diffuse = 0.45 + 0.55 * diff
view_dir = -rd
half_dir = view_dir.clone()
half_dir[:, 2] -= 1.0
half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8)
spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32)
# Mild depth fade.
z_vals = p_hit[:, 2]
z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item())
if z_hi - z_lo > 1e-6:
fade = 1.0 - (z_vals - z_lo) / (z_hi - z_lo)
depth_factor = 0.85 + 0.15 * fade
else:
depth_factor = torch.ones_like(z_vals)
base = col * diffuse.unsqueeze(-1) * depth_factor.unsqueeze(-1)
highlight = (0.5 * spec * depth_factor).unsqueeze(-1)
out[hit_idx] = base + highlight
return out.view(H, W, 3).clamp(0.0, 1.0)
def render_pose_data_capsules(
pose_data: Dict[str, Any],
*,
frame_idx: int,
W: int,
H: int,
background: Optional[torch.Tensor] = None,
composite: str = "over",
radius_m: float = 0.025,
include_hands: bool = False,
palette: str = "scail",
person_brightness_falloff: float = 0.0,
flat_shade: bool = False,
hand_radius_scale: float = 0.4,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Render a frame's pose_data as 3D capsules through the per-person camera.
Returns (H, W, 3) fp32 in [0, 1].
`composite='over'` paints over `background` (black if None); 'mesh_only'
uses a black canvas. `radius_m` is in meters; hand limbs use
`radius_m * hand_radius_scale`. fx/fy come from each person's `focal_length`.
"""
persons = pose_data["frames"][frame_idx]
if device is None:
device = comfy.model_management.get_torch_device()
# SAM3DBody shares one camera across the clip — use the first valid person.
fx = fy = float(min(H, W))
for person in persons:
f = person.get("focal_length")
if f is None:
continue
fx = fy = float(np.asarray(f, dtype=np.float32).reshape(-1)[0])
break
cx, cy = W * 0.5, H * 0.5
starts_np, ends_np, colors_np, is_hand_np = _build_specs_from_pose(
persons, pose_data, include_hands=include_hands, palette=palette,
person_brightness_falloff=person_brightness_falloff,
)
bg_t: Optional[torch.Tensor] = None
if composite == "over" and background is not None:
bg_t = background.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)
if bg_t.shape[:2] != (H, W):
bg_t = bg_t.permute(2, 0, 1).unsqueeze(0)
bg_t = torch.nn.functional.interpolate(
bg_t, size=(H, W), mode="bilinear", align_corners=False,
)
bg_t = bg_t.squeeze(0).permute(1, 2, 0).contiguous()
if starts_np.shape[0] == 0:
if bg_t is not None:
return bg_t
return torch.zeros(H, W, 3, dtype=torch.float32, device=device)
starts_t = torch.from_numpy(starts_np).to(device=device, dtype=torch.float32)
ends_t = torch.from_numpy(ends_np).to(device=device, dtype=torch.float32)
colors_t = torch.from_numpy(colors_np).to(device=device, dtype=torch.float32)
radii_np = np.where(is_hand_np, radius_m * hand_radius_scale, radius_m).astype(np.float32)
radii_t = torch.from_numpy(radii_np).to(device=device, dtype=torch.float32)
return _render_capsules_torch(
starts_t, ends_t, colors_t,
H=H, W=W, fx=fx, fy=fy, cx=cx, cy=cy,
radius=radii_t,
background_rgb=bg_t,
device=device,
flat_shade=flat_shade,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,530 @@
"""GLB export — skeletal (real armature) mode.
Rebuilds an Armature with the MHR 127-bone rig: per-frame local TRS from
param_transform on `mhr_model_params`, rest verts from a zero-pose forward,
sparse skinning compacted to glTF's 4-influence form, and facial expression as
72 morph targets driven by `expr_params`. Optional octahedron bone-vis is
rigidly skinned alongside for viewers that don't draw bones. Shared infra lives
in `glb_shared.py`.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from .glb_shared import (
GLBWriter,
Rig,
bake_vertex_colors,
bone_locals_from_globals,
collect_tracks,
compute_normals,
compute_pastel_mix,
flat_shade_mesh,
gaussian_smooth_quats,
global_skel_state_from_pose_data,
global_skel_state_per_frame,
ibp_from_bind_global,
make_lit_material,
quat_sign_fix_per_joint,
rotation_align,
unflip,
)
from comfy_extras.sam3d_body.utils import jet_colormap
def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]:
"""Per-bone RGB (NJ, 3) float32 in [0, 1]. None for 'white' (default material)."""
if scheme == "rainbow_y":
y = bind_pos_m[:, 1].astype(np.float32)
y_min, y_max = float(y.min()), float(y.max())
s = np.clip((y - y_min) / max(y_max - y_min, 1e-6), 0.0, 1.0)
return jet_colormap(s)
return None
def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
"""Canonical Blender-style bone octahedron: head at origin, tail at +Y, unit
length, ridge at 1/10 height. 6 verts, 8 triangles, faces wound outward."""
v = np.array([
[0.0, 0.0, 0.0], # 0: head
[0.0, 1.0, 0.0], # 1: tail
[1.0, 0.1, 0.0], # 2: +X ridge (pre-scale; X/Z scale by half_width)
[-1.0, 0.1, 0.0], # 3: -X ridge
[0.0, 0.1, 1.0], # 4: +Z ridge
[0.0, 0.1, -1.0], # 5: -Z ridge
], dtype=np.float32)
f = np.array([
# head pyramid: outward = away from bone axis, slightly -Y
[0, 2, 4], [0, 5, 2], [0, 3, 5], [0, 4, 3],
# tail pyramid: outward = away from bone axis, slightly +Y
[1, 4, 2], [1, 3, 4], [1, 5, 3], [1, 2, 5],
], dtype=np.uint32)
return v, f
def _bone_edges(
joint_pos_m: np.ndarray, parents: np.ndarray,
) -> List[Tuple[int, int, np.ndarray, np.ndarray]]:
"""One (parent_idx, child_idx, head_pos, tail_pos) per parent→child edge.
Skips edges whose parent is a root (world-anchor sticks) and zero-length
edges."""
NJ = joint_pos_m.shape[0]
out: List[Tuple[int, int, np.ndarray, np.ndarray]] = []
for c in range(NJ):
p = int(parents[c])
if not (0 <= p < NJ and p != c):
continue
# Skip world-anchor sticks: parent itself is a root.
gp = int(parents[p])
if not (0 <= gp < NJ and gp != p):
continue
head = joint_pos_m[p].astype(np.float32)
tail = joint_pos_m[c].astype(np.float32)
if float(np.linalg.norm(tail - head)) < 1e-6:
continue
out.append((p, c, head, tail))
return out
def _build_bone_octahedrons_mesh(
bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""One octahedron per parent→child edge. Returns (verts, normals, faces,
joints, weights, child_idx_per_vert); child_idx feeds per-bone color."""
base_v, base_f = _octahedron_unit()
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
out_v: List[List[float]] = []
out_n: List[List[float]] = []
out_f: List[List[int]] = []
out_j: List[List[int]] = []
out_w: List[List[float]] = []
child_per_vert: List[int] = []
# Width scales with length (capped by half_width_m) so short bones aren't chunky.
WIDTH_RATIO = 0.1
MIN_WIDTH = 0.001
for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents):
direction = tail - head
length = float(np.linalg.norm(direction))
if length < 1e-6:
continue
unit_dir = direction / length
R = rotation_align(canonical, unit_dir)
half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m))
scale = np.array([half_width_eff, length, half_width_eff], dtype=np.float32)
v_local = base_v * scale
v_world = v_local @ R.T + head
# head pole outward = -Y, tail pole +Y, ridges outward in XZ.
n_local = np.zeros_like(base_v)
n_local[0] = [0.0, -1.0, 0.0]
n_local[1] = [0.0, 1.0, 0.0]
for k in range(2, 6):
n = base_v[k].copy()
n[1] = 0.0
n_norm = float(np.linalg.norm(n))
if n_norm > 0:
n_local[k] = n / n_norm
n_world = n_local @ R.T
v_off = len(out_v)
out_v.extend(v_world.tolist())
out_n.extend(n_world.tolist())
for face in base_f:
out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off])
# Dual skin (head→parent, tail→child); ridges blend by canonical Y so
# the bone stretches between joints instead of going rigid with one.
for k in range(base_v.shape[0]):
y_canon = float(base_v[k, 1])
w_parent = max(0.0, 1.0 - y_canon)
w_child = max(0.0, y_canon)
wsum = w_parent + w_child
if wsum > 0:
w_parent /= wsum
w_child /= wsum
out_j.append([int(parent_idx), int(child_idx), 0, 0])
out_w.append([w_parent, w_child, 0.0, 0.0])
child_per_vert.append(int(child_idx))
if not out_v:
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32),
np.zeros((0, 3), dtype=np.uint32), np.zeros((0, 4), dtype=np.uint16),
np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=np.int64))
return (np.asarray(out_v, dtype=np.float32),
np.asarray(out_n, dtype=np.float32),
np.asarray(out_f, dtype=np.uint32),
np.asarray(out_j, dtype=np.uint16),
np.asarray(out_w, dtype=np.float32),
np.asarray(child_per_vert, dtype=np.int64))
def build_glb_skeletal(
pose_data: Dict[str, Any],
model: Any = None,
*,
fps: float = 24.0,
camera_translation: str = "off",
track_index: int = -1,
include_face_morphs: bool = True,
shader: str = "default",
rainbow_tilt_x_deg: float = 0.0,
rainbow_tilt_z_deg: float = 0.0,
person_palette_falloff: float = 0.6,
bone_smooth_window: int = 0,
use_stored_global_rots: bool = True,
bone_vis: str = "off",
bone_vis_radius_m: float = 0.04,
bone_vis_color: str = "white",
include_body_mesh: bool = True,
) -> bytes:
"""Build pose_data as a real Armature GLB with per-bone TRS keyframes. For
MHR, facial expression is exposed as 72 morph targets when
include_face_morphs=True.
External skeletons (e.g. ComfyUI-Kimodo) can supply
``pose_data["_skeleton_override"]`` to bypass MHR rig extraction (``model``
may be None then); per-frame state still reads ``pred_global_rots`` /
``pred_joint_coords``. See ``glb_shared._get_skeleton_override`` for the schema.
"""
frames = pose_data["frames"]
# Only `pred_cam_t` is camera-y-down; everything else is rig-native Y-up.
faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32)
tracks = collect_tracks(pose_data, track_index)
if not tracks:
raise ValueError("build_glb_skeletal: no valid tracks in pose_data")
rig = Rig.from_pose_data(pose_data, model)
NJ = rig.num_joints
NEXPR = rig.num_expr
parents = rig.parents
if not rig.can_rerun_fk:
# External rigs have no PCA pose params to re-run; use stored globals.
use_stored_global_rots = True
joint_coords_y_down = rig.per_frame_y_down
# Skin already compacted to ≤8 influences/vertex (some shoulder/hip verts
# need >4, else per-bone rotation noise leaks into the mesh).
joints_8 = rig.lbs_joints
weights_8 = rig.lbs_weights
actual_max_inf = rig.lbs_max_inf
joints_set0 = np.ascontiguousarray(joints_8[:, :4])
weights_set0 = np.ascontiguousarray(weights_8[:, :4])
use_set1 = actual_max_inf > 4
joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None
weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None
# Derive bone locals from bind globals so any `parents`/FK mismatch is
# absorbed into the local TRS instead of producing wrong globals.
bind_global_m = rig.bind_global_m
bind_local = bone_locals_from_globals(bind_global_m[None], parents)[0]
# IBP = inverse of bind global → skin_matrix at rest is identity.
ibp_mat4 = ibp_from_bind_global(bind_global_m)
w = GLBWriter()
nodes: List[dict] = []
meshes: List[dict] = []
skins: List[dict] = []
materials: List[dict] = []
animations: List[dict] = []
scene_root_indices: List[int] = []
canonical_colors = pose_data.get("canonical_colors")
indices_acc = w.add_indices_u32(faces_native)
joints0_acc = w.add_joints_u16(joints_set0)
weights0_acc = w.add_weights_f32(weights_set0)
joints1_acc = w.add_joints_u16(joints_set1) if use_set1 else None
weights1_acc = w.add_weights_f32(weights_set1) if use_set1 else None
ibm_acc = w.add_mat4_f32(ibp_mat4)
expr_morph_accs: List[int] = []
if include_face_morphs and NEXPR > 0:
eb = rig.expr_basis.astype(np.float32) * 0.01
for e in range(NEXPR):
expr_morph_accs.append(w.add_vec3_f32_no_minmax(eb[e]))
samplers: List[dict] = []
channels: List[dict] = []
for track_i, (person_k, frame_indices) in enumerate(tracks):
person_root = {"name": f"track{track_i:02d}", "children": []}
nodes.append(person_root)
person_root_idx = len(nodes) - 1
scene_root_indices.append(person_root_idx)
bone_node_indices: List[int] = []
for j in range(NJ):
bone = {
"name": f"bone_{j:03d}",
"translation": bind_local[j, :3].tolist(),
"rotation": bind_local[j, 3:7].tolist(),
"scale": [float(bind_local[j, 7])] * 3,
}
nodes.append(bone)
bone_node_indices.append(len(nodes) - 1)
bone_children: List[List[int]] = [[] for _ in range(NJ)]
bone_root_indices: List[int] = []
for j in range(NJ):
p = int(parents[j])
if 0 <= p < NJ and p != j:
bone_children[p].append(bone_node_indices[j])
else:
bone_root_indices.append(bone_node_indices[j])
for j in range(NJ):
if bone_children[j]:
nodes[bone_node_indices[j]]["children"] = bone_children[j]
person_root["children"].extend(bone_root_indices)
skin = {
"joints": bone_node_indices,
"inverseBindMatrices": ibm_acc,
"skeleton": bone_root_indices[0] if bone_root_indices else bone_node_indices[0],
}
skins.append(skin)
skin_idx = len(skins) - 1
include_body = bool(include_body_mesh)
include_bones = bone_vis == "octahedrons"
body_mesh_node_idx: Optional[int] = None
if include_body:
# MHR rest verts depend on shape_params; external rigs ignore the arg.
shape_params_arr = np.asarray(
frames[frame_indices[0]][person_k].get("shape_params", []),
dtype=np.float32,
)
rest_v = rig.rest_verts_m(shape_params_arr)
normals = compute_normals(rest_v, faces_native)
positions_acc = w.add_vec3_f32(rest_v)
normals_acc = w.add_vec3_f32(normals)
pastel_mix = compute_pastel_mix(track_i, person_palette_falloff)
vcolor = bake_vertex_colors(
canonical_colors, shader,
rainbow_tilt_x_deg, rainbow_tilt_z_deg, pastel_mix,
)
color_acc = w.add_vec3_f32(vcolor) if vcolor is not None else None
attributes = {
"POSITION": positions_acc, "NORMAL": normals_acc,
"JOINTS_0": joints0_acc, "WEIGHTS_0": weights0_acc,
}
if joints1_acc is not None:
attributes["JOINTS_1"] = joints1_acc
attributes["WEIGHTS_1"] = weights1_acc
if color_acc is not None:
attributes["COLOR_0"] = color_acc
primitive = {
"attributes": attributes,
"indices": indices_acc,
"mode": 4,
}
# See-through body when bones are shown, else opaque (only if a
# shader baked COLOR_0; otherwise default material).
if color_acc is not None or include_bones:
materials.append(make_lit_material(opacity=0.35 if include_bones else 1.0))
primitive["material"] = len(materials) - 1
if expr_morph_accs:
primitive["targets"] = [{"POSITION": a} for a in expr_morph_accs]
mesh = {"primitives": [primitive]}
if expr_morph_accs:
mesh["weights"] = [0.0] * len(expr_morph_accs)
meshes.append(mesh)
mesh_idx = len(meshes) - 1
mesh_node = {
"name": f"track{track_i:02d}_mesh", "mesh": mesh_idx, "skin": skin_idx,
}
nodes.append(mesh_node)
body_mesh_node_idx = len(nodes) - 1
person_root["children"].append(body_mesh_node_idx)
if include_bones:
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
# Color by child joint so every bone has its own color.
color_idx_per_vert: Optional[np.ndarray] = None
hw = float(bone_vis_radius_m)
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
bind_global_m[:, :3], parents, half_width_m=hw,
)
if bv_v.shape[0] > 0:
F = bv_f.shape[0]
expanded_child = np.empty((F * 3,), dtype=np.int64)
for k in range(3):
expanded_child[k::3] = child_per_vert[bv_f[:, k]]
bv_v, bv_n, bv_f, bv_j, bv_w = flat_shade_mesh(bv_v, bv_f, bv_j, bv_w)
color_idx_per_vert = expanded_child
primitive_mode = 4
bv_idx_flat = bv_f.reshape(-1)
if bv_v.shape[0] > 0:
bv_pos_acc = w.add_vec3_f32(bv_v)
bv_idx_acc = w.add_indices_u32(bv_idx_flat)
bv_j_acc = w.add_joints_u16(bv_j)
bv_w_acc = w.add_weights_f32(bv_w)
bv_attrs = {
"POSITION": bv_pos_acc,
"JOINTS_0": bv_j_acc, "WEIGHTS_0": bv_w_acc,
}
if bv_n is not None:
bv_attrs["NORMAL"] = w.add_vec3_f32(bv_n)
if bone_palette is not None and color_idx_per_vert is not None:
bv_color = bone_palette[color_idx_per_vert].astype(np.float32)
bv_attrs["COLOR_0"] = w.add_vec3_f32(bv_color)
bv_primitive = {
"attributes": bv_attrs,
"indices": bv_idx_acc,
"mode": primitive_mode,
}
if bone_palette is not None:
materials.append(make_lit_material())
bv_primitive["material"] = len(materials) - 1
bv_mesh = {"primitives": [bv_primitive]}
meshes.append(bv_mesh)
bv_mesh_node = {
"name": f"track{track_i:02d}_bones",
"mesh": len(meshes) - 1,
"skin": skin_idx,
}
nodes.append(bv_mesh_node)
person_root["children"].append(len(nodes) - 1)
# Per-frame global skel state → bone locals via parent-inverse. Stored
# output by default; fallback re-runs FK.
if use_stored_global_rots:
rig_global_m = global_skel_state_from_pose_data(
pose_data, frame_indices, person_k, NJ,
joint_coords_y_down=joint_coords_y_down,
)
else:
mp_per_frame = np.stack([
np.asarray(frames[t][person_k]["mhr_model_params"], dtype=np.float32)
for t in frame_indices
], axis=0)
rig_global_cm = global_skel_state_per_frame(model, mp_per_frame)
rig_global_m = rig_global_cm.copy().astype(np.float32)
rig_global_m[..., :3] *= 0.01
# Sign-fix global quats BEFORE deriving locals: a parent's ±180° flip
# would otherwise propagate into the child's local translation and cause
# visible "axis resets" mid-animation.
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
bone_local_anim = bone_locals_from_globals(rig_global_m, parents)
local_t = bone_local_anim[..., :3].astype(np.float32)
local_q = bone_local_anim[..., 3:7].astype(np.float32)
local_s = bone_local_anim[..., 7].astype(np.float32)
# Second pass on locals catches residual drift from the parent-inverse.
local_q = quat_sign_fix_per_joint(local_q)
# Align frame 0 with the bind quat so pause/play takes the short path.
bind_q = bind_local[:, 3:7].astype(np.float32)
if local_q.shape[0] > 0:
d0 = (bind_q * local_q[0]).sum(axis=-1)
sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None]
local_q[0] = local_q[0] * sign0
local_q = quat_sign_fix_per_joint(local_q)
# Optional smoothing for multi-frame rig spikes (e.g. q.w at handstand).
if bone_smooth_window and bone_smooth_window > 1:
local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window))
# fp64 renormalize → fp32; viewers' nlerp amplifies non-unit drift.
lq64 = local_q.astype(np.float64)
lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12)
local_q = lq64.astype(np.float32)
n_frames = len(frame_indices)
times = np.asarray(frame_indices, dtype=np.float32) / float(fps)
time_acc = w.add_scalar_f32(times)
for j in range(NJ):
t_j = local_t[:, j, :]
q_j = local_q[:, j, :]
s_j = np.broadcast_to(local_s[:, j:j+1], (n_frames, 3)).astype(np.float32)
t_const = (np.ptp(t_j, axis=0) < 1e-6).all()
q_const = (np.ptp(q_j, axis=0) < 1e-6).all()
s_const = (np.ptp(s_j, axis=0) < 1e-6).all()
if t_const:
nodes[bone_node_indices[j]]["translation"] = t_j[0].tolist()
else:
acc = w.add_vec3_f32_anim(t_j)
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": bone_node_indices[j], "path": "translation"},
})
if q_const:
nodes[bone_node_indices[j]]["rotation"] = q_j[0].tolist()
else:
acc = w.add_vec4_f32(q_j)
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": bone_node_indices[j], "path": "rotation"},
})
if s_const:
nodes[bone_node_indices[j]]["scale"] = s_j[0].tolist()
else:
acc = w.add_vec3_f32_anim(s_j)
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": bone_node_indices[j], "path": "scale"},
})
if camera_translation != "off":
cam_t = np.stack([
unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32))
for t in frame_indices
], axis=0)
if camera_translation == "centered" and cam_t.shape[0] > 0:
cam_t = cam_t - cam_t[0:1]
if (np.ptp(cam_t, axis=0) < 1e-6).all():
person_root["translation"] = cam_t[0].tolist()
else:
acc = w.add_vec3_f32_anim(cam_t)
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": person_root_idx, "path": "translation"},
})
# Body mesh only — bone-vis primitives have no morph targets.
if expr_morph_accs and body_mesh_node_idx is not None:
expr_per_frame = np.stack([
np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32)
for t in frame_indices
], axis=0).astype(np.float32)
weights_acc_anim = w.add_scalar_f32_flat(expr_per_frame, count=n_frames * NEXPR)
samplers.append({"input": time_acc, "output": weights_acc_anim, "interpolation": "LINEAR"})
channels.append({
"sampler": len(samplers) - 1,
"target": {"node": body_mesh_node_idx, "path": "weights"},
})
if samplers:
animations.append({
"name": "all_tracks",
"samplers": samplers, "channels": channels,
})
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI-SAM3DBody"},
"scene": 0,
"scenes": [{"nodes": scene_root_indices}],
"nodes": nodes,
"meshes": meshes,
"skins": skins,
}
if materials:
gltf["materials"] = materials
if animations:
gltf["animations"] = animations
return w.to_bytes(gltf)

View File

@ -0,0 +1,225 @@
"""2D OpenPose-style skeleton rendering for SAM 3D Body pose_data.
Body / hand drawing is delegated to `KeypointDraw.draw_wholebody_keypoints`
(shared with SDPose). SAM3D-specific: MHR70 -> DWPose-134 keypoint packing,
plus optional rig-projected face landmarks when `pred_face_keypoints_2d`
isn't present (and arbitrary-count face dots, since sapiens-238 doesn't fit
the DWPose face slot).
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
"""
import logging
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from PIL import Image
from comfy_extras.pose.keypoint_draw import KeypointDraw
from .glb_shared import (
OPENPOSE_HAND_COLORS_21,
openpose_render_keypoints,
select_face_landmark_vert_ids,
)
_KD = KeypointDraw()
# OpenPose hand palette as a (21, 3) int array (0..255) for KeypointDraw.
_HAND_DOT_PALETTE_OPENPOSE = (OPENPOSE_HAND_COLORS_21 * 255.0).astype(int)
def _project_face_landmarks_2d(
person: Dict[str, Any], face_vert_ids: np.ndarray, H: int, W: int,
) -> Optional[np.ndarray]:
"""Project `pred_vertices[face_vert_ids]` to 2D using each person's
pred_cam_t + focal_length. Same projection used by `_replay_mhr_with_overrides`."""
verts = person.get("pred_vertices")
cam_t = person.get("pred_cam_t")
focal = person.get("focal_length")
if verts is None or cam_t is None or focal is None:
return None
verts = np.asarray(verts, dtype=np.float32)
cam_t = np.asarray(cam_t, dtype=np.float32).reshape(3)
f = float(np.asarray(focal, dtype=np.float32).reshape(-1)[0])
pts3 = verts[face_vert_ids] + cam_t[None, :]
z = np.maximum(pts3[:, 2:3], 1e-6)
xy = pts3[:, :2] * f
xy = xy + np.array([W * 0.5, H * 0.5], dtype=np.float32)[None, :] * z
return (xy / z).astype(np.float32)
def _pack_dwpose_134(
person: Dict[str, Any], pose_data: Dict[str, Any], *,
include_body: bool, include_hands: bool, H: int, W: int,
) -> Tuple[np.ndarray, np.ndarray]:
"""Pack a SAM3D person dict into (kp, scores): (134, 2) DWPose-layout
coords + (134,) confidence. Face slot (24-91) is left zeroed; face dots
are drawn separately so SAM3D's 238-sapiens / rig-fallback counts work.
Non-finite or out-of-band entries get score=0 and are filtered downstream.
Keypoints come from the shared provider: MHR reindexes `pred_keypoints_2d`,
external rigs (Kimodo) resolve + project from `pred_joint_coords`."""
kp = np.zeros((134, 2), dtype=np.float32)
scores = np.zeros(134, dtype=np.float32)
if include_body:
body_xy = openpose_render_keypoints(person, pose_data, "body", dim=2, H=H, W=W)
if body_xy is not None:
finite = np.isfinite(body_xy).all(axis=1)
kp[:18][finite] = body_xy[finite]
scores[:18][finite] = 1.0
if include_hands:
for slot_start, part in ((92, "hand_r"), (113, "hand_l")):
hand_xy = openpose_render_keypoints(person, pose_data, part, dim=2, H=H, W=W)
if hand_xy is None:
continue
finite = np.isfinite(hand_xy).all(axis=1)
kp[slot_start:slot_start + 21][finite] = hand_xy[finite]
scores[slot_start:slot_start + 21][finite] = 1.0
return kp, scores
def _draw_face_dots(
canvas: np.ndarray, face_xy: np.ndarray, marker_radius_px: int,
) -> None:
"""White face dots, variable count (238 sapiens / ~30 rig-projected)."""
H, W = canvas.shape[:2]
pad = int(marker_radius_px)
white = (255, 255, 255)
for i in range(face_xy.shape[0]):
x_, y_ = float(face_xy[i, 0]), float(face_xy[i, 1])
if not (np.isfinite(x_) and np.isfinite(y_)):
continue
x, y = int(round(x_)), int(round(y_))
if x + pad < 0 or x - pad >= W or y + pad < 0 or y - pad >= H:
continue
_KD.draw.circle(canvas, (x, y), int(marker_radius_px), white, thickness=-1)
def render_pose_data_openpose(
pose_data: Dict[str, Any],
*,
frame_idx: int,
W: int,
H: int,
background: Optional[torch.Tensor] = None,
composite: str = "over",
marker_radius_px: int = 4,
stick_width_px: int = 4,
limb_alpha: float = 0.6,
include_body: bool = True,
include_hands: bool = False,
face_style: str = "disabled",
hand_color_style: str = "dwpose",
hand_marker_radius_px: int = 0,
hand_stick_width_px: int = 0,
face_marker_radius_px: int = 3,
person_brightness_falloff: float = 0.0,
) -> torch.Tensor:
"""Render a 2D OpenPose-style skeleton onto an (H, W, 3) canvas.
`composite='over'` paints over `background` (else black canvas).
`hand_marker_radius_px` / `hand_stick_width_px`: 0 = auto = 0.7x / 0.5x
of the body sizes.
`face_style`: 'disabled' / 'full' / 'eyes_mouth'. eyes_mouth falls through
to the rig fallback since sapiens-238 has no documented subset.
`person_brightness_falloff` mixes each person's drawn pixels toward white
by `1 - falloff^k` (track 0 stays vivid). Applied post-draw so per-limb
alpha blending against the existing canvas remains correct.
"""
persons = pose_data["frames"][frame_idx]
if composite == "over" and background is not None:
bg = background.cpu().numpy()
canvas = (np.clip(bg, 0.0, 1.0) * 255.0).astype(np.uint8)
if canvas.shape[:2] != (H, W):
canvas = np.array(Image.fromarray(canvas).resize((W, H), Image.LANCZOS))
else:
canvas = np.zeros((H, W, 3), dtype=np.uint8)
# In-place draw needs a contiguous writable buffer.
canvas = np.ascontiguousarray(canvas)
if int(hand_marker_radius_px) <= 0:
hand_marker_radius_px = max(1, int(round(marker_radius_px * 0.7)))
if int(hand_stick_width_px) <= 0:
hand_stick_width_px = max(1, int(round(stick_width_px * 0.5)))
_EYES_MOUTH_IDX = np.array([6, 7, 8, 9, 10, 11, 12, 13, 19, 20, 21, 22], dtype=np.int64)
include_face = face_style != "disabled"
use_rig_only = face_style == "eyes_mouth"
face_vert_ids: Optional[np.ndarray] = None
if include_face:
any_real = (not use_rig_only) and any(
p.get("pred_face_keypoints_2d") is not None for p in persons
)
if not any_real:
cc = pose_data.get("canonical_colors") or {}
positions = cc.get("positions")
if positions is not None:
try:
face_vert_ids = select_face_landmark_vert_ids(
np.asarray(positions), face_mask=cc.get("face_mask"),
)
if use_rig_only:
face_vert_ids = face_vert_ids[_EYES_MOUTH_IDX]
except (ValueError, IndexError) as e:
logging.warning(f"[SAM3DBody] face landmarks disabled - {e}")
face_vert_ids = None
hand_dot_color = (
_HAND_DOT_PALETTE_OPENPOSE if hand_color_style == "openpose"
else (0, 0, 255)
)
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
for k, person in enumerate(persons):
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
# Snapshot before this person's strokes so we can identify the pixels
# they touched and blend just those toward white. Drawing happens
# against the live canvas first so limb_alpha blends correctly.
pre = canvas.copy() if pastel > 0 else None
kp134, scores134 = _pack_dwpose_134(
person, pose_data, include_body=include_body,
include_hands=include_hands, H=H, W=W,
)
_KD.draw_wholebody_keypoints(
canvas, kp134, scores=scores134, threshold=0.5,
draw_body=include_body, draw_feet=False,
draw_face=False, # SAM3D draws face dots separately (variable count)
draw_hands=include_hands,
stick_width=stick_width_px,
marker_radius=marker_radius_px,
hand_stick_width=hand_stick_width_px,
hand_marker_radius=hand_marker_radius_px,
limb_alpha=limb_alpha,
hand_dot_color=hand_dot_color,
)
if include_face:
face_xy = None
real_face = person.get("pred_face_keypoints_2d")
if real_face is not None:
arr = np.asarray(real_face, dtype=np.float32)
if arr.ndim == 2 and arr.shape[1] == 2:
face_xy = arr
elif face_vert_ids is not None:
face_xy = _project_face_landmarks_2d(person, face_vert_ids, H, W)
if face_xy is not None:
_draw_face_dots(canvas, face_xy, face_marker_radius_px)
if pre is not None:
changed = (canvas != pre).any(axis=-1)
if changed.any():
touched = canvas[changed].astype(np.float32)
blended = touched * (1.0 - pastel) + 255.0 * pastel
canvas[changed] = np.clip(blended, 0.0, 255.0).astype(np.uint8)
return torch.from_numpy(canvas.astype(np.float32) / 255.0)

View File

@ -0,0 +1,516 @@
"""Face expression for SAM 3D Body.
Pipeline: comfy_extras.mediapipe.face_landmarker 52 ARKit blendshapes
72-dim MHR expr_params (mapping inlined below).
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import comfy.model_management
# Bypass deadzone — jaw signals are clean (open or not)
_NOISE_FREE_BLENDSHAPES = {"jawOpen", "jawForward", "jawLeft", "jawRight"}
# Per-region gain — MP magnitudes vary by family (jaw up to 1.0, eye/brow
# rarely past 0.3), so a single global gain over/underdrives.
_REGION_PREFIXES = {
"mouth": ("jaw", "mouth"),
"eye": ("eye",),
"brow": ("brow", "cheek", "nose"), # cheek/nose read as upper-face
}
def _region_of(arkit_name: str) -> str:
for region, prefixes in _REGION_PREFIXES.items():
for p in prefixes:
if arkit_name.startswith(p):
return region
return "other"
# MHR axis → ARKit driver(s). Each axis collects 1-3 (name, weight) entries;
# the consumer takes max() across them so primary + aux contributions don't
# stack. MHR's 72 expression axes ship as anonymous `shape_c_N` channels in
# the upstream FBX (no semantic names), so this table is hand-derived by
# visual inspection of which axis each ARKit shape drives. Axes 2/3 and
# 12/13 are filled by aux routes only. ARKit shapes with no MHR analog are simply absent.
_AXIS_TO_ARKIT: Dict[int, List[Tuple[str, float]]] = {
0: [("browDownLeft", 1.0)],
1: [("browDownRight", 1.0)],
2: [("cheekPuff", 1.0)],
3: [("cheekPuff", 1.0)],
4: [("cheekSquintLeft", 1.0)],
5: [("cheekSquintRight", 1.0)],
6: [("mouthStretchLeft", 1.0)],
7: [("mouthStretchRight", 1.0)],
8: [("mouthShrugLower", 1.0)],
9: [("mouthShrugUpper", 1.0)],
10: [("mouthDimpleLeft", 1.0)],
11: [("mouthDimpleRight", 1.0)],
12: [("eyeLookDownLeft", 0.3)],
13: [("eyeLookDownRight", 0.3)],
14: [("eyeBlinkLeft", 1.0)],
15: [("eyeBlinkRight", 1.0)],
16: [("eyeLookOutLeft", 1.0)],
17: [("eyeLookInRight", 1.0)],
18: [("eyeLookInLeft", 1.0)],
19: [("eyeLookOutRight", 1.0)],
22: [("eyeLookUpLeft", 1.0), ("browInnerUp", 0.5)],
23: [("eyeLookUpRight", 1.0), ("browInnerUp", 0.5)],
24: [("jawOpen", 1.0), ("mouthLowerDownLeft", 0.3), ("mouthLowerDownRight", 0.3)],
25: [("jawLeft", 1.0)],
26: [("jawRight", 1.0)],
27: [("jawForward", 1.0)],
28: [("eyeSquintLeft", 1.0)],
29: [("eyeSquintRight", 1.0)],
32: [("mouthSmileLeft", 1.0)],
33: [("mouthSmileRight", 1.0)],
40: [("mouthLeft", 1.0)],
41: [("mouthRight", 1.0)],
42: [("mouthFrownLeft", 1.0)],
43: [("mouthFrownRight", 1.0)],
54: [("mouthLowerDownLeft", 1.0)],
55: [("mouthLowerDownRight", 1.0)],
60: [("noseSneerLeft", 1.0)],
61: [("noseSneerRight", 1.0)],
66: [("browOuterUpLeft", 1.0)],
67: [("browOuterUpRight", 1.0)],
68: [("eyeWideLeft", 1.0)],
69: [("eyeWideRight", 1.0)],
70: [("mouthUpperUpLeft", 1.0)],
71: [("mouthUpperUpRight", 1.0)],
}
def _deadzone(x: float, threshold: float) -> float:
"""Zero below threshold, linearly remap (threshold..1] → (0..1] so
amplification doesn't promote MP's per-blendshape noise floor."""
if threshold <= 0.0:
return x
if x <= threshold:
return 0.0
return (x - threshold) / (1.0 - threshold)
def arkit_to_expr_params(
blendshape_coefs: Dict[str, float],
strength: float = 1.0,
mouth_strength: float = 1.0,
eye_strength: float = 1.0,
brow_strength: float = 1.0,
input_threshold: float = 0.0,
n_axes: int = 72,
) -> np.ndarray:
"""Map MediaPipe's 52 ARKit blendshapes to MHR's 72 expr_params axes.
Multiple ARKit names per axis combine via max() so primary + aux routes
don't double up."""
expr = np.zeros(n_axes, dtype=np.float32)
region_scale = {
"mouth": float(mouth_strength), "eye": float(eye_strength),
"brow": float(brow_strength), "other": 1.0,
}
thr = float(input_threshold)
for axis, routes in _AXIS_TO_ARKIT.items():
best = 0.0
for name, weight in routes:
raw = float(blendshape_coefs.get(name, 0.0))
name_thr = 0.0 if name in _NOISE_FREE_BLENDSHAPES else thr
raw = _deadzone(raw, name_thr)
c = raw * region_scale[_region_of(name)] * float(weight)
if c > best:
best = c
expr[axis] = best * strength
return expr
def subtract_per_clip_baseline(
per_frame_coefs: List[Optional[Dict[str, float]]],
percentile: float = 5.0,
) -> List[Optional[Dict[str, float]]]:
"""Subtract per-blendshape p`percentile` baseline, clamp at 0. Adapts to
per-subject MP bias (e.g. resting browOuterUp ~0.15 permanent surprise
under brow_strength=2.0) that a global deadzone can't catch."""
if percentile <= 0.0:
return list(per_frame_coefs)
names: set = set()
for c in per_frame_coefs:
if c is not None:
names.update(c.keys())
baselines: Dict[str, float] = {}
for n in names:
vals = [c[n] for c in per_frame_coefs if c is not None and n in c]
if vals:
baselines[n] = float(np.percentile(vals, percentile))
return [
None if c is None
else {n: max(0.0, float(v) - baselines.get(n, 0.0)) for n, v in c.items()}
for c in per_frame_coefs
]
def smooth_blendshape_series(
per_frame_coefs: List[Optional[Dict[str, float]]],
window: int = 7,
sigma: Optional[float] = None,
) -> List[Optional[Dict[str, float]]]:
"""Gaussian-smooth each coefficient across time. MP per-frame output swings
30-70% on static faces; smoothing pre-mapping cleans better than smoothing
mesh verts. None frames pass through unchanged."""
if window <= 1:
return list(per_frame_coefs)
if window % 2 == 0:
window += 1
if sigma is None:
sigma = max(1.0, window / 5.0)
x = np.arange(window) - (window - 1) / 2.0
k = np.exp(-(x ** 2) / (2 * sigma ** 2))
k = k / k.sum()
names: set = set()
for c in per_frame_coefs:
if c is not None:
names.update(c.keys())
if not names:
return list(per_frame_coefs)
N = len(per_frame_coefs)
pad = window // 2
out: List[Optional[Dict[str, float]]] = [None] * N
for name in names:
series = np.zeros(N, dtype=np.float32)
mask = np.zeros(N, dtype=bool)
for i, c in enumerate(per_frame_coefs):
if c is not None:
series[i] = float(c.get(name, 0.0))
mask[i] = True
if not mask.any():
continue
if not mask.all():
idx = np.arange(N)
series = np.interp(idx, idx[mask], series[mask])
padded = np.concatenate(
[np.repeat(series[:1], pad), series, np.repeat(series[-1:], pad)]
)
filt = np.zeros_like(series)
for i, w in enumerate(k):
filt += w * padded[i: i + N]
for i in range(N):
if per_frame_coefs[i] is None:
continue
if out[i] is None:
out[i] = {}
out[i][name] = float(filt[i])
return out
def fill_detection_gaps(
per_frame_coefs: List[Optional[Dict[str, float]]],
method: str = "interpolate",
max_gap: int = 12,
) -> List[Optional[Dict[str, float]]]:
"""Fill missing per-frame dicts so the signal doesn't slam to zero at
undetected frames. method: 'interpolate' | 'hold' | 'zeros'.
`max_gap` applies to 'interpolate' and 'hold' gaps longer than that stay
None (don't fake too far). 'zeros' ignores `max_gap` on purpose: the goal
there is to relax to neutral on every miss, no matter how long, otherwise
long undetected runs would inherit Predict's per-frame expression."""
if method == "zeros":
names: set = set()
for c in per_frame_coefs:
if c is not None:
names.update(c.keys())
zero = {n: 0.0 for n in names}
return [dict(zero) if c is None else c for c in per_frame_coefs]
N = len(per_frame_coefs)
detected = [i for i, c in enumerate(per_frame_coefs) if c is not None]
if not detected:
return list(per_frame_coefs)
out: List[Optional[Dict[str, float]]] = list(per_frame_coefs)
for fi in range(N):
if out[fi] is not None:
continue
prev_i = next((k for k in range(fi - 1, -1, -1) if per_frame_coefs[k] is not None), None)
next_i = next((k for k in range(fi + 1, N) if per_frame_coefs[k] is not None), None)
if prev_i is None and next_i is None:
continue
max_dist = max(
(fi - prev_i) if prev_i is not None else 10**9,
(next_i - fi) if next_i is not None else 10**9,
)
if max_dist > max_gap:
continue
if method == "hold":
src = per_frame_coefs[prev_i] if prev_i is not None else per_frame_coefs[next_i]
out[fi] = dict(src)
elif method == "interpolate":
if prev_i is None:
out[fi] = dict(per_frame_coefs[next_i])
elif next_i is None:
out[fi] = dict(per_frame_coefs[prev_i])
else:
w = (fi - prev_i) / (next_i - prev_i)
a = per_frame_coefs[prev_i]
b = per_frame_coefs[next_i]
keys = set(a.keys()) | set(b.keys())
out[fi] = {k: (1.0 - w) * a.get(k, 0.0) + w * b.get(k, 0.0) for k in keys}
return out
def detect_faces_in_crop(
inner, image_rgb_uint8: np.ndarray, crop_xyxy: np.ndarray, num_faces: int = 1,
) -> List[dict]:
"""Run detection on a sub-region; remap bbox+landmarks back to full-image
coords. Helps small/distant faces that fall below BlazeFace's min size."""
H, W = image_rgb_uint8.shape[:2]
x1, y1, x2, y2 = (int(round(float(v))) for v in crop_xyxy)
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(W, x2), min(H, y2)
if x2 - x1 < 16 or y2 - y1 < 16:
return []
crop = np.ascontiguousarray(image_rgb_uint8[y1:y2, x1:x2])
faces = inner.face_landmarker.detect_batch([crop], num_faces=num_faces)[0]
bbox_off = np.array([x1, y1, x1, y1], dtype=np.float32)
xy_off = np.array([x1, y1], dtype=np.float32)
for f in faces:
f["bbox_xyxy"] = f["bbox_xyxy"] + bbox_off
f["landmarks_xy"] = f["landmarks_xy"] + xy_off
return faces
# Crop helpers — feed MP a tight head region so it doesn't downsample the face
# to 192px for full-frame detection.
def _expand_bbox(bbox_xyxy: np.ndarray, factor: float, W: int, H: int) -> np.ndarray:
x1, y1, x2, y2 = (float(v) for v in bbox_xyxy)
cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
hw, hh = 0.5 * (x2 - x1) * factor, 0.5 * (y2 - y1) * factor
return np.array([
max(0.0, cx - hw), max(0.0, cy - hh),
min(float(W), cx + hw), min(float(H), cy + hh),
], dtype=np.float32)
def head_region_crop(
person_bbox: np.ndarray, expand: float, W: int, H: int, head_h_frac: float = 0.4,
) -> np.ndarray:
"""Crop upper `head_h_frac` of a body bbox — cropping the whole body wastes
BlazeFace's 128² input on body pixels."""
x1, y1, x2, y2 = (float(v) for v in person_bbox)
body_h = y2 - y1
if body_h <= 0 or x2 - x1 <= 0:
return np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
return _expand_bbox(np.array([x1, y1, x2, y1 + body_h * head_h_frac]), expand, W, H)
# mhr70 convention: first five kp are COCO-style face landmarks in pixel coords.
_FACE_KP_INDICES = (0, 1, 2, 3, 4) # nose, L-eye, R-eye, L-ear, R-ear
def head_crop_from_keypoints(
pred_keypoints_2d: np.ndarray, expand: float, W: int, H: int,
) -> Optional[np.ndarray]:
"""Head crop from SAM3D nose/eyes/ears kp. HEAD_FIT pads forehead/chin
since these only span the central face. None if <2 kp in-frame."""
if pred_keypoints_2d is None:
return None
kp = np.asarray(pred_keypoints_2d, dtype=np.float32)
if kp.ndim != 2 or kp.shape[0] <= max(_FACE_KP_INDICES):
return None
face = kp[list(_FACE_KP_INDICES), :2]
in_frame = (face[:, 0] > 0) & (face[:, 1] > 0) & (face[:, 0] < W) & (face[:, 1] < H)
valid = face[in_frame]
if len(valid) < 2:
return None
x1, x2 = float(valid[:, 0].min()), float(valid[:, 0].max())
y1, y2 = float(valid[:, 1].min()), float(valid[:, 1].max())
cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
span = max(x2 - x1, y2 - y1, 1.0)
half = 0.5 * span * 1.8 * float(expand) # 1.8 = pad forehead+chin
return np.array([
max(0.0, cx - half), max(0.0, cy - half),
min(float(W), cx + half), min(float(H), cy + half),
], dtype=np.float32)
# Face → person assignment when running full-frame detection.
def _iou_xyxy(a: np.ndarray, b: np.ndarray) -> float:
ix1, iy1 = max(a[0], b[0]), max(a[1], b[1])
ix2, iy2 = min(a[2], b[2]), min(a[3], b[3])
iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
inter = iw * ih
if inter <= 0.0:
return 0.0
aw, ah = max(0.0, a[2] - a[0]), max(0.0, a[3] - a[1])
bw, bh = max(0.0, b[2] - b[0]), max(0.0, b[3] - b[1])
union = aw * ah + bw * bh - inter
return float(inter / union) if union > 0.0 else 0.0
def assign_faces_to_persons(
face_bboxes: List[np.ndarray], person_bboxes: List[np.ndarray], min_iou: float = 0.01,
) -> List[Optional[int]]:
if not face_bboxes or not person_bboxes:
return [None] * len(person_bboxes)
assigned: List[Optional[int]] = [None] * len(person_bboxes)
used: set = set()
# Larger persons first — bigger bbox correlates with detectable face.
order = sorted(range(len(person_bboxes)),
key=lambda p: -((person_bboxes[p][2] - person_bboxes[p][0])
* (person_bboxes[p][3] - person_bboxes[p][1])))
for pi in order:
best_iou = min_iou
best_fi = None
pb = person_bboxes[pi]
for fi, fb in enumerate(face_bboxes):
if fi in used:
continue
cx, cy = 0.5 * (fb[0] + fb[2]), 0.5 * (fb[1] + fb[3])
inside = (pb[0] <= cx <= pb[2]) and (pb[1] <= cy <= pb[3])
score = max(_iou_xyxy(fb, pb), 0.5 if inside else 0.0)
if score > best_iou:
best_iou = score
best_fi = fi
if best_fi is not None:
assigned[pi] = best_fi
used.add(best_fi)
return assigned
# Re-run MHR forward after writing expr_params back into pose_frames; updates
# pred_vertices / pred_keypoints_2d/3d / pred_joint_coords / pred_global_rots.
def regenerate_mesh_from_params(inner, pose_frames: List[List[Dict[str, Any]]]) -> None:
"""Re-run MHR forward and write verts/kp3d/kp2d/joint back in place.
Drives MHR via euler params directly because hand refinement zeroes
pred_pose_raw."""
device = comfy.model_management.get_torch_device()
head = inner.head_pose
if head.mhr is None:
return
B = len(pose_frames)
max_p = max((len(f) for f in pose_frames), default=0)
for pid in range(max_p):
grots, bpps, hands, shapes, scales, exprs, cam_ts, fls = [], [], [], [], [], [], [], []
present: List[bool] = []
for fi in range(B):
if pid >= len(pose_frames[fi]):
present.append(False)
continue
p = pose_frames[fi][pid]
needed = ("global_rot", "body_pose_params", "hand_pose_params",
"shape_params", "scale_params", "expr_params",
"pred_cam_t", "focal_length")
if any(p.get(k) is None for k in needed):
present.append(False)
continue
grots.append(np.asarray(p["global_rot"], dtype=np.float32))
bpps.append(np.asarray(p["body_pose_params"], dtype=np.float32))
hands.append(np.asarray(p["hand_pose_params"], dtype=np.float32))
shapes.append(np.asarray(p["shape_params"], dtype=np.float32))
scales.append(np.asarray(p["scale_params"], dtype=np.float32))
exprs.append(np.asarray(p["expr_params"], dtype=np.float32))
cam_ts.append(np.asarray(p["pred_cam_t"], dtype=np.float32))
fls.append(float(np.asarray(p["focal_length"]).reshape(-1)[0]))
present.append(True)
if not any(present):
continue
global_rot_euler = torch.from_numpy(np.stack(grots)).to(device)
body_pose_euler = torch.from_numpy(np.stack(bpps)).to(device)
hand_t = torch.from_numpy(np.stack(hands)).to(device)
shape_t = torch.from_numpy(np.stack(shapes)).to(device)
scale_t = torch.from_numpy(np.stack(scales)).to(device)
expr_t = torch.from_numpy(np.stack(exprs)).to(device)
cam_t_t = torch.from_numpy(np.stack(cam_ts)).to(device)
f_t = torch.tensor(fls, device=device, dtype=torch.float32)
verts, kp3d_full, joint_coords, _, joint_rotmats = head.mhr_forward(
global_trans=torch.zeros_like(global_rot_euler),
global_rot=global_rot_euler,
body_pose_params=body_pose_euler,
hand_pose_params=hand_t,
scale_params=scale_t,
shape_params=shape_t,
expr_params=expr_t,
return_keypoints=True,
return_joint_coords=True,
return_model_params=True,
return_joint_rotations=True,
)
# y/z flip matches head_pose.forward (camera-y-down convention).
verts = verts.clone()
verts[..., [1, 2]] *= -1
kp3d = kp3d_full[:, :70].clone()
kp3d[..., [1, 2]] *= -1
# 238 sapiens face landmarks (70:308) — track retargeted expression
# so openpose face dots follow new mouth/eye/brow shape.
kp3d_face = kp3d_full[:, 70:].clone()
kp3d_face[..., [1, 2]] *= -1
joint_coords = joint_coords.clone()
joint_coords[..., [1, 2]] *= -1
# Recover principal point from any raw frame for reprojection.
cx = cy = 0.0
for fi in range(B):
if not present[fi]:
continue
raw = pose_frames[fi][pid]
kp2d_r = np.asarray(raw["pred_keypoints_2d"], dtype=np.float32)
kp3d_r = np.asarray(raw["pred_keypoints_3d"], dtype=np.float32)
ct_r = np.asarray(raw["pred_cam_t"], dtype=np.float32)
fl_r = float(np.asarray(raw["focal_length"]).reshape(-1)[0])
x, y, z = kp3d_r[0] + ct_r
cx = float(kp2d_r[0, 0] - fl_r * x / max(z, 1e-6))
cy = float(kp2d_r[0, 1] - fl_r * y / max(z, 1e-6))
break
def _project_kp(kp3d_local: torch.Tensor) -> torch.Tensor:
kp3d_cam = kp3d_local + cam_t_t.unsqueeze(1)
u = f_t[:, None] * kp3d_cam[..., 0] / kp3d_cam[..., 2].clamp(min=1e-6) + cx
v = f_t[:, None] * kp3d_cam[..., 1] / kp3d_cam[..., 2].clamp(min=1e-6) + cy
return torch.stack([u, v], dim=-1)
kp2d = _project_kp(kp3d)
kp2d_face = _project_kp(kp3d_face)
verts_np = verts.float().cpu().numpy()
kp3d_np = kp3d.float().cpu().numpy()
kp2d_np = kp2d.float().cpu().numpy()
kp3d_face_np = kp3d_face.float().cpu().numpy()
kp2d_face_np = kp2d_face.float().cpu().numpy()
jc_np = joint_coords.float().cpu().numpy()
jrot_np = joint_rotmats.float().cpu().numpy()
fi_active = 0
for fi in range(B):
if not present[fi]:
continue
pose_frames[fi][pid] = dict(pose_frames[fi][pid])
p = pose_frames[fi][pid]
p["pred_vertices"] = verts_np[fi_active]
p["pred_keypoints_3d"] = kp3d_np[fi_active]
p["pred_keypoints_2d"] = kp2d_np[fi_active]
p["pred_face_keypoints_3d"] = kp3d_face_np[fi_active]
p["pred_face_keypoints_2d"] = kp2d_face_np[fi_active]
p["pred_joint_coords"] = jc_np[fi_active]
p["pred_global_rots"] = jrot_np[fi_active]
fi_active += 1

View File

@ -0,0 +1,467 @@
"""Pure-PyTorch rasterizer for SAM 3D Body meshes.
Algorithm: forward triangle rasterizer with hard z-buffer. Per-face screen
bbox cull faces sorted by bbox size and chunked under a fixed pixel
budget inside-test via edge functions, barycentric interpolation, depth
test via `scatter_reduce_(amin)`.
"""
from typing import Sequence
import numpy as np
import torch
import comfy.model_management
from .utils import jet_colormap
_CANONICAL_PRESETS = {"rainbow", "rainbow_face_normal", "rainbow_face_semantic"}
_rainbow_cache: dict = {}
def rainbow_colors_from_canonical(
positions: np.ndarray,
tilt_x_deg: float = 0.0,
tilt_z_deg: float = 0.0,
) -> np.ndarray:
"""Compute per-vertex jet-colormap RGB from canonical (T-pose, Y-up) vertices.
Args:
positions: (N_v, 3) canonical vertex positions, Y-up (head at max Y).
tilt_x_deg: rotation of the jet axis around X (in degrees). Positive
biases the ramp toward +Z (front).
tilt_z_deg: rotation of the jet axis around Z (in degrees). Positive
biases the ramp toward +X (right, in body frame).
Returns:
(N_v, 3) float32 RGB in [0, 1].
"""
key = (hash(positions.tobytes()), round(float(tilt_x_deg), 3), round(float(tilt_z_deg), 3))
cached = _rainbow_cache.get(key)
if cached is not None:
return cached
theta_x = np.deg2rad(tilt_x_deg)
theta_z = np.deg2rad(tilt_z_deg)
axis = np.array([
np.sin(theta_z),
np.cos(theta_z) * np.cos(theta_x),
np.cos(theta_z) * np.sin(theta_x),
], dtype=np.float32)
s = positions @ axis
s = (s - s.min()) / max(float(s.max() - s.min()), 1e-8)
s = np.clip(s * 0.98, 0.0, 1.0).astype(np.float32)
colors = jet_colormap(s)
_rainbow_cache[key] = colors
if len(_rainbow_cache) > 32:
_rainbow_cache.pop(next(iter(_rainbow_cache)))
return colors
def _vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
"""Area-weighted per-vertex normals; matches `_compute_vertex_normals`."""
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
vn = torch.zeros_like(verts)
vn.index_add_(0, faces[:, 0], fn)
vn.index_add_(0, faces[:, 1], fn)
vn.index_add_(0, faces[:, 2], fn)
return vn / vn.norm(dim=-1, keepdim=True).clamp(min=1e-8)
def _build_vcolor(canonical_colors, shader_preset, tilt_x, tilt_z):
"""Mirrors the canonical_colors -> per-vertex RGB pipeline in
`rasterizer.render_pose_data`. Returns a numpy float32 (V, 3) table."""
positions = np.asarray(canonical_colors.get("positions"), dtype=np.float32)
vcolor = rainbow_colors_from_canonical(positions, tilt_x_deg=tilt_x, tilt_z_deg=tilt_z).copy()
if shader_preset in ("rainbow_face_normal", "rainbow_face_semantic"):
face_mask = canonical_colors.get("face_mask")
if face_mask is not None and face_mask.any():
if shader_preset == "rainbow_face_normal":
norm = np.asarray(canonical_colors["norm"], dtype=np.float32)
vcolor[face_mask] = norm[face_mask]
else: # rainbow_face_semantic
sem = np.asarray(canonical_colors["face_region_rgb"], dtype=np.float32)
assigned = sem.sum(axis=1) > 0
vcolor[assigned] = sem[assigned]
return vcolor
def _rasterize_chunk(
fv_pix: torch.Tensor, # (Fc, 3, 2) — pixel coords (sub-pixel float)
fv_z: torch.Tensor, # (Fc, 3) — image-frame z (smaller=closer)
bb_min_x: torch.Tensor, bb_max_x: torch.Tensor, # (Fc,) clamped int bboxes
bb_min_y: torch.Tensor, bb_max_y: torch.Tensor,
max_sx: int, max_sy: int,
W: int,
):
"""Rasterize a chunk of faces at pixel centers. Returns flat tensors of
inside fragments: (pixel_idx, depth, face_local, bary).
"""
device = fv_pix.device
if max_sx == 0 or max_sy == 0:
return None
sx = bb_max_x - bb_min_x
sy = bb_max_y - bb_min_y
px_off = torch.arange(max_sx, device=device)
py_off = torch.arange(max_sy, device=device)
# Pixel-center sample positions, broadcast to (Fc, max_sy, max_sx).
P_x = (bb_min_x[:, None, None] + px_off[None, None, :]).float() + 0.5
P_y = (bb_min_y[:, None, None] + py_off[None, :, None]).float() + 0.5
in_bb = (px_off[None, None, :] < sx[:, None, None]) & \
(py_off[None, :, None] < sy[:, None, None])
Ax = fv_pix[:, 0, 0][:, None, None]
Ay = fv_pix[:, 0, 1][:, None, None]
Bx = fv_pix[:, 1, 0][:, None, None]
By = fv_pix[:, 1, 1][:, None, None]
Cx = fv_pix[:, 2, 0][:, None, None]
Cy = fv_pix[:, 2, 1][:, None, None]
area2 = (Bx - Ax) * (Cy - Ay) - (By - Ay) * (Cx - Ax) # (Fc, 1, 1)
e_a = (Bx - P_x) * (Cy - P_y) - (By - P_y) * (Cx - P_x)
e_b = (Cx - P_x) * (Ay - P_y) - (Cy - P_y) * (Ax - P_x)
e_c = (Ax - P_x) * (By - P_y) - (Ay - P_y) * (Bx - P_x)
# Same-sign-as-area2 inside test (no back-face culling — match either winding).
nondegen = area2.abs() > 1e-6 # threshold rejects near-degenerate triangles
inside = (e_a * area2 >= 0) & (e_b * area2 >= 0) & (e_c * area2 >= 0)
inside = inside & in_bb & nondegen
if not inside.any():
return None
inv_a2 = torch.where(nondegen, 1.0 / area2, torch.zeros_like(area2))
w_a = e_a * inv_a2
w_b = e_b * inv_a2
w_c = e_c * inv_a2
z_a = fv_z[:, 0, None, None]
z_b = fv_z[:, 1, None, None]
z_c = fv_z[:, 2, None, None]
z_grid = w_a * z_a + w_b * z_b + w_c * z_c
fi, yi, xi = inside.nonzero(as_tuple=True)
px_pixel = bb_min_x[fi] + xi
py_pixel = bb_min_y[fi] + yi
pixel_idx = (py_pixel * W + px_pixel).long()
z_flat = z_grid[fi, yi, xi]
bary_flat = torch.stack([w_a[fi, yi, xi], w_b[fi, yi, xi], w_c[fi, yi, xi]], dim=-1)
return pixel_idx, z_flat, fi.long(), bary_flat
def _rasterize_person(
verts_world: torch.Tensor, faces: torch.Tensor,
focal: float, W: int, H: int,
z_buf: torch.Tensor, color_buf: torch.Tensor, mask_buf: torch.Tensor,
shade_fn,
):
# Project image-frame verts to pixel coords. Skip verts at/behind camera.
z_min_ok = 0.05
valid_v = verts_world[:, 2] > z_min_ok
safe_z = verts_world[:, 2].clamp(min=z_min_ok)
px = 0.5 * W + focal * verts_world[:, 0] / safe_z
py = 0.5 * H + focal * verts_world[:, 1] / safe_z
Fv_pix = torch.stack([px, py], dim=-1)[faces] # (F, 3, 2)
Fv_z = verts_world[faces][..., 2] # (F, 3)
Fv_valid = valid_v[faces].all(dim=-1)
sx_face = Fv_pix[..., 0]
sy_face = Fv_pix[..., 1]
bb_min_x = sx_face.amin(dim=-1).floor().long().clamp(min=0, max=W)
bb_max_x = (sx_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=W)
bb_min_y = sy_face.amin(dim=-1).floor().long().clamp(min=0, max=H)
bb_max_y = (sy_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=H)
sx_all = bb_max_x - bb_min_x
sy_all = bb_max_y - bb_min_y
valid_face = Fv_valid & (sx_all > 0) & (sy_all > 0)
keep = torch.where(valid_face)[0]
if keep.numel() == 0:
return
# Sort kept faces by max bbox dimension so chunks stay similarly-sized.
bbsize = torch.maximum(sx_all, sy_all)[keep]
order = torch.argsort(bbsize)
keep = keep[order]
n = keep.numel()
sx_cpu = sx_all[keep].tolist()
sy_cpu = sy_all[keep].tolist()
bbsize_cpu = bbsize[order].tolist()
PIXEL_BUDGET = 4_000_000
MAX_CHUNK = 8192
i = 0
while i < n:
e = min(i + MAX_CHUNK, n)
# Shrink chunk so worst-case per-face bbox stays within pixel budget.
while e > i + 1:
bb = bbsize_cpu[e - 1]
if (e - i) * bb * bb <= PIXEL_BUDGET:
break
e = max(i + 1, e - max(1, (e - i) // 4))
chunk = keep[i:e]
max_sx = max(sx_cpu[i:e])
max_sy = max(sy_cpu[i:e])
i = e
result = _rasterize_chunk(
Fv_pix[chunk], Fv_z[chunk],
bb_min_x[chunk], bb_max_x[chunk],
bb_min_y[chunk], bb_max_y[chunk],
max_sx, max_sy, W,
)
if result is None:
continue
pixel_idx, z_chunk, face_local, bary = result
face_global = chunk[face_local]
# Atomic depth test against z_buf.
old_at = z_buf[pixel_idx].clone()
z_buf.scatter_reduce_(0, pixel_idx, z_chunk, reduce='amin', include_self=True)
new_at = z_buf[pixel_idx]
is_min = (z_chunk == new_at) & (new_at < old_at)
if not is_min.any():
continue
# Multiple fragments can land on the same pixel and share the new min;
# stable-sort by pixel and keep the first of each run so shade_fn runs
# once per winning pixel. O(M) where M = surviving fragments
surv_pixel = pixel_idx[is_min]
surv_face = face_global[is_min]
surv_bary = bary[is_min]
sort_perm = torch.argsort(surv_pixel, stable=True)
sp = surv_pixel[sort_perm]
first = torch.ones_like(sp, dtype=torch.bool)
first[1:] = sp[1:] != sp[:-1]
selected = sort_perm[first]
wp_idx = surv_pixel[selected]
wp_face = surv_face[selected]
wp_bary = surv_bary[selected]
color_buf[wp_idx] = shade_fn(wp_face, wp_bary)
mask_buf[wp_idx] = True
def _make_shade_fn(
shader_preset, composite,
view_normals_v, view_pos_v, vcolor_v, faces,
base_color, light_dir, pastel_mix,
):
device = view_normals_v.device
base_color_t = torch.as_tensor(base_color, dtype=torch.float32, device=device)
light_dir_t = torch.as_tensor(light_dir, dtype=torch.float32, device=device)
# Light-vector constants — normalized once per render call.
l_unit = -light_dir_t
l_unit = l_unit / l_unit.norm().clamp(min=1e-8)
if pastel_mix <= 0.0:
apply_pastel = lambda rgb: rgb
else:
pm = float(pastel_mix)
apply_pastel = lambda rgb: rgb * (1.0 - pm) + pm
def gather_n(face_idx, bary):
n = (view_normals_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
return n / n.norm(dim=-1, keepdim=True).clamp(min=1e-8)
def gather_pos(face_idx, bary):
return (view_pos_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
def gather_color(face_idx, bary):
return (vcolor_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
if composite == "silhouette":
return lambda fi, ba: torch.ones((fi.shape[0], 3), device=device)
if shader_preset == "normals":
# View-space surface normal encoded as RGB (OpenGL Y+ convention).
# +X right → R; +Y up → G; +Z toward viewer → B. Each face shows mostly
# one channel, matching standard normal-map visualization.
def shade(face_idx, bary):
n = gather_n(face_idx, bary)
return apply_pastel(((n + 1.0) * 0.5).clamp(0.0, 1.0))
return shade
use_vcolor = (shader_preset in _CANONICAL_PRESETS) and (vcolor_v is not None)
if not use_vcolor:
# default.frag: ambient + diffuse + rim
def shade(face_idx, bary):
n = gather_n(face_idx, bary)
v = -gather_pos(face_idx, bary)
v = v / v.norm(dim=-1, keepdim=True).clamp(min=1e-8)
ndotl = (n * l_unit).sum(dim=-1).clamp(min=0)
ndotv = (n * v).sum(dim=-1).clamp(min=0)
rim = (1.0 - ndotv).pow(3.0)
lit = 0.25 * base_color_t \
+ 0.75 * base_color_t * ndotl.unsqueeze(-1) \
+ 0.35 * rim.unsqueeze(-1)
return apply_pastel(lit)
return shade
if shader_preset == "rainbow":
def shade(face_idx, bary):
base = gather_color(face_idx, bary)
n = gather_n(face_idx, bary)
ndotl = (n * l_unit).sum(dim=-1).clamp(min=0)
return apply_pastel(base * (0.65 + 0.35 * ndotl).unsqueeze(-1))
return shade
# rainbow_face_* → rainbow_lit.frag. All light-direction & half-vector
# constants depend only on the (constant) light_dir, so precompute them.
key_l = l_unit
fill_l = torch.stack([-key_l[0], key_l[1].abs(), -key_l[2]])
view_dir = torch.tensor([0.0, 0.0, 1.0], device=device)
h = key_l + view_dir
h = h / h.norm().clamp(min=1e-8)
def shade(face_idx, bary):
base = gather_color(face_idx, bary)
n = gather_n(face_idx, bary)
key_ndotl = (n * key_l).sum(dim=-1).clamp(min=0)
fill_ndotl = (n * fill_l).sum(dim=-1).clamp(min=0)
rim = (1.0 - n[..., 2].clamp(min=0)).pow(2.5) * 0.30
shade_val = (0.45 + 0.45 * key_ndotl + 0.15 * fill_ndotl + rim * 0.5).clamp(min=0.0, max=1.25)
ndoth = (n * h).sum(dim=-1).clamp(min=0)
spec = ndoth.pow(48) * 0.12
lit = base * shade_val.unsqueeze(-1) + spec.unsqueeze(-1)
return apply_pastel(lit)
return shade
def render_pose_data_torch(
pose_data: dict,
frame_idx: int,
W: int,
H: int,
background=None, # Optional[np.ndarray | torch.Tensor] (H, W, 3) fp32 [0, 1]
composite: str = "over",
opacity: float = 1.0,
shader_preset: str = "default",
base_color: Sequence[float] = (0.68, 0.71, 0.78),
light_dir: Sequence[float] = (0.4, -0.7, -0.6),
rainbow_tilt_x_deg: float = 0.0,
rainbow_tilt_z_deg: float = 0.0,
person_brightness_falloff: float = 0.6,
) -> torch.Tensor:
"""Render one frame of persons from `pose_data` at resolution WxH.
Returns an (H, W, 3) float32 torch.Tensor on the comfy compute device,
ready to be stacked into the node's IMAGE output without a CPU round-trip."""
device = comfy.model_management.get_torch_device()
persons = pose_data["frames"][frame_idx] if frame_idx < len(pose_data["frames"]) else []
if len(persons) == 0:
if composite == "over" and background is not None:
if isinstance(background, np.ndarray):
bg = torch.as_tensor(background, dtype=torch.float32, device=device)
else:
bg = background.to(device=device, dtype=torch.float32) if (
background.device != device or background.dtype != torch.float32
) else background
return bg.clamp(0.0, 1.0)
return torch.zeros((H, W, 3), device=device, dtype=torch.float32)
faces = torch.as_tensor(np.asarray(pose_data["faces"], dtype=np.int64), device=device)
canonical_colors = pose_data.get("canonical_colors")
using_canonical = shader_preset in _CANONICAL_PRESETS
if using_canonical and canonical_colors is None:
shader_preset = "default"
using_canonical = False
vcolor = None
if using_canonical:
vcolor_np = _build_vcolor(canonical_colors, shader_preset,
rainbow_tilt_x_deg, rainbow_tilt_z_deg)
vcolor = torch.as_tensor(vcolor_np, dtype=torch.float32, device=device)
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
person_pastel = [0.0 if k == 0 else (1.0 - falloff ** k) for k in range(len(persons))]
# Front-to-back draw order so nearer persons overdraw farther ones.
order = sorted(range(len(persons)),
key=lambda i: -float(np.asarray(persons[i]["pred_cam_t"]).reshape(-1)[2]))
HW = H * W
z_buf = torch.full((HW,), float('inf'), device=device, dtype=torch.float32)
color_buf = torch.zeros((HW, 3), device=device, dtype=torch.float32)
mask_buf = torch.zeros(HW, device=device, dtype=torch.bool)
for idx in order:
p = persons[idx]
verts_np = np.asarray(p["pred_vertices"], dtype=np.float32).reshape(-1, 3)
cam_t = np.asarray(p["pred_cam_t"], dtype=np.float32).reshape(3)
verts_world = torch.as_tensor(verts_np + cam_t[None, :],
device=device, dtype=torch.float32)
focal = float(np.asarray(p.get("focal_length", 5000.0)).reshape(-1)[0])
# Image-frame (+Y down, +Z forward) → view-space (+Y up, -Z forward)
# for shading, matching what the GL-style shader math expects.
view_pos_v = torch.stack(
[verts_world[:, 0], -verts_world[:, 1], -verts_world[:, 2]], dim=-1,
)
normals_world = _vertex_normals(verts_world, faces)
view_normals_v = torch.stack(
[normals_world[:, 0], -normals_world[:, 1], -normals_world[:, 2]], dim=-1,
)
vcolor_p = vcolor if (vcolor is not None and vcolor.shape[0] == verts_world.shape[0]) else None
# Only canonical-vcolor shaders need vcolor; geometric shaders
# ('normals', 'depth') and the lit default work without it.
if shader_preset in _CANONICAL_PRESETS and vcolor_p is None:
effective_preset = "default"
else:
effective_preset = shader_preset
shade_fn = _make_shade_fn(
effective_preset, composite,
view_normals_v, view_pos_v, vcolor_p, faces,
base_color, light_dir, person_pastel[idx],
)
_rasterize_person(
verts_world, faces, focal, W, H,
z_buf, color_buf, mask_buf, shade_fn,
)
# Stay on GPU through readback + composite.
if shader_preset == "depth":
# z_buf already holds linear image-frame z (smaller=closer; +inf where no mesh covers)
# Normalize within the rendered mesh's range: near=white, far=black, background=black
mask_2d = mask_buf.reshape(H, W)
z_2d = z_buf.reshape(H, W)
if mask_2d.any():
zin = z_2d[mask_2d]
zmin = zin.min()
zr = (zin.max() - zmin).clamp(min=1e-6)
norm = torch.where(mask_2d, 1.0 - (z_2d - zmin) / zr, torch.zeros_like(z_2d))
else:
norm = torch.zeros((H, W), device=device, dtype=torch.float32)
rendered = torch.stack([norm, norm, norm], dim=-1)
mask_f = mask_2d.float()
else:
rendered = color_buf.reshape(H, W, 3).clamp(0.0, 1.0)
mask_f = mask_buf.reshape(H, W).float()
if composite == "over" and background is not None:
if isinstance(background, np.ndarray):
bg = torch.as_tensor(background, dtype=torch.float32, device=device)
else:
bg = background.to(device=device, dtype=torch.float32)
a = mask_f.unsqueeze(-1)
if opacity != 1.0:
a = a * float(opacity)
rendered = torch.lerp(bg, rendered, a)
return rendered

View File

@ -0,0 +1,503 @@
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import numpy as np
from comfy.ldm.sam3d_body.utils import prepare_batch
from comfy.ldm.sam3.tracker import unpack_masks
from comfy.ldm.sam3d_body.model.model import SAM3DBody
import comfy.model_management
import comfy.utils
from tqdm import tqdm
def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]:
"""xyxy bounds of a binary mask, with sub-5px speckles filtered out."""
m = mask[..., 0] if mask.dim() == 3 else mask
m_bool = m > 0
if not m_bool.any():
return None
t = m_bool.to(torch.float32)[None, None]
eroded = -F.max_pool2d(-t, kernel_size=5, stride=1, padding=2)
keep = eroded[0, 0] > 0.5
if not keep.any():
keep = m_bool
ys, xs = torch.where(keep)
return torch.stack([
xs.min().float(), ys.min().float(),
(xs.max() + 1).float(), (ys.max() + 1).float(),
])
def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
resolution. Returns (None, None) on empty track / frame-count mismatch."""
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
if packed is None:
return None, None
N, K = packed.shape[0], packed.shape[1]
if N != B or K == 0:
return None, None
device = comfy.model_management.get_torch_device()
unpacked = unpack_masks(packed.to(device)) # (N, K, Hm, Wm) bool
Hm, Wm = unpacked.shape[2], unpacked.shape[3]
resized = F.interpolate(
unpacked.float().reshape(N * K, 1, Hm, Wm),
size=(H, W), mode="bilinear", align_corners=False,
)
arr_gpu = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W)
arr = arr_gpu.cpu()
per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)]
full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32)
per_frame_bboxes = []
for f in range(N):
derived = []
for k in range(K):
# Erosion + argmax bbox on GPU; CPU max_pool2d over N*K full-res masks is slow.
b = _bbox_from_mask(arr_gpu[f, k])
derived.append(b.cpu() if b is not None else full_frame_bbox)
per_frame_bboxes.append(torch.stack(derived, dim=0))
return per_frame_bboxes, per_frame_masks
# Soft budget for the batched Predict path
BATCHED_CROPS_PER_CHUNK = 64
def _quat_to_mat_wxyz(w: float, x: float, y: float, z: float) -> np.ndarray:
"""(3,3) rotation from a wxyz quaternion; columns are the rotated axes."""
n = math.sqrt(w * w + x * x + y * y + z * z) or 1.0
w, x, y, z = w / n, x / n, y / n, z / n
return np.array([
[1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y)],
[2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x)],
[2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y)],
], dtype=np.float32)
def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[torch.Tensor]:
"""(1,3,3) intrinsic matrix from a vertical FOV in degrees. Matches MoGe2's
convention (vertical focal for both axes). Returns None for fov<=0 so the
caller falls back to prepare_batch's diagonal-focal default."""
if fov_degrees <= 0:
return None
focal = height / (2.0 * math.tan(math.radians(fov_degrees) / 2.0))
return torch.tensor(
[[[focal, 0.0, width / 2.0],
[0.0, focal, height / 2.0],
[0.0, 0.0, 1.0]]],
dtype=torch.float32,
)
def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any],
H: int, W: int) -> Dict[str, Any]:
"""Re-project every frame's pose through a Load3D 6DOF camera (position/
target/zoom + optional FOV). Returns a new mhr_pose_data, unchanged on
empty/invalid input."""
first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else []
if not first_frame:
return mhr_pose_data
# Per-person rig root (pred_cam_t) and body centroid (mesh mean), in camera space.
roots, centroids = [], []
for p in first_frame:
cam_t = p.get("pred_cam_t")
if cam_t is None:
continue
cam_t = np.asarray(cam_t, dtype=np.float32).reshape(3)
roots.append(cam_t)
v = p.get("pred_vertices")
centroids.append(np.asarray(v, dtype=np.float32).reshape(-1, 3).mean(axis=0) + cam_t
if v is not None else cam_t)
if not roots:
return mhr_pose_data
subj_root = np.mean(np.stack(roots, axis=0), axis=0)
subj_centroid = np.mean(np.stack(centroids, axis=0), axis=0)
# Meter-scale, so Three.js coords map 1:1 (Three.js Y-up → flip Y,Z)
pos = camera_info.get("position") or {}
tgt = camera_info.get("target") or {}
pos_v = np.array([float(pos.get("x", 0.0)), -float(pos.get("y", 5.0)), -float(pos.get("z", 0.0))], dtype=np.float32)
tgt_v = np.array([float(tgt.get("x", 0.0)), -float(tgt.get("y", 0.0)), -float(tgt.get("z", 0.0))], dtype=np.float32)
offset = pos_v - tgt_v
has_offset = float(np.linalg.norm(offset)) >= 1e-6
q = camera_info.get("quaternion")
if not has_offset and not q:
return mhr_pose_data # no viewpoint and no orientation -> nothing to apply
zoom = float(camera_info.get("zoom", 1.0)) or 1.0
# SAM3D roots near the feet. A target at the origin -> center the body centroid
if float(np.linalg.norm(tgt_v)) < 1e-6:
target = subj_centroid
else:
target = subj_root + tgt_v
if q:
mv = lambda v: np.array([v[0], -v[1], -v[2]], dtype=np.float32)
norm = lambda v: v / max(1e-6, float(np.linalg.norm(v)))
Rc = _quat_to_mat_wxyz(
float(q.get("w", 1.0)), float(q.get("x", 0.0)),
float(q.get("y", 0.0)), float(q.get("z", 0.0)),
) # columns = camera world axes
x_axis = norm(mv(Rc[:, 0])) # camera +X -> image right
y_axis = norm(mv(-Rc[:, 1])) # image +Y is down -> negative of camera up
z_axis = norm(mv(-Rc[:, 2])) # camera looks down local -Z -> view direction
else:
# x degenerates only when looking straight along world-up -> world +X.
z_axis = -offset / float(np.linalg.norm(offset))
x_axis = np.cross(z_axis, np.array([0.0, -1.0, 0.0], dtype=np.float32))
x_norm = float(np.linalg.norm(x_axis))
x_axis = x_axis / x_norm if x_norm > 1e-6 else np.array([1.0, 0.0, 0.0], dtype=np.float32)
y_axis = np.cross(z_axis, x_axis)
R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32)
# Eye: dolly along the offset; rotation-only camera keeps the predicted
# viewing distance so only orientation/roll changes.
if has_offset:
eye = target + offset / max(0.01, zoom)
else:
d = max(0.1, float(target[2]))
eye = target - z_axis * (d / max(0.01, zoom))
# Lens: camera FoV if given, else the SAM3D predicted focal. Three.js fov
# is vertical → focal from image height.
cam_fov = float(camera_info.get("fov", 0.0) or 0.0)
if cam_fov > 0:
new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0)))
else:
f0 = first_frame[0].get("focal_length")
new_focal = (float(np.asarray(f0, dtype=np.float32).reshape(-1)[0]) if f0 is not None
else float(H) / (2.0 * float(np.tan(np.deg2rad(50.0) / 2.0))))
center = np.array([W * 0.5, H * 0.5], dtype=np.float32)
reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"}
# External rigs store pred_joint_coords Y-up; transform them through the
# camera too (in camera space, then back to Y-up) so they follow the override.
override = mhr_pose_data.get("_skeleton_override")
joints_y_up = override is not None and not bool(override.get("per_frame_y_down", False))
new_frames: List[List[Dict[str, Any]]] = []
for frame in mhr_pose_data["frames"]:
scaled = []
for p in frame:
p = dict(p)
cam_t = p.get("pred_cam_t")
if cam_t is None:
scaled.append(p)
continue
cam_t = np.asarray(cam_t, dtype=np.float32).reshape(3)
for k in ("pred_keypoints_3d", "pred_vertices", "pred_face_keypoints_3d"):
v = p.get(k)
if v is None:
continue
cam = (np.asarray(v, dtype=np.float32) + cam_t - eye) @ R.T
p[k] = cam.astype(np.float32)
if k in reproj: # re-project the new 3D to 2D image coords
z = np.maximum(cam[..., 2:3], 1e-6)
p[reproj[k]] = (cam[..., :2] * new_focal / z + center).astype(np.float32)
jc = p.get("pred_joint_coords")
if jc is not None:
jc = np.asarray(jc, dtype=np.float32).copy()
if joints_y_up:
jc[..., 1] *= -1.0
jc[..., 2] *= -1.0
jc = (jc + cam_t - eye) @ R.T
if joints_y_up:
jc[..., 1] *= -1.0
jc[..., 2] *= -1.0
p["pred_joint_coords"] = jc.astype(np.float32)
p["pred_cam_t"] = np.zeros(3, dtype=np.float32)
p["focal_length"] = np.array(new_focal, dtype=np.float32)
scaled.append(p)
new_frames.append(scaled)
out = dict(mhr_pose_data)
out["frames"] = new_frames
return out
def run_batched_single_chunk(inner: SAM3DBody, frames_rgb: List[torch.Tensor], per_frame_boxes: List[torch.Tensor],
per_frame_masks: Optional[List[torch.Tensor]], image_size: Tuple[int, int], inference_type: str, K: int,
cam_int: Optional[torch.Tensor] = None) -> List[List[Dict[str, Any]]]:
"""Run a SINGLE chunk of frames through run_inference in one forward."""
N = len(frames_rgb)
total = N * K
# Reset stateful caches on the model
for attr in ("batch", "image_embeddings", "output"):
if hasattr(inner, attr):
setattr(inner, attr, None)
inner.prev_prompt = []
boxes_stacked = torch.stack(
[per_frame_boxes[f][k] for f in range(N) for k in range(K)], dim=0
)
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
if per_frame_masks is not None:
# One mask but multiple bboxes per frame → each bbox gets the same mask.
flat_masks = []
for f in range(N):
mf = per_frame_masks[f]
if mf.shape[0] == 1 and K > 1:
mf = mf.repeat_interleave(K, dim=0)
flat_masks.extend([mf[k] for k in range(K)])
masks_stacked = torch.stack(flat_masks, dim=0)
masks_score = torch.ones(total, dtype=torch.float32)
else:
masks_stacked = None
masks_score = None
batch = prepare_batch(
img_per_crop, boxes_stacked,
input_size=image_size,
masks=masks_stacked, masks_score=masks_score, cam_int=cam_int,
)
device = comfy.model_management.get_torch_device()
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
inner._initialize_batch(batch)
outputs = inner.run_inference(
img_per_crop,
batch,
inference_type=inference_type,
thresh_wrist_angle=1.4,
)
if inference_type == "full":
pose_output, batch_lhand, batch_rhand, _, _ = outputs
else:
pose_output = outputs
batch_lhand = batch_rhand = None
out = {k: v.float().cpu().numpy() for k, v in pose_output["mhr"].items()
if v is not None and k != "faces"}
# Snapshot batch['bbox'] to CPU before we release `batch` references
batch_bbox_cpu = batch["bbox"][0].cpu().numpy()
lhand_bboxes = rhand_bboxes = None
if inference_type == "full" and batch_lhand is not None and batch_rhand is not None:
lhand_bboxes = [_bbox_from_center_scale(batch_lhand, i) for i in range(total)]
rhand_bboxes = [_bbox_from_center_scale(batch_rhand, i) for i in range(total)]
del pose_output, batch, batch_lhand, batch_rhand, outputs
frames_out: List[List[Dict[str, Any]]] = []
for f in range(N):
persons: List[Dict[str, Any]] = []
for k in range(K):
idx = f * K + k
p: Dict[str, Any] = {
"bbox": batch_bbox_cpu[idx],
"focal_length": out["focal_length"][idx],
"pred_keypoints_3d": out["pred_keypoints_3d"][idx],
"pred_keypoints_2d": out["pred_keypoints_2d"][idx],
"pred_vertices": out["pred_vertices"][idx],
"pred_cam_t": out["pred_cam_t"][idx],
"pred_pose_raw": out["pred_pose_raw"][idx],
"global_rot": out["global_rot"][idx],
"body_pose_params": out["body_pose"][idx],
"hand_pose_params": out["hand"][idx],
"scale_params": out["scale"][idx],
"shape_params": out["shape"][idx],
"expr_params": out["face"][idx],
"mask": (per_frame_masks[f][k] if per_frame_masks[f].shape[0] > 1 else per_frame_masks[f][0])
if per_frame_masks is not None else None,
"pred_joint_coords": out["pred_joint_coords"][idx],
"pred_global_rots": out["joint_global_rots"][idx],
"mhr_model_params": out["mhr_model_params"][idx],
# 238 face landmarks from sapiens-308 (indices 70..308 of the pre-slice keypoint tensor).
"pred_face_keypoints_3d": out["pred_face_keypoints_3d"][idx] if "pred_face_keypoints_3d" in out else None,
"pred_face_keypoints_2d": out["pred_face_keypoints_2d"][idx] if "pred_face_keypoints_2d" in out else None,
}
if lhand_bboxes is not None:
p["lhand_bbox"] = lhand_bboxes[idx]
p["rhand_bbox"] = rhand_bboxes[idx]
persons.append(p)
frames_out.append(persons)
return frames_out
def run_batched_frames(
inner: SAM3DBody,
frames_rgb: List[torch.Tensor],
per_frame_boxes: List[torch.Tensor],
per_frame_masks: Optional[List[torch.Tensor]],
image_size: Tuple[int, int],
inference_type: str,
cam_int: Optional[torch.Tensor] = None,
pbar: Optional[comfy.utils.ProgressBar] = None,
crops_per_chunk: int = BATCHED_CROPS_PER_CHUNK,
) -> List[List[Dict[str, Any]]]:
"""Run the clip through chunked batched run_inference calls.
Supports K persons per frame (K must be the same across frames padded
externally). Splits frames into chunks so chunk_frames * K <= budget; each
chunk is one body forward + optional hand forwards over its person-crops
"""
N = len(frames_rgb)
assert N > 0, "empty frame list"
K_set = {len(b) for b in per_frame_boxes}
assert len(K_set) == 1, f"batched path requires same bbox count per frame, got {K_set}"
K = K_set.pop()
assert K >= 1, "need at least one bbox per frame"
chunk_frames = max(1, crops_per_chunk // K)
results: List[List[Dict[str, Any]]] = []
with tqdm(total=N, desc="SAM3D body inference") as t:
for start in range(0, N, chunk_frames):
end = min(N, start + chunk_frames)
sub_frames = frames_rgb[start:end]
sub_boxes = per_frame_boxes[start:end]
sub_masks = None if per_frame_masks is None else per_frame_masks[start:end]
chunk_result = run_batched_single_chunk(
inner, sub_frames, sub_boxes, sub_masks,
image_size, inference_type, K,
cam_int=cam_int,
)
results.extend(chunk_result)
t.update(end - start)
if pbar is not None:
pbar.update(end - start)
# Drop GPU caches so the next chunk starts from a clean allocator state
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def _bbox_from_center_scale(batch, idx: int) -> np.ndarray:
cx = batch["bbox_center"].flatten(0, 1)[idx][0].item()
cy = batch["bbox_center"].flatten(0, 1)[idx][1].item()
sx = batch["bbox_scale"].flatten(0, 1)[idx][0].item()
sy = batch["bbox_scale"].flatten(0, 1)[idx][1].item()
return np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2], dtype=np.float32)
# Wire types and small helpers shared across the SAM 3D Body node modules.
def image_to_uint8(image: torch.Tensor) -> torch.Tensor:
"""ComfyUI image tensor (any shape, float 0..1) → uint8 tensor in [0, 255] on CPU."""
return (image * 255.0).clamp(0.0, 255.0).to(dtype=torch.uint8, device="cpu")
def compute_canonical_colors(model) -> Dict[str, np.ndarray]:
"""Canonical rest-pose data for shader color lookups: positions (Nv,3),
norm (Nv,3 in [0,1]), face_mask, head_mask, and face_region_rgb
(per-region painted color from the .safetensors)."""
verts = model.head_pose.canonical_vertices().float().cpu().numpy()
faces = model.head_pose.faces.cpu().numpy()
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
fn = np.cross(v1 - v0, v2 - v0).astype(np.float32)
vn = np.zeros_like(verts, dtype=np.float32)
np.add.at(vn, faces[:, 0], fn)
np.add.at(vn, faces[:, 1], fn)
np.add.at(vn, faces[:, 2], fn)
ln = np.linalg.norm(vn, axis=1, keepdims=True)
ln[ln < 1e-8] = 1.0
vn = vn / ln
norm_map = ((vn + 1.0) * 0.5).astype(np.float32)
face_mask = _compute_face_mask(model)
# Head: above jaw-neck (y>1.43) and narrower than shoulders (|x|<0.11).
# Ears reach |x|≈0.09; shoulders start at |x|≈0.20.
head_mask = (verts[:, 1] > 1.43) & (np.abs(verts[:, 0]) < 0.11)
# Painted per-vertex face region RGB ships in the model .safetensors as
# `head_pose.face_region_rgb` and gets loaded by load_state_dict.
face_region_rgb = model.head_pose.face_region_rgb.detach().float().cpu().numpy()
return {
"positions": verts.astype(np.float32),
"norm": norm_map,
"face_mask": face_mask,
"head_mask": head_mask,
"face_region_rgb": face_region_rgb,
}
def compute_hand_vert_mask(model, hand_radius_m: float = 0.15, weight_threshold: float = 0.5) -> np.ndarray:
"""(Nv,) bool mask of hand-region verts. Picks joints within `hand_radius_m`
of the mhr70 hand keypoint clusters (indices 21..62), then sums sparse LBS
weights across them; verts above `weight_threshold` are hand verts."""
head = model.head_pose
mhr = head.mhr
device = head.scale_mean.device
zeros = lambda *s: torch.zeros(1, *s, device=device)
out = head.mhr_forward(
global_trans=zeros(3),
global_rot=zeros(3),
body_pose_params=zeros(130),
hand_pose_params=zeros(head.num_hand_comps * 2),
scale_params=zeros(head.num_scale_comps),
shape_params=zeros(head.num_shape_comps),
expr_params=zeros(head.num_face_comps),
return_keypoints=True,
return_joint_coords=True,
)
# Output order with these flags: (verts, kp, jcoords). See mhr_head.mhr_forward.
_, kp, jcoords = out[0], out[1], out[2]
kp = kp[0, :70].cpu().numpy()
jcoords = jcoords[0].cpu().numpy()
right_center = kp[21:42].mean(axis=0)
left_center = kp[42:63].mean(axis=0)
j_dist_r = np.linalg.norm(jcoords - right_center, axis=1)
j_dist_l = np.linalg.norm(jcoords - left_center, axis=1)
is_hand_joint = (j_dist_r < hand_radius_m) | (j_dist_l < hand_radius_m)
lbs_w = mhr.lbs_skin_weights.cpu().numpy()
lbs_v = mhr.lbs_vert_indices.cpu().numpy()
lbs_j = mhr.lbs_skin_indices.cpu().numpy()
is_hand_joint_f = is_hand_joint.astype(np.float32)
n_verts = mhr.NUM_VERTS
hand_mass = np.zeros(n_verts, dtype=np.float32)
np.add.at(hand_mass, lbs_v, lbs_w * is_hand_joint_f[lbs_j])
return hand_mass >= weight_threshold
def _compute_face_mask(model, disp_threshold_m: float = 1e-4) -> np.ndarray:
"""(Nv,) bool mask of verts that move with face expression. Sweeps each of
the 72 expression axes at coef=+1.0 and flags any vert that moves more
than `disp_threshold_m` for at least one axis."""
head = model.head_pose
device = head.scale_mean.device
num_face = head.num_face_comps
zeros = lambda *s: torch.zeros(1, *s, device=device)
neutral_kw = dict(
global_trans=zeros(3),
global_rot=zeros(3),
body_pose_params=zeros(130),
hand_pose_params=zeros(head.num_hand_comps * 2),
scale_params=zeros(head.num_scale_comps),
shape_params=zeros(head.num_shape_comps),
expr_params=zeros(num_face),
)
v0 = head.mhr_forward(**neutral_kw).cpu().numpy()[0] # (Nv, 3)
face_mask = np.zeros(v0.shape[0], dtype=bool)
for axis in range(num_face):
expr = zeros(num_face)
expr[0, axis] = 1.0
kw = dict(neutral_kw)
kw["expr_params"] = expr
v = head.mhr_forward(**kw).cpu().numpy()[0]
face_mask |= (np.linalg.norm(v - v0, axis=1) > disp_threshold_m)
return face_mask
def jet_colormap(s: np.ndarray) -> np.ndarray:
"""matplotlib jet, (N,) in [0,1] -> (N, 3) float32 RGB."""
s = np.asarray(s, dtype=np.float32).clip(0.0, 1.0)
r = np.interp(s, [0.0, 0.35, 0.66, 0.89, 1.0], [0.0, 0.0, 1.0, 1.0, 0.5])
g = np.interp(s, [0.0, 0.125, 0.375, 0.64, 0.91, 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0])
b = np.interp(s, [0.0, 0.11, 0.34, 0.65, 1.0], [0.5, 1.0, 1.0, 0.0, 0.0])
return np.stack([r, g, b], axis=-1).astype(np.float32)

View File

@ -2502,6 +2502,7 @@ async def init_builtin_extra_nodes():
"nodes_triposplat.py",
"nodes_depth_anything_3.py",
"nodes_seed.py",
"nodes_sam3d_body.py",
]
import_failed = []