Cleanup, add MoGeGeometryToFOV

This commit is contained in:
kijai 2026-06-01 01:05:40 +03:00
parent ebd9c6e620
commit 59dc7ac152
5 changed files with 61 additions and 66 deletions

View File

@ -1,6 +1,8 @@
"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration.""" """ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
import math
import torch import torch
import comfy.utils import comfy.utils
@ -403,10 +405,57 @@ class MoGePointMapToMesh(io.ComfyNode):
return io.NodeOutput(mesh) return io.NodeOutput(mesh)
class MoGeGeometryToFOV(io.ComfyNode):
"""Extract horizontal/vertical FOV from MoGe intrinsics, e.g. fov_y to feed SAM3DBody_Predict."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGeGeometryToFOV",
search_aliases=["moge", "fov", "geometry", "intrinsics", "field of view"],
display_name="Get FoV from MoGe Geometry",
description="Derive the field of view and focal length from MoGe intrinsics.",
category="image/geometry estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),
io.Combo.Input("axis", options=["vertical", "horizontal", "diagonal"], default="vertical",
tooltip="'vertical' (fov_y), 'horizontal' (fov_x), or 'diagonal'."),
io.Combo.Input("unit", options=["degrees", "radians"], default="degrees",
tooltip="Output unit for the FOV."),
],
outputs=[
io.Float.Output(display_name="fov"),
io.Float.Output(display_name="focal_pixels"),
],
)
@classmethod
def execute(cls, moge_geometry, axis, unit) -> io.NodeOutput:
K = moge_geometry.get("intrinsics") if isinstance(moge_geometry, dict) else None
if K is None:
raise ValueError("moge_geometry has no intrinsics (panorama geometry has none).")
if K.ndim == 3:
K = K[0]
# MoGe normalizes fx by width and fy by height; with cx=cy=0.5 the half-extent
# in normalized units is 0.5, so fov = 2*atan(0.5 / f) per axis (hypot for diagonal).
hx = 0.5 / float(K[0, 0].item())
hy = 0.5 / float(K[1, 1].item())
half_tan = {"horizontal": hx, "vertical": hy, "diagonal": math.hypot(hx, hy)}[axis]
fov_radians = 2.0 * math.atan(half_tan)
fov = fov_radians if unit == "radians" else math.degrees(fov_radians)
# Pixels are square here, so fy*H == fx*W is the single lens focal in pixels.
src = next((moge_geometry[k] for k in ("image", "points", "depth") if k in moge_geometry), None)
if src is None:
raise ValueError("moge_geometry has no image/points/depth to read the pixel height from.")
H = int(src.shape[1])
focal_pixels = float(K[1, 1].item()) * H
return io.NodeOutput(fov, focal_pixels)
class MoGeExtension(ComfyExtension): class MoGeExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh] return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV]
async def comfy_entrypoint() -> MoGeExtension: async def comfy_entrypoint() -> MoGeExtension:

View File

@ -19,7 +19,6 @@ from comfy.ldm.sam3d_body.model.dinov3 import apply_dinov3_qkv_bias_mask
from comfy_extras.sam3d_body.utils import ( from comfy_extras.sam3d_body.utils import (
apply_camera_override, apply_camera_override,
cam_int_from_fov, cam_int_from_fov,
cam_int_from_moge,
inputs_from_sam3_track, inputs_from_sam3_track,
run_batched_frames, run_batched_frames,
run_batched_single_chunk, run_batched_single_chunk,
@ -42,7 +41,6 @@ SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
# KIMODO_POSE_DATA via a MultiType union — those types are mirrored there. # KIMODO_POSE_DATA via a MultiType union — those types are mirrored there.
MHRPoseData = io.Custom("MHR_POSE_DATA") MHRPoseData = io.Custom("MHR_POSE_DATA")
SAM3DBodyModel = io.Custom("SAM3D_BODY_MODEL") SAM3DBodyModel = io.Custom("SAM3D_BODY_MODEL")
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
# Loader # Loader
@ -153,18 +151,10 @@ class SAM3DBody_Predict(io.ComfyNode):
io.Float.Input( io.Float.Input(
"fov_degrees", "fov_degrees",
default=0.0, min=0.0, max=170.0, step=0.5, default=0.0, min=0.0, max=170.0, step=0.5,
tooltip=( #TODO: get FoV from moge another way?
"Vertical FOV in degrees. Affects predicted depth (cam_t.z) and "
"absolute scale. 0 = use moge_geometry or fall back to ~53° (16:9). "
"Any non-zero value overrides moge_geometry."
),
),
MoGeGeometry.Input(
"moge_geometry",
optional=True,
tooltip=( tooltip=(
"MoGe geometry, used to calculate camera field of view." "Vertical FOV in degrees. Affects predicted depth (cam_t.z) and "
"For batches choose the most representative frame, or leave unset" "absolute scale. 0 = fall back to ~53° (16:9). Feed MoGeGeometryToFOV "
"here to derive it from a MoGe estimate."
), ),
), ),
io.Int.Input( io.Int.Input(
@ -180,7 +170,7 @@ class SAM3DBody_Predict(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, sam3d_body_model, image, sam3_track_data=None, bboxes=None, run_hand_refinement=True, fov_degrees=0.0, moge_geometry=None, chunk_size=64) -> io.NodeOutput: def execute(cls, sam3d_body_model, image, sam3_track_data=None, bboxes=None, run_hand_refinement=True, fov_degrees=0.0, chunk_size=64) -> io.NodeOutput:
comfy.model_management.load_model_gpu(sam3d_body_model) comfy.model_management.load_model_gpu(sam3d_body_model)
inner: SAM3DBody = sam3d_body_model.model inner: SAM3DBody = sam3d_body_model.model
@ -200,10 +190,8 @@ class SAM3DBody_Predict(io.ComfyNode):
per_frame_bboxes = [full_frame_bbox.clone() for _ in range(B)] per_frame_bboxes = [full_frame_bbox.clone() for _ in range(B)]
per_frame_masks = None per_frame_masks = None
inference_type = "full" if run_hand_refinement else "body" inference_type = "full" if run_hand_refinement else "body"
# Precedence: explicit fov_degrees > MoGe estimate > diagonal default. # fov_degrees > 0 sets intrinsics; else None falls back to prepare_batch's diagonal default.
cam_int = cam_int_from_fov(int(H), int(W), float(fov_degrees)) cam_int = cam_int_from_fov(int(H), int(W), float(fov_degrees))
if cam_int is None:
cam_int = cam_int_from_moge(moge_geometry, int(H), int(W))
frames_rgb: List[Optional[torch.Tensor]] = [] frames_rgb: List[Optional[torch.Tensor]] = []
for f in range(B): for f in range(B):

View File

@ -480,14 +480,6 @@ class BuildPoseGLB(IO.ComfyNode):
), ),
), ),
]), ]),
IO.DynamicCombo.Option("sticks", [
IO.Combo.Input(
"bone_vis_color",
options=["white", "rainbow_y"],
default="rainbow_y",
tooltip="Per-bone vertex colors (see octahedrons).",
),
]),
], ],
tooltip=("Bone vis shape, rigidly skinned to each joint. "), tooltip=("Bone vis shape, rigidly skinned to each joint. "),
), ),
@ -546,19 +538,11 @@ class BuildPoseGLB(IO.ComfyNode):
), ),
), ),
]), ]),
IO.DynamicCombo.Option("sticks", [
IO.Combo.Input(
"bone_vis_color",
options=["white", "rainbow_y"],
default="rainbow_y",
tooltip="Per-bone vertex colors (see octahedrons).",
),
]),
], ],
tooltip=( tooltip=(
"Bone vis shape, rigidly skinned to each joint. " "Bone vis shape, rigidly skinned to each joint. "
"'octahedrons' = Blender-style directional bones (joint → " "'octahedrons' = Blender-style directional bones (joint → "
"primary child); 'sticks' = thin lines." "primary child)."
), ),
), ),
]), ]),

View File

@ -8,7 +8,7 @@ Rebuilds an Armature with the MHR 127-bone rig:
- facial expression is re-exposed as 72 morph targets driven by `expr_params` - facial expression is re-exposed as 72 morph targets driven by `expr_params`
so face animation survives plain glTF skinning. so face animation survives plain glTF skinning.
Optional bone visualization (octahedrons / sticks) is rigidly Optional bone visualization (octahedrons) is rigidly
skinned alongside the body mesh used to preview the armature in glTF skinned alongside the body mesh used to preview the armature in glTF
viewers that don't draw bones. viewers that don't draw bones.
@ -323,7 +323,7 @@ def build_glb_skeletal(
skin_idx = len(skins) - 1 skin_idx = len(skins) - 1
include_body = bool(include_body_mesh) include_body = bool(include_body_mesh)
include_bones = bone_vis in ("octahedrons", "sticks") include_bones = bone_vis == "octahedrons"
body_mesh_node_idx: Optional[int] = None body_mesh_node_idx: Optional[int] = None
if include_body: if include_body:
@ -386,13 +386,10 @@ def build_glb_skeletal(
if include_bones: if include_bones:
bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color) bone_palette = _bone_colors_rgb(bind_global_m[:, :3], bone_vis_color)
# Indexes `bone_palette`: octahedrons/sticks use the bone's child # Indexes `bone_palette`: octahedrons use the bone's child joint so
# joint so every bone has its own color regardless of skin target. # every bone has its own color regardless of skin target.
# 'sticks' = thin octahedrons. glTF LINES skinning is unreliable
# (Three.js' GLTFLoader doesn't animate skinned line primitives),
# so we render triangle tubes instead.
color_idx_per_vert: Optional[np.ndarray] = None color_idx_per_vert: Optional[np.ndarray] = None
hw = float(bone_vis_radius_m) if bone_vis == "octahedrons" else 0.0035 hw = float(bone_vis_radius_m)
bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh( bv_v, bv_n, bv_f, bv_j, bv_w, child_per_vert = _build_bone_octahedrons_mesh(
bind_global_m[:, :3], rig_static["parents"], half_width_m=hw, bind_global_m[:, :3], rig_static["parents"], half_width_m=hw,
) )

View File

@ -82,29 +82,6 @@ def cam_int_from_fov(height: int, width: int, fov_degrees: float) -> Optional[to
) )
def cam_int_from_moge(moge_geometry, height: int, width: int) -> Optional[torch.Tensor]:
"""(1,3,3) intrinsic matrix from a MoGe geometry payload. Uses MoGe's
vertical focal for both axes; forces principal point to image center
(overrides MoGe's predicted cx/cy to match prepare_batch's convention)."""
if moge_geometry is None:
return None
# MOGE_GEOMETRY is a dict with optional keys (see comfy_extras/nodes_moge.py).
K_norm = moge_geometry.get("intrinsics") if isinstance(moge_geometry, dict) else None
if K_norm is None:
return None
if K_norm.ndim == 3:
K_norm = K_norm[0]
# MoGe stores fy in height-units (multiply by H to get pixels); vfov = fy.
fy_norm = float(K_norm[1, 1].item())
focal = fy_norm * height
return torch.tensor(
[[[focal, 0.0, width / 2.0],
[0.0, focal, height / 2.0],
[0.0, 0.0, 1.0]]],
dtype=torch.float32,
)
def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any], def apply_camera_override(mhr_pose_data: Dict[str, Any], camera_info: Dict[str, Any],
H: int, W: int, fov_deg: float = 0.0) -> Dict[str, Any]: H: int, W: int, fov_deg: float = 0.0) -> Dict[str, Any]:
"""Re-project every frame's pose through a Load3D 6DOF camera (position/ """Re-project every frame's pose through a Load3D 6DOF camera (position/