mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
1096 lines
47 KiB
Python
1096 lines
47 KiB
Python
"""SAM 3D Body — Predict + Smooth nodes and their inference helpers."""
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
from scipy.signal import savgol_coeffs
|
|
|
|
import comfy.model_management
|
|
import comfy.utils
|
|
from comfy_api.latest import io, ComfyExtension
|
|
from typing_extensions import override
|
|
import folder_paths
|
|
|
|
from comfy.ldm.sam3d_body.model.model import SAM3DBody
|
|
from comfy.ldm.sam3d_body.model.dinov3 import apply_dinov3_qkv_bias_mask
|
|
from comfy_extras.sam3d_body.utils import (
|
|
apply_camera_override,
|
|
cam_int_from_fov,
|
|
cam_int_from_moge,
|
|
inputs_from_sam3_track,
|
|
run_batched_frames,
|
|
run_batched_single_chunk,
|
|
compute_canonical_colors,
|
|
compute_hand_vert_mask,
|
|
)
|
|
|
|
from comfy_extras.sam3d_body.rasterizer import render_pose_data_torch as render_pose_data
|
|
from comfy_extras.sam3d_body.export.capsules import render_pose_data_capsules
|
|
from comfy_extras.sam3d_body.export.openpose_2d import render_pose_data_openpose
|
|
from comfy_extras.sam3d_body import face_expression as fx
|
|
from comfy_extras.sam3d_body.utils import image_to_uint8
|
|
|
|
|
|
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
|
# MHRPoseData = SAM3DBody_Predict's native output (carries mhr_model_params,
|
|
# shape_params, expr_params, MHR70 keypoint layout, canonical_colors keyed to
|
|
# MHR mesh, hand_vert_mask from MHR LBS). The export-side consumers
|
|
# (BuildPoseGLB / SavePoseBVH in comfy_extras/nodes_save_3d.py) also accept
|
|
# KIMODO_POSE_DATA via a MultiType union — those types are mirrored there.
|
|
MHRPoseData = io.Custom("MHR_POSE_DATA")
|
|
SAM3DBodyModel = io.Custom("SAM3D_BODY_MODEL")
|
|
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
|
|
|
|
# Loader
|
|
|
|
class SAM3DBody_Loader(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3DBody_Loader",
|
|
display_name="Load SAM3D Body Model",
|
|
category="image/detection/sam3dbody/", #TODO: better category?
|
|
inputs=[
|
|
io.Combo.Input(
|
|
"model_file",
|
|
options=folder_paths.get_filename_list("detection"),
|
|
tooltip="SAM 3D Body weights (.safetensors) in the 'detection' folder",
|
|
),
|
|
],
|
|
outputs=[SAM3DBodyModel.Output("model", display_name="sam3d_body_model")],
|
|
)
|
|
|
|
|
|
@classmethod
|
|
def execute(cls, model_file) -> io.NodeOutput:
|
|
path = folder_paths.get_full_path_or_raise("detection", model_file)
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
sd = {k.replace(".layers.0.0.", ".layers.0."): v for k, v in sd.items()}
|
|
|
|
load_device = comfy.model_management.get_torch_device()
|
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
|
torch_dtype = comfy.model_management.unet_dtype(
|
|
device=load_device, model_params=-1, weight_dtype=weight_dtype,
|
|
)
|
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(torch_dtype, load_device)
|
|
operations = comfy.ops.pick_operations(torch_dtype, manual_cast_dtype, load_device=load_device, disable_fast_fp8=True)
|
|
|
|
model = SAM3DBody(dtype=torch_dtype, operations=operations)
|
|
model.load_state_dict(sd, strict=False)
|
|
|
|
apply_dinov3_qkv_bias_mask(model.backbone)
|
|
|
|
model.eval()
|
|
model.backbone_dtype = torch_dtype
|
|
model._sam3d_image_size = model.image_size
|
|
|
|
model._sam3d_canonical_colors = compute_canonical_colors(model)
|
|
model._sam3d_hand_vert_mask = compute_hand_vert_mask(model)
|
|
|
|
patcher = comfy.model_patcher.CoreModelPatcher(
|
|
model,
|
|
load_device=load_device,
|
|
offload_device=comfy.model_management.unet_offload_device(),
|
|
size=comfy.model_management.module_size(model),
|
|
)
|
|
return io.NodeOutput(patcher)
|
|
|
|
# Predict
|
|
|
|
def _per_frame_bboxes_from_detections(bboxes, B: int):
|
|
# BoundingBox payload (RT-DETR etc.): dict | list[dict] | list[list[dict]].
|
|
if isinstance(bboxes, dict):
|
|
norm = [[bboxes]]
|
|
elif not bboxes:
|
|
return None
|
|
elif isinstance(bboxes[0], dict):
|
|
norm = [bboxes] # flat list → same detections every frame
|
|
else:
|
|
norm = list(bboxes)
|
|
if len(norm) == 1:
|
|
norm = norm * B
|
|
norm = (norm + [[]] * B)[:B]
|
|
out = []
|
|
for frame in norm:
|
|
if frame:
|
|
boxes = torch.tensor(
|
|
[[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in frame],
|
|
dtype=torch.float32,
|
|
)
|
|
else:
|
|
boxes = torch.zeros((0, 4), dtype=torch.float32)
|
|
out.append(boxes)
|
|
return out
|
|
|
|
|
|
class SAM3DBody_Predict(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3DBody_Predict",
|
|
display_name="Predict SAM3D Body",
|
|
category="image/detection/sam3dbody/",
|
|
inputs=[
|
|
SAM3DBodyModel.Input("sam3d_body_model"),
|
|
io.Image.Input("image"),
|
|
SAM3TrackData.Input(
|
|
"sam3_track_data", optional=True,
|
|
tooltip=("Output of SAM3 Video Track, required for multi-person detection"),
|
|
),
|
|
io.BoundingBox.Input(
|
|
"bboxes", optional=True, force_input=True,
|
|
tooltip=(
|
|
"Per-frame person boxes (e.g. RT-DETR Detect with class_name='person'). "
|
|
"Use for better detection as alternative to SAM3 tracks."
|
|
),
|
|
),
|
|
io.Boolean.Input(
|
|
"run_hand_refinement", default=True,
|
|
tooltip="Improves hand pose at the cost of extra inference time and memory use"),
|
|
io.Float.Input(
|
|
"fov_degrees",
|
|
default=0.0, min=0.0, max=170.0, step=0.5,
|
|
tooltip=( #TODO: get FoV from moge another way?
|
|
"Vertical FOV in degrees. Affects predicted depth (cam_t.z) and "
|
|
"absolute scale. 0 = use moge_geometry or fall back to ~53° (16:9). "
|
|
"Any non-zero value overrides moge_geometry."
|
|
),
|
|
),
|
|
MoGeGeometry.Input(
|
|
"moge_geometry",
|
|
optional=True,
|
|
tooltip=(
|
|
"MoGe geometry, used to calculate camera field of view."
|
|
"For batches choose the most representative frame, or leave unset"
|
|
),
|
|
),
|
|
io.Int.Input(
|
|
"chunk_size", #TODO: automate?
|
|
default=64, min=1, max=512, step=1, advanced=True,
|
|
tooltip=(
|
|
"Max person-crops per forward. Higher = throughput + VRAM; "
|
|
"per-chunk frame count is chunk_size / persons_per_frame."
|
|
),
|
|
),
|
|
],
|
|
outputs=[MHRPoseData.Output("mhr_pose_data")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, sam3d_body_model, image, sam3_track_data=None, bboxes=None, run_hand_refinement=True, fov_degrees=0.0, moge_geometry=None, chunk_size=64) -> io.NodeOutput:
|
|
comfy.model_management.load_model_gpu(sam3d_body_model)
|
|
inner: SAM3DBody = sam3d_body_model.model
|
|
|
|
B, H, W, _ = image.shape
|
|
image_size = getattr(inner, "_sam3d_image_size", (512, 512))
|
|
|
|
# Precedence: SAM3 track (masks + boxes) > detector boxes > full-frame fallback.
|
|
per_frame_bboxes, per_frame_masks = (None, None)
|
|
if sam3_track_data is not None:
|
|
per_frame_bboxes, per_frame_masks = inputs_from_sam3_track(sam3_track_data, B, H, W)
|
|
if per_frame_bboxes is None and bboxes:
|
|
per_frame_bboxes = _per_frame_bboxes_from_detections(bboxes, B)
|
|
per_frame_masks = None
|
|
if per_frame_bboxes is None:
|
|
# No track or detector boxes — single-person full-frame fallback.
|
|
full_frame_bbox = torch.tensor([[0.0, 0.0, float(W), float(H)]], dtype=torch.float32)
|
|
per_frame_bboxes = [full_frame_bbox.clone() for _ in range(B)]
|
|
per_frame_masks = None
|
|
inference_type = "full" if run_hand_refinement else "body"
|
|
# Precedence: explicit fov_degrees > MoGe estimate > diagonal default.
|
|
cam_int = cam_int_from_fov(int(H), int(W), float(fov_degrees))
|
|
if cam_int is None:
|
|
cam_int = cam_int_from_moge(moge_geometry, int(H), int(W))
|
|
|
|
frames_rgb: List[Optional[torch.Tensor]] = []
|
|
for f in range(B):
|
|
if per_frame_bboxes[f].shape[0] == 0:
|
|
frames_rgb.append(None)
|
|
else:
|
|
frames_rgb.append(image_to_uint8(image[f]))
|
|
|
|
# Batched path requires uniform non-zero K across all frames.
|
|
bbox_counts = {per_frame_bboxes[f].shape[0] for f in range(B) if frames_rgb[f] is not None}
|
|
can_batch = (
|
|
len(bbox_counts) == 1
|
|
and 0 not in bbox_counts
|
|
and all(frames_rgb[f] is not None for f in range(B))
|
|
)
|
|
|
|
frames_out: List[List[Dict[str, Any]]] = []
|
|
pbar = comfy.utils.ProgressBar(B)
|
|
|
|
if can_batch and B > 0:
|
|
frames_out = run_batched_frames(
|
|
inner, frames_rgb, per_frame_bboxes, per_frame_masks,
|
|
image_size, inference_type,
|
|
cam_int=cam_int,
|
|
pbar=pbar,
|
|
crops_per_chunk=int(chunk_size),
|
|
)
|
|
else:
|
|
# Mixed K per frame — call the batched path once per frame.
|
|
for f in range(B):
|
|
if frames_rgb[f] is None or per_frame_bboxes[f].shape[0] == 0:
|
|
frames_out.append([])
|
|
pbar.update(1)
|
|
continue
|
|
mask_f = [per_frame_masks[f]] if per_frame_masks is not None else None
|
|
chunk = run_batched_single_chunk(
|
|
inner, [frames_rgb[f]], [per_frame_bboxes[f]], mask_f,
|
|
image_size, inference_type,
|
|
K=int(per_frame_bboxes[f].shape[0]),
|
|
cam_int=cam_int,
|
|
)
|
|
frames_out.append(chunk[0])
|
|
pbar.update(1)
|
|
|
|
mhr_pose_data = {
|
|
"frames": frames_out,
|
|
"faces": inner.head_pose.faces.cpu().numpy(),
|
|
"image_size": (int(H), int(W)),
|
|
"canonical_colors": getattr(inner, "_sam3d_canonical_colors", None),
|
|
"hand_vert_mask": getattr(inner, "_sam3d_hand_vert_mask", None),
|
|
}
|
|
return io.NodeOutput(mhr_pose_data)
|
|
|
|
|
|
class SAM3DBody_FaceExpression(io.ComfyNode):
|
|
"""Drive MHR face blendshapes from the core MediaPipe Face Landmarker.
|
|
|
|
Detects per-frame faces, IoU-matches each to a tracked person, maps the 52
|
|
ARKit blendshapes onto MHR's 72-axis `expr_params`, and re-runs MHR forward
|
|
so pred_vertices/pred_keypoints reflect the new expression.
|
|
"""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3DBody_FaceExpression",
|
|
description="Drive MHR face blendshapes from the core MediaPipe Face Landmarker.",
|
|
display_name="Face Expression to SAM3D Body", #TODO: better name?
|
|
category="image/detection/sam3dbody/",
|
|
inputs=[
|
|
MHRPoseData.Input("mhr_pose_data"),
|
|
SAM3DBodyModel.Input("sam3d_body_model"),
|
|
io.Image.Input("image"),
|
|
io.Float.Input(
|
|
"strength", default=1.0, min=0.0, max=4.0, step=0.05,
|
|
tooltip="Global multiplier on all blendshapes. >1 exaggerates.",
|
|
),
|
|
io.Float.Input(
|
|
"mouth_strength", default=1.0, min=0.0, max=4.0, step=0.05,
|
|
tooltip="Multiplier on mouth/jaw shapes. MP's jawOpen saturates near 1.0.",
|
|
advanced=True,
|
|
),
|
|
io.Float.Input(
|
|
"eye_strength", default=2.0, min=0.0, max=4.0, step=0.05,
|
|
tooltip="Multiplier on eye shapes. MP rarely exceeds 0.5; 2-3x often needed.",
|
|
advanced=True,
|
|
),
|
|
io.Float.Input(
|
|
"brow_strength", default=2.0, min=0.0, max=4.0, step=0.05,
|
|
tooltip="Multiplier on brow/cheek/sneer shapes. MP outputs ~0.1-0.3; 2-3x.",
|
|
advanced=True,
|
|
),
|
|
io.Float.Input(
|
|
"input_threshold", default=0.02, min=0.0, max=0.5, step=0.01,
|
|
tooltip=(
|
|
"Deadzone on MediaPipe's raw output (below = zero, above = linear remap). "
|
|
),
|
|
advanced=True,
|
|
),
|
|
io.Int.Input(
|
|
"blendshape_smooth_window", default=7, min=1, max=31, step=2,
|
|
tooltip=(
|
|
"Gaussian window on MediaPipe's per-frame signal before MHR mapping. "
|
|
"MediaPipe's raw output swings 30-70% frame-to-frame on static faces. "
|
|
"1 = disabled. Use odd values."
|
|
),
|
|
advanced=True,
|
|
),
|
|
],
|
|
outputs=[
|
|
MHRPoseData.Output("mhr_pose_data"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, mhr_pose_data, sam3d_body_model, image,
|
|
strength=1.0, mouth_strength=1.0, eye_strength=1.0, brow_strength=1.0,
|
|
input_threshold=0.15, blendshape_smooth_window=7) -> io.NodeOutput:
|
|
|
|
comfy.model_management.load_model_gpu(sam3d_body_model)
|
|
inner: SAM3DBody = sam3d_body_model.model
|
|
|
|
frames = mhr_pose_data["frames"]
|
|
B = len(frames)
|
|
if B == 0:
|
|
return io.NodeOutput(mhr_pose_data)
|
|
|
|
img_np = (image * 255.0).clamp(0.0, 255.0).to(torch.uint8).cpu().numpy()
|
|
new_frames: List[List[Dict[str, Any]]] = [[dict(p) for p in f] for f in frames]
|
|
|
|
max_persons = max((len(f) for f in new_frames), default=0)
|
|
per_person_coefs: List[List[Optional[Dict[str, float]]]] = [
|
|
[None] * B for _ in range(max_persons)
|
|
]
|
|
pbar = comfy.utils.ProgressBar(B)
|
|
n_total_frames_with_persons = 0
|
|
|
|
crop_factor = 1.2
|
|
|
|
# Auto-pick full-frame vs per-person crops. BlazeFace full-range needs
|
|
# ≥32px face in its 192px input; below that we escalate to per-person crops
|
|
H_img0, W_img0 = img_np.shape[1], img_np.shape[2]
|
|
min_bbox_px = 32.0 * max(H_img0, W_img0) / (192.0 * 0.20) # Face height ≈ 20% of body-bbox height.
|
|
use_per_person_crops = any(
|
|
(p["bbox"][3] - p["bbox"][1]) < min_bbox_px
|
|
for persons in new_frames for p in persons
|
|
)
|
|
|
|
for fi in tqdm(range(B), desc="SAM3D face expression detect"):
|
|
persons = new_frames[fi]
|
|
img_fi = img_np[min(fi, img_np.shape[0] - 1)]
|
|
|
|
if not persons:
|
|
pbar.update(1)
|
|
continue
|
|
n_total_frames_with_persons += 1
|
|
|
|
person_bboxes = [np.asarray(p["bbox"], dtype=np.float32) for p in persons]
|
|
H_img, W_img = img_fi.shape[:2]
|
|
|
|
if use_per_person_crops:
|
|
# One MP call per person on a tight head crop — recovers small/
|
|
# distant faces that the full-frame 192px BlazeFace would miss.
|
|
for pid, pb in enumerate(person_bboxes):
|
|
if pid >= max_persons:
|
|
continue
|
|
cr = fx.head_crop_from_keypoints(
|
|
persons[pid].get("pred_keypoints_2d"), crop_factor, W_img, H_img,
|
|
)
|
|
if cr is None:
|
|
cr = fx.head_region_crop(pb, crop_factor, W_img, H_img)
|
|
faces = fx.detect_faces_in_crop(inner, img_fi, cr, num_faces=1)
|
|
if not faces:
|
|
continue
|
|
# Pick face closest to person bbox center when a neighbor leaks in.
|
|
pcx, pcy = 0.5 * (pb[0] + pb[2]), 0.5 * (pb[1] + pb[3])
|
|
best = min(
|
|
faces,
|
|
key=lambda f: (0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]) - pcx) ** 2
|
|
+ (0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3]) - pcy) ** 2,
|
|
)
|
|
per_person_coefs[pid][fi] = best["blendshapes"]
|
|
else:
|
|
faces = inner.face_landmarker.detect_batch([img_fi], num_faces=max(1, len(persons)))[0]
|
|
if faces:
|
|
face_bboxes = [f["bbox_xyxy"] for f in faces]
|
|
assignment = fx.assign_faces_to_persons(face_bboxes, person_bboxes)
|
|
for pid, face_idx in enumerate(assignment):
|
|
if face_idx is None or pid >= max_persons:
|
|
continue
|
|
per_person_coefs[pid][fi] = faces[face_idx]["blendshapes"]
|
|
|
|
pbar.update(1)
|
|
|
|
# Baseline subtraction. MP has subject-specific rest bias, without subtraction, strength
|
|
# multipliers bake that into every frame.
|
|
# Per-clip needs ~30 frames or it would zero out the expression.
|
|
BASELINE_MIN_FRAMES = 30
|
|
if n_total_frames_with_persons >= BASELINE_MIN_FRAMES:
|
|
for pid in range(max_persons):
|
|
per_person_coefs[pid] = fx.subtract_per_clip_baseline(
|
|
per_person_coefs[pid], percentile=5.0,
|
|
)
|
|
else:
|
|
logging.warning(
|
|
f"[SAM 3D Body FaceExpression] per-clip baseline subtraction "
|
|
f"needs ~{BASELINE_MIN_FRAMES}+ frames with detections; "
|
|
f"got {n_total_frames_with_persons}. Skipping subtraction."
|
|
)
|
|
|
|
# Smooth raw signal AFTER baseline subtraction but BEFORE gap fill
|
|
# MP's per-frame noise gets averaged out at the source.
|
|
bs_win = int(blendshape_smooth_window)
|
|
if bs_win > 1:
|
|
for pid in range(max_persons):
|
|
per_person_coefs[pid] = fx.smooth_blendshape_series(
|
|
per_person_coefs[pid], window=bs_win,
|
|
)
|
|
|
|
for pid in range(max_persons):
|
|
per_person_coefs[pid] = fx.fill_detection_gaps(
|
|
per_person_coefs[pid], method="interpolate", max_gap=12,
|
|
)
|
|
|
|
n_written = 0
|
|
for fi in range(B):
|
|
for pid, p in enumerate(new_frames[fi]):
|
|
if pid >= max_persons:
|
|
continue
|
|
coefs = per_person_coefs[pid][fi]
|
|
if coefs is None:
|
|
continue
|
|
p["expr_params"] = fx.arkit_to_expr_params(
|
|
coefs,
|
|
strength=float(strength),
|
|
mouth_strength=float(mouth_strength),
|
|
eye_strength=float(eye_strength),
|
|
brow_strength=float(brow_strength),
|
|
input_threshold=float(input_threshold),
|
|
).astype(np.float32)
|
|
n_written += 1
|
|
|
|
if n_written > 0:
|
|
fx.regenerate_mesh_from_params(inner, new_frames)
|
|
|
|
new_pose = dict(mhr_pose_data)
|
|
new_pose["frames"] = new_frames
|
|
|
|
return io.NodeOutput(new_pose)
|
|
|
|
|
|
class SAM3DBody_Smooth(io.ComfyNode):
|
|
"""Reduce frame-to-frame jitter via vertex-space temporal averaging.
|
|
Backs off on mesh-geometry keys when the subject rotates fast (averaging
|
|
across a spin flattens the mesh); camera-space keys still get full
|
|
smoothing.
|
|
"""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3DBody_Smooth",
|
|
description="Reduce frame-to-frame jitter via vertex-space temporal averaging",
|
|
display_name="Smooth SAM3D Body Pose Frames",
|
|
category="image/detection/sam3dbody/",
|
|
inputs=[
|
|
MHRPoseData.Input("mhr_pose_data"),
|
|
io.Float.Input(
|
|
"strength",
|
|
default=1.0, min=0.0, max=1.0, step=0.05,
|
|
tooltip="Blend raw (0) → smoothed (1).",
|
|
),
|
|
io.Combo.Input(
|
|
"method",
|
|
options=["gaussian", "savgol"],
|
|
default="gaussian", advanced=True,
|
|
tooltip=(
|
|
"'gaussian': symmetric weighted average, best general-purpose smoother. "
|
|
"'savgol': sliding polynomial fit, preserves sharp peaks."
|
|
),
|
|
),
|
|
io.Int.Input(
|
|
"window",
|
|
default=7, min=1, max=51, step=2, advanced=True,
|
|
tooltip="Temporal window in frames (odd values).",
|
|
),
|
|
io.Float.Input(
|
|
"rotation_threshold_deg",
|
|
default=15.0, min=0.0, max=45.0, step=1.0, advanced=True,
|
|
tooltip=(
|
|
"Disables smoothing for this root-rotation rate (deg/frame) to preserve fast spins. "
|
|
"15° suits most content, low values trigger on ordinary jitter and "
|
|
"silently sabotage smoothing. 0 = disable."
|
|
),
|
|
),
|
|
],
|
|
outputs=[MHRPoseData.Output("mhr_pose_data")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, mhr_pose_data, method, window, strength, rotation_threshold_deg) -> io.NodeOutput:
|
|
if strength <= 0.0 or window <= 1:
|
|
return io.NodeOutput(mhr_pose_data)
|
|
|
|
frames = mhr_pose_data["frames"]
|
|
B = len(frames)
|
|
if B < 2:
|
|
return io.NodeOutput(mhr_pose_data)
|
|
max_p = max((len(f) for f in frames), default=0)
|
|
if max_p == 0:
|
|
return io.NodeOutput(mhr_pose_data)
|
|
|
|
# Geometry keys rotate with the subject, so linear averaging during
|
|
# fast spins flattens the mesh — these get per-frame adaptive strength.
|
|
keys_geom = {
|
|
"pred_vertices", "pred_keypoints_3d", "pred_joint_coords",
|
|
"pred_global_rots", "mhr_model_params", "body_pose_params",
|
|
"global_rot", "pred_pose_raw",
|
|
}
|
|
# Camera / appearance / 2D keys are safe to smooth linearly.
|
|
keys_cam = {
|
|
"pred_cam_t", "pred_keypoints_2d", "focal_length",
|
|
"shape_params", "scale_params", "hand_pose_params", "expr_params",
|
|
}
|
|
all_keys = sorted(keys_geom | keys_cam)
|
|
|
|
kernel = _smoothing_kernel(method, window)
|
|
smoothed = [list(f) for f in frames]
|
|
|
|
base_blend = float(strength)
|
|
rot_thresh = float(np.deg2rad(max(0.0, rotation_threshold_deg)))
|
|
|
|
for pid in range(max_p):
|
|
valid = np.array([pid < len(f) for f in frames], dtype=bool)
|
|
if valid.sum() < 2:
|
|
continue
|
|
|
|
# Adaptive blend per frame from `global_rot` (euler ZYX);
|
|
geom_blend = np.full(B, base_blend, dtype=np.float32)
|
|
if rot_thresh > 0.0:
|
|
root_rotmats = []
|
|
valid_root = []
|
|
for fi in range(B):
|
|
if not valid[fi]:
|
|
root_rotmats.append(np.eye(3, dtype=np.float32))
|
|
valid_root.append(False)
|
|
continue
|
|
gr = frames[fi][pid].get("global_rot")
|
|
if gr is None:
|
|
root_rotmats.append(np.eye(3, dtype=np.float32))
|
|
valid_root.append(False)
|
|
continue
|
|
eul = np.asarray(gr, dtype=np.float32).reshape(3)
|
|
# ZYX convention: R = Rz @ Ry @ Rx
|
|
cz, sz = np.cos(eul[0]), np.sin(eul[0])
|
|
cy, sy = np.cos(eul[1]), np.sin(eul[1])
|
|
cx, sx = np.cos(eul[2]), np.sin(eul[2])
|
|
Rz = np.array([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dtype=np.float32)
|
|
Ry = np.array([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=np.float32)
|
|
Rx = np.array([[1, 0, 0], [0, cx, -sx], [0, sx, cx]], dtype=np.float32)
|
|
root_rotmats.append(Rz @ Ry @ Rx)
|
|
valid_root.append(True)
|
|
ang = np.zeros(B, dtype=np.float32)
|
|
for fi in range(1, B):
|
|
if valid_root[fi] and valid_root[fi - 1]:
|
|
R_delta = root_rotmats[fi] @ root_rotmats[fi - 1].T
|
|
cos_a = float(np.clip((np.trace(R_delta) - 1.0) / 2.0, -1.0, 1.0))
|
|
ang[fi] = abs(np.arccos(cos_a))
|
|
# Peak-smear over ±window/2 — neighbors of a rotated frame
|
|
# must back off too, or the temporal average pulls the rotated
|
|
# pose in and flattens the mesh anyway.
|
|
half = max(1, window // 2)
|
|
ang_smooth = np.zeros_like(ang)
|
|
for fi in range(B):
|
|
lo = max(0, fi - half)
|
|
hi = min(B, fi + half + 1)
|
|
ang_smooth[fi] = ang[lo:hi].max() if hi > lo else 0.0
|
|
# base_blend at no rotation, 0 at threshold.
|
|
ratio = np.clip(ang_smooth / rot_thresh, 0.0, 1.0)
|
|
geom_blend = base_blend * (1.0 - ratio)
|
|
|
|
for key in all_keys:
|
|
ref = None
|
|
for fi in range(B):
|
|
if valid[fi] and key in frames[fi][pid] and frames[fi][pid][key] is not None:
|
|
ref = np.asarray(frames[fi][pid][key])
|
|
break
|
|
if ref is None or ref.dtype.kind not in "fiu":
|
|
continue
|
|
stacked = np.zeros((B,) + ref.shape, dtype=np.float32)
|
|
for fi in range(B):
|
|
if valid[fi] and key in frames[fi][pid] and frames[fi][pid][key] is not None:
|
|
stacked[fi] = np.asarray(frames[fi][pid][key], dtype=np.float32)
|
|
else:
|
|
stacked[fi] = stacked[fi - 1] if fi > 0 else 0.0
|
|
filtered = _apply_temporal_filter(stacked, kernel)
|
|
|
|
if key in keys_geom:
|
|
b = geom_blend
|
|
while b.ndim < stacked.ndim:
|
|
b = b[..., None]
|
|
out = (1.0 - b) * stacked + b * filtered
|
|
else:
|
|
out = (1.0 - base_blend) * stacked + base_blend * filtered
|
|
|
|
for fi in range(B):
|
|
if valid[fi]:
|
|
smoothed[fi][pid] = dict(smoothed[fi][pid])
|
|
smoothed[fi][pid][key] = out[fi].astype(ref.dtype)
|
|
|
|
new_pose = dict(mhr_pose_data)
|
|
new_pose["frames"] = smoothed
|
|
return io.NodeOutput(new_pose)
|
|
|
|
|
|
def _smoothing_kernel(method: str, window: int) -> np.ndarray:
|
|
window = max(1, int(window))
|
|
if window % 2 == 0:
|
|
window += 1
|
|
if method == "savgol":
|
|
order = 3 if window >= 5 else min(window - 1, 1)
|
|
return savgol_coeffs(window, order).astype(np.float32)
|
|
# gaussian (default)
|
|
sigma = max(1.0, window / 5.0)
|
|
x = np.arange(window) - (window - 1) / 2.0
|
|
k = np.exp(-(x ** 2) / (2 * sigma ** 2))
|
|
return (k / k.sum()).astype(np.float32)
|
|
|
|
|
|
def _apply_temporal_filter(stacked: np.ndarray, kernel: np.ndarray) -> np.ndarray:
|
|
"""stacked: (B, *feature_shape). Returns same shape, smoothed over axis 0."""
|
|
B = stacked.shape[0]
|
|
w = len(kernel)
|
|
pad = w // 2
|
|
flat = stacked.reshape(B, -1) # (B, K)
|
|
padded = np.concatenate(
|
|
[np.repeat(flat[:1], pad, axis=0), flat, np.repeat(flat[-1:], pad, axis=0)],
|
|
axis=0,
|
|
) # (B + 2*pad, K)
|
|
out = np.zeros_like(flat)
|
|
for i, k in enumerate(kernel):
|
|
out += k * padded[i : i + B]
|
|
return out.reshape(stacked.shape)
|
|
|
|
|
|
# Render
|
|
|
|
def rainbow_tilt_inputs():
|
|
"""Shared rainbow-shader tilt inputs (used by Render and ToGLB schemas)."""
|
|
return [
|
|
io.Float.Input(
|
|
"rainbow_tilt_z", default=-35.0, min=-90.0, max=90.0, step=0.5,
|
|
tooltip="Rotate rainbow jet axis around Z (forward). Differentiates left/right.",
|
|
),
|
|
io.Float.Input(
|
|
"rainbow_tilt_x", default=0.0, min=-90.0, max=90.0, step=0.5,
|
|
tooltip="Rotate rainbow jet axis around X (right). Differentiates front/back.",
|
|
),
|
|
]
|
|
|
|
|
|
def _render_mesh_mode_inputs():
|
|
return [
|
|
io.DynamicCombo.Input(
|
|
"shader",
|
|
options=[
|
|
io.DynamicCombo.Option("default", []),
|
|
io.DynamicCombo.Option("normals", []),
|
|
io.DynamicCombo.Option("rainbow", rainbow_tilt_inputs()),
|
|
io.DynamicCombo.Option("rainbow_face_normal", rainbow_tilt_inputs()),
|
|
io.DynamicCombo.Option("rainbow_face_semantic", rainbow_tilt_inputs()),
|
|
io.DynamicCombo.Option("depth", []),
|
|
],
|
|
tooltip=(
|
|
"Preset shader. 'normals' = current surface normal in camera "
|
|
"space (OpenGL Y+ normal-map convention: +X→R, +Y→G, +Z→B). "
|
|
"'rainbow' = RealisDance style body-Y jet; the 'rainbow_face_*' "
|
|
"variants override face verts with normal/per-region colors; "
|
|
"'depth' = linear gray."
|
|
),
|
|
),
|
|
io.Float.Input("opacity", default=1.0, min=0.0, max=1.0, step=0.01),
|
|
io.Float.Input(
|
|
"person_palette_falloff",
|
|
default=0.6, min=0.1, max=1.0, step=0.05,
|
|
tooltip=(
|
|
"Per-person desaturation toward white: track k gets a "
|
|
"(1 - falloff^k) pastel mix (SCAIL 'softer second person'). 1.0 = off."
|
|
),
|
|
),
|
|
io.Combo.Input(
|
|
"region",
|
|
options=["full_body", "hands_only"],
|
|
default="full_body",
|
|
tooltip=(
|
|
"'hands_only' filters faces via the precomputed `hand_vert_mask` "
|
|
"(LBS weights against canonical hand KPs) — isolates the hand "
|
|
"mesh for debugging. Falls back to full mesh if the mask is missing."
|
|
),
|
|
),
|
|
]
|
|
|
|
|
|
def _render_capsules_mode_inputs():
|
|
return [
|
|
io.Float.Input(
|
|
"radius_m", default=0.022, min=0.005, max=0.2, step=0.001,
|
|
tooltip="Capsule radius in meters (SCAIL reference: ~0.022 m).",
|
|
),
|
|
io.Combo.Input(
|
|
"hand_style",
|
|
options=["disabled", "dwpose", "openpose"],
|
|
default="dwpose",
|
|
tooltip=(
|
|
"Composite 2D OpenPose hands on top of the 3D capsule body "
|
|
"(matches SCAIL — no 3D hand capsules). 'disabled' = no hands. "
|
|
"'dwpose' = solid-blue hand dots; 'openpose' = rainbow dots. "
|
|
"Sticks stay rainbow per-finger either way."
|
|
),
|
|
),
|
|
io.Combo.Input(
|
|
"face_style",
|
|
options=["disabled", "full", "eyes_mouth"],
|
|
default="disabled",
|
|
tooltip=(
|
|
"'full' = all face landmarks (sapiens-238 if present, else "
|
|
"rig-fallback ~30). 'eyes_mouth' = rig-fallback subset (~12 "
|
|
"dots: eyes + outer lips only). 'disabled' = no face dots."
|
|
),
|
|
),
|
|
io.Float.Input(
|
|
"person_palette_falloff", default=0.6, min=0.1, max=1.0, step=0.05,
|
|
tooltip=(
|
|
"Per-person desaturation: track k blends toward white by "
|
|
"1 - falloff^k. Track 0 stays vivid; 1.0 disables falloff."
|
|
),
|
|
),
|
|
]
|
|
|
|
|
|
def _render_openpose3d_mode_inputs():
|
|
return [
|
|
io.Float.Input(
|
|
"radius_m", default=0.015, min=0.004, max=0.1, step=0.001,
|
|
tooltip="Limb capsule radius in meters (thin = stick-like).",
|
|
),
|
|
io.Boolean.Input(
|
|
"include_hands", default=True,
|
|
tooltip="Draw 21+21 hand keypoints as 3D capsules.",
|
|
),
|
|
io.Float.Input(
|
|
"person_palette_falloff", default=0.6, min=0.1, max=1.0, step=0.05,
|
|
tooltip=(
|
|
"Per-person desaturation: track k blends toward white by "
|
|
"1 - falloff^k. Track 0 stays vivid; 1.0 disables falloff."
|
|
),
|
|
),
|
|
]
|
|
|
|
|
|
def _render_openpose_mode_inputs():
|
|
return [
|
|
io.Int.Input(
|
|
"marker_radius_px", default=4, min=1, max=32, step=1,
|
|
tooltip="Body keypoint dot radius (px).",
|
|
),
|
|
io.Int.Input(
|
|
"stick_width_px", default=4, min=1, max=32, step=1,
|
|
tooltip="Body limb ellipse half-width (px). DWPose default = 4.",
|
|
),
|
|
io.Float.Input(
|
|
"limb_alpha", default=0.6, min=0.0, max=1.0, step=0.05,
|
|
tooltip="Per-limb alpha. DWPose default = 0.6.",
|
|
),
|
|
io.Combo.Input(
|
|
"face_style",
|
|
options=["disabled", "full", "eyes_mouth"],
|
|
default="disabled",
|
|
tooltip=(
|
|
"'full' = all face landmarks (sapiens-238 if present, else "
|
|
"rig-fallback ~30). 'eyes_mouth' = rig-fallback subset (~12 "
|
|
"dots: eyes + outer lips only). 'disabled' = no face dots."
|
|
),
|
|
),
|
|
io.Combo.Input(
|
|
"hand_style",
|
|
options=["disabled", "dwpose", "openpose"],
|
|
default="disabled",
|
|
tooltip=(
|
|
"Draw 21+21 hand keypoints + sticks. 'disabled' = no hands. "
|
|
"'dwpose' = solid-blue dots; 'openpose' = rainbow dots."
|
|
),
|
|
),
|
|
io.Float.Input(
|
|
"person_palette_falloff", default=0.6, min=0.1, max=1.0, step=0.05,
|
|
tooltip=(
|
|
"Per-person desaturation: track k blends toward white by "
|
|
"1 - falloff^k. Track 0 stays vivid; 1.0 disables falloff."
|
|
),
|
|
),
|
|
]
|
|
|
|
|
|
def _scale_pose_data(mhr_pose_data: Dict[str, Any], new_H: int, new_W: int) -> Dict[str, Any]:
|
|
# 2D coords must match the body's letterbox transform (uniform scale +
|
|
# center offset), else face/hand overlays drift off the body.
|
|
old_H, old_W = mhr_pose_data["image_size"]
|
|
if new_H == old_H and new_W == old_W:
|
|
return mhr_pose_data
|
|
rW = new_W / old_W
|
|
rH = new_H / old_H
|
|
r_focal = min(rW, rH)
|
|
offset_x = (new_W - r_focal * old_W) * 0.5
|
|
offset_y = (new_H - r_focal * old_H) * 0.5
|
|
new_frames: List[List[Dict[str, Any]]] = []
|
|
for frame in mhr_pose_data["frames"]:
|
|
scaled = []
|
|
for p in frame:
|
|
p = dict(p)
|
|
f = p.get("focal_length")
|
|
if f is not None:
|
|
p["focal_length"] = np.asarray(f, dtype=np.float32) * r_focal
|
|
for k in ("pred_keypoints_2d", "pred_face_keypoints_2d"):
|
|
v = p.get(k)
|
|
if v is not None:
|
|
arr = np.asarray(v, dtype=np.float32).copy()
|
|
arr[..., 0] = arr[..., 0] * r_focal + offset_x
|
|
arr[..., 1] = arr[..., 1] * r_focal + offset_y
|
|
p[k] = arr
|
|
bb = p.get("bbox")
|
|
if bb is not None:
|
|
bb = np.asarray(bb, dtype=np.float32).copy()
|
|
bb[..., [0, 2]] = bb[..., [0, 2]] * r_focal + offset_x
|
|
bb[..., [1, 3]] = bb[..., [1, 3]] * r_focal + offset_y
|
|
p["bbox"] = bb
|
|
scaled.append(p)
|
|
new_frames.append(scaled)
|
|
out = dict(mhr_pose_data)
|
|
out["image_size"] = (new_H, new_W)
|
|
out["frames"] = new_frames
|
|
return out
|
|
|
|
|
|
class SAM3DBody_Render(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3DBody_Render",
|
|
display_name="Render SAM3D Body",
|
|
category="image/detection/sam3dbody/",
|
|
inputs=[
|
|
MHRPoseData.Input("mhr_pose_data"),
|
|
io.Image.Input(
|
|
"background",
|
|
optional=True,
|
|
tooltip="Per-frame background. Omitted = black canvas.",
|
|
),
|
|
io.Int.Input(
|
|
"width", default=0, min=0, max=16384, step=8,
|
|
tooltip=(
|
|
"Output width in pixels. 0 = use pose data's native "
|
|
"image_size. If only one of width/height is set, the "
|
|
"other is derived preserving the original aspect."
|
|
),
|
|
),
|
|
io.Int.Input(
|
|
"height", default=0, min=0, max=16384, step=8,
|
|
tooltip=(
|
|
"Output height in pixels. 0 = use pose data's native "
|
|
"image_size. If only one of width/height is set, the "
|
|
"other is derived preserving the original aspect."
|
|
),
|
|
),
|
|
io.Load3DCamera.Input(
|
|
"camera_info", optional=True,
|
|
tooltip=(
|
|
"Free 6DOF camera override. When wired, the pose is re-projected through this camera "
|
|
"(position/target/zoom) instead of the predicted one. "
|
|
),
|
|
),
|
|
io.Float.Input(
|
|
"camera_fov", default=0.0, min=0.0, max=170.0, step=0.5, advanced=True,
|
|
tooltip=(
|
|
"Vertical FOV for the camera_info override. 0 = keep the SAM3D "
|
|
"predicted camera's FOV (only the viewpoint changes). Any non-zero "
|
|
"value overrides the lens. Ignored when camera_info is unwired."
|
|
),
|
|
),
|
|
io.DynamicCombo.Input(
|
|
"render_style",
|
|
options=[
|
|
io.DynamicCombo.Option("mesh", _render_mesh_mode_inputs()),
|
|
io.DynamicCombo.Option("silhouette", []),
|
|
io.DynamicCombo.Option("openpose_2d", _render_openpose_mode_inputs()),
|
|
io.DynamicCombo.Option("openpose_3d", _render_openpose3d_mode_inputs()),
|
|
io.DynamicCombo.Option("scail", _render_capsules_mode_inputs()),
|
|
],
|
|
tooltip=(
|
|
"'mesh' = 3D MHR mesh rasterized through the camera. "
|
|
"'silhouette' = binary mask of the mesh. "
|
|
"'openpose_2d' = flat 2D skeleton "
|
|
"'openpose_3d' = openpose skeleton as flat-shaded 3D model "
|
|
"'scail' = SCAIL 3D capsules "
|
|
),
|
|
),
|
|
],
|
|
outputs=[io.Image.Output("image")],
|
|
)
|
|
|
|
|
|
@classmethod
|
|
def execute(cls, mhr_pose_data, background=None, width=0, height=0, camera_info=None, camera_fov=0.0, render_style=None) -> io.NodeOutput:
|
|
render_style = render_style or {"render_style": "mesh"}
|
|
mode_key = render_style.get("render_style", "mesh")
|
|
|
|
native_H, native_W = mhr_pose_data["image_size"]
|
|
new_W, new_H = int(width), int(height)
|
|
if new_W == 0 and new_H == 0:
|
|
H, W = native_H, native_W
|
|
px_scale = 1.0
|
|
else:
|
|
if new_W == 0:
|
|
new_W = max(1, round(native_W * new_H / native_H))
|
|
elif new_H == 0:
|
|
new_H = max(1, round(native_H * new_W / native_W))
|
|
mhr_pose_data = _scale_pose_data(mhr_pose_data, new_H, new_W)
|
|
H, W = new_H, new_W
|
|
px_scale = min(new_W / native_W, new_H / native_H)
|
|
|
|
if camera_info is not None:
|
|
mhr_pose_data = apply_camera_override(mhr_pose_data, camera_info, H, W, fov_deg=float(camera_fov))
|
|
|
|
B = len(mhr_pose_data["frames"])
|
|
if B == 0:
|
|
return io.NodeOutput(torch.zeros(1, H, W, 3, dtype=torch.float32))
|
|
|
|
out_device = comfy.model_management.intermediate_device()
|
|
bg_t = None if background is None else background.to(device=out_device, dtype=torch.float32)
|
|
|
|
if bg_t is not None and tuple(bg_t.shape[1:3]) != (H, W): # Match the background to the render resolution
|
|
bg_t = comfy.utils.common_upscale(bg_t.movedim(-1, 1), W, H, "bilinear", "disabled").movedim(1, -1)
|
|
|
|
if mode_key == "silhouette":
|
|
composite = "silhouette"
|
|
elif bg_t is not None:
|
|
composite = "over"
|
|
else:
|
|
composite = "mesh_only"
|
|
|
|
if mode_key == "openpose_2d":
|
|
marker_radius_px = max(1, int(round(render_style.get("marker_radius_px", 4) * px_scale)))
|
|
stick_width_px = max(1, int(round(render_style.get("stick_width_px", 4) * px_scale)))
|
|
limb_alpha = float(render_style.get("limb_alpha", 0.6))
|
|
face_style = str(render_style.get("face_style", "disabled"))
|
|
hand_style = str(render_style.get("hand_style", "disabled"))
|
|
include_hands = hand_style != "disabled"
|
|
hand_color_style = hand_style if include_hands else "dwpose"
|
|
person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6))
|
|
elif mode_key == "openpose_3d":
|
|
op3d_radius_m = float(render_style.get("radius_m", 0.015))
|
|
op3d_include_hands = bool(render_style.get("include_hands", True))
|
|
person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6))
|
|
elif mode_key == "scail":
|
|
cap_radius_m = float(render_style.get("radius_m", 0.030))
|
|
cap_hand_style = str(render_style.get("hand_style", "disabled"))
|
|
cap_include_hands = cap_hand_style != "disabled"
|
|
cap_hand_color_style = cap_hand_style if cap_include_hands else "dwpose"
|
|
cap_face_style = str(render_style.get("face_style", "disabled"))
|
|
person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6))
|
|
elif mode_key == "mesh":
|
|
shader_dict = render_style.get("shader") or {}
|
|
shader_key = shader_dict.get("shader", "default")
|
|
rainbow_tilt_x = float(shader_dict.get("rainbow_tilt_x", 0.0))
|
|
rainbow_tilt_z = float(shader_dict.get("rainbow_tilt_z", -35.0))
|
|
opacity = float(render_style.get("opacity", 1.0))
|
|
person_palette_falloff = float(render_style.get("person_palette_falloff", 0.6))
|
|
region = str(render_style.get("region", "full_body"))
|
|
|
|
if region == "hands_only":
|
|
hand_mask = mhr_pose_data["hand_vert_mask"]
|
|
faces_full = np.asarray(mhr_pose_data["faces"])
|
|
keep = hand_mask[faces_full].all(axis=1)
|
|
mhr_pose_data = dict(mhr_pose_data)
|
|
mhr_pose_data["faces"] = np.ascontiguousarray(
|
|
faces_full[keep], dtype=faces_full.dtype,
|
|
)
|
|
else: # silhouette — no shader/opacity controls, mask is binary
|
|
shader_key = "default"
|
|
rainbow_tilt_x = 0.0
|
|
rainbow_tilt_z = -35.0
|
|
opacity = 1.0
|
|
person_palette_falloff = 0.6
|
|
|
|
frames_out = []
|
|
pbar = comfy.utils.ProgressBar(B)
|
|
desc = (
|
|
"SAM3D openpose-2D render" if mode_key == "openpose_2d"
|
|
else "SAM3D openpose-3D render" if mode_key == "openpose_3d"
|
|
else "SAM3D SCAIL-3D render" if mode_key == "scail"
|
|
else "SAM3D silhouette" if mode_key == "silhouette"
|
|
else "SAM3D render"
|
|
)
|
|
for f in tqdm(range(B), desc=desc):
|
|
bg_f = None
|
|
if bg_t is not None:
|
|
bg_f = bg_t[min(f, bg_t.shape[0] - 1)]
|
|
if mode_key == "openpose_2d":
|
|
img = render_pose_data_openpose(
|
|
mhr_pose_data, frame_idx=f, W=W, H=H,
|
|
background=bg_f,
|
|
composite=composite,
|
|
marker_radius_px=marker_radius_px,
|
|
stick_width_px=stick_width_px,
|
|
limb_alpha=limb_alpha,
|
|
include_hands=include_hands,
|
|
face_style=face_style,
|
|
hand_color_style=hand_color_style,
|
|
person_brightness_falloff=person_palette_falloff,
|
|
)
|
|
elif mode_key == "openpose_3d":
|
|
img = render_pose_data_capsules(
|
|
mhr_pose_data, frame_idx=f, W=W, H=H,
|
|
background=bg_f,
|
|
composite=composite,
|
|
radius_m=op3d_radius_m,
|
|
include_hands=op3d_include_hands,
|
|
palette="openpose",
|
|
flat_shade=True,
|
|
person_brightness_falloff=person_palette_falloff,
|
|
)
|
|
elif mode_key == "scail":
|
|
# SCAIL renders body as 3D capsules + 2D openpose hands on top
|
|
img = render_pose_data_capsules(
|
|
mhr_pose_data, frame_idx=f, W=W, H=H,
|
|
background=bg_f,
|
|
composite=composite,
|
|
radius_m=cap_radius_m,
|
|
include_hands=False,
|
|
palette="scail",
|
|
person_brightness_falloff=person_palette_falloff,
|
|
)
|
|
if cap_include_hands or cap_face_style != "disabled":
|
|
scail_overlay_px = max(1, int(round(4 * px_scale)))
|
|
scail_face_px = max(1, int(round(1 * px_scale)))
|
|
img = render_pose_data_openpose(
|
|
mhr_pose_data, frame_idx=f, W=W, H=H,
|
|
background=img,
|
|
composite="over",
|
|
include_body=False,
|
|
include_hands=cap_include_hands,
|
|
face_style=cap_face_style,
|
|
marker_radius_px=scail_overlay_px,
|
|
stick_width_px=scail_overlay_px,
|
|
face_marker_radius_px=scail_face_px,
|
|
hand_color_style=cap_hand_color_style,
|
|
person_brightness_falloff=person_palette_falloff,
|
|
)
|
|
else:
|
|
img = render_pose_data(
|
|
mhr_pose_data, frame_idx=f, W=W, H=H,
|
|
background=bg_f, composite=composite, opacity=opacity,
|
|
shader_preset=shader_key,
|
|
rainbow_tilt_x_deg=rainbow_tilt_x,
|
|
rainbow_tilt_z_deg=rainbow_tilt_z,
|
|
person_brightness_falloff=person_palette_falloff,
|
|
)
|
|
frames_out.append(img)
|
|
pbar.update(1)
|
|
|
|
out_image = torch.stack(frames_out, dim=0)
|
|
if out_image.device != out_device:
|
|
out_image = out_image.to(out_device)
|
|
return io.NodeOutput(out_image)
|
|
|
|
|
|
class SAM3DBodyExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> List[type[io.ComfyNode]]:
|
|
return [
|
|
SAM3DBody_Loader,
|
|
SAM3DBody_Predict,
|
|
SAM3DBody_FaceExpression,
|
|
SAM3DBody_Smooth,
|
|
SAM3DBody_Render,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> SAM3DBodyExtension:
|
|
return SAM3DBodyExtension()
|