"""3D capsule rendering for OpenPose-style skeletons — SCAIL-Pose-equivalent torch ray-marching SDF renderer adapted to SAM3DBody pose_data. Each limb is drawn as a true 3D capsule (cylinder + hemispherical caps), projected through the per-person camera (`pred_cam_t` + `focal_length` + image_size) so closer limbs appear thicker/brighter — the SCAIL-Pose visual style. Self-contained: no dependency on the SCAIL-Pose package. Output: (H, W, 3) fp32 torch.Tensor in [0, 1]. """ from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch import comfy.model_management from .glb_shared import ( OPENPOSE_18_PAIRS, OPENPOSE18_TO_MHR70, OPENPOSE_RAINBOW_18, SCAIL_LIMB_COLORS_17, OPENPOSE_HAND_PAIRS, OPENPOSE_HAND21_TO_MHR70_R, OPENPOSE_HAND21_TO_MHR70_L, OPENPOSE_HAND_COLORS_21, ) def _limb_palette_rgb01(palette: str) -> np.ndarray: """17 per-limb RGB colors in [0,1] for the OpenPose-18 body limbs.""" if palette == "scail": return SCAIL_LIMB_COLORS_17.astype(np.float32) return OPENPOSE_RAINBOW_18[: len(OPENPOSE_18_PAIRS)].astype(np.float32) def _build_specs_from_pose( persons: List[Dict[str, Any]], *, include_hands: bool, palette: str, person_brightness_falloff: float = 0.0, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Flatten body + optional hand limbs for one frame into (starts, ends, colors_rgba, is_hand) in camera coords (Y-down, +Z forward). Drops endpoints that are non-finite or behind the camera. `is_hand` flags the hand limbs so the renderer can draw them thinner. `person_brightness_falloff` mixes each per-person limb color toward white by `1 - falloff^k` for track index `k` (track 0 stays vivid). Matches the mesh rasterizer and GLB exporters.""" starts: List[np.ndarray] = [] ends: List[np.ndarray] = [] colors: List[np.ndarray] = [] is_hand: List[bool] = [] body_limb_colors = _limb_palette_rgb01(palette) hand_limb_colors = OPENPOSE_HAND_COLORS_21.astype(np.float32) falloff = max(0.0, min(1.0, float(person_brightness_falloff))) for k, person in enumerate(persons): kp2d_full = person.get("pred_keypoints_3d") cam_t = person.get("pred_cam_t") if kp2d_full is None or cam_t is None: continue kp = np.asarray(kp2d_full, dtype=np.float32) if kp.ndim != 2 or kp.shape[1] != 3 or kp.shape[0] < 70: continue cam_t_np = np.asarray(cam_t, dtype=np.float32).reshape(3) # pred_keypoints_3d is camera frame (Y-down post-flip); add cam_t to # place the subject in front of the camera. kp_cam = kp + cam_t_np[None, :] pastel = 0.0 if k == 0 else (1.0 - falloff ** k) def _tint(rgb: np.ndarray) -> np.ndarray: if pastel <= 0: return rgb return rgb * (1.0 - pastel) + pastel # SCAIL skips face bones (13..16) and redirects limb 12 into a short # head stub blending spine + neck→nose direction. body_limb_count = 13 if palette == "scail" else len(OPENPOSE_18_PAIRS) body_kp = kp_cam[OPENPOSE18_TO_MHR70] spine_dir = None if palette == "scail": mid_hip = 0.5 * (body_kp[8] + body_kp[11]) # 8=RHip, 11=LHip sd = body_kp[1] - mid_hip # 1=Neck sd_len = float(np.linalg.norm(sd)) if np.all(np.isfinite(sd)) and sd_len > 1e-6: spine_dir = sd / sd_len for limb_i, (a, b) in enumerate(OPENPOSE_18_PAIRS[:body_limb_count]): sa, sb = body_kp[a], body_kp[b] if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))): continue if sa[2] <= 0 or sb[2] <= 0: continue if palette == "scail" and limb_i == 12: nose_vec = sb - sa nose_len = float(np.linalg.norm(nose_vec)) if nose_len > 1e-6 and spine_dir is not None: nose_dir = nose_vec / nose_len mixed = 0.6 * spine_dir + 0.4 * nose_dir mixed = mixed / max(float(np.linalg.norm(mixed)), 1e-6) sb = sa + mixed * (nose_len * 0.5) elif nose_len > 1e-6: sb = sa + nose_vec * 0.5 elif spine_dir is not None: sb = sa + spine_dir * (sd_len * 0.3) starts.append(sa) ends.append(sb) is_hand.append(False) color_rgb = _tint(body_limb_colors[limb_i]) colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0], dtype=np.float32)) if include_hands: r_kp = kp_cam[OPENPOSE_HAND21_TO_MHR70_R] l_kp = kp_cam[OPENPOSE_HAND21_TO_MHR70_L] for limb_i, (a, b) in enumerate(OPENPOSE_HAND_PAIRS): for hand_kp in (r_kp, l_kp): sa, sb = hand_kp[a], hand_kp[b] if not (np.all(np.isfinite(sa)) and np.all(np.isfinite(sb))): continue if sa[2] <= 0 or sb[2] <= 0: continue starts.append(sa) ends.append(sb) is_hand.append(True) color_rgb = _tint(hand_limb_colors[(a + b) % len(hand_limb_colors)]) colors.append(np.array([color_rgb[0], color_rgb[1], color_rgb[2], 1.0], dtype=np.float32)) if not starts: return (np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32), np.zeros((0, 4), dtype=np.float32), np.zeros((0,), dtype=bool)) return (np.stack(starts).astype(np.float32), np.stack(ends).astype(np.float32), np.stack(colors).astype(np.float32), np.asarray(is_hand, dtype=bool)) def _ray_capsule_t( ray_dirs: torch.Tensor, # (K, 3) unit rays from camera origin starts: torch.Tensor, # (M, 3) ends: torch.Tensor, # (M, 3) ba_norm: torch.Tensor, # (M, 3) unit axis (A → B) ba_len: torch.Tensor, # (M,) segment length radius: torch.Tensor, # (M,) per-capsule radius ) -> torch.Tensor: """Closed-form ray-capsule intersection. Returns (K, M) tensor of ray parameters t to the nearest valid hit per capsule, +inf where the ray misses. A capsule is the union of (cylinder body, hemisphere at A, hemisphere at B); each component is a quadratic root-find.""" INF = float("inf") r_sq = radius * radius # (M,) # Cached dot products. dn = ray_dirs @ ba_norm.transpose(0, 1) # (K, M) — d·n dA = ray_dirs @ starts.transpose(0, 1) # (K, M) — d·A dB = ray_dirs @ ends.transpose(0, 1) # (K, M) — d·B An = (starts * ba_norm).sum(-1) # (M,) — A·n A_sq = (starts * starts).sum(-1) # (M,) — |A|² B_sq = (ends * ends).sum(-1) # (M,) — |B|² # Cylinder body: project onto plane ⊥ n and solve |P_⊥(t)|² = r². a_c = 1.0 - dn * dn # (K, M) b_c = -2.0 * (dA - dn * An) # (K, M) c_c = A_sq - An * An - r_sq # (M,) disc_c = b_c * b_c - 4.0 * a_c * c_c safe_a = a_c.clamp(min=1e-9) sqrt_c = torch.sqrt(disc_c.clamp(min=0.0)) t_cyl = (-b_c - sqrt_c) / (2.0 * safe_a) s_cyl = t_cyl * dn - An # axial projection from A cyl_ok = (disc_c >= 0) & (a_c > 1e-7) & (t_cyl > 1e-6) & \ (s_cyl >= 0.0) & (s_cyl <= ba_len) t_cyl = torch.where(cyl_ok, t_cyl, torch.full_like(t_cyl, INF)) # Sphere at A, restricted to the hemisphere with axial projection ≤ 0. disc_a = dA * dA - (A_sq - r_sq) sqrt_a = torch.sqrt(disc_a.clamp(min=0.0)) t_sa = dA - sqrt_a s_a = t_sa * dn - An a_ok = (disc_a >= 0) & (t_sa > 1e-6) & (s_a <= 0.0) t_sa = torch.where(a_ok, t_sa, torch.full_like(t_sa, INF)) # Sphere at B, restricted to the hemisphere with axial projection ≥ ba_len. disc_b = dB * dB - (B_sq - r_sq) sqrt_b = torch.sqrt(disc_b.clamp(min=0.0)) t_sb = dB - sqrt_b s_b = t_sb * dn - An b_ok = (disc_b >= 0) & (t_sb > 1e-6) & (s_b >= ba_len) t_sb = torch.where(b_ok, t_sb, torch.full_like(t_sb, INF)) return torch.minimum(torch.minimum(t_cyl, t_sa), t_sb) def _render_capsules_torch( starts: torch.Tensor, ends: torch.Tensor, colors: torch.Tensor, H: int, W: int, fx: float, fy: float, cx: float, cy: float, radius: torch.Tensor, # scalar or (M,) per-capsule radius background_rgb: Optional[torch.Tensor], device: torch.device, flat_shade: bool = False, ) -> torch.Tensor: """Analytic ray-capsule renderer for a union of capsules. Camera at origin looking down +Z; pixels in y-down screen coords.""" M = int(starts.shape[0]) if M == 0: if background_rgb is not None: return background_rgb.to(device=device, dtype=torch.float32).clamp(0.0, 1.0) return torch.zeros(H, W, 3, dtype=torch.float32, device=device) yy, xx = torch.meshgrid( torch.arange(H, device=device, dtype=torch.float32), torch.arange(W, device=device, dtype=torch.float32), indexing="ij", ) u = (xx - cx) / fx v = (yy - cy) / fy z = torch.ones_like(u) ray_dirs = torch.stack([u, v, z], dim=-1) ray_dirs = ray_dirs / torch.linalg.norm(ray_dirs, dim=-1, keepdim=True) flat_dirs = ray_dirs.view(-1, 3) N = flat_dirs.shape[0] radius = torch.as_tensor(radius, device=device, dtype=torch.float32) if radius.ndim == 0: radius = radius.expand(M) ba = ends - starts ba_len = torch.linalg.norm(ba, dim=1).clamp(min=1e-6) ba_norm = ba / ba_len.unsqueeze(1) z_min = float(min(starts[:, 2].min().item(), ends[:, 2].min().item())) z_near = max(0.05, z_min - float(radius.max().item())) # Union of per-capsule screen-space bboxes. Pixels outside this mask # provably can't hit any capsule, so the analytic intersection only runs # on the relevant subset of the canvas (~5-15% at 1080p for typical poses). sz = starts[:, 2].clamp(min=z_near) ez = ends[:, 2].clamp(min=z_near) sx_p = starts[:, 0] * fx / sz + cx sy_p = starts[:, 1] * fy / sz + cy ex_p = ends[:, 0] * fx / ez + cx ey_p = ends[:, 1] * fy / ez + cy # Projected radius using the closer endpoint — conservative bbox. r_pix = radius * fx / torch.minimum(sz, ez) pad = 2.0 xmin_t = (torch.minimum(sx_p, ex_p) - r_pix - pad).floor().long().clamp(min=0, max=W) xmax_t = (torch.maximum(sx_p, ex_p) + r_pix + pad).ceil().long().clamp(min=0, max=W) ymin_t = (torch.minimum(sy_p, ey_p) - r_pix - pad).floor().long().clamp(min=0, max=H) ymax_t = (torch.maximum(sy_p, ey_p) + r_pix + pad).ceil().long().clamp(min=0, max=H) # One stack→tolist sync amortizes the GPU→CPU read over all M bboxes. bboxes_cpu = torch.stack([xmin_t, ymin_t, xmax_t, ymax_t], dim=1).tolist() coarse_mask = torch.zeros(H, W, dtype=torch.bool, device=device) for xmin_i, ymin_i, xmax_i, ymax_i in bboxes_cpu: if xmax_i > xmin_i and ymax_i > ymin_i: coarse_mask[ymin_i:ymax_i, xmin_i:xmax_i] = True # Analytic ray-capsule intersection. One pass over the masked pixels — # the previous SDF marcher took up to MAX_STEPS=96 iterations per pixel # plus 6 SDF evaluations per hit pixel for finite-difference normals. INF = float("inf") flat_t = torch.full((N,), INF, device=device, dtype=torch.float32) flat_m_idx = torch.full((N,), -1, device=device, dtype=torch.long) active_idx = torch.nonzero(coarse_mask.view(-1), as_tuple=False).squeeze(1) if active_idx.numel() > 0: # Cap per-chunk (K, M) tensors to ~4M elements to keep peak memory # manageable when both K (image pixels) and M (capsules) are large. chunk_max = max(1, int(4_000_000 / max(M, 1))) for i0 in range(0, active_idx.numel(), chunk_max): sub = active_idx[i0 : i0 + chunk_max] t_KM = _ray_capsule_t( flat_dirs[sub], starts, ends, ba_norm, ba_len, radius, ) t_min, m_idx = t_KM.min(dim=1) hit = t_min < INF if hit.any(): winners = sub[hit] flat_t[winners] = t_min[hit] flat_m_idx[winners] = m_idx[hit] # Shade: analytic normal (P - closest_point_on_segment) → soft Lambert × depth fade. out = torch.zeros((N, 3), dtype=torch.float32, device=device) if background_rgb is not None: out = background_rgb.to(device=device, dtype=torch.float32).reshape(N, 3).clone() hit_idx = torch.nonzero(flat_m_idx >= 0, as_tuple=False).squeeze(1) if hit_idx.numel() > 0: rd = flat_dirs[hit_idx] t_h = flat_t[hit_idx] m_h = flat_m_idx[hit_idx] p_hit = rd * t_h.unsqueeze(-1) A_h = starts[m_h] n_h = ba_norm[m_h] L_h = ba_len[m_h] proj = ((p_hit - A_h) * n_h).sum(-1).clamp(min=0.0) proj = torch.minimum(proj, L_h) C_h = A_h + proj.unsqueeze(-1) * n_h normals = p_hit - C_h normals = normals / normals.norm(dim=-1, keepdim=True).clamp(min=1e-8) col = colors[m_h, :3] if flat_shade: # Solid per-limb color (OpenPose look) — no lighting/depth modulation. out[hit_idx] = col return out.view(H, W, 3).clamp(0.0, 1.0) # SCAIL Blinn-Phong (render_torch.py:290-331). Headlight: light = +Z. diff = torch.clamp(-(normals[:, 2]), min=0.0) diffuse = 0.45 + 0.55 * diff view_dir = -rd half_dir = view_dir.clone() half_dir[:, 2] -= 1.0 half_dir = half_dir / half_dir.norm(dim=-1, keepdim=True).clamp(min=1e-8) spec = torch.clamp((normals * half_dir).sum(dim=-1), min=0.0).pow(32) # Mild depth fade matches SCAIL's mm-scale ramp in our meter units. z_vals = p_hit[:, 2] z_lo, z_hi = float(z_vals.min().item()), float(z_vals.max().item()) if z_hi - z_lo > 1e-6: fade = 1.0 - (z_vals - z_lo) / (z_hi - z_lo) depth_factor = 0.85 + 0.15 * fade else: depth_factor = torch.ones_like(z_vals) base = col * diffuse.unsqueeze(-1) * depth_factor.unsqueeze(-1) highlight = (0.5 * spec * depth_factor).unsqueeze(-1) out[hit_idx] = base + highlight return out.view(H, W, 3).clamp(0.0, 1.0) def render_pose_data_capsules( pose_data: Dict[str, Any], *, frame_idx: int, W: int, H: int, background: Optional[torch.Tensor] = None, composite: str = "over", radius_m: float = 0.025, include_hands: bool = False, palette: str = "scail", person_brightness_falloff: float = 0.0, flat_shade: bool = False, hand_radius_scale: float = 0.4, device: Optional[torch.device] = None, ) -> torch.Tensor: """Render a frame's pose_data as 3D capsules projected through the per- person camera. Returns (H, W, 3) fp32 in [0, 1]. `composite='over'` paints over `background` (black if None); `composite='mesh_only'` always uses a black canvas. `radius_m` is in METERS (matching `pred_keypoints_3d` / `pred_cam_t`). Hand limbs use `radius_m * hand_radius_scale` (their bones are far shorter than body limbs). Camera fx/fy come from each person's `focal_length`. """ persons = pose_data["frames"][frame_idx] if device is None: device = comfy.model_management.get_torch_device() # SAM3DBody shares one camera across the clip — pick from the first valid person. fx = fy = float(min(H, W)) for person in persons: f = person.get("focal_length") if f is None: continue fx = fy = float(np.asarray(f, dtype=np.float32).reshape(-1)[0]) break cx, cy = W * 0.5, H * 0.5 starts_np, ends_np, colors_np, is_hand_np = _build_specs_from_pose( persons, include_hands=include_hands, palette=palette, person_brightness_falloff=person_brightness_falloff, ) bg_t: Optional[torch.Tensor] = None if composite == "over" and background is not None: bg_t = background.to(device=device, dtype=torch.float32).clamp(0.0, 1.0) if bg_t.shape[:2] != (H, W): bg_t = bg_t.permute(2, 0, 1).unsqueeze(0) bg_t = torch.nn.functional.interpolate( bg_t, size=(H, W), mode="bilinear", align_corners=False, ) bg_t = bg_t.squeeze(0).permute(1, 2, 0).contiguous() if starts_np.shape[0] == 0: if bg_t is not None: return bg_t return torch.zeros(H, W, 3, dtype=torch.float32, device=device) starts_t = torch.from_numpy(starts_np).to(device=device, dtype=torch.float32) ends_t = torch.from_numpy(ends_np).to(device=device, dtype=torch.float32) colors_t = torch.from_numpy(colors_np).to(device=device, dtype=torch.float32) radii_np = np.where(is_hand_np, radius_m * hand_radius_scale, radius_m).astype(np.float32) radii_t = torch.from_numpy(radii_np).to(device=device, dtype=torch.float32) return _render_capsules_torch( starts_t, ends_t, colors_t, H=H, W=W, fx=fx, fy=fy, cx=cx, cy=cy, radius=radii_t, background_rgb=bg_t, device=device, flat_shade=flat_shade, )