mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Initial sam3d body support
This commit is contained in:
parent
04879a8113
commit
1294041778
377
comfy/ldm/sam3d_body/mhr/mhr_head.py
Normal file
377
comfy/ldm/sam3d_body/mhr/mhr_head.py
Normal file
@ -0,0 +1,377 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import euler_to_rotmat, rot6d_to_rotmat, rotmat_to_euler, unitquat_to_rotmat
|
||||
from .mhr_utils import compact_cont_to_model_params_body, compact_cont_to_model_params_hand, compact_model_params_to_cont_body, mhr_param_hand_mask
|
||||
|
||||
from ..model.transformer import MLP
|
||||
|
||||
|
||||
class MHRHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
mhr_rig,
|
||||
mlp_depth: int = 1,
|
||||
extra_joint_regressor: str = "",
|
||||
mlp_channel_div_factor: int = 8,
|
||||
enable_hand_model=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Store the shared MHRRig as a non-registered Python attribute
|
||||
object.__setattr__(self, "mhr", mhr_rig)
|
||||
|
||||
self.num_shape_comps = 45
|
||||
self.num_scale_comps = 28
|
||||
self.num_hand_comps = 54
|
||||
self.num_face_comps = 72
|
||||
self.enable_hand_model = enable_hand_model
|
||||
|
||||
self.body_cont_dim = 260
|
||||
self.npose = (
|
||||
6 # Global Rotation
|
||||
+ self.body_cont_dim # then body
|
||||
+ self.num_shape_comps
|
||||
+ self.num_scale_comps
|
||||
+ self.num_hand_comps * 2
|
||||
+ self.num_face_comps
|
||||
)
|
||||
|
||||
self.proj = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dim=input_dim // mlp_channel_div_factor,
|
||||
output_dim=self.npose,
|
||||
num_layers=mlp_depth,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# MHR Parameters
|
||||
self.num_hand_scale_comps = self.num_scale_comps - 18
|
||||
self.num_hand_pose_comps = self.num_hand_comps
|
||||
|
||||
# Buffers populated by load_state_dict from the safetensors
|
||||
def _p(*shape, dtype=torch.float32):
|
||||
return nn.Parameter(torch.empty(*shape, dtype=dtype), requires_grad=False)
|
||||
self.joint_rotation = _p(127, 3, 3)
|
||||
self.scale_mean = _p(68)
|
||||
self.scale_comps = _p(28, 68)
|
||||
self.faces = _p(36874, 3, dtype=torch.int64)
|
||||
self.hand_pose_mean = _p(54)
|
||||
self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
|
||||
self.hand_joint_idxs_left = _p(27, dtype=torch.int64)
|
||||
self.hand_joint_idxs_right = _p(27, dtype=torch.int64)
|
||||
self.keypoint_mapping = _p(308, 18439 + 127)
|
||||
# Some special buffers for the hand-version
|
||||
self.right_wrist_coords = _p(3)
|
||||
self.root_coords = _p(3)
|
||||
self.local_to_world_wrist = _p(3, 3)
|
||||
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
|
||||
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
|
||||
# Optional — loaded from the .safetensors if present, otherwise the
|
||||
# render path falls back to a coarse geometric approximation.
|
||||
self.register_buffer(
|
||||
"face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32),
|
||||
)
|
||||
|
||||
def canonical_vertices(self, device=None):
|
||||
"""Return the T-pose vertices for the mean shape (scaled to meters).
|
||||
|
||||
Runs MHR with zero pose / shape / scale / expression so the returned
|
||||
mesh is the canonical rest pose — fixed per-model
|
||||
"""
|
||||
dev = device or self.scale_mean.device
|
||||
dt = self.scale_mean.dtype
|
||||
B = 1
|
||||
global_trans = torch.zeros(B, 3, device=dev, dtype=dt)
|
||||
global_rot = torch.zeros(B, 3, device=dev, dtype=dt)
|
||||
body_pose = torch.zeros(B, 130, device=dev, dtype=dt)
|
||||
hand_pose = torch.zeros(B, self.num_hand_comps * 2, device=dev, dtype=dt)
|
||||
scale = torch.zeros(B, self.num_scale_comps, device=dev, dtype=dt)
|
||||
shape = torch.zeros(B, self.num_shape_comps, device=dev, dtype=dt)
|
||||
expr = torch.zeros(B, self.num_face_comps, device=dev, dtype=dt)
|
||||
verts = self.mhr_forward(
|
||||
global_trans=global_trans,
|
||||
global_rot=global_rot,
|
||||
body_pose_params=body_pose,
|
||||
hand_pose_params=hand_pose,
|
||||
scale_params=scale,
|
||||
shape_params=shape,
|
||||
expr_params=expr,
|
||||
) # single-tensor shape (1, N_v, 3) in meters
|
||||
return verts[0]
|
||||
|
||||
def get_zero_pose_init(self, factor=1.0):
|
||||
# Initialize pose token with zero-initialized learnable params
|
||||
# Note: bias/initial value should be zero-pose in cont, not all-zeros
|
||||
weights = torch.zeros(1, self.npose)
|
||||
weights[:, : 6 + self.body_cont_dim] = torch.cat(
|
||||
[
|
||||
torch.FloatTensor([1, 0, 0, 0, 1, 0]),
|
||||
compact_model_params_to_cont_body(torch.zeros(1, 133)).squeeze()
|
||||
* factor,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
return weights
|
||||
|
||||
def replace_hands_in_pose(self, full_pose_params, hand_pose_params):
|
||||
assert full_pose_params.shape[1] == 136
|
||||
|
||||
# This drops in the hand poses from hand_pose_params (PCA 6D) into full_pose_params.
|
||||
# Split into left and right hands
|
||||
left_hand_params, right_hand_params = torch.split(
|
||||
hand_pose_params,
|
||||
[self.num_hand_pose_comps, self.num_hand_pose_comps],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Change from cont to model params
|
||||
left_hand_params_model_params = compact_cont_to_model_params_hand(
|
||||
self.hand_pose_mean
|
||||
+ torch.einsum("da,ab->db", left_hand_params, self.hand_pose_comps)
|
||||
)
|
||||
right_hand_params_model_params = compact_cont_to_model_params_hand(
|
||||
self.hand_pose_mean
|
||||
+ torch.einsum("da,ab->db", right_hand_params, self.hand_pose_comps)
|
||||
)
|
||||
|
||||
# Drop it in
|
||||
full_pose_params[:, self.hand_joint_idxs_left] = left_hand_params_model_params
|
||||
full_pose_params[:, self.hand_joint_idxs_right] = right_hand_params_model_params
|
||||
|
||||
return full_pose_params # B x 207
|
||||
|
||||
def mhr_forward(
|
||||
self,
|
||||
global_trans,
|
||||
global_rot,
|
||||
body_pose_params,
|
||||
hand_pose_params,
|
||||
scale_params,
|
||||
shape_params,
|
||||
expr_params=None,
|
||||
return_keypoints=False,
|
||||
do_pcblend=True,
|
||||
return_joint_coords=False,
|
||||
return_model_params=False,
|
||||
return_joint_rotations=False,
|
||||
scale_offsets=None,
|
||||
vertex_offsets=None,
|
||||
):
|
||||
# Align everything to the static buffers
|
||||
dt = self.scale_mean.dtype
|
||||
global_trans = global_trans.to(dt)
|
||||
global_rot = global_rot.to(dt)
|
||||
body_pose_params = body_pose_params.to(dt)
|
||||
if hand_pose_params is not None:
|
||||
hand_pose_params = hand_pose_params.to(dt)
|
||||
scale_params = scale_params.to(dt)
|
||||
shape_params = shape_params.to(dt)
|
||||
if expr_params is not None:
|
||||
expr_params = expr_params.to(dt)
|
||||
|
||||
if self.enable_hand_model:
|
||||
# Transfer wrist-centric predictions to the body.
|
||||
global_rot_ori = global_rot.clone()
|
||||
global_trans_ori = global_trans.clone()
|
||||
global_rot = rotmat_to_euler(
|
||||
"xyz",
|
||||
euler_to_rotmat("xyz", global_rot_ori) @ self.local_to_world_wrist,
|
||||
)
|
||||
global_trans = (
|
||||
-(
|
||||
euler_to_rotmat("xyz", global_rot)
|
||||
@ (self.right_wrist_coords - self.root_coords)
|
||||
+ self.root_coords
|
||||
)
|
||||
+ global_trans_ori
|
||||
)
|
||||
|
||||
body_pose_params = body_pose_params[..., :130]
|
||||
|
||||
# Convert from scale and shape params to actual scales and vertices
|
||||
|
||||
# Add singleton batches in case...
|
||||
if len(scale_params.shape) == 1:
|
||||
scale_params = scale_params[None]
|
||||
if len(shape_params.shape) == 1:
|
||||
shape_params = shape_params[None]
|
||||
# Convert scale...
|
||||
scales = self.scale_mean[None, :] + scale_params @ self.scale_comps
|
||||
if scale_offsets is not None:
|
||||
scales = scales + scale_offsets
|
||||
|
||||
# Now, figure out the pose.
|
||||
## 10 here is because it's more stable to optimize global translation in meters.
|
||||
full_pose_params = torch.cat(
|
||||
[global_trans * 10, global_rot, body_pose_params], dim=1
|
||||
) # B x 127
|
||||
## Put in hands
|
||||
if hand_pose_params is not None:
|
||||
full_pose_params = self.replace_hands_in_pose(
|
||||
full_pose_params, hand_pose_params
|
||||
)
|
||||
model_params = torch.cat([full_pose_params, scales], dim=1)
|
||||
|
||||
if self.enable_hand_model:
|
||||
# Zero out non-hand parameters
|
||||
model_params[:, self.nonhand_param_idxs] = 0
|
||||
|
||||
curr_skinned_verts, curr_skel_state = self.mhr(
|
||||
shape_params, model_params, expr_params
|
||||
)
|
||||
curr_joint_coords, curr_joint_quats, _ = torch.split(
|
||||
curr_skel_state, [3, 4, 1], dim=2
|
||||
)
|
||||
curr_skinned_verts = curr_skinned_verts / 100
|
||||
curr_joint_coords = curr_joint_coords / 100
|
||||
curr_joint_rots = unitquat_to_rotmat(curr_joint_quats)
|
||||
|
||||
# Prepare returns
|
||||
to_return = [curr_skinned_verts]
|
||||
if return_keypoints:
|
||||
# Get sapiens 308 keypoints
|
||||
model_vert_joints = torch.cat(
|
||||
[curr_skinned_verts, curr_joint_coords], dim=1
|
||||
) # B x (num_verts + 127) x 3
|
||||
|
||||
kp_map = self.keypoint_mapping.to(model_vert_joints.dtype)
|
||||
model_keypoints_pred = (
|
||||
(kp_map @ model_vert_joints.permute(1, 0, 2).flatten(1, 2))
|
||||
.reshape(-1, model_vert_joints.shape[0], 3)
|
||||
.permute(1, 0, 2)
|
||||
)
|
||||
|
||||
if self.enable_hand_model:
|
||||
# Zero out everything except for the right hand
|
||||
model_keypoints_pred[:, :21] = 0
|
||||
model_keypoints_pred[:, 42:] = 0
|
||||
|
||||
to_return = to_return + [model_keypoints_pred]
|
||||
if return_joint_coords:
|
||||
to_return = to_return + [curr_joint_coords]
|
||||
if return_model_params:
|
||||
to_return = to_return + [model_params]
|
||||
if return_joint_rotations:
|
||||
to_return = to_return + [curr_joint_rots]
|
||||
|
||||
if isinstance(to_return, list) and len(to_return) == 1:
|
||||
return to_return[0]
|
||||
else:
|
||||
return tuple(to_return)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
init_estimate: Optional[torch.Tensor] = None,
|
||||
do_pcblend=True,
|
||||
slim_keypoints=False,
|
||||
intermediate: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x: pose token with shape [B, C], usually C=DECODER.DIM
|
||||
init_estimate: [B, self.npose]
|
||||
intermediate: when True, the caller only needs the keypoints/pose
|
||||
outputs needed by the per-layer keypoint-token update path —
|
||||
vertex output is suppressed so `camera_project` skips the
|
||||
18439-vertex perspective projection on intermediate decoder
|
||||
layers. The final layer must call with intermediate=False.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
pred = self.proj(x)
|
||||
if init_estimate is not None:
|
||||
pred = pred + init_estimate
|
||||
|
||||
# From pred, we want to pull out individual predictions.
|
||||
|
||||
## First, get globals
|
||||
### Global rotation is first 6.
|
||||
count = 6
|
||||
global_rot_6d = pred[:, :count]
|
||||
global_rot_rotmat = rot6d_to_rotmat(global_rot_6d) # B x 3 x 3
|
||||
global_rot_euler = rotmat_to_euler("ZYX", global_rot_rotmat) # B x 3
|
||||
global_trans = torch.zeros_like(global_rot_euler)
|
||||
|
||||
## Next, get body pose.
|
||||
### Hold onto raw, continuous version for iterative correction.
|
||||
pred_pose_cont = pred[:, count : count + self.body_cont_dim]
|
||||
count += self.body_cont_dim
|
||||
### Convert to eulers (and trans)
|
||||
pred_pose_euler = compact_cont_to_model_params_body(pred_pose_cont)
|
||||
### Zero-out hands
|
||||
pred_pose_euler[:, mhr_param_hand_mask] = 0
|
||||
### Zero-out jaw
|
||||
pred_pose_euler[:, -3:] = 0
|
||||
|
||||
## Get remaining parameters
|
||||
pred_shape = pred[:, count : count + self.num_shape_comps]
|
||||
count += self.num_shape_comps
|
||||
pred_scale = pred[:, count : count + self.num_scale_comps]
|
||||
count += self.num_scale_comps
|
||||
pred_hand = pred[:, count : count + self.num_hand_comps * 2]
|
||||
count += self.num_hand_comps * 2
|
||||
pred_face = pred[:, count : count + self.num_face_comps] * 0
|
||||
count += self.num_face_comps
|
||||
|
||||
# Run everything through mhr
|
||||
output = self.mhr_forward(
|
||||
global_trans=global_trans,
|
||||
global_rot=global_rot_euler,
|
||||
body_pose_params=pred_pose_euler,
|
||||
hand_pose_params=pred_hand,
|
||||
scale_params=pred_scale,
|
||||
shape_params=pred_shape,
|
||||
expr_params=pred_face,
|
||||
do_pcblend=do_pcblend,
|
||||
return_keypoints=True,
|
||||
return_joint_coords=True,
|
||||
return_model_params=True,
|
||||
return_joint_rotations=True,
|
||||
)
|
||||
|
||||
# Some existing code to get joints and fix camera system
|
||||
verts, j3d, jcoords, mhr_model_params, joint_global_rots = output
|
||||
j3d = j3d[:, :70] # 308 --> 70 keypoints
|
||||
|
||||
# Intermediate decoder layers only consume pred_keypoints_3d via the
|
||||
# keypoint-token update path; suppress verts so camera_project skips
|
||||
# the 18439-vertex perspective projection.
|
||||
if intermediate:
|
||||
verts = None
|
||||
if verts is not None:
|
||||
verts[..., [1, 2]] *= -1 # Camera system difference
|
||||
j3d[..., [1, 2]] *= -1 # Camera system difference
|
||||
if jcoords is not None:
|
||||
jcoords[..., [1, 2]] *= -1
|
||||
|
||||
# Head-MLP outputs are promoted to fp32 here so the external
|
||||
# pose_output["mhr"] contract has a stable dtype regardless of what
|
||||
# the head ran at (fp16/bf16 for speed). MHR-derived outputs are
|
||||
# already fp32 from MHR's math layers; the cast on them is a no-op.
|
||||
output = {
|
||||
"pred_pose_raw": torch.cat([global_rot_6d, pred_pose_cont], dim=1).float(),
|
||||
"pred_pose_rotmat": None,
|
||||
"global_rot": global_rot_euler.float(),
|
||||
"body_pose": pred_pose_euler.float(),
|
||||
"shape": pred_shape.float(),
|
||||
"scale": pred_scale.float(),
|
||||
"hand": pred_hand.float(),
|
||||
"face": pred_face.float(),
|
||||
"pred_keypoints_3d": j3d.reshape(batch_size, -1, 3),
|
||||
"pred_vertices": verts.reshape(batch_size, -1, 3) if verts is not None else None,
|
||||
"pred_joint_coords": jcoords.reshape(batch_size, -1, 3) if jcoords is not None else None,
|
||||
"faces": self.faces.cpu().numpy(),
|
||||
"joint_global_rots": joint_global_rots,
|
||||
"mhr_model_params": mhr_model_params,
|
||||
}
|
||||
|
||||
return output
|
||||
235
comfy/ldm/sam3d_body/mhr/mhr_rig.py
Normal file
235
comfy/ldm/sam3d_body/mhr/mhr_rig.py
Normal file
@ -0,0 +1,235 @@
|
||||
# Adapted from facebookresearch/MHR (Apache 2.0):
|
||||
# https://github.com/facebookresearch/MHR/blob/main/mhr/mhr.py
|
||||
# Skinning ops follow facebookincubator/momentum (Apache 2.0) — formulas
|
||||
# verbatim from the TorchScript source bundled in the upstream mhr_model.pt
|
||||
# (pymomentum.{skel_state,quaternion,backend.skel_state_backend}).
|
||||
# Original Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
torch.sparse.check_sparse_tensor_invariants.disable() # silence the "implicitly disabled" UserWarning.
|
||||
|
||||
from .mhr_utils import batch6DFromXYZ
|
||||
|
||||
_LN2 = 0.6931471824645996
|
||||
|
||||
def _euler_xyz_to_quat(angles):
|
||||
"""(roll, pitch, yaw) -> quaternion (x, y, z, w). Matches pymomentum.quaternion.euler_xyz_to_quaternion."""
|
||||
roll, pitch, yaw = angles.unbind(-1)
|
||||
cy, sy = torch.cos(yaw * 0.5), torch.sin(yaw * 0.5)
|
||||
cp, sp = torch.cos(pitch * 0.5), torch.sin(pitch * 0.5)
|
||||
cr, sr = torch.cos(roll * 0.5), torch.sin(roll * 0.5)
|
||||
x = sr * cp * cy - cr * sp * sy
|
||||
y = cr * sp * cy + sr * cp * sy
|
||||
z = cr * cp * sy - sr * sp * cy
|
||||
w = cr * cp * cy + sr * sp * sy
|
||||
return torch.stack([x, y, z, w], dim=-1)
|
||||
|
||||
|
||||
def _quat_multiply(q1, q2):
|
||||
x1, y1, z1, w1 = q1.unbind(-1)
|
||||
x2, y2, z2, w2 = q2.unbind(-1)
|
||||
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
||||
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
||||
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
||||
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
||||
return torch.stack([x, y, z, w], dim=-1)
|
||||
|
||||
|
||||
def _quat_rotate(q, v):
|
||||
"""Rotate v by unit quaternion q (xyzw). v + 2 * (axis x v * w + axis x (axis x v))."""
|
||||
axis = q[..., :3]
|
||||
r = q[..., 3:4]
|
||||
av = torch.cross(axis, v, dim=-1)
|
||||
aav = torch.cross(axis, av, dim=-1)
|
||||
return v + 2.0 * (av * r + aav)
|
||||
|
||||
|
||||
def _skel_multiply(s1, s2):
|
||||
"""Compose two skel states (..., 8). Returns parent ∘ child.
|
||||
|
||||
Mirrors pymomentum.skel_state.multiply: both quaternions are renormalized
|
||||
before composition. With many FK levels the previously-normalized quats
|
||||
drift in ULPs; the JIT renormalizes defensively, so we do too to stay
|
||||
bit-close to its outputs.
|
||||
"""
|
||||
t1, sc1 = s1[..., :3], s1[..., 7:8]
|
||||
t2, sc2 = s2[..., :3], s2[..., 7:8]
|
||||
q1 = F.normalize(s1[..., 3:7], p=2, dim=-1, eps=1e-12)
|
||||
q2 = F.normalize(s2[..., 3:7], p=2, dim=-1, eps=1e-12)
|
||||
t_res = t1 + sc1 * _quat_rotate(q1, t2)
|
||||
q_res = _quat_multiply(q1, q2)
|
||||
s_res = sc1 * sc2
|
||||
return torch.cat([t_res, q_res, s_res], dim=-1)
|
||||
|
||||
|
||||
def _skel_transform_points(skel_state, points):
|
||||
"""Apply skel_state (..., 8) to points (..., 3): t + q * (s * points).
|
||||
|
||||
Assumes the quaternion in skel_state is already unit-norm. Callers that
|
||||
can't guarantee that should normalize first.
|
||||
"""
|
||||
t = skel_state[..., :3]
|
||||
q = skel_state[..., 3:7]
|
||||
s = skel_state[..., 7:8]
|
||||
return t + _quat_rotate(q, s * points)
|
||||
|
||||
|
||||
def _global_skel_state_from_local(local, pmi_levels):
|
||||
"""FK walk in fp64 (matches the JIT's use_double_precision=True path).
|
||||
|
||||
`pmi_levels` is a precomputed list of (source_idx, target_idx) tensor pairs,
|
||||
one per BFS level. Avoids per-call torch.split + tolist() sync.
|
||||
"""
|
||||
orig_dtype = local.dtype
|
||||
g = local.to(torch.float64).clone()
|
||||
for source, target in pmi_levels:
|
||||
parent = g.index_select(-2, target)
|
||||
child = g.index_select(-2, source)
|
||||
g.index_copy_(-2, source, _skel_multiply(parent, child))
|
||||
return g.to(orig_dtype)
|
||||
|
||||
|
||||
class MHRRig(nn.Module):
|
||||
"""Plain-PyTorch reimplementation of Meta's MHR rig.
|
||||
|
||||
All math runs in fp32 (FK upcast to fp64 internally, matching the JIT's
|
||||
use_double_precision=True backend) regardless of the host model's dtype.
|
||||
"""
|
||||
|
||||
NUM_VERTS = 18439
|
||||
NUM_JOINTS = 127
|
||||
NUM_LBS_TRIPLETS = 51337
|
||||
NUM_IDENTITY = 45
|
||||
NUM_EXPR = 72
|
||||
PARAM_TRANSFORM_IN = 249 # = model_parameters(204) + identity_coeffs(45)
|
||||
PARAM_TRANSFORM_OUT = 889 # = NUM_JOINTS * 7
|
||||
POSE_CORR_IN = 750 # = (NUM_JOINTS - 2) * 6
|
||||
POSE_CORR_HIDDEN = 3000
|
||||
POSE_CORR_SPARSE_NNZ = 53136
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
del dtype, operations
|
||||
f32 = torch.float32
|
||||
|
||||
# All buffers are populated by load_state_dict from the `mhr.*` keys
|
||||
def _p(*shape, dtype=f32):
|
||||
return nn.Parameter(torch.empty(*shape, dtype=dtype, device=device), requires_grad=False)
|
||||
def _b(name, *shape, dtype):
|
||||
self.register_buffer(name, torch.empty(*shape, dtype=dtype, device=device))
|
||||
|
||||
self.base_shape = _p(self.NUM_VERTS, 3)
|
||||
self.identity_basis = _p(self.NUM_IDENTITY, self.NUM_VERTS, 3)
|
||||
self.expr_basis = _p(self.NUM_EXPR, self.NUM_VERTS, 3)
|
||||
self.param_transform = _p(self.PARAM_TRANSFORM_OUT, self.PARAM_TRANSFORM_IN)
|
||||
|
||||
self.skel_joint_translation_offsets = _p(self.NUM_JOINTS, 3)
|
||||
self.skel_joint_prerotations = _p(self.NUM_JOINTS, 4)
|
||||
_b("skel_joint_parents", self.NUM_JOINTS, dtype=torch.int32)
|
||||
_b("skel_pmi", 2, 266, dtype=torch.int64)
|
||||
_b("skel_pmi_buffer_sizes", 4, dtype=torch.int64)
|
||||
|
||||
self.lbs_inverse_bind_pose = _p(self.NUM_JOINTS, 8)
|
||||
self.lbs_skin_weights = _p(self.NUM_LBS_TRIPLETS)
|
||||
_b("lbs_skin_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int32)
|
||||
_b("lbs_vert_indices", self.NUM_LBS_TRIPLETS, dtype=torch.int64)
|
||||
|
||||
_b("pose_corr_sparse_indices", 2, self.POSE_CORR_SPARSE_NNZ, dtype=torch.int64)
|
||||
self.pose_corr_sparse_weight = _p(self.POSE_CORR_SPARSE_NNZ)
|
||||
|
||||
self.register_buffer("pose_corr_sparse_shape", torch.tensor([self.POSE_CORR_HIDDEN, self.POSE_CORR_IN], dtype=torch.int64, device=device))
|
||||
self.pose_corr_weight = _p(self.NUM_VERTS * 3, self.POSE_CORR_HIDDEN)
|
||||
self.pose_corr_bias = None
|
||||
self._pose_corr_sparse_cache = None
|
||||
self._pmi_levels_cache = None
|
||||
|
||||
def forward(self, identity_coeffs, model_parameters, expr_coeffs, apply_correctives: bool = True):
|
||||
f32 = self.base_shape.dtype
|
||||
identity_coeffs = identity_coeffs.to(f32)
|
||||
model_parameters = model_parameters.to(f32)
|
||||
expr_coeffs = expr_coeffs.to(f32)
|
||||
B = identity_coeffs.shape[0]
|
||||
|
||||
identity_rest = self.base_shape + torch.einsum("nvd,bn->bvd", self.identity_basis, identity_coeffs)
|
||||
|
||||
cat_in = torch.cat([model_parameters, torch.zeros_like(identity_coeffs)], dim=1)
|
||||
joint_parameters = torch.einsum("dn,bn->bd", self.param_transform, cat_in)
|
||||
|
||||
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
|
||||
local_t = jp[..., :3] + self.skel_joint_translation_offsets.unsqueeze(0)
|
||||
local_q = _euler_xyz_to_quat(jp[..., 3:6])
|
||||
local_q = _quat_multiply(self.skel_joint_prerotations.unsqueeze(0), local_q)
|
||||
local_s = torch.exp(jp[..., 6:7] * _LN2)
|
||||
local_state = torch.cat([local_t, local_q, local_s], dim=-1)
|
||||
|
||||
skel_state = _global_skel_state_from_local(local_state, self._pmi_levels())
|
||||
|
||||
face_expr = torch.einsum("nvd,bn->bvd", self.expr_basis, expr_coeffs)
|
||||
unposed = identity_rest + face_expr
|
||||
if apply_correctives:
|
||||
unposed = unposed + self._pose_correctives(joint_parameters)
|
||||
|
||||
verts = self._skin(skel_state, unposed)
|
||||
return verts, skel_state
|
||||
|
||||
def _pose_correctives(self, joint_parameters):
|
||||
B = joint_parameters.shape[0]
|
||||
jp = joint_parameters.view(B, self.NUM_JOINTS, 7)
|
||||
# Joints [2:] only — root and one more skipped. Take Euler XYZ (cols 3:6).
|
||||
feat = batch6DFromXYZ(jp[:, 2:, 3:6], return_9D=False) # (B, 125, 6)
|
||||
feat[..., 0] -= 1.0
|
||||
feat[..., 4] -= 1.0
|
||||
feat = feat.flatten(1, 2) # (B, 750)
|
||||
|
||||
h = (self._sparse_w() @ feat.T).T # (B, 3000)
|
||||
h = F.relu(h)
|
||||
out = F.linear(h, self.pose_corr_weight, self.pose_corr_bias) # (B, 55317)
|
||||
return out.view(B, self.NUM_VERTS, 3)
|
||||
|
||||
def _pmi_levels(self):
|
||||
cached = self._pmi_levels_cache
|
||||
pmi = self.skel_pmi
|
||||
if cached is not None and cached[0][0].device == pmi.device:
|
||||
return cached
|
||||
sizes = self.skel_pmi_buffer_sizes.tolist()
|
||||
parts = torch.split(pmi, sizes, dim=1)
|
||||
levels = [(p[0], p[1]) for p in parts]
|
||||
self._pmi_levels_cache = levels
|
||||
return levels
|
||||
|
||||
def _sparse_w(self):
|
||||
cached = self._pose_corr_sparse_cache
|
||||
w = self.pose_corr_sparse_weight
|
||||
if cached is not None and cached.device == w.device and cached.dtype == w.dtype:
|
||||
return cached
|
||||
sparse = torch.sparse_coo_tensor(
|
||||
self.pose_corr_sparse_indices,
|
||||
w,
|
||||
tuple(self.pose_corr_sparse_shape.tolist()),
|
||||
check_invariants=False,
|
||||
).coalesce()
|
||||
self._pose_corr_sparse_cache = sparse
|
||||
return sparse
|
||||
|
||||
def _skin(self, skel_state, rest_verts):
|
||||
B = skel_state.shape[0]
|
||||
ibp = self.lbs_inverse_bind_pose.unsqueeze(0).expand(B, self.NUM_JOINTS, 8)
|
||||
joint_xform = _skel_multiply(skel_state, ibp)
|
||||
|
||||
norm_q = F.normalize(joint_xform[..., 3:7], p=2, dim=-1, eps=1e-12)
|
||||
joint_xform = torch.cat([joint_xform[..., :3], norm_q, joint_xform[..., 7:8]], dim=-1)
|
||||
|
||||
sk_idx = self.lbs_skin_indices.long()
|
||||
v_idx = self.lbs_vert_indices
|
||||
w = self.lbs_skin_weights
|
||||
|
||||
per_triplet_xform = joint_xform.index_select(-2, sk_idx) # (B, 51337, 8)
|
||||
per_triplet_rest = rest_verts.index_select(-2, v_idx) # (B, 51337, 3)
|
||||
contrib = _skel_transform_points(per_triplet_xform, per_triplet_rest) * w.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
out = torch.zeros(B, self.NUM_VERTS, 3, dtype=rest_verts.dtype, device=rest_verts.device)
|
||||
out.index_add_(-2, v_idx, contrib)
|
||||
return out
|
||||
413
comfy/ldm/sam3d_body/mhr/mhr_utils.py
Normal file
413
comfy/ldm/sam3d_body/mhr/mhr_utils.py
Normal file
@ -0,0 +1,413 @@
|
||||
# MHR (Meta Human Rig) parameter packing/unpacking. The 6D-rotation helpers
|
||||
# (batch6DFromXYZ, batchXYZfrom6D, batch9Dfrom6D) are the continuity
|
||||
# representation from Zhou et al., "On the Continuity of Rotation
|
||||
# Representations in Neural Networks" (CVPR 2019, https://arxiv.org/abs/1812.07035),
|
||||
# implementations from papagina/RotationContinuity:
|
||||
# https://github.com/papagina/RotationContinuity/blob/758b0ce5/shapenet/code/tools.py
|
||||
# The compact_cont_to_model_params_* functions are MHR-rig-specific glue.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def rotation_angle_difference(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the angle difference (magnitude) between two batches of SO(3) rotation matrices.
|
||||
Args:
|
||||
A: Tensor of shape (*, 3, 3), batch of rotation matrices.
|
||||
B: Tensor of shape (*, 3, 3), batch of rotation matrices.
|
||||
Returns:
|
||||
Tensor of shape (*,), angle differences in radians.
|
||||
"""
|
||||
# Compute relative rotation matrix
|
||||
R_rel = torch.matmul(A, B.transpose(-2, -1)) # (B, 3, 3)
|
||||
# Compute trace of relative rotation
|
||||
trace = R_rel[..., 0, 0] + R_rel[..., 1, 1] + R_rel[..., 2, 2] # (B,)
|
||||
# Compute angle using the trace formula
|
||||
cos_theta = (trace - 1) / 2
|
||||
# Clamp for numerical stability
|
||||
cos_theta_clamped = torch.clamp(cos_theta, -1.0, 1.0)
|
||||
# Compute angle difference
|
||||
angle = torch.acos(cos_theta_clamped)
|
||||
return angle
|
||||
|
||||
|
||||
def fix_wrist_euler(
|
||||
wrist_xzy, limits_x=(-2.2, 1.0), limits_z=(-2.2, 1.5), limits_y=(-1.2, 1.5)
|
||||
):
|
||||
"""
|
||||
wrist_xzy: B x 2 x 3 (X, Z, Y angles)
|
||||
Returns: Fixed angles within joint limits
|
||||
"""
|
||||
x, z, y = wrist_xzy[..., 0], wrist_xzy[..., 1], wrist_xzy[..., 2]
|
||||
|
||||
x_alt = torch.atan2(torch.sin(x + torch.pi), torch.cos(x + torch.pi))
|
||||
z_alt = torch.atan2(torch.sin(-(z + torch.pi)), torch.cos(-(z + torch.pi)))
|
||||
y_alt = torch.atan2(torch.sin(y + torch.pi), torch.cos(y + torch.pi))
|
||||
|
||||
# Calculate L2 violation distance
|
||||
def calc_violation(val, limits):
|
||||
below = torch.clamp(limits[0] - val, min=0.0)
|
||||
above = torch.clamp(val - limits[1], min=0.0)
|
||||
return below**2 + above**2
|
||||
|
||||
violation_orig = (
|
||||
calc_violation(x, limits_x)
|
||||
+ calc_violation(z, limits_z)
|
||||
+ calc_violation(y, limits_y)
|
||||
)
|
||||
|
||||
violation_alt = (
|
||||
calc_violation(x_alt, limits_x)
|
||||
+ calc_violation(z_alt, limits_z)
|
||||
+ calc_violation(y_alt, limits_y)
|
||||
)
|
||||
|
||||
# Use alternative where it has lower L2 violation
|
||||
use_alt = violation_alt < violation_orig
|
||||
|
||||
# Stack alternative and apply mask
|
||||
wrist_xzy_alt = torch.stack([x_alt, z_alt, y_alt], dim=-1)
|
||||
result = torch.where(use_alt.unsqueeze(-1), wrist_xzy_alt, wrist_xzy)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py
|
||||
def batch6DFromXYZ(r, return_9D=False):
|
||||
"""
|
||||
Generate a matrix representing a rotation defined by a XYZ-Euler
|
||||
rotation.
|
||||
|
||||
Args:
|
||||
r: ... x 3 rotation vectors
|
||||
|
||||
Returns:
|
||||
... x 6
|
||||
"""
|
||||
rc = torch.cos(r)
|
||||
rs = torch.sin(r)
|
||||
cx = rc[..., 0]
|
||||
cy = rc[..., 1]
|
||||
cz = rc[..., 2]
|
||||
sx = rs[..., 0]
|
||||
sy = rs[..., 1]
|
||||
sz = rs[..., 2]
|
||||
|
||||
result = torch.empty(list(r.shape[:-1]) + [3, 3], dtype=r.dtype).to(r.device)
|
||||
|
||||
result[..., 0, 0] = cy * cz
|
||||
result[..., 0, 1] = -cx * sz + sx * sy * cz
|
||||
result[..., 0, 2] = sx * sz + cx * sy * cz
|
||||
result[..., 1, 0] = cy * sz
|
||||
result[..., 1, 1] = cx * cz + sx * sy * sz
|
||||
result[..., 1, 2] = -sx * cz + cx * sy * sz
|
||||
result[..., 2, 0] = -sy
|
||||
result[..., 2, 1] = sx * cy
|
||||
result[..., 2, 2] = cx * cy
|
||||
|
||||
if not return_9D:
|
||||
return torch.cat([result[..., :, 0], result[..., :, 1]], dim=-1)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L82
|
||||
def batchXYZfrom6D(poses):
|
||||
# Args: poses: ... x 6, where "6" is the combined first and second columns
|
||||
# First, get the rotaiton matrix
|
||||
x_raw = poses[..., :3]
|
||||
y_raw = poses[..., 3:]
|
||||
|
||||
x = F.normalize(x_raw, dim=-1)
|
||||
z = torch.cross(x, y_raw, dim=-1)
|
||||
z = F.normalize(z, dim=-1)
|
||||
y = torch.cross(z, x, dim=-1)
|
||||
|
||||
matrix = torch.stack([x, y, z], dim=-1) # ... x 3 x 3
|
||||
|
||||
# Now get it into euler
|
||||
# https://github.com/papagina/RotationContinuity/blob/758b0ce551c06372cab7022d4c0bdf331c89c696/shapenet/code/tools.py#L412
|
||||
sy = torch.sqrt(
|
||||
matrix[..., 0, 0] * matrix[..., 0, 0] + matrix[..., 1, 0] * matrix[..., 1, 0]
|
||||
)
|
||||
singular = sy < 1e-6
|
||||
singular = singular.float()
|
||||
|
||||
x = torch.atan2(matrix[..., 2, 1], matrix[..., 2, 2])
|
||||
y = torch.atan2(-matrix[..., 2, 0], sy)
|
||||
z = torch.atan2(matrix[..., 1, 0], matrix[..., 0, 0])
|
||||
|
||||
xs = torch.atan2(-matrix[..., 1, 2], matrix[..., 1, 1])
|
||||
ys = torch.atan2(-matrix[..., 2, 0], sy)
|
||||
zs = matrix[..., 1, 0] * 0
|
||||
|
||||
out_euler = torch.zeros_like(matrix[..., 0])
|
||||
out_euler[..., 0] = x * (1 - singular) + xs * singular
|
||||
out_euler[..., 1] = y * (1 - singular) + ys * singular
|
||||
out_euler[..., 2] = z * (1 - singular) + zs * singular
|
||||
|
||||
return out_euler
|
||||
|
||||
|
||||
_HAND_DOFS = [3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1]
|
||||
_HAND_MASK_CACHE: dict = {} # device -> dict of masks
|
||||
|
||||
|
||||
def _hand_masks(device):
|
||||
m = _HAND_MASK_CACHE.get(device)
|
||||
if m is not None:
|
||||
return m
|
||||
mask_cont_threedofs = torch.cat(
|
||||
[torch.ones(2 * k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
|
||||
).to(device)
|
||||
mask_cont_onedofs = torch.cat(
|
||||
[torch.ones(2 * k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
|
||||
).to(device)
|
||||
mask_model_params_threedofs = torch.cat(
|
||||
[torch.ones(k, dtype=torch.bool) * (k == 3) for k in _HAND_DOFS]
|
||||
).to(device)
|
||||
mask_model_params_onedofs = torch.cat(
|
||||
[torch.ones(k, dtype=torch.bool) * (k in (1, 2)) for k in _HAND_DOFS]
|
||||
).to(device)
|
||||
m = dict(
|
||||
mask_cont_threedofs=mask_cont_threedofs,
|
||||
mask_cont_onedofs=mask_cont_onedofs,
|
||||
mask_model_params_threedofs=mask_model_params_threedofs,
|
||||
mask_model_params_onedofs=mask_model_params_onedofs,
|
||||
)
|
||||
_HAND_MASK_CACHE[device] = m
|
||||
return m
|
||||
|
||||
|
||||
def compact_cont_to_model_params_hand(hand_cont):
|
||||
# These are ordered by joint, not model params ^^
|
||||
assert hand_cont.shape[-1] == 54
|
||||
m = _hand_masks(hand_cont.device)
|
||||
mask_cont_threedofs = m["mask_cont_threedofs"]
|
||||
mask_cont_onedofs = m["mask_cont_onedofs"]
|
||||
mask_model_params_threedofs = m["mask_model_params_threedofs"]
|
||||
mask_model_params_onedofs = m["mask_model_params_onedofs"]
|
||||
|
||||
# Convert hand_cont to eulers
|
||||
## First for 3DoFs
|
||||
hand_cont_threedofs = hand_cont[..., mask_cont_threedofs].unflatten(-1, (-1, 6))
|
||||
hand_model_params_threedofs = batchXYZfrom6D(hand_cont_threedofs).flatten(-2, -1)
|
||||
## Next for 1DoFs
|
||||
hand_cont_onedofs = hand_cont[..., mask_cont_onedofs].unflatten(
|
||||
-1, (-1, 2)
|
||||
) # (sincos)
|
||||
hand_model_params_onedofs = torch.atan2(
|
||||
hand_cont_onedofs[..., -2], hand_cont_onedofs[..., -1]
|
||||
)
|
||||
|
||||
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
|
||||
hand_model_params = torch.zeros(*hand_cont.shape[:-1], 27).to(hand_cont)
|
||||
hand_model_params[..., mask_model_params_threedofs] = hand_model_params_threedofs
|
||||
hand_model_params[..., mask_model_params_onedofs] = hand_model_params_onedofs
|
||||
|
||||
return hand_model_params
|
||||
|
||||
|
||||
def compact_model_params_to_cont_hand(hand_model_params):
|
||||
# These are ordered by joint, not model params ^^
|
||||
assert hand_model_params.shape[-1] == 27
|
||||
hand_dofs_in_order = torch.tensor([3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 2, 3, 1, 1])
|
||||
assert sum(hand_dofs_in_order) == 27
|
||||
# Mask of 3DoFs into hand_cont
|
||||
mask_cont_threedofs = torch.cat(
|
||||
[torch.ones(2 * k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
||||
)
|
||||
# Mask of 1DoFs (including 2DoF) into hand_cont
|
||||
mask_cont_onedofs = torch.cat(
|
||||
[torch.ones(2 * k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
||||
)
|
||||
# Mask of 3DoFs into hand_model_params
|
||||
mask_model_params_threedofs = torch.cat(
|
||||
[torch.ones(k).bool() * (k in [3]) for k in hand_dofs_in_order]
|
||||
)
|
||||
# Mask of 1DoFs (including 2DoF) into hand_model_params
|
||||
mask_model_params_onedofs = torch.cat(
|
||||
[torch.ones(k).bool() * (k in [1, 2]) for k in hand_dofs_in_order]
|
||||
)
|
||||
|
||||
# Convert eulers to hand_cont hand_cont
|
||||
## First for 3DoFs
|
||||
hand_model_params_threedofs = hand_model_params[
|
||||
..., mask_model_params_threedofs
|
||||
].unflatten(-1, (-1, 3))
|
||||
hand_cont_threedofs = batch6DFromXYZ(hand_model_params_threedofs).flatten(-2, -1)
|
||||
## Next for 1DoFs
|
||||
hand_model_params_onedofs = hand_model_params[..., mask_model_params_onedofs]
|
||||
hand_cont_onedofs = torch.stack(
|
||||
[hand_model_params_onedofs.sin(), hand_model_params_onedofs.cos()], dim=-1
|
||||
).flatten(-2, -1)
|
||||
|
||||
# Finally, assemble into a 27-dim vector, ordered by joint, then XYZ.
|
||||
hand_cont = torch.zeros(*hand_model_params.shape[:-1], 54).to(hand_model_params)
|
||||
hand_cont[..., mask_cont_threedofs] = hand_cont_threedofs
|
||||
hand_cont[..., mask_cont_onedofs] = hand_cont_onedofs
|
||||
|
||||
return hand_cont
|
||||
|
||||
|
||||
def batch9Dfrom6D(poses):
|
||||
# Args: poses: ... x 6, where "6" is the combined first and second columns
|
||||
# First, get the rotaiton matrix
|
||||
x_raw = poses[..., :3]
|
||||
y_raw = poses[..., 3:]
|
||||
|
||||
x = F.normalize(x_raw, dim=-1)
|
||||
z = torch.cross(x, y_raw, dim=-1)
|
||||
z = F.normalize(z, dim=-1)
|
||||
y = torch.cross(z, x, dim=-1)
|
||||
|
||||
matrix = torch.stack([x, y, z], dim=-1).flatten(-2, -1) # ... x 3 x 3 -> x9
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def batch4Dfrom2D(poses):
|
||||
# Args: poses: ... x 2, where "2" is sincos
|
||||
poses_norm = F.normalize(poses, dim=-1)
|
||||
|
||||
poses_4d = torch.stack(
|
||||
[
|
||||
poses_norm[..., 1],
|
||||
poses_norm[..., 0],
|
||||
-poses_norm[..., 0],
|
||||
poses_norm[..., 1],
|
||||
],
|
||||
dim=-1,
|
||||
) # Flattened SO2.
|
||||
|
||||
return poses_4d # .... x 4
|
||||
|
||||
|
||||
def compact_cont_to_rotmat_body(body_pose_cont, inflate_trans=False):
|
||||
# fmt: off
|
||||
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
||||
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
||||
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
||||
# fmt: on
|
||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
||||
assert body_pose_cont.shape[-1] == (
|
||||
2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
|
||||
)
|
||||
# Get subsets
|
||||
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
||||
body_cont_1dofs = body_pose_cont[
|
||||
..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles
|
||||
]
|
||||
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
|
||||
# Convert conts to model params
|
||||
## First for 3dofs
|
||||
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
|
||||
body_rotmat_3dofs = batch9Dfrom6D(body_cont_3dofs).flatten(-2, -1)
|
||||
## Next for 1dofs
|
||||
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
|
||||
body_rotmat_1dofs = batch4Dfrom2D(body_cont_1dofs).flatten(-2, -1)
|
||||
if inflate_trans:
|
||||
assert (
|
||||
False
|
||||
), "This is left as a possibility to increase the space/contribution/supervision trans params gets compared to rots"
|
||||
else:
|
||||
## Nothing to do for trans
|
||||
body_rotmat_trans = body_cont_trans
|
||||
# Put them together
|
||||
body_rotmat_params = torch.cat(
|
||||
[body_rotmat_3dofs, body_rotmat_1dofs, body_rotmat_trans], dim=-1
|
||||
)
|
||||
return body_rotmat_params
|
||||
|
||||
|
||||
_BODY_IDX_CACHE: dict = {}
|
||||
|
||||
|
||||
def _body_idxs(device):
|
||||
cached = _BODY_IDX_CACHE.get(device)
|
||||
if cached is not None:
|
||||
return cached
|
||||
# fmt: off
|
||||
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)]).to(device)
|
||||
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123]).to(device)
|
||||
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129]).to(device)
|
||||
# fmt: on
|
||||
cached = (
|
||||
all_param_3dof_rot_idxs,
|
||||
all_param_1dof_rot_idxs,
|
||||
all_param_1dof_trans_idxs,
|
||||
all_param_3dof_rot_idxs.flatten(),
|
||||
)
|
||||
_BODY_IDX_CACHE[device] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def compact_cont_to_model_params_body(body_pose_cont):
|
||||
(all_param_3dof_rot_idxs, all_param_1dof_rot_idxs, all_param_1dof_trans_idxs, idxs_3dof_flat) = _body_idxs(body_pose_cont.device)
|
||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
||||
assert body_pose_cont.shape[-1] == 2 * num_3dof_angles + 2 * num_1dof_angles + num_1dof_trans
|
||||
# Get subsets
|
||||
body_cont_3dofs = body_pose_cont[..., : 2 * num_3dof_angles]
|
||||
body_cont_1dofs = body_pose_cont[..., 2 * num_3dof_angles : 2 * num_3dof_angles + 2 * num_1dof_angles]
|
||||
body_cont_trans = body_pose_cont[..., 2 * num_3dof_angles + 2 * num_1dof_angles :]
|
||||
# Convert conts to model params
|
||||
## First for 3dofs
|
||||
body_cont_3dofs = body_cont_3dofs.unflatten(-1, (-1, 6))
|
||||
body_params_3dofs = batchXYZfrom6D(body_cont_3dofs).flatten(-2, -1)
|
||||
## Next for 1dofs
|
||||
body_cont_1dofs = body_cont_1dofs.unflatten(-1, (-1, 2)) # (sincos)
|
||||
body_params_1dofs = torch.atan2(body_cont_1dofs[..., -2], body_cont_1dofs[..., -1])
|
||||
## Nothing to do for trans
|
||||
body_params_trans = body_cont_trans
|
||||
# Put them together
|
||||
body_pose_params = torch.zeros(*body_pose_cont.shape[:-1], 133, dtype=body_pose_cont.dtype, device=body_pose_cont.device)
|
||||
body_pose_params[..., idxs_3dof_flat] = body_params_3dofs
|
||||
body_pose_params[..., all_param_1dof_rot_idxs] = body_params_1dofs
|
||||
body_pose_params[..., all_param_1dof_trans_idxs] = body_params_trans
|
||||
return body_pose_params
|
||||
|
||||
|
||||
def compact_model_params_to_cont_body(body_pose_params):
|
||||
# fmt: off
|
||||
all_param_3dof_rot_idxs = torch.LongTensor([(0, 2, 4), (6, 8, 10), (12, 13, 14), (15, 16, 17), (18, 19, 20), (21, 22, 23), (24, 25, 26), (27, 28, 29), (34, 35, 36), (37, 38, 39), (44, 45, 46), (53, 54, 55), (64, 65, 66), (85, 69, 73), (86, 70, 79), (87, 71, 82), (88, 72, 76), (91, 92, 93), (112, 96, 100), (113, 97, 106), (114, 98, 109), (115, 99, 103), (130, 131, 132)])
|
||||
all_param_1dof_rot_idxs = torch.LongTensor([1, 3, 5, 7, 9, 11, 30, 31, 32, 33, 40, 41, 42, 43, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 67, 68, 74, 75, 77, 78, 80, 81, 83, 84, 89, 90, 94, 95, 101, 102, 104, 105, 107, 108, 110, 111, 116, 117, 118, 119, 120, 121, 122, 123])
|
||||
all_param_1dof_trans_idxs = torch.LongTensor([124, 125, 126, 127, 128, 129])
|
||||
# fmt: on
|
||||
num_3dof_angles = len(all_param_3dof_rot_idxs) * 3
|
||||
num_1dof_angles = len(all_param_1dof_rot_idxs)
|
||||
num_1dof_trans = len(all_param_1dof_trans_idxs)
|
||||
assert body_pose_params.shape[-1] == (
|
||||
num_3dof_angles + num_1dof_angles + num_1dof_trans
|
||||
)
|
||||
# Take out params
|
||||
body_params_3dofs = body_pose_params[..., all_param_3dof_rot_idxs.flatten()]
|
||||
body_params_1dofs = body_pose_params[..., all_param_1dof_rot_idxs]
|
||||
body_params_trans = body_pose_params[..., all_param_1dof_trans_idxs]
|
||||
# params to cont
|
||||
body_cont_3dofs = batch6DFromXYZ(body_params_3dofs.unflatten(-1, (-1, 3))).flatten(
|
||||
-2, -1
|
||||
)
|
||||
body_cont_1dofs = torch.stack(
|
||||
[body_params_1dofs.sin(), body_params_1dofs.cos()], dim=-1
|
||||
).flatten(-2, -1)
|
||||
body_cont_trans = body_params_trans
|
||||
# Put them together
|
||||
body_pose_cont = torch.cat(
|
||||
[body_cont_3dofs, body_cont_1dofs, body_cont_trans], dim=-1
|
||||
)
|
||||
return body_pose_cont
|
||||
|
||||
|
||||
# fmt: off
|
||||
mhr_param_hand_idxs = [62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115]
|
||||
mhr_cont_hand_idxs = [72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237]
|
||||
mhr_param_hand_mask = torch.zeros(133).bool()
|
||||
mhr_param_hand_mask[mhr_param_hand_idxs] = True
|
||||
mhr_cont_hand_mask = torch.zeros(260).bool()
|
||||
mhr_cont_hand_mask[mhr_cont_hand_idxs] = True
|
||||
# fmt: on
|
||||
155
comfy/ldm/sam3d_body/model/camera_modules.py
Normal file
155
comfy/ldm/sam3d_body/model/camera_modules.py
Normal file
@ -0,0 +1,155 @@
|
||||
import math
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.cascade.common import LayerNorm2d_op
|
||||
from torch import nn
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from ..utils import perspective_projection
|
||||
from .transformer import MLP
|
||||
|
||||
class CameraEncoder(nn.Module):
|
||||
def __init__(self, embed_dim: int, patch_size: int = 14, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.camera = FourierPositionEncoding(n=3, num_bands=16, max_resolution=64)
|
||||
|
||||
self.conv = operations.Conv2d(embed_dim + 99, embed_dim, kernel_size=1, bias=False, device=device, dtype=dtype)
|
||||
self.norm = LayerNorm2d_op(operations)(embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, img_embeddings: torch.Tensor, rays: torch.Tensor):
|
||||
B, D, _h, _w = img_embeddings.shape
|
||||
|
||||
scale = 1 / self.patch_size
|
||||
rays = F.interpolate(rays, scale_factor=(scale, scale), mode="bilinear", align_corners=False, antialias=True)
|
||||
rays = rays.permute(0, 2, 3, 1).contiguous() # [b, h, w, 2]
|
||||
rays = torch.cat([rays, torch.ones_like(rays[..., :1])], dim=-1)
|
||||
rays_embeddings = self.camera(pos=rays.reshape(B, -1, 3)) # (bs, N, 99): rays fourier embedding
|
||||
rays_embeddings = einops.rearrange(rays_embeddings, "b (h w) c -> b c h w", h=_h, w=_w).contiguous()
|
||||
|
||||
z = torch.cat([img_embeddings, rays_embeddings], dim=1)
|
||||
return self.norm(self.conv(z))
|
||||
|
||||
|
||||
class FourierPositionEncoding(nn.Module):
|
||||
"""Sin/cos Fourier features for ray positions"""
|
||||
|
||||
def __init__(self, n: int, num_bands: int, max_resolution: int):
|
||||
super().__init__()
|
||||
self.num_bands = num_bands
|
||||
self.max_resolution = [max_resolution] * n
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
num_dims = len(self.max_resolution)
|
||||
encoding_size = self.num_bands * num_dims
|
||||
encoding_size *= 2 # sin-cos
|
||||
encoding_size += num_dims # concat
|
||||
|
||||
return encoding_size
|
||||
|
||||
def forward(self, pos: torch.Tensor):
|
||||
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
|
||||
return fourier_pos_enc
|
||||
|
||||
|
||||
def _generate_fourier_features(pos: torch.Tensor, num_bands: int, max_resolution: List[int], min_freq: float = 1.0):
|
||||
b, n = pos.shape[:2]
|
||||
|
||||
freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=pos.device, dtype=pos.dtype) for res in max_resolution], dim=0)
|
||||
|
||||
per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0)
|
||||
per_pos_features = per_pos_features.reshape(b, n, -1)
|
||||
|
||||
# Sin-Cos
|
||||
per_pos_features = torch.cat([torch.sin(math.pi * per_pos_features), torch.cos(math.pi * per_pos_features)], dim=-1)
|
||||
|
||||
# Concat with initial pos
|
||||
per_pos_features = torch.cat([pos, per_pos_features], dim=-1)
|
||||
|
||||
return per_pos_features
|
||||
|
||||
|
||||
class PerspectiveHead(nn.Module):
|
||||
"""
|
||||
Predict camera translation (s, tx, ty) and perform full-perspective 2D reprojection (CLIFF/CameraHMR setup).
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int, img_size: Union[int, Tuple[int, int]], # model input size (W, H)
|
||||
mlp_depth: int = 1, mlp_channel_div_factor: int = 8, default_scale_factor: float = 1.0,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Metadata to compute 3D skeleton and 2D reprojection
|
||||
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
|
||||
self.ncam = 3 # (s, tx, ty)
|
||||
self.default_scale_factor = default_scale_factor
|
||||
|
||||
self.proj = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dim=input_dim // mlp_channel_div_factor,
|
||||
output_dim=self.ncam,
|
||||
num_layers=mlp_depth,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, init_estimate: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Args:
|
||||
x: pose token with shape [B, C], usually C=DECODER.DIM
|
||||
init_estimate: [B, self.ncam]
|
||||
"""
|
||||
pred_cam = self.proj(x)
|
||||
if init_estimate is not None:
|
||||
pred_cam = pred_cam + init_estimate
|
||||
|
||||
return pred_cam
|
||||
|
||||
def perspective_projection(
|
||||
self,
|
||||
points_3d: torch.Tensor,
|
||||
pred_cam: torch.Tensor,
|
||||
bbox_center: torch.Tensor, # [N, 2], in original image space (w, h)
|
||||
bbox_size: torch.Tensor, # [N,], in original image space
|
||||
img_size: torch.Tensor,
|
||||
cam_int: torch.Tensor, # [B, 3, 3]
|
||||
use_intrin_center: bool = False,
|
||||
):
|
||||
batch_size = points_3d.shape[0]
|
||||
pred_cam = pred_cam.clone()
|
||||
pred_cam[..., [0, 2]] *= -1 # Camera system difference
|
||||
|
||||
# Compute camera translation: (scale, x, y) --> (x, y, depth)
|
||||
# depth ~= f / s, Note that f is in the NDC space
|
||||
s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2]
|
||||
bs = bbox_size * s * self.default_scale_factor + 1e-8
|
||||
focal_length = cam_int[:, 0, 0]
|
||||
tz = 2 * focal_length / bs
|
||||
|
||||
if not use_intrin_center:
|
||||
cx = 2 * (bbox_center[:, 0] - (img_size[:, 0] / 2)) / bs
|
||||
cy = 2 * (bbox_center[:, 1] - (img_size[:, 1] / 2)) / bs
|
||||
else:
|
||||
cx = 2 * (bbox_center[:, 0] - (cam_int[:, 0, 2])) / bs
|
||||
cy = 2 * (bbox_center[:, 1] - (cam_int[:, 1, 2])) / bs
|
||||
|
||||
pred_cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
||||
|
||||
# Compute camera translation
|
||||
j3d_cam = points_3d + pred_cam_t.unsqueeze(1)
|
||||
|
||||
# Projection to the image plane, note that the projection output is in original image space now.
|
||||
j2d = perspective_projection(j3d_cam, cam_int)
|
||||
|
||||
return {
|
||||
"pred_keypoints_2d": j2d.reshape(batch_size, -1, 2),
|
||||
"pred_keypoints_2d_depth": j3d_cam.reshape(batch_size, -1, 3)[:, :, 2],
|
||||
"pred_cam_t": pred_cam_t, "focal_length": focal_length,
|
||||
}
|
||||
250
comfy/ldm/sam3d_body/model/dinov3.py
Normal file
250
comfy/ldm/sam3d_body/model/dinov3.py
Normal file
@ -0,0 +1,250 @@
|
||||
# DINOv3 ViT-H+ backbone for SAM 3D Body.
|
||||
#
|
||||
# Single-file consolidation of the inference path. SAM 3D Body only ships a
|
||||
# `dinov3_vith16plus` checkpoint, so the architecture is hardcoded rather
|
||||
# than reconstructed from Hydra-flavoured configs.
|
||||
#
|
||||
# Adapted from facebookresearch/dinov3 (DINOv3 License Agreement). Trimmed
|
||||
# to what's actually exercised at inference: no multi-crop training path,
|
||||
# no DINOHead, no causal blocks, no rmsnorm/Mlp variants, no rope shift /
|
||||
# jitter / rescale (training-time augmentations).
|
||||
|
||||
#TODO: Unify with TRELLIS2
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from torch import Tensor, nn
|
||||
|
||||
# DINOv3 ViT-H+ architecture constants.
|
||||
EMBED_DIM = 1280
|
||||
DEPTH = 32
|
||||
NUM_HEADS = 20
|
||||
FFN_RATIO = 6.0
|
||||
PATCH_SIZE = 16
|
||||
LAYERSCALE_INIT = 1.0e-5
|
||||
N_STORAGE_TOKENS = 4
|
||||
LAYERNORM_EPS = 1e-5 # "layernormbf16" preset uses 1e-5
|
||||
ROPE_BASE = 100.0
|
||||
|
||||
# RoPE (axial sin/cos, no learnable weights)
|
||||
|
||||
def _rotate_half(x: Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat([-x2, x1], dim=-1)
|
||||
|
||||
|
||||
def _apply_rope(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
|
||||
return x * cos + _rotate_half(x) * sin
|
||||
|
||||
|
||||
class RopePositionEmbedding(nn.Module):
|
||||
"""Axial RoPE for 2D patch grids; periods buffer is deterministic."""
|
||||
|
||||
def __init__(self, embed_dim: int, num_heads: int, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
assert embed_dim % (4 * num_heads) == 0
|
||||
D_head = embed_dim // num_heads
|
||||
# Periods are persistent so they round-trip through state_dict, but the
|
||||
# values are deterministic from D_head/base; load_state_dict will
|
||||
# overwrite this with the saved buffer either way.
|
||||
periods = ROPE_BASE ** (
|
||||
2 * torch.arange(D_head // 4, dtype=dtype, device=device) / (D_head // 2)
|
||||
)
|
||||
self.register_buffer("periods", periods, persistent=True)
|
||||
self._dtype = dtype
|
||||
|
||||
def forward(self, H: int, W: int) -> Tuple[Tensor, Tensor]:
|
||||
device, dtype = self.periods.device, self._dtype
|
||||
coords_h = torch.arange(0.5, H, device=device, dtype=dtype) / H
|
||||
coords_w = torch.arange(0.5, W, device=device, dtype=dtype) / W
|
||||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||
coords = 2.0 * coords.flatten(0, 1) - 1.0 # [HW, 2] in [-1, +1]
|
||||
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
|
||||
angles = angles.flatten(1, 2).tile(2) # [HW, D_head]
|
||||
return torch.sin(angles), torch.cos(angles)
|
||||
|
||||
|
||||
def _apply_rope_to_qk(q: Tensor, k: Tensor, rope: Tuple[Tensor, Tensor]):
|
||||
"""Apply RoPE only to the patch-token slice (skip CLS + storage tokens)."""
|
||||
sin, cos = rope
|
||||
rope_dtype = sin.dtype
|
||||
q_dtype, k_dtype = q.dtype, k.dtype
|
||||
q = q.to(rope_dtype)
|
||||
k = k.to(rope_dtype)
|
||||
prefix = q.shape[-2] - sin.shape[-2]
|
||||
q_pre, q_rope = q[..., :prefix, :], q[..., prefix:, :]
|
||||
k_pre, k_rope = k[..., :prefix, :], k[..., prefix:, :]
|
||||
q = torch.cat([q_pre, _apply_rope(q_rope, sin, cos)], dim=-2)
|
||||
k = torch.cat([k_pre, _apply_rope(k_rope, sin, cos)], dim=-2)
|
||||
return q.to(q_dtype), k.to(k_dtype)
|
||||
|
||||
# Layers
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim: int, init_values: float, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(
|
||||
torch.full((dim,), init_values, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x * self.gamma
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
"""w3(silu(w1(x)) * w2(x))."""
|
||||
|
||||
def __init__(self, in_features: int, hidden_features: int, align_to: int = 8,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
d = int(hidden_features * 2 / 3)
|
||||
h = d + (-d % align_to)
|
||||
self.w1 = ops.Linear(in_features, h, bias=True, device=device, dtype=dtype)
|
||||
self.w2 = ops.Linear(in_features, h, bias=True, device=device, dtype=dtype)
|
||||
self.w3 = ops.Linear(h, in_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
self.num_heads = num_heads
|
||||
# DINOv3's `mask_k_bias` zeroes the K third of qkv.bias. The mask is
|
||||
# deterministic from out_features, so the loader applies it in-place
|
||||
# once after load_state_dict (see `apply_dinov3_qkv_bias_mask`) and the
|
||||
# forward stays a plain F.linear.
|
||||
self.qkv = ops.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor, rope: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
q, k, v = qkv.unbind(2)
|
||||
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
|
||||
if rope is not None:
|
||||
q, k = _apply_rope_to_qk(q, k, rope)
|
||||
# low_precision_attention=False forces attention_sage (when enabled
|
||||
# globally in comfy) to fall back to pytorch SDPA. SAM 3D Body's
|
||||
# regression heads (camera projection, MHR rig math) are sensitive
|
||||
# to attention output precision; sage's int8/fp8 path drifts the
|
||||
# keypoints and mesh visibly.
|
||||
x = optimized_attention(
|
||||
q, k, v, self.num_heads, skip_reshape=True,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, ffn_ratio: float,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
self.norm1 = ops.LayerNorm(dim, eps=LAYERNORM_EPS, device=device, dtype=dtype)
|
||||
self.attn = SelfAttention(dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.ls1 = LayerScale(dim, LAYERSCALE_INIT, device=device, dtype=dtype)
|
||||
self.norm2 = ops.LayerNorm(dim, eps=LAYERNORM_EPS, device=device, dtype=dtype)
|
||||
self.mlp = SwiGLUFFN(dim, int(dim * ffn_ratio), device=device, dtype=dtype, operations=operations)
|
||||
self.ls2 = LayerScale(dim, LAYERSCALE_INIT, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor, rope=None) -> Tensor:
|
||||
x = x + self.ls1(self.attn(self.norm1(x), rope=rope))
|
||||
x = x + self.ls2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, in_chans=3, embed_dim=EMBED_DIM, patch_size=PATCH_SIZE,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
self.proj = ops.Conv2d(
|
||||
in_chans, embed_dim,
|
||||
kernel_size=patch_size, stride=patch_size,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
|
||||
# Encoder + wrapper
|
||||
|
||||
class _DinoEncoder(nn.Module):
|
||||
"""Inner ViT module. Held under `Dinov3Backbone.encoder` so state_dict
|
||||
keys (`backbone.encoder.*`) match the upstream layout."""
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
self.patch_size = PATCH_SIZE
|
||||
self.embed_dim = EMBED_DIM
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
embed_dim=EMBED_DIM, patch_size=PATCH_SIZE,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.cls_token = nn.Parameter(torch.empty(1, 1, EMBED_DIM, device=device, dtype=dtype))
|
||||
self.storage_tokens = nn.Parameter(
|
||||
torch.empty(1, N_STORAGE_TOKENS, EMBED_DIM, device=device, dtype=dtype)
|
||||
)
|
||||
# The released config sets pos_embed_rope_dtype="fp32"; periods stays
|
||||
# in fp32 regardless of the backbone weight dtype.
|
||||
self.rope_embed = RopePositionEmbedding(EMBED_DIM, NUM_HEADS, dtype=torch.float32, device=device)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(EMBED_DIM, NUM_HEADS, FFN_RATIO, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(DEPTH)
|
||||
])
|
||||
self.norm = ops.LayerNorm(EMBED_DIM, eps=LAYERNORM_EPS, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.patch_embed.proj(x) # (B, embed_dim, H, W)
|
||||
B, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2) # (B, H*W, embed_dim)
|
||||
|
||||
# Prepend CLS + storage tokens.
|
||||
x = torch.cat([
|
||||
self.cls_token.expand(B, -1, -1),
|
||||
self.storage_tokens.expand(B, -1, -1),
|
||||
x,
|
||||
], dim=1)
|
||||
|
||||
rope = self.rope_embed(H=H, W=W)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, rope)
|
||||
x = self.norm(x)
|
||||
|
||||
# Drop CLS + storage tokens; reshape patch grid to (B, C, H, W).
|
||||
x = x[:, 1 + N_STORAGE_TOKENS :]
|
||||
return x.reshape(B, H, W, EMBED_DIM).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
|
||||
class Dinov3Backbone(nn.Module):
|
||||
"""Public backbone interface used by SAM3DBody."""
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.encoder = _DinoEncoder(device=device, dtype=dtype, operations=operations)
|
||||
self.patch_size = PATCH_SIZE
|
||||
self.embed_dim = self.embed_dims = EMBED_DIM
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.encoder(x)
|
||||
|
||||
|
||||
def apply_dinov3_qkv_bias_mask(backbone: "Dinov3Backbone") -> None:
|
||||
"""Zero the K third of every block's qkv.bias in-place.
|
||||
|
||||
Implements DINOv3's `mask_k_bias` once at load time so the per-block forward
|
||||
stays a plain F.linear instead of cloning + slicing the bias every call.
|
||||
"""
|
||||
for blk in backbone.encoder.blocks:
|
||||
qkv = blk.attn.qkv
|
||||
if qkv.bias is not None:
|
||||
o = qkv.out_features
|
||||
qkv.bias.data[o // 3 : 2 * o // 3] = 0
|
||||
1197
comfy/ldm/sam3d_body/model/model.py
Normal file
1197
comfy/ldm/sam3d_body/model/model.py
Normal file
File diff suppressed because it is too large
Load Diff
272
comfy/ldm/sam3d_body/model/prompt.py
Normal file
272
comfy/ldm/sam3d_body/model/prompt.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""SAM 3D Body prompt pipeline: encode (keypoint, mask) prompts and run them
|
||||
through a cross-attention transformer decoder over (token, image) pairs.
|
||||
|
||||
Both adapted from the SAM-style prompt path (Meta, Apache 2.0):
|
||||
https://github.com/facebookresearch/segment-anything
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.cascade.common import LayerNorm2d_op
|
||||
from comfy.ldm.sam3.sam import PositionEmbeddingRandom
|
||||
|
||||
from .transformer import TransformerDecoderLayer
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_body_joints: int,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
"""
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
"""
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
self.embed_dim = embed_dim
|
||||
self.num_body_joints = num_body_joints
|
||||
|
||||
# Keypoint prompts
|
||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||
self.point_embeddings = nn.ModuleList(
|
||||
[ops.Embedding(1, embed_dim, device=device, dtype=dtype) for _ in range(self.num_body_joints)]
|
||||
)
|
||||
self.not_a_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
self.invalid_point_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
# Mask prompt: 5-stage 2x2 strided conv downscaling to embed_dim.
|
||||
LN2d = LayerNorm2d_op(ops)
|
||||
mask_in_chans = 256
|
||||
self.mask_downscaling = nn.Sequential(
|
||||
ops.Conv2d(1, mask_in_chans // 64, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
LN2d(mask_in_chans // 64, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Conv2d(mask_in_chans // 64, mask_in_chans // 16, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
LN2d(mask_in_chans // 16, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Conv2d(mask_in_chans // 16, mask_in_chans // 4, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
LN2d(mask_in_chans // 4, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2, device=device, dtype=dtype),
|
||||
LN2d(mask_in_chans, device=device, dtype=dtype),
|
||||
nn.GELU(),
|
||||
ops.Conv2d(mask_in_chans, embed_dim, kernel_size=1, device=device, dtype=dtype),
|
||||
)
|
||||
# Trained values for the gating conv and no_mask_embed are loaded from the state dict
|
||||
self.no_mask_embed = ops.Embedding(1, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def get_dense_pe(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Positional encoding over the image-embedding grid; (1, C, H, W)."""
|
||||
return self.pe_layer(size)
|
||||
|
||||
def _embed_keypoints(self, points: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Embeds point prompts.
|
||||
Assuming points have been normalized to [0, 1].
|
||||
|
||||
Output shape [B, N, C], mask shape [B, N]
|
||||
"""
|
||||
assert points.min() >= 0 and points.max() <= 1
|
||||
# PE compute in fp32 for precision (sin/cos of large coords), then cast back to the embedding weight dtype
|
||||
weight_dtype = self.invalid_point_embed.weight.dtype
|
||||
point_embedding = self.pe_layer._encode(points.to(torch.float)).to(weight_dtype)
|
||||
point_embedding[labels == -2] = 0.0 # invalid points
|
||||
point_embedding[labels == -2] += self.invalid_point_embed.weight.to(point_embedding)
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding)
|
||||
for i in range(self.num_body_joints):
|
||||
point_embedding[labels == i] += self.point_embeddings[i].weight.to(point_embedding)
|
||||
|
||||
point_mask = labels > -2
|
||||
return point_embedding, point_mask
|
||||
|
||||
def _get_batch_size(self, keypoints: Optional[torch.Tensor], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor]) -> int:
|
||||
if keypoints is not None:
|
||||
return keypoints.shape[0]
|
||||
elif boxes is not None:
|
||||
return boxes.shape[0]
|
||||
elif masks is not None:
|
||||
return masks.shape[0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
keypoints: Optional[torch.Tensor],
|
||||
boxes: Optional[torch.Tensor] = None,
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Embeds different types of prompts, returning both sparse and dense
|
||||
embeddings.
|
||||
|
||||
Arguments:
|
||||
keypoints (torchTensor or none): point coordinates and labels to embed.
|
||||
boxes (torch.Tensor or none): boxes to embed
|
||||
masks (torch.Tensor or none): masks to embed
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
||||
BxNx(embed_dim), where N is determined by the number of input points
|
||||
and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape
|
||||
Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
"""
|
||||
bs = self._get_batch_size(keypoints, boxes, masks)
|
||||
# Anchor device on the input prompts so we don't pull the offloaded
|
||||
# CPU embedding device under dynamic loading.
|
||||
ref = keypoints if keypoints is not None else boxes if boxes is not None else masks
|
||||
device = ref.device if ref is not None else self.point_embeddings[0].weight.device
|
||||
weight_dtype = self.invalid_point_embed.weight.dtype
|
||||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=device, dtype=weight_dtype)
|
||||
sparse_masks = torch.empty((bs, 0), device=device)
|
||||
if keypoints is not None:
|
||||
coords = keypoints[:, :, :2]
|
||||
labels = keypoints[:, :, -1]
|
||||
point_embeddings, point_mask = self._embed_keypoints(coords, labels)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
||||
sparse_masks = torch.cat([sparse_masks, point_mask], dim=1)
|
||||
|
||||
return sparse_embeddings, sparse_masks
|
||||
|
||||
def get_mask_embeddings(
|
||||
self,
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
bs: int = 1,
|
||||
size: Tuple[int, int] = (16, 16), # [H, W]
|
||||
) -> torch.Tensor:
|
||||
"""Embeds mask inputs."""
|
||||
# masks is always on the active device when present; fall back to the
|
||||
# downscaling Conv's weight device when it isn't (rare callers).
|
||||
ref = masks if masks is not None else next(self.mask_downscaling.parameters())
|
||||
no_mask_embeddings = self.no_mask_embed.weight.to(ref).reshape(1, -1, 1, 1).expand(
|
||||
bs, -1, size[0], size[1]
|
||||
)
|
||||
if masks is not None:
|
||||
mask_embeddings = self.mask_downscaling(masks)
|
||||
else:
|
||||
mask_embeddings = no_mask_embeddings
|
||||
return mask_embeddings, no_mask_embeddings
|
||||
|
||||
|
||||
class PromptableDecoder(nn.Module):
|
||||
"""Cross-attention transformer decoder over (token, image) pairs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
context_dims: int,
|
||||
depth: int,
|
||||
num_heads: int = 8,
|
||||
head_dims: int = 64,
|
||||
mlp_dims: int = 1024,
|
||||
repeat_pe: bool = False,
|
||||
do_interm_preds: bool = False,
|
||||
keypoint_token_update: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
ops = operations if operations is not None else nn
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
TransformerDecoderLayer(
|
||||
token_dims=dims,
|
||||
context_dims=context_dims,
|
||||
num_heads=num_heads,
|
||||
head_dims=head_dims,
|
||||
mlp_dims=mlp_dims,
|
||||
repeat_pe=repeat_pe,
|
||||
skip_first_pe=(i == 0),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
)
|
||||
|
||||
self.norm_final = ops.LayerNorm(dims, eps=1e-6, device=device, dtype=dtype)
|
||||
self.do_interm_preds = do_interm_preds
|
||||
self.keypoint_token_update = keypoint_token_update
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_embedding: torch.Tensor,
|
||||
image_embedding: torch.Tensor,
|
||||
token_augment: Optional[torch.Tensor] = None,
|
||||
image_augment: Optional[torch.Tensor] = None,
|
||||
token_mask: Optional[torch.Tensor] = None,
|
||||
token_to_pose_output_fn=None,
|
||||
keypoint_token_update_fn=None,
|
||||
hand_embeddings=None,
|
||||
hand_augment=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
token_embedding: [B, N, C]
|
||||
image_embedding: [B, C, H, W] -- flattened to [B, HW, C] inline
|
||||
"""
|
||||
# Channels-last for the transformer.
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
if image_augment is not None:
|
||||
image_augment = image_augment.flatten(2).permute(0, 2, 1)
|
||||
if hand_embeddings is not None:
|
||||
hand_embeddings = hand_embeddings.flatten(2).permute(0, 2, 1)
|
||||
hand_augment = hand_augment.flatten(2).permute(0, 2, 1)
|
||||
if len(hand_augment) == 1:
|
||||
# inflate batch dimension
|
||||
assert len(hand_augment.shape) == 3
|
||||
hand_augment = hand_augment.repeat(len(hand_embeddings), 1, 1)
|
||||
|
||||
all_pose_outputs = [] if self.do_interm_preds else None
|
||||
if self.do_interm_preds:
|
||||
assert token_to_pose_output_fn is not None
|
||||
|
||||
layer_idx = 0
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
if hand_embeddings is None:
|
||||
token_embedding, image_embedding = layer(
|
||||
token_embedding, image_embedding,
|
||||
token_augment, image_augment, token_mask,
|
||||
)
|
||||
else:
|
||||
token_embedding, image_embedding = layer(
|
||||
token_embedding,
|
||||
torch.cat([image_embedding, hand_embeddings], dim=1),
|
||||
token_augment,
|
||||
torch.cat([image_augment, hand_augment], dim=1),
|
||||
token_mask,
|
||||
)
|
||||
image_embedding = image_embedding[:, : image_augment.shape[1]]
|
||||
|
||||
if self.do_interm_preds and layer_idx < len(self.layers) - 1:
|
||||
curr = token_to_pose_output_fn(
|
||||
self.norm_final(token_embedding),
|
||||
prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
all_pose_outputs.append(curr)
|
||||
if self.keypoint_token_update:
|
||||
assert keypoint_token_update_fn is not None
|
||||
token_embedding, token_augment, _, _ = keypoint_token_update_fn(
|
||||
token_embedding, token_augment, curr, layer_idx,
|
||||
)
|
||||
|
||||
out = self.norm_final(token_embedding)
|
||||
if self.do_interm_preds:
|
||||
curr = token_to_pose_output_fn(
|
||||
out,
|
||||
prev_pose_output=all_pose_outputs[-1] if all_pose_outputs else None,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
all_pose_outputs.append(curr)
|
||||
return out, all_pose_outputs
|
||||
return out
|
||||
104
comfy/ldm/sam3d_body/model/transformer.py
Normal file
104
comfy/ldm/sam3d_body/model/transformer.py
Normal file
@ -0,0 +1,104 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act_layer=nn.ReLU, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
|
||||
self.layers = nn.ModuleList(
|
||||
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
|
||||
for i in range(num_layers)
|
||||
)
|
||||
self.act = act_layer()
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = self.act(layer(x)) if i < len(self.layers) - 1 else layer(x)
|
||||
return x
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, embed_dims, num_heads, query_dims=None, key_dims=None, value_dims=None, qkv_bias=True, proj_bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_dims = query_dims or embed_dims
|
||||
self.key_dims = key_dims or embed_dims
|
||||
self.value_dims = value_dims or embed_dims
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.head_dims = embed_dims // num_heads
|
||||
|
||||
lin = lambda i, o, b: operations.Linear(i, o, bias=b, device=device, dtype=dtype)
|
||||
self.q_proj = lin(self.query_dims, embed_dims, qkv_bias)
|
||||
self.k_proj = lin(self.key_dims, embed_dims, qkv_bias)
|
||||
self.v_proj = lin(self.value_dims, embed_dims, qkv_bias)
|
||||
self.proj = lin(embed_dims, self.query_dims, proj_bias)
|
||||
|
||||
def _split(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, n, _ = x.shape
|
||||
return x.reshape(b, n, self.num_heads, self.head_dims).transpose(1, 2)
|
||||
|
||||
def forward(self, q, k, v, attn_mask: Optional[torch.Tensor] = None):
|
||||
q, k, v = self._split(self.q_proj(q)), self._split(self.k_proj(k)), self._split(self.v_proj(v))
|
||||
x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True, low_precision_attention=False)
|
||||
return self.proj(x)
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, token_dims, context_dims, num_heads=8, head_dims=64, mlp_dims=1024,
|
||||
repeat_pe=False, skip_first_pe=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.repeat_pe = repeat_pe
|
||||
self.skip_first_pe = skip_first_pe
|
||||
|
||||
ln = lambda d: operations.LayerNorm(d, eps=1e-6, device=device, dtype=dtype)
|
||||
attn_dim = num_heads * head_dims
|
||||
attn_kwargs = dict(embed_dims=attn_dim, num_heads=num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
if repeat_pe:
|
||||
self.ln_pe_1, self.ln_pe_2 = ln(token_dims), ln(context_dims)
|
||||
|
||||
self.ln1 = ln(token_dims)
|
||||
self.self_attn = Attention(query_dims=token_dims, key_dims=token_dims, value_dims=token_dims, **attn_kwargs)
|
||||
|
||||
self.ln2_1, self.ln2_2 = ln(token_dims), ln(context_dims)
|
||||
self.cross_attn = Attention(query_dims=token_dims, key_dims=context_dims, value_dims=context_dims, **attn_kwargs)
|
||||
|
||||
self.ln3 = ln(token_dims)
|
||||
self.ffn = MLP(token_dims, mlp_dims, token_dims, num_layers=2, act_layer=nn.GELU, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, context, x_pe=None, context_pe=None, x_mask=None):
|
||||
"""x: [B, N_tokens, C], context: [B, N_ctx, C], x_mask: [B, N_tokens] or None."""
|
||||
# LaPE-style PE re-norm per layer.
|
||||
if self.repeat_pe and context_pe is not None:
|
||||
x_pe = self.ln_pe_1(x_pe)
|
||||
context_pe = self.ln_pe_2(context_pe)
|
||||
|
||||
# Self-attn over tokens.
|
||||
if self.repeat_pe and not self.skip_first_pe and x_pe is not None:
|
||||
q = k = self.ln1(x) + x_pe
|
||||
v = self.ln1(x)
|
||||
else:
|
||||
q = k = v = self.ln1(x)
|
||||
|
||||
attn_mask = None
|
||||
if x_mask is not None:
|
||||
attn_mask = x_mask[:, :, None] @ x_mask[:, None, :]
|
||||
attn_mask.diagonal(dim1=1, dim2=2).fill_(1) # avoid all-invalid rows -> nan
|
||||
attn_mask = attn_mask > 0
|
||||
x = x + self.self_attn(q, k, v, attn_mask=attn_mask)
|
||||
|
||||
# Cross-attn: tokens attend to image context.
|
||||
if self.repeat_pe and context_pe is not None:
|
||||
q = self.ln2_1(x) + x_pe
|
||||
k = self.ln2_2(context) + context_pe
|
||||
v = self.ln2_2(context)
|
||||
else:
|
||||
q = self.ln2_1(x)
|
||||
k = v = self.ln2_2(context)
|
||||
x = x + self.cross_attn(q, k, v)
|
||||
|
||||
x = x + self.ffn(self.ln3(x))
|
||||
return x, context
|
||||
341
comfy/ldm/sam3d_body/utils.py
Normal file
341
comfy/ldm/sam3d_body/utils.py
Normal file
@ -0,0 +1,341 @@
|
||||
# The bbox/affine math (xyxy<->cs, get_warp_matrices) is the standard
|
||||
# top-down pose-estimation crop pipeline from MMPose (Apache 2.0):
|
||||
# https://github.com/open-mmlab/mmpose — same algorithm as UDP (CVPR 2020).
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Bbox + affine math
|
||||
# All `output_size` / image-shape tuples in this block are (H, W) to match
|
||||
# the torch.Size convention used everywhere else in the codebase.
|
||||
|
||||
def bbox_xyxy2cs(bbox, padding: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""xyxy bbox -> (center, scale) with optional padding multiplier."""
|
||||
bbox = torch.as_tensor(bbox, dtype=torch.float32)
|
||||
dim = bbox.dim()
|
||||
if dim == 1:
|
||||
bbox = bbox.unsqueeze(0)
|
||||
x1, y1, x2, y2 = bbox[:, 0:1], bbox[:, 1:2], bbox[:, 2:3], bbox[:, 3:4]
|
||||
center = torch.cat([x1 + x2, y1 + y2], dim=1) * 0.5
|
||||
scale = torch.cat([x2 - x1, y2 - y1], dim=1) * padding
|
||||
if dim == 1:
|
||||
return center[0], scale[0]
|
||||
return center, scale
|
||||
|
||||
|
||||
def fix_aspect_ratio(bbox_scale, aspect_ratio: float) -> torch.Tensor:
|
||||
"""Pad whichever side is too narrow to hit `aspect_ratio` (w/h)."""
|
||||
bbox_scale = torch.as_tensor(bbox_scale, dtype=torch.float32)
|
||||
dim = bbox_scale.dim()
|
||||
if dim == 1:
|
||||
bbox_scale = bbox_scale.unsqueeze(0)
|
||||
w, h = bbox_scale[:, 0:1], bbox_scale[:, 1:2]
|
||||
out = torch.where(
|
||||
w > h * aspect_ratio,
|
||||
torch.cat([w, w / aspect_ratio], dim=1),
|
||||
torch.cat([h * aspect_ratio, h], dim=1),
|
||||
)
|
||||
return out[0] if dim == 1 else out
|
||||
|
||||
|
||||
def get_warp_matrices(centers, scales, output_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Batched 2x3 affine matrices mapping each (center, scale) bbox region to
|
||||
the output box. `output_size` is (H_out, W_out). With rot=0 the MMPose
|
||||
3-point fit reduces to a closed-form isotropic scale + translate.
|
||||
"""
|
||||
centers = torch.as_tensor(centers, dtype=torch.float32)
|
||||
scales = torch.as_tensor(scales, dtype=torch.float32)
|
||||
if centers.dim() == 1:
|
||||
centers = centers.unsqueeze(0)
|
||||
scales = scales.unsqueeze(0)
|
||||
n = centers.shape[0]
|
||||
src_w = scales[:, 0]
|
||||
dst_h = float(output_size[0])
|
||||
dst_w = float(output_size[1])
|
||||
# With rot=0 the warp is just scale + translate (uniform x/y scale based
|
||||
# on src_w/dst_w). The closed form drops out of MMPose's 3-point solve.
|
||||
s = dst_w / src_w # (N,)
|
||||
mats = torch.zeros((n, 2, 3), dtype=torch.float32)
|
||||
mats[:, 0, 0] = s
|
||||
mats[:, 1, 1] = s
|
||||
mats[:, 0, 2] = dst_w * 0.5 - s * centers[:, 0]
|
||||
mats[:, 1, 2] = dst_h * 0.5 - s * centers[:, 1]
|
||||
return mats # (N, 2, 3)
|
||||
|
||||
|
||||
def warp_affine_batched(
|
||||
src_t: torch.Tensor, # (N, C, H_src, W_src) float
|
||||
mats: torch.Tensor, # (N, 2, 3) float
|
||||
output_size: Tuple[int, int] # (H_out, W_out)
|
||||
) -> torch.Tensor:
|
||||
"""Apply N forward (src->dst) 2x3 affine warps to N source images in one
|
||||
grid_sample call. Kept generic over arbitrary affines (not specialized to
|
||||
the scale+translate produced by `get_warp_matrices`) so callers can pass
|
||||
rotated/sheared affines; the per-crop 3x3 invert is O(N) of trivial work."""
|
||||
|
||||
H_out, W_out = int(output_size[0]), int(output_size[1])
|
||||
N, _, H_src, W_src = src_t.shape
|
||||
device = src_t.device
|
||||
|
||||
# Invert each forward affine; grid_sample needs dst->src.
|
||||
mats_t = mats.to(device=device, dtype=torch.float32)
|
||||
bottom = torch.tensor([0.0, 0.0, 1.0], device=device).expand(N, 1, 3)
|
||||
mats_3 = torch.cat([mats_t, bottom], dim=1) # (N, 3, 3)
|
||||
mats_inv = torch.linalg.inv(mats_3)[:, :2, :] # (N, 2, 3)
|
||||
|
||||
# Output pixel-center grid (i+0.5, j+0.5).
|
||||
ys, xs = torch.meshgrid(
|
||||
torch.arange(H_out, dtype=torch.float32, device=device) + 0.5,
|
||||
torch.arange(W_out, dtype=torch.float32, device=device) + 0.5,
|
||||
indexing="ij",
|
||||
)
|
||||
homo = torch.stack([xs, ys, torch.ones_like(xs)], dim=-1) # (H_out, W_out, 3)
|
||||
src_pos = torch.einsum("nkl,ijl->nijk", mats_inv, homo) # (N, H_out, W_out, 2)
|
||||
# Normalize to [-1, 1] grid_sample coords (align_corners=False).
|
||||
src_pos[..., 0] = src_pos[..., 0] / W_src * 2 - 1
|
||||
src_pos[..., 1] = src_pos[..., 1] / H_src * 2 - 1
|
||||
|
||||
return F.grid_sample(src_t, src_pos, mode="bilinear", padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
# Batch construction (one prediction over N person crops from a single image)
|
||||
|
||||
def prepare_batch(
|
||||
img, # (H, W, 3) uint8 torch tensor or list of such tensors
|
||||
boxes, # (N, 4) xyxy (numpy or torch)
|
||||
input_size: Tuple[int, int], # (W, H) of the model crop
|
||||
bbox_padding: float = 1.25, # xyxy->cs padding multiplier (1.25 body, 0.9 hand)
|
||||
aspect_ratio: float = 0.75, # w/h of the crop (0.75 matches HMR2/Sapiens)
|
||||
masks=None, # optional per-person masks
|
||||
masks_score=None, # optional per-person mask scores
|
||||
cam_int=None, # optional camera intrinsics
|
||||
) -> Dict:
|
||||
"""Build the batch dict the SAM3DBody forward expects, doing the N crops in one batched `grid_sample` call."""
|
||||
|
||||
is_multi_image = isinstance(img, list)
|
||||
if is_multi_image:
|
||||
assert len(img) == boxes.shape[0]
|
||||
height, width = img[0].shape[:2]
|
||||
else:
|
||||
height, width = img.shape[:2]
|
||||
|
||||
n = int(boxes.shape[0])
|
||||
assert n > 0, "prepare_batch needs at least one box"
|
||||
|
||||
W_out, H_out = int(input_size[0]), int(input_size[1])
|
||||
|
||||
# Per-box bbox math (cheap, vectorized, CPU).
|
||||
centers, scales = bbox_xyxy2cs(boxes, padding=bbox_padding)
|
||||
# Two passes: first hits the upstream bbox aspect (e.g. 0.75 HMR2/Sapiens
|
||||
# convention), second pads further if the model crop's W_out/H_out differs
|
||||
# from that. When they match (common case) the second call is a no-op.
|
||||
scales = fix_aspect_ratio(scales, aspect_ratio)
|
||||
scales = fix_aspect_ratio(scales, W_out / H_out)
|
||||
mats = get_warp_matrices(centers, scales, (H_out, W_out)) # (N, 2, 3)
|
||||
|
||||
# Stack source images into a contiguous (N, 3, H, W) tensor on CPU.
|
||||
if is_multi_image:
|
||||
src_t = torch.stack(list(img), dim=0)
|
||||
else:
|
||||
src_t = img.unsqueeze(0).expand(n, -1, -1, -1)
|
||||
src_t = src_t.permute(0, 3, 1, 2).contiguous().float() # (N, 3, H, W) in [0, 255]
|
||||
|
||||
warped_t = warp_affine_batched(src_t, mats, (H_out, W_out)) # (N, 3, H_out, W_out)
|
||||
# Float warp -> floor (matches the legacy uint8 round-trip) -> /255.
|
||||
img_t = torch.floor(warped_t).clamp_(0.0, 255.0) / 255.0 # (N, 3, H_out, W_out) in [0, 1]
|
||||
|
||||
# Masks: zero-init when missing, otherwise stack and warp through the same matrices.
|
||||
boxes_t = torch.as_tensor(boxes, dtype=torch.float32)
|
||||
if masks is None:
|
||||
mask_t = torch.zeros((n, H_out, W_out), dtype=torch.float32)
|
||||
mask_score_t = torch.zeros((n,), dtype=torch.float32)
|
||||
else:
|
||||
# masks is an array of N items, each (H, W) or (H, W, 1).
|
||||
masks_t = torch.stack([torch.as_tensor(masks[i]) for i in range(n)], dim=0)
|
||||
if masks_t.dim() == 4 and masks_t.shape[-1] == 1:
|
||||
masks_t = masks_t[..., 0]
|
||||
masks_src_t = masks_t.float().unsqueeze(1) # (N, 1, H, W) in [0, 255]
|
||||
warped_masks = warp_affine_batched(masks_src_t, mats, (H_out, W_out))
|
||||
mask_t = torch.floor(warped_masks.squeeze(1)).clamp_(0.0, 255.0)
|
||||
if masks_score is not None:
|
||||
mask_score_t = torch.as_tensor([masks_score[i] for i in range(n)], dtype=torch.float32)
|
||||
else:
|
||||
mask_score_t = torch.ones((n,), dtype=torch.float32)
|
||||
|
||||
img_size_t = torch.tensor([W_out, H_out], dtype=torch.float32).expand(n, 2).contiguous()
|
||||
ori_img_size_t = torch.tensor([width, height], dtype=torch.float32).expand(n, 2).contiguous()
|
||||
|
||||
batch = {
|
||||
"img": img_t.unsqueeze(0), # (1, N, 3, H_out, W_out)
|
||||
"img_size": img_size_t.unsqueeze(0), # (1, N, 2)
|
||||
"ori_img_size": ori_img_size_t.unsqueeze(0),# (1, N, 2)
|
||||
"bbox_center": centers.unsqueeze(0), # (1, N, 2)
|
||||
"bbox_scale": scales.unsqueeze(0), # (1, N, 2)
|
||||
"bbox": boxes_t.unsqueeze(0), # (1, N, 4)
|
||||
"affine_trans": mats.unsqueeze(0), # (1, N, 2, 3)
|
||||
"mask": mask_t.unsqueeze(0).unsqueeze(2), # (1, N, 1, H_out, W_out)
|
||||
"mask_score": mask_score_t.unsqueeze(0), # (1, N)
|
||||
"person_valid": torch.ones((1, n), dtype=torch.float32),
|
||||
}
|
||||
|
||||
if cam_int is not None:
|
||||
batch["cam_int"] = cam_int.to(batch["img"])
|
||||
else:
|
||||
# Default intrinsics: focal = sqrt(W^2 + H^2), principal point = image center.
|
||||
f = (height ** 2 + width ** 2) ** 0.5
|
||||
batch["cam_int"] = torch.tensor(
|
||||
[[[f, 0, width / 2.0], [0, f, height / 2.0], [0, 0, 1]]],
|
||||
).to(batch["img"])
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
# Geometry utils
|
||||
|
||||
def rot6d_to_rotmat(
|
||||
x: torch.Tensor # (B, 6) batch of 6-D rotation representations.
|
||||
) -> torch.Tensor: # (B, 3, 3) rotation matrices.
|
||||
"""6D continuous rotation rep (Zhou et al., CVPR 2019) -> 3x3 rotation matrix."""
|
||||
x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
|
||||
a1, a2 = x[:, :, 0], x[:, :, 1]
|
||||
b1 = F.normalize(a1)
|
||||
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
|
||||
b3 = torch.linalg.cross(b1, b2)
|
||||
return torch.stack((b1, b2, b3), dim=-1)
|
||||
|
||||
|
||||
def perspective_projection(
|
||||
x: torch.Tensor, # (B, N, 3) 3D points in camera coords.
|
||||
K: torch.Tensor # (B, 3, 3) camera intrinsics.
|
||||
) -> torch.Tensor: # (B, N, 2) 2D image-plane projections.
|
||||
"""Project 3D points (already in camera frame) through intrinsics K."""
|
||||
y = x / x[:, :, -1].unsqueeze(-1) # perspective divide
|
||||
y = torch.einsum("bij,bkj->bki", K, y) # apply intrinsics
|
||||
return y[:, :, :2]
|
||||
|
||||
|
||||
# Rotation conversions, behavior mirrors the roma library (https://github.com/naver/roma)
|
||||
|
||||
def _axis_rotmat(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotation matrices around a single coordinate axis. Shape (..., 3, 3)."""
|
||||
cos = torch.cos(angle)
|
||||
sin = torch.sin(angle)
|
||||
one = torch.ones_like(angle)
|
||||
zero = torch.zeros_like(angle)
|
||||
if axis == "X":
|
||||
flat = (one, zero, zero,
|
||||
zero, cos, -sin,
|
||||
zero, sin, cos)
|
||||
elif axis == "Y":
|
||||
flat = (cos, zero, sin,
|
||||
zero, one, zero,
|
||||
-sin, zero, cos)
|
||||
elif axis == "Z":
|
||||
flat = (cos, -sin, zero,
|
||||
sin, cos, zero,
|
||||
zero, zero, one)
|
||||
else:
|
||||
raise ValueError(f"Invalid axis {axis!r}; expected X/Y/Z.")
|
||||
return torch.stack(flat, dim=-1).reshape(angle.shape + (3, 3))
|
||||
|
||||
|
||||
def euler_to_rotmat(convention: str, angles: torch.Tensor) -> torch.Tensor:
|
||||
"""Euler angles -> rotation matrix, matching roma's case-keyed convention."""
|
||||
axes = convention.upper()
|
||||
R0 = _axis_rotmat(axes[0], angles[..., 0])
|
||||
R1 = _axis_rotmat(axes[1], angles[..., 1])
|
||||
R2 = _axis_rotmat(axes[2], angles[..., 2])
|
||||
if convention.islower():
|
||||
return R2 @ R1 @ R0
|
||||
return R0 @ R1 @ R2
|
||||
|
||||
|
||||
def _index_from_letter(letter: str) -> int:
|
||||
return {"X": 0, "Y": 1, "Z": 2}[letter]
|
||||
|
||||
|
||||
def _angle_from_tan(
|
||||
axis: str,
|
||||
other_axis: str,
|
||||
data: torch.Tensor,
|
||||
horizontal: bool,
|
||||
tait_bryan: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Extract an outer Euler angle from a row/column of a rotation matrix.
|
||||
|
||||
Adapted from PyTorch3D's matrix_to_euler_angles helper.
|
||||
"""
|
||||
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
||||
if horizontal:
|
||||
i2, i1 = i1, i2
|
||||
even = (axis + other_axis) in ("XY", "YZ", "ZX")
|
||||
if horizontal == even:
|
||||
return torch.atan2(data[..., i1], data[..., i2])
|
||||
if tait_bryan:
|
||||
return torch.atan2(-data[..., i2], data[..., i1])
|
||||
return torch.atan2(data[..., i2], -data[..., i1])
|
||||
|
||||
|
||||
def _matrix_to_euler_intrinsic(matrix: torch.Tensor, convention: str) -> torch.Tensor:
|
||||
"""Decompose a rotation matrix into intrinsic Euler angles (uppercase abc).
|
||||
|
||||
Adapted from PyTorch3D's matrix_to_euler_angles.
|
||||
"""
|
||||
i0 = _index_from_letter(convention[0])
|
||||
i2 = _index_from_letter(convention[2])
|
||||
tait_bryan = i0 != i2
|
||||
if tait_bryan:
|
||||
sign = -1.0 if (i0 - i2) in (-1, 2) else 1.0
|
||||
central = torch.asin(matrix[..., i0, i2] * sign)
|
||||
else:
|
||||
central = torch.acos(matrix[..., i0, i0])
|
||||
|
||||
out = (
|
||||
_angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan),
|
||||
central,
|
||||
_angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan),
|
||||
)
|
||||
return torch.stack(out, dim=-1)
|
||||
|
||||
|
||||
def rotmat_to_euler(convention: str, matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotation matrix -> Euler angles, inverse of :func:`euler_to_rotmat`.
|
||||
|
||||
PyTorch3D's matrix_to_euler_angles uses the convention R = R_a R_b R_c for
|
||||
convention "abc"; that matches roma's UPPERCASE ordering directly. For
|
||||
roma's lowercase, the matrix is reversed (R_c R_b R_a), so we decompose
|
||||
with the reversed convention and flip the angles back to axis order.
|
||||
"""
|
||||
if matrix.shape[-2:] != (3, 3):
|
||||
raise ValueError(f"Expected (..., 3, 3) rotation matrix, got {tuple(matrix.shape)}.")
|
||||
if convention.isupper():
|
||||
return _matrix_to_euler_intrinsic(matrix, convention)
|
||||
decomposed = _matrix_to_euler_intrinsic(matrix, convention.upper()[::-1])
|
||||
return decomposed.flip(-1)
|
||||
|
||||
|
||||
def unitquat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
|
||||
"""Unit quaternion (x, y, z, w) -> rotation matrix.
|
||||
|
||||
Matches roma.unitquat_to_rotmat (scalar-last). The quaternion is assumed to be normalized.
|
||||
|
||||
Args:
|
||||
quat: (..., 4) unit quaternion.
|
||||
Returns:
|
||||
(..., 3, 3) rotation matrix.
|
||||
"""
|
||||
x, y, z, w = quat.unbind(dim=-1)
|
||||
tx, ty, tz = 2 * x, 2 * y, 2 * z
|
||||
twx, twy, twz = tx * w, ty * w, tz * w
|
||||
txx, txy, txz = tx * x, ty * x, tz * x
|
||||
tyy, tyz, tzz = ty * y, tz * y, tz * z
|
||||
one = torch.ones_like(w)
|
||||
flat = (
|
||||
one - (tyy + tzz), txy - twz, txz + twy,
|
||||
txy + twz, one - (txx + tzz), tyz - twx,
|
||||
txz - twy, tyz + twx, one - (txx + tyy),
|
||||
)
|
||||
return torch.stack(flat, dim=-1).reshape(quat.shape[:-1] + (3, 3))
|
||||
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -558,31 +558,47 @@ def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -
|
||||
|
||||
class FaceLandmarker(nn.Module):
|
||||
"""BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short'
|
||||
(128², ≤2m) or 'full' (192² FPN, ≤5m). State dict uses inner-module prefixes
|
||||
`detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel
|
||||
(128², ≤2m), 'full' (192² FPN, ≤5m), or 'both' — which holds both detectors
|
||||
alongside shared mesh/blendshapes so detect_batch can run them and per-frame
|
||||
keep whichever found more faces (tie → short). State dict uses inner-module
|
||||
prefixes `detector.*` (short or single-variant), `detector_full.*` (only
|
||||
under 'both'), `mesh.*`, `blendshapes.*`; the outer FaceLandmarkerModel
|
||||
wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading.
|
||||
"""
|
||||
|
||||
def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"):
|
||||
super().__init__()
|
||||
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant)
|
||||
|
||||
self.detector_variant = detector_variant
|
||||
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
|
||||
if detector_variant == "both":
|
||||
self.detector = BlazeFace(device=device, dtype=dtype, operations=operations)
|
||||
self.detector_full = BlazeFaceFullRange(device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}[detector_variant]
|
||||
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
|
||||
self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations)
|
||||
self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations)
|
||||
self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False)
|
||||
|
||||
def _detector(self, variant: str) -> nn.Module:
|
||||
if self.detector_variant == "both":
|
||||
return self.detector_full if variant == "full" else self.detector
|
||||
return self.detector
|
||||
|
||||
def run_detector_batch(self, images_rgb_uint8: List[np.ndarray],
|
||||
score_thresh: float = _BF_MIN_SCORE,
|
||||
iou_thresh: float = 0.5):
|
||||
iou_thresh: float = 0.5,
|
||||
variant: Optional[str] = None):
|
||||
"""Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded)
|
||||
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords."""
|
||||
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords.
|
||||
`variant` overrides per-call (required on 'both'-mode instances)."""
|
||||
if not images_rgb_uint8:
|
||||
return [], [], [], []
|
||||
device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype
|
||||
if variant is None:
|
||||
variant = "short" if self.detector_variant == "both" else self.detector_variant
|
||||
detector = self._detector(variant)
|
||||
device, dtype = detector.stem.weight.device, detector.stem.weight.dtype
|
||||
det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range)
|
||||
if self.detector_variant == "full"
|
||||
if variant == "full"
|
||||
else (_BF_INPUT_SIZE, _decode_blazeface))
|
||||
|
||||
# Same-size frames: stack once and transfer once. Variable size falls back
|
||||
@ -598,7 +614,7 @@ class FaceLandmarker(nn.Module):
|
||||
det_crops = [w[0] for w in warps]
|
||||
sub_rects = [(w[1], w[2], w[3]) for w in warps]
|
||||
|
||||
regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0))
|
||||
regs_b, cls_b = detector(torch.stack(det_crops, dim=0))
|
||||
regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy()
|
||||
per_frame = []
|
||||
for b in range(len(images_rgb_uint8)):
|
||||
@ -607,15 +623,22 @@ class FaceLandmarker(nn.Module):
|
||||
return img_raws, sub_rects, sizes, per_frame
|
||||
|
||||
def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1,
|
||||
score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]:
|
||||
score_thresh: float = _BF_MIN_SCORE,
|
||||
variant: Optional[str] = None) -> List[List[dict]]:
|
||||
"""Full pipeline batched across `images_rgb_uint8`. Returns one face-dict
|
||||
list per image (empty if nothing detected). Face dict:
|
||||
bbox_xyxy (4,) image pixels, blendshapes {52} ∈ [0,1],
|
||||
landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in
|
||||
192-canonical (pre-transformation) units, presence float (raw logit).
|
||||
On 'both'-mode instances `variant=None` runs both detectors and keeps
|
||||
whichever found more faces per frame (tie → short).
|
||||
"""
|
||||
if variant is None and self.detector_variant == "both":
|
||||
s = self.detect_batch(images_rgb_uint8, num_faces=num_faces, score_thresh=score_thresh, variant="short")
|
||||
f = self.detect_batch(images_rgb_uint8, num_faces=num_faces, score_thresh=score_thresh, variant="full")
|
||||
return [s[b] if len(s[b]) >= len(f[b]) else f[b] for b in range(len(images_rgb_uint8))]
|
||||
img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch(
|
||||
images_rgb_uint8, score_thresh=score_thresh,
|
||||
images_rgb_uint8, score_thresh=score_thresh, variant=variant,
|
||||
)
|
||||
# tensor-normalized → image-normalized [0,1] for _detection_to_face_rect.
|
||||
for b, decoded in enumerate(per_frame_dets):
|
||||
|
||||
1014
comfy_extras/nodes_sam3d_body.py
Normal file
1014
comfy_extras/nodes_sam3d_body.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,6 @@
|
||||
"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB node."""
|
||||
"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB
|
||||
node, plus pose-data exporters (BuildPoseGLB / SavePoseBVH) that accept either
|
||||
SAM3DBody Predict's MHR pose data or external-rig pose data from Kimodo."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
@ -15,6 +17,15 @@ import folder_paths
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
from comfy_extras.sam3d_body.export.bvh import build_bvh
|
||||
from comfy_extras.sam3d_body.export.glb_openpose import build_glb_openpose
|
||||
from comfy_extras.sam3d_body.export.glb_skeletal import build_glb_skeletal
|
||||
|
||||
|
||||
MHRPoseData = IO.Custom("MHR_POSE_DATA")
|
||||
KimodoPoseData = IO.Custom("KIMODO_POSE_DATA")
|
||||
SAM3DBodyModel = IO.Custom("SAM3D_BODY_MODEL")
|
||||
|
||||
|
||||
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None):
|
||||
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
|
||||
@ -386,10 +397,437 @@ class SaveGLB(IO.ComfyNode):
|
||||
return IO.NodeOutput(ui={"3d": results})
|
||||
|
||||
|
||||
def rainbow_tilt_inputs():
|
||||
"""Shared rainbow-shader tilt inputs (used by Render and ToGLB schemas)."""
|
||||
return [
|
||||
IO.Float.Input(
|
||||
"rainbow_tilt_z", default=-35.0, min=-90.0, max=90.0, step=0.5,
|
||||
tooltip="Rotate rainbow jet axis around Z (forward). Differentiates left/right.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"rainbow_tilt_x", default=0.0, min=-90.0, max=90.0, step=0.5,
|
||||
tooltip="Rotate rainbow jet axis around X (right). Differentiates front/back.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class BuildPoseGLB(IO.ComfyNode):
|
||||
"""Convert pose_data to an in-memory animated GLB"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BuildPoseGLB",
|
||||
display_name="Build Pose GLB",
|
||||
description="Convert pose data to an animated GLB",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"pose_data", types=[MHRPoseData, KimodoPoseData],
|
||||
tooltip=(
|
||||
"MHR pose data from SAM3DBody_Predict, or external-rig "
|
||||
"pose data from Kimodo (`_skeleton_override`-augmented)."
|
||||
),
|
||||
),
|
||||
SAM3DBodyModel.Input("sam3d_body_model", optional=True),
|
||||
IO.DynamicCombo.Input(
|
||||
"mesh_style",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("body_mesh", [
|
||||
IO.Int.Input(
|
||||
"bone_smooth_window",
|
||||
default=0, min=0, max=51, step=2,
|
||||
tooltip=(
|
||||
"Gaussian window on per-bone rotation keyframes. 0 = off. "
|
||||
"7-15 helps cartwheels/spins where upstream Smooth misses spikes."
|
||||
),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"bone_vis",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("off", []),
|
||||
IO.DynamicCombo.Option("octahedrons", [
|
||||
IO.Float.Input(
|
||||
"bone_vis_radius_m",
|
||||
default=0.02, min=0.005, max=0.5, step=0.005,
|
||||
tooltip="Radius in m (sphere radius / octahedron half-width).",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"bone_vis_color",
|
||||
options=["white", "rainbow_y"],
|
||||
default="rainbow_y",
|
||||
tooltip=(
|
||||
"Per-bone vertex colors (unlit material). "
|
||||
"'white' = none, 'rainbow_y' = head→toe jet."
|
||||
),
|
||||
),
|
||||
]),
|
||||
IO.DynamicCombo.Option("sticks", [
|
||||
IO.Combo.Input(
|
||||
"bone_vis_color",
|
||||
options=["white", "rainbow_y"],
|
||||
default="rainbow_y",
|
||||
tooltip="Per-bone vertex colors (see octahedrons).",
|
||||
),
|
||||
]),
|
||||
],
|
||||
tooltip=(
|
||||
"Bone vis shape, rigidly skinned to each joint. "
|
||||
"'octahedrons' = Blender-style directional bones (joint → "
|
||||
"primary child); 'sticks' = thin lines."
|
||||
),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"shader",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("default", []),
|
||||
IO.DynamicCombo.Option("rainbow", [
|
||||
*rainbow_tilt_inputs(),
|
||||
IO.Float.Input(
|
||||
"person_palette_falloff",
|
||||
default=0.6, min=0.1, max=1.0, step=0.05,
|
||||
tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.",
|
||||
),
|
||||
]),
|
||||
IO.DynamicCombo.Option("rainbow_face_normal", [
|
||||
*rainbow_tilt_inputs(),
|
||||
IO.Float.Input(
|
||||
"person_palette_falloff",
|
||||
default=0.6, min=0.1, max=1.0, step=0.05,
|
||||
tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.",
|
||||
),
|
||||
]),
|
||||
IO.DynamicCombo.Option("rainbow_face_semantic", [
|
||||
*rainbow_tilt_inputs(),
|
||||
IO.Float.Input(
|
||||
"person_palette_falloff",
|
||||
default=0.6, min=0.1, max=1.0, step=0.05,
|
||||
tooltip="Per-person desaturation: track k gets (1 - falloff^k) pastel mix.",
|
||||
),
|
||||
]),
|
||||
],
|
||||
tooltip=(
|
||||
"Bake per-vertex colors matching the Render node's shaders "
|
||||
"(COLOR_0 + KHR_materials_unlit). 'default' = no colors."
|
||||
),
|
||||
),
|
||||
]),
|
||||
IO.DynamicCombo.Option("bones_only", [
|
||||
IO.Int.Input(
|
||||
"bone_smooth_window",
|
||||
default=0, min=0, max=51, step=2,
|
||||
tooltip=(
|
||||
"Gaussian window on per-bone rotation keyframes. 0 = off. "
|
||||
"7-15 helps cartwheels/spins where upstream Smooth misses spikes."
|
||||
),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"bone_vis",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("octahedrons", [
|
||||
IO.Float.Input(
|
||||
"bone_vis_radius_m",
|
||||
default=0.02, min=0.005, max=0.5, step=0.005,
|
||||
tooltip="Radius in m (sphere radius / octahedron half-width).",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"bone_vis_color",
|
||||
options=["white", "rainbow_y"],
|
||||
default="rainbow_y",
|
||||
tooltip=(
|
||||
"Per-bone vertex colors (unlit material). "
|
||||
"'white' = none, 'rainbow_y' = head→toe jet."
|
||||
),
|
||||
),
|
||||
]),
|
||||
IO.DynamicCombo.Option("sticks", [
|
||||
IO.Combo.Input(
|
||||
"bone_vis_color",
|
||||
options=["white", "rainbow_y"],
|
||||
default="rainbow_y",
|
||||
tooltip="Per-bone vertex colors (see octahedrons).",
|
||||
),
|
||||
]),
|
||||
],
|
||||
tooltip=(
|
||||
"Bone vis shape, rigidly skinned to each joint. "
|
||||
"'octahedrons' = Blender-style directional bones (joint → "
|
||||
"primary child); 'sticks' = thin lines."
|
||||
),
|
||||
),
|
||||
]),
|
||||
IO.DynamicCombo.Option("openpose", [
|
||||
IO.Float.Input(
|
||||
"marker_radius_m", default=0.010, min=0.005, max=0.1, step=0.001,
|
||||
tooltip="Sphere radius in m.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"stick_radius_m", default=0.008, min=0.002, max=0.05, step=0.001,
|
||||
tooltip="Limb half-width in m. Auto-clamped to bone_length x 0.1.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"include_hands", default=False,
|
||||
tooltip=(
|
||||
"Append 21+21 OpenPose hands (wrist + 5 fingers x 4 joints, "
|
||||
"base→tip) sourced from pred_keypoints_3d."
|
||||
),
|
||||
),
|
||||
IO.Float.Input(
|
||||
"hand_marker_radius_m", default=0.005, min=0.001, max=0.1, step=0.001,
|
||||
tooltip="Hand sphere radius in m.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"hand_stick_radius_m", default=0.003, min=0.001, max=0.05, step=0.001,
|
||||
tooltip="Hand limb half-width in m.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"face_source",
|
||||
options=["off", "rig"],
|
||||
default="off",
|
||||
tooltip=(
|
||||
"'rig' adds ~30 face-contour landmarks sampled from pred_vertices "
|
||||
"at fixed head-mesh vertex IDs (brow/eyes/nose/mouth/jaw); needs "
|
||||
"canonical_colors on pose_data."
|
||||
),
|
||||
),
|
||||
IO.Float.Input(
|
||||
"face_marker_radius_m", default=0.0, min=0.0, max=0.05, step=0.0005,
|
||||
tooltip="Face dot radius. 0 = auto = 0.3 x marker_radius_m.",
|
||||
),
|
||||
]),
|
||||
IO.DynamicCombo.Option("scail", [
|
||||
IO.Float.Input(
|
||||
"stick_radius_m", default=0.022, min=0.002, max=0.1, step=0.001,
|
||||
tooltip=(
|
||||
"Cylinder radius in m. Bones are open cylinders at constant "
|
||||
"radius; joint spheres (auto-sized to match) cap the open ends. "
|
||||
"SCAIL reference = 0.0215 m."
|
||||
),
|
||||
),
|
||||
IO.Float.Input(
|
||||
"marker_radius_m", default=0.0, min=0.0, max=0.1, step=0.001,
|
||||
tooltip="Joint sphere radius. 0 = auto = stick_radius_m (flush cap).",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"material_roughness", default=0.3, min=0.0, max=1.0, step=0.05,
|
||||
tooltip="PBR roughness. SCAIL ref = 0.3. 1 = matte; 0 = chrome.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"include_hands", default=False,
|
||||
tooltip="Append 21+21 hand keypoints + capsule sticks per track.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"hand_marker_radius_m", default=0.005, min=0.001, max=0.05, step=0.001,
|
||||
tooltip="Hand sphere radius in m.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"hand_stick_radius_m", default=0.003, min=0.001, max=0.05, step=0.001,
|
||||
tooltip="Hand cylinder radius in m.",
|
||||
),
|
||||
]),
|
||||
],
|
||||
tooltip=(
|
||||
"'body_mesh' = real Armature (127 bones, skinning, TRS "
|
||||
"keyframes, 72 face morphs; needs model). "
|
||||
"'bones_only' = bone-shape primitives at each joint (preview armature). "
|
||||
"'openpose' = OpenPose-18 3D skeleton from keypoints "
|
||||
"(no model needed). 'scail' = SCAIL 3D capsule rig (open "
|
||||
"cylinders capped flush by joint spheres)."
|
||||
),
|
||||
),
|
||||
IO.Float.Input(
|
||||
"fps", default=24.0, min=1.0, max=240.0, step=1.0,
|
||||
tooltip="Animation frame rate.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"camera_translation",
|
||||
options=["off", "centered", "absolute"],
|
||||
default="off",
|
||||
tooltip=(
|
||||
"Bake pred_cam_t into per-track root translation. "
|
||||
"'off' = origin; 'centered' = delta from frame 0; "
|
||||
"'absolute' = raw (Z is camera depth — usually meters away)."
|
||||
),
|
||||
),
|
||||
IO.Int.Input(
|
||||
"track_index", default=-1, min=-1, max=15,
|
||||
tooltip="-1 = all tracks; ≥0 = single track.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.File3DGLB.Output("glb")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, pose_data, mesh_style, sam3d_body_model=None, fps=24.0, camera_translation="off", track_index=-1) -> IO.NodeOutput:
|
||||
mesh_style = mesh_style or {"mesh_style": "body_mesh"}
|
||||
mode_key = mesh_style["mesh_style"]
|
||||
# `shader` is nested in body_mesh; absent for bones_only.
|
||||
shader_dict = mesh_style.get("shader") or {}
|
||||
shader_key = shader_dict.get("shader", "default")
|
||||
common = dict(
|
||||
fps=float(fps),
|
||||
camera_translation=str(camera_translation),
|
||||
track_index=int(track_index),
|
||||
shader=str(shader_key),
|
||||
rainbow_tilt_x_deg=float(shader_dict.get("rainbow_tilt_x", 0.0)),
|
||||
rainbow_tilt_z_deg=float(shader_dict.get("rainbow_tilt_z", 0.0)),
|
||||
person_palette_falloff=float(shader_dict.get("person_palette_falloff", 0.6)),
|
||||
)
|
||||
if mode_key in ("body_mesh", "bones_only"):
|
||||
# External rigs (e.g. ComfyUI-Kimodo) supply pose_data["_skeleton_override"]
|
||||
# so the GLB writer reads rig/bind/skin from there instead of MHR.
|
||||
has_external_rig = isinstance(pose_data, dict) and ("_skeleton_override" in pose_data)
|
||||
if sam3d_body_model is None and not has_external_rig:
|
||||
raise ValueError(
|
||||
f"BuildPoseGLB: '{mode_key}' mode needs the `sam3d_body_model` input OR a "
|
||||
"`_skeleton_override` dict in pose_data. Connect the SAM3DBody model "
|
||||
"or feed pose_data from a node that supplies the override (e.g. KimodoSample)."
|
||||
)
|
||||
default_shape = "off" if mode_key == "body_mesh" else "octahedrons"
|
||||
bone_vis_dict = mesh_style.get("bone_vis", {"bone_vis": default_shape})
|
||||
bone_vis = str(bone_vis_dict.get("bone_vis", default_shape))
|
||||
bone_vis_radius_m = float(bone_vis_dict.get("bone_vis_radius_m", 0.04))
|
||||
bone_vis_color = str(bone_vis_dict.get("bone_vis_color", "white"))
|
||||
glb_bytes = build_glb_skeletal(
|
||||
pose_data, sam3d_body_model,
|
||||
bone_smooth_window=int(mesh_style.get("bone_smooth_window", 0)),
|
||||
bone_vis=bone_vis,
|
||||
bone_vis_radius_m=bone_vis_radius_m,
|
||||
bone_vis_color=bone_vis_color,
|
||||
include_body_mesh=(mode_key == "body_mesh"),
|
||||
**common,
|
||||
)
|
||||
elif mode_key == "openpose":
|
||||
# Rig-independent: sourced from pred_keypoints_3d. face_source='rig'
|
||||
# additionally reads canonical_colors for head-mesh vertex IDs.
|
||||
glb_bytes = build_glb_openpose(
|
||||
pose_data,
|
||||
fps=float(fps),
|
||||
camera_translation=str(camera_translation),
|
||||
track_index=int(track_index),
|
||||
marker_radius_m=float(mesh_style.get("marker_radius_m", 0.025)),
|
||||
stick_radius_m=float(mesh_style.get("stick_radius_m", 0.008)),
|
||||
include_hands=bool(mesh_style.get("include_hands", False)),
|
||||
hand_marker_radius_m=float(mesh_style.get("hand_marker_radius_m", 0.005)),
|
||||
hand_stick_radius_m=float(mesh_style.get("hand_stick_radius_m", 0.003)),
|
||||
face_source=str(mesh_style.get("face_source", "off")),
|
||||
face_marker_radius_m=float(mesh_style.get("face_marker_radius_m", 0.0)),
|
||||
palette="openpose",
|
||||
shape="ellipsoid",
|
||||
)
|
||||
elif mode_key == "scail":
|
||||
# SCAIL rig: open cylinders capped flush by joint spheres (sphere
|
||||
# radius defaults to cylinder radius for a seamless silhouette).
|
||||
cap_stick_radius = float(mesh_style.get("stick_radius_m", 0.022))
|
||||
cap_marker_radius = float(mesh_style.get("marker_radius_m", 0.0))
|
||||
if cap_marker_radius <= 0.0:
|
||||
cap_marker_radius = cap_stick_radius
|
||||
glb_bytes = build_glb_openpose(
|
||||
pose_data,
|
||||
fps=float(fps),
|
||||
camera_translation=str(camera_translation),
|
||||
track_index=int(track_index),
|
||||
marker_radius_m=cap_marker_radius,
|
||||
stick_radius_m=cap_stick_radius,
|
||||
include_hands=bool(mesh_style.get("include_hands", False)),
|
||||
hand_marker_radius_m=float(mesh_style.get("hand_marker_radius_m", 0.005)),
|
||||
hand_stick_radius_m=float(mesh_style.get("hand_stick_radius_m", 0.003)),
|
||||
face_source="off",
|
||||
palette="scail",
|
||||
shape="capsule",
|
||||
smooth_shade=True,
|
||||
# SCAIL material: slightly glossy (0.3) + double-sided so the
|
||||
# inside of the open cylinders shades sensibly at grazing angles.
|
||||
material_roughness=float(mesh_style.get("material_roughness", 0.3)),
|
||||
material_double_sided=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"BuildPoseGLB: unknown mesh_style {mode_key!r}")
|
||||
|
||||
return IO.NodeOutput(Types.File3D(BytesIO(glb_bytes), file_format="glb"))
|
||||
|
||||
|
||||
class SavePoseBVH(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SavePoseBVH",
|
||||
description="Save pose data as BVH mocap file",
|
||||
display_name="Save Pose BVH",
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"pose_data", types=[MHRPoseData, KimodoPoseData],
|
||||
tooltip=(
|
||||
"MHR pose data from SAM3DBody_Predict, or external-rig "
|
||||
"pose data from Kimodo."
|
||||
),
|
||||
),
|
||||
SAM3DBodyModel.Input("sam3d_body_model"),
|
||||
IO.String.Input("filename_prefix", default="3d/ComfyUI"),
|
||||
IO.Float.Input(
|
||||
"fps", default=24.0, min=1.0, max=240.0, step=1.0,
|
||||
tooltip="Animation frame rate (BVH `Frame Time`).",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"camera_translation",
|
||||
options=["off", "centered", "absolute"],
|
||||
default="off",
|
||||
tooltip=(
|
||||
"Bake pred_cam_t into the root's position channels. "
|
||||
"'off' = bind position; 'centered' = delta from frame 0; "
|
||||
"'absolute' = raw (Z is camera depth — usually meters away)."
|
||||
),
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"units",
|
||||
options=["cm", "m"],
|
||||
default="cm",
|
||||
tooltip="BVH OFFSET/position units. 'cm' is the mocap standard.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"track_index", default=0, min=0, max=15,
|
||||
tooltip="Track to export. BVH carries one skeleton; export multi-person clips one at a time.",
|
||||
),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, pose_data, sam3d_body_model, filename_prefix="3d/ComfyUI",
|
||||
fps=24.0, camera_translation="off", units="cm",
|
||||
track_index=0) -> IO.NodeOutput:
|
||||
bvh_bytes = build_bvh(
|
||||
pose_data, sam3d_body_model,
|
||||
fps=float(fps),
|
||||
camera_translation=str(camera_translation),
|
||||
track_index=int(track_index),
|
||||
units=str(units),
|
||||
)
|
||||
|
||||
full_output_folder, filename, counter, subfolder, _ = \
|
||||
folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory(),
|
||||
)
|
||||
f = f"{filename}_{counter:05}_.bvh"
|
||||
out_path = os.path.join(full_output_folder, f)
|
||||
with open(out_path, "wb") as fh:
|
||||
fh.write(bvh_bytes)
|
||||
|
||||
return IO.NodeOutput(ui={"3d": [{
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output",
|
||||
}]})
|
||||
|
||||
|
||||
class Save3DExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [SaveGLB]
|
||||
return [SaveGLB, BuildPoseGLB, SavePoseBVH]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> Save3DExtension:
|
||||
|
||||
@ -2,11 +2,10 @@ import torch
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import numpy as np
|
||||
import math
|
||||
import colorsys
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.pose.keypoint_draw import KeypointDraw
|
||||
from comfy_extras.nodes_lotus import LotusConditioning
|
||||
|
||||
|
||||
@ -73,281 +72,6 @@ def _to_openpose_frames(all_keypoints, all_scores, height, width):
|
||||
return frames
|
||||
|
||||
|
||||
class KeypointDraw:
|
||||
"""
|
||||
Pose keypoint drawing class that supports both numpy and cv2 backends.
|
||||
"""
|
||||
def __init__(self):
|
||||
try:
|
||||
import cv2
|
||||
self.draw = cv2
|
||||
except ImportError:
|
||||
self.draw = self
|
||||
|
||||
# Hand connections (same for both hands)
|
||||
self.hand_edges = [
|
||||
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
|
||||
[0, 5], [5, 6], [6, 7], [7, 8], # index
|
||||
[0, 9], [9, 10], [10, 11], [11, 12], # middle
|
||||
[0, 13], [13, 14], [14, 15], [15, 16], # ring
|
||||
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
|
||||
]
|
||||
|
||||
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed)
|
||||
self.body_limbSeq = [
|
||||
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
|
||||
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
|
||||
[1, 16], [16, 18]
|
||||
]
|
||||
|
||||
# Colors matching DWPose
|
||||
self.colors = [
|
||||
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
|
||||
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
|
||||
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
|
||||
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def circle(canvas_np, center, radius, color, **kwargs):
|
||||
"""Draw a filled circle using NumPy vectorized operations."""
|
||||
cx, cy = center
|
||||
h, w = canvas_np.shape[:2]
|
||||
|
||||
radius_int = int(np.ceil(radius))
|
||||
|
||||
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
|
||||
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
|
||||
|
||||
if y_max <= y_min or x_max <= x_min:
|
||||
return
|
||||
|
||||
y, x = np.ogrid[y_min:y_max, x_min:x_max]
|
||||
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
|
||||
canvas_np[y_min:y_max, x_min:x_max][mask] = color
|
||||
|
||||
@staticmethod
|
||||
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
|
||||
"""Draw line using Bresenham's algorithm with NumPy operations."""
|
||||
x0, y0, x1, y1 = *pt1, *pt2
|
||||
h, w = canvas_np.shape[:2]
|
||||
dx, dy = abs(x1 - x0), abs(y1 - y0)
|
||||
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
|
||||
err, x, y, line_points = dx - dy, x0, y0, []
|
||||
|
||||
while True:
|
||||
line_points.append((x, y))
|
||||
if x == x1 and y == y1:
|
||||
break
|
||||
e2 = 2 * err
|
||||
if e2 > -dy:
|
||||
err, x = err - dy, x + sx
|
||||
if e2 < dx:
|
||||
err, y = err + dx, y + sy
|
||||
|
||||
if thickness > 1:
|
||||
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
|
||||
for px, py in line_points:
|
||||
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
|
||||
if y_max > y_min and x_max > x_min:
|
||||
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
|
||||
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
|
||||
else:
|
||||
line_points = np.array(line_points)
|
||||
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
|
||||
if (valid_points := line_points[valid]).size:
|
||||
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
|
||||
|
||||
@staticmethod
|
||||
def fillConvexPoly(canvas_np, pts, color, **kwargs):
|
||||
"""Fill polygon using vectorized scanline algorithm."""
|
||||
if len(pts) < 3:
|
||||
return
|
||||
pts = np.array(pts, dtype=np.int32)
|
||||
h, w = canvas_np.shape[:2]
|
||||
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
|
||||
if y_max <= y_min or x_max <= x_min:
|
||||
return
|
||||
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
|
||||
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
|
||||
|
||||
for i in range(len(pts)):
|
||||
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
|
||||
y1, y2 = p1[1], p2[1]
|
||||
if y1 == y2:
|
||||
continue
|
||||
if y1 > y2:
|
||||
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
|
||||
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
|
||||
continue
|
||||
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
|
||||
|
||||
canvas_np[y_min:y_max, x_min:x_max][mask] = color
|
||||
|
||||
@staticmethod
|
||||
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
|
||||
"""Python implementation of cv2.ellipse2Poly."""
|
||||
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
|
||||
angle = angle % 360
|
||||
if arc_start > arc_end:
|
||||
arc_start, arc_end = arc_end, arc_start
|
||||
while arc_start < 0:
|
||||
arc_start, arc_end = arc_start + 360, arc_end + 360
|
||||
while arc_end > 360:
|
||||
arc_end, arc_start = arc_end - 360, arc_start - 360
|
||||
if arc_end - arc_start > 360:
|
||||
arc_start, arc_end = 0, 360
|
||||
|
||||
angle_rad = math.radians(angle)
|
||||
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
|
||||
pts = []
|
||||
for i in range(arc_start, arc_end + delta, delta):
|
||||
theta_rad = math.radians(min(i, arc_end))
|
||||
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
|
||||
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
|
||||
|
||||
unique_pts, prev_pt = [], (float('inf'), float('inf'))
|
||||
for pt in pts:
|
||||
if (pt_tuple := tuple(pt)) != prev_pt:
|
||||
unique_pts.append(pt)
|
||||
prev_pt = pt_tuple
|
||||
|
||||
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
|
||||
|
||||
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
|
||||
draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3):
|
||||
"""
|
||||
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
|
||||
|
||||
Expected keypoint format (after neck insertion and remapping):
|
||||
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
|
||||
- Foot: 18-23 (6 keypoints)
|
||||
- Face: 24-91 (68 landmarks)
|
||||
- Right hand: 92-112 (21 keypoints)
|
||||
- Left hand: 113-133 (21 keypoints)
|
||||
|
||||
Args:
|
||||
canvas: The canvas to draw on (numpy array)
|
||||
keypoints: Array of keypoint coordinates
|
||||
scores: Optional confidence scores for each keypoint
|
||||
threshold: Minimum confidence threshold for drawing keypoints
|
||||
|
||||
Returns:
|
||||
canvas: The canvas with keypoints drawn
|
||||
"""
|
||||
H, W, C = canvas.shape
|
||||
|
||||
# Draw body limbs
|
||||
if draw_body and len(keypoints) >= 18:
|
||||
for i, limb in enumerate(self.body_limbSeq):
|
||||
# Convert from 1-indexed to 0-indexed
|
||||
idx1, idx2 = limb[0] - 1, limb[1] - 1
|
||||
|
||||
if idx1 >= 18 or idx2 >= 18:
|
||||
continue
|
||||
|
||||
if scores is not None:
|
||||
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||
continue
|
||||
|
||||
Y = [keypoints[idx1][0], keypoints[idx2][0]]
|
||||
X = [keypoints[idx1][1], keypoints[idx2][1]]
|
||||
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
|
||||
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
|
||||
|
||||
if length < 1:
|
||||
continue
|
||||
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
|
||||
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
|
||||
|
||||
self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)])
|
||||
|
||||
# Draw body keypoints
|
||||
if draw_body and len(keypoints) >= 18:
|
||||
for i in range(18):
|
||||
if scores is not None and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
|
||||
|
||||
# Draw foot keypoints (18-23, 6 keypoints)
|
||||
if draw_feet and len(keypoints) >= 24:
|
||||
for i in range(18, 24):
|
||||
if scores is not None and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
|
||||
|
||||
# Draw right hand (92-112)
|
||||
if draw_hands and len(keypoints) >= 113:
|
||||
eps = 0.01
|
||||
for ie, edge in enumerate(self.hand_edges):
|
||||
idx1, idx2 = 92 + edge[0], 92 + edge[1]
|
||||
if scores is not None:
|
||||
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||
continue
|
||||
|
||||
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
|
||||
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
|
||||
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
|
||||
# HSV to RGB conversion for rainbow colors
|
||||
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
|
||||
color = (int(r * 255), int(g * 255), int(b * 255))
|
||||
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
|
||||
|
||||
# Draw right hand keypoints
|
||||
for i in range(92, 113):
|
||||
if scores is not None and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
|
||||
# Draw left hand (113-133)
|
||||
if draw_hands and len(keypoints) >= 134:
|
||||
eps = 0.01
|
||||
for ie, edge in enumerate(self.hand_edges):
|
||||
idx1, idx2 = 113 + edge[0], 113 + edge[1]
|
||||
if scores is not None:
|
||||
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||
continue
|
||||
|
||||
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
|
||||
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
|
||||
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
|
||||
# HSV to RGB conversion for rainbow colors
|
||||
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
|
||||
color = (int(r * 255), int(g * 255), int(b * 255))
|
||||
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
|
||||
|
||||
# Draw left hand keypoints
|
||||
for i in range(113, 134):
|
||||
if scores is not None and i < len(scores) and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
|
||||
# Draw face keypoints (24-91) - white dots only, no lines
|
||||
if draw_face and len(keypoints) >= 92:
|
||||
eps = 0.01
|
||||
for i in range(24, 92):
|
||||
if scores is not None and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
|
||||
|
||||
return canvas
|
||||
|
||||
class SDPoseDrawKeypoints(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
|
||||
348
comfy_extras/pose/keypoint_draw.py
Normal file
348
comfy_extras/pose/keypoint_draw.py
Normal file
@ -0,0 +1,348 @@
|
||||
"""Pose keypoint drawing primitives shared across pose nodes.
|
||||
|
||||
`KeypointDraw` exposes a cv2-or-numpy backend so callers can use one drawing
|
||||
API regardless of whether OpenCV is installed:
|
||||
|
||||
kd = KeypointDraw()
|
||||
kd.draw.circle(canvas, (x, y), radius, color, thickness=-1)
|
||||
kd.draw.line(canvas, p1, p2, color, thickness=4)
|
||||
kd.draw.fillConvexPoly(canvas, polygon, color)
|
||||
kd.draw.ellipse2Poly(center, axes, angle, 0, 360, 1)
|
||||
|
||||
It also carries DWPose's body/hand topology + color tables, used by:
|
||||
- comfy_extras.nodes_sdpose (SDPose pose drawing)
|
||||
- comfy_extras.pose.export.openpose_2d (SAM 3D Body 2D pose viz)
|
||||
- comfy_extras.pose.export.glb_shared (SAM 3D Body GLB tables)
|
||||
"""
|
||||
|
||||
import colorsys
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class KeypointDraw:
|
||||
"""
|
||||
Pose keypoint drawing class that supports both numpy and cv2 backends.
|
||||
"""
|
||||
def __init__(self):
|
||||
try:
|
||||
import cv2
|
||||
self.draw = cv2
|
||||
except ImportError:
|
||||
self.draw = self
|
||||
|
||||
# Hand connections (same for both hands)
|
||||
self.hand_edges = [
|
||||
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
|
||||
[0, 5], [5, 6], [6, 7], [7, 8], # index
|
||||
[0, 9], [9, 10], [10, 11], [11, 12], # middle
|
||||
[0, 13], [13, 14], [14, 15], [15, 16], # ring
|
||||
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
|
||||
]
|
||||
|
||||
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed)
|
||||
self.body_limbSeq = [
|
||||
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
|
||||
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
|
||||
[1, 16], [16, 18]
|
||||
]
|
||||
|
||||
# Colors matching DWPose
|
||||
self.colors = [
|
||||
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
|
||||
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
|
||||
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
|
||||
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def circle(canvas_np, center, radius, color, **kwargs):
|
||||
"""Draw a filled circle using NumPy vectorized operations."""
|
||||
cx, cy = center
|
||||
h, w = canvas_np.shape[:2]
|
||||
|
||||
radius_int = int(np.ceil(radius))
|
||||
|
||||
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
|
||||
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
|
||||
|
||||
if y_max <= y_min or x_max <= x_min:
|
||||
return
|
||||
|
||||
y, x = np.ogrid[y_min:y_max, x_min:x_max]
|
||||
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
|
||||
canvas_np[y_min:y_max, x_min:x_max][mask] = color
|
||||
|
||||
@staticmethod
|
||||
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
|
||||
"""Draw line using Bresenham's algorithm with NumPy operations."""
|
||||
x0, y0, x1, y1 = *pt1, *pt2
|
||||
h, w = canvas_np.shape[:2]
|
||||
dx, dy = abs(x1 - x0), abs(y1 - y0)
|
||||
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
|
||||
err, x, y, line_points = dx - dy, x0, y0, []
|
||||
|
||||
while True:
|
||||
line_points.append((x, y))
|
||||
if x == x1 and y == y1:
|
||||
break
|
||||
e2 = 2 * err
|
||||
if e2 > -dy:
|
||||
err, x = err - dy, x + sx
|
||||
if e2 < dx:
|
||||
err, y = err + dx, y + sy
|
||||
|
||||
if thickness > 1:
|
||||
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
|
||||
for px, py in line_points:
|
||||
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
|
||||
if y_max > y_min and x_max > x_min:
|
||||
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
|
||||
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
|
||||
else:
|
||||
line_points = np.array(line_points)
|
||||
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
|
||||
if (valid_points := line_points[valid]).size:
|
||||
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
|
||||
|
||||
@staticmethod
|
||||
def fillConvexPoly(canvas_np, pts, color, **kwargs):
|
||||
"""Fill polygon using vectorized scanline algorithm."""
|
||||
if len(pts) < 3:
|
||||
return
|
||||
pts = np.array(pts, dtype=np.int32)
|
||||
h, w = canvas_np.shape[:2]
|
||||
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
|
||||
if y_max <= y_min or x_max <= x_min:
|
||||
return
|
||||
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
|
||||
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
|
||||
|
||||
for i in range(len(pts)):
|
||||
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
|
||||
y1, y2 = p1[1], p2[1]
|
||||
if y1 == y2:
|
||||
continue
|
||||
if y1 > y2:
|
||||
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
|
||||
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
|
||||
continue
|
||||
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
|
||||
|
||||
canvas_np[y_min:y_max, x_min:x_max][mask] = color
|
||||
|
||||
@staticmethod
|
||||
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
|
||||
"""Python implementation of cv2.ellipse2Poly."""
|
||||
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
|
||||
angle = angle % 360
|
||||
if arc_start > arc_end:
|
||||
arc_start, arc_end = arc_end, arc_start
|
||||
while arc_start < 0:
|
||||
arc_start, arc_end = arc_start + 360, arc_end + 360
|
||||
while arc_end > 360:
|
||||
arc_end, arc_start = arc_end - 360, arc_start - 360
|
||||
if arc_end - arc_start > 360:
|
||||
arc_start, arc_end = 0, 360
|
||||
|
||||
angle_rad = math.radians(angle)
|
||||
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
|
||||
pts = []
|
||||
for i in range(arc_start, arc_end + delta, delta):
|
||||
theta_rad = math.radians(min(i, arc_end))
|
||||
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
|
||||
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
|
||||
|
||||
unique_pts, prev_pt = [], (float('inf'), float('inf'))
|
||||
for pt in pts:
|
||||
if (pt_tuple := tuple(pt)) != prev_pt:
|
||||
unique_pts.append(pt)
|
||||
prev_pt = pt_tuple
|
||||
|
||||
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
|
||||
|
||||
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
|
||||
draw_body=True, draw_feet=True, draw_face=True, draw_hands=True,
|
||||
stick_width=4, face_point_size=3,
|
||||
marker_radius=4, hand_stick_width=2, hand_marker_radius=4,
|
||||
limb_alpha=1.0, hand_dot_color=(0, 0, 255)):
|
||||
"""
|
||||
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
|
||||
|
||||
Expected keypoint format (after neck insertion and remapping):
|
||||
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
|
||||
- Foot: 18-23 (6 keypoints)
|
||||
- Face: 24-91 (68 landmarks)
|
||||
- Right hand: 92-112 (21 keypoints)
|
||||
- Left hand: 113-133 (21 keypoints)
|
||||
|
||||
Args:
|
||||
canvas: The canvas to draw on (numpy array)
|
||||
keypoints: Array of keypoint coordinates
|
||||
scores: Optional confidence scores for each keypoint
|
||||
threshold: Minimum confidence threshold for drawing keypoints
|
||||
stick_width: Body limb half-width (passed to ellipse2Poly).
|
||||
face_point_size: Radius of the white face dots.
|
||||
marker_radius: Radius of body/foot dots. Defaults to 4 (DWPose).
|
||||
hand_stick_width: Thickness of hand limb lines. Defaults to 2.
|
||||
hand_marker_radius: Radius of hand dots. Defaults to 4.
|
||||
limb_alpha: Body-limb alpha blend (0..1). 1.0 = opaque fill (default),
|
||||
<1.0 enables per-limb bbox-clipped alpha overlay (DWPose semantics
|
||||
where overlapping limbs darken).
|
||||
hand_dot_color: Either an (R, G, B) tuple/list of ints for solid-color
|
||||
hand dots (default (0, 0, 255), DWPose blue), or a (21, 3) array
|
||||
for per-keypoint hand-dot colors (OpenPose-style rainbow palette).
|
||||
|
||||
Returns:
|
||||
canvas: The canvas with keypoints drawn
|
||||
"""
|
||||
H, W, C = canvas.shape
|
||||
|
||||
# Normalize hand_dot_color to a (21, 3) int array.
|
||||
hdc_arr = np.asarray(hand_dot_color, dtype=int)
|
||||
if hdc_arr.ndim == 1:
|
||||
hdc_arr = np.tile(hdc_arr.reshape(1, 3), (21, 1))
|
||||
hand_dot_tuples = [tuple(int(c) for c in hdc_arr[i]) for i in range(21)]
|
||||
|
||||
do_alpha = float(limb_alpha) < 1.0
|
||||
|
||||
# Draw body limbs
|
||||
if draw_body and len(keypoints) >= 18:
|
||||
for i, limb in enumerate(self.body_limbSeq):
|
||||
# Convert from 1-indexed to 0-indexed
|
||||
idx1, idx2 = limb[0] - 1, limb[1] - 1
|
||||
|
||||
if idx1 >= 18 or idx2 >= 18:
|
||||
continue
|
||||
|
||||
if scores is not None:
|
||||
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||
continue
|
||||
|
||||
Y = [keypoints[idx1][0], keypoints[idx2][0]]
|
||||
X = [keypoints[idx1][1], keypoints[idx2][1]]
|
||||
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
|
||||
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
|
||||
|
||||
if length < 1:
|
||||
continue
|
||||
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
|
||||
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
|
||||
|
||||
color = self.colors[i % len(self.colors)]
|
||||
if do_alpha:
|
||||
_fill_poly_alpha(canvas, polygon, color, limb_alpha, self.draw)
|
||||
else:
|
||||
self.draw.fillConvexPoly(canvas, polygon, color)
|
||||
|
||||
# Draw body keypoints
|
||||
if draw_body and len(keypoints) >= 18:
|
||||
for i in range(18):
|
||||
if scores is not None and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), marker_radius, self.colors[i % len(self.colors)], thickness=-1)
|
||||
|
||||
# Draw foot keypoints (18-23, 6 keypoints)
|
||||
if draw_feet and len(keypoints) >= 24:
|
||||
for i in range(18, 24):
|
||||
if scores is not None and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), marker_radius, self.colors[i % len(self.colors)], thickness=-1)
|
||||
|
||||
# Draw right hand (92-112)
|
||||
if draw_hands and len(keypoints) >= 113:
|
||||
eps = 0.01
|
||||
for ie, edge in enumerate(self.hand_edges):
|
||||
idx1, idx2 = 92 + edge[0], 92 + edge[1]
|
||||
if scores is not None:
|
||||
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||
continue
|
||||
|
||||
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
|
||||
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
|
||||
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
|
||||
# HSV to RGB conversion for rainbow colors
|
||||
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
|
||||
color = (int(r * 255), int(g * 255), int(b * 255))
|
||||
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=hand_stick_width)
|
||||
|
||||
# Draw right hand keypoints
|
||||
for i in range(92, 113):
|
||||
if scores is not None and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), hand_marker_radius, hand_dot_tuples[i - 92], thickness=-1)
|
||||
|
||||
# Draw left hand (113-133)
|
||||
if draw_hands and len(keypoints) >= 134:
|
||||
eps = 0.01
|
||||
for ie, edge in enumerate(self.hand_edges):
|
||||
idx1, idx2 = 113 + edge[0], 113 + edge[1]
|
||||
if scores is not None:
|
||||
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||
continue
|
||||
|
||||
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
|
||||
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
|
||||
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
|
||||
# HSV to RGB conversion for rainbow colors
|
||||
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
|
||||
color = (int(r * 255), int(g * 255), int(b * 255))
|
||||
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=hand_stick_width)
|
||||
|
||||
# Draw left hand keypoints
|
||||
for i in range(113, 134):
|
||||
if scores is not None and i < len(scores) and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), hand_marker_radius, hand_dot_tuples[i - 113], thickness=-1)
|
||||
|
||||
# Draw face keypoints (24-91) - white dots only, no lines
|
||||
if draw_face and len(keypoints) >= 92:
|
||||
eps = 0.01
|
||||
for i in range(24, 92):
|
||||
if scores is not None and scores[i] < threshold:
|
||||
continue
|
||||
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
def _fill_poly_alpha(canvas, polygon, color, alpha, draw_backend):
|
||||
"""Bbox-clipped alpha-blended fillConvexPoly. `canvas` is mutated in-place.
|
||||
|
||||
DWPose semantics: each limb blends with `alpha` independently so overlapping
|
||||
limbs darken further. Operates on the polygon's bbox to avoid copying the
|
||||
whole canvas per limb.
|
||||
"""
|
||||
H, W = canvas.shape[:2]
|
||||
poly_arr = np.asarray(polygon, dtype=np.int32)
|
||||
x0 = max(0, int(poly_arr[:, 0].min()))
|
||||
xN = min(W, int(poly_arr[:, 0].max()) + 1)
|
||||
y0 = max(0, int(poly_arr[:, 1].min()))
|
||||
yN = min(H, int(poly_arr[:, 1].max()) + 1)
|
||||
if xN <= x0 or yN <= y0:
|
||||
return
|
||||
local_poly = poly_arr - np.array([x0, y0], dtype=poly_arr.dtype)
|
||||
roi = canvas[y0:yN, x0:xN].copy()
|
||||
draw_backend.fillConvexPoly(roi, local_poly, color)
|
||||
a = float(alpha)
|
||||
canvas[y0:yN, x0:xN] = np.clip(
|
||||
roi.astype(np.float32) * a + canvas[y0:yN, x0:xN].astype(np.float32) * (1.0 - a),
|
||||
0, 255,
|
||||
).astype(np.uint8)
|
||||
207
comfy_extras/sam3d_body/export/bvh.py
Normal file
207
comfy_extras/sam3d_body/export/bvh.py
Normal file
@ -0,0 +1,207 @@
|
||||
"""BVH export for SAM 3D Body pose_data.
|
||||
|
||||
BVH stores explicit bone OFFSETs per joint, so any standard importer
|
||||
(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations
|
||||
directly — no heuristic guessing as needed for glTF. We skip the rig's joint 0
|
||||
(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos +
|
||||
ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are
|
||||
intrinsic Z-X-Y Euler degrees.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .glb_shared import (
|
||||
bind_skel_state,
|
||||
bone_locals_from_globals,
|
||||
collect_tracks,
|
||||
extract_rig_static,
|
||||
global_skel_state_from_pose_data,
|
||||
quat_sign_fix_per_joint,
|
||||
unflip,
|
||||
)
|
||||
|
||||
|
||||
def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray:
|
||||
"""xyzw quat → intrinsic Z-X-Y Euler degrees, returned as (..., 3) in
|
||||
(z, x, y) order to match BVH's `CHANNELS Zrotation Xrotation Yrotation`."""
|
||||
q = np.asarray(quat, dtype=np.float64)
|
||||
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
||||
# ZXY decomposition R = Rz(c)·Rx(a)·Ry(b):
|
||||
# M[2][1] = 2(yz + xw) → sin(a)
|
||||
# M[0][1] = 2(xy - zw) → -cos(a) sin(c)
|
||||
# M[1][1] = 1 - 2(x² + z²) → cos(a) cos(c)
|
||||
# M[2][0] = 2(xz - yw) → -cos(a) sin(b)
|
||||
# M[2][2] = 1 - 2(x² + y²) → cos(a) cos(b)
|
||||
M21 = np.clip(2.0 * (y * z + x * w), -1.0, 1.0)
|
||||
M01 = 2.0 * (x * y - z * w)
|
||||
M11 = 1.0 - 2.0 * (x * x + z * z)
|
||||
M20 = 2.0 * (x * z - y * w)
|
||||
M22 = 1.0 - 2.0 * (x * x + y * y)
|
||||
a = np.arcsin(M21)
|
||||
c = np.arctan2(-M01, M11)
|
||||
b = np.arctan2(-M20, M22)
|
||||
out = np.stack([np.rad2deg(c), np.rad2deg(a), np.rad2deg(b)], axis=-1)
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
def _find_bvh_root(parents: np.ndarray) -> int:
|
||||
"""First child of the rig's world anchor so the static origin→body stick
|
||||
bone gets left out. Falls back to the first root joint."""
|
||||
NJ = parents.shape[0]
|
||||
world_anchors = [j for j in range(NJ)
|
||||
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
|
||||
if not world_anchors:
|
||||
return 0
|
||||
children: List[List[int]] = [[] for _ in range(NJ)]
|
||||
for j in range(NJ):
|
||||
p = int(parents[j])
|
||||
if 0 <= p < NJ and p != j:
|
||||
children[p].append(j)
|
||||
wa = world_anchors[0]
|
||||
if children[wa]:
|
||||
return children[wa][0]
|
||||
return wa
|
||||
|
||||
|
||||
def _build_children_map(parents: np.ndarray) -> List[List[int]]:
|
||||
NJ = parents.shape[0]
|
||||
out: List[List[int]] = [[] for _ in range(NJ)]
|
||||
for j in range(NJ):
|
||||
p = int(parents[j])
|
||||
if 0 <= p < NJ and p != j:
|
||||
out[p].append(j)
|
||||
return out
|
||||
|
||||
|
||||
def build_bvh(
|
||||
pose_data: Dict[str, Any],
|
||||
model: Any,
|
||||
*,
|
||||
fps: float = 24.0,
|
||||
camera_translation: str = "off",
|
||||
track_index: int = -1,
|
||||
units: str = "cm",
|
||||
) -> bytes:
|
||||
"""Build a BVH file from pose_data. Returns UTF-8 encoded text bytes.
|
||||
|
||||
`units` is "cm" (default, standard mocap convention) or "m". Affects the
|
||||
OFFSET and root-position values; rotations are independent of units.
|
||||
"""
|
||||
if units not in ("cm", "m"):
|
||||
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
|
||||
unit_scale = 100.0 if units == "cm" else 1.0
|
||||
|
||||
rig_static = extract_rig_static(model)
|
||||
NJ = int(rig_static["num_joints"])
|
||||
parents = rig_static["parents"]
|
||||
frames = pose_data["frames"]
|
||||
|
||||
tracks = collect_tracks(pose_data, track_index)
|
||||
if not tracks:
|
||||
raise ValueError("build_bvh: no valid tracks in pose_data")
|
||||
person_k, frame_indices = tracks[0]
|
||||
n_frames = len(frame_indices)
|
||||
if n_frames == 0:
|
||||
raise ValueError("build_bvh: track has zero frames")
|
||||
|
||||
body_root = _find_bvh_root(parents)
|
||||
children_map = _build_children_map(parents)
|
||||
|
||||
# Bone OFFSETs come from MHR's translation_offsets (joint position
|
||||
# relative to parent in parent's local-bind frame). For the BVH root,
|
||||
# we use its bind world position so the skeleton sits at the right
|
||||
# spot when imported.
|
||||
bind_global = bind_skel_state(model) # (NJ, 8) cm
|
||||
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
|
||||
offset_m = rig_static["joint_translation_offsets"].astype(np.float64) * 0.01
|
||||
|
||||
# DFS order rooted at body_root — matches per-frame channel order.
|
||||
bvh_order: List[int] = []
|
||||
def _visit(j: int) -> None:
|
||||
bvh_order.append(j)
|
||||
for c in children_map[j]:
|
||||
_visit(c)
|
||||
_visit(body_root)
|
||||
|
||||
# Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative)
|
||||
# rather than re-running rig.forward, then derive locals with body_root
|
||||
# treated as the hierarchy root in BVH-space.
|
||||
rig_global_m = global_skel_state_from_pose_data(
|
||||
pose_data, frame_indices, person_k, NJ,
|
||||
)
|
||||
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
|
||||
bvh_parents = parents.copy()
|
||||
bvh_parents[body_root] = -1
|
||||
bone_local = bone_locals_from_globals(rig_global_m, bvh_parents)
|
||||
# Second pass catches sign discontinuities from the parent-inverse composition.
|
||||
bone_local[..., 3:7] = quat_sign_fix_per_joint(bone_local[..., 3:7])
|
||||
|
||||
eulers_deg = _quat_to_zxy_euler_deg(bone_local[..., 3:7])
|
||||
|
||||
if camera_translation in ("absolute", "centered"):
|
||||
cam_t = np.stack([
|
||||
unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32))
|
||||
for t in frame_indices
|
||||
], axis=0).astype(np.float64)
|
||||
if camera_translation == "centered":
|
||||
cam_t = cam_t - cam_t[0:1]
|
||||
root_pos_m = bind_pos_m[body_root][None, :] + cam_t
|
||||
else:
|
||||
root_pos_m = np.tile(bind_pos_m[body_root], (n_frames, 1))
|
||||
|
||||
lines: List[str] = ["HIERARCHY"]
|
||||
|
||||
def _emit_joint(j: int, depth: int, is_root: bool) -> None:
|
||||
ind = " " * depth
|
||||
keyword = "ROOT" if is_root else "JOINT"
|
||||
name = "Hips" if is_root else f"joint_{j:03d}"
|
||||
lines.append(f"{ind}{keyword} {name}")
|
||||
lines.append(ind + "{")
|
||||
o = (bind_pos_m[j] if is_root else offset_m[j]) * unit_scale
|
||||
lines.append(f"{ind} OFFSET {o[0]:.6f} {o[1]:.6f} {o[2]:.6f}")
|
||||
if is_root:
|
||||
lines.append(ind + " CHANNELS 6 Xposition Yposition Zposition "
|
||||
"Zrotation Xrotation Yrotation")
|
||||
else:
|
||||
lines.append(ind + " CHANNELS 3 Zrotation Xrotation Yrotation")
|
||||
kids = children_map[j]
|
||||
if kids:
|
||||
for c in kids:
|
||||
_emit_joint(c, depth + 1, is_root=False)
|
||||
else:
|
||||
# End Site (standard BVH spec) gives leaf bones a drawable length.
|
||||
lines.append(ind + " End Site")
|
||||
lines.append(ind + " {")
|
||||
tip = (offset_m[j] * unit_scale) * 0.3
|
||||
tip_norm = float(np.linalg.norm(tip))
|
||||
if tip_norm < 0.5 * unit_scale * 0.01: # < 0.5 mm → fall back
|
||||
tip = np.array([0.0, 0.05 * unit_scale, 0.0], dtype=np.float64)
|
||||
lines.append(f"{ind} OFFSET {tip[0]:.6f} {tip[1]:.6f} {tip[2]:.6f}")
|
||||
lines.append(ind + " }")
|
||||
lines.append(ind + "}")
|
||||
|
||||
_emit_joint(body_root, 0, is_root=True)
|
||||
|
||||
lines.append("MOTION")
|
||||
lines.append(f"Frames: {n_frames}")
|
||||
lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
|
||||
|
||||
# Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per
|
||||
# frame, columns in `bvh_order` order. Vectorized — savetxt's C-side
|
||||
# formatting beats Python f-strings by ~10× on long clips.
|
||||
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
|
||||
motion = np.concatenate([
|
||||
root_pos_m * unit_scale, # (N, 3)
|
||||
eulers_deg[:, body_root].astype(np.float64), # (N, 3)
|
||||
eulers_deg[:, non_root_idx, :].reshape(n_frames, -1), # (N, 3*(NJ-1))
|
||||
], axis=1)
|
||||
buf = io.StringIO()
|
||||
np.savetxt(buf, motion, fmt="%.6f")
|
||||
lines.append(buf.getvalue().rstrip("\n"))
|
||||
|
||||
return ("\n".join(lines) + "\n").encode("utf-8")
|
||||
403
comfy_extras/sam3d_body/export/capsules.py
Normal file
403
comfy_extras/sam3d_body/export/capsules.py
Normal file
@ -0,0 +1,403 @@
|
||||
"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent
|
||||
torch ray-marching SDF renderer adapted to SAM3DBody pose_data.
|
||||
|
||||
Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps),
|
||||
projected through the per-person camera (`pred_cam_t` + `focal_length` +
|
||||
image_size) so closer limbs appear thicker/brighter — the SCAIL-Pose
|
||||
visual style. Self-contained: no dependency on the SCAIL-Pose package.
|
||||
|
||||
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
from .glb_shared import (
|
||||
OPENPOSE_18_PAIRS,
|
||||
OPENPOSE18_TO_MHR70,
|
||||
OPENPOSE_RAINBOW_18,
|
||||
SCAIL_LIMB_COLORS_17,
|
||||
OPENPOSE_HAND_PAIRS,
|
||||
OPENPOSE_HAND21_TO_MHR70_R,
|
||||
OPENPOSE_HAND21_TO_MHR70_L,
|
||||
OPENPOSE_HAND_COLORS_21,
|
||||
)
|
||||
|
||||
|
||||
def _limb_palette_rgb01(palette: str) -> np.ndarray:
|
||||
"""17 per-limb RGB colors in [0,1] for the OpenPose-18 body limbs."""
|
||||
if palette == "scail":
|
||||
return SCAIL_LIMB_COLORS_17.astype(np.float32)
|
||||
return OPENPOSE_RAINBOW_18[: len(OPENPOSE_18_PAIRS)].astype(np.float32)
|
||||
|
||||
|
||||
def _build_specs_from_pose(
|
||||
persons: List[Dict[str, Any]],
|
||||
*,
|
||||
include_hands: bool,
|
||||
palette: str,
|
||||
person_brightness_falloff: float = 0.0,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Flatten body + optional hand limbs for one frame into
|
||||
(starts, ends, colors_rgba) in camera coords (Y-down, +Z forward).
|
||||
Drops endpoints that are non-finite or behind the camera.
|
||||
|
||||
`person_brightness_falloff` mixes each per-person limb color toward white
|
||||
by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the
|
||||
mesh rasterizer and GLB exporters."""
|
||||
starts: List[np.ndarray] = []
|
||||
ends: List[np.ndarray] = []
|
||||
colors: List[np.ndarray] = []
|
||||
|
||||
body_limb_colors = _limb_palette_rgb01(palette)
|
||||
hand_limb_colors = OPENPOSE_HAND_COLORS_21.astype(np.float32)
|
||||
|
||||
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
|
||||
|
||||
for k, person in enumerate(persons):
|
||||
kp2d_full = person.get("pred_keypoints_3d")
|
||||
cam_t = person.get("pred_cam_t")
|
||||
if kp2d_full is None or cam_t is None:
|
||||
continue
|
||||
kp = np.asarray(kp2d_full, dtype=np.float32)
|
||||
if kp.ndim != 2 or kp.shape[1] != 3 or kp.shape[0] < 70:
|
||||
continue
|
||||
cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3)
|
||||
# pred_keypoints_3d is camera frame (Y-down post-flip); add cam_t to
|
||||
# place the subject in front of the camera.
|
||||
kp_cam = kp + cam_t_np[None, :]
|
||||
|
||||
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
|
||||
|
||||
def _tint(rgb: np.ndarray) -> np.ndarray:
|
||||
if pastel <= 0:
|
||||
return rgb
|
||||
return rgb * (1.0 - pastel) + pastel
|
||||
|
||||
# SCAIL drops the 4 face bones (13..16: nose↔eyes, eyes→ears) — in the
|
||||
# reference, NLF leaves those COCO slots at zero so its `sum==0` skip
|
||||
# silently culls them. The grey neck limb (12) blends spine direction
|
||||
# (mid-hip → neck, stable) with the neck→nose direction at 60/40 so
|
||||
# the stub tracks head pose lightly without flapping around like full
|
||||
# nose direction does.
|
||||
body_limb_count = 13 if palette == "scail" else len(OPENPOSE_18_PAIRS)
|
||||
body_kp = kp_cam[OPENPOSE18_TO_MHR70]
|
||||
spine_dir = None
|
||||
if palette == "scail":
|
||||
mid_hip = 0.5 * (body_kp[8] + body_kp[11]) # 8=RHip, 11=LHip
|
||||
sd = body_kp[1] - mid_hip # 1=Neck
|
||||
sd_len = float(np.linalg.norm(sd))
|
||||
if np.all(np.isfinite(sd)) and sd_len > 1e-6:
|
||||
spine_dir = sd / sd_len
|
||||
for limb_i, (a, b) in enumerate(OPENPOSE_18_PAIRS[:body_limb_count]):
|
||||
sa, sb = body_kp[a], body_kp[b]
|
||||
if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))):
|
||||
continue
|
||||
if sa[2] <= 0 or sb[2] <= 0:
|
||||
continue
|
||||
if palette == "scail" and limb_i == 12:
|
||||
nose_vec = sb - sa
|
||||
nose_len = float(np.linalg.norm(nose_vec))
|
||||
if nose_len > 1e-6 and spine_dir is not None:
|
||||
nose_dir = nose_vec / nose_len
|
||||
mixed = 0.6 * spine_dir + 0.4 * nose_dir
|
||||
mixed = mixed / max(float(np.linalg.norm(mixed)), 1e-6)
|
||||
sb = sa + mixed * (nose_len * 0.5)
|
||||
elif nose_len > 1e-6:
|
||||
sb = sa + nose_vec * 0.5
|
||||
elif spine_dir is not None:
|
||||
sb = sa + spine_dir * (sd_len * 0.3)
|
||||
starts.append(sa)
|
||||
ends.append(sb)
|
||||
color_rgb = _tint(body_limb_colors[limb_i])
|
||||
colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0],
|
||||
dtype=np.float32))
|
||||
|
||||
if include_hands:
|
||||
r_kp = kp_cam[OPENPOSE_HAND21_TO_MHR70_R]
|
||||
l_kp = kp_cam[OPENPOSE_HAND21_TO_MHR70_L]
|
||||
for limb_i, (a, b) in enumerate(OPENPOSE_HAND_PAIRS):
|
||||
for hand_kp in (r_kp, l_kp):
|
||||
sa, sb = hand_kp[a], hand_kp[b]
|
||||
if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))):
|
||||
continue
|
||||
if sa[2] <= 0 or sb[2] <= 0:
|
||||
continue
|
||||
starts.append(sa)
|
||||
ends.append(sb)
|
||||
color_rgb = _tint(hand_limb_colors[(a + b) % len(hand_limb_colors)])
|
||||
colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0],
|
||||
dtype=np.float32))
|
||||
|
||||
if not starts:
|
||||
return (np.zeros((0, 3), dtype=np.float32),
|
||||
np.zeros((0, 3), dtype=np.float32),
|
||||
np.zeros((0, 4), dtype=np.float32))
|
||||
return (np.stack(starts).astype(np.float32),
|
||||
np.stack(ends).astype(np.float32),
|
||||
np.stack(colors).astype(np.float32))
|
||||
|
||||
|
||||
def _ray_capsule_t(
|
||||
ray_dirs: torch.Tensor, # (K, 3) unit rays from camera origin
|
||||
starts: torch.Tensor, # (M, 3)
|
||||
ends: torch.Tensor, # (M, 3)
|
||||
ba_norm: torch.Tensor, # (M, 3) unit axis (A → B)
|
||||
ba_len: torch.Tensor, # (M,) segment length
|
||||
radius: float,
|
||||
) -> torch.Tensor:
|
||||
"""Closed-form ray-capsule intersection. Returns (K, M) tensor of ray
|
||||
parameters t to the nearest valid hit per capsule, +inf where the ray
|
||||
misses. A capsule is the union of (cylinder body, hemisphere at A,
|
||||
hemisphere at B); each component is a quadratic root-find."""
|
||||
INF = float("inf")
|
||||
r_sq = float(radius) * float(radius)
|
||||
|
||||
# Cached dot products.
|
||||
dn = ray_dirs @ ba_norm.transpose(0, 1) # (K, M) — d·n
|
||||
dA = ray_dirs @ starts.transpose(0, 1) # (K, M) — d·A
|
||||
dB = ray_dirs @ ends.transpose(0, 1) # (K, M) — d·B
|
||||
An = (starts * ba_norm).sum(-1) # (M,) — A·n
|
||||
A_sq = (starts * starts).sum(-1) # (M,) — |A|²
|
||||
B_sq = (ends * ends).sum(-1) # (M,) — |B|²
|
||||
|
||||
# Cylinder body: project onto plane ⊥ n and solve |P_⊥(t)|² = r².
|
||||
a_c = 1.0 - dn * dn # (K, M)
|
||||
b_c = -2.0 * (dA - dn * An) # (K, M)
|
||||
c_c = A_sq - An * An - r_sq # (M,)
|
||||
disc_c = b_c * b_c - 4.0 * a_c * c_c
|
||||
safe_a = a_c.clamp(min=1e-9)
|
||||
sqrt_c = torch.sqrt(disc_c.clamp(min=0.0))
|
||||
t_cyl = (-b_c - sqrt_c) / (2.0 * safe_a)
|
||||
s_cyl = t_cyl * dn - An # axial projection from A
|
||||
cyl_ok = (disc_c >= 0) & (a_c > 1e-7) & (t_cyl > 1e-6) & \
|
||||
(s_cyl >= 0.0) & (s_cyl <= ba_len)
|
||||
t_cyl = torch.where(cyl_ok, t_cyl, torch.full_like(t_cyl, INF))
|
||||
|
||||
# Sphere at A, restricted to the hemisphere with axial projection ≤ 0.
|
||||
disc_a = dA * dA - (A_sq - r_sq)
|
||||
sqrt_a = torch.sqrt(disc_a.clamp(min=0.0))
|
||||
t_sa = dA - sqrt_a
|
||||
s_a = t_sa * dn - An
|
||||
a_ok = (disc_a >= 0) & (t_sa > 1e-6) & (s_a <= 0.0)
|
||||
t_sa = torch.where(a_ok, t_sa, torch.full_like(t_sa, INF))
|
||||
|
||||
# Sphere at B, restricted to the hemisphere with axial projection ≥ ba_len.
|
||||
disc_b = dB * dB - (B_sq - r_sq)
|
||||
sqrt_b = torch.sqrt(disc_b.clamp(min=0.0))
|
||||
t_sb = dB - sqrt_b
|
||||
s_b = t_sb * dn - An
|
||||
b_ok = (disc_b >= 0) & (t_sb > 1e-6) & (s_b >= ba_len)
|
||||
t_sb = torch.where(b_ok, t_sb, torch.full_like(t_sb, INF))
|
||||
|
||||
return torch.minimum(torch.minimum(t_cyl, t_sa), t_sb)
|
||||
|
||||
|
||||
def _render_capsules_torch(
|
||||
starts: torch.Tensor,
|
||||
ends: torch.Tensor,
|
||||
colors: torch.Tensor,
|
||||
H: int, W: int,
|
||||
fx: float, fy: float, cx: float, cy: float,
|
||||
radius: float,
|
||||
background_rgb: Optional[torch.Tensor],
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Analytic ray-capsule renderer for a union of capsules. Camera at
|
||||
origin looking down +Z; pixels in y-down screen coords."""
|
||||
M = int(starts.shape[0])
|
||||
if M == 0:
|
||||
if background_rgb is not None:
|
||||
return background_rgb.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)
|
||||
return torch.zeros(H, W, 3, dtype=torch.float32, device=device)
|
||||
|
||||
yy, xx = torch.meshgrid(
|
||||
torch.arange(H, device=device, dtype=torch.float32),
|
||||
torch.arange(W, device=device, dtype=torch.float32),
|
||||
indexing="ij",
|
||||
)
|
||||
u = (xx - cx) / fx
|
||||
v = (yy - cy) / fy
|
||||
z = torch.ones_like(u)
|
||||
ray_dirs = torch.stack([u, v, z], dim=-1)
|
||||
ray_dirs = ray_dirs / torch.linalg.norm(ray_dirs, dim=-1, keepdim=True)
|
||||
flat_dirs = ray_dirs.view(-1, 3)
|
||||
N = flat_dirs.shape[0]
|
||||
|
||||
ba = ends - starts
|
||||
ba_len = torch.linalg.norm(ba, dim=1).clamp(min=1e-6)
|
||||
ba_norm = ba / ba_len.unsqueeze(1)
|
||||
|
||||
z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item()))
|
||||
z_near = max(0.05, z_min - radius)
|
||||
|
||||
# Union of per-capsule screen-space bboxes. Pixels outside this mask
|
||||
# provably can't hit any capsule, so the analytic intersection only runs
|
||||
# on the relevant subset of the canvas (~5-15% at 1080p for typical poses).
|
||||
sz = starts[:, 2].clamp(min=z_near)
|
||||
ez = ends[:, 2].clamp(min=z_near)
|
||||
sx_p = starts[:, 0] * fx / sz + cx
|
||||
sy_p = starts[:, 1] * fy / sz + cy
|
||||
ex_p = ends[:, 0] * fx / ez + cx
|
||||
ey_p = ends[:, 1] * fy / ez + cy
|
||||
# Projected radius using the closer endpoint — conservative bbox.
|
||||
r_pix = radius * fx / torch.minimum(sz, ez)
|
||||
pad = 2.0
|
||||
xmin_t = (torch.minimum(sx_p, ex_p) - r_pix - pad).floor().long().clamp(min=0, max=W)
|
||||
xmax_t = (torch.maximum(sx_p, ex_p) + r_pix + pad).ceil().long().clamp(min=0, max=W)
|
||||
ymin_t = (torch.minimum(sy_p, ey_p) - r_pix - pad).floor().long().clamp(min=0, max=H)
|
||||
ymax_t = (torch.maximum(sy_p, ey_p) + r_pix + pad).ceil().long().clamp(min=0, max=H)
|
||||
# One stack→tolist sync amortizes the GPU→CPU read over all M bboxes.
|
||||
bboxes_cpu = torch.stack([xmin_t, ymin_t, xmax_t, ymax_t], dim=1).tolist()
|
||||
coarse_mask = torch.zeros(H, W, dtype=torch.bool, device=device)
|
||||
for xmin_i, ymin_i, xmax_i, ymax_i in bboxes_cpu:
|
||||
if xmax_i > xmin_i and ymax_i > ymin_i:
|
||||
coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True
|
||||
|
||||
# Analytic ray-capsule intersection. One pass over the masked pixels —
|
||||
# the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel
|
||||
# plus 6 SDF evaluations per hit pixel for finite-difference normals.
|
||||
INF = float("inf")
|
||||
flat_t = torch.full((N,), INF, device=device, dtype=torch.float32)
|
||||
flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long)
|
||||
active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1)
|
||||
if active_idx.numel() > 0:
|
||||
# Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory
|
||||
# manageable when both K (image pixels) and M (capsules) are large.
|
||||
chunk_max = max(1, int(4_000_000 / max(M, 1)))
|
||||
for i0 in range(0, active_idx.numel(), chunk_max):
|
||||
sub = active_idx[i0 : i0 + chunk_max]
|
||||
t_KM = _ray_capsule_t(
|
||||
flat_dirs[sub], starts, ends, ba_norm, ba_len, radius,
|
||||
)
|
||||
t_min, m_idx = t_KM.min(dim=1)
|
||||
hit = t_min < INF
|
||||
if hit.any():
|
||||
winners = sub[hit]
|
||||
flat_t[winners] = t_min[hit]
|
||||
flat_m_idx[winners] = m_idx[hit]
|
||||
|
||||
# Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade.
|
||||
out = torch.zeros((N, 3), dtype=torch.float32, device=device)
|
||||
if background_rgb is not None:
|
||||
out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone()
|
||||
hit_idx = torch.nonzero(flat_m_idx >= 0, as_tuple=False).squeeze(1)
|
||||
if hit_idx.numel() > 0:
|
||||
rd = flat_dirs[hit_idx]
|
||||
t_h = flat_t[hit_idx]
|
||||
m_h = flat_m_idx[hit_idx]
|
||||
p_hit = rd * t_h.unsqueeze(-1)
|
||||
|
||||
A_h = starts[m_h]
|
||||
n_h = ba_norm[m_h]
|
||||
L_h = ba_len[m_h]
|
||||
proj = ((p_hit - A_h) * n_h).sum(-1).clamp(min=0.0)
|
||||
proj = torch.minimum(proj, L_h)
|
||||
C_h = A_h + proj.unsqueeze(-1) * n_h
|
||||
normals = p_hit - C_h
|
||||
normals = normals / normals.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
|
||||
col = colors[m_h, :3]
|
||||
# SCAIL shading (render_torch.py:290-331). Light from camera (+Z toward
|
||||
# subject); diffuse term `N·-L` simplifies to `-N.z`. Specular uses the
|
||||
# proper Blinn-Phong half-vector `(view + (-L))` — using `diff` as a
|
||||
# shortcut would lock the highlight to image center.
|
||||
diff = torch.clamp(-(normals[:, 2]), min=0.0)
|
||||
diffuse = 0.45 + 0.55 * diff
|
||||
|
||||
view_dir = -rd
|
||||
half_dir = view_dir.clone()
|
||||
half_dir[:, 2] -= 1.0
|
||||
half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32)
|
||||
|
||||
# SCAIL's reference depth fade uses mm-scale constants (`z_max + 6000`)
|
||||
# that translate to almost no fade in our meter units — `depth_factor`
|
||||
# stays ~0.85-1.0. Matching that with a mild ramp.
|
||||
z_vals = p_hit[:, 2]
|
||||
z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item())
|
||||
if z_hi - z_lo > 1e-6:
|
||||
fade = 1.0 - (z_vals - z_lo) / (z_hi - z_lo)
|
||||
depth_factor = 0.85 + 0.15 * fade
|
||||
else:
|
||||
depth_factor = torch.ones_like(z_vals)
|
||||
|
||||
base = col * diffuse.unsqueeze(-1) * depth_factor.unsqueeze(-1)
|
||||
highlight = (0.5 * spec * depth_factor).unsqueeze(-1)
|
||||
out[hit_idx] = base + highlight
|
||||
|
||||
return out.view(H, W, 3).clamp(0.0, 1.0)
|
||||
|
||||
|
||||
def render_pose_data_capsules(
|
||||
pose_data: Dict[str, Any],
|
||||
*,
|
||||
frame_idx: int,
|
||||
W: int,
|
||||
H: int,
|
||||
background: Optional[torch.Tensor] = None,
|
||||
composite: str = "over",
|
||||
radius_m: float = 0.025,
|
||||
include_hands: bool = False,
|
||||
palette: str = "scail",
|
||||
person_brightness_falloff: float = 0.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Render a frame's pose_data as 3D capsules projected through the per-
|
||||
person camera. Returns (H, W, 3) fp32 in [0, 1].
|
||||
|
||||
`composite='over'` paints over `background` (black if None);
|
||||
`composite='mesh_only'` always uses a black canvas.
|
||||
|
||||
`radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`).
|
||||
Camera fx/fy come from each person's `focal_length` (pixels); cx/cy = center.
|
||||
"""
|
||||
persons = pose_data["frames"][frame_idx]
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# SAM3DBody shares one camera across the clip — pick from the first valid person.
|
||||
fx = fy = float(min(H, W))
|
||||
for person in persons:
|
||||
f = person.get("focal_length")
|
||||
if f is None:
|
||||
continue
|
||||
fx = fy = float(np.asarray(f, dtype=np.float32).reshape(-1)[0])
|
||||
break
|
||||
cx, cy = W * 0.5, H * 0.5
|
||||
|
||||
starts_np, ends_np, colors_np = _build_specs_from_pose(
|
||||
persons, include_hands=include_hands, palette=palette,
|
||||
person_brightness_falloff=person_brightness_falloff,
|
||||
)
|
||||
|
||||
bg_t: Optional[torch.Tensor] = None
|
||||
if composite == "over" and background is not None:
|
||||
bg_t = background.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)
|
||||
if bg_t.shape[:2] != (H, W):
|
||||
bg_t = bg_t.permute(2, 0, 1).unsqueeze(0)
|
||||
bg_t = torch.nn.functional.interpolate(
|
||||
bg_t, size=(H, W), mode="bilinear", align_corners=False,
|
||||
)
|
||||
bg_t = bg_t.squeeze(0).permute(1, 2, 0).contiguous()
|
||||
|
||||
if starts_np.shape[0] == 0:
|
||||
if bg_t is not None:
|
||||
return bg_t
|
||||
return torch.zeros(H, W, 3, dtype=torch.float32, device=device)
|
||||
|
||||
starts_t = torch.from_numpy(starts_np).to(device=device, dtype=torch.float32)
|
||||
ends_t = torch.from_numpy(ends_np).to(device=device, dtype=torch.float32)
|
||||
colors_t = torch.from_numpy(colors_np).to(device=device, dtype=torch.float32)
|
||||
|
||||
return _render_capsules_torch(
|
||||
starts_t, ends_t, colors_t,
|
||||
H=H, W=W, fx=fx, fy=fy, cx=cx, cy=cy,
|
||||
radius=float(radius_m),
|
||||
background_rgb=bg_t,
|
||||
device=device,
|
||||
)
|
||||
1138
comfy_extras/sam3d_body/export/glb_openpose.py
Normal file
1138
comfy_extras/sam3d_body/export/glb_openpose.py
Normal file
File diff suppressed because it is too large
Load Diff
1138
comfy_extras/sam3d_body/export/glb_shared.py
Normal file
1138
comfy_extras/sam3d_body/export/glb_shared.py
Normal file
File diff suppressed because it is too large
Load Diff
578
comfy_extras/sam3d_body/export/glb_skeletal.py
Normal file
578
comfy_extras/sam3d_body/export/glb_skeletal.py
Normal file
@ -0,0 +1,578 @@
|
||||
"""GLB export — skeletal (real armature) mode.
|
||||
|
||||
Rebuilds an Armature with the MHR 127-bone rig:
|
||||
- per-frame local TRS comes from re-running param_transform on the saved
|
||||
`mhr_model_params`;
|
||||
- rest verts come from a zero-pose forward with each person's `shape_params`;
|
||||
- sparse triplet skinning is compacted to glTF's max-4-influences-per-vertex form;
|
||||
- facial expression is re-exposed as 72 morph targets driven by `expr_params`
|
||||
so face animation survives plain glTF skinning.
|
||||
|
||||
Optional bone visualization (octahedrons / sticks) is rigidly
|
||||
skinned alongside the body mesh — used to preview the armature in glTF
|
||||
viewers that don't draw bones.
|
||||
|
||||
Shared GLB infra (writer, math, rig static extraction, shaders, normals)
|
||||
stays in `glb_shared.py`; only this mode's geometry + assembly live here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .glb_shared import (
|
||||
GLBWriter,
|
||||
bake_vertex_colors,
|
||||
bind_skel_state,
|
||||
bone_locals_from_globals,
|
||||
collect_tracks,
|
||||
compact_skin_to_n,
|
||||
compute_normals,
|
||||
compute_pastel_mix,
|
||||
extract_rig_static,
|
||||
flat_shade_mesh,
|
||||
gaussian_smooth_quats,
|
||||
global_skel_state_from_pose_data,
|
||||
global_skel_state_per_frame,
|
||||
ibp_from_bind_global,
|
||||
make_lit_material,
|
||||
quat_sign_fix_per_joint,
|
||||
rotation_align,
|
||||
unflip,
|
||||
zero_pose_rest_verts,
|
||||
)
|
||||
|
||||
from comfy_extras.sam3d_body.utils import jet_colormap
|
||||
|
||||
def _bone_colors_rgb(bind_pos_m: np.ndarray, scheme: str) -> Optional[np.ndarray]:
|
||||
"""Per-bone RGB color (NJ, 3) float32 in [0, 1]. Returns None for 'white'
|
||||
(no per-bone color → bone-vis mesh uses default unlit material)."""
|
||||
if scheme == "rainbow_y":
|
||||
y = bind_pos_m[:, 1].astype(np.float32)
|
||||
y_min, y_max = float(y.min()), float(y.max())
|
||||
s = np.clip((y - y_min) / max(y_max - y_min, 1e-6), 0.0, 1.0)
|
||||
return jet_colormap(s)
|
||||
return None
|
||||
|
||||
|
||||
def _octahedron_unit() -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Canonical Blender-style bone octahedron. Head at origin, tail at +Y,
|
||||
unit length, ridge at 1/10 height. 6 verts, 8 triangles. Faces wound
|
||||
so cross(v1-v0, v2-v0) points OUTWARD from the bone axis."""
|
||||
v = np.array([
|
||||
[0.0, 0.0, 0.0], # 0: head
|
||||
[0.0, 1.0, 0.0], # 1: tail
|
||||
[1.0, 0.1, 0.0], # 2: +X ridge (pre-scale; X/Z scale by half_width)
|
||||
[-1.0, 0.1, 0.0], # 3: -X ridge
|
||||
[0.0, 0.1, 1.0], # 4: +Z ridge
|
||||
[0.0, 0.1, -1.0], # 5: -Z ridge
|
||||
], dtype=np.float32)
|
||||
f = np.array([
|
||||
# head pyramid: outward = away from bone axis, slightly -Y
|
||||
[0, 2, 4], [0, 5, 2], [0, 3, 5], [0, 4, 3],
|
||||
# tail pyramid: outward = away from bone axis, slightly +Y
|
||||
[1, 4, 2], [1, 3, 4], [1, 5, 3], [1, 2, 5],
|
||||
], dtype=np.uint32)
|
||||
return v, f
|
||||
|
||||
|
||||
def _bone_edges(
|
||||
joint_pos_m: np.ndarray, parents: np.ndarray,
|
||||
) -> List[Tuple[int, int, np.ndarray, np.ndarray]]:
|
||||
"""Return one (parent_idx, child_idx, head_pos, tail_pos) tuple per
|
||||
parent→child edge in the hierarchy, skipping edges whose PARENT is a
|
||||
root joint (those typically anchor the skeleton at world origin and
|
||||
just look like a stray stick from origin to the body). Zero-length
|
||||
edges are skipped too."""
|
||||
NJ = joint_pos_m.shape[0]
|
||||
out: List[Tuple[int, int, np.ndarray, np.ndarray]] = []
|
||||
for c in range(NJ):
|
||||
p = int(parents[c])
|
||||
if not (0 <= p < NJ and p != c):
|
||||
continue
|
||||
# Skip if parent itself is a root — that bone is a world-anchor stick.
|
||||
gp = int(parents[p])
|
||||
if not (0 <= gp < NJ and gp != p):
|
||||
continue
|
||||
head = joint_pos_m[p].astype(np.float32)
|
||||
tail = joint_pos_m[c].astype(np.float32)
|
||||
if float(np.linalg.norm(tail - head)) < 1e-6:
|
||||
continue
|
||||
out.append((p, c, head, tail))
|
||||
return out
|
||||
|
||||
|
||||
def _build_bone_octahedrons_mesh(
|
||||
bind_joint_pos_m: np.ndarray, parents: np.ndarray, half_width_m: float = 0.02,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""One Blender-style octahedron per parent→child edge. Returns
|
||||
(verts, normals, faces, joints, weights, child_idx_per_vert);
|
||||
child_idx feeds per-bone color lookup at the call site."""
|
||||
base_v, base_f = _octahedron_unit()
|
||||
canonical = np.array([0.0, 1.0, 0.0], dtype=np.float32)
|
||||
|
||||
out_v: List[List[float]] = []
|
||||
out_n: List[List[float]] = []
|
||||
out_f: List[List[int]] = []
|
||||
out_j: List[List[int]] = []
|
||||
out_w: List[List[float]] = []
|
||||
child_per_vert: List[int] = []
|
||||
|
||||
# Width scales with length so short bones (fingers, face) don't look chunky
|
||||
# next to long ones (limbs, spine). `half_width_m` caps long bones.
|
||||
WIDTH_RATIO = 0.1
|
||||
MIN_WIDTH = 0.001
|
||||
for parent_idx, child_idx, head, tail in _bone_edges(bind_joint_pos_m, parents):
|
||||
direction = tail - head
|
||||
length = float(np.linalg.norm(direction))
|
||||
if length < 1e-6:
|
||||
continue
|
||||
unit_dir = direction / length
|
||||
R = rotation_align(canonical, unit_dir)
|
||||
|
||||
half_width_eff = max(MIN_WIDTH, min(length * WIDTH_RATIO, half_width_m))
|
||||
scale = np.array([half_width_eff, length, half_width_eff], dtype=np.float32)
|
||||
v_local = base_v * scale
|
||||
v_world = v_local @ R.T + head
|
||||
|
||||
# head pole outward = -Y, tail pole +Y, ridges outward in XZ.
|
||||
n_local = np.zeros_like(base_v)
|
||||
n_local[0] = [0.0, -1.0, 0.0]
|
||||
n_local[1] = [0.0, 1.0, 0.0]
|
||||
for k in range(2, 6):
|
||||
n = base_v[k].copy()
|
||||
n[1] = 0.0
|
||||
n_norm = float(np.linalg.norm(n))
|
||||
if n_norm > 0:
|
||||
n_local[k] = n / n_norm
|
||||
n_world = n_local @ R.T
|
||||
|
||||
v_off = len(out_v)
|
||||
out_v.extend(v_world.tolist())
|
||||
out_n.extend(n_world.tolist())
|
||||
for face in base_f:
|
||||
out_f.append([int(face[0]) + v_off, int(face[1]) + v_off, int(face[2]) + v_off])
|
||||
# Dual skin head→parent, tail→child, ridges blend by canonical Y so the
|
||||
# bone stretches between joints instead of going rigid with one.
|
||||
for k in range(base_v.shape[0]):
|
||||
y_canon = float(base_v[k, 1])
|
||||
w_parent = max(0.0, 1.0 - y_canon)
|
||||
w_child = max(0.0, y_canon)
|
||||
wsum = w_parent + w_child
|
||||
if wsum > 0:
|
||||
w_parent /= wsum
|
||||
w_child /= wsum
|
||||
out_j.append([int(parent_idx), int(child_idx), 0, 0])
|
||||
out_w.append([w_parent, w_child, 0.0, 0.0])
|
||||
child_per_vert.append(int(child_idx))
|
||||
|
||||
if not out_v:
|
||||
return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32),
|
||||
np.zeros((0, 3), dtype=np.uint32), np.zeros((0, 4), dtype=np.uint16),
|
||||
np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=np.int64))
|
||||
return (np.asarray(out_v, dtype=np.float32),
|
||||
np.asarray(out_n, dtype=np.float32),
|
||||
np.asarray(out_f, dtype=np.uint32),
|
||||
np.asarray(out_j, dtype=np.uint16),
|
||||
np.asarray(out_w, dtype=np.float32),
|
||||
np.asarray(child_per_vert, dtype=np.int64))
|
||||
|
||||
|
||||
def build_glb_skeletal(
|
||||
pose_data: Dict[str, Any],
|
||||
model: Any = None,
|
||||
*,
|
||||
fps: float = 24.0,
|
||||
camera_translation: str = "off",
|
||||
track_index: int = -1,
|
||||
include_face_morphs: bool = True,
|
||||
shader: str = "default",
|
||||
rainbow_tilt_x_deg: float = 0.0,
|
||||
rainbow_tilt_z_deg: float = 0.0,
|
||||
person_palette_falloff: float = 0.6,
|
||||
bone_smooth_window: int = 0,
|
||||
use_stored_global_rots: bool = True,
|
||||
bone_vis: str = "off",
|
||||
bone_vis_radius_m: float = 0.04,
|
||||
bone_vis_color: str = "white",
|
||||
include_body_mesh: bool = True,
|
||||
) -> bytes:
|
||||
"""Build pose_data as a real Armature GLB blob with per-bone TRS keyframes.
|
||||
|
||||
For MHR (default) facial expression is exposed as 72 morph targets driven
|
||||
by expr_params per frame when include_face_morphs=True.
|
||||
|
||||
External skeletons (e.g. ComfyUI-Kimodo) can supply a
|
||||
``pose_data["_skeleton_override"]`` dict to bypass the MHR rig extraction
|
||||
entirely. When present, ``model`` may be None and the rig data, bind pose,
|
||||
skin weights, and rest verts come from the override. Per-frame skeletal
|
||||
state still reads ``pred_global_rots`` / ``pred_joint_coords`` from each
|
||||
person dict (kimodo populates these from its own FK output). See
|
||||
``glb.shared._get_skeleton_override`` for the override schema.
|
||||
"""
|
||||
frames = pose_data["frames"]
|
||||
# Only `pred_cam_t` is camera-y-down; mhr_model_params, lbs_*, expr_basis,
|
||||
# faces are all rig-native (Y-up).
|
||||
faces_native = np.ascontiguousarray(pose_data["faces"], dtype=np.uint32)
|
||||
tracks = collect_tracks(pose_data, track_index)
|
||||
if not tracks:
|
||||
raise ValueError("build_glb_skeletal: no valid tracks in pose_data")
|
||||
|
||||
rig_static = extract_rig_static(model, pose_data)
|
||||
NJ = rig_static["num_joints"]
|
||||
NV = rig_static["num_verts"]
|
||||
NEXPR = rig_static["num_expr"]
|
||||
parents = rig_static["parents"]
|
||||
is_external = bool(rig_static.get("_external", False))
|
||||
if is_external:
|
||||
# External rigs have no PCA pose params to re-run; only stored globals
|
||||
# are available, and kimodo stores joint coords already Y-up.
|
||||
use_stored_global_rots = True
|
||||
joint_coords_y_down = not is_external
|
||||
# Compact sparse skinning to 8 influences per vertex into glTF's two
|
||||
# JOINTS_*/WEIGHTS_* sets. MHR averages ~2.8 influences/vert but some
|
||||
# shoulder/hip verts have 5-8 where multiple joints cancel — keeping only
|
||||
# 4 there leaks per-bone rotation noise into the rendered mesh.
|
||||
if is_external:
|
||||
joints_8 = rig_static["lbs_compact_joints"]
|
||||
weights_8 = rig_static["lbs_compact_weights"]
|
||||
actual_max_inf = rig_static["lbs_compact_max_inf"]
|
||||
else:
|
||||
joints_8, weights_8, actual_max_inf = compact_skin_to_n(
|
||||
rig_static["lbs_skin_indices"], rig_static["lbs_vert_indices"],
|
||||
rig_static["lbs_skin_weights"], NV, max_inf=8,
|
||||
)
|
||||
joints_set0 = np.ascontiguousarray(joints_8[:, :4])
|
||||
weights_set0 = np.ascontiguousarray(weights_8[:, :4])
|
||||
use_set1 = actual_max_inf > 4
|
||||
joints_set1 = np.ascontiguousarray(joints_8[:, 4:8]) if use_set1 else None
|
||||
weights_set1 = np.ascontiguousarray(weights_8[:, 4:8]) if use_set1 else None
|
||||
# Derive bone locals from the rig's bind globals rather than recomputing
|
||||
# FK ourselves, so any mismatch between `parents` and the rig's actual FK
|
||||
# is absorbed into the local TRS instead of producing wrong globals.
|
||||
bind_global_cm = bind_skel_state(model, pose_data)
|
||||
bind_global_m = bind_global_cm.copy().astype(np.float32)
|
||||
bind_global_m[:, :3] *= 0.01
|
||||
bind_local = bone_locals_from_globals(bind_global_m[None], rig_static["parents"])[0]
|
||||
|
||||
# IBP = inverse of bind global. With bone defaults set to bind_local and
|
||||
# FK composed via `parents`, skin_matrix at rest = identity.
|
||||
ibp_mat4 = ibp_from_bind_global(bind_global_m)
|
||||
|
||||
w = GLBWriter()
|
||||
|
||||
nodes: List[dict] = []
|
||||
meshes: List[dict] = []
|
||||
skins: List[dict] = []
|
||||
materials: List[dict] = []
|
||||
animations: List[dict] = []
|
||||
scene_root_indices: List[int] = []
|
||||
canonical_colors = pose_data.get("canonical_colors")
|
||||
|
||||
indices_acc = w.add_indices_u32(faces_native)
|
||||
joints0_acc = w.add_joints_u16(joints_set0)
|
||||
weights0_acc = w.add_weights_f32(weights_set0)
|
||||
joints1_acc = w.add_joints_u16(joints_set1) if use_set1 else None
|
||||
weights1_acc = w.add_weights_f32(weights_set1) if use_set1 else None
|
||||
ibm_acc = w.add_mat4_f32(ibp_mat4)
|
||||
|
||||
expr_morph_accs: List[int] = []
|
||||
if include_face_morphs and NEXPR > 0:
|
||||
eb = rig_static["expr_basis"].astype(np.float32) * 0.01
|
||||
for e in range(NEXPR):
|
||||
expr_morph_accs.append(w.add_vec3_f32_no_minmax(eb[e]))
|
||||
|
||||
for track_i, (person_k, frame_indices) in enumerate(tracks):
|
||||
person_root = {"name": f"track{track_i:02d}", "children": []}
|
||||
nodes.append(person_root)
|
||||
person_root_idx = len(nodes) - 1
|
||||
scene_root_indices.append(person_root_idx)
|
||||
|
||||
bone_node_indices: List[int] = []
|
||||
for j in range(NJ):
|
||||
bone = {
|
||||
"name": f"bone_{j:03d}",
|
||||
"translation": bind_local[j, :3].tolist(),
|
||||
"rotation": bind_local[j, 3:7].tolist(),
|
||||
"scale": [float(bind_local[j, 7])] * 3,
|
||||
}
|
||||
nodes.append(bone)
|
||||
bone_node_indices.append(len(nodes) - 1)
|
||||
|
||||
bone_children: List[List[int]] = [[] for _ in range(NJ)]
|
||||
bone_root_indices: List[int] = []
|
||||
for j in range(NJ):
|
||||
p = int(parents[j])
|
||||
if 0 <= p < NJ and p != j:
|
||||
bone_children[p].append(bone_node_indices[j])
|
||||
else:
|
||||
bone_root_indices.append(bone_node_indices[j])
|
||||
for j in range(NJ):
|
||||
if bone_children[j]:
|
||||
nodes[bone_node_indices[j]]["children"] = bone_children[j]
|
||||
person_root["children"].extend(bone_root_indices)
|
||||
|
||||
skin = {
|
||||
"joints": bone_node_indices,
|
||||
"inverseBindMatrices": ibm_acc,
|
||||
"skeleton": bone_root_indices[0] if bone_root_indices else bone_node_indices[0],
|
||||
}
|
||||
skins.append(skin)
|
||||
skin_idx = len(skins) - 1
|
||||
|
||||
include_body = bool(include_body_mesh)
|
||||
include_bones = bone_vis in ("octahedrons", "sticks")
|
||||
body_mesh_node_idx: Optional[int] = None
|
||||
|
||||
if include_body:
|
||||
# External rigs have no PCA shape — `zero_pose_rest_verts` short-
|
||||
# circuits to `pose_data["_skeleton_override"]["rest_verts_m"]`,
|
||||
# so zeroed shape_params is safe there.
|
||||
if is_external:
|
||||
shape_params_arr = np.zeros(0, dtype=np.float32)
|
||||
else:
|
||||
shape_params_arr = np.asarray(
|
||||
frames[frame_indices[0]][person_k]["shape_params"], dtype=np.float32,
|
||||
)
|
||||
rest_v = zero_pose_rest_verts(model, shape_params_arr, pose_data=pose_data)
|
||||
normals = compute_normals(rest_v, faces_native)
|
||||
positions_acc = w.add_vec3_f32(rest_v)
|
||||
normals_acc = w.add_vec3_f32(normals)
|
||||
|
||||
pastel_mix = compute_pastel_mix(track_i, person_palette_falloff)
|
||||
vcolor = bake_vertex_colors(
|
||||
canonical_colors, shader,
|
||||
rainbow_tilt_x_deg, rainbow_tilt_z_deg, pastel_mix,
|
||||
)
|
||||
color_acc = w.add_vec3_f32(vcolor) if vcolor is not None else None
|
||||
|
||||
attributes = {
|
||||
"POSITION": positions_acc, "NORMAL": normals_acc,
|
||||
"JOINTS_0": joints0_acc, "WEIGHTS_0": weights0_acc,
|
||||
}
|
||||
if joints1_acc is not None:
|
||||
attributes["JOINTS_1"] = joints1_acc
|
||||
attributes["WEIGHTS_1"] = weights1_acc
|
||||
if color_acc is not None:
|
||||
attributes["COLOR_0"] = color_acc
|
||||
primitive = {
|
||||
"attributes": attributes,
|
||||
"indices": indices_acc,
|
||||
"mode": 4,
|
||||
}
|
||||
if color_acc is not None:
|
||||
materials.append(make_lit_material())
|
||||
primitive["material"] = len(materials) - 1
|
||||
if expr_morph_accs:
|
||||
primitive["targets"] = [{"POSITION": a} for a in expr_morph_accs]
|
||||
|
||||
mesh = {"primitives": [primitive]}
|
||||
if expr_morph_accs:
|
||||
mesh["weights"] = [0.0] * len(expr_morph_accs)
|
||||
meshes.append(mesh)
|
||||
mesh_idx = len(meshes) - 1
|
||||
|
||||
mesh_node = {
|
||||
"name": f"track{track_i:02d}_mesh", "mesh": mesh_idx, "skin": skin_idx,
|
||||
}
|
||||
nodes.append(mesh_node)
|
||||
body_mesh_node_idx = len(nodes) - 1
|
||||
person_root["children"].append(body_mesh_node_idx)
|
||||
|
||||
if include_bones:
|
||||
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
|
||||
|
||||
# Indexes `bone_palette`: octahedrons/sticks use the bone's child
|
||||
# joint so every bone has its own color regardless of skin target.
|
||||
# 'sticks' = thin octahedrons. glTF LINES skinning is unreliable
|
||||
# (Three.js' GLTFLoader doesn't animate skinned line primitives),
|
||||
# so we render triangle tubes instead.
|
||||
color_idx_per_vert: Optional[np.ndarray] = None
|
||||
hw = float(bone_vis_radius_m) if bone_vis == "octahedrons" else 0.0035
|
||||
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
|
||||
bind_global_m[:, :3], rig_static["parents"], half_width_m=hw,
|
||||
)
|
||||
if bv_v.shape[0] > 0:
|
||||
F = bv_f.shape[0]
|
||||
expanded_child = np.empty((F * 3,), dtype=np.int64)
|
||||
for k in range(3):
|
||||
expanded_child[k::3] = child_per_vert[bv_f[:, k]]
|
||||
bv_v, bv_n, bv_f, bv_j, bv_w = flat_shade_mesh(bv_v, bv_f, bv_j, bv_w)
|
||||
color_idx_per_vert = expanded_child
|
||||
primitive_mode = 4
|
||||
bv_idx_flat = bv_f.reshape(-1)
|
||||
|
||||
if bv_v.shape[0] > 0:
|
||||
bv_pos_acc = w.add_vec3_f32(bv_v)
|
||||
bv_idx_acc = w.add_indices_u32(bv_idx_flat)
|
||||
bv_j_acc = w.add_joints_u16(bv_j)
|
||||
bv_w_acc = w.add_weights_f32(bv_w)
|
||||
bv_attrs = {
|
||||
"POSITION": bv_pos_acc,
|
||||
"JOINTS_0": bv_j_acc, "WEIGHTS_0": bv_w_acc,
|
||||
}
|
||||
if bv_n is not None:
|
||||
bv_attrs["NORMAL"] = w.add_vec3_f32(bv_n)
|
||||
if bone_palette is not None and color_idx_per_vert is not None:
|
||||
bv_color = bone_palette[color_idx_per_vert].astype(np.float32)
|
||||
bv_attrs["COLOR_0"] = w.add_vec3_f32(bv_color)
|
||||
bv_primitive = {
|
||||
"attributes": bv_attrs,
|
||||
"indices": bv_idx_acc,
|
||||
"mode": primitive_mode,
|
||||
}
|
||||
if bone_palette is not None:
|
||||
materials.append(make_lit_material())
|
||||
bv_primitive["material"] = len(materials) - 1
|
||||
bv_mesh = {"primitives": [bv_primitive]}
|
||||
meshes.append(bv_mesh)
|
||||
bv_mesh_node = {
|
||||
"name": f"track{track_i:02d}_bones",
|
||||
"mesh": len(meshes) - 1,
|
||||
"skin": skin_idx,
|
||||
}
|
||||
nodes.append(bv_mesh_node)
|
||||
person_root["children"].append(len(nodes) - 1)
|
||||
|
||||
# Per-frame GLOBAL skel state → bone locals via parent-inverse.
|
||||
# Default uses the rig's stored output; the fallback re-runs FK.
|
||||
if use_stored_global_rots:
|
||||
rig_global_m = global_skel_state_from_pose_data(
|
||||
pose_data, frame_indices, person_k, NJ,
|
||||
joint_coords_y_down=joint_coords_y_down,
|
||||
)
|
||||
else:
|
||||
mp_per_frame = np.stack([
|
||||
np.asarray(frames[t][person_k]["mhr_model_params"], dtype=np.float32)
|
||||
for t in frame_indices
|
||||
], axis=0)
|
||||
rig_global_cm = global_skel_state_per_frame(model, mp_per_frame)
|
||||
rig_global_m = rig_global_cm.copy().astype(np.float32)
|
||||
rig_global_m[..., :3] *= 0.01
|
||||
# Sign-fix on the GLOBAL quats BEFORE deriving locals. The rig's
|
||||
# Euler-XYZ parametrization wraps at ±180° for spinning joints; if we
|
||||
# only fix locals, the parent's flip propagates into the child's
|
||||
# local translation (t_local inherits parent sign via q_parent_inv)
|
||||
# and produces visible "axis resets" mid-animation.
|
||||
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
|
||||
bone_local_anim = bone_locals_from_globals(rig_global_m, rig_static["parents"])
|
||||
local_t = bone_local_anim[..., :3].astype(np.float32)
|
||||
local_q = bone_local_anim[..., 3:7].astype(np.float32)
|
||||
local_s = bone_local_anim[..., 7].astype(np.float32)
|
||||
# Second pass on locals catches residual drift from the parent-inverse.
|
||||
local_q = quat_sign_fix_per_joint(local_q)
|
||||
# Hemisphere-align frame 0 with the bind quat so pause/play takes the
|
||||
# short path; then re-propagate.
|
||||
bind_q = bind_local[:, 3:7].astype(np.float32)
|
||||
if local_q.shape[0] > 0:
|
||||
d0 = (bind_q * local_q[0]).sum(axis=-1)
|
||||
sign0 = np.where(d0 < 0, -1.0, 1.0).astype(np.float32)[:, None]
|
||||
local_q[0] = local_q[0] * sign0
|
||||
local_q = quat_sign_fix_per_joint(local_q)
|
||||
# Optional smoothing for multi-frame rig spikes (e.g. q.w discontinuity
|
||||
# at handstand) that the upstream Smooth node may not catch.
|
||||
if bone_smooth_window and bone_smooth_window > 1:
|
||||
local_q = gaussian_smooth_quats(local_q, int(bone_smooth_window))
|
||||
# fp64 renormalize → fp32 keyframes. Viewers' nlerp amplifies non-unit
|
||||
# drift into visible flips otherwise.
|
||||
lq64 = local_q.astype(np.float64)
|
||||
lq64 = lq64 / np.maximum(np.linalg.norm(lq64, axis=-1, keepdims=True), 1e-12)
|
||||
local_q = lq64.astype(np.float32)
|
||||
|
||||
n_frames = len(frame_indices)
|
||||
times = np.asarray(frame_indices, dtype=np.float32) / float(fps)
|
||||
time_acc = w.add_scalar_f32(times)
|
||||
|
||||
samplers: List[dict] = []
|
||||
channels: List[dict] = []
|
||||
|
||||
for j in range(NJ):
|
||||
t_j = local_t[:, j, :]
|
||||
q_j = local_q[:, j, :]
|
||||
s_j = np.broadcast_to(local_s[:, j:j+1], (n_frames, 3)).astype(np.float32)
|
||||
|
||||
t_const = (np.ptp(t_j, axis=0) < 1e-6).all()
|
||||
q_const = (np.ptp(q_j, axis=0) < 1e-6).all()
|
||||
s_const = (np.ptp(s_j, axis=0) < 1e-6).all()
|
||||
|
||||
if t_const:
|
||||
nodes[bone_node_indices[j]]["translation"] = t_j[0].tolist()
|
||||
else:
|
||||
acc = w.add_vec3_f32_anim(t_j)
|
||||
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
|
||||
channels.append({
|
||||
"sampler": len(samplers) - 1,
|
||||
"target": {"node": bone_node_indices[j], "path": "translation"},
|
||||
})
|
||||
|
||||
if q_const:
|
||||
nodes[bone_node_indices[j]]["rotation"] = q_j[0].tolist()
|
||||
else:
|
||||
acc = w.add_vec4_f32(q_j)
|
||||
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
|
||||
channels.append({
|
||||
"sampler": len(samplers) - 1,
|
||||
"target": {"node": bone_node_indices[j], "path": "rotation"},
|
||||
})
|
||||
|
||||
if s_const:
|
||||
nodes[bone_node_indices[j]]["scale"] = s_j[0].tolist()
|
||||
else:
|
||||
acc = w.add_vec3_f32_anim(s_j)
|
||||
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
|
||||
channels.append({
|
||||
"sampler": len(samplers) - 1,
|
||||
"target": {"node": bone_node_indices[j], "path": "scale"},
|
||||
})
|
||||
|
||||
if camera_translation != "off":
|
||||
cam_t = np.stack([
|
||||
unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32))
|
||||
for t in frame_indices
|
||||
], axis=0)
|
||||
if camera_translation == "centered" and cam_t.shape[0] > 0:
|
||||
cam_t = cam_t - cam_t[0:1]
|
||||
if (np.ptp(cam_t, axis=0) < 1e-6).all():
|
||||
person_root["translation"] = cam_t[0].tolist()
|
||||
else:
|
||||
acc = w.add_vec3_f32_anim(cam_t)
|
||||
samplers.append({"input": time_acc, "output": acc, "interpolation": "LINEAR"})
|
||||
channels.append({
|
||||
"sampler": len(samplers) - 1,
|
||||
"target": {"node": person_root_idx, "path": "translation"},
|
||||
})
|
||||
|
||||
# Body-mesh-only: bone-vis primitives have no morph targets.
|
||||
if expr_morph_accs and body_mesh_node_idx is not None:
|
||||
expr_per_frame = np.stack([
|
||||
np.asarray(frames[t][person_k]["expr_params"], dtype=np.float32)
|
||||
for t in frame_indices
|
||||
], axis=0).astype(np.float32)
|
||||
weights_acc_anim = w.add_scalar_f32_flat(expr_per_frame, count=n_frames * NEXPR)
|
||||
samplers.append({"input": time_acc, "output": weights_acc_anim, "interpolation": "LINEAR"})
|
||||
channels.append({
|
||||
"sampler": len(samplers) - 1,
|
||||
"target": {"node": body_mesh_node_idx, "path": "weights"},
|
||||
})
|
||||
|
||||
animations.append({
|
||||
"name": f"track{track_i:02d}",
|
||||
"samplers": samplers, "channels": channels,
|
||||
})
|
||||
|
||||
gltf = {
|
||||
"asset": {"version": "2.0", "generator": "ComfyUI-SAM3DBody"},
|
||||
"scene": 0,
|
||||
"scenes": [{"nodes": scene_root_indices}],
|
||||
"nodes": nodes,
|
||||
"meshes": meshes,
|
||||
"skins": skins,
|
||||
}
|
||||
if materials:
|
||||
gltf["materials"] = materials
|
||||
if animations:
|
||||
gltf["animations"] = animations
|
||||
|
||||
return w.to_bytes(gltf)
|
||||
233
comfy_extras/sam3d_body/export/openpose_2d.py
Normal file
233
comfy_extras/sam3d_body/export/openpose_2d.py
Normal file
@ -0,0 +1,233 @@
|
||||
"""2D OpenPose-style skeleton rendering for SAM 3D Body pose_data.
|
||||
|
||||
Body / hand drawing is delegated to `KeypointDraw.draw_wholebody_keypoints`
|
||||
(shared with SDPose). SAM3D-specific: MHR70 -> DWPose-134 keypoint packing,
|
||||
plus optional rig-projected face landmarks when `pred_face_keypoints_2d`
|
||||
isn't present (and arbitrary-count face dots, since sapiens-238 doesn't fit
|
||||
the DWPose face slot).
|
||||
|
||||
Output: (H, W, 3) fp32 torch.Tensor in [0, 1].
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy_extras.pose.keypoint_draw import KeypointDraw
|
||||
|
||||
from .glb_shared import (
|
||||
OPENPOSE18_TO_MHR70,
|
||||
OPENPOSE_HAND21_TO_MHR70_L,
|
||||
OPENPOSE_HAND21_TO_MHR70_R,
|
||||
OPENPOSE_HAND_COLORS_21,
|
||||
select_face_landmark_vert_ids,
|
||||
)
|
||||
|
||||
|
||||
_KD = KeypointDraw()
|
||||
# OpenPose hand palette as a (21, 3) int array (0..255) for KeypointDraw.
|
||||
_HAND_DOT_PALETTE_OPENPOSE = (OPENPOSE_HAND_COLORS_21 * 255.0).astype(int)
|
||||
|
||||
|
||||
def _project_face_landmarks_2d(
|
||||
person: Dict[str, Any], face_vert_ids: np.ndarray, H: int, W: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Project `pred_vertices[face_vert_ids]` to 2D using each person's
|
||||
pred_cam_t + focal_length. Same projection used by `_replay_mhr_with_overrides`."""
|
||||
verts = person.get("pred_vertices")
|
||||
cam_t = person.get("pred_cam_t")
|
||||
focal = person.get("focal_length")
|
||||
if verts is None or cam_t is None or focal is None:
|
||||
return None
|
||||
verts = np.asarray(verts, dtype=np.float32)
|
||||
cam_t = np.asarray(cam_t, dtype=np.float32).reshape(3)
|
||||
f = float(np.asarray(focal, dtype=np.float32).reshape(-1)[0])
|
||||
pts3 = verts[face_vert_ids] + cam_t[None, :]
|
||||
z = np.maximum(pts3[:, 2:3], 1e-6)
|
||||
xy = pts3[:, :2] * f
|
||||
xy = xy + np.array([W * 0.5, H * 0.5], dtype=np.float32)[None, :] * z
|
||||
return (xy / z).astype(np.float32)
|
||||
|
||||
|
||||
def _pack_dwpose_134(
|
||||
person: Dict[str, Any], *, include_body: bool, include_hands: bool,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Pack a SAM3D person dict into (kp, scores): (134, 2) DWPose-layout
|
||||
coords + (134,) confidence. Face slot (24-91) is left zeroed; face dots
|
||||
are drawn separately so SAM3D's 238-sapiens / rig-fallback counts work.
|
||||
Non-finite or out-of-band entries get score=0 and are filtered downstream."""
|
||||
kp = np.zeros((134, 2), dtype=np.float32)
|
||||
scores = np.zeros(134, dtype=np.float32)
|
||||
|
||||
kp2d_full = person.get("pred_keypoints_2d")
|
||||
if kp2d_full is None:
|
||||
return kp, scores
|
||||
kp2d = np.asarray(kp2d_full, dtype=np.float32)
|
||||
if kp2d.ndim != 2 or kp2d.shape[1] != 2 or kp2d.shape[0] < 70:
|
||||
return kp, scores
|
||||
|
||||
if include_body:
|
||||
body_xy = kp2d[OPENPOSE18_TO_MHR70]
|
||||
finite = np.isfinite(body_xy).all(axis=1)
|
||||
kp[:18][finite] = body_xy[finite]
|
||||
scores[:18][finite] = 1.0
|
||||
|
||||
if include_hands:
|
||||
for slot_start, mhr_idx in ((92, OPENPOSE_HAND21_TO_MHR70_R),
|
||||
(113, OPENPOSE_HAND21_TO_MHR70_L)):
|
||||
hand_xy = kp2d[mhr_idx]
|
||||
finite = np.isfinite(hand_xy).all(axis=1)
|
||||
kp[slot_start:slot_start + 21][finite] = hand_xy[finite]
|
||||
scores[slot_start:slot_start + 21][finite] = 1.0
|
||||
|
||||
return kp, scores
|
||||
|
||||
|
||||
def _draw_face_dots(
|
||||
canvas: np.ndarray, face_xy: np.ndarray, marker_radius_px: int,
|
||||
) -> None:
|
||||
"""White face dots, variable count (238 sapiens / ~30 rig-projected)."""
|
||||
H, W = canvas.shape[:2]
|
||||
pad = int(marker_radius_px)
|
||||
white = (255, 255, 255)
|
||||
for i in range(face_xy.shape[0]):
|
||||
x_, y_ = float(face_xy[i, 0]), float(face_xy[i, 1])
|
||||
if not (np.isfinite(x_) and np.isfinite(y_)):
|
||||
continue
|
||||
x, y = int(round(x_)), int(round(y_))
|
||||
if x + pad < 0 or x - pad >= W or y + pad < 0 or y - pad >= H:
|
||||
continue
|
||||
_KD.draw.circle(canvas, (x, y), int(marker_radius_px), white, thickness=-1)
|
||||
|
||||
|
||||
def render_pose_data_openpose(
|
||||
pose_data: Dict[str, Any],
|
||||
*,
|
||||
frame_idx: int,
|
||||
W: int,
|
||||
H: int,
|
||||
background: Optional[torch.Tensor] = None,
|
||||
composite: str = "over",
|
||||
marker_radius_px: int = 4,
|
||||
stick_width_px: int = 4,
|
||||
limb_alpha: float = 0.6,
|
||||
include_body: bool = True,
|
||||
include_hands: bool = False,
|
||||
face_style: str = "disabled",
|
||||
hand_color_style: str = "dwpose",
|
||||
hand_marker_radius_px: int = 0,
|
||||
hand_stick_width_px: int = 0,
|
||||
face_marker_radius_px: int = 3,
|
||||
person_brightness_falloff: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""Render a 2D OpenPose-style skeleton onto an (H, W, 3) canvas.
|
||||
|
||||
`composite='over'` paints over `background` (else black canvas).
|
||||
`hand_marker_radius_px` / `hand_stick_width_px`: 0 = auto = 0.7x / 0.5x
|
||||
of the body sizes.
|
||||
`face_style`: 'disabled' = no face dots, 'full' = all face landmarks
|
||||
(prefers sapiens-238 if present, else rig-fallback ~30), 'eyes_mouth' =
|
||||
rig-fallback subset (~12 dots: eyes + mouth only). The subset only has a
|
||||
documented layout for the rig fallback so eyes_mouth always uses it,
|
||||
regardless of whether sapiens-238 is available.
|
||||
`person_brightness_falloff` mixes each person's drawn pixels toward white
|
||||
by `1 - falloff^k` (track 0 stays vivid). Applied post-draw so per-limb
|
||||
alpha blending against the existing canvas remains correct.
|
||||
"""
|
||||
persons = pose_data["frames"][frame_idx]
|
||||
|
||||
if composite == "over" and background is not None:
|
||||
bg = background.cpu().numpy()
|
||||
canvas = (np.clip(bg, 0.0, 1.0) * 255.0).astype(np.uint8)
|
||||
if canvas.shape[:2] != (H, W):
|
||||
canvas = np.array(Image.fromarray(canvas).resize((W, H), Image.LANCZOS))
|
||||
else:
|
||||
canvas = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
# In-place draw needs a contiguous writable buffer.
|
||||
canvas = np.ascontiguousarray(canvas)
|
||||
|
||||
if int(hand_marker_radius_px) <= 0:
|
||||
hand_marker_radius_px = max(1, int(round(marker_radius_px * 0.7)))
|
||||
if int(hand_stick_width_px) <= 0:
|
||||
hand_stick_width_px = max(1, int(round(stick_width_px * 0.5)))
|
||||
|
||||
# Eyes+mouth indices into the rig fallback (FACE_LANDMARK_TARGETS): 6..13
|
||||
# = both eyes, 19..22 = outer-lip ring. Brows/nose/chin/jaw are dropped.
|
||||
_EYES_MOUTH_IDX = np.array([6, 7, 8, 9, 10, 11, 12, 13, 19, 20, 21, 22], dtype=np.int64)
|
||||
|
||||
include_face = face_style != "disabled"
|
||||
use_rig_only = face_style == "eyes_mouth"
|
||||
|
||||
# Real 238 sapiens face KPs take priority for 'full'; 'eyes_mouth' always
|
||||
# falls through to the rig path since sapiens has no documented subset.
|
||||
face_vert_ids: Optional[np.ndarray] = None
|
||||
if include_face:
|
||||
any_real = (not use_rig_only) and any(
|
||||
p.get("pred_face_keypoints_2d") is not None for p in persons
|
||||
)
|
||||
if not any_real:
|
||||
cc = pose_data.get("canonical_colors") or {}
|
||||
positions = cc.get("positions")
|
||||
if positions is not None:
|
||||
try:
|
||||
face_vert_ids = select_face_landmark_vert_ids(
|
||||
np.asarray(positions), face_mask=cc.get("face_mask"),
|
||||
)
|
||||
if use_rig_only:
|
||||
face_vert_ids = face_vert_ids[_EYES_MOUTH_IDX]
|
||||
except Exception as e:
|
||||
print(f"[SAM3DBody] face landmarks disabled - {e}")
|
||||
face_vert_ids = None
|
||||
|
||||
hand_dot_color = (
|
||||
_HAND_DOT_PALETTE_OPENPOSE if hand_color_style == "openpose"
|
||||
else (0, 0, 255)
|
||||
)
|
||||
|
||||
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
|
||||
|
||||
for k, person in enumerate(persons):
|
||||
pastel = 0.0 if k == 0 else (1.0 - falloff ** k)
|
||||
# Snapshot before this person's strokes so we can identify the pixels
|
||||
# they touched and blend just those toward white. Drawing happens
|
||||
# against the live canvas first so limb_alpha blends correctly.
|
||||
pre = canvas.copy() if pastel > 0 else None
|
||||
|
||||
kp134, scores134 = _pack_dwpose_134(
|
||||
person, include_body=include_body, include_hands=include_hands,
|
||||
)
|
||||
_KD.draw_wholebody_keypoints(
|
||||
canvas, kp134, scores=scores134, threshold=0.5,
|
||||
draw_body=include_body, draw_feet=False,
|
||||
draw_face=False, # SAM3D draws face dots separately (variable count)
|
||||
draw_hands=include_hands,
|
||||
stick_width=stick_width_px,
|
||||
marker_radius=marker_radius_px,
|
||||
hand_stick_width=hand_stick_width_px,
|
||||
hand_marker_radius=hand_marker_radius_px,
|
||||
limb_alpha=limb_alpha,
|
||||
hand_dot_color=hand_dot_color,
|
||||
)
|
||||
|
||||
if include_face:
|
||||
face_xy = None
|
||||
real_face = person.get("pred_face_keypoints_2d")
|
||||
if real_face is not None:
|
||||
arr = np.asarray(real_face, dtype=np.float32)
|
||||
if arr.ndim == 2 and arr.shape[1] == 2:
|
||||
face_xy = arr
|
||||
elif face_vert_ids is not None:
|
||||
face_xy = _project_face_landmarks_2d(person, face_vert_ids, H, W)
|
||||
if face_xy is not None:
|
||||
_draw_face_dots(canvas, face_xy, face_marker_radius_px)
|
||||
|
||||
if pre is not None:
|
||||
changed = (canvas != pre).any(axis=-1)
|
||||
if changed.any():
|
||||
touched = canvas[changed].astype(np.float32)
|
||||
blended = touched * (1.0 - pastel) + 255.0 * pastel
|
||||
canvas[changed] = np.clip(blended, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
return torch.from_numpy(canvas.astype(np.float32) / 255.0)
|
||||
516
comfy_extras/sam3d_body/face_expression.py
Normal file
516
comfy_extras/sam3d_body/face_expression.py
Normal file
@ -0,0 +1,516 @@
|
||||
"""Face expression for SAM 3D Body.
|
||||
|
||||
Pipeline: comfy_extras.mediapipe.face_landmarker → 52 ARKit blendshapes →
|
||||
72-dim MHR expr_params (mapping inlined below).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
# Bypass deadzone — jaw signals are clean (open or not)
|
||||
_NOISE_FREE_BLENDSHAPES = {"jawOpen", "jawForward", "jawLeft", "jawRight"}
|
||||
|
||||
# Per-region gain — MP magnitudes vary by family (jaw up to 1.0, eye/brow
|
||||
# rarely past 0.3), so a single global gain over/underdrives.
|
||||
_REGION_PREFIXES = {
|
||||
"mouth": ("jaw", "mouth"),
|
||||
"eye": ("eye",),
|
||||
"brow": ("brow", "cheek", "nose"), # cheek/nose read as upper-face
|
||||
}
|
||||
|
||||
|
||||
def _region_of(arkit_name: str) -> str:
|
||||
for region, prefixes in _REGION_PREFIXES.items():
|
||||
for p in prefixes:
|
||||
if arkit_name.startswith(p):
|
||||
return region
|
||||
return "other"
|
||||
|
||||
|
||||
# MHR axis → ARKit driver(s). Each axis collects 1-3 (name, weight) entries;
|
||||
# the consumer takes max() across them so primary + aux contributions don't
|
||||
# stack. MHR's 72 expression axes ship as anonymous `shape_c_N` channels in
|
||||
# the upstream FBX (no semantic names), so this table is hand-derived by
|
||||
# visual inspection of which axis each ARKit shape drives. Axes 2/3 and
|
||||
# 12/13 are filled by aux routes only. ARKit shapes with no MHR analog are simply absent.
|
||||
_AXIS_TO_ARKIT: Dict[int, List[Tuple[str, float]]] = {
|
||||
0: [("browDownLeft", 1.0)],
|
||||
1: [("browDownRight", 1.0)],
|
||||
2: [("cheekPuff", 1.0)],
|
||||
3: [("cheekPuff", 1.0)],
|
||||
4: [("cheekSquintLeft", 1.0)],
|
||||
5: [("cheekSquintRight", 1.0)],
|
||||
6: [("mouthStretchLeft", 1.0)],
|
||||
7: [("mouthStretchRight", 1.0)],
|
||||
8: [("mouthShrugLower", 1.0)],
|
||||
9: [("mouthShrugUpper", 1.0)],
|
||||
10: [("mouthDimpleLeft", 1.0)],
|
||||
11: [("mouthDimpleRight", 1.0)],
|
||||
12: [("eyeLookDownLeft", 0.3)],
|
||||
13: [("eyeLookDownRight", 0.3)],
|
||||
14: [("eyeBlinkLeft", 1.0)],
|
||||
15: [("eyeBlinkRight", 1.0)],
|
||||
16: [("eyeLookOutLeft", 1.0)],
|
||||
17: [("eyeLookInRight", 1.0)],
|
||||
18: [("eyeLookInLeft", 1.0)],
|
||||
19: [("eyeLookOutRight", 1.0)],
|
||||
22: [("eyeLookUpLeft", 1.0), ("browInnerUp", 0.5)],
|
||||
23: [("eyeLookUpRight", 1.0), ("browInnerUp", 0.5)],
|
||||
24: [("jawOpen", 1.0), ("mouthLowerDownLeft", 0.3), ("mouthLowerDownRight", 0.3)],
|
||||
25: [("jawLeft", 1.0)],
|
||||
26: [("jawRight", 1.0)],
|
||||
27: [("jawForward", 1.0)],
|
||||
28: [("eyeSquintLeft", 1.0)],
|
||||
29: [("eyeSquintRight", 1.0)],
|
||||
32: [("mouthSmileLeft", 1.0)],
|
||||
33: [("mouthSmileRight", 1.0)],
|
||||
40: [("mouthLeft", 1.0)],
|
||||
41: [("mouthRight", 1.0)],
|
||||
42: [("mouthFrownLeft", 1.0)],
|
||||
43: [("mouthFrownRight", 1.0)],
|
||||
54: [("mouthLowerDownLeft", 1.0)],
|
||||
55: [("mouthLowerDownRight", 1.0)],
|
||||
60: [("noseSneerLeft", 1.0)],
|
||||
61: [("noseSneerRight", 1.0)],
|
||||
66: [("browOuterUpLeft", 1.0)],
|
||||
67: [("browOuterUpRight", 1.0)],
|
||||
68: [("eyeWideLeft", 1.0)],
|
||||
69: [("eyeWideRight", 1.0)],
|
||||
70: [("mouthUpperUpLeft", 1.0)],
|
||||
71: [("mouthUpperUpRight", 1.0)],
|
||||
}
|
||||
|
||||
|
||||
def _deadzone(x: float, threshold: float) -> float:
|
||||
"""Zero below threshold, linearly remap (threshold..1] → (0..1] so
|
||||
amplification doesn't promote MP's per-blendshape noise floor."""
|
||||
if threshold <= 0.0:
|
||||
return x
|
||||
if x <= threshold:
|
||||
return 0.0
|
||||
return (x - threshold) / (1.0 - threshold)
|
||||
|
||||
|
||||
def arkit_to_expr_params(
|
||||
blendshape_coefs: Dict[str, float],
|
||||
strength: float = 1.0,
|
||||
mouth_strength: float = 1.0,
|
||||
eye_strength: float = 1.0,
|
||||
brow_strength: float = 1.0,
|
||||
input_threshold: float = 0.0,
|
||||
n_axes: int = 72,
|
||||
) -> np.ndarray:
|
||||
"""Map MediaPipe's 52 ARKit blendshapes to MHR's 72 expr_params axes.
|
||||
Multiple ARKit names per axis combine via max() so primary + aux routes
|
||||
don't double up."""
|
||||
expr = np.zeros(n_axes, dtype=np.float32)
|
||||
region_scale = {
|
||||
"mouth": float(mouth_strength), "eye": float(eye_strength),
|
||||
"brow": float(brow_strength), "other": 1.0,
|
||||
}
|
||||
thr = float(input_threshold)
|
||||
for axis, routes in _AXIS_TO_ARKIT.items():
|
||||
best = 0.0
|
||||
for name, weight in routes:
|
||||
raw = float(blendshape_coefs.get(name, 0.0))
|
||||
name_thr = 0.0 if name in _NOISE_FREE_BLENDSHAPES else thr
|
||||
raw = _deadzone(raw, name_thr)
|
||||
c = raw * region_scale[_region_of(name)] * float(weight)
|
||||
if c > best:
|
||||
best = c
|
||||
expr[axis] = best * strength
|
||||
return expr
|
||||
|
||||
|
||||
def subtract_per_clip_baseline(
|
||||
per_frame_coefs: List[Optional[Dict[str, float]]],
|
||||
percentile: float = 5.0,
|
||||
) -> List[Optional[Dict[str, float]]]:
|
||||
"""Subtract per-blendshape p`percentile` baseline, clamp at 0. Adapts to
|
||||
per-subject MP bias (e.g. resting browOuterUp ~0.15 → permanent surprise
|
||||
under brow_strength=2.0) that a global deadzone can't catch."""
|
||||
if percentile <= 0.0:
|
||||
return list(per_frame_coefs)
|
||||
|
||||
names: set = set()
|
||||
for c in per_frame_coefs:
|
||||
if c is not None:
|
||||
names.update(c.keys())
|
||||
|
||||
baselines: Dict[str, float] = {}
|
||||
for n in names:
|
||||
vals = [c[n] for c in per_frame_coefs if c is not None and n in c]
|
||||
if vals:
|
||||
baselines[n] = float(np.percentile(vals, percentile))
|
||||
|
||||
return [
|
||||
None if c is None
|
||||
else {n: max(0.0, float(v) - baselines.get(n, 0.0)) for n, v in c.items()}
|
||||
for c in per_frame_coefs
|
||||
]
|
||||
|
||||
|
||||
def smooth_blendshape_series(
|
||||
per_frame_coefs: List[Optional[Dict[str, float]]],
|
||||
window: int = 7,
|
||||
sigma: Optional[float] = None,
|
||||
) -> List[Optional[Dict[str, float]]]:
|
||||
"""Gaussian-smooth each coefficient across time. MP per-frame output swings
|
||||
30-70% on static faces; smoothing pre-mapping cleans better than smoothing
|
||||
mesh verts. None frames pass through unchanged."""
|
||||
if window <= 1:
|
||||
return list(per_frame_coefs)
|
||||
if window % 2 == 0:
|
||||
window += 1
|
||||
if sigma is None:
|
||||
sigma = max(1.0, window / 5.0)
|
||||
|
||||
x = np.arange(window) - (window - 1) / 2.0
|
||||
k = np.exp(-(x ** 2) / (2 * sigma ** 2))
|
||||
k = k / k.sum()
|
||||
|
||||
names: set = set()
|
||||
for c in per_frame_coefs:
|
||||
if c is not None:
|
||||
names.update(c.keys())
|
||||
if not names:
|
||||
return list(per_frame_coefs)
|
||||
|
||||
N = len(per_frame_coefs)
|
||||
pad = window // 2
|
||||
out: List[Optional[Dict[str, float]]] = [None] * N
|
||||
|
||||
for name in names:
|
||||
series = np.zeros(N, dtype=np.float32)
|
||||
mask = np.zeros(N, dtype=bool)
|
||||
for i, c in enumerate(per_frame_coefs):
|
||||
if c is not None:
|
||||
series[i] = float(c.get(name, 0.0))
|
||||
mask[i] = True
|
||||
if not mask.any():
|
||||
continue
|
||||
if not mask.all():
|
||||
idx = np.arange(N)
|
||||
series = np.interp(idx, idx[mask], series[mask])
|
||||
padded = np.concatenate(
|
||||
[np.repeat(series[:1], pad), series, np.repeat(series[-1:], pad)]
|
||||
)
|
||||
filt = np.zeros_like(series)
|
||||
for i, w in enumerate(k):
|
||||
filt += w * padded[i: i + N]
|
||||
for i in range(N):
|
||||
if per_frame_coefs[i] is None:
|
||||
continue
|
||||
if out[i] is None:
|
||||
out[i] = {}
|
||||
out[i][name] = float(filt[i])
|
||||
return out
|
||||
|
||||
|
||||
def fill_detection_gaps(
|
||||
per_frame_coefs: List[Optional[Dict[str, float]]],
|
||||
method: str = "interpolate",
|
||||
max_gap: int = 12,
|
||||
) -> List[Optional[Dict[str, float]]]:
|
||||
"""Fill missing per-frame dicts so the signal doesn't slam to zero at
|
||||
undetected frames. method: 'interpolate' | 'hold' | 'zeros'.
|
||||
|
||||
`max_gap` applies to 'interpolate' and 'hold' — gaps longer than that stay
|
||||
None (don't fake too far). 'zeros' ignores `max_gap` on purpose: the goal
|
||||
there is to relax to neutral on every miss, no matter how long, otherwise
|
||||
long undetected runs would inherit Predict's per-frame expression."""
|
||||
if method == "zeros":
|
||||
names: set = set()
|
||||
for c in per_frame_coefs:
|
||||
if c is not None:
|
||||
names.update(c.keys())
|
||||
zero = {n: 0.0 for n in names}
|
||||
return [dict(zero) if c is None else c for c in per_frame_coefs]
|
||||
|
||||
N = len(per_frame_coefs)
|
||||
detected = [i for i, c in enumerate(per_frame_coefs) if c is not None]
|
||||
if not detected:
|
||||
return list(per_frame_coefs)
|
||||
|
||||
out: List[Optional[Dict[str, float]]] = list(per_frame_coefs)
|
||||
|
||||
for fi in range(N):
|
||||
if out[fi] is not None:
|
||||
continue
|
||||
prev_i = next((k for k in range(fi - 1, -1, -1) if per_frame_coefs[k] is not None), None)
|
||||
next_i = next((k for k in range(fi + 1, N) if per_frame_coefs[k] is not None), None)
|
||||
if prev_i is None and next_i is None:
|
||||
continue
|
||||
max_dist = max(
|
||||
(fi - prev_i) if prev_i is not None else 10**9,
|
||||
(next_i - fi) if next_i is not None else 10**9,
|
||||
)
|
||||
if max_dist > max_gap:
|
||||
continue
|
||||
|
||||
if method == "hold":
|
||||
src = per_frame_coefs[prev_i] if prev_i is not None else per_frame_coefs[next_i]
|
||||
out[fi] = dict(src)
|
||||
elif method == "interpolate":
|
||||
if prev_i is None:
|
||||
out[fi] = dict(per_frame_coefs[next_i])
|
||||
elif next_i is None:
|
||||
out[fi] = dict(per_frame_coefs[prev_i])
|
||||
else:
|
||||
w = (fi - prev_i) / (next_i - prev_i)
|
||||
a = per_frame_coefs[prev_i]
|
||||
b = per_frame_coefs[next_i]
|
||||
keys = set(a.keys()) | set(b.keys())
|
||||
out[fi] = {k: (1.0 - w) * a.get(k, 0.0) + w * b.get(k, 0.0) for k in keys}
|
||||
return out
|
||||
|
||||
|
||||
def detect_faces_in_crop(
|
||||
inner, image_rgb_uint8: np.ndarray, crop_xyxy: np.ndarray, num_faces: int = 1,
|
||||
) -> List[dict]:
|
||||
"""Run detection on a sub-region; remap bbox+landmarks back to full-image
|
||||
coords. Helps small/distant faces that fall below BlazeFace's min size."""
|
||||
H, W = image_rgb_uint8.shape[:2]
|
||||
x1, y1, x2, y2 = (int(round(float(v))) for v in crop_xyxy)
|
||||
x1, y1 = max(0, x1), max(0, y1)
|
||||
x2, y2 = min(W, x2), min(H, y2)
|
||||
if x2 - x1 < 16 or y2 - y1 < 16:
|
||||
return []
|
||||
crop = np.ascontiguousarray(image_rgb_uint8[y1:y2, x1:x2])
|
||||
faces = inner.face_landmarker.detect_batch([crop], num_faces=num_faces)[0]
|
||||
bbox_off = np.array([x1, y1, x1, y1], dtype=np.float32)
|
||||
xy_off = np.array([x1, y1], dtype=np.float32)
|
||||
for f in faces:
|
||||
f["bbox_xyxy"] = f["bbox_xyxy"] + bbox_off
|
||||
f["landmarks_xy"] = f["landmarks_xy"] + xy_off
|
||||
return faces
|
||||
|
||||
|
||||
# Crop helpers — feed MP a tight head region so it doesn't downsample the face
|
||||
# to 192px for full-frame detection.
|
||||
|
||||
def _expand_bbox(bbox_xyxy: np.ndarray, factor: float, W: int, H: int) -> np.ndarray:
|
||||
x1, y1, x2, y2 = (float(v) for v in bbox_xyxy)
|
||||
cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
|
||||
hw, hh = 0.5 * (x2 - x1) * factor, 0.5 * (y2 - y1) * factor
|
||||
return np.array([
|
||||
max(0.0, cx - hw), max(0.0, cy - hh),
|
||||
min(float(W), cx + hw), min(float(H), cy + hh),
|
||||
], dtype=np.float32)
|
||||
|
||||
|
||||
def head_region_crop(
|
||||
person_bbox: np.ndarray, expand: float, W: int, H: int, head_h_frac: float = 0.4,
|
||||
) -> np.ndarray:
|
||||
"""Crop upper `head_h_frac` of a body bbox — cropping the whole body wastes
|
||||
BlazeFace's 128² input on body pixels."""
|
||||
x1, y1, x2, y2 = (float(v) for v in person_bbox)
|
||||
body_h = y2 - y1
|
||||
if body_h <= 0 or x2 - x1 <= 0:
|
||||
return np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
|
||||
return _expand_bbox(np.array([x1, y1, x2, y1 + body_h * head_h_frac]), expand, W, H)
|
||||
|
||||
|
||||
# mhr70 convention: first five kp are COCO-style face landmarks in pixel coords.
|
||||
_FACE_KP_INDICES = (0, 1, 2, 3, 4) # nose, L-eye, R-eye, L-ear, R-ear
|
||||
|
||||
|
||||
def head_crop_from_keypoints(
|
||||
pred_keypoints_2d: np.ndarray, expand: float, W: int, H: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""Head crop from SAM3D nose/eyes/ears kp. HEAD_FIT pads forehead/chin
|
||||
since these only span the central face. None if <2 kp in-frame."""
|
||||
if pred_keypoints_2d is None:
|
||||
return None
|
||||
kp = np.asarray(pred_keypoints_2d, dtype=np.float32)
|
||||
if kp.ndim != 2 or kp.shape[0] <= max(_FACE_KP_INDICES):
|
||||
return None
|
||||
face = kp[list(_FACE_KP_INDICES), :2]
|
||||
in_frame = (face[:, 0] > 0) & (face[:, 1] > 0) & (face[:, 0] < W) & (face[:, 1] < H)
|
||||
valid = face[in_frame]
|
||||
if len(valid) < 2:
|
||||
return None
|
||||
x1, x2 = float(valid[:, 0].min()), float(valid[:, 0].max())
|
||||
y1, y2 = float(valid[:, 1].min()), float(valid[:, 1].max())
|
||||
cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
|
||||
span = max(x2 - x1, y2 - y1, 1.0)
|
||||
half = 0.5 * span * 1.8 * float(expand) # 1.8 = pad forehead+chin
|
||||
return np.array([
|
||||
max(0.0, cx - half), max(0.0, cy - half),
|
||||
min(float(W), cx + half), min(float(H), cy + half),
|
||||
], dtype=np.float32)
|
||||
|
||||
|
||||
# Face → person assignment when running full-frame detection.
|
||||
|
||||
def _iou_xyxy(a: np.ndarray, b: np.ndarray) -> float:
|
||||
ix1, iy1 = max(a[0], b[0]), max(a[1], b[1])
|
||||
ix2, iy2 = min(a[2], b[2]), min(a[3], b[3])
|
||||
iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
|
||||
inter = iw * ih
|
||||
if inter <= 0.0:
|
||||
return 0.0
|
||||
aw, ah = max(0.0, a[2] - a[0]), max(0.0, a[3] - a[1])
|
||||
bw, bh = max(0.0, b[2] - b[0]), max(0.0, b[3] - b[1])
|
||||
union = aw * ah + bw * bh - inter
|
||||
return float(inter / union) if union > 0.0 else 0.0
|
||||
|
||||
|
||||
def assign_faces_to_persons(
|
||||
face_bboxes: List[np.ndarray], person_bboxes: List[np.ndarray], min_iou: float = 0.01,
|
||||
) -> List[Optional[int]]:
|
||||
if not face_bboxes or not person_bboxes:
|
||||
return [None] * len(person_bboxes)
|
||||
assigned: List[Optional[int]] = [None] * len(person_bboxes)
|
||||
used: set = set()
|
||||
# Larger persons first — bigger bbox correlates with detectable face.
|
||||
order = sorted(range(len(person_bboxes)),
|
||||
key=lambda p: -((person_bboxes[p][2] - person_bboxes[p][0])
|
||||
* (person_bboxes[p][3] - person_bboxes[p][1])))
|
||||
for pi in order:
|
||||
best_iou = min_iou
|
||||
best_fi = None
|
||||
pb = person_bboxes[pi]
|
||||
for fi, fb in enumerate(face_bboxes):
|
||||
if fi in used:
|
||||
continue
|
||||
cx, cy = 0.5 * (fb[0] + fb[2]), 0.5 * (fb[1] + fb[3])
|
||||
inside = (pb[0] <= cx <= pb[2]) and (pb[1] <= cy <= pb[3])
|
||||
score = max(_iou_xyxy(fb, pb), 0.5 if inside else 0.0)
|
||||
if score > best_iou:
|
||||
best_iou = score
|
||||
best_fi = fi
|
||||
if best_fi is not None:
|
||||
assigned[pi] = best_fi
|
||||
used.add(best_fi)
|
||||
return assigned
|
||||
|
||||
|
||||
# Re-run MHR forward after writing expr_params back into pose_frames; updates
|
||||
# pred_vertices / pred_keypoints_2d/3d / pred_joint_coords / pred_global_rots.
|
||||
|
||||
def regenerate_mesh_from_params(inner, pose_frames: List[List[Dict[str, Any]]]) -> None:
|
||||
"""Re-run MHR forward and write verts/kp3d/kp2d/joint back in place.
|
||||
Drives MHR via euler params directly because hand refinement zeroes
|
||||
pred_pose_raw."""
|
||||
device = comfy.model_management.get_torch_device()
|
||||
head = inner.head_pose
|
||||
if head.mhr is None:
|
||||
return
|
||||
|
||||
B = len(pose_frames)
|
||||
max_p = max((len(f) for f in pose_frames), default=0)
|
||||
|
||||
for pid in range(max_p):
|
||||
grots, bpps, hands, shapes, scales, exprs, cam_ts, fls = [], [], [], [], [], [], [], []
|
||||
present: List[bool] = []
|
||||
for fi in range(B):
|
||||
if pid >= len(pose_frames[fi]):
|
||||
present.append(False)
|
||||
continue
|
||||
p = pose_frames[fi][pid]
|
||||
needed = ("global_rot", "body_pose_params", "hand_pose_params",
|
||||
"shape_params", "scale_params", "expr_params",
|
||||
"pred_cam_t", "focal_length")
|
||||
if any(p.get(k) is None for k in needed):
|
||||
present.append(False)
|
||||
continue
|
||||
grots.append(np.asarray(p["global_rot"], dtype=np.float32))
|
||||
bpps.append(np.asarray(p["body_pose_params"], dtype=np.float32))
|
||||
hands.append(np.asarray(p["hand_pose_params"], dtype=np.float32))
|
||||
shapes.append(np.asarray(p["shape_params"], dtype=np.float32))
|
||||
scales.append(np.asarray(p["scale_params"], dtype=np.float32))
|
||||
exprs.append(np.asarray(p["expr_params"], dtype=np.float32))
|
||||
cam_ts.append(np.asarray(p["pred_cam_t"], dtype=np.float32))
|
||||
fls.append(float(np.asarray(p["focal_length"]).reshape(-1)[0]))
|
||||
present.append(True)
|
||||
if not any(present):
|
||||
continue
|
||||
|
||||
global_rot_euler = torch.from_numpy(np.stack(grots)).to(device)
|
||||
body_pose_euler = torch.from_numpy(np.stack(bpps)).to(device)
|
||||
hand_t = torch.from_numpy(np.stack(hands)).to(device)
|
||||
shape_t = torch.from_numpy(np.stack(shapes)).to(device)
|
||||
scale_t = torch.from_numpy(np.stack(scales)).to(device)
|
||||
expr_t = torch.from_numpy(np.stack(exprs)).to(device)
|
||||
cam_t_t = torch.from_numpy(np.stack(cam_ts)).to(device)
|
||||
f_t = torch.tensor(fls, device=device, dtype=torch.float32)
|
||||
|
||||
verts, kp3d_full, joint_coords, _, joint_rotmats = head.mhr_forward(
|
||||
global_trans=torch.zeros_like(global_rot_euler),
|
||||
global_rot=global_rot_euler,
|
||||
body_pose_params=body_pose_euler,
|
||||
hand_pose_params=hand_t,
|
||||
scale_params=scale_t,
|
||||
shape_params=shape_t,
|
||||
expr_params=expr_t,
|
||||
return_keypoints=True,
|
||||
return_joint_coords=True,
|
||||
return_model_params=True,
|
||||
return_joint_rotations=True,
|
||||
)
|
||||
# y/z flip matches head_pose.forward (camera-y-down convention).
|
||||
verts = verts.clone()
|
||||
verts[..., [1, 2]] *= -1
|
||||
kp3d = kp3d_full[:, :70].clone()
|
||||
kp3d[..., [1, 2]] *= -1
|
||||
# 238 sapiens face landmarks (70:308) — track retargeted expression
|
||||
# so openpose face dots follow new mouth/eye/brow shape.
|
||||
kp3d_face = kp3d_full[:, 70:].clone()
|
||||
kp3d_face[..., [1, 2]] *= -1
|
||||
joint_coords = joint_coords.clone()
|
||||
joint_coords[..., [1, 2]] *= -1
|
||||
|
||||
# Recover principal point from any raw frame for reprojection.
|
||||
cx = cy = 0.0
|
||||
for fi in range(B):
|
||||
if not present[fi]:
|
||||
continue
|
||||
raw = pose_frames[fi][pid]
|
||||
kp2d_r = np.asarray(raw["pred_keypoints_2d"], dtype=np.float32)
|
||||
kp3d_r = np.asarray(raw["pred_keypoints_3d"], dtype=np.float32)
|
||||
ct_r = np.asarray(raw["pred_cam_t"], dtype=np.float32)
|
||||
fl_r = float(np.asarray(raw["focal_length"]).reshape(-1)[0])
|
||||
x, y, z = kp3d_r[0] + ct_r
|
||||
cx = float(kp2d_r[0, 0] - fl_r * x / max(z, 1e-6))
|
||||
cy = float(kp2d_r[0, 1] - fl_r * y / max(z, 1e-6))
|
||||
break
|
||||
|
||||
def _project_kp(kp3d_local: torch.Tensor) -> torch.Tensor:
|
||||
kp3d_cam = kp3d_local + cam_t_t.unsqueeze(1)
|
||||
u = f_t[:, None] * kp3d_cam[..., 0] / kp3d_cam[..., 2].clamp(min=1e-6) + cx
|
||||
v = f_t[:, None] * kp3d_cam[..., 1] / kp3d_cam[..., 2].clamp(min=1e-6) + cy
|
||||
return torch.stack([u, v], dim=-1)
|
||||
|
||||
kp2d = _project_kp(kp3d)
|
||||
kp2d_face = _project_kp(kp3d_face)
|
||||
|
||||
verts_np = verts.float().cpu().numpy()
|
||||
kp3d_np = kp3d.float().cpu().numpy()
|
||||
kp2d_np = kp2d.float().cpu().numpy()
|
||||
kp3d_face_np = kp3d_face.float().cpu().numpy()
|
||||
kp2d_face_np = kp2d_face.float().cpu().numpy()
|
||||
jc_np = joint_coords.float().cpu().numpy()
|
||||
jrot_np = joint_rotmats.float().cpu().numpy()
|
||||
|
||||
fi_active = 0
|
||||
for fi in range(B):
|
||||
if not present[fi]:
|
||||
continue
|
||||
pose_frames[fi][pid] = dict(pose_frames[fi][pid])
|
||||
p = pose_frames[fi][pid]
|
||||
p["pred_vertices"] = verts_np[fi_active]
|
||||
p["pred_keypoints_3d"] = kp3d_np[fi_active]
|
||||
p["pred_keypoints_2d"] = kp2d_np[fi_active]
|
||||
p["pred_face_keypoints_3d"] = kp3d_face_np[fi_active]
|
||||
p["pred_face_keypoints_2d"] = kp2d_face_np[fi_active]
|
||||
p["pred_joint_coords"] = jc_np[fi_active]
|
||||
p["pred_global_rots"] = jrot_np[fi_active]
|
||||
fi_active += 1
|
||||
467
comfy_extras/sam3d_body/rasterizer.py
Normal file
467
comfy_extras/sam3d_body/rasterizer.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""Pure-PyTorch rasterizer for SAM 3D Body meshes.
|
||||
|
||||
Algorithm: forward triangle rasterizer with hard z-buffer. Per-face screen
|
||||
bbox cull → faces sorted by bbox size and chunked under a fixed pixel
|
||||
budget → inside-test via edge functions, barycentric interpolation, depth
|
||||
test via `scatter_reduce_(amin)`.
|
||||
"""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
from .utils import jet_colormap
|
||||
|
||||
_CANONICAL_PRESETS = {"rainbow", "rainbow_face_normal", "rainbow_face_semantic"}
|
||||
|
||||
_rainbow_cache: dict = {}
|
||||
|
||||
def rainbow_colors_from_canonical(
|
||||
positions: np.ndarray,
|
||||
tilt_x_deg: float = 0.0,
|
||||
tilt_z_deg: float = 0.0,
|
||||
) -> np.ndarray:
|
||||
"""Compute per-vertex jet-colormap RGB from canonical (T-pose, Y-up) vertices.
|
||||
|
||||
Args:
|
||||
positions: (N_v, 3) canonical vertex positions, Y-up (head at max Y).
|
||||
tilt_x_deg: rotation of the jet axis around X (in degrees). Positive
|
||||
biases the ramp toward +Z (front).
|
||||
tilt_z_deg: rotation of the jet axis around Z (in degrees). Positive
|
||||
biases the ramp toward +X (right, in body frame).
|
||||
|
||||
Returns:
|
||||
(N_v, 3) float32 RGB in [0, 1].
|
||||
"""
|
||||
key = (id(positions), round(float(tilt_x_deg), 3), round(float(tilt_z_deg), 3))
|
||||
cached = _rainbow_cache.get(key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
theta_x = np.deg2rad(tilt_x_deg)
|
||||
theta_z = np.deg2rad(tilt_z_deg)
|
||||
axis = np.array([
|
||||
np.sin(theta_z),
|
||||
np.cos(theta_z) * np.cos(theta_x),
|
||||
np.cos(theta_z) * np.sin(theta_x),
|
||||
], dtype=np.float32)
|
||||
|
||||
s = positions @ axis
|
||||
s = (s - s.min()) / max(float(s.max() - s.min()), 1e-8)
|
||||
s = np.clip(s * 0.98, 0.0, 1.0).astype(np.float32)
|
||||
colors = jet_colormap(s)
|
||||
_rainbow_cache[key] = colors
|
||||
if len(_rainbow_cache) > 32:
|
||||
_rainbow_cache.pop(next(iter(_rainbow_cache)))
|
||||
return colors
|
||||
|
||||
|
||||
def _vertex_normals(verts: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
|
||||
"""Area-weighted per-vertex normals; matches `_compute_vertex_normals`."""
|
||||
v0 = verts[faces[:, 0]]
|
||||
v1 = verts[faces[:, 1]]
|
||||
v2 = verts[faces[:, 2]]
|
||||
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
vn = torch.zeros_like(verts)
|
||||
vn.index_add_(0, faces[:, 0], fn)
|
||||
vn.index_add_(0, faces[:, 1], fn)
|
||||
vn.index_add_(0, faces[:, 2], fn)
|
||||
return vn / vn.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
|
||||
|
||||
def _build_vcolor(canonical_colors, shader_preset, tilt_x, tilt_z):
|
||||
"""Mirrors the canonical_colors -> per-vertex RGB pipeline in
|
||||
`rasterizer.render_pose_data`. Returns a numpy float32 (V, 3) table."""
|
||||
positions = np.asarray(canonical_colors.get("positions"), dtype=np.float32)
|
||||
vcolor = rainbow_colors_from_canonical(positions, tilt_x_deg=tilt_x, tilt_z_deg=tilt_z).copy()
|
||||
|
||||
if shader_preset in ("rainbow_face_normal", "rainbow_face_semantic"):
|
||||
face_mask = canonical_colors.get("face_mask")
|
||||
if face_mask is not None and face_mask.any():
|
||||
if shader_preset == "rainbow_face_normal":
|
||||
norm = np.asarray(canonical_colors["norm"], dtype=np.float32)
|
||||
vcolor[face_mask] = norm[face_mask]
|
||||
else: # rainbow_face_semantic
|
||||
sem = np.asarray(canonical_colors["face_region_rgb"], dtype=np.float32)
|
||||
assigned = sem.sum(axis=1) > 0
|
||||
vcolor[assigned] = sem[assigned]
|
||||
return vcolor
|
||||
|
||||
|
||||
def _rasterize_chunk(
|
||||
fv_pix: torch.Tensor, # (Fc, 3, 2) — pixel coords (sub-pixel float)
|
||||
fv_z: torch.Tensor, # (Fc, 3) — image-frame z (smaller=closer)
|
||||
bb_min_x: torch.Tensor, bb_max_x: torch.Tensor, # (Fc,) clamped int bboxes
|
||||
bb_min_y: torch.Tensor, bb_max_y: torch.Tensor,
|
||||
max_sx: int, max_sy: int,
|
||||
W: int,
|
||||
):
|
||||
"""Rasterize a chunk of faces at pixel centers. Returns flat tensors of
|
||||
inside fragments: (pixel_idx, depth, face_local, bary).
|
||||
"""
|
||||
device = fv_pix.device
|
||||
if max_sx == 0 or max_sy == 0:
|
||||
return None
|
||||
|
||||
sx = bb_max_x - bb_min_x
|
||||
sy = bb_max_y - bb_min_y
|
||||
|
||||
px_off = torch.arange(max_sx, device=device)
|
||||
py_off = torch.arange(max_sy, device=device)
|
||||
|
||||
# Pixel-center sample positions, broadcast to (Fc, max_sy, max_sx).
|
||||
P_x = (bb_min_x[:, None, None] + px_off[None, None, :]).float() + 0.5
|
||||
P_y = (bb_min_y[:, None, None] + py_off[None, :, None]).float() + 0.5
|
||||
|
||||
in_bb = (px_off[None, None, :] < sx[:, None, None]) & \
|
||||
(py_off[None, :, None] < sy[:, None, None])
|
||||
|
||||
Ax = fv_pix[:, 0, 0][:, None, None]
|
||||
Ay = fv_pix[:, 0, 1][:, None, None]
|
||||
Bx = fv_pix[:, 1, 0][:, None, None]
|
||||
By = fv_pix[:, 1, 1][:, None, None]
|
||||
Cx = fv_pix[:, 2, 0][:, None, None]
|
||||
Cy = fv_pix[:, 2, 1][:, None, None]
|
||||
|
||||
area2 = (Bx - Ax) * (Cy - Ay) - (By - Ay) * (Cx - Ax) # (Fc, 1, 1)
|
||||
e_a = (Bx - P_x) * (Cy - P_y) - (By - P_y) * (Cx - P_x)
|
||||
e_b = (Cx - P_x) * (Ay - P_y) - (Cy - P_y) * (Ax - P_x)
|
||||
e_c = (Ax - P_x) * (By - P_y) - (Ay - P_y) * (Bx - P_x)
|
||||
|
||||
# Same-sign-as-area2 inside test (no back-face culling — match either winding).
|
||||
nondegen = area2.abs() > 1e-6 # threshold rejects near-degenerate triangles
|
||||
inside = (e_a * area2 >= 0) & (e_b * area2 >= 0) & (e_c * area2 >= 0)
|
||||
inside = inside & in_bb & nondegen
|
||||
|
||||
if not inside.any():
|
||||
return None
|
||||
|
||||
inv_a2 = torch.where(nondegen, 1.0 / area2, torch.zeros_like(area2))
|
||||
w_a = e_a * inv_a2
|
||||
w_b = e_b * inv_a2
|
||||
w_c = e_c * inv_a2
|
||||
|
||||
z_a = fv_z[:, 0, None, None]
|
||||
z_b = fv_z[:, 1, None, None]
|
||||
z_c = fv_z[:, 2, None, None]
|
||||
z_grid = w_a * z_a + w_b * z_b + w_c * z_c
|
||||
|
||||
fi, yi, xi = inside.nonzero(as_tuple=True)
|
||||
px_pixel = bb_min_x[fi] + xi
|
||||
py_pixel = bb_min_y[fi] + yi
|
||||
pixel_idx = (py_pixel * W + px_pixel).long()
|
||||
z_flat = z_grid[fi, yi, xi]
|
||||
bary_flat = torch.stack([w_a[fi, yi, xi], w_b[fi, yi, xi], w_c[fi, yi, xi]], dim=-1)
|
||||
return pixel_idx, z_flat, fi.long(), bary_flat
|
||||
|
||||
|
||||
def _rasterize_person(
|
||||
verts_world: torch.Tensor, faces: torch.Tensor,
|
||||
focal: float, W: int, H: int,
|
||||
z_buf: torch.Tensor, color_buf: torch.Tensor, mask_buf: torch.Tensor,
|
||||
shade_fn,
|
||||
):
|
||||
# Project image-frame verts to pixel coords. Skip verts at/behind camera.
|
||||
z_min_ok = 0.05
|
||||
valid_v = verts_world[:, 2] > z_min_ok
|
||||
safe_z = verts_world[:, 2].clamp(min=z_min_ok)
|
||||
px = 0.5 * W + focal * verts_world[:, 0] / safe_z
|
||||
py = 0.5 * H + focal * verts_world[:, 1] / safe_z
|
||||
|
||||
Fv_pix = torch.stack([px, py], dim=-1)[faces] # (F, 3, 2)
|
||||
Fv_z = verts_world[faces][..., 2] # (F, 3)
|
||||
Fv_valid = valid_v[faces].all(dim=-1)
|
||||
|
||||
sx_face = Fv_pix[..., 0]
|
||||
sy_face = Fv_pix[..., 1]
|
||||
bb_min_x = sx_face.amin(dim=-1).floor().long().clamp(min=0, max=W)
|
||||
bb_max_x = (sx_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=W)
|
||||
bb_min_y = sy_face.amin(dim=-1).floor().long().clamp(min=0, max=H)
|
||||
bb_max_y = (sy_face.amax(dim=-1).ceil().long() + 1).clamp(min=0, max=H)
|
||||
sx_all = bb_max_x - bb_min_x
|
||||
sy_all = bb_max_y - bb_min_y
|
||||
valid_face = Fv_valid & (sx_all > 0) & (sy_all > 0)
|
||||
|
||||
keep = torch.where(valid_face)[0]
|
||||
if keep.numel() == 0:
|
||||
return
|
||||
|
||||
# Sort kept faces by max bbox dimension so chunks stay similarly-sized.
|
||||
bbsize = torch.maximum(sx_all, sy_all)[keep]
|
||||
order = torch.argsort(bbsize)
|
||||
keep = keep[order]
|
||||
n = keep.numel()
|
||||
sx_cpu = sx_all[keep].tolist()
|
||||
sy_cpu = sy_all[keep].tolist()
|
||||
bbsize_cpu = bbsize[order].tolist()
|
||||
|
||||
PIXEL_BUDGET = 4_000_000
|
||||
MAX_CHUNK = 8192
|
||||
|
||||
i = 0
|
||||
while i < n:
|
||||
e = min(i + MAX_CHUNK, n)
|
||||
# Shrink chunk so worst-case per-face bbox stays within pixel budget.
|
||||
while e > i + 1:
|
||||
bb = bbsize_cpu[e - 1]
|
||||
if (e - i) * bb * bb <= PIXEL_BUDGET:
|
||||
break
|
||||
e = max(i + 1, e - max(1, (e - i) // 4))
|
||||
chunk = keep[i:e]
|
||||
max_sx = max(sx_cpu[i:e])
|
||||
max_sy = max(sy_cpu[i:e])
|
||||
i = e
|
||||
|
||||
result = _rasterize_chunk(
|
||||
Fv_pix[chunk], Fv_z[chunk],
|
||||
bb_min_x[chunk], bb_max_x[chunk],
|
||||
bb_min_y[chunk], bb_max_y[chunk],
|
||||
max_sx, max_sy, W,
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
pixel_idx, z_chunk, face_local, bary = result
|
||||
face_global = chunk[face_local]
|
||||
|
||||
# Atomic depth test against z_buf.
|
||||
old_at = z_buf[pixel_idx].clone()
|
||||
z_buf.scatter_reduce_(0, pixel_idx, z_chunk, reduce='amin', include_self=True)
|
||||
new_at = z_buf[pixel_idx]
|
||||
is_min = (z_chunk == new_at) & (new_at < old_at)
|
||||
if not is_min.any():
|
||||
continue
|
||||
|
||||
# Multiple fragments can land on the same pixel and share the new min;
|
||||
# stable-sort by pixel and keep the first of each run so shade_fn runs
|
||||
# once per winning pixel. O(M) where M = surviving fragments
|
||||
surv_pixel = pixel_idx[is_min]
|
||||
surv_face = face_global[is_min]
|
||||
surv_bary = bary[is_min]
|
||||
sort_perm = torch.argsort(surv_pixel, stable=True)
|
||||
sp = surv_pixel[sort_perm]
|
||||
first = torch.ones_like(sp, dtype=torch.bool)
|
||||
first[1:] = sp[1:] != sp[:-1]
|
||||
selected = sort_perm[first]
|
||||
|
||||
wp_idx = surv_pixel[selected]
|
||||
wp_face = surv_face[selected]
|
||||
wp_bary = surv_bary[selected]
|
||||
|
||||
color_buf[wp_idx] = shade_fn(wp_face, wp_bary)
|
||||
mask_buf[wp_idx] = True
|
||||
|
||||
|
||||
def _make_shade_fn(
|
||||
shader_preset, composite,
|
||||
view_normals_v, view_pos_v, vcolor_v, faces,
|
||||
base_color, light_dir, pastel_mix,
|
||||
):
|
||||
device = view_normals_v.device
|
||||
base_color_t = torch.as_tensor(base_color, dtype=torch.float32, device=device)
|
||||
light_dir_t = torch.as_tensor(light_dir, dtype=torch.float32, device=device)
|
||||
# Light-vector constants — normalized once per render call.
|
||||
l_unit = -light_dir_t
|
||||
l_unit = l_unit / l_unit.norm().clamp(min=1e-8)
|
||||
|
||||
if pastel_mix <= 0.0:
|
||||
apply_pastel = lambda rgb: rgb
|
||||
else:
|
||||
pm = float(pastel_mix)
|
||||
apply_pastel = lambda rgb: rgb * (1.0 - pm) + pm
|
||||
|
||||
def gather_n(face_idx, bary):
|
||||
n = (view_normals_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
|
||||
return n / n.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
|
||||
def gather_pos(face_idx, bary):
|
||||
return (view_pos_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
|
||||
|
||||
def gather_color(face_idx, bary):
|
||||
return (vcolor_v[faces[face_idx]] * bary.unsqueeze(-1)).sum(dim=1)
|
||||
|
||||
if composite == "silhouette":
|
||||
return lambda fi, ba: torch.ones((fi.shape[0], 3), device=device)
|
||||
|
||||
if shader_preset == "normals":
|
||||
# View-space surface normal encoded as RGB (OpenGL Y+ convention).
|
||||
# +X right → R; +Y up → G; +Z toward viewer → B. Each face shows mostly
|
||||
# one channel, matching standard normal-map visualization.
|
||||
def shade(face_idx, bary):
|
||||
n = gather_n(face_idx, bary)
|
||||
return apply_pastel(((n + 1.0) * 0.5).clamp(0.0, 1.0))
|
||||
return shade
|
||||
|
||||
use_vcolor = (shader_preset in _CANONICAL_PRESETS) and (vcolor_v is not None)
|
||||
|
||||
if not use_vcolor:
|
||||
# default.frag: ambient + diffuse + rim
|
||||
def shade(face_idx, bary):
|
||||
n = gather_n(face_idx, bary)
|
||||
v = -gather_pos(face_idx, bary)
|
||||
v = v / v.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
ndotl = (n * l_unit).sum(dim=-1).clamp(min=0)
|
||||
ndotv = (n * v).sum(dim=-1).clamp(min=0)
|
||||
rim = (1.0 - ndotv).pow(3.0)
|
||||
lit = 0.25 * base_color_t \
|
||||
+ 0.75 * base_color_t * ndotl.unsqueeze(-1) \
|
||||
+ 0.35 * rim.unsqueeze(-1)
|
||||
return apply_pastel(lit)
|
||||
return shade
|
||||
|
||||
if shader_preset == "rainbow":
|
||||
def shade(face_idx, bary):
|
||||
base = gather_color(face_idx, bary)
|
||||
n = gather_n(face_idx, bary)
|
||||
ndotl = (n * l_unit).sum(dim=-1).clamp(min=0)
|
||||
return apply_pastel(base * (0.65 + 0.35 * ndotl).unsqueeze(-1))
|
||||
return shade
|
||||
|
||||
# rainbow_face_* → rainbow_lit.frag. All light-direction & half-vector
|
||||
# constants depend only on the (constant) light_dir, so precompute them.
|
||||
key_l = l_unit
|
||||
fill_l = torch.stack([-key_l[0], key_l[1].abs(), -key_l[2]])
|
||||
view_dir = torch.tensor([0.0, 0.0, 1.0], device=device)
|
||||
h = key_l + view_dir
|
||||
h = h / h.norm().clamp(min=1e-8)
|
||||
|
||||
def shade(face_idx, bary):
|
||||
base = gather_color(face_idx, bary)
|
||||
n = gather_n(face_idx, bary)
|
||||
key_ndotl = (n * key_l).sum(dim=-1).clamp(min=0)
|
||||
fill_ndotl = (n * fill_l).sum(dim=-1).clamp(min=0)
|
||||
rim = (1.0 - n[..., 2].clamp(min=0)).pow(2.5) * 0.30
|
||||
shade_val = (0.45 + 0.45 * key_ndotl + 0.15 * fill_ndotl + rim * 0.5).clamp(min=0.0, max=1.25)
|
||||
ndoth = (n * h).sum(dim=-1).clamp(min=0)
|
||||
spec = ndoth.pow(48) * 0.12
|
||||
lit = base * shade_val.unsqueeze(-1) + spec.unsqueeze(-1)
|
||||
return apply_pastel(lit)
|
||||
return shade
|
||||
|
||||
|
||||
def render_pose_data_torch(
|
||||
pose_data: dict,
|
||||
frame_idx: int,
|
||||
W: int,
|
||||
H: int,
|
||||
background=None, # Optional[np.ndarray | torch.Tensor] (H, W, 3) fp32 [0, 1]
|
||||
composite: str = "over",
|
||||
opacity: float = 1.0,
|
||||
shader_preset: str = "default",
|
||||
base_color: Sequence[float] = (0.68, 0.71, 0.78),
|
||||
light_dir: Sequence[float] = (0.4, -0.7, -0.6),
|
||||
rainbow_tilt_x_deg: float = 0.0,
|
||||
rainbow_tilt_z_deg: float = 0.0,
|
||||
person_brightness_falloff: float = 0.6,
|
||||
) -> torch.Tensor:
|
||||
"""Render one frame of persons from `pose_data` at resolution WxH.
|
||||
|
||||
Returns an (H, W, 3) float32 torch.Tensor on the comfy compute device,
|
||||
ready to be stacked into the node's IMAGE output without a CPU round-trip."""
|
||||
device = comfy.model_management.get_torch_device()
|
||||
persons = pose_data["frames"][frame_idx] if frame_idx < len(pose_data["frames"]) else []
|
||||
if len(persons) == 0:
|
||||
if composite == "over" and background is not None:
|
||||
if isinstance(background, np.ndarray):
|
||||
bg = torch.as_tensor(background, dtype=torch.float32, device=device)
|
||||
else:
|
||||
bg = background.to(device=device, dtype=torch.float32) if (
|
||||
background.device != device or background.dtype != torch.float32
|
||||
) else background
|
||||
return bg.clamp(0.0, 1.0)
|
||||
return torch.zeros((H, W, 3), device=device, dtype=torch.float32)
|
||||
|
||||
faces = torch.as_tensor(np.asarray(pose_data["faces"], dtype=np.int64), device=device)
|
||||
|
||||
canonical_colors = pose_data.get("canonical_colors")
|
||||
using_canonical = shader_preset in _CANONICAL_PRESETS
|
||||
if using_canonical and canonical_colors is None:
|
||||
shader_preset = "default"
|
||||
using_canonical = False
|
||||
|
||||
vcolor = None
|
||||
if using_canonical:
|
||||
vcolor_np = _build_vcolor(canonical_colors, shader_preset,
|
||||
rainbow_tilt_x_deg, rainbow_tilt_z_deg)
|
||||
vcolor = torch.as_tensor(vcolor_np, dtype=torch.float32, device=device)
|
||||
|
||||
falloff = max(0.0, min(1.0, float(person_brightness_falloff)))
|
||||
person_pastel = [0.0 if k == 0 else (1.0 - falloff ** k) for k in range(len(persons))]
|
||||
|
||||
# Front-to-back draw order so nearer persons overdraw farther ones.
|
||||
order = sorted(range(len(persons)),
|
||||
key=lambda i: -float(np.asarray(persons[i]["pred_cam_t"]).reshape(-1)[2]))
|
||||
|
||||
HW = H * W
|
||||
z_buf = torch.full((HW,), float('inf'), device=device, dtype=torch.float32)
|
||||
color_buf = torch.zeros((HW, 3), device=device, dtype=torch.float32)
|
||||
mask_buf = torch.zeros(HW, device=device, dtype=torch.bool)
|
||||
|
||||
for idx in order:
|
||||
p = persons[idx]
|
||||
verts_np = np.asarray(p["pred_vertices"], dtype=np.float32).reshape(-1, 3)
|
||||
cam_t = np.asarray(p["pred_cam_t"], dtype=np.float32).reshape(3)
|
||||
verts_world = torch.as_tensor(verts_np + cam_t[None, :],
|
||||
device=device, dtype=torch.float32)
|
||||
focal = float(np.asarray(p.get("focal_length", 5000.0)).reshape(-1)[0])
|
||||
|
||||
# Image-frame (+Y down, +Z forward) → view-space (+Y up, -Z forward)
|
||||
# for shading, matching what the GL-style shader math expects.
|
||||
view_pos_v = torch.stack(
|
||||
[verts_world[:, 0], -verts_world[:, 1], -verts_world[:, 2]], dim=-1,
|
||||
)
|
||||
normals_world = _vertex_normals(verts_world, faces)
|
||||
view_normals_v = torch.stack(
|
||||
[normals_world[:, 0], -normals_world[:, 1], -normals_world[:, 2]], dim=-1,
|
||||
)
|
||||
|
||||
vcolor_p = vcolor if (vcolor is not None and vcolor.shape[0] == verts_world.shape[0]) else None
|
||||
# Only canonical-vcolor shaders need vcolor; geometric shaders
|
||||
# ('normals', 'depth') and the lit default work without it.
|
||||
if shader_preset in _CANONICAL_PRESETS and vcolor_p is None:
|
||||
effective_preset = "default"
|
||||
else:
|
||||
effective_preset = shader_preset
|
||||
|
||||
shade_fn = _make_shade_fn(
|
||||
effective_preset, composite,
|
||||
view_normals_v, view_pos_v, vcolor_p, faces,
|
||||
base_color, light_dir, person_pastel[idx],
|
||||
)
|
||||
|
||||
_rasterize_person(
|
||||
verts_world, faces, focal, W, H,
|
||||
z_buf, color_buf, mask_buf, shade_fn,
|
||||
)
|
||||
|
||||
# Stay on GPU through readback + composite.
|
||||
if shader_preset == "depth":
|
||||
# z_buf already holds linear image-frame z (smaller=closer; +inf where no mesh covers)
|
||||
# Normalize within the rendered mesh's range: near=white, far=black, background=black
|
||||
mask_2d = mask_buf.reshape(H, W)
|
||||
z_2d = z_buf.reshape(H, W)
|
||||
if mask_2d.any():
|
||||
zin = z_2d[mask_2d]
|
||||
zmin = zin.min()
|
||||
zr = (zin.max() - zmin).clamp(min=1e-6)
|
||||
norm = torch.where(mask_2d, 1.0 - (z_2d - zmin) / zr, torch.zeros_like(z_2d))
|
||||
else:
|
||||
norm = torch.zeros((H, W), device=device, dtype=torch.float32)
|
||||
rendered = torch.stack([norm, norm, norm], dim=-1)
|
||||
mask_f = mask_2d.float()
|
||||
else:
|
||||
rendered = color_buf.reshape(H, W, 3).clamp(0.0, 1.0)
|
||||
mask_f = mask_buf.reshape(H, W).float()
|
||||
|
||||
if composite == "over" and background is not None:
|
||||
if isinstance(background, np.ndarray):
|
||||
bg = torch.as_tensor(background, dtype=torch.float32, device=device)
|
||||
else:
|
||||
bg = background.to(device=device, dtype=torch.float32)
|
||||
a = mask_f.unsqueeze(-1)
|
||||
if opacity != 1.0:
|
||||
a = a * float(opacity)
|
||||
rendered = torch.lerp(bg, rendered, a)
|
||||
|
||||
return rendered
|
||||
397
comfy_extras/sam3d_body/utils.py
Normal file
397
comfy_extras/sam3d_body/utils.py
Normal file
@ -0,0 +1,397 @@
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from comfy.ldm.sam3d_body.utils import prepare_batch
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
from comfy.ldm.sam3d_body.model.model import SAM3DBody
|
||||
|
||||
|
||||
import comfy.model_management
|
||||
from comfy_api.latest import io
|
||||
import comfy.utils
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""xyxy bounds of a binary mask, with sub-5px speckles filtered out."""
|
||||
m = mask[..., 0] if mask.dim() == 3 else mask
|
||||
m_bool = m > 0
|
||||
if not m_bool.any():
|
||||
return None
|
||||
t = m_bool.to(torch.float32)[None, None]
|
||||
eroded = -F.max_pool2d(-t, kernel_size=5, stride=1, padding=2)
|
||||
keep = eroded[0, 0] > 0.5
|
||||
if not keep.any():
|
||||
keep = m_bool
|
||||
ys, xs = torch.where(keep)
|
||||
return torch.stack([
|
||||
xs.min().float(), ys.min().float(),
|
||||
(xs.max() + 1).float(), (ys.max() + 1).float(),
|
||||
])
|
||||
|
||||
|
||||
def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
|
||||
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
|
||||
resolution. Returns (per_frame_bboxes, per_frame_masks) or
|
||||
(None, None) when the track is empty / frame count doesn't match"""
|
||||
|
||||
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
|
||||
if packed is None:
|
||||
return None, None
|
||||
unpacked = unpack_masks(packed) # (N, K, Hm, Wm) bool
|
||||
N, K = unpacked.shape[:2]
|
||||
if N != B or K == 0:
|
||||
return None, None
|
||||
Hm, Wm = unpacked.shape[2], unpacked.shape[3]
|
||||
resized = F.interpolate(
|
||||
unpacked.float().reshape(N * K, 1, Hm, Wm),
|
||||
size=(H, W), mode="bilinear", align_corners=False,
|
||||
)
|
||||
arr = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W).cpu()
|
||||
|
||||
per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)]
|
||||
full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32)
|
||||
per_frame_bboxes = []
|
||||
for f in range(N):
|
||||
derived = []
|
||||
for k in range(K):
|
||||
b = _bbox_from_mask(arr[f, k])
|
||||
derived.append(b if b is not None else full_frame_bbox)
|
||||
per_frame_bboxes.append(torch.stack(derived, dim=0))
|
||||
return per_frame_bboxes, per_frame_masks
|
||||
|
||||
|
||||
# Soft budget for the batched Predict path
|
||||
BATCHED_CROPS_PER_CHUNK = 64
|
||||
|
||||
|
||||
def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[torch.Tensor]:
|
||||
"""(1,3,3) intrinsic matrix from a vertical FOV in degrees. Matches MoGe2's
|
||||
convention (vertical focal for both axes). Returns None for fov<=0 so the
|
||||
caller falls back to prepare_batch's diagonal-focal default."""
|
||||
if fov_degrees <= 0:
|
||||
return None
|
||||
focal = height / (2.0 * math.tan(math.radians(fov_degrees) / 2.0))
|
||||
return torch.tensor(
|
||||
[[[focal, 0.0, width / 2.0],
|
||||
[0.0, focal, height / 2.0],
|
||||
[0.0, 0.0, 1.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
||||
def cam_int_from_moge(moge_geometry, height: int, width: int) -> Optional[torch.Tensor]:
|
||||
"""(1,3,3) intrinsic matrix from a MoGe geometry payload. Uses MoGe's
|
||||
vertical focal for both axes; forces principal point to image center
|
||||
(overrides MoGe's predicted cx/cy to match prepare_batch's convention)."""
|
||||
if moge_geometry is None:
|
||||
return None
|
||||
# MOGE_GEOMETRY is a dict with optional keys (see comfy_extras/nodes_moge.py).
|
||||
K_norm = moge_geometry.get("intrinsics") if isinstance(moge_geometry, dict) else None
|
||||
if K_norm is None:
|
||||
return None
|
||||
if K_norm.ndim == 3:
|
||||
K_norm = K_norm[0]
|
||||
# MoGe stores fy in height-units (multiply by H to get pixels); vfov = fy.
|
||||
fy_norm = float(K_norm[1, 1].item())
|
||||
focal = fy_norm * height
|
||||
return torch.tensor(
|
||||
[[[focal, 0.0, width / 2.0],
|
||||
[0.0, focal, height / 2.0],
|
||||
[0.0, 0.0, 1.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
||||
def run_batched_single_chunk(
|
||||
inner: SAM3DBody,
|
||||
frames_rgb: List[torch.Tensor],
|
||||
per_frame_boxes: List[torch.Tensor],
|
||||
per_frame_masks: Optional[List[torch.Tensor]],
|
||||
image_size: Tuple[int, int],
|
||||
inference_type: str,
|
||||
K: int,
|
||||
cam_int: Optional[torch.Tensor] = None,
|
||||
) -> List[List[Dict[str, Any]]]:
|
||||
"""Run a SINGLE chunk of frames through run_inference in one forward."""
|
||||
N = len(frames_rgb)
|
||||
total = N * K
|
||||
|
||||
# Reset stateful caches on the model
|
||||
for attr in ("batch", "image_embeddings", "output"):
|
||||
if hasattr(inner, attr):
|
||||
setattr(inner, attr, None)
|
||||
inner.prev_prompt = []
|
||||
|
||||
boxes_stacked = torch.stack(
|
||||
[per_frame_boxes[f][k] for f in range(N) for k in range(K)], dim=0
|
||||
)
|
||||
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
|
||||
|
||||
if per_frame_masks is not None:
|
||||
# Broadcast a single-mask bundle to per-bbox: when the user supplied one
|
||||
# mask but multiple bboxes per frame, each bbox gets the same mask.
|
||||
flat_masks = []
|
||||
for f in range(N):
|
||||
mf = per_frame_masks[f]
|
||||
if mf.shape[0] == 1 and K > 1:
|
||||
mf = mf.repeat_interleave(K, dim=0)
|
||||
flat_masks.extend([mf[k] for k in range(K)])
|
||||
masks_stacked = torch.stack(flat_masks, dim=0)
|
||||
masks_score = torch.ones(total, dtype=torch.float32)
|
||||
else:
|
||||
masks_stacked = None
|
||||
masks_score = None
|
||||
|
||||
batch = prepare_batch(
|
||||
img_per_crop, boxes_stacked,
|
||||
input_size=image_size,
|
||||
masks=masks_stacked, masks_score=masks_score, cam_int=cam_int,
|
||||
)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
||||
inner._initialize_batch(batch)
|
||||
|
||||
outputs = inner.run_inference(
|
||||
img_per_crop,
|
||||
batch,
|
||||
inference_type=inference_type,
|
||||
thresh_wrist_angle=1.4,
|
||||
)
|
||||
|
||||
if inference_type == "full":
|
||||
pose_output, batch_lhand, batch_rhand, _, _ = outputs
|
||||
else:
|
||||
pose_output = outputs
|
||||
batch_lhand = batch_rhand = None
|
||||
|
||||
out = {k: v.cpu().numpy() for k, v in pose_output["mhr"].items()
|
||||
if v is not None and k != "faces"}
|
||||
|
||||
# Snapshot batch['bbox'] to CPU before we release `batch` references
|
||||
batch_bbox_cpu = batch["bbox"][0].cpu().numpy()
|
||||
lhand_bboxes = rhand_bboxes = None
|
||||
if inference_type == "full" and batch_lhand is not None and batch_rhand is not None:
|
||||
lhand_bboxes = [_bbox_from_center_scale(batch_lhand, i) for i in range(total)]
|
||||
rhand_bboxes = [_bbox_from_center_scale(batch_rhand, i) for i in range(total)]
|
||||
|
||||
del pose_output, batch, batch_lhand, batch_rhand, outputs
|
||||
|
||||
frames_out: List[List[Dict[str, Any]]] = []
|
||||
for f in range(N):
|
||||
persons: List[Dict[str, Any]] = []
|
||||
for k in range(K):
|
||||
idx = f * K + k
|
||||
p: Dict[str, Any] = {
|
||||
"bbox": batch_bbox_cpu[idx],
|
||||
"focal_length": out["focal_length"][idx],
|
||||
"pred_keypoints_3d": out["pred_keypoints_3d"][idx],
|
||||
"pred_keypoints_2d": out["pred_keypoints_2d"][idx],
|
||||
"pred_vertices": out["pred_vertices"][idx],
|
||||
"pred_cam_t": out["pred_cam_t"][idx],
|
||||
"pred_pose_raw": out["pred_pose_raw"][idx],
|
||||
"global_rot": out["global_rot"][idx],
|
||||
"body_pose_params": out["body_pose"][idx],
|
||||
"hand_pose_params": out["hand"][idx],
|
||||
"scale_params": out["scale"][idx],
|
||||
"shape_params": out["shape"][idx],
|
||||
"expr_params": out["face"][idx],
|
||||
"mask": (per_frame_masks[f][k] if per_frame_masks[f].shape[0] > 1 else per_frame_masks[f][0])
|
||||
if per_frame_masks is not None else None,
|
||||
"pred_joint_coords": out["pred_joint_coords"][idx],
|
||||
"pred_global_rots": out["joint_global_rots"][idx],
|
||||
"mhr_model_params": out["mhr_model_params"][idx],
|
||||
# 238 face landmarks from sapiens-308 (indices 70..308 of the pre-slice keypoint tensor).
|
||||
"pred_face_keypoints_3d": out["pred_face_keypoints_3d"][idx] if "pred_face_keypoints_3d" in out else None,
|
||||
"pred_face_keypoints_2d": out["pred_face_keypoints_2d"][idx] if "pred_face_keypoints_2d" in out else None,
|
||||
}
|
||||
if lhand_bboxes is not None:
|
||||
p["lhand_bbox"] = lhand_bboxes[idx]
|
||||
p["rhand_bbox"] = rhand_bboxes[idx]
|
||||
persons.append(p)
|
||||
frames_out.append(persons)
|
||||
return frames_out
|
||||
|
||||
|
||||
def run_batched_frames(
|
||||
inner: SAM3DBody,
|
||||
frames_rgb: List[torch.Tensor],
|
||||
per_frame_boxes: List[torch.Tensor],
|
||||
per_frame_masks: Optional[List[torch.Tensor]],
|
||||
image_size: Tuple[int, int],
|
||||
inference_type: str,
|
||||
cam_int: Optional[torch.Tensor] = None,
|
||||
pbar: Optional[comfy.utils.ProgressBar] = None,
|
||||
crops_per_chunk: int = BATCHED_CROPS_PER_CHUNK,
|
||||
) -> List[List[Dict[str, Any]]]:
|
||||
"""Run the clip through chunked batched run_inference calls.
|
||||
|
||||
Supports K persons per frame (K must be the same across frames — padded
|
||||
externally). Splits frames into chunks so chunk_frames * K <= budget; each
|
||||
chunk is one body forward + optional hand forwards over its person-crops
|
||||
"""
|
||||
N = len(frames_rgb)
|
||||
assert N > 0, "empty frame list"
|
||||
K_set = {len(b) for b in per_frame_boxes}
|
||||
assert len(K_set) == 1, f"batched path requires same bbox count per frame, got {K_set}"
|
||||
K = K_set.pop()
|
||||
assert K >= 1, "need at least one bbox per frame"
|
||||
|
||||
chunk_frames = max(1, crops_per_chunk // K)
|
||||
results: List[List[Dict[str, Any]]] = []
|
||||
with tqdm(total=N, desc="SAM3D body inference") as t:
|
||||
for start in range(0, N, chunk_frames):
|
||||
end = min(N, start + chunk_frames)
|
||||
sub_frames = frames_rgb[start:end]
|
||||
sub_boxes = per_frame_boxes[start:end]
|
||||
sub_masks = None if per_frame_masks is None else per_frame_masks[start:end]
|
||||
chunk_result = run_batched_single_chunk(
|
||||
inner, sub_frames, sub_boxes, sub_masks,
|
||||
image_size, inference_type, K,
|
||||
cam_int=cam_int,
|
||||
)
|
||||
results.extend(chunk_result)
|
||||
t.update(end - start)
|
||||
if pbar is not None:
|
||||
pbar.update(end - start)
|
||||
# Drop GPU caches so the next chunk starts from a clean allocator state
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
||||
|
||||
def _bbox_from_center_scale(batch, idx: int) -> np.ndarray:
|
||||
cx = batch["bbox_center"].flatten(0, 1)[idx][0].item()
|
||||
cy = batch["bbox_center"].flatten(0, 1)[idx][1].item()
|
||||
sx = batch["bbox_scale"].flatten(0, 1)[idx][0].item()
|
||||
sy = batch["bbox_scale"].flatten(0, 1)[idx][1].item()
|
||||
return np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2], dtype=np.float32)
|
||||
|
||||
# Wire types and small helpers shared across the SAM 3D Body node modules.
|
||||
|
||||
def image_to_uint8(image: torch.Tensor) -> torch.Tensor:
|
||||
"""ComfyUI image tensor (any shape, float 0..1) → uint8 tensor in [0, 255] on CPU."""
|
||||
return (image * 255.0).clamp(0.0, 255.0).to(dtype=torch.uint8, device="cpu")
|
||||
|
||||
|
||||
def compute_canonical_colors(model) -> Dict[str, np.ndarray]:
|
||||
"""Canonical rest-pose data for shader color lookups: positions (Nv,3),
|
||||
norm (Nv,3 in [0,1]), face_mask, head_mask, and face_region_rgb
|
||||
(per-region painted color from the .safetensors)."""
|
||||
verts = model.head_pose.canonical_vertices().float().cpu().numpy()
|
||||
faces = model.head_pose.faces.cpu().numpy()
|
||||
|
||||
v0 = verts[faces[:, 0]]
|
||||
v1 = verts[faces[:, 1]]
|
||||
v2 = verts[faces[:, 2]]
|
||||
fn = np.cross(v1 - v0, v2 - v0).astype(np.float32)
|
||||
vn = np.zeros_like(verts, dtype=np.float32)
|
||||
np.add.at(vn, faces[:, 0], fn)
|
||||
np.add.at(vn, faces[:, 1], fn)
|
||||
np.add.at(vn, faces[:, 2], fn)
|
||||
ln = np.linalg.norm(vn, axis=1, keepdims=True)
|
||||
ln[ln < 1e-8] = 1.0
|
||||
vn = vn / ln
|
||||
norm_map = ((vn + 1.0) * 0.5).astype(np.float32)
|
||||
|
||||
face_mask = _compute_face_mask(model)
|
||||
# Head: above jaw-neck (y>1.43) and narrower than shoulders (|x|<0.11).
|
||||
# Ears reach |x|≈0.09; shoulders start at |x|≈0.20.
|
||||
head_mask = (verts[:, 1] > 1.43) & (np.abs(verts[:, 0]) < 0.11)
|
||||
|
||||
# Painted per-vertex face region RGB ships in the model .safetensors as
|
||||
# `head_pose.face_region_rgb` and gets loaded by load_state_dict.
|
||||
face_region_rgb = model.head_pose.face_region_rgb.detach().float().cpu().numpy()
|
||||
|
||||
return {
|
||||
"positions": verts.astype(np.float32),
|
||||
"norm": norm_map,
|
||||
"face_mask": face_mask,
|
||||
"head_mask": head_mask,
|
||||
"face_region_rgb": face_region_rgb,
|
||||
}
|
||||
|
||||
|
||||
def compute_hand_vert_mask(model, hand_radius_m: float = 0.15, weight_threshold: float = 0.5) -> np.ndarray:
|
||||
"""(Nv,) bool mask of hand-region verts. Picks joints within `hand_radius_m`
|
||||
of the mhr70 hand keypoint clusters (indices 21..62), then sums sparse LBS
|
||||
weights across them; verts above `weight_threshold` are hand verts."""
|
||||
head = model.head_pose
|
||||
mhr = head.mhr
|
||||
device = head.scale_mean.device
|
||||
|
||||
zeros = lambda *s: torch.zeros(1, *s, device=device)
|
||||
out = head.mhr_forward(
|
||||
global_trans=zeros(3),
|
||||
global_rot=zeros(3),
|
||||
body_pose_params=zeros(130),
|
||||
hand_pose_params=zeros(head.num_hand_comps * 2),
|
||||
scale_params=zeros(head.num_scale_comps),
|
||||
shape_params=zeros(head.num_shape_comps),
|
||||
expr_params=zeros(head.num_face_comps),
|
||||
return_keypoints=True,
|
||||
return_joint_coords=True,
|
||||
)
|
||||
# Output order with these flags: (verts, kp, jcoords). See mhr_head.mhr_forward.
|
||||
_, kp, jcoords = out[0], out[1], out[2]
|
||||
kp = kp[0, :70].cpu().numpy()
|
||||
jcoords = jcoords[0].cpu().numpy()
|
||||
|
||||
right_center = kp[21:42].mean(axis=0)
|
||||
left_center = kp[42:63].mean(axis=0)
|
||||
j_dist_r = np.linalg.norm(jcoords - right_center, axis=1)
|
||||
j_dist_l = np.linalg.norm(jcoords - left_center, axis=1)
|
||||
is_hand_joint = (j_dist_r < hand_radius_m) | (j_dist_l < hand_radius_m)
|
||||
|
||||
lbs_w = mhr.lbs_skin_weights.cpu().numpy()
|
||||
lbs_v = mhr.lbs_vert_indices.cpu().numpy()
|
||||
lbs_j = mhr.lbs_skin_indices.cpu().numpy()
|
||||
is_hand_joint_f = is_hand_joint.astype(np.float32)
|
||||
|
||||
n_verts = mhr.NUM_VERTS
|
||||
hand_mass = np.zeros(n_verts, dtype=np.float32)
|
||||
np.add.at(hand_mass, lbs_v, lbs_w * is_hand_joint_f[lbs_j])
|
||||
return hand_mass >= weight_threshold
|
||||
|
||||
|
||||
def _compute_face_mask(model, disp_threshold_m: float = 1e-4) -> np.ndarray:
|
||||
"""(Nv,) bool mask of verts that move with face expression. Sweeps each of
|
||||
the 72 expression axes at coef=+1.0 and flags any vert that moves more
|
||||
than `disp_threshold_m` for at least one axis."""
|
||||
head = model.head_pose
|
||||
device = head.scale_mean.device
|
||||
num_face = head.num_face_comps
|
||||
|
||||
zeros = lambda *s: torch.zeros(1, *s, device=device)
|
||||
neutral_kw = dict(
|
||||
global_trans=zeros(3),
|
||||
global_rot=zeros(3),
|
||||
body_pose_params=zeros(130),
|
||||
hand_pose_params=zeros(head.num_hand_comps * 2),
|
||||
scale_params=zeros(head.num_scale_comps),
|
||||
shape_params=zeros(head.num_shape_comps),
|
||||
expr_params=zeros(num_face),
|
||||
)
|
||||
v0 = head.mhr_forward(**neutral_kw).cpu().numpy()[0] # (Nv, 3)
|
||||
|
||||
face_mask = np.zeros(v0.shape[0], dtype=bool)
|
||||
for axis in range(num_face):
|
||||
expr = zeros(num_face)
|
||||
expr[0, axis] = 1.0
|
||||
kw = dict(neutral_kw)
|
||||
kw["expr_params"] = expr
|
||||
v = head.mhr_forward(**kw).cpu().numpy()[0]
|
||||
face_mask |= (np.linalg.norm(v - v0, axis=1) > disp_threshold_m)
|
||||
return face_mask
|
||||
|
||||
|
||||
def jet_colormap(s: np.ndarray) -> np.ndarray:
|
||||
"""matplotlib jet, (N,) in [0,1] -> (N, 3) float32 RGB."""
|
||||
s = np.asarray(s, dtype=np.float32).clip(0.0, 1.0)
|
||||
r = np.interp(s, [0.0, 0.35, 0.66, 0.89, 1.0], [0.0, 0.0, 1.0, 1.0, 0.5])
|
||||
g = np.interp(s, [0.0, 0.125, 0.375, 0.64, 0.91, 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0])
|
||||
b = np.interp(s, [0.0, 0.11, 0.34, 0.65, 1.0], [0.5, 1.0, 1.0, 0.0, 0.0])
|
||||
return np.stack([r, g, b], axis=-1).astype(np.float32)
|
||||
Loading…
Reference in New Issue
Block a user