mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
164 lines
7.2 KiB
Python
164 lines
7.2 KiB
Python
"""High-level loader and inference wrapper for MoGe v1 / v2 checkpoints.
|
|
|
|
Mirrors the structure of :mod:`comfy.clip_vision`: owns the ``nn.Module`` and
|
|
a :class:`comfy.model_patcher.CoreModelPatcher`, exposes a
|
|
:meth:`MoGeModel.infer` that runs preprocessing, forward, and post-processing.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from numbers import Number
|
|
from typing import Dict, Optional, Union
|
|
|
|
import torch
|
|
|
|
import comfy.model_management
|
|
import comfy.model_patcher
|
|
import comfy.ops
|
|
import comfy.utils
|
|
|
|
from .ldm.moge.geometry import (
|
|
depth_map_to_point_map,
|
|
intrinsics_from_focal_center,
|
|
recover_focal_shift,
|
|
)
|
|
from .ldm.moge.model import detect_and_build
|
|
from .ldm.moge.state_dict import remap_moge_state_dict
|
|
|
|
|
|
class MoGeModel:
|
|
"""Loaded MoGe model + ComfyUI memory management."""
|
|
|
|
def __init__(self, state_dict: dict):
|
|
self.load_device = comfy.model_management.text_encoder_device()
|
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
|
|
|
sd = remap_moge_state_dict(state_dict)
|
|
self.model = detect_and_build(sd, dtype=self.dtype, device=offload_device,
|
|
operations=comfy.ops.manual_cast)
|
|
self.model.load_state_dict(sd, strict=True)
|
|
self.model.eval()
|
|
self.patcher = comfy.model_patcher.CoreModelPatcher(
|
|
self.model, load_device=self.load_device, offload_device=offload_device
|
|
)
|
|
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
|
|
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
|
|
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
|
|
self.num_tokens_range = (int(nt[0]), int(nt[1]))
|
|
|
|
@torch.inference_mode()
|
|
def infer(self,
|
|
image: torch.Tensor,
|
|
num_tokens: Optional[int] = None,
|
|
resolution_level: int = 9,
|
|
fov_x: Optional[Union[Number, torch.Tensor]] = None,
|
|
force_projection: bool = True,
|
|
apply_mask: bool = True) -> Dict[str, torch.Tensor]:
|
|
"""Run a single MoGe forward + post-process pass.
|
|
|
|
``image`` must already be ``(B, 3, H, W)`` in ``[0, 1]`` on any device.
|
|
Returns a dict with at least ``points``, ``depth``, ``intrinsics``,
|
|
``mask``; v2 checkpoints additionally produce ``normal``.
|
|
"""
|
|
comfy.model_management.load_model_gpu(self.patcher)
|
|
device = self.load_device
|
|
image = image.to(device=device, dtype=self.dtype)
|
|
|
|
if image.dim() == 3:
|
|
image = image.unsqueeze(0)
|
|
H, W = image.shape[-2:]
|
|
aspect_ratio = W / H
|
|
|
|
if num_tokens is None:
|
|
lo, hi = self.num_tokens_range
|
|
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
|
|
|
|
out = self.model.forward(image, num_tokens=num_tokens)
|
|
points = out.get("points")
|
|
normal = out.get("normal")
|
|
mask = out.get("mask")
|
|
metric_scale = out.get("metric_scale")
|
|
|
|
# Post-processing always runs in fp32 for numerical stability.
|
|
if points is not None: points = points.float()
|
|
if normal is not None: normal = normal.float()
|
|
if mask is not None: mask = mask.float()
|
|
if metric_scale is not None: metric_scale = metric_scale.float()
|
|
|
|
mask_binary = (mask > self.mask_threshold) if mask is not None else None
|
|
|
|
depth = None
|
|
intrinsics = None
|
|
if points is not None:
|
|
if fov_x is None:
|
|
focal, shift = recover_focal_shift(points, mask_binary)
|
|
# The unconstrained least-squares solver inside recover_focal_shift
|
|
# can converge to a degenerate solution where (z + shift) is
|
|
# negative for most pixels, which flips the sign of the
|
|
# estimated focal. Detect that and fall back to a sensible
|
|
# 60-degree-FoV default rather than emitting garbage geometry.
|
|
bad = ~torch.isfinite(focal) | (focal <= 0)
|
|
if bool(bad.any()):
|
|
default_fov = 60.0
|
|
fov_t = torch.as_tensor(default_fov, device=points.device, dtype=points.dtype)
|
|
fallback_focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 \
|
|
/ torch.tan(torch.deg2rad(fov_t / 2))
|
|
fallback_focal = fallback_focal.expand_as(focal).clone()
|
|
focal = torch.where(bad, fallback_focal, focal)
|
|
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
|
else:
|
|
fov_t = torch.as_tensor(fov_x, device=points.device, dtype=points.dtype)
|
|
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(fov_t / 2))
|
|
if focal.ndim == 0:
|
|
focal = focal[None].expand(points.shape[0])
|
|
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
|
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
|
|
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
|
|
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
|
|
intrinsics = intrinsics_from_focal_center(fx, fy, half, half)
|
|
points[..., 2] = points[..., 2] + shift[..., None, None]
|
|
# v2 upstream additionally filters mask by depth > 0 as a safeguard
|
|
# against negative-depth artifacts from the metric-scale path; v1
|
|
# does not, and applying it there can cut out the foreground when
|
|
# shift recovery overshoots slightly.
|
|
if mask_binary is not None and self.version == "v2":
|
|
mask_binary = mask_binary & (points[..., 2] > 0)
|
|
depth = points[..., 2].clone()
|
|
|
|
if force_projection and depth is not None and intrinsics is not None:
|
|
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
|
|
|
|
if metric_scale is not None:
|
|
if points is not None:
|
|
points = points * metric_scale[:, None, None, None]
|
|
if depth is not None:
|
|
depth = depth * metric_scale[:, None, None]
|
|
|
|
if apply_mask and mask_binary is not None:
|
|
if points is not None:
|
|
points = torch.where(mask_binary[..., None], points,
|
|
torch.full_like(points, float("inf")))
|
|
if depth is not None:
|
|
depth = torch.where(mask_binary, depth,
|
|
torch.full_like(depth, float("inf")))
|
|
if normal is not None:
|
|
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
|
|
|
|
result = {
|
|
"points": points,
|
|
"depth": depth,
|
|
"intrinsics": intrinsics,
|
|
"mask": mask_binary,
|
|
"normal": normal,
|
|
}
|
|
return {k: v for k, v in result.items() if v is not None}
|
|
|
|
|
|
def load(ckpt_path: str) -> MoGeModel:
|
|
"""Load a MoGe ``.pt`` / ``.safetensors`` checkpoint into a :class:`MoGeModel`."""
|
|
sd = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
|
if isinstance(sd, dict) and "model" in sd and "model_config" in sd:
|
|
sd = sd["model"]
|
|
return MoGeModel(sd)
|