"""SAM 3D Body — Predict + Smooth nodes and their inference helpers.""" import logging from typing import Any, Dict, List, Optional import numpy as np import torch from tqdm import tqdm from scipy.signal import savgol_coeffs import comfy.model_management import comfy.utils from comfy_api.latest import io, ComfyExtension from typing_extensions import override import folder_paths from comfy.ldm.sam3d_body.model.model import SAM3DBody from comfy.ldm.sam3d_body.model.dinov3 import apply_dinov3_qkv_bias_mask from comfy_extras.sam3d_body.utils import ( cam_int_from_fov, cam_int_from_moge, inputs_from_sam3_track, run_batched_frames, run_batched_single_chunk, compute_canonical_colors, compute_hand_vert_mask, ) from comfy_extras.sam3d_body.rasterizer import render_pose_data_torch as render_pose_data from comfy_extras.sam3d_body.export.capsules import render_pose_data_capsules from comfy_extras.sam3d_body.export.openpose_2d import render_pose_data_openpose from comfy_extras.sam3d_body import face_expression as fx from comfy_extras.sam3d_body.utils import image_to_uint8 SAM3TrackData = io.Custom("SAM3_TRACK_DATA") # MHRPoseData = SAM3DBody_Predict's native output (carries mhr_model_params, # shape_params, expr_params, MHR70 keypoint layout, canonical_colors keyed to # MHR mesh, hand_vert_mask from MHR LBS). The export-side consumers # (BuildPoseGLB / SavePoseBVH in comfy_extras/nodes_save_3d.py) also accept # KIMODO_POSE_DATA via a MultiType union — those types are mirrored there. MHRPoseData = io.Custom("MHR_POSE_DATA") SAM3DBodyModel = io.Custom("SAM3D_BODY_MODEL") MoGeGeometry = io.Custom("MOGE_GEOMETRY") # Loader class SAM3DBody_Loader(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="SAM3DBody_Loader", display_name="Load SAM3D Body Model", category="image/detection/sam3dbody/", #TODO: better category? inputs=[ io.Combo.Input( "model_file", options=folder_paths.get_filename_list("detection"), tooltip="SAM 3D Body weights (.safetensors) in the 'detection' folder", ), ], outputs=[SAM3DBodyModel.Output("model", display_name="sam3d_body_model")], ) @classmethod def execute(cls, model_file) -> io.NodeOutput: path = folder_paths.get_full_path_or_raise("detection", model_file) sd = comfy.utils.load_torch_file(path, safe_load=True) sd = {k.replace(".layers.0.0.", ".layers.0."): v for k, v in sd.items()} load_device = comfy.model_management.get_torch_device() weight_dtype = comfy.utils.weight_dtype(sd) torch_dtype = comfy.model_management.unet_dtype( device=load_device, model_params=-1, weight_dtype=weight_dtype, ) manual_cast_dtype = comfy.model_management.unet_manual_cast(torch_dtype, load_device) operations = comfy.ops.pick_operations(torch_dtype, manual_cast_dtype, load_device=load_device, disable_fast_fp8=True) model = SAM3DBody(dtype=torch_dtype, operations=operations) model.load_state_dict(sd, strict=False) apply_dinov3_qkv_bias_mask(model.backbone) model.eval() model.backbone_dtype = torch_dtype model._sam3d_image_size = model.image_size model._sam3d_canonical_colors = compute_canonical_colors(model) model._sam3d_hand_vert_mask = compute_hand_vert_mask(model) patcher = comfy.model_patcher.CoreModelPatcher( model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device(), size=comfy.model_management.module_size(model), ) return io.NodeOutput(patcher) # Predict class SAM3DBody_Predict(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="SAM3DBody_Predict", display_name="Predict SAM3D Body", category="image/detection/sam3dbody/", inputs=[ SAM3DBodyModel.Input("sam3d_body_model"), io.Image.Input("image"), SAM3TrackData.Input( "sam3_track_data", optional=True, tooltip=("Output of SAM3 Video Track, required for multi-person detection"), ), io.Boolean.Input( "run_hand_refinement", default=True, tooltip="Improves hand pose at the cost of extra inference time and memory use"), io.Float.Input( "fov_degrees", default=0.0, min=0.0, max=170.0, step=0.5, tooltip=( "Vertical FOV in degrees. Affects predicted depth (cam_t.z) and " "absolute scale. 0 = use moge_geometry or fall back to ~53° (16:9). " "Any non-zero value overrides moge_geometry." ), ), MoGeGeometry.Input( "moge_geometry", optional=True, tooltip=( "MoGe geometry (from MoGeInference), used to calculate camera field of view." "For batches choose the most representative frame, or leave unset" ), ), io.Int.Input( "chunk_size", #TODO: automate? default=64, min=1, max=512, step=1, advanced=True, tooltip=( "Max person-crops per forward. Higher = throughput + VRAM; " "per-chunk frame count is chunk_size / persons_per_frame." ), ), ], outputs=[MHRPoseData.Output("mhr_pose_data")], ) @classmethod def execute(cls, sam3d_body_model, image, sam3_track_data=None, run_hand_refinement=True, fov_degrees=0.0, moge_geometry=None, chunk_size=144) -> io.NodeOutput: comfy.model_management.load_model_gpu(sam3d_body_model) inner: SAM3DBody = sam3d_body_model.model B, H, W, _ = image.shape image_size = getattr(inner, "_sam3d_image_size", (512, 512)) per_frame_bboxes, per_frame_masks = (None, None) if sam3_track_data is not None: per_frame_bboxes, per_frame_masks = inputs_from_sam3_track(sam3_track_data, B, H, W) if per_frame_bboxes is None: # No track wired (or empty / frame count mismatch) — single-person # full-frame fallback. Multi-person scenes need SAM3 Video Track. full_frame_bbox = torch.tensor([[0.0, 0.0, float(W), float(H)]], dtype=torch.float32) per_frame_bboxes = [full_frame_bbox.clone() for _ in range(B)] per_frame_masks = None inference_type = "full" if run_hand_refinement else "body" # Precedence: explicit fov_degrees > MoGe estimate > diagonal default. cam_int = cam_int_from_fov(int(H), int(W), float(fov_degrees)) if cam_int is None: cam_int = cam_int_from_moge(moge_geometry, int(H), int(W)) frames_rgb: List[Optional[torch.Tensor]] = [] for f in range(B): if per_frame_bboxes[f].shape[0] == 0: frames_rgb.append(None) else: frames_rgb.append(image_to_uint8(image[f])) # Batched path requires uniform non-zero K across all frames. bbox_counts = {per_frame_bboxes[f].shape[0] for f in range(B) if frames_rgb[f] is not None} can_batch = ( len(bbox_counts) == 1 and 0 not in bbox_counts and all(frames_rgb[f] is not None for f in range(B)) ) frames_out: List[List[Dict[str, Any]]] = [] pbar = comfy.utils.ProgressBar(B) if can_batch and B > 0: frames_out = run_batched_frames( inner, frames_rgb, per_frame_bboxes, per_frame_masks, image_size, inference_type, cam_int=cam_int, pbar=pbar, crops_per_chunk=int(chunk_size), ) else: # Mixed K per frame — call the batched path once per frame. for f in range(B): if frames_rgb[f] is None or per_frame_bboxes[f].shape[0] == 0: frames_out.append([]) pbar.update(1) continue mask_f = [per_frame_masks[f]] if per_frame_masks is not None else None chunk = run_batched_single_chunk( inner, [frames_rgb[f]], [per_frame_bboxes[f]], mask_f, image_size, inference_type, K=int(per_frame_bboxes[f].shape[0]), cam_int=cam_int, ) frames_out.append(chunk[0]) pbar.update(1) mhr_pose_data = { "frames": frames_out, "faces": inner.head_pose.faces.cpu().numpy(), "image_size": (int(H), int(W)), "canonical_colors": getattr(inner, "_sam3d_canonical_colors", None), "hand_vert_mask": getattr(inner, "_sam3d_hand_vert_mask", None), } return io.NodeOutput(mhr_pose_data) class SAM3DBody_FaceExpression(io.ComfyNode): """Drive MHR face blendshapes from the core MediaPipe Face Landmarker. Detects per-frame faces, IoU-matches each to a tracked person, maps the 52 ARKit blendshapes onto MHR's 72-axis `expr_params`, and re-runs MHR forward so pred_vertices/pred_keypoints reflect the new expression. """ @classmethod def define_schema(cls): return io.Schema( node_id="SAM3DBody_FaceExpression", description="Drive MHR face blendshapes from the core MediaPipe Face Landmarker.", display_name="Face Expression to SAM3D Body", #TODO: better name? category="image/detection/sam3dbody/", inputs=[ MHRPoseData.Input("mhr_pose_data"), SAM3DBodyModel.Input("sam3d_body_model"), io.Image.Input("image"), io.Float.Input( "strength", default=1.0, min=0.0, max=4.0, step=0.05, tooltip="Global multiplier on all blendshapes. >1 exaggerates.", ), io.Float.Input( "mouth_strength", default=1.0, min=0.0, max=4.0, step=0.05, tooltip="Multiplier on mouth/jaw shapes. MP's jawOpen saturates near 1.0.", advanced=True, ), io.Float.Input( "eye_strength", default=2.0, min=0.0, max=4.0, step=0.05, tooltip="Multiplier on eye shapes. MP rarely exceeds 0.5; 2-3x often needed.", advanced=True, ), io.Float.Input( "brow_strength", default=2.0, min=0.0, max=4.0, step=0.05, tooltip="Multiplier on brow/cheek/sneer shapes. MP outputs ~0.1-0.3; 2-3x.", advanced=True, ), io.Float.Input( "input_threshold", default=0.02, min=0.0, max=0.5, step=0.01, tooltip=( "Deadzone on MediaPipe's raw output (below = zero, above = linear remap). " ), advanced=True, ), io.Int.Input( "blendshape_smooth_window", default=7, min=1, max=31, step=2, tooltip=( "Gaussian window on MediaPipe's per-frame signal before MHR mapping. " "MediaPipe's raw output swings 30-70% frame-to-frame on static faces. " "1 = disabled. Use odd values." ), advanced=True, ), ], outputs=[ MHRPoseData.Output("mhr_pose_data"), ], ) @classmethod def execute(cls, mhr_pose_data, sam3d_body_model, image, strength=1.0, mouth_strength=1.0, eye_strength=1.0, brow_strength=1.0, input_threshold=0.15, blendshape_smooth_window=7) -> io.NodeOutput: comfy.model_management.load_model_gpu(sam3d_body_model) inner: SAM3DBody = sam3d_body_model.model frames = mhr_pose_data["frames"] B = len(frames) if B == 0: return io.NodeOutput(mhr_pose_data) img_np = (image * 255.0).clamp(0.0, 255.0).to(torch.uint8).cpu().numpy() new_frames: List[List[Dict[str, Any]]] = [[dict(p) for p in f] for f in frames] max_persons = max((len(f) for f in new_frames), default=0) per_person_coefs: List[List[Optional[Dict[str, float]]]] = [ [None] * B for _ in range(max_persons) ] pbar = comfy.utils.ProgressBar(B) n_total_frames_with_persons = 0 crop_factor = 1.2 # Auto-pick full-frame vs per-person crops. BlazeFace full-range needs # ≥32px face in its 192px input; below that we escalate to per-person # crops. Face height ≈ 20% of body-bbox height (rough but stable). H_img0, W_img0 = img_np.shape[1], img_np.shape[2] min_bbox_px = 32.0 * max(H_img0, W_img0) / (192.0 * 0.20) use_per_person_crops = any( (p["bbox"][3] - p["bbox"][1]) < min_bbox_px for persons in new_frames for p in persons ) for fi in tqdm(range(B), desc="SAM3D face expression detect"): persons = new_frames[fi] img_fi = img_np[min(fi, img_np.shape[0] - 1)] if not persons: pbar.update(1) continue n_total_frames_with_persons += 1 person_bboxes = [np.asarray(p["bbox"], dtype=np.float32) for p in persons] H_img, W_img = img_fi.shape[:2] if use_per_person_crops: # One MP call per person on a tight head crop — recovers small/ # distant faces that the full-frame 192px BlazeFace would miss. for pid, pb in enumerate(person_bboxes): if pid >= max_persons: continue cr = fx.head_crop_from_keypoints( persons[pid].get("pred_keypoints_2d"), crop_factor, W_img, H_img, ) if cr is None: cr = fx.head_region_crop(pb, crop_factor, W_img, H_img) faces = fx.detect_faces_in_crop(inner, img_fi, cr, num_faces=1) if not faces: continue # Pick face closest to person bbox center when a neighbor leaks in. pcx, pcy = 0.5 * (pb[0] + pb[2]), 0.5 * (pb[1] + pb[3]) best = min( faces, key=lambda f: (0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]) - pcx) ** 2 + (0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3]) - pcy) ** 2, ) per_person_coefs[pid][fi] = best["blendshapes"] else: faces = inner.face_landmarker.detect_batch([img_fi], num_faces=max(1, len(persons)))[0] if faces: face_bboxes = [f["bbox_xyxy"] for f in faces] assignment = fx.assign_faces_to_persons(face_bboxes, person_bboxes) for pid, face_idx in enumerate(assignment): if face_idx is None or pid >= max_persons: continue per_person_coefs[pid][fi] = faces[face_idx]["blendshapes"] pbar.update(1) # Baseline subtraction. MP has subject-specific rest bias (e.g. # naturally-raised brow at 0.15); without subtraction, strength # multipliers bake that into every frame. Per-clip needs ~30 frames # or it would zero out the expression. BASELINE_MIN_FRAMES = 30 if n_total_frames_with_persons >= BASELINE_MIN_FRAMES: for pid in range(max_persons): per_person_coefs[pid] = fx.subtract_per_clip_baseline( per_person_coefs[pid], percentile=5.0, ) else: logging.warning( f"[SAM 3D Body FaceExpression] per-clip baseline subtraction " f"needs ~{BASELINE_MIN_FRAMES}+ frames with detections; " f"got {n_total_frames_with_persons}. Skipping subtraction." ) # Smooth raw signal AFTER baseline subtraction but BEFORE gap fill — # MP's per-frame noise gets averaged out at the source. bs_win = int(blendshape_smooth_window) if bs_win > 1: for pid in range(max_persons): per_person_coefs[pid] = fx.smooth_blendshape_series( per_person_coefs[pid], window=bs_win, ) for pid in range(max_persons): per_person_coefs[pid] = fx.fill_detection_gaps( per_person_coefs[pid], method="interpolate", max_gap=12, ) n_written = 0 for fi in range(B): for pid, p in enumerate(new_frames[fi]): if pid >= max_persons: continue coefs = per_person_coefs[pid][fi] if coefs is None: continue p["expr_params"] = fx.arkit_to_expr_params( coefs, strength=float(strength), mouth_strength=float(mouth_strength), eye_strength=float(eye_strength), brow_strength=float(brow_strength), input_threshold=float(input_threshold), ).astype(np.float32) n_written += 1 if n_written > 0: fx.regenerate_mesh_from_params(inner, new_frames) new_pose = dict(mhr_pose_data) new_pose["frames"] = new_frames return io.NodeOutput(new_pose) class SAM3DBody_Smooth(io.ComfyNode): """Reduce frame-to-frame jitter via vertex-space temporal averaging. Backs off on mesh-geometry keys when the subject rotates fast (averaging across a spin flattens the mesh); camera-space keys still get full smoothing. """ @classmethod def define_schema(cls): return io.Schema( node_id="SAM3DBody_Smooth", description="Reduce frame-to-frame jitter via vertex-space temporal averaging", display_name="Smooth SAM3D Body Pose Frames", category="image/detection/sam3dbody/", inputs=[ MHRPoseData.Input("mhr_pose_data"), io.Float.Input( "strength", default=1.0, min=0.0, max=1.0, step=0.05, tooltip="Blend raw (0) → smoothed (1).", ), io.Combo.Input( "method", options=["gaussian", "savgol"], default="gaussian", advanced=True, tooltip=( "'gaussian': symmetric weighted average — phase-preserving " "(no time-shift), best general-purpose smoother. " "'savgol': sliding polynomial fit — preserves sharp peaks " ), ), io.Int.Input( "window", default=7, min=1, max=51, step=2, advanced=True, tooltip="Temporal window in frames (odd values).", ), io.Float.Input( "rotation_threshold_deg", default=15.0, min=0.0, max=45.0, step=1.0, advanced=True, tooltip=( "Geometry smoothing drops to RAW above this root-rotation " "rate (deg/frame) to preserve fast spins. 15° suits most " "content; low values trigger on ordinary jitter and " "silently sabotage smoothing. 0 = disable backoff." ), ), ], outputs=[MHRPoseData.Output("mhr_pose_data")], ) @classmethod def execute(cls, mhr_pose_data, method, window, strength, rotation_threshold_deg) -> io.NodeOutput: if strength <= 0.0 or window <= 1: return io.NodeOutput(mhr_pose_data) frames = mhr_pose_data["frames"] B = len(frames) if B < 2: return io.NodeOutput(mhr_pose_data) max_p = max((len(f) for f in frames), default=0) if max_p == 0: return io.NodeOutput(mhr_pose_data) # Geometry keys rotate with the subject, so linear averaging during # fast spins flattens the mesh — these get per-frame adaptive strength. keys_geom = { "pred_vertices", "pred_keypoints_3d", "pred_joint_coords", "pred_global_rots", "mhr_model_params", "body_pose_params", "global_rot", "pred_pose_raw", } # Camera / appearance / 2D keys are safe to smooth linearly. keys_cam = { "pred_cam_t", "pred_keypoints_2d", "focal_length", "shape_params", "scale_params", "hand_pose_params", "expr_params", } all_keys = sorted(keys_geom | keys_cam) kernel = _smoothing_kernel(method, window) smoothed = [list(f) for f in frames] base_blend = float(strength) rot_thresh = float(np.deg2rad(max(0.0, rotation_threshold_deg))) for pid in range(max_p): valid = np.array([pid < len(f) for f in frames], dtype=bool) if valid.sum() < 2: continue # Adaptive blend per frame from `global_rot` (euler ZYX); geom_blend = np.full(B, base_blend, dtype=np.float32) if rot_thresh > 0.0: root_rotmats = [] valid_root = [] for fi in range(B): if not valid[fi]: root_rotmats.append(np.eye(3, dtype=np.float32)) valid_root.append(False) continue gr = frames[fi][pid].get("global_rot") if gr is None: root_rotmats.append(np.eye(3, dtype=np.float32)) valid_root.append(False) continue eul = np.asarray(gr, dtype=np.float32).reshape(3) # ZYX convention: R = Rz @ Ry @ Rx cz, sz = np.cos(eul[0]), np.sin(eul[0]) cy, sy = np.cos(eul[1]), np.sin(eul[1]) cx, sx = np.cos(eul[2]), np.sin(eul[2]) Rz = np.array([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dtype=np.float32) Ry = np.array([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=np.float32) Rx = np.array([[1, 0, 0], [0, cx, -sx], [0, sx, cx]], dtype=np.float32) root_rotmats.append(Rz @ Ry @ Rx) valid_root.append(True) ang = np.zeros(B, dtype=np.float32) for fi in range(1, B): if valid_root[fi] and valid_root[fi - 1]: R_delta = root_rotmats[fi] @ root_rotmats[fi - 1].T cos_a = float(np.clip((np.trace(R_delta) - 1.0) / 2.0, -1.0, 1.0)) ang[fi] = abs(np.arccos(cos_a)) # Peak-smear over ±window/2 — neighbors of a rotated frame # must back off too, or the temporal average pulls the rotated # pose in and flattens the mesh anyway. half = max(1, window // 2) ang_smooth = np.zeros_like(ang) for fi in range(B): lo = max(0, fi - half) hi = min(B, fi + half + 1) ang_smooth[fi] = ang[lo:hi].max() if hi > lo else 0.0 # base_blend at no rotation, 0 at threshold. ratio = np.clip(ang_smooth / rot_thresh, 0.0, 1.0) geom_blend = base_blend * (1.0 - ratio) for key in all_keys: ref = None for fi in range(B): if valid[fi] and key in frames[fi][pid] and frames[fi][pid][key] is not None: ref = np.asarray(frames[fi][pid][key]) break if ref is None or ref.dtype.kind not in "fiu": continue stacked = np.zeros((B,) + ref.shape, dtype=np.float32) for fi in range(B): if valid[fi] and key in frames[fi][pid] and frames[fi][pid][key] is not None: stacked[fi] = np.asarray(frames[fi][pid][key], dtype=np.float32) else: stacked[fi] = stacked[fi - 1] if fi > 0 else 0.0 filtered = _apply_temporal_filter(stacked, kernel) if key in keys_geom: b = geom_blend while b.ndim < stacked.ndim: b = b[..., None] out = (1.0 - b) * stacked + b * filtered else: out = (1.0 - base_blend) * stacked + base_blend * filtered for fi in range(B): if valid[fi]: smoothed[fi][pid] = dict(smoothed[fi][pid]) smoothed[fi][pid][key] = out[fi].astype(ref.dtype) new_pose = dict(mhr_pose_data) new_pose["frames"] = smoothed return io.NodeOutput(new_pose) def _smoothing_kernel(method: str, window: int) -> np.ndarray: window = max(1, int(window)) if window % 2 == 0: window += 1 if method == "savgol": order = 3 if window >= 5 else min(window - 1, 1) return savgol_coeffs(window, order).astype(np.float32) # gaussian (default) sigma = max(1.0, window / 5.0) x = np.arange(window) - (window - 1) / 2.0 k = np.exp(-(x ** 2) / (2 * sigma ** 2)) return (k / k.sum()).astype(np.float32) def _apply_temporal_filter(stacked: np.ndarray, kernel: np.ndarray) -> np.ndarray: """stacked: (B, *feature_shape). Returns same shape, smoothed over axis 0.""" B = stacked.shape[0] w = len(kernel) pad = w // 2 flat = stacked.reshape(B, -1) # (B, K) padded = np.concatenate( [np.repeat(flat[:1], pad, axis=0), flat, np.repeat(flat[-1:], pad, axis=0)], axis=0, ) # (B + 2*pad, K) out = np.zeros_like(flat) for i, k in enumerate(kernel): out += k * padded[i : i + B] return out.reshape(stacked.shape) # Render def rainbow_tilt_inputs(): """Shared rainbow-shader tilt inputs (used by Render and ToGLB schemas).""" return [ io.Float.Input( "rainbow_tilt_z", default=-35.0, min=-90.0, max=90.0, step=0.5, tooltip="Rotate rainbow jet axis around Z (forward). Differentiates left/right.", ), io.Float.Input( "rainbow_tilt_x", default=0.0, min=-90.0, max=90.0, step=0.5, tooltip="Rotate rainbow jet axis around X (right). Differentiates front/back.", ), ] def _render_mesh_mode_inputs(): return [ io.DynamicCombo.Input( "shader", options=[ io.DynamicCombo.Option("default", []), io.DynamicCombo.Option("normals", []), io.DynamicCombo.Option("rainbow", rainbow_tilt_inputs()), io.DynamicCombo.Option("rainbow_face_normal", rainbow_tilt_inputs()), io.DynamicCombo.Option("rainbow_face_semantic", rainbow_tilt_inputs()), io.DynamicCombo.Option("depth", []), ], tooltip=( "Preset shader. 'normals' = current surface normal in camera " "space (OpenGL Y+ normal-map convention: +X→R, +Y→G, +Z→B). " "'rainbow' = RealisDance style body-Y jet; the 'rainbow_face_*' " "variants override face verts with normal/per-region colors; " "'depth' = linear gray." ), ), io.Float.Input("opacity", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input( "person_palette_falloff", default=0.6, min=0.1, max=1.0, step=0.05, tooltip=( "Per-person desaturation toward white: track k gets a " "(1 - falloff^k) pastel mix (SCAIL 'softer second person'). 1.0 = off." ), ), io.Combo.Input( "region", options=["full_body", "hands_only"], default="full_body", tooltip=( "'hands_only' filters faces via the precomputed `hand_vert_mask` " "(LBS weights against canonical hand KPs) — isolates the hand " "mesh for debugging. Falls back to full mesh if the mask is missing." ), ), ] def _render_capsules_mode_inputs(): return [ io.Float.Input( "radius_m", default=0.022, min=0.005, max=0.2, step=0.001, tooltip="Capsule radius in meters (SCAIL reference: ~0.022 m).", ), io.Combo.Input( "hand_style", options=["disabled", "dwpose", "openpose"], default="dwpose", tooltip=( "Composite 2D OpenPose hands on top of the 3D capsule body " "(matches SCAIL — no 3D hand capsules). 'disabled' = no hands. " "'dwpose' = solid-blue hand dots; 'openpose' = rainbow dots. " "Sticks stay rainbow per-finger either way." ), ), io.Combo.Input( "face_style", options=["disabled", "full", "eyes_mouth"], default="disabled", tooltip=( "'full' = all face landmarks (sapiens-238 if present, else " "rig-fallback ~30). 'eyes_mouth' = rig-fallback subset (~12 " "dots: eyes + outer lips only). 'disabled' = no face dots." ), ), io.Float.Input( "person_palette_falloff", default=0.6, min=0.1, max=1.0, step=0.05, tooltip=( "Per-person desaturation: track k blends toward white by " "1 - falloff^k. Track 0 stays vivid; 1.0 disables falloff." ), ), ] def _render_openpose_mode_inputs(): return [ io.Int.Input( "marker_radius_px", default=4, min=1, max=32, step=1, tooltip="Body keypoint dot radius (px).", ), io.Int.Input( "stick_width_px", default=4, min=1, max=32, step=1, tooltip="Body limb ellipse half-width (px). DWPose default = 4.", ), io.Float.Input( "limb_alpha", default=0.6, min=0.0, max=1.0, step=0.05, tooltip="Per-limb alpha. DWPose default = 0.6.", ), io.Combo.Input( "face_style", options=["disabled", "full", "eyes_mouth"], default="disabled", tooltip=( "'full' = all face landmarks (sapiens-238 if present, else " "rig-fallback ~30). 'eyes_mouth' = rig-fallback subset (~12 " "dots: eyes + outer lips only). 'disabled' = no face dots." ), ), io.Combo.Input( "hand_style", options=["disabled", "dwpose", "openpose"], default="disabled", tooltip=( "Draw 21+21 hand keypoints + sticks. 'disabled' = no hands. " "'dwpose' = solid-blue dots; 'openpose' = rainbow dots." ), ), io.Float.Input( "person_palette_falloff", default=0.6, min=0.1, max=1.0, step=0.05, tooltip=( "Per-person desaturation: track k blends toward white by " "1 - falloff^k. Track 0 stays vivid; 1.0 disables falloff." ), ), ] def _scale_pose_data(mhr_pose_data: Dict[str, Any], new_H: int, new_W: int) -> Dict[str, Any]: """Rescale per-person camera intrinsics + 2D coords to a new canvas size. Pose data records focal_length in pixels of the original image; without scaling, the FOV would change and subjects would be cropped/zoomed. When the new aspect differs from the original, the body (3D-projected through focal_length on a centered principal point) lands in a letterboxed region of the new canvas. 2D-prestored coords must follow the same uniform scale + center offset so face/hand overlays align with the body — per-axis stretching would split them apart.""" old_H, old_W = mhr_pose_data["image_size"] if new_H == old_H and new_W == old_W: return mhr_pose_data rW = new_W / old_W rH = new_H / old_H r_focal = min(rW, rH) offset_x = (new_W - r_focal * old_W) * 0.5 offset_y = (new_H - r_focal * old_H) * 0.5 new_frames: List[List[Dict[str, Any]]] = [] for frame in mhr_pose_data["frames"]: scaled = [] for p in frame: p = dict(p) f = p.get("focal_length") if f is not None: p["focal_length"] = np.asarray(f, dtype=np.float32) * r_focal for k in ("pred_keypoints_2d", "pred_face_keypoints_2d"): v = p.get(k) if v is not None: arr = np.asarray(v, dtype=np.float32).copy() arr[..., 0] = arr[..., 0] * r_focal + offset_x arr[..., 1] = arr[..., 1] * r_focal + offset_y p[k] = arr bb = p.get("bbox") if bb is not None: bb = np.asarray(bb, dtype=np.float32).copy() bb[..., [0, 2]] = bb[..., [0, 2]] * r_focal + offset_x bb[..., [1, 3]] = bb[..., [1, 3]] * r_focal + offset_y p["bbox"] = bb scaled.append(p) new_frames.append(scaled) out = dict(mhr_pose_data) out["image_size"] = (new_H, new_W) out["frames"] = new_frames return out class SAM3DBody_Render(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="SAM3DBody_Render", display_name="Render SAM3D Body", category="image/detection/sam3dbody/", inputs=[ MHRPoseData.Input("mhr_pose_data"), io.Image.Input( "background", optional=True, tooltip="Per-frame background. Omitted = black canvas.", ), io.Int.Input( "width", default=0, min=0, max=16384, step=8, tooltip=( "Output width in pixels. 0 = use pose data's native " "image_size. If only one of width/height is set, the " "other is derived preserving the original aspect." ), ), io.Int.Input( "height", default=0, min=0, max=16384, step=8, tooltip=( "Output height in pixels. 0 = use pose data's native " "image_size. If only one of width/height is set, the " "other is derived preserving the original aspect." ), ), io.DynamicCombo.Input( "render_style", options=[ io.DynamicCombo.Option("mesh", _render_mesh_mode_inputs()), io.DynamicCombo.Option("silhouette", []), io.DynamicCombo.Option("openpose", _render_openpose_mode_inputs()), io.DynamicCombo.Option("scail", _render_capsules_mode_inputs()), ], tooltip=( "'mesh' = 3D MHR mesh rasterized through the camera. " "'silhouette' = binary mask of the mesh (white-on-black, " "background ignored). 'openpose' = flat 2D skeleton " "from pred_keypoints_2d (DWPose look). 'scail' = SCAIL " "3D capsules via torch SDF ray-march (proper occlusion / depth)." ), ), ], outputs=[io.Image.Output("image")], ) @classmethod def execute(cls, mhr_pose_data, background=None, width=0, height=0, render_style=None) -> io.NodeOutput: render_style = render_style or {"render_style": "mesh"} mode_key = render_style.get("render_style", "mesh") native_H, native_W = mhr_pose_data["image_size"] new_W, new_H = int(width), int(height) if new_W == 0 and new_H == 0: H, W = native_H, native_W px_scale = 1.0 else: if new_W == 0: new_W = max(1, round(native_W * new_H / native_H)) elif new_H == 0: new_H = max(1, round(native_H * new_W / native_W)) mhr_pose_data = _scale_pose_data(mhr_pose_data, new_H, new_W) H, W = new_H, new_W # Marker/stick px constants are authored for native resolution — # scale them so the openpose overlay reads at the same relative size. px_scale = min(new_W / native_W, new_H / native_H) B = len(mhr_pose_data["frames"]) if B == 0: return io.NodeOutput(torch.zeros(1, H, W, 3, dtype=torch.float32)) out_device = comfy.model_management.intermediate_device() bg_t = None if background is None else background.to(device=out_device, dtype=torch.float32) if mode_key == "silhouette": composite = "silhouette" elif bg_t is not None: composite = "over" else: composite = "mesh_only" if mode_key == "openpose": marker_radius_px = max(1, int(round(render_style.get("marker_radius_px", 4) * px_scale))) stick_width_px = max(1, int(round(render_style.get("stick_width_px", 4) * px_scale))) limb_alpha = float(render_style.get("limb_alpha", 0.6)) face_style = str(render_style.get("face_style", "disabled")) hand_style = str(render_style.get("hand_style", "disabled")) include_hands = hand_style != "disabled" hand_color_style = hand_style if include_hands else "dwpose" person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6)) elif mode_key == "scail": cap_radius_m = float(render_style.get("radius_m", 0.030)) cap_hand_style = str(render_style.get("hand_style", "disabled")) cap_include_hands = cap_hand_style != "disabled" cap_hand_color_style = cap_hand_style if cap_include_hands else "dwpose" cap_face_style = str(render_style.get("face_style", "disabled")) person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6)) elif mode_key == "mesh": shader_dict = render_style.get("shader") or {} shader_key = shader_dict.get("shader", "default") rainbow_tilt_x = float(shader_dict.get("rainbow_tilt_x", 0.0)) rainbow_tilt_z = float(shader_dict.get("rainbow_tilt_z", -35.0)) opacity = float(render_style.get("opacity", 1.0)) person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6)) region = str(render_style.get("region", "full_body")) if region == "hands_only": hand_mask = mhr_pose_data["hand_vert_mask"] faces_full = np.asarray(mhr_pose_data["faces"]) keep = hand_mask[faces_full].all(axis=1) mhr_pose_data = dict(mhr_pose_data) mhr_pose_data["faces"] = np.ascontiguousarray( faces_full[keep], dtype=faces_full.dtype, ) else: # silhouette — no shader/opacity controls, mask is binary shader_key = "default" rainbow_tilt_x = 0.0 rainbow_tilt_z = -35.0 opacity = 1.0 person_palette_falloff = 0.6 frames_out = [] pbar = comfy.utils.ProgressBar(B) desc = ( "SAM3D openpose-2D render" if mode_key == "openpose" else "SAM3D SCAIL-3D render" if mode_key == "scail" else "SAM3D silhouette" if mode_key == "silhouette" else "SAM3D render" ) for f in tqdm(range(B), desc=desc): bg_f = None if bg_t is not None: bg_f = bg_t[min(f, bg_t.shape[0] - 1)] if mode_key == "openpose": img = render_pose_data_openpose( mhr_pose_data, frame_idx=f, W=W, H=H, background=bg_f, composite=composite, marker_radius_px=marker_radius_px, stick_width_px=stick_width_px, limb_alpha=limb_alpha, include_hands=include_hands, face_style=face_style, hand_color_style=hand_color_style, person_brightness_falloff=person_palette_falloff, ) elif mode_key == "scail": # SCAIL renders body as 3D capsules + 2D openpose hands on top img = render_pose_data_capsules( mhr_pose_data, frame_idx=f, W=W, H=H, background=bg_f, composite=composite, radius_m=cap_radius_m, include_hands=False, palette="scail", person_brightness_falloff=person_palette_falloff, ) if cap_include_hands or cap_face_style != "disabled": scail_overlay_px = max(1, int(round(4 * px_scale))) scail_face_px = max(1, int(round(1 * px_scale))) img = render_pose_data_openpose( mhr_pose_data, frame_idx=f, W=W, H=H, background=img, composite="over", include_body=False, include_hands=cap_include_hands, face_style=cap_face_style, marker_radius_px=scail_overlay_px, stick_width_px=scail_overlay_px, face_marker_radius_px=scail_face_px, hand_color_style=cap_hand_color_style, person_brightness_falloff=person_palette_falloff, ) else: img = render_pose_data( mhr_pose_data, frame_idx=f, W=W, H=H, background=bg_f, composite=composite, opacity=opacity, shader_preset=shader_key, rainbow_tilt_x_deg=rainbow_tilt_x, rainbow_tilt_z_deg=rainbow_tilt_z, person_brightness_falloff=person_palette_falloff, ) frames_out.append(img) pbar.update(1) out_image = torch.stack(frames_out, dim=0) if out_image.device != out_device: out_image = out_image.to(out_device) return io.NodeOutput(out_image) class SAM3DBodyExtension(ComfyExtension): @override async def get_node_list(self) -> List[type[io.ComfyNode]]: return [ SAM3DBody_Loader, SAM3DBody_Predict, SAM3DBody_FaceExpression, SAM3DBody_Smooth, SAM3DBody_Render, ] async def comfy_entrypoint() -> SAM3DBodyExtension: return SAM3DBodyExtension()