mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
208 lines
7.9 KiB
Python
208 lines
7.9 KiB
Python
"""BVH export for SAM 3D Body pose_data.
|
||
|
||
BVH stores explicit bone OFFSETs per joint, so any standard importer
|
||
(Blender, Maya, MotionBuilder, etc.) reconstructs anatomical bone orientations
|
||
directly — no heuristic guessing as needed for glTF. We skip the rig's joint 0
|
||
(static world anchor) and use joint 1 as the BVH ROOT (6 channels: XYZ pos +
|
||
ZXY rot); every other joint gets 3 channels (ZXY rot only). Rotations are
|
||
intrinsic Z-X-Y Euler degrees.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import io
|
||
from typing import Any, Dict, List
|
||
|
||
import numpy as np
|
||
|
||
from .glb_shared import (
|
||
bind_skel_state,
|
||
bone_locals_from_globals,
|
||
collect_tracks,
|
||
extract_rig_static,
|
||
global_skel_state_from_pose_data,
|
||
quat_sign_fix_per_joint,
|
||
unflip,
|
||
)
|
||
|
||
|
||
def _quat_to_zxy_euler_deg(quat: np.ndarray) -> np.ndarray:
|
||
"""xyzw quat → intrinsic Z-X-Y Euler degrees, returned as (..., 3) in
|
||
(z, x, y) order to match BVH's `CHANNELS Zrotation Xrotation Yrotation`."""
|
||
q = np.asarray(quat, dtype=np.float64)
|
||
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
||
# ZXY decomposition R = Rz(c)·Rx(a)·Ry(b):
|
||
# M[2][1] = 2(yz + xw) → sin(a)
|
||
# M[0][1] = 2(xy - zw) → -cos(a) sin(c)
|
||
# M[1][1] = 1 - 2(x² + z²) → cos(a) cos(c)
|
||
# M[2][0] = 2(xz - yw) → -cos(a) sin(b)
|
||
# M[2][2] = 1 - 2(x² + y²) → cos(a) cos(b)
|
||
M21 = np.clip(2.0 * (y * z + x * w), -1.0, 1.0)
|
||
M01 = 2.0 * (x * y - z * w)
|
||
M11 = 1.0 - 2.0 * (x * x + z * z)
|
||
M20 = 2.0 * (x * z - y * w)
|
||
M22 = 1.0 - 2.0 * (x * x + y * y)
|
||
a = np.arcsin(M21)
|
||
c = np.arctan2(-M01, M11)
|
||
b = np.arctan2(-M20, M22)
|
||
out = np.stack([np.rad2deg(c), np.rad2deg(a), np.rad2deg(b)], axis=-1)
|
||
return out.astype(np.float32)
|
||
|
||
|
||
def _find_bvh_root(parents: np.ndarray) -> int:
|
||
"""First child of the rig's world anchor so the static origin→body stick
|
||
bone gets left out. Falls back to the first root joint."""
|
||
NJ = parents.shape[0]
|
||
world_anchors = [j for j in range(NJ)
|
||
if not (0 <= int(parents[j]) < NJ and int(parents[j]) != j)]
|
||
if not world_anchors:
|
||
return 0
|
||
children: List[List[int]] = [[] for _ in range(NJ)]
|
||
for j in range(NJ):
|
||
p = int(parents[j])
|
||
if 0 <= p < NJ and p != j:
|
||
children[p].append(j)
|
||
wa = world_anchors[0]
|
||
if children[wa]:
|
||
return children[wa][0]
|
||
return wa
|
||
|
||
|
||
def _build_children_map(parents: np.ndarray) -> List[List[int]]:
|
||
NJ = parents.shape[0]
|
||
out: List[List[int]] = [[] for _ in range(NJ)]
|
||
for j in range(NJ):
|
||
p = int(parents[j])
|
||
if 0 <= p < NJ and p != j:
|
||
out[p].append(j)
|
||
return out
|
||
|
||
|
||
def build_bvh(
|
||
pose_data: Dict[str, Any],
|
||
model: Any,
|
||
*,
|
||
fps: float = 24.0,
|
||
camera_translation: str = "off",
|
||
track_index: int = -1,
|
||
units: str = "cm",
|
||
) -> bytes:
|
||
"""Build a BVH file from pose_data. Returns UTF-8 encoded text bytes.
|
||
|
||
`units` is "cm" (default, standard mocap convention) or "m". Affects the
|
||
OFFSET and root-position values; rotations are independent of units.
|
||
"""
|
||
if units not in ("cm", "m"):
|
||
raise ValueError(f"build_bvh: units must be 'cm' or 'm', got {units!r}")
|
||
unit_scale = 100.0 if units == "cm" else 1.0
|
||
|
||
rig_static = extract_rig_static(model)
|
||
NJ = int(rig_static["num_joints"])
|
||
parents = rig_static["parents"]
|
||
frames = pose_data["frames"]
|
||
|
||
tracks = collect_tracks(pose_data, track_index)
|
||
if not tracks:
|
||
raise ValueError("build_bvh: no valid tracks in pose_data")
|
||
person_k, frame_indices = tracks[0]
|
||
n_frames = len(frame_indices)
|
||
if n_frames == 0:
|
||
raise ValueError("build_bvh: track has zero frames")
|
||
|
||
body_root = _find_bvh_root(parents)
|
||
children_map = _build_children_map(parents)
|
||
|
||
# Bone OFFSETs come from MHR's translation_offsets (joint position
|
||
# relative to parent in parent's local-bind frame). For the BVH root,
|
||
# we use its bind world position so the skeleton sits at the right
|
||
# spot when imported.
|
||
bind_global = bind_skel_state(model) # (NJ, 8) cm
|
||
bind_pos_m = bind_global[:, :3].astype(np.float64) * 0.01 # (NJ, 3) m
|
||
offset_m = rig_static["joint_translation_offsets"].astype(np.float64) * 0.01
|
||
|
||
# DFS order rooted at body_root — matches per-frame channel order.
|
||
bvh_order: List[int] = []
|
||
def _visit(j: int) -> None:
|
||
bvh_order.append(j)
|
||
for c in children_map[j]:
|
||
_visit(c)
|
||
_visit(body_root)
|
||
|
||
# Use pose_data's stored pred_global_rots/pred_joint_coords (authoritative)
|
||
# rather than re-running rig.forward, then derive locals with body_root
|
||
# treated as the hierarchy root in BVH-space.
|
||
rig_global_m = global_skel_state_from_pose_data(
|
||
pose_data, frame_indices, person_k, NJ,
|
||
)
|
||
rig_global_m[..., 3:7] = quat_sign_fix_per_joint(rig_global_m[..., 3:7])
|
||
bvh_parents = parents.copy()
|
||
bvh_parents[body_root] = -1
|
||
bone_local = bone_locals_from_globals(rig_global_m, bvh_parents)
|
||
# Second pass catches sign discontinuities from the parent-inverse composition.
|
||
bone_local[..., 3:7] = quat_sign_fix_per_joint(bone_local[..., 3:7])
|
||
|
||
eulers_deg = _quat_to_zxy_euler_deg(bone_local[..., 3:7])
|
||
|
||
if camera_translation in ("absolute", "centered"):
|
||
cam_t = np.stack([
|
||
unflip(np.asarray(frames[t][person_k]["pred_cam_t"], dtype=np.float32))
|
||
for t in frame_indices
|
||
], axis=0).astype(np.float64)
|
||
if camera_translation == "centered":
|
||
cam_t = cam_t - cam_t[0:1]
|
||
root_pos_m = bind_pos_m[body_root][None, :] + cam_t
|
||
else:
|
||
root_pos_m = np.tile(bind_pos_m[body_root], (n_frames, 1))
|
||
|
||
lines: List[str] = ["HIERARCHY"]
|
||
|
||
def _emit_joint(j: int, depth: int, is_root: bool) -> None:
|
||
ind = " " * depth
|
||
keyword = "ROOT" if is_root else "JOINT"
|
||
name = "Hips" if is_root else f"joint_{j:03d}"
|
||
lines.append(f"{ind}{keyword} {name}")
|
||
lines.append(ind + "{")
|
||
o = (bind_pos_m[j] if is_root else offset_m[j]) * unit_scale
|
||
lines.append(f"{ind} OFFSET {o[0]:.6f} {o[1]:.6f} {o[2]:.6f}")
|
||
if is_root:
|
||
lines.append(ind + " CHANNELS 6 Xposition Yposition Zposition "
|
||
"Zrotation Xrotation Yrotation")
|
||
else:
|
||
lines.append(ind + " CHANNELS 3 Zrotation Xrotation Yrotation")
|
||
kids = children_map[j]
|
||
if kids:
|
||
for c in kids:
|
||
_emit_joint(c, depth + 1, is_root=False)
|
||
else:
|
||
# End Site (standard BVH spec) gives leaf bones a drawable length.
|
||
lines.append(ind + " End Site")
|
||
lines.append(ind + " {")
|
||
tip = (offset_m[j] * unit_scale) * 0.3
|
||
tip_norm = float(np.linalg.norm(tip))
|
||
if tip_norm < 0.5 * unit_scale * 0.01: # < 0.5 mm → fall back
|
||
tip = np.array([0.0, 0.05 * unit_scale, 0.0], dtype=np.float64)
|
||
lines.append(f"{ind} OFFSET {tip[0]:.6f} {tip[1]:.6f} {tip[2]:.6f}")
|
||
lines.append(ind + " }")
|
||
lines.append(ind + "}")
|
||
|
||
_emit_joint(body_root, 0, is_root=True)
|
||
|
||
lines.append("MOTION")
|
||
lines.append(f"Frames: {n_frames}")
|
||
lines.append(f"Frame Time: {1.0 / float(fps):.6f}")
|
||
|
||
# Channel matrix: root pos (3) + root rot (3) + non-root rots (3 each) per
|
||
# frame, columns in `bvh_order` order. Vectorized — savetxt's C-side
|
||
# formatting beats Python f-strings by ~10× on long clips.
|
||
non_root_idx = np.asarray(bvh_order[1:], dtype=np.int64)
|
||
motion = np.concatenate([
|
||
root_pos_m * unit_scale, # (N, 3)
|
||
eulers_deg[:, body_root].astype(np.float64), # (N, 3)
|
||
eulers_deg[:, non_root_idx, :].reshape(n_frames, -1), # (N, 3*(NJ-1))
|
||
], axis=1)
|
||
buf = io.StringIO()
|
||
np.savetxt(buf, motion, fmt="%.6f")
|
||
lines.append(buf.getvalue().rstrip("\n"))
|
||
|
||
return ("\n".join(lines) + "\n").encode("utf-8")
|