mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-23 23:47:25 +08:00
510 lines
22 KiB
Python
510 lines
22 KiB
Python
"""ComfyUI nodes for the pure-PyTorch MediaPipe Face Landmarker port.
|
||
|
||
Custom IO types:
|
||
FACE_LANDMARKER — FaceLandmarkerModel wrapper (ModelPatcher inside)
|
||
FACE_LANDMARKS — {"frames": List[List[face_dict]], "image_size": (H, W),
|
||
"connection_sets": dict[str, frozenset[(int, int)]]}
|
||
face_dict: bbox_xyxy, blendshapes, landmarks_xy,
|
||
landmarks_3d, presence, score, transformation_matrix
|
||
|
||
MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import numpy as np
|
||
import torch
|
||
from PIL import Image, ImageColor, ImageDraw
|
||
from tqdm.auto import tqdm
|
||
from typing_extensions import override
|
||
|
||
import comfy.model_management
|
||
import comfy.model_patcher
|
||
import comfy.utils
|
||
import folder_paths
|
||
from comfy_api.latest import ComfyExtension, io
|
||
|
||
from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
|
||
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
|
||
|
||
|
||
FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
|
||
FaceLandmarksType = io.Custom("FACE_LANDMARKS")
|
||
|
||
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
|
||
_CONTOUR_PARTS = ("face_oval", "left_eye", "right_eye", "left_eyebrow", "right_eyebrow", "lips")
|
||
|
||
|
||
class FaceLandmarkerModel:
|
||
"""Loaded FaceLandmarker variants + ModelPatcher per variant.
|
||
|
||
Safetensors layout: `detector_short.*` / `detector_full.*` plus shared
|
||
`mesh.*`, `blendshapes.*`, `canonical_*`, and `topology.*`.
|
||
PReLU forces plain-nn / fp32 (manual_cast strands buffers across devices).
|
||
"""
|
||
|
||
def __init__(self, state_dict: dict):
|
||
self.load_device = comfy.model_management.text_encoder_device()
|
||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||
self.dtype = torch.float32
|
||
|
||
# FACEMESH_* connection sets, embedded as int32 (N, 2) under topology.*.
|
||
base: dict[str, frozenset] = {}
|
||
for k in [k for k in state_dict if k.startswith("topology.")]:
|
||
base[k[len("topology."):]] = frozenset(map(tuple, state_dict.pop(k).tolist()))
|
||
base["contours"] = frozenset().union(*(base[p] for p in _CONTOUR_PARTS))
|
||
base["all"] = base["contours"] | base["irises"] | base["nose"]
|
||
|
||
self.connection_sets: dict[str, frozenset] = base
|
||
self.canonical_data: dict[str, np.ndarray] = {k: state_dict.pop(k).numpy() for k in _CANONICAL_KEYS}
|
||
|
||
shared = {k: v for k, v in state_dict.items() if k.startswith(("mesh.", "blendshapes."))}
|
||
|
||
self.models: dict[str, FaceLandmarker] = {}
|
||
self.patchers: dict[str, comfy.model_patcher.ModelPatcher] = {}
|
||
for variant in ("short", "full"):
|
||
prefix = f"detector_{variant}."
|
||
sub = dict(shared)
|
||
sub.update({f"detector.{k[len(prefix):]}": v for k, v in state_dict.items() if k.startswith(prefix)})
|
||
fl = FaceLandmarker(device=offload_device, dtype=self.dtype, operations=None, detector_variant=variant).eval()
|
||
fl.load_state_dict(sub, strict=False)
|
||
|
||
self.models[variant] = fl
|
||
self.patchers[variant] = comfy.model_patcher.CoreModelPatcher(
|
||
fl, load_device=self.load_device, offload_device=offload_device,
|
||
size=comfy.model_management.module_size(fl),
|
||
)
|
||
|
||
def detect_batch(self, images, num_faces: int, score_thresh: float, variant: str):
|
||
comfy.model_management.load_model_gpu(self.patchers[variant])
|
||
return self.models[variant].detect_batch(images, num_faces=num_faces, score_thresh=score_thresh)
|
||
|
||
|
||
def _image_to_uint8(image: torch.Tensor) -> np.ndarray:
|
||
return image[..., :3].mul(255.0).add_(0.5).clamp_(0, 255).to(torch.uint8).cpu().numpy()
|
||
|
||
|
||
def _parse_color(color: str) -> tuple[int, int, int]:
|
||
try:
|
||
return ImageColor.getrgb(color)[:3]
|
||
except ValueError:
|
||
return (0, 255, 0)
|
||
|
||
|
||
def _copy_face(face: dict) -> dict:
|
||
"""Shallow copy of a face_dict with array-fields cloned so callers can mutate."""
|
||
return {
|
||
"bbox_xyxy": face["bbox_xyxy"].copy(),
|
||
"blendshapes": dict(face["blendshapes"]),
|
||
"landmarks_xy": face["landmarks_xy"].copy(),
|
||
"landmarks_3d": face["landmarks_3d"].copy(),
|
||
"presence": face["presence"],
|
||
"score": face["score"],
|
||
}
|
||
|
||
|
||
def _lerp_face(a: dict, b: dict, t: float) -> dict:
|
||
return {
|
||
"bbox_xyxy": (1 - t) * a["bbox_xyxy"] + t * b["bbox_xyxy"],
|
||
"blendshapes": {k: (1 - t) * a["blendshapes"][k] + t * b["blendshapes"][k] for k in a["blendshapes"]},
|
||
"landmarks_xy": (1 - t) * a["landmarks_xy"] + t * b["landmarks_xy"],
|
||
"landmarks_3d": (1 - t) * a["landmarks_3d"] + t * b["landmarks_3d"],
|
||
"presence": (1 - t) * a["presence"] + t * b["presence"],
|
||
"score": (1 - t) * a["score"] + t * b["score"],
|
||
}
|
||
|
||
|
||
def _match_faces(a: list[dict], b: list[dict]) -> list[tuple[int, int]]:
|
||
"""Greedy nearest-neighbour pairing of faces between two frames by bbox
|
||
centre distance. Unmatched (when counts differ) are dropped."""
|
||
if not a or not b:
|
||
return []
|
||
centers_a = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
|
||
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in a])
|
||
centers_b = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]),
|
||
0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in b])
|
||
dists = np.linalg.norm(centers_a[:, None] - centers_b[None], axis=-1)
|
||
pairs: list[tuple[int, int]] = []
|
||
used_a: set[int] = set()
|
||
used_b: set[int] = set()
|
||
candidates = sorted((dists[ia, ib], ia, ib) for ia in range(len(a)) for ib in range(len(b)))
|
||
for _, ia, ib in candidates:
|
||
if ia in used_a or ib in used_b:
|
||
continue
|
||
pairs.append((ia, ib))
|
||
used_a.add(ia)
|
||
used_b.add(ib)
|
||
return pairs
|
||
|
||
|
||
def _fill_missing_frames(frames: list[list[dict]], mode: str) -> None:
|
||
"""In-place fill empty frame slots from neighbouring detections. Multi-face
|
||
aware: pairs faces across bracketing frames by greedy bbox-centre NN.
|
||
When counts differ, unmatched faces are dropped from the synthesised frame."""
|
||
if mode == "empty":
|
||
return
|
||
valid = [i for i, fr in enumerate(frames) if fr]
|
||
if not valid:
|
||
return # nothing to fill from
|
||
if mode == "previous":
|
||
last: list[dict] = []
|
||
for i, fr in enumerate(frames):
|
||
if fr:
|
||
last = fr
|
||
elif last:
|
||
frames[i] = [_copy_face(f) for f in last]
|
||
return
|
||
# interpolate: lerp between bracketing valid frames; clamp at ends.
|
||
for i in range(len(frames)):
|
||
if frames[i]:
|
||
continue
|
||
prev_i = max((v for v in valid if v < i), default=None)
|
||
next_i = min((v for v in valid if v > i), default=None)
|
||
if prev_i is None:
|
||
frames[i] = [_copy_face(f) for f in frames[next_i]]
|
||
elif next_i is None:
|
||
frames[i] = [_copy_face(f) for f in frames[prev_i]]
|
||
else:
|
||
t = (i - prev_i) / (next_i - prev_i)
|
||
pairs = _match_faces(frames[prev_i], frames[next_i])
|
||
frames[i] = [_lerp_face(frames[prev_i][a], frames[next_i][b], t) for a, b in pairs]
|
||
|
||
|
||
def _ordered_rings(edges: frozenset[tuple[int, int]]) -> list[list[int]]:
|
||
"""Walk an unordered edge set into one or more closed-loop vertex rings
|
||
(handles multi-loop sets like FACEMESH_LIPS: outer + inner)."""
|
||
adj: dict[int, set[int]] = {}
|
||
for a, b in edges:
|
||
adj.setdefault(a, set()).add(b)
|
||
adj.setdefault(b, set()).add(a)
|
||
visited: set[int] = set()
|
||
rings: list[list[int]] = []
|
||
for start in adj:
|
||
if start in visited:
|
||
continue
|
||
ring = [start]
|
||
visited.add(start)
|
||
prev, cur = -1, start
|
||
while True:
|
||
nxt = next((v for v in adj[cur] if v != prev), None)
|
||
if nxt is None or nxt == start:
|
||
break
|
||
ring.append(nxt)
|
||
visited.add(nxt)
|
||
prev, cur = cur, nxt
|
||
rings.append(ring)
|
||
return rings
|
||
|
||
|
||
class LoadMediaPipeFaceLandmarker(io.ComfyNode):
|
||
"""Load MediaPipe Face Landmarker v2 weights. Contains both detector variants
|
||
(short / full), shared mesh, blendshapes, and canonical geometry."""
|
||
|
||
@classmethod
|
||
def define_schema(cls):
|
||
return io.Schema(
|
||
node_id="LoadMediaPipeFaceLandmarker",
|
||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||
display_name="Load Face Detection Model (MediaPipe)",
|
||
category="loaders",
|
||
inputs=[
|
||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
|
||
tooltip="Face detection model from models/detection/."),
|
||
],
|
||
outputs=[FaceDetectionType.Output()],
|
||
)
|
||
|
||
@classmethod
|
||
def execute(cls, model_name) -> io.NodeOutput:
|
||
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
|
||
wrapper = FaceLandmarkerModel(sd)
|
||
return io.NodeOutput(wrapper)
|
||
|
||
|
||
# Per-frame fallback modes for detection failures in a batch.
|
||
_FALLBACK_MODES = ("empty", "previous", "interpolate")
|
||
|
||
|
||
class MediaPipeFaceLandmarker(io.ComfyNode):
|
||
"""BlazeFace → FaceMesh v2 → ARKit-52 blendshapes, batched across the
|
||
input. Also emits a BOUNDING_BOX list (landmark-extent bbox per face) —
|
||
pair with DrawBBoxes for detector-only viz or MediaPipeFaceMeshVisualize
|
||
for the mesh overlay."""
|
||
|
||
@classmethod
|
||
def define_schema(cls):
|
||
return io.Schema(
|
||
node_id="MediaPipeFaceLandmarker",
|
||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
|
||
display_name="Detect Face Landmarks (MediaPipe)",
|
||
category="image/detection",
|
||
description="Detects facial landmarks using MediaPipe model.",
|
||
inputs=[
|
||
FaceDetectionType.Input("face_detection_model"),
|
||
io.Image.Input("image"),
|
||
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
|
||
tooltip="Face detector range. 'short' is tuned for close-up faces "
|
||
"(within ~2 m of the camera); 'full' covers farther / smaller "
|
||
"faces (up to ~5 m) but is slower. 'both' runs both detectors and "
|
||
"keeps whichever found more faces per frame (~2× detection cost)."),
|
||
io.Int.Input("num_faces", default=1, min=0, max=16, step=1,
|
||
tooltip="Maximum faces to return per frame. 0 = no cap (return all detected)."),
|
||
io.Float.Input("min_confidence", default=0.5, min=0.0, max=1.0, step=0.01, advanced=True,
|
||
tooltip="BlazeFace score threshold. Lower to catch small/occluded faces."),
|
||
io.Combo.Input("missing_frame_fallback", options=list(_FALLBACK_MODES), default="empty", advanced=True,
|
||
tooltip="Per-frame behaviour when detection fails in a batch. "
|
||
"'empty' leaves the frame faceless. 'previous' copies the most recent successful "
|
||
"detection. 'interpolate' lerps landmarks/bbox/blendshapes between bracketing "
|
||
"successful frames. Multi-face: pairs faces across frames by greedy bbox-centre NN."),
|
||
],
|
||
outputs=[
|
||
FaceLandmarksType.Output(display_name="face_landmarks"),
|
||
io.BoundingBox.Output("bboxes"),
|
||
],
|
||
)
|
||
|
||
@classmethod
|
||
def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
|
||
missing_frame_fallback) -> io.NodeOutput:
|
||
canonical = face_detection_model.canonical_data
|
||
img_np = _image_to_uint8(image)
|
||
B, H, W = img_np.shape[:3]
|
||
chunk = 16
|
||
is_both = detector_variant == "both"
|
||
total_work = 2 * B if is_both else B
|
||
pbar = comfy.utils.ProgressBar(total_work)
|
||
|
||
def _run(variant: str) -> list[list[dict]]:
|
||
res: list[list[dict]] = []
|
||
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
|
||
for i in range(0, B, chunk):
|
||
end = min(i + chunk, B)
|
||
res.extend(face_detection_model.detect_batch(
|
||
[img_np[bi] for bi in range(i, end)],
|
||
num_faces=int(num_faces),
|
||
score_thresh=float(min_confidence),
|
||
variant=variant,
|
||
))
|
||
pbar.update_absolute(min(pbar.current + (end - i), total_work))
|
||
tq.update(end - i)
|
||
return res
|
||
|
||
if is_both:
|
||
short_res = _run("short")
|
||
full_res = _run("full")
|
||
# Per-frame keep whichever found more faces (tie → short).
|
||
frames: list[list[dict]] = [
|
||
short_res[bi] if len(short_res[bi]) >= len(full_res[bi]) else full_res[bi]
|
||
for bi in range(B)
|
||
]
|
||
else:
|
||
frames = _run(detector_variant)
|
||
_fill_missing_frames(frames, missing_frame_fallback)
|
||
bboxes = []
|
||
for per_frame in frames:
|
||
per_bb = []
|
||
for f in per_frame:
|
||
f["transformation_matrix"] = transformation_matrix_from_detection(f, W, H, canonical)
|
||
x1, y1, x2, y2 = (float(v) for v in f["bbox_xyxy"])
|
||
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
|
||
bboxes.append(per_bb)
|
||
return io.NodeOutput({"frames": frames, "image_size": (H, W),
|
||
"connection_sets": face_detection_model.connection_sets}, bboxes)
|
||
|
||
|
||
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
|
||
_ALL_CONNECTION_PARTS: tuple[str, ...] = (*_CONTOUR_PARTS, "irises", "nose")
|
||
_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
|
||
("face_oval", True),
|
||
("lips", True),
|
||
("left_eye", True),
|
||
("right_eye", True),
|
||
("left_eyebrow", True),
|
||
("right_eyebrow", True),
|
||
("irises", True),
|
||
("nose", True),
|
||
("tesselation", False),
|
||
)
|
||
|
||
|
||
class MediaPipeFaceMeshVisualize(io.ComfyNode):
|
||
"""Draw a FACEMESH_* subset over an image. Topology travels with the
|
||
FACE_LANDMARKS payload (set at detection time)."""
|
||
|
||
@classmethod
|
||
def define_schema(cls):
|
||
return io.Schema(
|
||
node_id="MediaPipeFaceMeshVisualize",
|
||
search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
|
||
display_name="Visualize Face Landmarks (MediaPipe)",
|
||
category="image/detection",
|
||
description="Draws face landmarks mesh on the input image.",
|
||
inputs=[
|
||
FaceLandmarksType.Input("face_landmarks"),
|
||
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
|
||
io.DynamicCombo.Input(
|
||
"connections",
|
||
tooltip="'all' = oval+eyes+brows+lips+irises+nose. 'fill' = solid face_oval polygon (silhouette mask). 'custom' = toggle each feature individually (including 'tesselation', the full 2547-edge wireframe).",
|
||
options=[
|
||
io.DynamicCombo.Option("all", []),
|
||
io.DynamicCombo.Option("fill", []),
|
||
io.DynamicCombo.Option("custom", [
|
||
io.Boolean.Input(feat, default=default,
|
||
tooltip=f"Draw the '{feat}' connection set.")
|
||
for feat, default in _CUSTOM_FEATURES
|
||
]),
|
||
],
|
||
),
|
||
io.Color.Input("color", default="#00ff00"),
|
||
io.Int.Input("thickness", default=1, min=0, max=8, step=1,
|
||
tooltip="Edge line thickness in pixels. 0 disables edge drawing."),
|
||
io.Int.Input("point_size", default=2, min=0, max=16, step=1,
|
||
tooltip="Landmark dot radius in pixels. 0 disables point drawing."),
|
||
],
|
||
outputs=[io.Image.Output()],
|
||
)
|
||
|
||
@classmethod
|
||
def execute(cls, face_landmarks, connections, color, thickness, point_size, image=None) -> io.NodeOutput:
|
||
sets = face_landmarks["connection_sets"]
|
||
sel = connections["connections"]
|
||
fill_rings: list[list[int]] | None = None
|
||
if sel == "fill":
|
||
fill_rings = _ordered_rings(sets["face_oval"])
|
||
edges = frozenset()
|
||
elif sel == "custom":
|
||
parts = [feat for feat, _ in _CUSTOM_FEATURES if connections.get(feat, False)]
|
||
edges = frozenset().union(*(sets[p] for p in parts))
|
||
else: # "all"
|
||
edges = frozenset().union(*(sets[p] for p in _ALL_CONNECTION_PARTS))
|
||
rgb, thick, psize = _parse_color(color), int(thickness), int(point_size)
|
||
frames = face_landmarks["frames"]
|
||
if image is None:
|
||
H, W = face_landmarks["image_size"]
|
||
img_np = np.zeros((len(frames), H, W, 3), dtype=np.uint8)
|
||
else:
|
||
img_np = _image_to_uint8(image)
|
||
B = img_np.shape[0]
|
||
n_frames = len(frames)
|
||
pbar = comfy.utils.ProgressBar(B)
|
||
out = np.empty_like(img_np)
|
||
for bi in range(B):
|
||
faces = frames[bi] if bi < n_frames else []
|
||
out[bi] = _draw_mesh(img_np[bi], faces, edges, rgb, thick, psize, fill_rings)
|
||
pbar.update_absolute(bi + 1)
|
||
return io.NodeOutput(torch.from_numpy(out).to(
|
||
device=comfy.model_management.intermediate_device(),
|
||
dtype=comfy.model_management.intermediate_dtype(),
|
||
).div_(255.0))
|
||
|
||
|
||
def _draw_mesh(image_rgb: np.ndarray, faces: list, edges,
|
||
rgb: tuple[int, int, int], thickness: int,
|
||
point_size: int, fill_rings: list[list[int]] | None = None) -> np.ndarray:
|
||
draw_edges = thickness > 0 and edges
|
||
if not faces or (fill_rings is None and not draw_edges and point_size <= 0):
|
||
return image_rgb.copy()
|
||
pil = Image.fromarray(image_rgb)
|
||
draw = ImageDraw.Draw(pil)
|
||
r = point_size * 0.5
|
||
if fill_rings is not None:
|
||
for f in faces:
|
||
lmks = f["landmarks_xy"]
|
||
for ring in fill_rings:
|
||
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=rgb)
|
||
return np.asarray(pil)
|
||
for f in faces:
|
||
lmks = f["landmarks_xy"]
|
||
n = lmks.shape[0]
|
||
if draw_edges:
|
||
for a, b in edges:
|
||
if a < n and b < n:
|
||
draw.line([(float(lmks[a, 0]), float(lmks[a, 1])),
|
||
(float(lmks[b, 0]), float(lmks[b, 1]))], fill=rgb, width=thickness)
|
||
if point_size == 1:
|
||
draw.point(lmks.flatten().tolist(), fill=rgb)
|
||
elif point_size > 1:
|
||
for x, y in lmks:
|
||
draw.ellipse((float(x) - r, float(y) - r, float(x) + r, float(y) + r), fill=rgb)
|
||
return np.asarray(pil)
|
||
|
||
|
||
# Mask region presets — closed-loop topologies only.
|
||
_MASK_REGIONS: tuple[str, ...] = ("face_oval", "lips", "left_eye", "right_eye", "irises")
|
||
_MASK_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = (
|
||
("face_oval", True),
|
||
("lips", False),
|
||
("left_eye", False),
|
||
("right_eye", False),
|
||
("irises", False),
|
||
)
|
||
|
||
|
||
class MediaPipeFaceMask(io.ComfyNode):
|
||
"""Binary mask from face landmarks, filled polygon per face. One mask per
|
||
frame in the batch; faces in the same frame composite (union)."""
|
||
|
||
@classmethod
|
||
def define_schema(cls):
|
||
return io.Schema(
|
||
node_id="MediaPipeFaceMask",
|
||
search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
|
||
display_name="Draw Face Mask (MediaPipe)",
|
||
category="image/detection",
|
||
description="Draws a mask from face landmarks.",
|
||
inputs=[
|
||
FaceLandmarksType.Input("face_landmarks"),
|
||
io.DynamicCombo.Input(
|
||
"regions",
|
||
tooltip="'all' = union of face_oval+lips+eyes+irises (which collapses to face_oval since it encloses the rest). 'custom' = toggle each region individually for combos like lips+eyes.",
|
||
options=[
|
||
io.DynamicCombo.Option("all", []),
|
||
io.DynamicCombo.Option("custom", [
|
||
io.Boolean.Input(reg, default=default,
|
||
tooltip=f"Include the '{reg}' region in the mask.")
|
||
for reg, default in _MASK_CUSTOM_FEATURES
|
||
]),
|
||
],
|
||
),
|
||
],
|
||
outputs=[io.Mask.Output()],
|
||
)
|
||
|
||
@classmethod
|
||
def execute(cls, face_landmarks, regions) -> io.NodeOutput:
|
||
sets = face_landmarks["connection_sets"]
|
||
sel = regions["regions"]
|
||
if sel == "custom":
|
||
picked = [reg for reg, _ in _MASK_CUSTOM_FEATURES if regions.get(reg, False)]
|
||
else:
|
||
picked = list(_MASK_REGIONS)
|
||
rings = [r for reg in picked for r in _ordered_rings(sets[reg])]
|
||
frames = face_landmarks["frames"]
|
||
H, W = face_landmarks["image_size"]
|
||
masks = np.zeros((len(frames), H, W), dtype=np.uint8)
|
||
pbar = comfy.utils.ProgressBar(len(frames))
|
||
for bi, per_frame in enumerate(frames):
|
||
if per_frame:
|
||
pil = Image.new("L", (W, H), 0)
|
||
draw = ImageDraw.Draw(pil)
|
||
for f in per_frame:
|
||
lmks = f["landmarks_xy"]
|
||
for ring in rings:
|
||
draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=255)
|
||
masks[bi] = np.asarray(pil)
|
||
pbar.update_absolute(bi + 1)
|
||
return io.NodeOutput(torch.from_numpy(masks).to(
|
||
device=comfy.model_management.intermediate_device(),
|
||
dtype=comfy.model_management.intermediate_dtype(),
|
||
).div_(255.0))
|
||
|
||
|
||
class MediaPipeFaceExtension(ComfyExtension):
|
||
@override
|
||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||
return [LoadMediaPipeFaceLandmarker, MediaPipeFaceLandmarker, MediaPipeFaceMeshVisualize, MediaPipeFaceMask]
|
||
|
||
|
||
async def comfy_entrypoint() -> MediaPipeFaceExtension:
|
||
return MediaPipeFaceExtension()
|