ComfyUI/comfy_extras/sam3d_body/utils.py
2026-06-16 00:38:29 +03:00

487 lines
20 KiB
Python

import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import numpy as np
from comfy.ldm.sam3d_body.utils import prepare_batch
from comfy.ldm.sam3.tracker import unpack_masks
from comfy.ldm.sam3d_body.model.model import SAM3DBody
import comfy.model_management
import comfy.utils
from tqdm import tqdm
def _bbox_from_mask(mask: torch.Tensor) -> Optional[torch.Tensor]:
"""xyxy bounds of a binary mask, with sub-5px speckles filtered out."""
m = mask[..., 0] if mask.dim() == 3 else mask
m_bool = m > 0
if not m_bool.any():
return None
t = m_bool.to(torch.float32)[None, None]
eroded = -F.max_pool2d(-t, kernel_size=5, stride=1, padding=2)
keep = eroded[0, 0] > 0.5
if not keep.any():
keep = m_bool
ys, xs = torch.where(keep)
return torch.stack([
xs.min().float(), ys.min().float(),
(xs.max() + 1).float(), (ys.max() + 1).float(),
])
def inputs_from_sam3_track(track_data, B: int, H: int, W: int):
"""Unpack SAM3_TRACK_DATA into per-frame per-object bboxes + masks at image
resolution. Returns (per_frame_bboxes, per_frame_masks) or
(None, None) when the track is empty / frame count doesn't match"""
packed = track_data.get("packed_masks") if isinstance(track_data, dict) else None
if packed is None:
return None, None
unpacked = unpack_masks(packed) # (N, K, Hm, Wm) bool
N, K = unpacked.shape[:2]
if N != B or K == 0:
return None, None
Hm, Wm = unpacked.shape[2], unpacked.shape[3]
resized = F.interpolate(
unpacked.float().reshape(N * K, 1, Hm, Wm),
size=(H, W), mode="bilinear", align_corners=False,
)
arr = (resized > 0.5).to(torch.uint8).reshape(N, K, H, W).cpu()
per_frame_masks = [arr[f, :, :, :, None].contiguous() for f in range(N)]
full_frame_bbox = torch.tensor([0.0, 0.0, float(W), float(H)], dtype=torch.float32)
per_frame_bboxes = []
for f in range(N):
derived = []
for k in range(K):
b = _bbox_from_mask(arr[f, k])
derived.append(b if b is not None else full_frame_bbox)
per_frame_bboxes.append(torch.stack(derived, dim=0))
return per_frame_bboxes, per_frame_masks
# Soft budget for the batched Predict path
BATCHED_CROPS_PER_CHUNK = 64
def _quat_to_mat_wxyz(w: float, x: float, y: float, z: float) -> np.ndarray:
"""(3,3) rotation from a wxyz quaternion; columns are the rotated axes."""
n = math.sqrt(w * w + x * x + y * y + z * z) or 1.0
w, x, y, z = w / n, x / n, y / n, z / n
return np.array([
[1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y)],
[2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x)],
[2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y)],
], dtype=np.float32)
def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[torch.Tensor]:
"""(1,3,3) intrinsic matrix from a vertical FOV in degrees. Matches MoGe2's
convention (vertical focal for both axes). Returns None for fov<=0 so the
caller falls back to prepare_batch's diagonal-focal default."""
if fov_degrees <= 0:
return None
focal = height / (2.0 * math.tan(math.radians(fov_degrees) / 2.0))
return torch.tensor(
[[[focal, 0.0, width / 2.0],
[0.0, focal, height / 2.0],
[0.0, 0.0, 1.0]]],
dtype=torch.float32,
)
def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any],
H: int, W: int) -> Dict[str, Any]:
"""Re-project every frame's pose through a Load3D 6DOF camera (position/
target/zoom + optional FOV). Returns a new mhr_pose_data; unchanged on
empty/invalid input."""
first_frame = mhr_pose_data["frames"][0] if mhr_pose_data["frames"] else []
if not first_frame:
return mhr_pose_data
# Per-person rig root (pred_cam_t) and body centroid (mesh mean), in camera space.
roots, centroids = [], []
for p in first_frame:
cam_t = p.get("pred_cam_t")
if cam_t is None:
continue
cam_t = np.asarray(cam_t, dtype=np.float32).reshape(3)
roots.append(cam_t)
v = p.get("pred_vertices")
centroids.append(np.asarray(v, dtype=np.float32).reshape(-1, 3).mean(axis=0) + cam_t
if v is not None else cam_t)
if not roots:
return mhr_pose_data
subj_root = np.mean(np.stack(roots, axis=0), axis=0)
subj_centroid = np.mean(np.stack(centroids, axis=0), axis=0)
# Meter-scale, so Three.js coords map 1:1 (Three.js Y-up → flip Y,Z)
pos = camera_info.get("position") or {}
tgt = camera_info.get("target") or {}
pos_v = np.array([float(pos.get("x", 0.0)), -float(pos.get("y", 5.0)), -float(pos.get("z", 0.0))], dtype=np.float32)
tgt_v = np.array([float(tgt.get("x", 0.0)), -float(tgt.get("y", 0.0)), -float(tgt.get("z", 0.0))], dtype=np.float32)
offset = pos_v - tgt_v
has_offset = float(np.linalg.norm(offset)) >= 1e-6
q = camera_info.get("quaternion")
if not has_offset and not q:
return mhr_pose_data # no viewpoint and no orientation -> nothing to apply
zoom = float(camera_info.get("zoom", 1.0)) or 1.0
# SAM3D roots near the feet. A target at the origin -> center the body centroid
if float(np.linalg.norm(tgt_v)) < 1e-6:
target = subj_centroid
else:
target = subj_root + tgt_v
if q:
mv = lambda v: np.array([v[0], -v[1], -v[2]], dtype=np.float32)
norm = lambda v: v / max(1e-6, float(np.linalg.norm(v)))
Rc = _quat_to_mat_wxyz(
float(q.get("w", 1.0)), float(q.get("x", 0.0)),
float(q.get("y", 0.0)), float(q.get("z", 0.0)),
) # columns = camera world axes
x_axis = norm(mv(Rc[:, 0])) # camera +X -> image right
y_axis = norm(mv(-Rc[:, 1])) # image +Y is down -> negative of camera up
z_axis = norm(mv(-Rc[:, 2])) # camera looks down local -Z -> view direction
else:
# x degenerates only when looking straight along world-up -> world +X.
z_axis = -offset / float(np.linalg.norm(offset))
x_axis = np.cross(z_axis, np.array([0.0, -1.0, 0.0], dtype=np.float32))
x_norm = float(np.linalg.norm(x_axis))
x_axis = x_axis / x_norm if x_norm > 1e-6 else np.array([1.0, 0.0, 0.0], dtype=np.float32)
y_axis = np.cross(z_axis, x_axis)
R = np.stack([x_axis, y_axis, z_axis], axis=0).astype(np.float32)
# Eye: dolly along the given offset; for a rotation-only camera (position ==
# target) keep the predicted viewing distance so only orientation/roll changes.
if has_offset:
eye = target + offset / max(0.01, zoom)
else:
d = max(0.1, float(target[2]))
eye = target - z_axis * (d / max(0.01, zoom))
# Lens: use the camera's own FoV; else the SAM3D predicted focal (viewpoint-
# only change). Three.js fov is vertical → focal from image height.
cam_fov = float(camera_info.get("fov", 0.0) or 0.0)
if cam_fov > 0:
new_focal = float(H) / (2.0 * float(np.tan(np.deg2rad(cam_fov) / 2.0)))
else:
f0 = first_frame[0].get("focal_length")
new_focal = (float(np.asarray(f0, dtype=np.float32).reshape(-1)[0]) if f0 is not None
else float(H) / (2.0 * float(np.tan(np.deg2rad(50.0) / 2.0))))
center = np.array([W * 0.5, H * 0.5], dtype=np.float32)
reproj = {"pred_keypoints_3d": "pred_keypoints_2d", "pred_face_keypoints_3d": "pred_face_keypoints_2d"}
new_frames: List[List[Dict[str, Any]]] = []
for frame in mhr_pose_data["frames"]:
scaled = []
for p in frame:
p = dict(p)
cam_t = p.get("pred_cam_t")
if cam_t is None:
scaled.append(p)
continue
cam_t = np.asarray(cam_t, dtype=np.float32).reshape(3)
for k in ("pred_keypoints_3d", "pred_vertices", "pred_face_keypoints_3d"):
v = p.get(k)
if v is None:
continue
cam = (np.asarray(v, dtype=np.float32) + cam_t - eye) @ R.T
p[k] = cam.astype(np.float32)
if k in reproj: # re-project the new 3D to 2D image coords
z = np.maximum(cam[..., 2:3], 1e-6)
p[reproj[k]] = (cam[..., :2] * new_focal / z + center).astype(np.float32)
p["pred_cam_t"] = np.zeros(3, dtype=np.float32)
p["focal_length"] = np.array(new_focal, dtype=np.float32)
scaled.append(p)
new_frames.append(scaled)
out = dict(mhr_pose_data)
out["frames"] = new_frames
return out
def run_batched_single_chunk(inner: SAM3DBody, frames_rgb: List[torch.Tensor], per_frame_boxes: List[torch.Tensor],
per_frame_masks: Optional[List[torch.Tensor]], image_size: Tuple[int, int], inference_type: str, K: int,
cam_int: Optional[torch.Tensor] = None) -> List[List[Dict[str, Any]]]:
"""Run a SINGLE chunk of frames through run_inference in one forward."""
N = len(frames_rgb)
total = N * K
# Reset stateful caches on the model
for attr in ("batch", "image_embeddings", "output"):
if hasattr(inner, attr):
setattr(inner, attr, None)
inner.prev_prompt = []
boxes_stacked = torch.stack(
[per_frame_boxes[f][k] for f in range(N) for k in range(K)], dim=0
)
img_per_crop = [frames_rgb[f] for f in range(N) for _ in range(K)]
if per_frame_masks is not None:
# Broadcast a single-mask bundle to per-bbox: when the user supplied one
# mask but multiple bboxes per frame, each bbox gets the same mask.
flat_masks = []
for f in range(N):
mf = per_frame_masks[f]
if mf.shape[0] == 1 and K > 1:
mf = mf.repeat_interleave(K, dim=0)
flat_masks.extend([mf[k] for k in range(K)])
masks_stacked = torch.stack(flat_masks, dim=0)
masks_score = torch.ones(total, dtype=torch.float32)
else:
masks_stacked = None
masks_score = None
batch = prepare_batch(
img_per_crop, boxes_stacked,
input_size=image_size,
masks=masks_stacked, masks_score=masks_score, cam_int=cam_int,
)
device = comfy.model_management.get_torch_device()
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
inner._initialize_batch(batch)
outputs = inner.run_inference(
img_per_crop,
batch,
inference_type=inference_type,
thresh_wrist_angle=1.4,
)
if inference_type == "full":
pose_output, batch_lhand, batch_rhand, _, _ = outputs
else:
pose_output = outputs
batch_lhand = batch_rhand = None
out = {k: v.float().cpu().numpy() for k, v in pose_output["mhr"].items()
if v is not None and k != "faces"}
# Snapshot batch['bbox'] to CPU before we release `batch` references
batch_bbox_cpu = batch["bbox"][0].cpu().numpy()
lhand_bboxes = rhand_bboxes = None
if inference_type == "full" and batch_lhand is not None and batch_rhand is not None:
lhand_bboxes = [_bbox_from_center_scale(batch_lhand, i) for i in range(total)]
rhand_bboxes = [_bbox_from_center_scale(batch_rhand, i) for i in range(total)]
del pose_output, batch, batch_lhand, batch_rhand, outputs
frames_out: List[List[Dict[str, Any]]] = []
for f in range(N):
persons: List[Dict[str, Any]] = []
for k in range(K):
idx = f * K + k
p: Dict[str, Any] = {
"bbox": batch_bbox_cpu[idx],
"focal_length": out["focal_length"][idx],
"pred_keypoints_3d": out["pred_keypoints_3d"][idx],
"pred_keypoints_2d": out["pred_keypoints_2d"][idx],
"pred_vertices": out["pred_vertices"][idx],
"pred_cam_t": out["pred_cam_t"][idx],
"pred_pose_raw": out["pred_pose_raw"][idx],
"global_rot": out["global_rot"][idx],
"body_pose_params": out["body_pose"][idx],
"hand_pose_params": out["hand"][idx],
"scale_params": out["scale"][idx],
"shape_params": out["shape"][idx],
"expr_params": out["face"][idx],
"mask": (per_frame_masks[f][k] if per_frame_masks[f].shape[0] > 1 else per_frame_masks[f][0])
if per_frame_masks is not None else None,
"pred_joint_coords": out["pred_joint_coords"][idx],
"pred_global_rots": out["joint_global_rots"][idx],
"mhr_model_params": out["mhr_model_params"][idx],
# 238 face landmarks from sapiens-308 (indices 70..308 of the pre-slice keypoint tensor).
"pred_face_keypoints_3d": out["pred_face_keypoints_3d"][idx] if "pred_face_keypoints_3d" in out else None,
"pred_face_keypoints_2d": out["pred_face_keypoints_2d"][idx] if "pred_face_keypoints_2d" in out else None,
}
if lhand_bboxes is not None:
p["lhand_bbox"] = lhand_bboxes[idx]
p["rhand_bbox"] = rhand_bboxes[idx]
persons.append(p)
frames_out.append(persons)
return frames_out
def run_batched_frames(
inner: SAM3DBody,
frames_rgb: List[torch.Tensor],
per_frame_boxes: List[torch.Tensor],
per_frame_masks: Optional[List[torch.Tensor]],
image_size: Tuple[int, int],
inference_type: str,
cam_int: Optional[torch.Tensor] = None,
pbar: Optional[comfy.utils.ProgressBar] = None,
crops_per_chunk: int = BATCHED_CROPS_PER_CHUNK,
) -> List[List[Dict[str, Any]]]:
"""Run the clip through chunked batched run_inference calls.
Supports K persons per frame (K must be the same across frames — padded
externally). Splits frames into chunks so chunk_frames * K <= budget; each
chunk is one body forward + optional hand forwards over its person-crops
"""
N = len(frames_rgb)
assert N > 0, "empty frame list"
K_set = {len(b) for b in per_frame_boxes}
assert len(K_set) == 1, f"batched path requires same bbox count per frame, got {K_set}"
K = K_set.pop()
assert K >= 1, "need at least one bbox per frame"
chunk_frames = max(1, crops_per_chunk // K)
results: List[List[Dict[str, Any]]] = []
with tqdm(total=N, desc="SAM3D body inference") as t:
for start in range(0, N, chunk_frames):
end = min(N, start + chunk_frames)
sub_frames = frames_rgb[start:end]
sub_boxes = per_frame_boxes[start:end]
sub_masks = None if per_frame_masks is None else per_frame_masks[start:end]
chunk_result = run_batched_single_chunk(
inner, sub_frames, sub_boxes, sub_masks,
image_size, inference_type, K,
cam_int=cam_int,
)
results.extend(chunk_result)
t.update(end - start)
if pbar is not None:
pbar.update(end - start)
# Drop GPU caches so the next chunk starts from a clean allocator state
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def _bbox_from_center_scale(batch, idx: int) -> np.ndarray:
cx = batch["bbox_center"].flatten(0, 1)[idx][0].item()
cy = batch["bbox_center"].flatten(0, 1)[idx][1].item()
sx = batch["bbox_scale"].flatten(0, 1)[idx][0].item()
sy = batch["bbox_scale"].flatten(0, 1)[idx][1].item()
return np.array([cx - sx / 2, cy - sy / 2, cx + sx / 2, cy + sy / 2], dtype=np.float32)
# Wire types and small helpers shared across the SAM 3D Body node modules.
def image_to_uint8(image: torch.Tensor) -> torch.Tensor:
"""ComfyUI image tensor (any shape, float 0..1) → uint8 tensor in [0, 255] on CPU."""
return (image * 255.0).clamp(0.0, 255.0).to(dtype=torch.uint8, device="cpu")
def compute_canonical_colors(model) -> Dict[str, np.ndarray]:
"""Canonical rest-pose data for shader color lookups: positions (Nv,3),
norm (Nv,3 in [0,1]), face_mask, head_mask, and face_region_rgb
(per-region painted color from the .safetensors)."""
verts = model.head_pose.canonical_vertices().float().cpu().numpy()
faces = model.head_pose.faces.cpu().numpy()
v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]]
v2 = verts[faces[:, 2]]
fn = np.cross(v1 - v0, v2 - v0).astype(np.float32)
vn = np.zeros_like(verts, dtype=np.float32)
np.add.at(vn, faces[:, 0], fn)
np.add.at(vn, faces[:, 1], fn)
np.add.at(vn, faces[:, 2], fn)
ln = np.linalg.norm(vn, axis=1, keepdims=True)
ln[ln < 1e-8] = 1.0
vn = vn / ln
norm_map = ((vn + 1.0) * 0.5).astype(np.float32)
face_mask = _compute_face_mask(model)
# Head: above jaw-neck (y>1.43) and narrower than shoulders (|x|<0.11).
# Ears reach |x|≈0.09; shoulders start at |x|≈0.20.
head_mask = (verts[:, 1] > 1.43) & (np.abs(verts[:, 0]) < 0.11)
# Painted per-vertex face region RGB ships in the model .safetensors as
# `head_pose.face_region_rgb` and gets loaded by load_state_dict.
face_region_rgb = model.head_pose.face_region_rgb.detach().float().cpu().numpy()
return {
"positions": verts.astype(np.float32),
"norm": norm_map,
"face_mask": face_mask,
"head_mask": head_mask,
"face_region_rgb": face_region_rgb,
}
def compute_hand_vert_mask(model, hand_radius_m: float = 0.15, weight_threshold: float = 0.5) -> np.ndarray:
"""(Nv,) bool mask of hand-region verts. Picks joints within `hand_radius_m`
of the mhr70 hand keypoint clusters (indices 21..62), then sums sparse LBS
weights across them; verts above `weight_threshold` are hand verts."""
head = model.head_pose
mhr = head.mhr
device = head.scale_mean.device
zeros = lambda *s: torch.zeros(1, *s, device=device)
out = head.mhr_forward(
global_trans=zeros(3),
global_rot=zeros(3),
body_pose_params=zeros(130),
hand_pose_params=zeros(head.num_hand_comps * 2),
scale_params=zeros(head.num_scale_comps),
shape_params=zeros(head.num_shape_comps),
expr_params=zeros(head.num_face_comps),
return_keypoints=True,
return_joint_coords=True,
)
# Output order with these flags: (verts, kp, jcoords). See mhr_head.mhr_forward.
_, kp, jcoords = out[0], out[1], out[2]
kp = kp[0, :70].cpu().numpy()
jcoords = jcoords[0].cpu().numpy()
right_center = kp[21:42].mean(axis=0)
left_center = kp[42:63].mean(axis=0)
j_dist_r = np.linalg.norm(jcoords - right_center, axis=1)
j_dist_l = np.linalg.norm(jcoords - left_center, axis=1)
is_hand_joint = (j_dist_r < hand_radius_m) | (j_dist_l < hand_radius_m)
lbs_w = mhr.lbs_skin_weights.cpu().numpy()
lbs_v = mhr.lbs_vert_indices.cpu().numpy()
lbs_j = mhr.lbs_skin_indices.cpu().numpy()
is_hand_joint_f = is_hand_joint.astype(np.float32)
n_verts = mhr.NUM_VERTS
hand_mass = np.zeros(n_verts, dtype=np.float32)
np.add.at(hand_mass, lbs_v, lbs_w * is_hand_joint_f[lbs_j])
return hand_mass >= weight_threshold
def _compute_face_mask(model, disp_threshold_m: float = 1e-4) -> np.ndarray:
"""(Nv,) bool mask of verts that move with face expression. Sweeps each of
the 72 expression axes at coef=+1.0 and flags any vert that moves more
than `disp_threshold_m` for at least one axis."""
head = model.head_pose
device = head.scale_mean.device
num_face = head.num_face_comps
zeros = lambda *s: torch.zeros(1, *s, device=device)
neutral_kw = dict(
global_trans=zeros(3),
global_rot=zeros(3),
body_pose_params=zeros(130),
hand_pose_params=zeros(head.num_hand_comps * 2),
scale_params=zeros(head.num_scale_comps),
shape_params=zeros(head.num_shape_comps),
expr_params=zeros(num_face),
)
v0 = head.mhr_forward(**neutral_kw).cpu().numpy()[0] # (Nv, 3)
face_mask = np.zeros(v0.shape[0], dtype=bool)
for axis in range(num_face):
expr = zeros(num_face)
expr[0, axis] = 1.0
kw = dict(neutral_kw)
kw["expr_params"] = expr
v = head.mhr_forward(**kw).cpu().numpy()[0]
face_mask |= (np.linalg.norm(v - v0, axis=1) > disp_threshold_m)
return face_mask
def jet_colormap(s: np.ndarray) -> np.ndarray:
"""matplotlib jet, (N,) in [0,1] -> (N, 3) float32 RGB."""
s = np.asarray(s, dtype=np.float32).clip(0.0, 1.0)
r = np.interp(s, [0.0, 0.35, 0.66, 0.89, 1.0], [0.0, 0.0, 1.0, 1.0, 0.5])
g = np.interp(s, [0.0, 0.125, 0.375, 0.64, 0.91, 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0])
b = np.interp(s, [0.0, 0.11, 0.34, 0.65, 1.0], [0.5, 1.0, 1.0, 0.0, 0.0])
return np.stack([r, g, b], axis=-1).astype(np.float32)