This commit is contained in:
kijai 2026-05-13 21:24:49 +03:00
parent 84254b388d
commit 6617e76b1c
7 changed files with 70 additions and 359 deletions

View File

@ -86,10 +86,10 @@ def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None, def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64) focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64)
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Recover the focal length and z-shift that turn ``points`` into a metric point map. """Recover the focal length and z-shift that turn points into a metric point map.
Optical center is at the image center; returned focal is relative to half the image diagonal. Optical center is at the image center; returned focal is relative to half the image diagonal.
Returns ``(focal, shift)`` on the same device/dtype as ``points``. Returns (focal, shift) on the same device/dtype as points.
""" """
shape = points.shape shape = points.shape
H, W = shape[-3], shape[-2] H, W = shape[-3], shape[-2]
@ -155,11 +155,11 @@ def depth_map_edge(depth: torch.Tensor, atol: Optional[float] = None, rtol: Opti
def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04, def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04,
depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Triangulate a (H, W, 3) point map into ``(vertices, faces, uvs)`` on CPU. """Triangulate a (H, W, 3) point map into (vertices, faces, uvs) on CPU.
Vertices: pixels with finite coords (passing optional ``mask``). Quads with four valid corners Vertices: pixels with finite coords (passing optional mask). Quads with four valid corners
become two triangles. ``depth`` overrides the scalar used for the rtol edge check; pass radial become two triangles. depth overrides the scalar used for the rtol edge check; pass radial
depth for panoramas (the default ``points[..., 2]`` goes negative below the equator). depth for panoramas (the default points[..., 2] goes negative below the equator).
""" """
points = points.detach().cpu() points = points.detach().cpu()
finite = torch.isfinite(points).all(dim=-1) finite = torch.isfinite(points).all(dim=-1)

View File

@ -24,7 +24,7 @@ from .modules import ConvStack, DINOv2Encoder, HeadV1, MLP, _view_plane_uv_grid
def _remap_points(points: torch.Tensor) -> torch.Tensor: def _remap_points(points: torch.Tensor) -> torch.Tensor:
"""Apply the ``exp`` remap: z -> exp(z), xy stays linear and gets scaled by the new z.""" """Apply the exp remap: z -> exp(z), xy stays linear and gets scaled by the new z."""
xy, z = points.split([2, 1], dim=-1) xy, z = points.split([2, 1], dim=-1)
z = torch.exp(z) z = torch.exp(z)
return torch.cat([xy * z, z], dim=-1) return torch.cat([xy * z, z], dim=-1)
@ -156,7 +156,7 @@ class MoGeModelV2(nn.Module):
@classmethod @classmethod
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast): def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
"""Detect the v2 encoder/neck/heads config from ``sd``, build a model, and load weights.""" """Detect the v2 encoder/neck/heads config from sd, build a model, and load weights."""
sd = _remap_state_dict(sd) sd = _remap_state_dict(sd)
backbone = _detect_dinov2(sd, prefix="encoder.backbone.") backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
depth = backbone["num_hidden_layers"] depth = backbone["num_hidden_layers"]
@ -184,7 +184,7 @@ class MoGeModelV2(nn.Module):
@staticmethod @staticmethod
def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]: def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]:
"""Reconstruct a ConvStack config from the keys under ``prefix``""" """Reconstruct a ConvStack config from the keys under prefix"""
in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")] in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")]
n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys) n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys)
@ -230,7 +230,6 @@ _DINOV2_BLOCK_RENAMES = [
def _remap_state_dict(sd: dict) -> dict: def _remap_state_dict(sd: dict) -> dict:
"""Unwrap the upstream ``{"model": ..., "model_config": ...}`` envelope and remap DINOv2 keys"""
if "model" in sd and "model_config" in sd: if "model" in sd and "model_config" in sd:
sd = sd["model"] sd = sd["model"]
prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone." prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone."
@ -272,6 +271,7 @@ class MoGeModel:
"""Loaded MoGe model + ComfyUI memory management.""" """Loaded MoGe model + ComfyUI memory management."""
def __init__(self, state_dict: dict): def __init__(self, state_dict: dict):
# text encoder dtype closest match
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@ -288,7 +288,7 @@ class MoGeModel:
force_projection: bool = True, apply_mask: bool = True, force_projection: bool = True, apply_mask: bool = True,
apply_metric_scale: bool = True apply_metric_scale: bool = True
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Run a single MoGe forward + post-process pass. ``image`` is (B, 3, H, W) in [0, 1].""" """Run a single MoGe forward + post-process pass. image is (B, 3, H, W) in [0, 1]."""
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
image = image.to(device=self.load_device, dtype=self.dtype) image = image.to(device=self.load_device, dtype=self.dtype)
H, W = image.shape[-2:] H, W = image.shape[-2:]

View File

@ -3,7 +3,7 @@
Splits an equirect into 12 perspective views via an icosahedron camera rig, runs Splits an equirect into 12 perspective views via an icosahedron camera rig, runs
the model per view, and stitches per-view distance maps back into a single the model per view, and stitches per-view distance maps back into a single
equirect distance map via a multi-scale Poisson + gradient sparse solve. equirect distance map via a multi-scale Poisson + gradient sparse solve.
Image sampling uses ``F.grid_sample`` (GPU); the sparse solve uses ``lsmr`` (CPU). Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
""" """
from __future__ import annotations from __future__ import annotations
@ -14,6 +14,10 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from scipy.ndimage import convolve, map_coordinates
from scipy.sparse import vstack, csr_array
from scipy.sparse.linalg import lsmr
def _icosahedron_directions() -> np.ndarray: def _icosahedron_directions() -> np.ndarray:
"""12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order).""" """12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order)."""
@ -122,7 +126,7 @@ def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarr
def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor: def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor:
"""Sample img_bchw at UV-in-[0,1] coords ``uv`` of shape (B, H, W, 2); replicate-border.""" """Sample img_bchw at UV-in-[0,1] coords uv of shape (B, H, W, 2); replicate-border."""
grid = uv * 2.0 - 1.0 grid = uv * 2.0 - 1.0
return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False) return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False)
@ -145,7 +149,6 @@ def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics
def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False): def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
"""Sparse Laplacian operator over the H x W grid.""" """Sparse Laplacian operator over the H x W grid."""
from scipy.sparse import csr_array
grid_index = np.arange(H * W).reshape(H, W) grid_index = np.arange(H * W).reshape(H, W)
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge") grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge")
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge") grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge")
@ -162,7 +165,6 @@ def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False
def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False): def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
"""Sparse forward-difference operator over the H x W grid.""" """Sparse forward-difference operator over the H x W grid."""
from scipy.sparse import csr_array
grid_index = np.arange(W * H).reshape(H, W) grid_index = np.arange(W * H).reshape(H, W)
if wrap_x: if wrap_x:
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap") grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap")
@ -191,7 +193,6 @@ def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray: def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray:
"""Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border.""" """Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border."""
from scipy.ndimage import map_coordinates
H, W = img.shape[:2] H, W = img.shape[:2]
yy = np.clip(sample_pixels[..., 1], 0, H - 1) yy = np.clip(sample_pixels[..., 1], 0, H - 1)
xx = np.clip(sample_pixels[..., 0], 0, W - 1) xx = np.clip(sample_pixels[..., 0], 0, W - 1)
@ -218,9 +219,6 @@ def merge_panorama_depth(width: int, height: int,
for the full-resolution solve. Optional callbacks fire per view processed and around each for the full-resolution solve. Optional callbacks fire per view processed and around each
lsmr solve so callers can drive a progress bar. lsmr solve so callers can drive a progress bar.
""" """
from scipy.ndimage import convolve
from scipy.sparse import vstack
from scipy.sparse.linalg import lsmr
if max(width, height) > 256: if max(width, height) > 256:
coarse_depth, _ = merge_panorama_depth(width // 2, height // 2, coarse_depth, _ = merge_panorama_depth(width // 2, height // 2,

View File

@ -1,94 +0,0 @@
"""Translate MoGe checkpoint keys to the layouts our nn.Modules use.
MoGe checkpoints embed DINOv2 with the original Meta naming
(``backbone.blocks.{i}.attn.qkv.weight``, ``ls1.gamma``, ``mlp.w12``, ...).
The shared ``comfy.image_encoders.dino2.Dinov2Model`` uses HF naming
(``encoder.layer.{i}.attention.attention.{query,key,value}.weight``,
``layer_scale1.lambda1``, ``mlp.weights_in``, ...). We rewrite keys at load
time and split the fused ``qkv`` weight into separate Q/K/V tensors.
"""
from __future__ import annotations
import re
_DINOV2_TOPLEVEL_RENAMES = {
"patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight",
"patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias",
"cls_token": "embeddings.cls_token",
"pos_embed": "embeddings.position_embeddings",
"register_tokens": "embeddings.register_tokens",
"mask_token": "embeddings.mask_token",
"norm.weight": "layernorm.weight",
"norm.bias": "layernorm.bias",
}
_BLOCK_SUFFIX_RENAMES = [
("ls1.gamma", "layer_scale1.lambda1"),
("ls2.gamma", "layer_scale2.lambda1"),
("attn.proj.", "attention.output.dense."),
("mlp.w12.", "mlp.weights_in."),
("mlp.w3.", "mlp.weights_out."),
]
_BLOCK_RE = re.compile(r"^blocks\.(\d+)\.(.+)$")
def remap_dinov2_keys(sd: dict, src_prefix: str = "") -> dict:
"""Rewrite Meta-style DINOv2 keys under ``src_prefix`` to comfy/HF naming.
Splits each fused ``attn.qkv.{weight,bias}`` into separate
``attention.attention.{query,key,value}.{weight,bias}`` tensors using a
chunk along the leading dim.
Keys that do not start with ``src_prefix`` are returned unchanged.
"""
out: dict = {}
for k, v in sd.items():
if not k.startswith(src_prefix):
out[k] = v
continue
rel = k[len(src_prefix):]
# Top-level (cls token, pos embed, patch embed, mask token, register tokens, final norm).
if rel in _DINOV2_TOPLEVEL_RENAMES:
out[src_prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v
continue
m = _BLOCK_RE.match(rel)
if not m:
out[k] = v
continue
i, sub = m.group(1), m.group(2)
# Split fused qkv into separate q / k / v tensors.
if sub == "attn.qkv.weight" or sub == "attn.qkv.bias":
q, kw, vw = v.chunk(3, dim=0)
tail = sub.rsplit(".", 1)[1] # weight / bias
base = "{}encoder.layer.{}.attention.attention".format(src_prefix, i)
out["{}.query.{}".format(base, tail)] = q
out["{}.key.{}".format(base, tail)] = kw
out["{}.value.{}".format(base, tail)] = vw
continue
for old, new in _BLOCK_SUFFIX_RENAMES:
sub = sub.replace(old, new)
out["{}encoder.layer.{}.{}".format(src_prefix, i, sub)] = v
return out
def remap_moge_state_dict(sd: dict) -> dict:
"""Convert a full MoGe checkpoint state dict to the layout our modules expect.
- v1 backbone lives under ``backbone.`` -> rewrite that subtree.
- v2 backbone lives under ``encoder.backbone.`` -> rewrite that subtree.
Everything else (heads, neck, projections, image_mean/std buffers) keeps
its original key names and passes through unchanged.
"""
if any(k.startswith("encoder.backbone.") for k in sd):
return remap_dinov2_keys(sd, src_prefix="encoder.backbone.")
return remap_dinov2_keys(sd, src_prefix="backbone.")

View File

@ -1,163 +0,0 @@
"""High-level loader and inference wrapper for MoGe v1 / v2 checkpoints.
Mirrors the structure of :mod:`comfy.clip_vision`: owns the ``nn.Module`` and
a :class:`comfy.model_patcher.CoreModelPatcher`, exposes a
:meth:`MoGeModel.infer` that runs preprocessing, forward, and post-processing.
"""
from __future__ import annotations
from numbers import Number
from typing import Dict, Optional, Union
import torch
import comfy.model_management
import comfy.model_patcher
import comfy.ops
import comfy.utils
from .ldm.moge.geometry import (
depth_map_to_point_map,
intrinsics_from_focal_center,
recover_focal_shift,
)
from .ldm.moge.model import detect_and_build
from .ldm.moge.state_dict import remap_moge_state_dict
class MoGeModel:
"""Loaded MoGe model + ComfyUI memory management."""
def __init__(self, state_dict: dict):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
sd = remap_moge_state_dict(state_dict)
self.model = detect_and_build(sd, dtype=self.dtype, device=offload_device,
operations=comfy.ops.manual_cast)
self.model.load_state_dict(sd, strict=True)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(
self.model, load_device=self.load_device, offload_device=offload_device
)
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
self.num_tokens_range = (int(nt[0]), int(nt[1]))
@torch.inference_mode()
def infer(self,
image: torch.Tensor,
num_tokens: Optional[int] = None,
resolution_level: int = 9,
fov_x: Optional[Union[Number, torch.Tensor]] = None,
force_projection: bool = True,
apply_mask: bool = True) -> Dict[str, torch.Tensor]:
"""Run a single MoGe forward + post-process pass.
``image`` must already be ``(B, 3, H, W)`` in ``[0, 1]`` on any device.
Returns a dict with at least ``points``, ``depth``, ``intrinsics``,
``mask``; v2 checkpoints additionally produce ``normal``.
"""
comfy.model_management.load_model_gpu(self.patcher)
device = self.load_device
image = image.to(device=device, dtype=self.dtype)
if image.dim() == 3:
image = image.unsqueeze(0)
H, W = image.shape[-2:]
aspect_ratio = W / H
if num_tokens is None:
lo, hi = self.num_tokens_range
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
out = self.model.forward(image, num_tokens=num_tokens)
points = out.get("points")
normal = out.get("normal")
mask = out.get("mask")
metric_scale = out.get("metric_scale")
# Post-processing always runs in fp32 for numerical stability.
if points is not None: points = points.float()
if normal is not None: normal = normal.float()
if mask is not None: mask = mask.float()
if metric_scale is not None: metric_scale = metric_scale.float()
mask_binary = (mask > self.mask_threshold) if mask is not None else None
depth = None
intrinsics = None
if points is not None:
if fov_x is None:
focal, shift = recover_focal_shift(points, mask_binary)
# The unconstrained least-squares solver inside recover_focal_shift
# can converge to a degenerate solution where (z + shift) is
# negative for most pixels, which flips the sign of the
# estimated focal. Detect that and fall back to a sensible
# 60-degree-FoV default rather than emitting garbage geometry.
bad = ~torch.isfinite(focal) | (focal <= 0)
if bool(bad.any()):
default_fov = 60.0
fov_t = torch.as_tensor(default_fov, device=points.device, dtype=points.dtype)
fallback_focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 \
/ torch.tan(torch.deg2rad(fov_t / 2))
fallback_focal = fallback_focal.expand_as(focal).clone()
focal = torch.where(bad, fallback_focal, focal)
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
else:
fov_t = torch.as_tensor(fov_x, device=points.device, dtype=points.dtype)
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(fov_t / 2))
if focal.ndim == 0:
focal = focal[None].expand(points.shape[0])
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
intrinsics = intrinsics_from_focal_center(fx, fy, half, half)
points[..., 2] = points[..., 2] + shift[..., None, None]
# v2 upstream additionally filters mask by depth > 0 as a safeguard
# against negative-depth artifacts from the metric-scale path; v1
# does not, and applying it there can cut out the foreground when
# shift recovery overshoots slightly.
if mask_binary is not None and self.version == "v2":
mask_binary = mask_binary & (points[..., 2] > 0)
depth = points[..., 2].clone()
if force_projection and depth is not None and intrinsics is not None:
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
if metric_scale is not None:
if points is not None:
points = points * metric_scale[:, None, None, None]
if depth is not None:
depth = depth * metric_scale[:, None, None]
if apply_mask and mask_binary is not None:
if points is not None:
points = torch.where(mask_binary[..., None], points,
torch.full_like(points, float("inf")))
if depth is not None:
depth = torch.where(mask_binary, depth,
torch.full_like(depth, float("inf")))
if normal is not None:
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
result = {
"points": points,
"depth": depth,
"intrinsics": intrinsics,
"mask": mask_binary,
"normal": normal,
}
return {k: v for k, v in result.items() if v is not None}
def load(ckpt_path: str) -> MoGeModel:
"""Load a MoGe ``.pt`` / ``.safetensors`` checkpoint into a :class:`MoGeModel`."""
sd = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if isinstance(sd, dict) and "model" in sd and "model_config" in sd:
sd = sd["model"]
return MoGeModel(sd)

View File

@ -2,9 +2,6 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch import torch
import comfy.utils import comfy.utils
@ -14,19 +11,21 @@ from typing_extensions import override
from comfy.ldm.moge.model import MoGeModel from comfy.ldm.moge.model import MoGeModel
from comfy.ldm.moge.geometry import triangulate_grid_mesh from comfy.ldm.moge.geometry import triangulate_grid_mesh
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
import comfy.model_management
from tqdm.auto import tqdm
MoGeModelType = io.Custom("MOGE_MODEL") MoGeModelType = io.Custom("MOGE_MODEL")
MoGeGeometry = io.Custom("MOGE_GEOMETRY") MoGeGeometry = io.Custom("MOGE_GEOMETRY")
@dataclass # MOGE_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
class _MoGeGeometryPayload: # "points": torch.Tensor (B, H, W, 3)
points: Optional[torch.Tensor] # (B, H, W, 3) # "depth": torch.Tensor (B, H, W)
depth: Optional[torch.Tensor] # (B, H, W) # "intrinsics": torch.Tensor (B, 3, 3) -- perspective only
intrinsics: Optional[torch.Tensor] # (B, 3, 3) # "mask": torch.Tensor (B, H, W) bool
mask: Optional[torch.Tensor] # (B, H, W) bool # "normal": torch.Tensor (B, H, W, 3) -- v2 only
normal: Optional[torch.Tensor] # (B, H, W, 3) or None for v1 # "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
image: torch.Tensor # (B, H, W, 3) in [0, 1], CPU
def _turbo(x: torch.Tensor) -> torch.Tensor: def _turbo(x: torch.Tensor) -> torch.Tensor:
@ -137,15 +136,7 @@ class MoGePanoramaInference(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, moge_model, image, resolution_level, def execute(cls, moge_model, image, resolution_level, split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
from comfy.ldm.moge.panorama import (
get_panorama_cameras, split_panorama_image, merge_panorama_depth,
spherical_uv_to_directions, _uv_grid,
)
import comfy.model_management as cmm
import numpy as np
from tqdm.auto import tqdm
if image.shape[0] != 1: if image.shape[0] != 1:
raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})") raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})")
@ -156,7 +147,7 @@ class MoGePanoramaInference(io.ComfyNode):
extrinsics, intrinsics = get_panorama_cameras() extrinsics, intrinsics = get_panorama_cameras()
cmm.load_model_gpu(moge_model.patcher) comfy.model_management.load_model_gpu(moge_model.patcher)
device = moge_model.load_device device = moge_model.load_device
img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype) img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype)
splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution) splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution)
@ -218,34 +209,25 @@ class MoGePanoramaInference(io.ComfyNode):
merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics, merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics,
on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end) on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end)
pano_depth = torch.from_numpy(pano_depth)
pano_mask = torch.from_numpy(pano_mask)
if (merge_h, merge_w) != (H, W): if (merge_h, merge_w) != (H, W):
t = torch.from_numpy(pano_depth).unsqueeze(0).unsqueeze(0) pano_depth = torch.nn.functional.interpolate(pano_depth[None, None], size=(H, W), mode="bilinear", align_corners=False).squeeze()
pano_depth = torch.nn.functional.interpolate(t, size=(H, W), mode="bilinear", pano_mask = torch.nn.functional.interpolate(pano_mask[None, None].float(), size=(H, W), mode="nearest").squeeze() > 0
align_corners=False).squeeze().numpy().astype(np.float32)
t = torch.from_numpy(pano_mask.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float()
pano_mask = (torch.nn.functional.interpolate(t, size=(H, W), mode="nearest").squeeze().numpy() > 0)
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve # Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve and stay at log_depth=0 (depth=1)
# and stay at log_depth=0 (depth=1) -- without this push-out they form a sphere shell
# woven through the foreground; here we lift them to a far skybox radius instead.
if pano_mask.any() and not pano_mask.all(): if pano_mask.any() and not pano_mask.all():
far = float(np.quantile(pano_depth[pano_mask], 0.95)) * 5.0 far = torch.quantile(pano_depth[pano_mask], 0.95) * 5.0
pano_depth = np.where(pano_mask, pano_depth, far).astype(np.float32) pano_depth = torch.where(pano_mask, pano_depth, far)
uv = _uv_grid(H, W) directions = torch.from_numpy(spherical_uv_to_directions(_uv_grid(H, W)))
directions = spherical_uv_to_directions(uv) points = (directions * pano_depth[..., None]).unsqueeze(0)
points_np = directions * pano_depth[..., None] depth = pano_depth.unsqueeze(0)
mask = pano_mask.unsqueeze(0)
points = torch.from_numpy(points_np).unsqueeze(0).float() # Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation after triangulation
depth = torch.from_numpy(pano_depth).unsqueeze(0).float() geometry = {"points": points, "depth": depth, "mask": mask, "image": image.cpu()}
mask = torch.from_numpy(pano_mask).unsqueeze(0)
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation
# after triangulation -- rotating before would scramble the rtol depth-edge check.
geometry = _MoGeGeometryPayload(
points=points, depth=depth, intrinsics=None, mask=mask, normal=None,
image=image.detach().cpu(),
)
return io.NodeOutput(geometry) return io.NodeOutput(geometry)
@ -273,9 +255,7 @@ class MoGeInference(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, moge_model, image, resolution_level, fov_x_degrees, def execute(cls, moge_model, image, resolution_level, fov_x_degrees, batch_size, force_projection, apply_mask) -> io.NodeOutput:
batch_size, force_projection, apply_mask) -> io.NodeOutput:
from tqdm.auto import tqdm
bchw = image.movedim(-1, -3).contiguous() bchw = image.movedim(-1, -3).contiguous()
B = bchw.shape[0] B = bchw.shape[0]
@ -295,20 +275,14 @@ class MoGeInference(io.ComfyNode):
vals = [c[field] for c in chunks if field in c] vals = [c[field] for c in chunks if field in c]
return torch.cat(vals, dim=0) if vals else None return torch.cat(vals, dim=0) if vals else None
geometry = _MoGeGeometryPayload( geometry = {"image": image.cpu()}
points=stack("points"), for field in ("points", "depth", "intrinsics", "mask", "normal"):
depth=stack("depth"), v = stack(field)
intrinsics=stack("intrinsics"), if v is not None:
mask=stack("mask"), geometry[field] = v
normal=stack("normal"),
image=image.detach().cpu(),
)
return io.NodeOutput(geometry) return io.NodeOutput(geometry)
_RENDER_MODES = ["depth", "depth_colored", "normal", "normal_screen", "mask"]
class MoGeRender(io.ComfyNode): class MoGeRender(io.ComfyNode):
"""Render a visualization or mask from a MOGE_GEOMETRY packet.""" """Render a visualization or mask from a MOGE_GEOMETRY packet."""
@ -320,36 +294,32 @@ class MoGeRender(io.ComfyNode):
category="image/geometry", category="image/geometry",
inputs=[ inputs=[
MoGeGeometry.Input("geometry"), MoGeGeometry.Input("geometry"),
io.Combo.Input("output", options=_RENDER_MODES, default="depth_colored"), io.Combo.Input("output", options=["depth", "depth_colored", "normal", "normal_screen", "mask"], default="depth"),
], ],
outputs=[io.Image.Output()], outputs=[io.Image.Output()],
) )
@classmethod @classmethod
def execute(cls, geometry, output) -> io.NodeOutput: def execute(cls, geometry, output) -> io.NodeOutput:
from tqdm.auto import tqdm
# Pick the input tensor for the chosen mode and validate availability. # Pick the input tensor for the chosen mode and validate availability.
if output in ("depth", "depth_colored", "normal_screen"): if output in ("depth", "depth_colored", "normal_screen"):
if geometry.depth is None: if "depth" not in geometry:
raise ValueError("MoGeGeometry has no depth output.") raise ValueError("MoGeGeometry has no depth output.")
src = geometry.depth src = geometry["depth"]
elif output == "normal": elif output == "normal":
if geometry.normal is not None: if "normal" in geometry:
src = geometry.normal src = geometry["normal"]
elif geometry.points is not None: elif "points" in geometry:
src = geometry.points src = geometry["points"]
else: else:
raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.") raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.")
elif output == "mask": elif output == "mask":
if geometry.mask is None: if "mask" not in geometry:
raise ValueError("MoGeGeometry has no mask output.") raise ValueError("MoGeGeometry has no mask output.")
src = geometry.mask src = geometry["mask"]
else: else:
raise ValueError(f"Unknown output mode: {output}") raise ValueError(f"Unknown output mode: {output}")
import comfy.model_management as cmm
B = src.shape[0] B = src.shape[0]
pbar = comfy.utils.ProgressBar(B) pbar = comfy.utils.ProgressBar(B)
out: list[torch.Tensor] = [] out: list[torch.Tensor] = []
@ -361,7 +331,7 @@ class MoGeRender(io.ComfyNode):
out.append(_turbo(d) if output == "depth_colored" out.append(_turbo(d) if output == "depth_colored"
else d.unsqueeze(-1).expand(*d.shape, 3).contiguous()) else d.unsqueeze(-1).expand(*d.shape, 3).contiguous())
elif output == "normal": elif output == "normal":
n = slc if geometry.normal is not None else _normals_from_points(slc) n = slc if "normal" in geometry else _normals_from_points(slc)
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0)) out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
elif output == "normal_screen": elif output == "normal_screen":
n = _screen_normals_from_depth(slc) n = _screen_normals_from_depth(slc)
@ -370,7 +340,7 @@ class MoGeRender(io.ComfyNode):
out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous()) out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous())
pbar.update_absolute(i + 1) pbar.update_absolute(i + 1)
tq.update(1) tq.update(1)
result = torch.cat(out, dim=0).to(device=cmm.intermediate_device(), dtype=cmm.intermediate_dtype()) result = torch.cat(out, dim=0).to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput(result) return io.NodeOutput(result)
@ -385,7 +355,7 @@ class MoGePointMapToMesh(io.ComfyNode):
category="3d", category="3d",
inputs=[ inputs=[
MoGeGeometry.Input("geometry"), MoGeGeometry.Input("geometry"),
io.Int.Input("batch_index", default=0, min=0, max=64, io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts " tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts "
"differ, so batches can't be stacked into a single MESH."), "differ, so batches can't be stacked into a single MESH."),
io.Int.Input("decimation", default=1, min=1, max=8, io.Int.Input("decimation", default=1, min=1, max=8,
@ -400,32 +370,32 @@ class MoGePointMapToMesh(io.ComfyNode):
@classmethod @classmethod
def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput: def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput:
if geometry.points is None: if "points" not in geometry:
raise ValueError("MoGeGeometry has no points output.") raise ValueError("MoGeGeometry has no points output.")
B = geometry.points.shape[0] points = geometry["points"]
B = points.shape[0]
if batch_index >= B: if batch_index >= B:
raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.") raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.")
# Pass geometry.depth so the rtol edge check sees radial depth -- for panoramas # Pass depth so the rtol edge check sees radial depth -- for panoramas
# points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half. # points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half.
edge_depth = geometry.depth[batch_index] if geometry.depth is not None else None edge_depth = geometry["depth"][batch_index] if "depth" in geometry else None
verts, faces, uvs = triangulate_grid_mesh( verts, faces, uvs = triangulate_grid_mesh(
geometry.points[batch_index], decimation=decimation, points[batch_index], decimation=decimation,
discontinuity_threshold=discontinuity_threshold, depth=edge_depth, discontinuity_threshold=discontinuity_threshold, depth=edge_depth,
) )
if verts.shape[0] == 0 or faces.shape[0] == 0: if verts.shape[0] == 0 or faces.shape[0] == 0:
raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.") raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.")
if geometry.intrinsics is None: if "intrinsics" not in geometry:
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back). Pure rotation # Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back), correct for inside-the-sphere viewing)
# preserves the natural inward winding (correct for inside-the-sphere viewing).
verts = verts[:, [1, 2, 0]].contiguous() verts = verts[:, [1, 2, 0]].contiguous()
else: else:
# Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip. # Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip.
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype) verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
faces = faces[:, [0, 2, 1]].contiguous() faces = faces[:, [0, 2, 1]].contiguous()
tex = geometry.image[batch_index:batch_index + 1] if texture and geometry.image is not None else None tex = geometry["image"][batch_index:batch_index + 1] if texture else None
mesh = Types.MESH( mesh = Types.MESH(
vertices=verts.unsqueeze(0), vertices=verts.unsqueeze(0),
faces=faces.unsqueeze(0), faces=faces.unsqueeze(0),