Update MediaPipe nodes to standardize with existing code base (CORE-242) (#14025)

This commit is contained in:
Alexis Rolland 2026-05-21 14:39:30 +08:00 committed by GitHub
parent 1668aaf037
commit 7b7c5fed7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 15 deletions

View File

@ -28,7 +28,7 @@ from comfy_extras.mediapipe.face_landmarker import FaceLandmarker
from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection
FaceLandmarkerType = io.Custom("FACE_LANDMARKER") FaceDetectionType = io.Custom("FACE_DETECTION_MODEL")
FaceLandmarksType = io.Custom("FACE_LANDMARKS") FaceLandmarksType = io.Custom("FACE_LANDMARKS")
_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights") _CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights")
@ -204,18 +204,19 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="LoadMediaPipeFaceLandmarker", node_id="LoadMediaPipeFaceLandmarker",
display_name="Load MediaPipe Face Landmarker", search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
display_name="Load Face Detection Model (MediaPipe)",
category="loaders", category="loaders",
inputs=[ inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"), io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"),
tooltip="Face Landmarker safetensors from models/mediapipe/."), tooltip="Face detection model from models/detection/."),
], ],
outputs=[FaceLandmarkerType.Output()], outputs=[FaceDetectionType.Output()],
) )
@classmethod @classmethod
def execute(cls, model_name) -> io.NodeOutput: def execute(cls, model_name) -> io.NodeOutput:
sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True) sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True)
wrapper = FaceLandmarkerModel(sd) wrapper = FaceLandmarkerModel(sd)
return io.NodeOutput(wrapper) return io.NodeOutput(wrapper)
@ -234,10 +235,12 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="MediaPipeFaceLandmarker", node_id="MediaPipeFaceLandmarker",
display_name="MediaPipe Face Landmarker", search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"],
display_name="Detect Face Landmarks (MediaPipe)",
category="image/detection", category="image/detection",
description="Detects facial landmarks using MediaPipe model.",
inputs=[ inputs=[
FaceLandmarkerType.Input("face_landmarker"), FaceDetectionType.Input("face_detection_model"),
io.Image.Input("image"), io.Image.Input("image"),
io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short", io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short",
tooltip="Face detector range. 'short' is tuned for close-up faces " tooltip="Face detector range. 'short' is tuned for close-up faces "
@ -261,9 +264,9 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence, def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence,
missing_frame_fallback) -> io.NodeOutput: missing_frame_fallback) -> io.NodeOutput:
canonical = face_landmarker.canonical_data canonical = face_detection_model.canonical_data
img_np = _image_to_uint8(image) img_np = _image_to_uint8(image)
B, H, W = img_np.shape[:3] B, H, W = img_np.shape[:3]
chunk = 16 chunk = 16
@ -276,7 +279,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq: with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq:
for i in range(0, B, chunk): for i in range(0, B, chunk):
end = min(i + chunk, B) end = min(i + chunk, B)
res.extend(face_landmarker.detect_batch( res.extend(face_detection_model.detect_batch(
[img_np[bi] for bi in range(i, end)], [img_np[bi] for bi in range(i, end)],
num_faces=int(num_faces), num_faces=int(num_faces),
score_thresh=float(min_confidence), score_thresh=float(min_confidence),
@ -306,7 +309,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode):
per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])}) per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])})
bboxes.append(per_bb) bboxes.append(per_bb)
return io.NodeOutput({"frames": frames, "image_size": (H, W), return io.NodeOutput({"frames": frames, "image_size": (H, W),
"connection_sets": face_landmarker.connection_sets}, bboxes) "connection_sets": face_detection_model.connection_sets}, bboxes)
# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose). # Topology keys unioned by the 'all' connections preset (contour parts + irises + nose).
@ -332,8 +335,10 @@ class MediaPipeFaceMeshVisualize(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="MediaPipeFaceMeshVisualize", node_id="MediaPipeFaceMeshVisualize",
display_name="MediaPipe Face Mesh Visualize", search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"],
display_name="Visualize Face Landmarks (MediaPipe)",
category="image/detection", category="image/detection",
description="Draws face landmarks mesh on the input image.",
inputs=[ inputs=[
FaceLandmarksType.Input("face_landmarks"), FaceLandmarksType.Input("face_landmarks"),
io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."), io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."),
@ -443,8 +448,10 @@ class MediaPipeFaceMask(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="MediaPipeFaceMask", node_id="MediaPipeFaceMask",
display_name="MediaPipe Face Mask", search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"],
display_name="Draw Face Mask (MediaPipe)",
category="image/detection", category="image/detection",
description="Draws a mask from face landmarks.",
inputs=[ inputs=[
FaceLandmarksType.Input("face_landmarks"), FaceLandmarksType.Input("face_landmarks"),
io.DynamicCombo.Input( io.DynamicCombo.Input(

View File

@ -60,7 +60,7 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions) folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")