feat: MediaPipe face detection (CORE-235) (#14009)

* Initial mediapipe face detection support

* Update face_geometry.py

* Account for diff sized batch input

* Model folder placeholder
This commit is contained in:
Jukka Seppänen 2026-05-21 02:07:48 +03:00 committed by GitHub
parent a8d2519058
commit 4d6a058bf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1298 additions and 0 deletions

View File

@ -0,0 +1,111 @@
"""Pure-numpy port of MediaPipe's face_geometry (FACE_LANDMARK_PIPELINE mode)
+ weighted Procrustes solver. Computes the 4x4 facial transformation matrix.
"""
from __future__ import annotations
import math
import numpy as np
def _solve_weighted_orthogonal_problem(src: np.ndarray, tgt: np.ndarray, weights: np.ndarray) -> np.ndarray:
"""Weighted orthogonal Procrustes (similarity). Returns 4x4 M with
`target M @ homogeneous(source)` in the weighted LS sense. fp64 for
SVD stability. Port of procrustes_solver.cc."""
sqrt_w = np.sqrt(weights.astype(np.float64))
w_total = float((sqrt_w ** 2).sum())
ws = src.astype(np.float64) * sqrt_w
wt = tgt.astype(np.float64) * sqrt_w
c_w = (ws @ sqrt_w) / w_total
centered = ws - np.outer(c_w, sqrt_w)
U, _S, Vt = np.linalg.svd(wt @ centered.T, full_matrices=True)
# Disallow reflection: flip the least-significant axis when det(U)·det(V)<0.
post, pre = U.copy(), Vt.T.copy()
if np.linalg.det(post) * np.linalg.det(pre) < 0:
post[:, 2] *= -1.0
R = post @ pre.T
denom = float((centered * ws).sum())
if denom < 1e-12:
raise ValueError("Procrustes denominator collapsed (degenerate source).")
scale = float((R @ centered * wt).sum()) / denom
translation = ((wt - scale * (R @ ws)) @ sqrt_w) / w_total
M = np.eye(4, dtype=np.float64)
M[:3, :3] = scale * R
M[:3, 3] = translation
return M
def _estimate_scale(canonical: np.ndarray, runtime: np.ndarray, weights: np.ndarray) -> float:
"""scale = ‖first column of M[:3]‖ per geometry_pipeline.cc::EstimateScale."""
return float(np.linalg.norm(_solve_weighted_orthogonal_problem(canonical, runtime, weights)[:3, 0]))
def solve_facial_transformation_matrix(
landmarks_normalized: np.ndarray,
canonical_vertices: np.ndarray,
procrustes_indices: np.ndarray,
procrustes_weights: np.ndarray,
image_width: int,
image_height: int,
# face_geometry_calculator_options.pbtxt defaults
vertical_fov_degrees: float = 63.0,
near: float = 1.0,
) -> np.ndarray:
"""4x4 facial transformation matrix via two-pass scale recovery
`landmarks_normalized` is (N, 3) in MediaPipe normalized convention: x, y
in [0,1] with TOP-LEFT origin, z in width-scaled units.
"""
h_near = 2.0 * near * math.tan(0.5 * math.radians(vertical_fov_degrees))
w_near = image_width * h_near / image_height
sub = procrustes_indices.astype(np.int64)
screen = landmarks_normalized[sub].T.astype(np.float64).copy()
canon = canonical_vertices[sub].T.astype(np.float64).copy()
weights = procrustes_weights.astype(np.float64)
# ProjectXY (TOP_LEFT y-flip, then scale all 3 axes; z uses x-scale).
screen[1] = 1.0 - screen[1]
screen[0] = screen[0] * w_near - 0.5 * w_near
screen[1] = screen[1] * h_near - 0.5 * h_near
screen[2] = screen[2] * w_near
depth_offset = float(screen[2].mean())
def _unproject(s: np.ndarray, scale: float) -> np.ndarray:
s = s.copy()
s[2] = (s[2] - depth_offset + near) / scale
s[0] *= s[2] / near
s[1] *= s[2] / near
s[2] *= -1.0
return s
first = screen.copy()
first[2] *= -1.0
s1 = _estimate_scale(canon, first, weights) # 1st pass: Procrustes on projected XY
s2 = _estimate_scale(canon, _unproject(screen, s1), weights) # 2nd pass: rescale z by s1, un-project XY
return _solve_weighted_orthogonal_problem(canon, _unproject(screen, s1 * s2), weights).astype(np.float32)
def transformation_matrix_from_detection(face_dict: dict, image_width: int, image_height: int, canonical_data: dict) -> np.ndarray:
"""Adapt a FaceLandmarker face dict to MP's normalized convention and solve.
FaceMesh emits (x, y, z) in 192-canonical units; MP's geometry expects
z_norm = z_canonical * scale_x / image_width"""
lmks_xy, lmks_3d = face_dict["landmarks_xy"], face_dict["landmarks_3d"]
aug = np.concatenate([lmks_3d[:, :2].astype(np.float64), np.ones((lmks_xy.shape[0], 1))], axis=1)
M, *_ = np.linalg.lstsq(aug, lmks_xy.astype(np.float64), rcond=None)
scale_x = float(np.linalg.norm(M[0]))
z_scale = scale_x / image_width if scale_x > 1e-6 else 1.0 / image_width
normalized = np.empty((lmks_xy.shape[0], 3), dtype=np.float32)
normalized[:, 0] = lmks_xy[:, 0] / image_width
normalized[:, 1] = lmks_xy[:, 1] / image_height
normalized[:, 2] = lmks_3d[:, 2] * z_scale
return solve_facial_transformation_matrix(
normalized, canonical_data["canonical_vertices"],
canonical_data["procrustes_indices"], canonical_data["procrustes_weights"],
image_width=image_width, image_height=image_height,
)

View File

@ -0,0 +1,682 @@
"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task:
BlazeFace detector FaceMesh v2 ARKit-52 blendshapes."""
from __future__ import annotations
import math
from functools import lru_cache
from typing import List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from scipy.special import expit
from torch import Tensor, nn
# Values below must stay verbatim with the published face_landmarker_v2 graph
# face_blendshapes_graph.cc::kLandmarksSubsetIdxs
_BS_INPUT_INDICES: Tuple[int, ...] = (
0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 40, 46, 52, 53, 54,
55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 81, 82, 84, 87, 88, 91, 93, 95,
103, 105, 107, 109, 127, 132, 133, 136, 144, 145, 146, 148, 149, 150, 152,
153, 154, 155, 157, 158, 159, 160, 161, 162, 163, 168, 172, 173, 176, 178,
181, 185, 191, 195, 197, 234, 246, 249, 251, 263, 267, 269, 270, 276, 282,
283, 284, 285, 288, 291, 293, 295, 296, 297, 300, 308, 310, 311, 312, 314,
317, 318, 321, 323, 324, 332, 334, 336, 338, 356, 361, 362, 365, 373, 374,
375, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, 389, 390, 397,
398, 400, 402, 405, 409, 415, 454, 466, 468, 469, 470, 471, 472, 473, 474,
475, 476, 477,
)
# face_blendshapes_graph.cc::kCategoryNames
BLENDSHAPE_NAMES: Tuple[str, ...] = (
"_neutral", "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft",
"browOuterUpRight", "cheekPuff", "cheekSquintLeft", "cheekSquintRight",
"eyeBlinkLeft", "eyeBlinkRight", "eyeLookDownLeft", "eyeLookDownRight",
"eyeLookInLeft", "eyeLookInRight", "eyeLookOutLeft", "eyeLookOutRight",
"eyeLookUpLeft", "eyeLookUpRight", "eyeSquintLeft", "eyeSquintRight",
"eyeWideLeft", "eyeWideRight", "jawForward", "jawLeft", "jawOpen",
"jawRight", "mouthClose", "mouthDimpleLeft", "mouthDimpleRight",
"mouthFrownLeft", "mouthFrownRight", "mouthFunnel", "mouthLeft",
"mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft",
"mouthPressRight", "mouthPucker", "mouthRight", "mouthRollLower",
"mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", "mouthSmileLeft",
"mouthSmileRight", "mouthStretchLeft", "mouthStretchRight",
"mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight",
)
# face_detection.pbtxt — short-range BlazeFace.
_BF_NUM_LAYERS = 4
_BF_INPUT_SIZE = 128
_BF_STRIDES = (8, 16, 16, 16)
_BF_ANCHOR_OFFSET_X = 0.5
_BF_ANCHOR_OFFSET_Y = 0.5
_BF_ASPECT_RATIOS = (1.0,)
_BF_INTERP_SCALE_AR = 1.0
_BF_BOX_SCALE = 128.0
_BF_KP_OFFSET = 4
_BF_SCORE_CLIP = 100.0
_BF_MIN_SCORE = 0.5
# face_detection_full_range.pbtxt — 48x48 grid at stride 4, 1 anchor/cell.
_BF_FR_INPUT_SIZE = 192
_BF_FR_GRID = 48
_BF_FR_NUM_ANCHORS = _BF_FR_GRID * _BF_FR_GRID
_BF_FR_BOX_SCALE = 192.0
_BF_FR_SCORE_CLIP = 100.0
_FM_INPUT_SIZE = 192
# Face ROI: 1.5xbbox rect warped anisotropically into 192x192.
_FACE_LEFT_EYE_KP = 0
_FACE_RIGHT_EYE_KP = 1
_FACE_ROI_SCALE_X = 1.5
_FACE_ROI_SCALE_Y = 1.5
_FACE_ROI_TARGET_ANGLE = 0.0
def _tf_same_pad(x: Tensor, kernel: int, stride: int) -> Tensor:
"""TF SAME pad (asymmetric on stride-2; PyTorch's symmetric pad undershoots by 1 px)."""
H, W = x.shape[-2], x.shape[-1]
pad_h = max(((H + stride - 1) // stride - 1) * stride + kernel - H, 0)
pad_w = max(((W + stride - 1) // stride - 1) * stride + kernel - W, 0)
if pad_h == 0 and pad_w == 0:
return x
return F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
# BlazeFace short-range: stem 5x5/s2 → 16 BlazeBlocks → parallel heads at
# 16²x88 (2 anchors/cell) and 8²x96 (6/cell) = 896 anchors. (in, out, stride):
_BLAZEFACE_BLOCKS = [
(24, 24, 1), (24, 28, 1), (28, 32, 2), (32, 36, 1),
(36, 42, 1), (42, 48, 2), (48, 56, 1), (56, 64, 1),
(64, 72, 1), (72, 80, 1), (80, 88, 1), (88, 96, 2),
(96, 96, 1), (96, 96, 1), (96, 96, 1), (96, 96, 1),
]
class BlazeFaceBlock(nn.Module):
"""DW 3x3 + PW + residual. Residual max-pools on stride>1, channel-pads on out_ch>in_ch."""
def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride
self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, device=device, dtype=dtype)
self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x
if self.out_ch > self.in_ch:
residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch))
x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1))
return F.relu(self.pointwise(self.depthwise(x)) + residual)
class BlazeFace(nn.Module):
"""Short-range BlazeFace: (B, 3, 128, 128) in [-1, 1] → 896 anchors x 17."""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.stem = ops.Conv2d(3, 24, 5, stride=2, padding=0, bias=True, **kw)
self.blocks = nn.ModuleList(BlazeFaceBlock(i, o, s, device=device, dtype=dtype, operations=operations)
for (i, o, s) in _BLAZEFACE_BLOCKS)
# 16²x2 + 8²x6 = 512 + 384 = 896 anchors.
self.cls_16 = ops.Conv2d(88, 2, 1, padding=0, bias=True, **kw)
self.cls_8 = ops.Conv2d(96, 6, 1, padding=0, bias=True, **kw)
self.reg_16 = ops.Conv2d(88, 32, 1, padding=0, bias=True, **kw)
self.reg_8 = ops.Conv2d(96, 96, 1, padding=0, bias=True, **kw)
def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
x = F.relu(self.stem(_tf_same_pad(image_chw_normalized, 5, 2)))
# 16x16 tap is block-10 output (before the 88→96 stride-2 in block 11).
for i in range(11):
x = self.blocks[i](x)
feat_16 = x
for i in range(11, 16):
x = self.blocks[i](x)
feat_8 = x
def flat(t, a, k): # NHWC flatten → (B, H*W*A, K)
B, _, H, W = t.shape
return t.permute(0, 2, 3, 1).reshape(B, H * W * a, k)
cls = torch.cat([flat(self.cls_16(feat_16), 2, 1), flat(self.cls_8(feat_8), 6, 1)], dim=1)
reg = torch.cat([flat(self.reg_16(feat_16), 2, 16), flat(self.reg_8(feat_8), 6, 16)], dim=1)
return reg, cls
# BlazeFace full-range (face_detection_full_range_sparse.tflite): MobileNetV2-ish
# backbone + top-down FPN, 192² input → 2304 anchors at the 48x48 grid.
class FRBlock(nn.Module):
"""Double inverted residual: DW → PW(mid) → DW → PW(out) [+ residual].
Per source tflite: dw* have no fused activation, pw1 is always ReLU, pw2
is ReLU only when no residual (else ReLU fuses into the ADD).
"""
def __init__(self, in_ch: int, mid_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.has_residual = (in_ch == out_ch and stride == 1)
self.dw1 = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw)
self.pw1 = ops.Conv2d(in_ch, mid_ch, 1, padding=0, bias=True, **kw)
self.dw2 = ops.Conv2d(mid_ch, mid_ch, 3, stride=1, padding=0, groups=mid_ch, bias=True, **kw)
self.pw2 = ops.Conv2d(mid_ch, out_ch, 1, padding=0, bias=True, **kw)
def forward(self, x: Tensor) -> Tensor:
residual = x if self.has_residual else None
x = F.relu(self.pw1(self.dw1(F.pad(x, (1, 1, 1, 1)))))
x = self.pw2(self.dw2(F.pad(x, (1, 1, 1, 1))))
return F.relu(x + residual) if residual is not None else F.relu(x)
# (in_ch, mid_ch, out_ch, stride). Stages downsample 96²x32 → 48²x64 → 24²x128
# → 12²x192 → 6²x384. Lateral taps at indices 4, 7, 10 (see _FR_LATERAL_*).
_FR_BACKBONE_BLOCKS = [
(32, 8, 32, 1), (32, 8, 32, 1), # 96²x32
(32, 16, 64, 2), (64, 16, 64, 1), (64, 16, 64, 1), # 48²x64 — tap[0]
(64, 32, 128, 2), (128, 32, 128, 1), (128, 32, 128, 1), # 24²x128 — tap[1]
(128, 48, 192, 2), (192, 48, 192, 1), (192, 48, 192, 1), # 12²x192 — tap[2]
(192, 96, 384, 2), (384, 96, 384, 1), (384, 96, 384, 1), (384, 96, 384, 1), # 6²x384
]
_FR_LATERAL_TAP_INDICES = (4, 7, 10)
_FR_LATERAL_CHANNELS = ((64, 48), (128, 64), (192, 96)) # (in, out) per side-conv
# Decoder blocks per FPN level (after upsample-and-merge with the lateral).
_FR_DECODER_BLOCKS = [
[(96, 48, 96, 1), (96, 48, 96, 1)], # 12²x96
[(64, 32, 64, 1), (64, 32, 64, 1)], # 24²x64
[(48, 24, 48, 1)], # 48²x48 — feeds the heads
]
def _dcr_depth_to_space(t: Tensor, r: int, c_out: int) -> Tensor:
"""TF DEPTH_TO_SPACE in DCR layout (input channels = (i, j, c_out)).
pixel_shuffle uses CRD which permutes output channels for c_out > 1."""
B_, _, H_, W_ = t.shape
t = t.reshape(B_, r, r, c_out, H_, W_)
t = t.permute(0, 3, 4, 1, 5, 2).contiguous()
return t.reshape(B_, c_out, H_ * r, W_ * r)
class BlazeFaceFullRange(nn.Module):
"""Full-range face detector: (B, 3, 192, 192) in [-1, 1] → 2304 anchors x 17 values."""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
mk_block = lambda i, m, o, s: FRBlock(i, m, o, s, device=device, dtype=dtype, operations=operations)
self.stem = ops.Conv2d(3, 32, 3, stride=2, padding=0, bias=True, **kw)
self.backbone = nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in _FR_BACKBONE_BLOCKS)
self.lateral_convs = nn.ModuleList(ops.Conv2d(i, o, 1, padding=0, bias=True, **kw) for (i, o) in _FR_LATERAL_CHANNELS)
self.top_conv = ops.Conv2d(384, 96, 1, padding=0, bias=True, **kw)
self.decoder_levels = nn.ModuleList(
nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in lvl) for lvl in _FR_DECODER_BLOCKS
)
# 96→64 before 12→24, 64→48 before 24→48.
self.decoder_reduce_convs = nn.ModuleList([
ops.Conv2d(96, 64, 1, padding=0, bias=True, **kw),
ops.Conv2d(64, 48, 1, padding=0, bias=True, **kw),
])
# Heads mix 2x2-cell info via DW-stride-2 + depth_to_space block_size=2.
self.cls_conv = ops.Conv2d(48, 4, 1, padding=0, bias=True, **kw)
self.cls_dw = ops.Conv2d(4, 4, 3, stride=2, padding=0, groups=4, bias=True, **kw)
self.reg_conv = ops.Conv2d(48, 64, 1, padding=0, bias=True, **kw)
self.reg_dw = ops.Conv2d(64, 64, 3, stride=2, padding=0, groups=64, bias=True, **kw)
def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
# Symmetric pad-1 throughout (full-range tflite uses explicit TF PAD, not SAME).
x = F.relu(self.stem(F.pad(image_chw_normalized, (1, 1, 1, 1))))
tap_set = set(_FR_LATERAL_TAP_INDICES)
laterals: list[Tensor] = []
for i, blk in enumerate(self.backbone):
x = blk(x)
if i in tap_set:
laterals.append(x)
# top_conv / lateral_convs / decoder_reduce_convs all have fused ReLU in the tflite.
p = F.relu(self.top_conv(x))
laterals_rev = list(reversed(laterals))
lateral_convs_rev = list(reversed(self.lateral_convs))
for level in range(len(self.decoder_levels)):
lateral = laterals_rev[level]
p = F.interpolate(p, size=lateral.shape[-2:], mode="bilinear", align_corners=False)
p = p + F.relu(lateral_convs_rev[level](lateral))
for blk in self.decoder_levels[level]:
p = blk(p)
if level < len(self.decoder_reduce_convs):
p = F.relu(self.decoder_reduce_convs[level](p))
c = self.cls_dw(F.pad(self.cls_conv(p), (1, 1, 1, 1)))
c = _dcr_depth_to_space(c, r=2, c_out=1)
r = self.reg_dw(F.pad(self.reg_conv(p), (1, 1, 1, 1)))
r = _dcr_depth_to_space(r, r=2, c_out=16)
B = c.shape[0]
cls_out = c.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 1)
reg_out = r.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 16)
return reg_out, cls_out
@lru_cache(maxsize=1)
def _blazeface_full_range_anchors() -> np.ndarray:
"""2304 anchors over 48x48; anchor_w=anchor_h=1 (fixed_anchor_size)."""
feat = _BF_FR_GRID
yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij")
cx, cy, ones = (xx + 0.5) / feat, (yy + 0.5) / feat, np.ones_like(xx)
return np.stack([cx, cy, ones, ones], axis=-1).reshape(_BF_FR_NUM_ANCHORS, 4)
def _decode_blazeface_full_range(regressors: np.ndarray, classificators: np.ndarray,
score_thresh: float = _BF_MIN_SCORE) -> np.ndarray:
"""Same decode as short-range with 2304-anchor grid and box_scale=192."""
scores = expit(np.clip(classificators[:, 0], -_BF_FR_SCORE_CLIP, _BF_FR_SCORE_CLIP))
keep = scores >= score_thresh
if not keep.any():
return np.empty((0, 17), dtype=np.float32)
r = regressors[keep] / _BF_FR_BOX_SCALE
a = _blazeface_full_range_anchors()[keep]
cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4]
xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys
w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs
out = np.empty((r.shape[0], 17), dtype=np.float32)
out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs
out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys
out[:, 16] = scores[keep]
return out
# FaceMesh (face_landmarks_detector.tflite): PReLU variant of BlazeBlock,
# 17 blocks, heads for 478x3 landmarks + presence.
_FACEMESH_BLOCKS = [ # (in_ch, out_ch, stride)
(16, 16, 1), (16, 16, 1), (16, 32, 2), (32, 32, 1), (32, 32, 1), (32, 64, 2),
(64, 64, 1), (64, 64, 1), (64, 128, 2), (128, 128, 1), (128, 128, 1), (128, 128, 2),
(128, 128, 1), (128, 128, 1), (128, 128, 2), (128, 128, 1), (128, 128, 1),
]
class FaceMeshBlock(nn.Module):
"""PReLU BlazeBlock: PReLU between DW and PW, and after the residual add."""
def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride
self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw)
self.prelu_dwise = nn.PReLU(num_parameters=in_ch, **kw)
self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, **kw)
self.prelu_out = nn.PReLU(num_parameters=out_ch, **kw)
def forward(self, x: Tensor) -> Tensor:
residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x
if self.out_ch > self.in_ch:
residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch))
x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1))
return self.prelu_out(self.pointwise(self.prelu_dwise(self.depthwise(x))) + residual)
class FaceMesh(nn.Module):
NUM_LANDMARKS = 478
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.stem = ops.Conv2d(3, 16, 3, stride=2, padding=0, bias=True, **kw)
self.prelu_stem = nn.PReLU(num_parameters=16, **kw)
self.blocks = nn.ModuleList(FaceMeshBlock(i, o, s, device=device, dtype=dtype, operations=operations)
for (i, o, s) in _FACEMESH_BLOCKS)
self.head_reduce = ops.Conv2d(128, 8, 1, padding=0, bias=True, **kw)
self.prelu_head_reduce = nn.PReLU(num_parameters=8, **kw)
self.head_block = FaceMeshBlock(8, 8, 1, device=device, dtype=dtype, operations=operations)
self.head_presence = ops.Conv2d(8, 1, 3, padding=0, bias=True, **kw)
self.head_landmarks = ops.Conv2d(8, self.NUM_LANDMARKS * 3, 3, padding=0, bias=True, **kw)
def forward(self, face_chw_normalized: Tensor) -> tuple[Tensor, Tensor]:
"""(B, 3, 192, 192) in [0, 1] → ((B, 478, 3) landmarks in 192-canonical, (B,) presence)."""
x = self.prelu_stem(self.stem(_tf_same_pad(face_chw_normalized, 3, 2)))
for blk in self.blocks:
x = blk(x)
x = self.prelu_head_reduce(self.head_reduce(x))
x = self.head_block(x)
B = x.shape[0]
presence = self.head_presence(x).reshape(B)
lmks = self.head_landmarks(x).reshape(B, self.NUM_LANDMARKS, 3)
return lmks, presence
# FaceBlendshapes (MLP-Mixer "GhumMarkerPoserMlpMixerGeneral"):
# 146x2 → token-reduce 146→96 → embed 2→64 → +cls token → 4x mixer → cls→52.
_BS_NUM_INPUT_LANDMARKS = 146
_BS_NUM_TOKENS_REDUCED = 96
_BS_NUM_TOKENS = 97 # +1 cls
_BS_TOKEN_DIM = 64
_BS_TOKEN_MIX_HIDDEN = 384
_BS_CHANNEL_MIX_HIDDEN = 256
_BS_NUM_BLENDSHAPES = 52
_BS_LN_EPS = 1e-6
class MlpMixerBlock(nn.Module):
"""MLP-Mixer block: token-mixing MLP (over tokens) → channel-mixing MLP (over dim).
Both pre-LN, both residual. LN has no beta (bias=False) to match MP."""
def __init__(self, num_tokens: int, token_dim: int, token_hidden: int, channel_hidden: int,
device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
# bias=False → no LN beta (matches MP).
self.ln1 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw)
self.ln2 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw)
self.token_mlp1 = ops.Linear(num_tokens, token_hidden, bias=True, **kw)
self.token_mlp2 = ops.Linear(token_hidden, num_tokens, bias=True, **kw)
self.channel_mlp1 = ops.Linear(token_dim, channel_hidden, bias=True, **kw)
self.channel_mlp2 = ops.Linear(channel_hidden, token_dim, bias=True, **kw)
def forward(self, x: Tensor) -> Tensor:
y = self.ln1(x).transpose(1, 2)
x = x + self.token_mlp2(F.relu(self.token_mlp1(y))).transpose(1, 2)
return x + self.channel_mlp2(F.relu(self.channel_mlp1(self.ln2(x))))
class FaceBlendshapes(nn.Module):
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
ops = operations if operations is not None else nn
kw = dict(device=device, dtype=dtype)
self.token_reduce = ops.Linear(_BS_NUM_INPUT_LANDMARKS, _BS_NUM_TOKENS_REDUCED, bias=True, **kw)
self.token_embed = ops.Linear(2, _BS_TOKEN_DIM, bias=True, **kw)
self.cls_token = nn.Parameter(torch.zeros(1, 1, _BS_TOKEN_DIM, **kw))
self.blocks = nn.ModuleList(
MlpMixerBlock(_BS_NUM_TOKENS, _BS_TOKEN_DIM, _BS_TOKEN_MIX_HIDDEN, _BS_CHANNEL_MIX_HIDDEN,
device=device, dtype=dtype, operations=operations) for _ in range(4)
)
self.head = ops.Linear(_BS_TOKEN_DIM, _BS_NUM_BLENDSHAPES, bias=True, **kw)
@staticmethod
def _input_normalize(landmarks_2d: Tensor) -> Tensor:
# Centroid-subtract → L2 scale → x0.5. The 0.5 is baked into training.
centroid = landmarks_2d.mean(dim=1, keepdim=True)
x = landmarks_2d - centroid
mag = torch.sqrt((x * x).sum(dim=-1, keepdim=True))
scale = mag.mean(dim=1, keepdim=True)
return (x / scale.clamp(min=1e-12)) * 0.5
def forward(self, landmarks_2d: Tensor) -> Tensor:
"""(B, 146, 2) → (B, 52) in [0, 1]. Input units don't matter (centroid + L2 normalize)."""
x = self._input_normalize(landmarks_2d)
x = self.token_reduce(x.transpose(1, 2)).transpose(1, 2)
x = self.token_embed(x)
cls = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls, x], dim=1)
for blk in self.blocks:
x = blk(x)
return torch.sigmoid(self.head(x[:, 0]))
@lru_cache(maxsize=1)
def _blazeface_anchors() -> np.ndarray:
"""896 anchors per SsdAnchorsCalculator (fixed_anchor_size → anchor_w=anchor_h=1)."""
per_ar = len(_BF_ASPECT_RATIOS) + (1 if _BF_INTERP_SCALE_AR > 0 else 0)
layer_anchors: List[np.ndarray] = []
layer = 0
while layer < _BF_NUM_LAYERS:
stride = _BF_STRIDES[layer]
last = layer
while last < _BF_NUM_LAYERS and _BF_STRIDES[last] == stride:
last += 1
per_cell = per_ar * (last - layer)
feat = (_BF_INPUT_SIZE + stride - 1) // stride
yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij")
cx, cy, ones = (xx + _BF_ANCHOR_OFFSET_X) / feat, (yy + _BF_ANCHOR_OFFSET_Y) / feat, np.ones_like(xx)
cell = np.stack([cx, cy, ones, ones], axis=-1).reshape(-1, 4)
layer_anchors.append(np.repeat(cell, per_cell, axis=0))
layer = last
out = np.concatenate(layer_anchors, axis=0)
assert out.shape == (896, 4), out.shape
return out
def _decode_blazeface(regressors: np.ndarray, classificators: np.ndarray,
score_thresh: float = _BF_MIN_SCORE) -> np.ndarray:
"""Decode (regs (896,16), cls (896,1)) → (N, 17) = [xyxy, kp0x..kp5y, score] in [0, 1]."""
scores = expit(np.clip(classificators[:, 0], -_BF_SCORE_CLIP, _BF_SCORE_CLIP))
keep = scores >= score_thresh
if not keep.any():
return np.empty((0, 17), dtype=np.float32)
r = regressors[keep] / _BF_BOX_SCALE
a = _blazeface_anchors()[keep] # (N, 4) cx, cy, 1, 1
cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4]
xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys
w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs
out = np.empty((r.shape[0], 17), dtype=np.float32)
out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2
out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs
out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys
out[:, 16] = scores[keep]
return out
def _weighted_nms(detections: np.ndarray, iou_thresh: float = 0.5) -> np.ndarray:
"""MP weighted NMS — kept boxes are score-weighted averages of overlapping detections."""
if detections.shape[0] == 0:
return detections
dets = detections[np.argsort(-detections[:, 16])]
N = dets.shape[0]
areas = np.clip(dets[:, 2] - dets[:, 0], 0, None) * np.clip(dets[:, 3] - dets[:, 1], 0, None)
kept: List[np.ndarray] = []
used = np.zeros(N, dtype=bool)
for i in range(N):
if used[i]:
continue
ax1, ay1, ax2, ay2 = dets[i, 0:4]
merge_idx = [i]
for j in range(i + 1, N):
if used[j]:
continue
bx1, by1, bx2, by2 = dets[j, 0:4]
iw = max(0.0, min(ax2, bx2) - max(ax1, bx1))
ih = max(0.0, min(ay2, by2) - max(ay1, by1))
inter = iw * ih
union = areas[i] + areas[j] - inter
if union > 0 and inter / union > iou_thresh: # strict > matches MP
merge_idx.append(j)
used[j] = True
used[i] = True
cluster = dets[merge_idx]
ws = cluster[:, 16:17]
ws_sum = ws.sum()
merged = np.copy(cluster[0])
if ws_sum > 0:
merged[:16] = (cluster[:, :16] * ws).sum(axis=0) / ws_sum
kept.append(merged)
return np.stack(kept, axis=0) if kept else np.empty((0, 17), dtype=np.float32)
def _detection_to_face_rect(detection: np.ndarray, image_w: int, image_h: int) -> Tuple[float, float, float, float, float]:
"""Detection (normalized) → rotated 1.5xbbox ROI in image pixels (anisotropic)."""
xmin, ymin, xmax, ymax = detection[0:4]
lx = detection[4 + _FACE_LEFT_EYE_KP * 2 + 0] * image_w
ly = detection[4 + _FACE_LEFT_EYE_KP * 2 + 1] * image_h
rx = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 0] * image_w
ry = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 1] * image_h
# Image-y-down convention: angle = target - atan2(-dy, dx).
angle = _FACE_ROI_TARGET_ANGLE - math.atan2(ly - ry, rx - lx)
return (float((xmin + xmax) * 0.5 * image_w),
float((ymin + ymax) * 0.5 * image_h),
float((xmax - xmin) * image_w * _FACE_ROI_SCALE_X),
float((ymax - ymin) * image_h * _FACE_ROI_SCALE_Y),
float(angle))
def _sample_warp(image_chw: Tensor, src_x: Tensor, src_y: Tensor, padding_mode: str) -> Tensor:
"""Bilinear-sample image_chw at corner-aligned (src_x, src_y)."""
H, W = int(image_chw.shape[-2]), int(image_chw.shape[-1])
grid = torch.stack([(2.0 * src_x + 1.0) / W - 1.0,
(2.0 * src_y + 1.0) / H - 1.0], dim=-1).unsqueeze(0)
return F.grid_sample(image_chw.unsqueeze(0), grid, mode="bilinear",
align_corners=False, padding_mode=padding_mode).squeeze(0)
def _warp_face_crop(image_chw: Tensor, cx: float, cy: float, width: float, height: float,
angle: float, output_size: int = _FM_INPUT_SIZE) -> Tensor:
"""Rotated rect → output_size² with BORDER_REPLICATE. image_chw must be in [0, 1]."""
s_x, s_y = width / output_size, height / output_size
cos_a, sin_a = math.cos(angle), math.sin(angle)
arange = torch.arange(output_size, dtype=image_chw.dtype, device=image_chw.device) - output_size * 0.5
v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij")
src_x = cx + u_grid * s_x * cos_a - v_grid * s_y * sin_a
src_y = cy + u_grid * s_x * sin_a + v_grid * s_y * cos_a
return _sample_warp(image_chw, src_x, src_y, "border")
def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -> Tuple[Tensor, float, float, float]:
"""Centered max(W,H) square → target² with BORDER_ZERO + [-1, 1] norm.
Sub-pixel grid_sample matters; integer-pad-then-resize drifts the bbox ~5%.
Returns (warped, sub_rect_cx, sub_rect_cy, sub_rect_size) the triplet maps
tensor-normalized [0,1] detections back to image pixels.
"""
H, W = int(image_chw_raw.shape[1]), int(image_chw_raw.shape[2])
sub_rect_size = float(max(W, H))
sub_rect_cx, sub_rect_cy = W * 0.5, H * 0.5
s = sub_rect_size / target
arange = torch.arange(target, dtype=image_chw_raw.dtype, device=image_chw_raw.device) - target * 0.5
v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij")
out = _sample_warp(image_chw_raw, sub_rect_cx + u_grid * s, sub_rect_cy + v_grid * s, "zeros")
return (out / 127.5) - 1.0, sub_rect_cx, sub_rect_cy, sub_rect_size
class FaceLandmarker(nn.Module):
"""BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short'
(128², 2m) or 'full' (192² FPN, 5m). State dict uses inner-module prefixes
`detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel
wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading.
"""
def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"):
super().__init__()
det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant)
self.detector_variant = detector_variant
self.detector = det_cls(device=device, dtype=dtype, operations=operations)
self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations)
self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations)
self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False)
def run_detector_batch(self, images_rgb_uint8: List[np.ndarray],
score_thresh: float = _BF_MIN_SCORE,
iou_thresh: float = 0.5):
"""Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded)
where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords."""
if not images_rgb_uint8:
return [], [], [], []
device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype
det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range)
if self.detector_variant == "full"
else (_BF_INPUT_SIZE, _decode_blazeface))
# Same-size frames: stack once and transfer once. Variable size falls back
# to per-image (only triggers for SAM3DBody's head crops).
sizes = [tuple(img.shape[:2]) for img in images_rgb_uint8]
if len(set(sizes)) == 1:
batch_chw = torch.from_numpy(np.stack(images_rgb_uint8, axis=0)).to(device, dtype).movedim(-1, -3).contiguous()
img_raws = [batch_chw[bi] for bi in range(batch_chw.shape[0])]
else:
img_raws = [torch.from_numpy(img).to(device, dtype).movedim(-1, -3).contiguous() for img in images_rgb_uint8]
warps = [_blazeface_input_warp(img_raw, det_input_size) for img_raw in img_raws]
det_crops = [w[0] for w in warps]
sub_rects = [(w[1], w[2], w[3]) for w in warps]
regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0))
regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy()
per_frame = []
for b in range(len(images_rgb_uint8)):
decoded = decode_fn(regs_np[b], cls_np[b], score_thresh=score_thresh)
per_frame.append(_weighted_nms(decoded, iou_thresh=iou_thresh) if decoded.shape[0] > 0 else decoded)
return img_raws, sub_rects, sizes, per_frame
def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1,
score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]:
"""Full pipeline batched across `images_rgb_uint8`. Returns one face-dict
list per image (empty if nothing detected). Face dict:
bbox_xyxy (4,) image pixels, blendshapes {52} [0,1],
landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in
192-canonical (pre-transformation) units, presence float (raw logit).
"""
img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch(
images_rgb_uint8, score_thresh=score_thresh,
)
# tensor-normalized → image-normalized [0,1] for _detection_to_face_rect.
for b, decoded in enumerate(per_frame_dets):
if decoded.shape[0] == 0:
continue
cx, cy, size = sub_rects[b]
H, W = sizes[b]
sx0, sy0 = cx - size * 0.5, cy - size * 0.5
decoded[:, 0:16:2] = (sx0 + size * decoded[:, 0:16:2]) / W
decoded[:, 1:16:2] = (sy0 + size * decoded[:, 1:16:2]) / H
if num_faces > 0:
per_frame_dets[b] = decoded[: int(num_faces)]
# Collect every detected face across all frames into one mesh input.
face_params: List[Tuple[int, float, float, float, float, float, float]] = []
mesh_crops: List[Tensor] = []
for b, dets in enumerate(per_frame_dets):
if dets.shape[0] == 0:
continue
H, W = sizes[b]
img_for_mesh = img_raws[b] / 255.0
for det in dets:
cx, cy, w, h, angle = _detection_to_face_rect(det, W, H)
mesh_crops.append(_warp_face_crop(img_for_mesh, cx, cy, w, h, angle, _FM_INPUT_SIZE))
face_params.append((b, float(det[16]), cx, cy, w, h, angle))
results: List[List[dict]] = [[] for _ in range(len(images_rgb_uint8))]
if not mesh_crops:
return results
lmks_canon_b, presence_b = self.mesh(torch.stack(mesh_crops, dim=0))
bs_out_b = self.blendshapes(lmks_canon_b[:, self._bs_idx, :2])
# Batched canonical→image affine
params_t = torch.tensor(
[(cx, cy, w, h, math.cos(a), math.sin(a)) for (_b, _s, cx, cy, w, h, a) in face_params],
device=lmks_canon_b.device, dtype=lmks_canon_b.dtype,
)
cxs, cys, ws, hs, cos_a, sin_a = params_t.unbind(dim=1)
inv = 1.0 / _FM_INPUT_SIZE
u = lmks_canon_b[..., 0] - _FM_INPUT_SIZE * 0.5
v = lmks_canon_b[..., 1] - _FM_INPUT_SIZE * 0.5
lmks_xy_t = torch.stack([
cxs[:, None] + u * (ws * inv * cos_a)[:, None] - v * (hs * inv * sin_a)[:, None],
cys[:, None] + u * (ws * inv * sin_a)[:, None] + v * (hs * inv * cos_a)[:, None],
], dim=-1)
lmks_xy_np = lmks_xy_t.float().cpu().numpy()
lmks_canon_np = lmks_canon_b.float().cpu().numpy()
presence_np = presence_b.float().cpu().numpy()
bs_np = bs_out_b.float().cpu().numpy()
for i, (b, score, *_) in enumerate(face_params):
lmks_xy = lmks_xy_np[i]
mn, mx = lmks_xy.min(0), lmks_xy.max(0)
results[b].append({
"bbox_xyxy": np.array([mn[0], mn[1], mx[0], mx[1]], dtype=np.float32),
"blendshapes": dict(zip(BLENDSHAPE_NAMES, bs_np[i].tolist())),
"landmarks_xy": lmks_xy,
"landmarks_3d": lmks_canon_np[i],
"presence": float(presence_np[i]),
"score": score,
})
return results

View File

@ -0,0 +1,502 @@
"""ComfyUI nodes for the pure-PyTorch MediaPipe Face Landmarker port.
Custom IO types:
FACE_LANDMARKER FaceLandmarkerModel wrapper (ModelPatcher inside)
FACE_LANDMARKS {"frames": List[List[face_dict]], "image_size": (H, W),
"connection_sets": dict[str, frozenset[(int, int)]]}
face_dict: bbox_xyxy, blendshapes, landmarks_xy,
landmarks_3d, presence, score, transformation_matrix
MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type pair with DrawBBoxes.
"""
from __future__ import annotations
import numpy as np
import torch
from PIL import Image, ImageColor, ImageDraw
from tqdm.auto import tqdm
from typing_extensions import override
import comfy.model_management
import comfy.model_patcher
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, io
from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
FaceLandmarkerType = io.Custom("FACE_LANDMARKER")
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
_CONTOUR_PARTS = ("face_oval", "left_eye", "right_eye", "left_eyebrow", "right_eyebrow", "lips")
class FaceLandmarkerModel:
"""Loaded FaceLandmarker variants + ModelPatcher per variant.
Safetensors layout: `detector_short.*` / `detector_full.*` plus shared
`mesh.*`, `blendshapes.*`, `canonical_*`, and `topology.*`.
PReLU forces plain-nn / fp32 (manual_cast strands buffers across devices).
"""
def __init__(self, state_dict: dict):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = torch.float32
# FACEMESH_* connection sets, embedded as int32 (N, 2) under topology.*.
base: dict[str, frozenset] = {}
for k in [k for k in state_dict if k.startswith("topology.")]:
base[k[len("topology."):]] = frozenset(map(tuple, state_dict.pop(k).tolist()))
base["contours"] = frozenset().union(*(base[p] for p in _CONTOUR_PARTS))
base["all"] = base["contours"] | base["irises"] | base["nose"]
self.connection_sets: dict[str, frozenset] = base
self.canonical_data: dict[str, np.ndarray] = {k: state_dict.pop(k).numpy() for k in _CANONICAL_KEYS}
shared = {k: v for k, v in state_dict.items() if k.startswith(("mesh.", "blendshapes."))}
self.models: dict[str, FaceLandmarker] = {}
self.patchers: dict[str, comfy.model_patcher.ModelPatcher] = {}
for variant in ("short", "full"):
prefix = f"detector_{variant}."
sub = dict(shared)
sub.update({f"detector.{k[len(prefix):]}": v for k, v in state_dict.items() if k.startswith(prefix)})
fl = FaceLandmarker(device=offload_device, dtype=self.dtype, operations=None, detector_variant=variant).eval()
fl.load_state_dict(sub, strict=False)
self.models[variant] = fl
self.patchers[variant] = comfy.model_patcher.CoreModelPatcher(
fl, load_device=self.load_device, offload_device=offload_device,
size=comfy.model_management.module_size(fl),
)
def detect_batch(self, images, num_faces: int, score_thresh: float, variant: str):
comfy.model_management.load_model_gpu(self.patchers[variant])
return self.models[variant].detect_batch(images, num_faces=num_faces, score_thresh=score_thresh)
def _image_to_uint8(image: torch.Tensor) -> np.ndarray:
return image[..., :3].mul(255.0).add_(0.5).clamp_(0, 255).to(torch.uint8).cpu().numpy()
def _parse_color(color: str) -> tuple[int, int, int]:
try:
return ImageColor.getrgb(color)[:3]
except ValueError:
return (0, 255, 0)
def _copy_face(face: dict) -> dict:
"""Shallow copy of a face_dict with array-fields cloned so callers can mutate."""
return {
"bbox_xyxy": face["bbox_xyxy"].copy(),
"blendshapes": dict(face["blendshapes"]),
"landmarks_xy": face["landmarks_xy"].copy(),
"landmarks_3d": face["landmarks_3d"].copy(),
"presence": face["presence"],
"score": face["score"],
}
def _lerp_face(a: dict, b: dict, t: float) -> dict:
return {
"bbox_xyxy": (1 - t) * a["bbox_xyxy"] + t * b["bbox_xyxy"],
"blendshapes": {k: (1 - t) * a["blendshapes"][k] + t * b["blendshapes"][k] for k in a["blendshapes"]},
"landmarks_xy": (1 - t) * a["landmarks_xy"] + t * b["landmarks_xy"],
"landmarks_3d": (1 - t) * a["landmarks_3d"] + t * b["landmarks_3d"],
"presence": (1 - t) * a["presence"] + t * b["presence"],
"score": (1 - t) * a["score"] + t * b["score"],
}
def _match_faces(a: list[dict], b: list[dict]) -> list[tuple[int, int]]:
"""Greedy nearest-neighbour pairing of faces between two frames by bbox
centre distance. Unmatched (when counts differ) are dropped."""
if not a or not b:
return []
centers_a = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in a])
centers_b = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in b])
dists = np.linalg.norm(centers_a[:, None] - centers_b[None], axis=-1)
pairs: list[tuple[int, int]] = []
used_a: set[int] = set()
used_b: set[int] = set()
candidates = sorted((dists[ia, ib], ia, ib) for ia in range(len(a)) for ib in range(len(b)))
for _, ia, ib in candidates:
if ia in used_a or ib in used_b:
continue
pairs.append((ia, ib))
used_a.add(ia)
used_b.add(ib)
return pairs
def _fill_missing_frames(frames: list[list[dict]], mode: str) -> None:
"""In-place fill empty frame slots from neighbouring detections. Multi-face
aware: pairs faces across bracketing frames by greedy bbox-centre NN.
When counts differ, unmatched faces are dropped from the synthesised frame."""
if mode == "empty":
return
valid = [i for i, fr in enumerate(frames) if fr]
if not valid:
return # nothing to fill from
if mode == "previous":
last: list[dict] = []
for i, fr in enumerate(frames):
if fr:
last = fr
elif last:
frames[i] = [_copy_face(f) for f in last]
return
# interpolate: lerp between bracketing valid frames; clamp at ends.
for i in range(len(frames)):
if frames[i]:
continue
prev_i = max((v for v in valid if v < i), default=None)
next_i = min((v for v in valid if v > i), default=None)
if prev_i is None:
frames[i] = [_copy_face(f) for f in frames[next_i]]
elif next_i is None:
frames[i] = [_copy_face(f) for f in frames[prev_i]]
else:
t = (i - prev_i) / (next_i - prev_i)
pairs = _match_faces(frames[prev_i], frames[next_i])
frames[i] = [_lerp_face(frames[prev_i][a], frames[next_i][b], t) for a, b in pairs]
def _ordered_rings(edges: frozenset[tuple[int, int]]) -> list[list[int]]:
"""Walk an unordered edge set into one or more closed-loop vertex rings
(handles multi-loop sets like FACEMESH_LIPS: outer + inner)."""
adj: dict[int, set[int]] = {}
for a, b in edges:
adj.setdefault(a, set()).add(b)
adj.setdefault(b, set()).add(a)
visited: set[int] = set()
rings: list[list[int]] = []
for start in adj:
if start in visited:
continue
ring = [start]
visited.add(start)
prev, cur = -1, start
while True:
nxt = next((v for v in adj[cur] if v != prev), None)
if nxt is None or nxt == start:
break
ring.append(nxt)
visited.add(nxt)
prev, cur = cur, nxt
rings.append(ring)
return rings
class LoadMediaPipeFaceLandmarker(io.ComfyNode):
"""Load MediaPipe Face Landmarker v2 weights. Contains both detector variants
(short / full), shared mesh, blendshapes, and canonical geometry."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadMediaPipeFaceLandmarker",
display_name="Load MediaPipe Face Landmarker",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"),
tooltip="Face Landmarker safetensors from models/mediapipe/."),
],
outputs=[FaceLandmarkerType.Output()],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True)
wrapper = FaceLandmarkerModel(sd)
return io.NodeOutput(wrapper)
# Per-frame fallback modes for detection failures in a batch.
_FALLBACK_MODES = ("empty", "previous", "interpolate")
class MediaPipeFaceLandmarker(io.ComfyNode):
"""BlazeFace → FaceMesh v2 → ARKit-52 blendshapes, batched across the
input. Also emits a BOUNDING_BOX list (landmark-extent bbox per face)
pair with DrawBBoxes for detector-only viz or MediaPipeFaceMeshVisualize
for the mesh overlay."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceLandmarker",
display_name="MediaPipe Face Landmarker",
category="image/detection",
inputs=[
FaceLandmarkerType.Input("face_landmarker"),
io.Image.Input("image"),
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
tooltip="Face detector range. 'short' is tuned for close-up faces "
"(within ~2 m of the camera); 'full' covers farther / smaller "
"faces (up to ~5 m) but is slower. 'both' runs both detectors and "
"keeps whichever found more faces per frame (~2× detection cost)."),
io.Int.Input("num_faces", default=1, min=0, max=16, step=1,
tooltip="Maximum faces to return per frame. 0 = no cap (return all detected)."),
io.Float.Input("min_confidence", default=0.5, min=0.0, max=1.0, step=0.01, advanced=True,
tooltip="BlazeFace score threshold. Lower to catch small/occluded faces."),
io.Combo.Input("missing_frame_fallback", options=list(_FALLBACK_MODES), default="empty", advanced=True,
tooltip="Per-frame behaviour when detection fails in a batch. "
"'empty' leaves the frame faceless. 'previous' copies the most recent successful "
"detection. 'interpolate' lerps landmarks/bbox/blendshapes between bracketing "
"successful frames. Multi-face: pairs faces across frames by greedy bbox-centre NN."),
],
outputs=[
FaceLandmarksType.Output(display_name="face_landmarks"),
io.BoundingBox.Output("bboxes"),
],
)
@classmethod
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence,
missing_frame_fallback) -> io.NodeOutput:
canonical = face_landmarker.canonical_data
img_np = _image_to_uint8(image)
B, H, W = img_np.shape[:3]
chunk = 16
is_both = detector_variant == "both"
total_work = 2 * B if is_both else B
pbar = comfy.utils.ProgressBar(total_work)
def _run(variant: str) -> list[list[dict]]:
res: list[list[dict]] = []
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
for i in range(0, B, chunk):
end = min(i + chunk, B)
res.extend(face_landmarker.detect_batch(
[img_np[bi] for bi in range(i, end)],
num_faces=int(num_faces),
score_thresh=float(min_confidence),
variant=variant,
))
pbar.update_absolute(min(pbar.current + (end - i), total_work))
tq.update(end - i)
return res
if is_both:
short_res = _run("short")
full_res = _run("full")
# Per-frame keep whichever found more faces (tie → short).
frames: list[list[dict]] = [
short_res[bi] if len(short_res[bi]) >= len(full_res[bi]) else full_res[bi]
for bi in range(B)
]
else:
frames = _run(detector_variant)
_fill_missing_frames(frames, missing_frame_fallback)
bboxes = []
for per_frame in frames:
per_bb = []
for f in per_frame:
f["transformation_matrix"] = transformation_matrix_from_detection(f, W, H, canonical)
x1, y1, x2, y2 = (float(v) for v in f["bbox_xyxy"])
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
bboxes.append(per_bb)
return io.NodeOutput({"frames": frames, "image_size": (H, W),
"connection_sets": face_landmarker.connection_sets}, bboxes)
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
_ALL_CONNECTION_PARTS: tuple[str, ...] = (*_CONTOUR_PARTS, "irises", "nose")
_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
("face_oval", True),
("lips", True),
("left_eye", True),
("right_eye", True),
("left_eyebrow", True),
("right_eyebrow", True),
("irises", True),
("nose", True),
("tesselation", False),
)
class MediaPipeFaceMeshVisualize(io.ComfyNode):
"""Draw a FACEMESH_* subset over an image. Topology travels with the
FACE_LANDMARKS payload (set at detection time)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMeshVisualize",
display_name="MediaPipe Face Mesh Visualize",
category="image/detection",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
io.DynamicCombo.Input(
"connections",
tooltip="'all' = oval+eyes+brows+lips+irises+nose. 'fill' = solid face_oval polygon (silhouette mask). 'custom' = toggle each feature individually (including 'tesselation', the full 2547-edge wireframe).",
options=[
io.DynamicCombo.Option("all", []),
io.DynamicCombo.Option("fill", []),
io.DynamicCombo.Option("custom", [
io.Boolean.Input(feat, default=default,
tooltip=f"Draw the '{feat}' connection set.")
for feat, default in _CUSTOM_FEATURES
]),
],
),
io.Color.Input("color", default="#00ff00"),
io.Int.Input("thickness", default=1, min=0, max=8, step=1,
tooltip="Edge line thickness in pixels. 0 disables edge drawing."),
io.Int.Input("point_size", default=2, min=0, max=16, step=1,
tooltip="Landmark dot radius in pixels. 0 disables point drawing."),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, face_landmarks, connections, color, thickness, point_size, image=None) -> io.NodeOutput:
sets = face_landmarks["connection_sets"]
sel = connections["connections"]
fill_rings: list[list[int]] | None = None
if sel == "fill":
fill_rings = _ordered_rings(sets["face_oval"])
edges = frozenset()
elif sel == "custom":
parts = [feat for feat, _ in _CUSTOM_FEATURES if connections.get(feat, False)]
edges = frozenset().union(*(sets[p] for p in parts))
else: # "all"
edges = frozenset().union(*(sets[p] for p in _ALL_CONNECTION_PARTS))
rgb, thick, psize = _parse_color(color), int(thickness), int(point_size)
frames = face_landmarks["frames"]
if image is None:
H, W = face_landmarks["image_size"]
img_np = np.zeros((len(frames), H, W, 3), dtype=np.uint8)
else:
img_np = _image_to_uint8(image)
B = img_np.shape[0]
n_frames = len(frames)
pbar = comfy.utils.ProgressBar(B)
out = np.empty_like(img_np)
for bi in range(B):
faces = frames[bi] if bi < n_frames else []
out[bi] = _draw_mesh(img_np[bi], faces, edges, rgb, thick, psize, fill_rings)
pbar.update_absolute(bi + 1)
return io.NodeOutput(torch.from_numpy(out).to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype(),
).div_(255.0))
def _draw_mesh(image_rgb: np.ndarray, faces: list, edges,
rgb: tuple[int, int, int], thickness: int,
point_size: int, fill_rings: list[list[int]] | None = None) -> np.ndarray:
draw_edges = thickness > 0 and edges
if not faces or (fill_rings is None and not draw_edges and point_size <= 0):
return image_rgb.copy()
pil = Image.fromarray(image_rgb)
draw = ImageDraw.Draw(pil)
r = point_size * 0.5
if fill_rings is not None:
for f in faces:
lmks = f["landmarks_xy"]
for ring in fill_rings:
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=rgb)
return np.asarray(pil)
for f in faces:
lmks = f["landmarks_xy"]
n = lmks.shape[0]
if draw_edges:
for a, b in edges:
if a < n and b < n:
draw.line([(float(lmks[a, 0]), float(lmks[a, 1])),
(float(lmks[b, 0]), float(lmks[b, 1]))], fill=rgb, width=thickness)
if point_size == 1:
draw.point(lmks.flatten().tolist(), fill=rgb)
elif point_size > 1:
for x, y in lmks:
draw.ellipse((float(x) - r, float(y) - r, float(x) + r, float(y) + r), fill=rgb)
return np.asarray(pil)
# Mask region presets — closed-loop topologies only.
_MASK_REGIONS: tuple[str, ...] = ("face_oval", "lips", "left_eye", "right_eye", "irises")
_MASK_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
("face_oval", True),
("lips", False),
("left_eye", False),
("right_eye", False),
("irises", False),
)
class MediaPipeFaceMask(io.ComfyNode):
"""Binary mask from face landmarks, filled polygon per face. One mask per
frame in the batch; faces in the same frame composite (union)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MediaPipeFaceMask",
display_name="MediaPipe Face Mask",
category="image/detection",
inputs=[
FaceLandmarksType.Input("face_landmarks"),
io.DynamicCombo.Input(
"regions",
tooltip="'all' = union of face_oval+lips+eyes+irises (which collapses to face_oval since it encloses the rest). 'custom' = toggle each region individually for combos like lips+eyes.",
options=[
io.DynamicCombo.Option("all", []),
io.DynamicCombo.Option("custom", [
io.Boolean.Input(reg, default=default,
tooltip=f"Include the '{reg}' region in the mask.")
for reg, default in _MASK_CUSTOM_FEATURES
]),
],
),
],
outputs=[io.Mask.Output()],
)
@classmethod
def execute(cls, face_landmarks, regions) -> io.NodeOutput:
sets = face_landmarks["connection_sets"]
sel = regions["regions"]
if sel == "custom":
picked = [reg for reg, _ in _MASK_CUSTOM_FEATURES if regions.get(reg, False)]
else:
picked = list(_MASK_REGIONS)
rings = [r for reg in picked for r in _ordered_rings(sets[reg])]
frames = face_landmarks["frames"]
H, W = face_landmarks["image_size"]
masks = np.zeros((len(frames), H, W), dtype=np.uint8)
pbar = comfy.utils.ProgressBar(len(frames))
for bi, per_frame in enumerate(frames):
if per_frame:
pil = Image.new("L", (W, H), 0)
draw = ImageDraw.Draw(pil)
for f in per_frame:
lmks = f["landmarks_xy"]
for ring in rings:
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=255)
masks[bi] = np.asarray(pil)
pbar.update_absolute(bi + 1)
return io.NodeOutput(torch.from_numpy(masks).to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype(),
).div_(255.0))
class MediaPipeFaceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMediaPipeFaceLandmarker, MediaPipeFaceLandmarker, MediaPipeFaceMeshVisualize, MediaPipeFaceMask]
async def comfy_entrypoint() -> MediaPipeFaceExtension:
return MediaPipeFaceExtension()

View File

@ -60,6 +60,8 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")

View File

@ -2444,6 +2444,7 @@ async def init_builtin_extra_nodes():
"nodes_hidream_o1.py",
"nodes_save_3d.py",
"nodes_moge.py",
"nodes_mediapipe.py",
]
import_failed = []