This commit is contained in:
kijai 2026-05-13 21:24:49 +03:00
parent 84254b388d
commit 6617e76b1c
7 changed files with 70 additions and 359 deletions

View File

@ -86,10 +86,10 @@ def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Recover the focal length and z-shift that turn ``points`` into a metric point map.
"""Recover the focal length and z-shift that turn points into a metric point map.
Optical center is at the image center; returned focal is relative to half the image diagonal.
Returns ``(focal, shift)`` on the same device/dtype as ``points``.
Returns (focal, shift) on the same device/dtype as points.
"""
shape = points.shape
H, W = shape[-3], shape[-2]
@ -155,11 +155,11 @@ def depth_map_edge(depth: torch.Tensor, atol: Optional[float] = None, rtol: Opti
def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04,
depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Triangulate a (H, W, 3) point map into ``(vertices, faces, uvs)`` on CPU.
"""Triangulate a (H, W, 3) point map into (vertices, faces, uvs) on CPU.
Vertices: pixels with finite coords (passing optional ``mask``). Quads with four valid corners
become two triangles. ``depth`` overrides the scalar used for the rtol edge check; pass radial
depth for panoramas (the default ``points[..., 2]`` goes negative below the equator).
Vertices: pixels with finite coords (passing optional mask). Quads with four valid corners
become two triangles. depth overrides the scalar used for the rtol edge check; pass radial
depth for panoramas (the default points[..., 2] goes negative below the equator).
"""
points = points.detach().cpu()
finite = torch.isfinite(points).all(dim=-1)

View File

@ -24,7 +24,7 @@ from .modules import ConvStack, DINOv2Encoder, HeadV1, MLP, _view_plane_uv_grid
def _remap_points(points: torch.Tensor) -> torch.Tensor:
"""Apply the ``exp`` remap: z -> exp(z), xy stays linear and gets scaled by the new z."""
"""Apply the exp remap: z -> exp(z), xy stays linear and gets scaled by the new z."""
xy, z = points.split([2, 1], dim=-1)
z = torch.exp(z)
return torch.cat([xy * z, z], dim=-1)
@ -156,7 +156,7 @@ class MoGeModelV2(nn.Module):
@classmethod
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
"""Detect the v2 encoder/neck/heads config from ``sd``, build a model, and load weights."""
"""Detect the v2 encoder/neck/heads config from sd, build a model, and load weights."""
sd = _remap_state_dict(sd)
backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
depth = backbone["num_hidden_layers"]
@ -184,7 +184,7 @@ class MoGeModelV2(nn.Module):
@staticmethod
def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]:
"""Reconstruct a ConvStack config from the keys under ``prefix``"""
"""Reconstruct a ConvStack config from the keys under prefix"""
in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")]
n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys)
@ -230,7 +230,6 @@ _DINOV2_BLOCK_RENAMES = [
def _remap_state_dict(sd: dict) -> dict:
"""Unwrap the upstream ``{"model": ..., "model_config": ...}`` envelope and remap DINOv2 keys"""
if "model" in sd and "model_config" in sd:
sd = sd["model"]
prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone."
@ -272,6 +271,7 @@ class MoGeModel:
"""Loaded MoGe model + ComfyUI memory management."""
def __init__(self, state_dict: dict):
# text encoder dtype closest match
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@ -288,7 +288,7 @@ class MoGeModel:
force_projection: bool = True, apply_mask: bool = True,
apply_metric_scale: bool = True
) -> Dict[str, torch.Tensor]:
"""Run a single MoGe forward + post-process pass. ``image`` is (B, 3, H, W) in [0, 1]."""
"""Run a single MoGe forward + post-process pass. image is (B, 3, H, W) in [0, 1]."""
comfy.model_management.load_model_gpu(self.patcher)
image = image.to(device=self.load_device, dtype=self.dtype)
H, W = image.shape[-2:]

View File

@ -3,7 +3,7 @@
Splits an equirect into 12 perspective views via an icosahedron camera rig, runs
the model per view, and stitches per-view distance maps back into a single
equirect distance map via a multi-scale Poisson + gradient sparse solve.
Image sampling uses ``F.grid_sample`` (GPU); the sparse solve uses ``lsmr`` (CPU).
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
"""
from __future__ import annotations
@ -14,6 +14,10 @@ import numpy as np
import torch
import torch.nn.functional as F
from scipy.ndimage import convolve, map_coordinates
from scipy.sparse import vstack, csr_array
from scipy.sparse.linalg import lsmr
def _icosahedron_directions() -> np.ndarray:
"""12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order)."""
@ -122,7 +126,7 @@ def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarr
def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor:
"""Sample img_bchw at UV-in-[0,1] coords ``uv`` of shape (B, H, W, 2); replicate-border."""
"""Sample img_bchw at UV-in-[0,1] coords uv of shape (B, H, W, 2); replicate-border."""
grid = uv * 2.0 - 1.0
return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False)
@ -145,7 +149,6 @@ def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics
def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
"""Sparse Laplacian operator over the H x W grid."""
from scipy.sparse import csr_array
grid_index = np.arange(H * W).reshape(H, W)
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge")
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge")
@ -162,7 +165,6 @@ def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False
def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
"""Sparse forward-difference operator over the H x W grid."""
from scipy.sparse import csr_array
grid_index = np.arange(W * H).reshape(H, W)
if wrap_x:
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap")
@ -191,7 +193,6 @@ def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray:
"""Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border."""
from scipy.ndimage import map_coordinates
H, W = img.shape[:2]
yy = np.clip(sample_pixels[..., 1], 0, H - 1)
xx = np.clip(sample_pixels[..., 0], 0, W - 1)
@ -218,9 +219,6 @@ def merge_panorama_depth(width: int, height: int,
for the full-resolution solve. Optional callbacks fire per view processed and around each
lsmr solve so callers can drive a progress bar.
"""
from scipy.ndimage import convolve
from scipy.sparse import vstack
from scipy.sparse.linalg import lsmr
if max(width, height) > 256:
coarse_depth, _ = merge_panorama_depth(width // 2, height // 2,

View File

@ -1,94 +0,0 @@
"""Translate MoGe checkpoint keys to the layouts our nn.Modules use.
MoGe checkpoints embed DINOv2 with the original Meta naming
(``backbone.blocks.{i}.attn.qkv.weight``, ``ls1.gamma``, ``mlp.w12``, ...).
The shared ``comfy.image_encoders.dino2.Dinov2Model`` uses HF naming
(``encoder.layer.{i}.attention.attention.{query,key,value}.weight``,
``layer_scale1.lambda1``, ``mlp.weights_in``, ...). We rewrite keys at load
time and split the fused ``qkv`` weight into separate Q/K/V tensors.
"""
from __future__ import annotations
import re
_DINOV2_TOPLEVEL_RENAMES = {
"patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight",
"patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias",
"cls_token": "embeddings.cls_token",
"pos_embed": "embeddings.position_embeddings",
"register_tokens": "embeddings.register_tokens",
"mask_token": "embeddings.mask_token",
"norm.weight": "layernorm.weight",
"norm.bias": "layernorm.bias",
}
_BLOCK_SUFFIX_RENAMES = [
("ls1.gamma", "layer_scale1.lambda1"),
("ls2.gamma", "layer_scale2.lambda1"),
("attn.proj.", "attention.output.dense."),
("mlp.w12.", "mlp.weights_in."),
("mlp.w3.", "mlp.weights_out."),
]
_BLOCK_RE = re.compile(r"^blocks\.(\d+)\.(.+)$")
def remap_dinov2_keys(sd: dict, src_prefix: str = "") -> dict:
"""Rewrite Meta-style DINOv2 keys under ``src_prefix`` to comfy/HF naming.
Splits each fused ``attn.qkv.{weight,bias}`` into separate
``attention.attention.{query,key,value}.{weight,bias}`` tensors using a
chunk along the leading dim.
Keys that do not start with ``src_prefix`` are returned unchanged.
"""
out: dict = {}
for k, v in sd.items():
if not k.startswith(src_prefix):
out[k] = v
continue
rel = k[len(src_prefix):]
# Top-level (cls token, pos embed, patch embed, mask token, register tokens, final norm).
if rel in _DINOV2_TOPLEVEL_RENAMES:
out[src_prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v
continue
m = _BLOCK_RE.match(rel)
if not m:
out[k] = v
continue
i, sub = m.group(1), m.group(2)
# Split fused qkv into separate q / k / v tensors.
if sub == "attn.qkv.weight" or sub == "attn.qkv.bias":
q, kw, vw = v.chunk(3, dim=0)
tail = sub.rsplit(".", 1)[1] # weight / bias
base = "{}encoder.layer.{}.attention.attention".format(src_prefix, i)
out["{}.query.{}".format(base, tail)] = q
out["{}.key.{}".format(base, tail)] = kw
out["{}.value.{}".format(base, tail)] = vw
continue
for old, new in _BLOCK_SUFFIX_RENAMES:
sub = sub.replace(old, new)
out["{}encoder.layer.{}.{}".format(src_prefix, i, sub)] = v
return out
def remap_moge_state_dict(sd: dict) -> dict:
"""Convert a full MoGe checkpoint state dict to the layout our modules expect.
- v1 backbone lives under ``backbone.`` -> rewrite that subtree.
- v2 backbone lives under ``encoder.backbone.`` -> rewrite that subtree.
Everything else (heads, neck, projections, image_mean/std buffers) keeps
its original key names and passes through unchanged.
"""
if any(k.startswith("encoder.backbone.") for k in sd):
return remap_dinov2_keys(sd, src_prefix="encoder.backbone.")
return remap_dinov2_keys(sd, src_prefix="backbone.")

View File

@ -1,163 +0,0 @@
"""High-level loader and inference wrapper for MoGe v1 / v2 checkpoints.
Mirrors the structure of :mod:`comfy.clip_vision`: owns the ``nn.Module`` and
a :class:`comfy.model_patcher.CoreModelPatcher`, exposes a
:meth:`MoGeModel.infer` that runs preprocessing, forward, and post-processing.
"""
from __future__ import annotations
from numbers import Number
from typing import Dict, Optional, Union
import torch
import comfy.model_management
import comfy.model_patcher
import comfy.ops
import comfy.utils
from .ldm.moge.geometry import (
depth_map_to_point_map,
intrinsics_from_focal_center,
recover_focal_shift,
)
from .ldm.moge.model import detect_and_build
from .ldm.moge.state_dict import remap_moge_state_dict
class MoGeModel:
"""Loaded MoGe model + ComfyUI memory management."""
def __init__(self, state_dict: dict):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
sd = remap_moge_state_dict(state_dict)
self.model = detect_and_build(sd, dtype=self.dtype, device=offload_device,
operations=comfy.ops.manual_cast)
self.model.load_state_dict(sd, strict=True)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(
self.model, load_device=self.load_device, offload_device=offload_device
)
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
self.num_tokens_range = (int(nt[0]), int(nt[1]))
@torch.inference_mode()
def infer(self,
image: torch.Tensor,
num_tokens: Optional[int] = None,
resolution_level: int = 9,
fov_x: Optional[Union[Number, torch.Tensor]] = None,
force_projection: bool = True,
apply_mask: bool = True) -> Dict[str, torch.Tensor]:
"""Run a single MoGe forward + post-process pass.
``image`` must already be ``(B, 3, H, W)`` in ``[0, 1]`` on any device.
Returns a dict with at least ``points``, ``depth``, ``intrinsics``,
``mask``; v2 checkpoints additionally produce ``normal``.
"""
comfy.model_management.load_model_gpu(self.patcher)
device = self.load_device
image = image.to(device=device, dtype=self.dtype)
if image.dim() == 3:
image = image.unsqueeze(0)
H, W = image.shape[-2:]
aspect_ratio = W / H
if num_tokens is None:
lo, hi = self.num_tokens_range
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
out = self.model.forward(image, num_tokens=num_tokens)
points = out.get("points")
normal = out.get("normal")
mask = out.get("mask")
metric_scale = out.get("metric_scale")
# Post-processing always runs in fp32 for numerical stability.
if points is not None: points = points.float()
if normal is not None: normal = normal.float()
if mask is not None: mask = mask.float()
if metric_scale is not None: metric_scale = metric_scale.float()
mask_binary = (mask > self.mask_threshold) if mask is not None else None
depth = None
intrinsics = None
if points is not None:
if fov_x is None:
focal, shift = recover_focal_shift(points, mask_binary)
# The unconstrained least-squares solver inside recover_focal_shift
# can converge to a degenerate solution where (z + shift) is
# negative for most pixels, which flips the sign of the
# estimated focal. Detect that and fall back to a sensible
# 60-degree-FoV default rather than emitting garbage geometry.
bad = ~torch.isfinite(focal) | (focal <= 0)
if bool(bad.any()):
default_fov = 60.0
fov_t = torch.as_tensor(default_fov, device=points.device, dtype=points.dtype)
fallback_focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 \
/ torch.tan(torch.deg2rad(fov_t / 2))
fallback_focal = fallback_focal.expand_as(focal).clone()
focal = torch.where(bad, fallback_focal, focal)
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
else:
fov_t = torch.as_tensor(fov_x, device=points.device, dtype=points.dtype)
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(fov_t / 2))
if focal.ndim == 0:
focal = focal[None].expand(points.shape[0])
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
intrinsics = intrinsics_from_focal_center(fx, fy, half, half)
points[..., 2] = points[..., 2] + shift[..., None, None]
# v2 upstream additionally filters mask by depth > 0 as a safeguard
# against negative-depth artifacts from the metric-scale path; v1
# does not, and applying it there can cut out the foreground when
# shift recovery overshoots slightly.
if mask_binary is not None and self.version == "v2":
mask_binary = mask_binary & (points[..., 2] > 0)
depth = points[..., 2].clone()
if force_projection and depth is not None and intrinsics is not None:
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
if metric_scale is not None:
if points is not None:
points = points * metric_scale[:, None, None, None]
if depth is not None:
depth = depth * metric_scale[:, None, None]
if apply_mask and mask_binary is not None:
if points is not None:
points = torch.where(mask_binary[..., None], points,
torch.full_like(points, float("inf")))
if depth is not None:
depth = torch.where(mask_binary, depth,
torch.full_like(depth, float("inf")))
if normal is not None:
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
result = {
"points": points,
"depth": depth,
"intrinsics": intrinsics,
"mask": mask_binary,
"normal": normal,
}
return {k: v for k, v in result.items() if v is not None}
def load(ckpt_path: str) -> MoGeModel:
"""Load a MoGe ``.pt`` / ``.safetensors`` checkpoint into a :class:`MoGeModel`."""
sd = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if isinstance(sd, dict) and "model" in sd and "model_config" in sd:
sd = sd["model"]
return MoGeModel(sd)

View File

@ -2,9 +2,6 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import comfy.utils
@ -14,19 +11,21 @@ from typing_extensions import override
from comfy.ldm.moge.model import MoGeModel
from comfy.ldm.moge.geometry import triangulate_grid_mesh
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
import comfy.model_management
from tqdm.auto import tqdm
MoGeModelType = io.Custom("MOGE_MODEL")
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
@dataclass
class _MoGeGeometryPayload:
points: Optional[torch.Tensor] # (B, H, W, 3)
depth: Optional[torch.Tensor] # (B, H, W)
intrinsics: Optional[torch.Tensor] # (B, 3, 3)
mask: Optional[torch.Tensor] # (B, H, W) bool
normal: Optional[torch.Tensor] # (B, H, W, 3) or None for v1
image: torch.Tensor # (B, H, W, 3) in [0, 1], CPU
# MOGE_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
# "points": torch.Tensor (B, H, W, 3)
# "depth": torch.Tensor (B, H, W)
# "intrinsics": torch.Tensor (B, 3, 3) -- perspective only
# "mask": torch.Tensor (B, H, W) bool
# "normal": torch.Tensor (B, H, W, 3) -- v2 only
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
def _turbo(x: torch.Tensor) -> torch.Tensor:
@ -137,15 +136,7 @@ class MoGePanoramaInference(io.ComfyNode):
)
@classmethod
def execute(cls, moge_model, image, resolution_level,
split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
from comfy.ldm.moge.panorama import (
get_panorama_cameras, split_panorama_image, merge_panorama_depth,
spherical_uv_to_directions, _uv_grid,
)
import comfy.model_management as cmm
import numpy as np
from tqdm.auto import tqdm
def execute(cls, moge_model, image, resolution_level, split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
if image.shape[0] != 1:
raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})")
@ -156,7 +147,7 @@ class MoGePanoramaInference(io.ComfyNode):
extrinsics, intrinsics = get_panorama_cameras()
cmm.load_model_gpu(moge_model.patcher)
comfy.model_management.load_model_gpu(moge_model.patcher)
device = moge_model.load_device
img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype)
splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution)
@ -218,34 +209,25 @@ class MoGePanoramaInference(io.ComfyNode):
merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics,
on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end)
pano_depth = torch.from_numpy(pano_depth)
pano_mask = torch.from_numpy(pano_mask)
if (merge_h, merge_w) != (H, W):
t = torch.from_numpy(pano_depth).unsqueeze(0).unsqueeze(0)
pano_depth = torch.nn.functional.interpolate(t, size=(H, W), mode="bilinear",
align_corners=False).squeeze().numpy().astype(np.float32)
t = torch.from_numpy(pano_mask.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float()
pano_mask = (torch.nn.functional.interpolate(t, size=(H, W), mode="nearest").squeeze().numpy() > 0)
pano_depth = torch.nn.functional.interpolate(pano_depth[None, None], size=(H, W), mode="bilinear", align_corners=False).squeeze()
pano_mask = torch.nn.functional.interpolate(pano_mask[None, None].float(), size=(H, W), mode="nearest").squeeze() > 0
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve
# and stay at log_depth=0 (depth=1) -- without this push-out they form a sphere shell
# woven through the foreground; here we lift them to a far skybox radius instead.
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve and stay at log_depth=0 (depth=1)
if pano_mask.any() and not pano_mask.all():
far = float(np.quantile(pano_depth[pano_mask], 0.95)) * 5.0
pano_depth = np.where(pano_mask, pano_depth, far).astype(np.float32)
far = torch.quantile(pano_depth[pano_mask], 0.95) * 5.0
pano_depth = torch.where(pano_mask, pano_depth, far)
uv = _uv_grid(H, W)
directions = spherical_uv_to_directions(uv)
points_np = directions * pano_depth[..., None]
directions = torch.from_numpy(spherical_uv_to_directions(_uv_grid(H, W)))
points = (directions * pano_depth[..., None]).unsqueeze(0)
depth = pano_depth.unsqueeze(0)
mask = pano_mask.unsqueeze(0)
points = torch.from_numpy(points_np).unsqueeze(0).float()
depth = torch.from_numpy(pano_depth).unsqueeze(0).float()
mask = torch.from_numpy(pano_mask).unsqueeze(0)
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation
# after triangulation -- rotating before would scramble the rtol depth-edge check.
geometry = _MoGeGeometryPayload(
points=points, depth=depth, intrinsics=None, mask=mask, normal=None,
image=image.detach().cpu(),
)
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation after triangulation
geometry = {"points": points, "depth": depth, "mask": mask, "image": image.cpu()}
return io.NodeOutput(geometry)
@ -273,9 +255,7 @@ class MoGeInference(io.ComfyNode):
)
@classmethod
def execute(cls, moge_model, image, resolution_level, fov_x_degrees,
batch_size, force_projection, apply_mask) -> io.NodeOutput:
from tqdm.auto import tqdm
def execute(cls, moge_model, image, resolution_level, fov_x_degrees, batch_size, force_projection, apply_mask) -> io.NodeOutput:
bchw = image.movedim(-1, -3).contiguous()
B = bchw.shape[0]
@ -295,20 +275,14 @@ class MoGeInference(io.ComfyNode):
vals = [c[field] for c in chunks if field in c]
return torch.cat(vals, dim=0) if vals else None
geometry = _MoGeGeometryPayload(
points=stack("points"),
depth=stack("depth"),
intrinsics=stack("intrinsics"),
mask=stack("mask"),
normal=stack("normal"),
image=image.detach().cpu(),
)
geometry = {"image": image.cpu()}
for field in ("points", "depth", "intrinsics", "mask", "normal"):
v = stack(field)
if v is not None:
geometry[field] = v
return io.NodeOutput(geometry)
_RENDER_MODES = ["depth", "depth_colored", "normal", "normal_screen", "mask"]
class MoGeRender(io.ComfyNode):
"""Render a visualization or mask from a MOGE_GEOMETRY packet."""
@ -320,36 +294,32 @@ class MoGeRender(io.ComfyNode):
category="image/geometry",
inputs=[
MoGeGeometry.Input("geometry"),
io.Combo.Input("output", options=_RENDER_MODES, default="depth_colored"),
io.Combo.Input("output", options=["depth", "depth_colored", "normal", "normal_screen", "mask"], default="depth"),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, geometry, output) -> io.NodeOutput:
from tqdm.auto import tqdm
# Pick the input tensor for the chosen mode and validate availability.
if output in ("depth", "depth_colored", "normal_screen"):
if geometry.depth is None:
if "depth" not in geometry:
raise ValueError("MoGeGeometry has no depth output.")
src = geometry.depth
src = geometry["depth"]
elif output == "normal":
if geometry.normal is not None:
src = geometry.normal
elif geometry.points is not None:
src = geometry.points
if "normal" in geometry:
src = geometry["normal"]
elif "points" in geometry:
src = geometry["points"]
else:
raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.")
elif output == "mask":
if geometry.mask is None:
if "mask" not in geometry:
raise ValueError("MoGeGeometry has no mask output.")
src = geometry.mask
src = geometry["mask"]
else:
raise ValueError(f"Unknown output mode: {output}")
import comfy.model_management as cmm
B = src.shape[0]
pbar = comfy.utils.ProgressBar(B)
out: list[torch.Tensor] = []
@ -361,7 +331,7 @@ class MoGeRender(io.ComfyNode):
out.append(_turbo(d) if output == "depth_colored"
else d.unsqueeze(-1).expand(*d.shape, 3).contiguous())
elif output == "normal":
n = slc if geometry.normal is not None else _normals_from_points(slc)
n = slc if "normal" in geometry else _normals_from_points(slc)
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
elif output == "normal_screen":
n = _screen_normals_from_depth(slc)
@ -370,7 +340,7 @@ class MoGeRender(io.ComfyNode):
out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous())
pbar.update_absolute(i + 1)
tq.update(1)
result = torch.cat(out, dim=0).to(device=cmm.intermediate_device(), dtype=cmm.intermediate_dtype())
result = torch.cat(out, dim=0).to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput(result)
@ -385,7 +355,7 @@ class MoGePointMapToMesh(io.ComfyNode):
category="3d",
inputs=[
MoGeGeometry.Input("geometry"),
io.Int.Input("batch_index", default=0, min=0, max=64,
io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts "
"differ, so batches can't be stacked into a single MESH."),
io.Int.Input("decimation", default=1, min=1, max=8,
@ -400,32 +370,32 @@ class MoGePointMapToMesh(io.ComfyNode):
@classmethod
def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput:
if geometry.points is None:
if "points" not in geometry:
raise ValueError("MoGeGeometry has no points output.")
B = geometry.points.shape[0]
points = geometry["points"]
B = points.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.")
# Pass geometry.depth so the rtol edge check sees radial depth -- for panoramas
# Pass depth so the rtol edge check sees radial depth -- for panoramas
# points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half.
edge_depth = geometry.depth[batch_index] if geometry.depth is not None else None
edge_depth = geometry["depth"][batch_index] if "depth" in geometry else None
verts, faces, uvs = triangulate_grid_mesh(
geometry.points[batch_index], decimation=decimation,
points[batch_index], decimation=decimation,
discontinuity_threshold=discontinuity_threshold, depth=edge_depth,
)
if verts.shape[0] == 0 or faces.shape[0] == 0:
raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.")
if geometry.intrinsics is None:
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back). Pure rotation
# preserves the natural inward winding (correct for inside-the-sphere viewing).
if "intrinsics" not in geometry:
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back), correct for inside-the-sphere viewing)
verts = verts[:, [1, 2, 0]].contiguous()
else:
# Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip.
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
faces = faces[:, [0, 2, 1]].contiguous()
tex = geometry.image[batch_index:batch_index + 1] if texture and geometry.image is not None else None
tex = geometry["image"][batch_index:batch_index + 1] if texture else None
mesh = Types.MESH(
vertices=verts.unsqueeze(0),
faces=faces.unsqueeze(0),