mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
Cleanup
This commit is contained in:
parent
84254b388d
commit
6617e76b1c
@ -86,10 +86,10 @@ def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float
|
|||||||
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
||||||
focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64)
|
focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64)
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Recover the focal length and z-shift that turn ``points`` into a metric point map.
|
"""Recover the focal length and z-shift that turn points into a metric point map.
|
||||||
|
|
||||||
Optical center is at the image center; returned focal is relative to half the image diagonal.
|
Optical center is at the image center; returned focal is relative to half the image diagonal.
|
||||||
Returns ``(focal, shift)`` on the same device/dtype as ``points``.
|
Returns (focal, shift) on the same device/dtype as points.
|
||||||
"""
|
"""
|
||||||
shape = points.shape
|
shape = points.shape
|
||||||
H, W = shape[-3], shape[-2]
|
H, W = shape[-3], shape[-2]
|
||||||
@ -155,11 +155,11 @@ def depth_map_edge(depth: torch.Tensor, atol: Optional[float] = None, rtol: Opti
|
|||||||
|
|
||||||
def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04,
|
def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04,
|
||||||
depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Triangulate a (H, W, 3) point map into ``(vertices, faces, uvs)`` on CPU.
|
"""Triangulate a (H, W, 3) point map into (vertices, faces, uvs) on CPU.
|
||||||
|
|
||||||
Vertices: pixels with finite coords (passing optional ``mask``). Quads with four valid corners
|
Vertices: pixels with finite coords (passing optional mask). Quads with four valid corners
|
||||||
become two triangles. ``depth`` overrides the scalar used for the rtol edge check; pass radial
|
become two triangles. depth overrides the scalar used for the rtol edge check; pass radial
|
||||||
depth for panoramas (the default ``points[..., 2]`` goes negative below the equator).
|
depth for panoramas (the default points[..., 2] goes negative below the equator).
|
||||||
"""
|
"""
|
||||||
points = points.detach().cpu()
|
points = points.detach().cpu()
|
||||||
finite = torch.isfinite(points).all(dim=-1)
|
finite = torch.isfinite(points).all(dim=-1)
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from .modules import ConvStack, DINOv2Encoder, HeadV1, MLP, _view_plane_uv_grid
|
|||||||
|
|
||||||
|
|
||||||
def _remap_points(points: torch.Tensor) -> torch.Tensor:
|
def _remap_points(points: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply the ``exp`` remap: z -> exp(z), xy stays linear and gets scaled by the new z."""
|
"""Apply the exp remap: z -> exp(z), xy stays linear and gets scaled by the new z."""
|
||||||
xy, z = points.split([2, 1], dim=-1)
|
xy, z = points.split([2, 1], dim=-1)
|
||||||
z = torch.exp(z)
|
z = torch.exp(z)
|
||||||
return torch.cat([xy * z, z], dim=-1)
|
return torch.cat([xy * z, z], dim=-1)
|
||||||
@ -156,7 +156,7 @@ class MoGeModelV2(nn.Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
|
||||||
"""Detect the v2 encoder/neck/heads config from ``sd``, build a model, and load weights."""
|
"""Detect the v2 encoder/neck/heads config from sd, build a model, and load weights."""
|
||||||
sd = _remap_state_dict(sd)
|
sd = _remap_state_dict(sd)
|
||||||
backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
|
backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
|
||||||
depth = backbone["num_hidden_layers"]
|
depth = backbone["num_hidden_layers"]
|
||||||
@ -184,7 +184,7 @@ class MoGeModelV2(nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]:
|
def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]:
|
||||||
"""Reconstruct a ConvStack config from the keys under ``prefix``"""
|
"""Reconstruct a ConvStack config from the keys under prefix"""
|
||||||
in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")]
|
in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")]
|
||||||
n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys)
|
n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys)
|
||||||
|
|
||||||
@ -230,7 +230,6 @@ _DINOV2_BLOCK_RENAMES = [
|
|||||||
|
|
||||||
|
|
||||||
def _remap_state_dict(sd: dict) -> dict:
|
def _remap_state_dict(sd: dict) -> dict:
|
||||||
"""Unwrap the upstream ``{"model": ..., "model_config": ...}`` envelope and remap DINOv2 keys"""
|
|
||||||
if "model" in sd and "model_config" in sd:
|
if "model" in sd and "model_config" in sd:
|
||||||
sd = sd["model"]
|
sd = sd["model"]
|
||||||
prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone."
|
prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone."
|
||||||
@ -272,6 +271,7 @@ class MoGeModel:
|
|||||||
"""Loaded MoGe model + ComfyUI memory management."""
|
"""Loaded MoGe model + ComfyUI memory management."""
|
||||||
|
|
||||||
def __init__(self, state_dict: dict):
|
def __init__(self, state_dict: dict):
|
||||||
|
# text encoder dtype closest match
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
@ -288,7 +288,7 @@ class MoGeModel:
|
|||||||
force_projection: bool = True, apply_mask: bool = True,
|
force_projection: bool = True, apply_mask: bool = True,
|
||||||
apply_metric_scale: bool = True
|
apply_metric_scale: bool = True
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Run a single MoGe forward + post-process pass. ``image`` is (B, 3, H, W) in [0, 1]."""
|
"""Run a single MoGe forward + post-process pass. image is (B, 3, H, W) in [0, 1]."""
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
image = image.to(device=self.load_device, dtype=self.dtype)
|
image = image.to(device=self.load_device, dtype=self.dtype)
|
||||||
H, W = image.shape[-2:]
|
H, W = image.shape[-2:]
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
Splits an equirect into 12 perspective views via an icosahedron camera rig, runs
|
Splits an equirect into 12 perspective views via an icosahedron camera rig, runs
|
||||||
the model per view, and stitches per-view distance maps back into a single
|
the model per view, and stitches per-view distance maps back into a single
|
||||||
equirect distance map via a multi-scale Poisson + gradient sparse solve.
|
equirect distance map via a multi-scale Poisson + gradient sparse solve.
|
||||||
Image sampling uses ``F.grid_sample`` (GPU); the sparse solve uses ``lsmr`` (CPU).
|
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -14,6 +14,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from scipy.ndimage import convolve, map_coordinates
|
||||||
|
from scipy.sparse import vstack, csr_array
|
||||||
|
from scipy.sparse.linalg import lsmr
|
||||||
|
|
||||||
|
|
||||||
def _icosahedron_directions() -> np.ndarray:
|
def _icosahedron_directions() -> np.ndarray:
|
||||||
"""12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order)."""
|
"""12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order)."""
|
||||||
@ -122,7 +126,7 @@ def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarr
|
|||||||
|
|
||||||
|
|
||||||
def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor:
|
def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor:
|
||||||
"""Sample img_bchw at UV-in-[0,1] coords ``uv`` of shape (B, H, W, 2); replicate-border."""
|
"""Sample img_bchw at UV-in-[0,1] coords uv of shape (B, H, W, 2); replicate-border."""
|
||||||
grid = uv * 2.0 - 1.0
|
grid = uv * 2.0 - 1.0
|
||||||
return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False)
|
return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False)
|
||||||
|
|
||||||
@ -145,7 +149,6 @@ def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics
|
|||||||
|
|
||||||
def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
||||||
"""Sparse Laplacian operator over the H x W grid."""
|
"""Sparse Laplacian operator over the H x W grid."""
|
||||||
from scipy.sparse import csr_array
|
|
||||||
grid_index = np.arange(H * W).reshape(H, W)
|
grid_index = np.arange(H * W).reshape(H, W)
|
||||||
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge")
|
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge")
|
||||||
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge")
|
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge")
|
||||||
@ -162,7 +165,6 @@ def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False
|
|||||||
|
|
||||||
def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
||||||
"""Sparse forward-difference operator over the H x W grid."""
|
"""Sparse forward-difference operator over the H x W grid."""
|
||||||
from scipy.sparse import csr_array
|
|
||||||
grid_index = np.arange(W * H).reshape(H, W)
|
grid_index = np.arange(W * H).reshape(H, W)
|
||||||
if wrap_x:
|
if wrap_x:
|
||||||
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap")
|
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap")
|
||||||
@ -191,7 +193,6 @@ def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
|
|||||||
|
|
||||||
def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray:
|
def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray:
|
||||||
"""Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border."""
|
"""Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border."""
|
||||||
from scipy.ndimage import map_coordinates
|
|
||||||
H, W = img.shape[:2]
|
H, W = img.shape[:2]
|
||||||
yy = np.clip(sample_pixels[..., 1], 0, H - 1)
|
yy = np.clip(sample_pixels[..., 1], 0, H - 1)
|
||||||
xx = np.clip(sample_pixels[..., 0], 0, W - 1)
|
xx = np.clip(sample_pixels[..., 0], 0, W - 1)
|
||||||
@ -218,9 +219,6 @@ def merge_panorama_depth(width: int, height: int,
|
|||||||
for the full-resolution solve. Optional callbacks fire per view processed and around each
|
for the full-resolution solve. Optional callbacks fire per view processed and around each
|
||||||
lsmr solve so callers can drive a progress bar.
|
lsmr solve so callers can drive a progress bar.
|
||||||
"""
|
"""
|
||||||
from scipy.ndimage import convolve
|
|
||||||
from scipy.sparse import vstack
|
|
||||||
from scipy.sparse.linalg import lsmr
|
|
||||||
|
|
||||||
if max(width, height) > 256:
|
if max(width, height) > 256:
|
||||||
coarse_depth, _ = merge_panorama_depth(width // 2, height // 2,
|
coarse_depth, _ = merge_panorama_depth(width // 2, height // 2,
|
||||||
|
|||||||
@ -1,94 +0,0 @@
|
|||||||
"""Translate MoGe checkpoint keys to the layouts our nn.Modules use.
|
|
||||||
|
|
||||||
MoGe checkpoints embed DINOv2 with the original Meta naming
|
|
||||||
(``backbone.blocks.{i}.attn.qkv.weight``, ``ls1.gamma``, ``mlp.w12``, ...).
|
|
||||||
The shared ``comfy.image_encoders.dino2.Dinov2Model`` uses HF naming
|
|
||||||
(``encoder.layer.{i}.attention.attention.{query,key,value}.weight``,
|
|
||||||
``layer_scale1.lambda1``, ``mlp.weights_in``, ...). We rewrite keys at load
|
|
||||||
time and split the fused ``qkv`` weight into separate Q/K/V tensors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
_DINOV2_TOPLEVEL_RENAMES = {
|
|
||||||
"patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight",
|
|
||||||
"patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias",
|
|
||||||
"cls_token": "embeddings.cls_token",
|
|
||||||
"pos_embed": "embeddings.position_embeddings",
|
|
||||||
"register_tokens": "embeddings.register_tokens",
|
|
||||||
"mask_token": "embeddings.mask_token",
|
|
||||||
"norm.weight": "layernorm.weight",
|
|
||||||
"norm.bias": "layernorm.bias",
|
|
||||||
}
|
|
||||||
|
|
||||||
_BLOCK_SUFFIX_RENAMES = [
|
|
||||||
("ls1.gamma", "layer_scale1.lambda1"),
|
|
||||||
("ls2.gamma", "layer_scale2.lambda1"),
|
|
||||||
("attn.proj.", "attention.output.dense."),
|
|
||||||
("mlp.w12.", "mlp.weights_in."),
|
|
||||||
("mlp.w3.", "mlp.weights_out."),
|
|
||||||
]
|
|
||||||
|
|
||||||
_BLOCK_RE = re.compile(r"^blocks\.(\d+)\.(.+)$")
|
|
||||||
|
|
||||||
|
|
||||||
def remap_dinov2_keys(sd: dict, src_prefix: str = "") -> dict:
|
|
||||||
"""Rewrite Meta-style DINOv2 keys under ``src_prefix`` to comfy/HF naming.
|
|
||||||
|
|
||||||
Splits each fused ``attn.qkv.{weight,bias}`` into separate
|
|
||||||
``attention.attention.{query,key,value}.{weight,bias}`` tensors using a
|
|
||||||
chunk along the leading dim.
|
|
||||||
|
|
||||||
Keys that do not start with ``src_prefix`` are returned unchanged.
|
|
||||||
"""
|
|
||||||
out: dict = {}
|
|
||||||
for k, v in sd.items():
|
|
||||||
if not k.startswith(src_prefix):
|
|
||||||
out[k] = v
|
|
||||||
continue
|
|
||||||
rel = k[len(src_prefix):]
|
|
||||||
|
|
||||||
# Top-level (cls token, pos embed, patch embed, mask token, register tokens, final norm).
|
|
||||||
if rel in _DINOV2_TOPLEVEL_RENAMES:
|
|
||||||
out[src_prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = _BLOCK_RE.match(rel)
|
|
||||||
if not m:
|
|
||||||
out[k] = v
|
|
||||||
continue
|
|
||||||
|
|
||||||
i, sub = m.group(1), m.group(2)
|
|
||||||
|
|
||||||
# Split fused qkv into separate q / k / v tensors.
|
|
||||||
if sub == "attn.qkv.weight" or sub == "attn.qkv.bias":
|
|
||||||
q, kw, vw = v.chunk(3, dim=0)
|
|
||||||
tail = sub.rsplit(".", 1)[1] # weight / bias
|
|
||||||
base = "{}encoder.layer.{}.attention.attention".format(src_prefix, i)
|
|
||||||
out["{}.query.{}".format(base, tail)] = q
|
|
||||||
out["{}.key.{}".format(base, tail)] = kw
|
|
||||||
out["{}.value.{}".format(base, tail)] = vw
|
|
||||||
continue
|
|
||||||
|
|
||||||
for old, new in _BLOCK_SUFFIX_RENAMES:
|
|
||||||
sub = sub.replace(old, new)
|
|
||||||
out["{}encoder.layer.{}.{}".format(src_prefix, i, sub)] = v
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def remap_moge_state_dict(sd: dict) -> dict:
|
|
||||||
"""Convert a full MoGe checkpoint state dict to the layout our modules expect.
|
|
||||||
|
|
||||||
- v1 backbone lives under ``backbone.`` -> rewrite that subtree.
|
|
||||||
- v2 backbone lives under ``encoder.backbone.`` -> rewrite that subtree.
|
|
||||||
|
|
||||||
Everything else (heads, neck, projections, image_mean/std buffers) keeps
|
|
||||||
its original key names and passes through unchanged.
|
|
||||||
"""
|
|
||||||
if any(k.startswith("encoder.backbone.") for k in sd):
|
|
||||||
return remap_dinov2_keys(sd, src_prefix="encoder.backbone.")
|
|
||||||
return remap_dinov2_keys(sd, src_prefix="backbone.")
|
|
||||||
163
comfy/moge.py
163
comfy/moge.py
@ -1,163 +0,0 @@
|
|||||||
"""High-level loader and inference wrapper for MoGe v1 / v2 checkpoints.
|
|
||||||
|
|
||||||
Mirrors the structure of :mod:`comfy.clip_vision`: owns the ``nn.Module`` and
|
|
||||||
a :class:`comfy.model_patcher.CoreModelPatcher`, exposes a
|
|
||||||
:meth:`MoGeModel.infer` that runs preprocessing, forward, and post-processing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from numbers import Number
|
|
||||||
from typing import Dict, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.model_patcher
|
|
||||||
import comfy.ops
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
from .ldm.moge.geometry import (
|
|
||||||
depth_map_to_point_map,
|
|
||||||
intrinsics_from_focal_center,
|
|
||||||
recover_focal_shift,
|
|
||||||
)
|
|
||||||
from .ldm.moge.model import detect_and_build
|
|
||||||
from .ldm.moge.state_dict import remap_moge_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
class MoGeModel:
|
|
||||||
"""Loaded MoGe model + ComfyUI memory management."""
|
|
||||||
|
|
||||||
def __init__(self, state_dict: dict):
|
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
|
||||||
|
|
||||||
sd = remap_moge_state_dict(state_dict)
|
|
||||||
self.model = detect_and_build(sd, dtype=self.dtype, device=offload_device,
|
|
||||||
operations=comfy.ops.manual_cast)
|
|
||||||
self.model.load_state_dict(sd, strict=True)
|
|
||||||
self.model.eval()
|
|
||||||
self.patcher = comfy.model_patcher.CoreModelPatcher(
|
|
||||||
self.model, load_device=self.load_device, offload_device=offload_device
|
|
||||||
)
|
|
||||||
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
|
|
||||||
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
|
|
||||||
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
|
|
||||||
self.num_tokens_range = (int(nt[0]), int(nt[1]))
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def infer(self,
|
|
||||||
image: torch.Tensor,
|
|
||||||
num_tokens: Optional[int] = None,
|
|
||||||
resolution_level: int = 9,
|
|
||||||
fov_x: Optional[Union[Number, torch.Tensor]] = None,
|
|
||||||
force_projection: bool = True,
|
|
||||||
apply_mask: bool = True) -> Dict[str, torch.Tensor]:
|
|
||||||
"""Run a single MoGe forward + post-process pass.
|
|
||||||
|
|
||||||
``image`` must already be ``(B, 3, H, W)`` in ``[0, 1]`` on any device.
|
|
||||||
Returns a dict with at least ``points``, ``depth``, ``intrinsics``,
|
|
||||||
``mask``; v2 checkpoints additionally produce ``normal``.
|
|
||||||
"""
|
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
|
||||||
device = self.load_device
|
|
||||||
image = image.to(device=device, dtype=self.dtype)
|
|
||||||
|
|
||||||
if image.dim() == 3:
|
|
||||||
image = image.unsqueeze(0)
|
|
||||||
H, W = image.shape[-2:]
|
|
||||||
aspect_ratio = W / H
|
|
||||||
|
|
||||||
if num_tokens is None:
|
|
||||||
lo, hi = self.num_tokens_range
|
|
||||||
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
|
|
||||||
|
|
||||||
out = self.model.forward(image, num_tokens=num_tokens)
|
|
||||||
points = out.get("points")
|
|
||||||
normal = out.get("normal")
|
|
||||||
mask = out.get("mask")
|
|
||||||
metric_scale = out.get("metric_scale")
|
|
||||||
|
|
||||||
# Post-processing always runs in fp32 for numerical stability.
|
|
||||||
if points is not None: points = points.float()
|
|
||||||
if normal is not None: normal = normal.float()
|
|
||||||
if mask is not None: mask = mask.float()
|
|
||||||
if metric_scale is not None: metric_scale = metric_scale.float()
|
|
||||||
|
|
||||||
mask_binary = (mask > self.mask_threshold) if mask is not None else None
|
|
||||||
|
|
||||||
depth = None
|
|
||||||
intrinsics = None
|
|
||||||
if points is not None:
|
|
||||||
if fov_x is None:
|
|
||||||
focal, shift = recover_focal_shift(points, mask_binary)
|
|
||||||
# The unconstrained least-squares solver inside recover_focal_shift
|
|
||||||
# can converge to a degenerate solution where (z + shift) is
|
|
||||||
# negative for most pixels, which flips the sign of the
|
|
||||||
# estimated focal. Detect that and fall back to a sensible
|
|
||||||
# 60-degree-FoV default rather than emitting garbage geometry.
|
|
||||||
bad = ~torch.isfinite(focal) | (focal <= 0)
|
|
||||||
if bool(bad.any()):
|
|
||||||
default_fov = 60.0
|
|
||||||
fov_t = torch.as_tensor(default_fov, device=points.device, dtype=points.dtype)
|
|
||||||
fallback_focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 \
|
|
||||||
/ torch.tan(torch.deg2rad(fov_t / 2))
|
|
||||||
fallback_focal = fallback_focal.expand_as(focal).clone()
|
|
||||||
focal = torch.where(bad, fallback_focal, focal)
|
|
||||||
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
|
||||||
else:
|
|
||||||
fov_t = torch.as_tensor(fov_x, device=points.device, dtype=points.dtype)
|
|
||||||
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(fov_t / 2))
|
|
||||||
if focal.ndim == 0:
|
|
||||||
focal = focal[None].expand(points.shape[0])
|
|
||||||
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
|
||||||
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
|
|
||||||
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
|
|
||||||
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
|
|
||||||
intrinsics = intrinsics_from_focal_center(fx, fy, half, half)
|
|
||||||
points[..., 2] = points[..., 2] + shift[..., None, None]
|
|
||||||
# v2 upstream additionally filters mask by depth > 0 as a safeguard
|
|
||||||
# against negative-depth artifacts from the metric-scale path; v1
|
|
||||||
# does not, and applying it there can cut out the foreground when
|
|
||||||
# shift recovery overshoots slightly.
|
|
||||||
if mask_binary is not None and self.version == "v2":
|
|
||||||
mask_binary = mask_binary & (points[..., 2] > 0)
|
|
||||||
depth = points[..., 2].clone()
|
|
||||||
|
|
||||||
if force_projection and depth is not None and intrinsics is not None:
|
|
||||||
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
|
|
||||||
|
|
||||||
if metric_scale is not None:
|
|
||||||
if points is not None:
|
|
||||||
points = points * metric_scale[:, None, None, None]
|
|
||||||
if depth is not None:
|
|
||||||
depth = depth * metric_scale[:, None, None]
|
|
||||||
|
|
||||||
if apply_mask and mask_binary is not None:
|
|
||||||
if points is not None:
|
|
||||||
points = torch.where(mask_binary[..., None], points,
|
|
||||||
torch.full_like(points, float("inf")))
|
|
||||||
if depth is not None:
|
|
||||||
depth = torch.where(mask_binary, depth,
|
|
||||||
torch.full_like(depth, float("inf")))
|
|
||||||
if normal is not None:
|
|
||||||
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"points": points,
|
|
||||||
"depth": depth,
|
|
||||||
"intrinsics": intrinsics,
|
|
||||||
"mask": mask_binary,
|
|
||||||
"normal": normal,
|
|
||||||
}
|
|
||||||
return {k: v for k, v in result.items() if v is not None}
|
|
||||||
|
|
||||||
|
|
||||||
def load(ckpt_path: str) -> MoGeModel:
|
|
||||||
"""Load a MoGe ``.pt`` / ``.safetensors`` checkpoint into a :class:`MoGeModel`."""
|
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
|
||||||
if isinstance(sd, dict) and "model" in sd and "model_config" in sd:
|
|
||||||
sd = sd["model"]
|
|
||||||
return MoGeModel(sd)
|
|
||||||
@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -14,19 +11,21 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from comfy.ldm.moge.model import MoGeModel
|
from comfy.ldm.moge.model import MoGeModel
|
||||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||||
|
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
|
||||||
|
import comfy.model_management
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
MoGeModelType = io.Custom("MOGE_MODEL")
|
MoGeModelType = io.Custom("MOGE_MODEL")
|
||||||
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
|
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
# MOGE_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
|
||||||
class _MoGeGeometryPayload:
|
# "points": torch.Tensor (B, H, W, 3)
|
||||||
points: Optional[torch.Tensor] # (B, H, W, 3)
|
# "depth": torch.Tensor (B, H, W)
|
||||||
depth: Optional[torch.Tensor] # (B, H, W)
|
# "intrinsics": torch.Tensor (B, 3, 3) -- perspective only
|
||||||
intrinsics: Optional[torch.Tensor] # (B, 3, 3)
|
# "mask": torch.Tensor (B, H, W) bool
|
||||||
mask: Optional[torch.Tensor] # (B, H, W) bool
|
# "normal": torch.Tensor (B, H, W, 3) -- v2 only
|
||||||
normal: Optional[torch.Tensor] # (B, H, W, 3) or None for v1
|
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
|
||||||
image: torch.Tensor # (B, H, W, 3) in [0, 1], CPU
|
|
||||||
|
|
||||||
|
|
||||||
def _turbo(x: torch.Tensor) -> torch.Tensor:
|
def _turbo(x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -137,15 +136,7 @@ class MoGePanoramaInference(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, moge_model, image, resolution_level,
|
def execute(cls, moge_model, image, resolution_level, split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
|
||||||
split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
|
|
||||||
from comfy.ldm.moge.panorama import (
|
|
||||||
get_panorama_cameras, split_panorama_image, merge_panorama_depth,
|
|
||||||
spherical_uv_to_directions, _uv_grid,
|
|
||||||
)
|
|
||||||
import comfy.model_management as cmm
|
|
||||||
import numpy as np
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
if image.shape[0] != 1:
|
if image.shape[0] != 1:
|
||||||
raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})")
|
raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})")
|
||||||
@ -156,7 +147,7 @@ class MoGePanoramaInference(io.ComfyNode):
|
|||||||
|
|
||||||
extrinsics, intrinsics = get_panorama_cameras()
|
extrinsics, intrinsics = get_panorama_cameras()
|
||||||
|
|
||||||
cmm.load_model_gpu(moge_model.patcher)
|
comfy.model_management.load_model_gpu(moge_model.patcher)
|
||||||
device = moge_model.load_device
|
device = moge_model.load_device
|
||||||
img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype)
|
img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype)
|
||||||
splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution)
|
splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution)
|
||||||
@ -218,34 +209,25 @@ class MoGePanoramaInference(io.ComfyNode):
|
|||||||
merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics,
|
merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics,
|
||||||
on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end)
|
on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end)
|
||||||
|
|
||||||
|
pano_depth = torch.from_numpy(pano_depth)
|
||||||
|
pano_mask = torch.from_numpy(pano_mask)
|
||||||
|
|
||||||
if (merge_h, merge_w) != (H, W):
|
if (merge_h, merge_w) != (H, W):
|
||||||
t = torch.from_numpy(pano_depth).unsqueeze(0).unsqueeze(0)
|
pano_depth = torch.nn.functional.interpolate(pano_depth[None, None], size=(H, W), mode="bilinear", align_corners=False).squeeze()
|
||||||
pano_depth = torch.nn.functional.interpolate(t, size=(H, W), mode="bilinear",
|
pano_mask = torch.nn.functional.interpolate(pano_mask[None, None].float(), size=(H, W), mode="nearest").squeeze() > 0
|
||||||
align_corners=False).squeeze().numpy().astype(np.float32)
|
|
||||||
t = torch.from_numpy(pano_mask.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float()
|
|
||||||
pano_mask = (torch.nn.functional.interpolate(t, size=(H, W), mode="nearest").squeeze().numpy() > 0)
|
|
||||||
|
|
||||||
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve
|
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve and stay at log_depth=0 (depth=1)
|
||||||
# and stay at log_depth=0 (depth=1) -- without this push-out they form a sphere shell
|
|
||||||
# woven through the foreground; here we lift them to a far skybox radius instead.
|
|
||||||
if pano_mask.any() and not pano_mask.all():
|
if pano_mask.any() and not pano_mask.all():
|
||||||
far = float(np.quantile(pano_depth[pano_mask], 0.95)) * 5.0
|
far = torch.quantile(pano_depth[pano_mask], 0.95) * 5.0
|
||||||
pano_depth = np.where(pano_mask, pano_depth, far).astype(np.float32)
|
pano_depth = torch.where(pano_mask, pano_depth, far)
|
||||||
|
|
||||||
uv = _uv_grid(H, W)
|
directions = torch.from_numpy(spherical_uv_to_directions(_uv_grid(H, W)))
|
||||||
directions = spherical_uv_to_directions(uv)
|
points = (directions * pano_depth[..., None]).unsqueeze(0)
|
||||||
points_np = directions * pano_depth[..., None]
|
depth = pano_depth.unsqueeze(0)
|
||||||
|
mask = pano_mask.unsqueeze(0)
|
||||||
|
|
||||||
points = torch.from_numpy(points_np).unsqueeze(0).float()
|
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation after triangulation
|
||||||
depth = torch.from_numpy(pano_depth).unsqueeze(0).float()
|
geometry = {"points": points, "depth": depth, "mask": mask, "image": image.cpu()}
|
||||||
mask = torch.from_numpy(pano_mask).unsqueeze(0)
|
|
||||||
|
|
||||||
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation
|
|
||||||
# after triangulation -- rotating before would scramble the rtol depth-edge check.
|
|
||||||
geometry = _MoGeGeometryPayload(
|
|
||||||
points=points, depth=depth, intrinsics=None, mask=mask, normal=None,
|
|
||||||
image=image.detach().cpu(),
|
|
||||||
)
|
|
||||||
return io.NodeOutput(geometry)
|
return io.NodeOutput(geometry)
|
||||||
|
|
||||||
|
|
||||||
@ -273,9 +255,7 @@ class MoGeInference(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, moge_model, image, resolution_level, fov_x_degrees,
|
def execute(cls, moge_model, image, resolution_level, fov_x_degrees, batch_size, force_projection, apply_mask) -> io.NodeOutput:
|
||||||
batch_size, force_projection, apply_mask) -> io.NodeOutput:
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
bchw = image.movedim(-1, -3).contiguous()
|
bchw = image.movedim(-1, -3).contiguous()
|
||||||
B = bchw.shape[0]
|
B = bchw.shape[0]
|
||||||
@ -295,20 +275,14 @@ class MoGeInference(io.ComfyNode):
|
|||||||
vals = [c[field] for c in chunks if field in c]
|
vals = [c[field] for c in chunks if field in c]
|
||||||
return torch.cat(vals, dim=0) if vals else None
|
return torch.cat(vals, dim=0) if vals else None
|
||||||
|
|
||||||
geometry = _MoGeGeometryPayload(
|
geometry = {"image": image.cpu()}
|
||||||
points=stack("points"),
|
for field in ("points", "depth", "intrinsics", "mask", "normal"):
|
||||||
depth=stack("depth"),
|
v = stack(field)
|
||||||
intrinsics=stack("intrinsics"),
|
if v is not None:
|
||||||
mask=stack("mask"),
|
geometry[field] = v
|
||||||
normal=stack("normal"),
|
|
||||||
image=image.detach().cpu(),
|
|
||||||
)
|
|
||||||
return io.NodeOutput(geometry)
|
return io.NodeOutput(geometry)
|
||||||
|
|
||||||
|
|
||||||
_RENDER_MODES = ["depth", "depth_colored", "normal", "normal_screen", "mask"]
|
|
||||||
|
|
||||||
|
|
||||||
class MoGeRender(io.ComfyNode):
|
class MoGeRender(io.ComfyNode):
|
||||||
"""Render a visualization or mask from a MOGE_GEOMETRY packet."""
|
"""Render a visualization or mask from a MOGE_GEOMETRY packet."""
|
||||||
|
|
||||||
@ -320,36 +294,32 @@ class MoGeRender(io.ComfyNode):
|
|||||||
category="image/geometry",
|
category="image/geometry",
|
||||||
inputs=[
|
inputs=[
|
||||||
MoGeGeometry.Input("geometry"),
|
MoGeGeometry.Input("geometry"),
|
||||||
io.Combo.Input("output", options=_RENDER_MODES, default="depth_colored"),
|
io.Combo.Input("output", options=["depth", "depth_colored", "normal", "normal_screen", "mask"], default="depth"),
|
||||||
],
|
],
|
||||||
outputs=[io.Image.Output()],
|
outputs=[io.Image.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, geometry, output) -> io.NodeOutput:
|
def execute(cls, geometry, output) -> io.NodeOutput:
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
# Pick the input tensor for the chosen mode and validate availability.
|
# Pick the input tensor for the chosen mode and validate availability.
|
||||||
if output in ("depth", "depth_colored", "normal_screen"):
|
if output in ("depth", "depth_colored", "normal_screen"):
|
||||||
if geometry.depth is None:
|
if "depth" not in geometry:
|
||||||
raise ValueError("MoGeGeometry has no depth output.")
|
raise ValueError("MoGeGeometry has no depth output.")
|
||||||
src = geometry.depth
|
src = geometry["depth"]
|
||||||
elif output == "normal":
|
elif output == "normal":
|
||||||
if geometry.normal is not None:
|
if "normal" in geometry:
|
||||||
src = geometry.normal
|
src = geometry["normal"]
|
||||||
elif geometry.points is not None:
|
elif "points" in geometry:
|
||||||
src = geometry.points
|
src = geometry["points"]
|
||||||
else:
|
else:
|
||||||
raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.")
|
raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.")
|
||||||
elif output == "mask":
|
elif output == "mask":
|
||||||
if geometry.mask is None:
|
if "mask" not in geometry:
|
||||||
raise ValueError("MoGeGeometry has no mask output.")
|
raise ValueError("MoGeGeometry has no mask output.")
|
||||||
src = geometry.mask
|
src = geometry["mask"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown output mode: {output}")
|
raise ValueError(f"Unknown output mode: {output}")
|
||||||
|
|
||||||
import comfy.model_management as cmm
|
|
||||||
|
|
||||||
B = src.shape[0]
|
B = src.shape[0]
|
||||||
pbar = comfy.utils.ProgressBar(B)
|
pbar = comfy.utils.ProgressBar(B)
|
||||||
out: list[torch.Tensor] = []
|
out: list[torch.Tensor] = []
|
||||||
@ -361,7 +331,7 @@ class MoGeRender(io.ComfyNode):
|
|||||||
out.append(_turbo(d) if output == "depth_colored"
|
out.append(_turbo(d) if output == "depth_colored"
|
||||||
else d.unsqueeze(-1).expand(*d.shape, 3).contiguous())
|
else d.unsqueeze(-1).expand(*d.shape, 3).contiguous())
|
||||||
elif output == "normal":
|
elif output == "normal":
|
||||||
n = slc if geometry.normal is not None else _normals_from_points(slc)
|
n = slc if "normal" in geometry else _normals_from_points(slc)
|
||||||
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
|
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
|
||||||
elif output == "normal_screen":
|
elif output == "normal_screen":
|
||||||
n = _screen_normals_from_depth(slc)
|
n = _screen_normals_from_depth(slc)
|
||||||
@ -370,7 +340,7 @@ class MoGeRender(io.ComfyNode):
|
|||||||
out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous())
|
out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous())
|
||||||
pbar.update_absolute(i + 1)
|
pbar.update_absolute(i + 1)
|
||||||
tq.update(1)
|
tq.update(1)
|
||||||
result = torch.cat(out, dim=0).to(device=cmm.intermediate_device(), dtype=cmm.intermediate_dtype())
|
result = torch.cat(out, dim=0).to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return io.NodeOutput(result)
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
@ -385,7 +355,7 @@ class MoGePointMapToMesh(io.ComfyNode):
|
|||||||
category="3d",
|
category="3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
MoGeGeometry.Input("geometry"),
|
MoGeGeometry.Input("geometry"),
|
||||||
io.Int.Input("batch_index", default=0, min=0, max=64,
|
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||||
tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts "
|
tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts "
|
||||||
"differ, so batches can't be stacked into a single MESH."),
|
"differ, so batches can't be stacked into a single MESH."),
|
||||||
io.Int.Input("decimation", default=1, min=1, max=8,
|
io.Int.Input("decimation", default=1, min=1, max=8,
|
||||||
@ -400,32 +370,32 @@ class MoGePointMapToMesh(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput:
|
def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput:
|
||||||
if geometry.points is None:
|
if "points" not in geometry:
|
||||||
raise ValueError("MoGeGeometry has no points output.")
|
raise ValueError("MoGeGeometry has no points output.")
|
||||||
B = geometry.points.shape[0]
|
points = geometry["points"]
|
||||||
|
B = points.shape[0]
|
||||||
if batch_index >= B:
|
if batch_index >= B:
|
||||||
raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.")
|
raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.")
|
||||||
|
|
||||||
# Pass geometry.depth so the rtol edge check sees radial depth -- for panoramas
|
# Pass depth so the rtol edge check sees radial depth -- for panoramas
|
||||||
# points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half.
|
# points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half.
|
||||||
edge_depth = geometry.depth[batch_index] if geometry.depth is not None else None
|
edge_depth = geometry["depth"][batch_index] if "depth" in geometry else None
|
||||||
verts, faces, uvs = triangulate_grid_mesh(
|
verts, faces, uvs = triangulate_grid_mesh(
|
||||||
geometry.points[batch_index], decimation=decimation,
|
points[batch_index], decimation=decimation,
|
||||||
discontinuity_threshold=discontinuity_threshold, depth=edge_depth,
|
discontinuity_threshold=discontinuity_threshold, depth=edge_depth,
|
||||||
)
|
)
|
||||||
if verts.shape[0] == 0 or faces.shape[0] == 0:
|
if verts.shape[0] == 0 or faces.shape[0] == 0:
|
||||||
raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.")
|
raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.")
|
||||||
|
|
||||||
if geometry.intrinsics is None:
|
if "intrinsics" not in geometry:
|
||||||
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back). Pure rotation
|
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back), correct for inside-the-sphere viewing)
|
||||||
# preserves the natural inward winding (correct for inside-the-sphere viewing).
|
|
||||||
verts = verts[:, [1, 2, 0]].contiguous()
|
verts = verts[:, [1, 2, 0]].contiguous()
|
||||||
else:
|
else:
|
||||||
# Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip.
|
# Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip.
|
||||||
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
|
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
|
||||||
faces = faces[:, [0, 2, 1]].contiguous()
|
faces = faces[:, [0, 2, 1]].contiguous()
|
||||||
|
|
||||||
tex = geometry.image[batch_index:batch_index + 1] if texture and geometry.image is not None else None
|
tex = geometry["image"][batch_index:batch_index + 1] if texture else None
|
||||||
mesh = Types.MESH(
|
mesh = Types.MESH(
|
||||||
vertices=verts.unsqueeze(0),
|
vertices=verts.unsqueeze(0),
|
||||||
faces=faces.unsqueeze(0),
|
faces=faces.unsqueeze(0),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user