ComfyUI/comfy/moge.py
2026-05-12 16:09:24 +03:00

164 lines
7.2 KiB
Python

"""High-level loader and inference wrapper for MoGe v1 / v2 checkpoints.
Mirrors the structure of :mod:`comfy.clip_vision`: owns the ``nn.Module`` and
a :class:`comfy.model_patcher.CoreModelPatcher`, exposes a
:meth:`MoGeModel.infer` that runs preprocessing, forward, and post-processing.
"""
from __future__ import annotations
from numbers import Number
from typing import Dict, Optional, Union
import torch
import comfy.model_management
import comfy.model_patcher
import comfy.ops
import comfy.utils
from .ldm.moge.geometry import (
depth_map_to_point_map,
intrinsics_from_focal_center,
recover_focal_shift,
)
from .ldm.moge.model import detect_and_build
from .ldm.moge.state_dict import remap_moge_state_dict
class MoGeModel:
"""Loaded MoGe model + ComfyUI memory management."""
def __init__(self, state_dict: dict):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
sd = remap_moge_state_dict(state_dict)
self.model = detect_and_build(sd, dtype=self.dtype, device=offload_device,
operations=comfy.ops.manual_cast)
self.model.load_state_dict(sd, strict=True)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(
self.model, load_device=self.load_device, offload_device=offload_device
)
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
self.num_tokens_range = (int(nt[0]), int(nt[1]))
@torch.inference_mode()
def infer(self,
image: torch.Tensor,
num_tokens: Optional[int] = None,
resolution_level: int = 9,
fov_x: Optional[Union[Number, torch.Tensor]] = None,
force_projection: bool = True,
apply_mask: bool = True) -> Dict[str, torch.Tensor]:
"""Run a single MoGe forward + post-process pass.
``image`` must already be ``(B, 3, H, W)`` in ``[0, 1]`` on any device.
Returns a dict with at least ``points``, ``depth``, ``intrinsics``,
``mask``; v2 checkpoints additionally produce ``normal``.
"""
comfy.model_management.load_model_gpu(self.patcher)
device = self.load_device
image = image.to(device=device, dtype=self.dtype)
if image.dim() == 3:
image = image.unsqueeze(0)
H, W = image.shape[-2:]
aspect_ratio = W / H
if num_tokens is None:
lo, hi = self.num_tokens_range
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
out = self.model.forward(image, num_tokens=num_tokens)
points = out.get("points")
normal = out.get("normal")
mask = out.get("mask")
metric_scale = out.get("metric_scale")
# Post-processing always runs in fp32 for numerical stability.
if points is not None: points = points.float()
if normal is not None: normal = normal.float()
if mask is not None: mask = mask.float()
if metric_scale is not None: metric_scale = metric_scale.float()
mask_binary = (mask > self.mask_threshold) if mask is not None else None
depth = None
intrinsics = None
if points is not None:
if fov_x is None:
focal, shift = recover_focal_shift(points, mask_binary)
# The unconstrained least-squares solver inside recover_focal_shift
# can converge to a degenerate solution where (z + shift) is
# negative for most pixels, which flips the sign of the
# estimated focal. Detect that and fall back to a sensible
# 60-degree-FoV default rather than emitting garbage geometry.
bad = ~torch.isfinite(focal) | (focal <= 0)
if bool(bad.any()):
default_fov = 60.0
fov_t = torch.as_tensor(default_fov, device=points.device, dtype=points.dtype)
fallback_focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 \
/ torch.tan(torch.deg2rad(fov_t / 2))
fallback_focal = fallback_focal.expand_as(focal).clone()
focal = torch.where(bad, fallback_focal, focal)
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
else:
fov_t = torch.as_tensor(fov_x, device=points.device, dtype=points.dtype)
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(fov_t / 2))
if focal.ndim == 0:
focal = focal[None].expand(points.shape[0])
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
intrinsics = intrinsics_from_focal_center(fx, fy, half, half)
points[..., 2] = points[..., 2] + shift[..., None, None]
# v2 upstream additionally filters mask by depth > 0 as a safeguard
# against negative-depth artifacts from the metric-scale path; v1
# does not, and applying it there can cut out the foreground when
# shift recovery overshoots slightly.
if mask_binary is not None and self.version == "v2":
mask_binary = mask_binary & (points[..., 2] > 0)
depth = points[..., 2].clone()
if force_projection and depth is not None and intrinsics is not None:
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
if metric_scale is not None:
if points is not None:
points = points * metric_scale[:, None, None, None]
if depth is not None:
depth = depth * metric_scale[:, None, None]
if apply_mask and mask_binary is not None:
if points is not None:
points = torch.where(mask_binary[..., None], points,
torch.full_like(points, float("inf")))
if depth is not None:
depth = torch.where(mask_binary, depth,
torch.full_like(depth, float("inf")))
if normal is not None:
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
result = {
"points": points,
"depth": depth,
"intrinsics": intrinsics,
"mask": mask_binary,
"normal": normal,
}
return {k: v for k, v in result.items() if v is not None}
def load(ckpt_path: str) -> MoGeModel:
"""Load a MoGe ``.pt`` / ``.safetensors`` checkpoint into a :class:`MoGeModel`."""
sd = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if isinstance(sd, dict) and "model" in sd and "model_config" in sd:
sd = sd["model"]
return MoGeModel(sd)