From 6617e76b1c5d4242db678d62a0bc1b79b001e448 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 13 May 2026 21:24:49 +0300 Subject: [PATCH] Cleanup --- comfy/ldm/moge/__init__.py | 0 comfy/ldm/moge/geometry.py | 12 +-- comfy/ldm/moge/model.py | 10 +-- comfy/ldm/moge/panorama.py | 14 ++- comfy/ldm/moge/state_dict.py | 94 -------------------- comfy/moge.py | 163 ----------------------------------- comfy_extras/nodes_moge.py | 136 ++++++++++++----------------- 7 files changed, 70 insertions(+), 359 deletions(-) delete mode 100644 comfy/ldm/moge/__init__.py delete mode 100644 comfy/ldm/moge/state_dict.py delete mode 100644 comfy/moge.py diff --git a/comfy/ldm/moge/__init__.py b/comfy/ldm/moge/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/comfy/ldm/moge/geometry.py b/comfy/ldm/moge/geometry.py index 3174bd613..9612bd5af 100644 --- a/comfy/ldm/moge/geometry.py +++ b/comfy/ldm/moge/geometry.py @@ -86,10 +86,10 @@ def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float) -> float def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None, focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64) ) -> Tuple[torch.Tensor, torch.Tensor]: - """Recover the focal length and z-shift that turn ``points`` into a metric point map. + """Recover the focal length and z-shift that turn points into a metric point map. Optical center is at the image center; returned focal is relative to half the image diagonal. - Returns ``(focal, shift)`` on the same device/dtype as ``points``. + Returns (focal, shift) on the same device/dtype as points. """ shape = points.shape H, W = shape[-3], shape[-2] @@ -155,11 +155,11 @@ def depth_map_edge(depth: torch.Tensor, atol: Optional[float] = None, rtol: Opti def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04, depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Triangulate a (H, W, 3) point map into ``(vertices, faces, uvs)`` on CPU. + """Triangulate a (H, W, 3) point map into (vertices, faces, uvs) on CPU. - Vertices: pixels with finite coords (passing optional ``mask``). Quads with four valid corners - become two triangles. ``depth`` overrides the scalar used for the rtol edge check; pass radial - depth for panoramas (the default ``points[..., 2]`` goes negative below the equator). + Vertices: pixels with finite coords (passing optional mask). Quads with four valid corners + become two triangles. depth overrides the scalar used for the rtol edge check; pass radial + depth for panoramas (the default points[..., 2] goes negative below the equator). """ points = points.detach().cpu() finite = torch.isfinite(points).all(dim=-1) diff --git a/comfy/ldm/moge/model.py b/comfy/ldm/moge/model.py index fe340f5e1..34246a8d9 100644 --- a/comfy/ldm/moge/model.py +++ b/comfy/ldm/moge/model.py @@ -24,7 +24,7 @@ from .modules import ConvStack, DINOv2Encoder, HeadV1, MLP, _view_plane_uv_grid def _remap_points(points: torch.Tensor) -> torch.Tensor: - """Apply the ``exp`` remap: z -> exp(z), xy stays linear and gets scaled by the new z.""" + """Apply the exp remap: z -> exp(z), xy stays linear and gets scaled by the new z.""" xy, z = points.split([2, 1], dim=-1) z = torch.exp(z) return torch.cat([xy * z, z], dim=-1) @@ -156,7 +156,7 @@ class MoGeModelV2(nn.Module): @classmethod def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast): - """Detect the v2 encoder/neck/heads config from ``sd``, build a model, and load weights.""" + """Detect the v2 encoder/neck/heads config from sd, build a model, and load weights.""" sd = _remap_state_dict(sd) backbone = _detect_dinov2(sd, prefix="encoder.backbone.") depth = backbone["num_hidden_layers"] @@ -184,7 +184,7 @@ class MoGeModelV2(nn.Module): @staticmethod def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]: - """Reconstruct a ConvStack config from the keys under ``prefix``""" + """Reconstruct a ConvStack config from the keys under prefix""" in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")] n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys) @@ -230,7 +230,6 @@ _DINOV2_BLOCK_RENAMES = [ def _remap_state_dict(sd: dict) -> dict: - """Unwrap the upstream ``{"model": ..., "model_config": ...}`` envelope and remap DINOv2 keys""" if "model" in sd and "model_config" in sd: sd = sd["model"] prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone." @@ -272,6 +271,7 @@ class MoGeModel: """Loaded MoGe model + ComfyUI memory management.""" def __init__(self, state_dict: dict): + # text encoder dtype closest match self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) @@ -288,7 +288,7 @@ class MoGeModel: force_projection: bool = True, apply_mask: bool = True, apply_metric_scale: bool = True ) -> Dict[str, torch.Tensor]: - """Run a single MoGe forward + post-process pass. ``image`` is (B, 3, H, W) in [0, 1].""" + """Run a single MoGe forward + post-process pass. image is (B, 3, H, W) in [0, 1].""" comfy.model_management.load_model_gpu(self.patcher) image = image.to(device=self.load_device, dtype=self.dtype) H, W = image.shape[-2:] diff --git a/comfy/ldm/moge/panorama.py b/comfy/ldm/moge/panorama.py index 76b2a4daf..de53ebe68 100644 --- a/comfy/ldm/moge/panorama.py +++ b/comfy/ldm/moge/panorama.py @@ -3,7 +3,7 @@ Splits an equirect into 12 perspective views via an icosahedron camera rig, runs the model per view, and stitches per-view distance maps back into a single equirect distance map via a multi-scale Poisson + gradient sparse solve. -Image sampling uses ``F.grid_sample`` (GPU); the sparse solve uses ``lsmr`` (CPU). +Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU). """ from __future__ import annotations @@ -14,6 +14,10 @@ import numpy as np import torch import torch.nn.functional as F +from scipy.ndimage import convolve, map_coordinates +from scipy.sparse import vstack, csr_array +from scipy.sparse.linalg import lsmr + def _icosahedron_directions() -> np.ndarray: """12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order).""" @@ -122,7 +126,7 @@ def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarr def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor: - """Sample img_bchw at UV-in-[0,1] coords ``uv`` of shape (B, H, W, 2); replicate-border.""" + """Sample img_bchw at UV-in-[0,1] coords uv of shape (B, H, W, 2); replicate-border.""" grid = uv * 2.0 - 1.0 return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False) @@ -145,7 +149,6 @@ def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False): """Sparse Laplacian operator over the H x W grid.""" - from scipy.sparse import csr_array grid_index = np.arange(H * W).reshape(H, W) grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge") grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge") @@ -162,7 +165,6 @@ def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False): """Sparse forward-difference operator over the H x W grid.""" - from scipy.sparse import csr_array grid_index = np.arange(W * H).reshape(H, W) if wrap_x: grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap") @@ -191,7 +193,6 @@ def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False): def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray: """Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border.""" - from scipy.ndimage import map_coordinates H, W = img.shape[:2] yy = np.clip(sample_pixels[..., 1], 0, H - 1) xx = np.clip(sample_pixels[..., 0], 0, W - 1) @@ -218,9 +219,6 @@ def merge_panorama_depth(width: int, height: int, for the full-resolution solve. Optional callbacks fire per view processed and around each lsmr solve so callers can drive a progress bar. """ - from scipy.ndimage import convolve - from scipy.sparse import vstack - from scipy.sparse.linalg import lsmr if max(width, height) > 256: coarse_depth, _ = merge_panorama_depth(width // 2, height // 2, diff --git a/comfy/ldm/moge/state_dict.py b/comfy/ldm/moge/state_dict.py deleted file mode 100644 index a241d720e..000000000 --- a/comfy/ldm/moge/state_dict.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Translate MoGe checkpoint keys to the layouts our nn.Modules use. - -MoGe checkpoints embed DINOv2 with the original Meta naming -(``backbone.blocks.{i}.attn.qkv.weight``, ``ls1.gamma``, ``mlp.w12``, ...). -The shared ``comfy.image_encoders.dino2.Dinov2Model`` uses HF naming -(``encoder.layer.{i}.attention.attention.{query,key,value}.weight``, -``layer_scale1.lambda1``, ``mlp.weights_in``, ...). We rewrite keys at load -time and split the fused ``qkv`` weight into separate Q/K/V tensors. -""" - -from __future__ import annotations - -import re - - -_DINOV2_TOPLEVEL_RENAMES = { - "patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight", - "patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias", - "cls_token": "embeddings.cls_token", - "pos_embed": "embeddings.position_embeddings", - "register_tokens": "embeddings.register_tokens", - "mask_token": "embeddings.mask_token", - "norm.weight": "layernorm.weight", - "norm.bias": "layernorm.bias", -} - -_BLOCK_SUFFIX_RENAMES = [ - ("ls1.gamma", "layer_scale1.lambda1"), - ("ls2.gamma", "layer_scale2.lambda1"), - ("attn.proj.", "attention.output.dense."), - ("mlp.w12.", "mlp.weights_in."), - ("mlp.w3.", "mlp.weights_out."), -] - -_BLOCK_RE = re.compile(r"^blocks\.(\d+)\.(.+)$") - - -def remap_dinov2_keys(sd: dict, src_prefix: str = "") -> dict: - """Rewrite Meta-style DINOv2 keys under ``src_prefix`` to comfy/HF naming. - - Splits each fused ``attn.qkv.{weight,bias}`` into separate - ``attention.attention.{query,key,value}.{weight,bias}`` tensors using a - chunk along the leading dim. - - Keys that do not start with ``src_prefix`` are returned unchanged. - """ - out: dict = {} - for k, v in sd.items(): - if not k.startswith(src_prefix): - out[k] = v - continue - rel = k[len(src_prefix):] - - # Top-level (cls token, pos embed, patch embed, mask token, register tokens, final norm). - if rel in _DINOV2_TOPLEVEL_RENAMES: - out[src_prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v - continue - - m = _BLOCK_RE.match(rel) - if not m: - out[k] = v - continue - - i, sub = m.group(1), m.group(2) - - # Split fused qkv into separate q / k / v tensors. - if sub == "attn.qkv.weight" or sub == "attn.qkv.bias": - q, kw, vw = v.chunk(3, dim=0) - tail = sub.rsplit(".", 1)[1] # weight / bias - base = "{}encoder.layer.{}.attention.attention".format(src_prefix, i) - out["{}.query.{}".format(base, tail)] = q - out["{}.key.{}".format(base, tail)] = kw - out["{}.value.{}".format(base, tail)] = vw - continue - - for old, new in _BLOCK_SUFFIX_RENAMES: - sub = sub.replace(old, new) - out["{}encoder.layer.{}.{}".format(src_prefix, i, sub)] = v - - return out - - -def remap_moge_state_dict(sd: dict) -> dict: - """Convert a full MoGe checkpoint state dict to the layout our modules expect. - - - v1 backbone lives under ``backbone.`` -> rewrite that subtree. - - v2 backbone lives under ``encoder.backbone.`` -> rewrite that subtree. - - Everything else (heads, neck, projections, image_mean/std buffers) keeps - its original key names and passes through unchanged. - """ - if any(k.startswith("encoder.backbone.") for k in sd): - return remap_dinov2_keys(sd, src_prefix="encoder.backbone.") - return remap_dinov2_keys(sd, src_prefix="backbone.") diff --git a/comfy/moge.py b/comfy/moge.py deleted file mode 100644 index a04d9f149..000000000 --- a/comfy/moge.py +++ /dev/null @@ -1,163 +0,0 @@ -"""High-level loader and inference wrapper for MoGe v1 / v2 checkpoints. - -Mirrors the structure of :mod:`comfy.clip_vision`: owns the ``nn.Module`` and -a :class:`comfy.model_patcher.CoreModelPatcher`, exposes a -:meth:`MoGeModel.infer` that runs preprocessing, forward, and post-processing. -""" - -from __future__ import annotations - -from numbers import Number -from typing import Dict, Optional, Union - -import torch - -import comfy.model_management -import comfy.model_patcher -import comfy.ops -import comfy.utils - -from .ldm.moge.geometry import ( - depth_map_to_point_map, - intrinsics_from_focal_center, - recover_focal_shift, -) -from .ldm.moge.model import detect_and_build -from .ldm.moge.state_dict import remap_moge_state_dict - - -class MoGeModel: - """Loaded MoGe model + ComfyUI memory management.""" - - def __init__(self, state_dict: dict): - self.load_device = comfy.model_management.text_encoder_device() - offload_device = comfy.model_management.text_encoder_offload_device() - self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) - - sd = remap_moge_state_dict(state_dict) - self.model = detect_and_build(sd, dtype=self.dtype, device=offload_device, - operations=comfy.ops.manual_cast) - self.model.load_state_dict(sd, strict=True) - self.model.eval() - self.patcher = comfy.model_patcher.CoreModelPatcher( - self.model, load_device=self.load_device, offload_device=offload_device - ) - self.version = "v2" if hasattr(self.model, "encoder") else "v1" - self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5)) - nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600)) - self.num_tokens_range = (int(nt[0]), int(nt[1])) - - @torch.inference_mode() - def infer(self, - image: torch.Tensor, - num_tokens: Optional[int] = None, - resolution_level: int = 9, - fov_x: Optional[Union[Number, torch.Tensor]] = None, - force_projection: bool = True, - apply_mask: bool = True) -> Dict[str, torch.Tensor]: - """Run a single MoGe forward + post-process pass. - - ``image`` must already be ``(B, 3, H, W)`` in ``[0, 1]`` on any device. - Returns a dict with at least ``points``, ``depth``, ``intrinsics``, - ``mask``; v2 checkpoints additionally produce ``normal``. - """ - comfy.model_management.load_model_gpu(self.patcher) - device = self.load_device - image = image.to(device=device, dtype=self.dtype) - - if image.dim() == 3: - image = image.unsqueeze(0) - H, W = image.shape[-2:] - aspect_ratio = W / H - - if num_tokens is None: - lo, hi = self.num_tokens_range - num_tokens = int(lo + (resolution_level / 9) * (hi - lo)) - - out = self.model.forward(image, num_tokens=num_tokens) - points = out.get("points") - normal = out.get("normal") - mask = out.get("mask") - metric_scale = out.get("metric_scale") - - # Post-processing always runs in fp32 for numerical stability. - if points is not None: points = points.float() - if normal is not None: normal = normal.float() - if mask is not None: mask = mask.float() - if metric_scale is not None: metric_scale = metric_scale.float() - - mask_binary = (mask > self.mask_threshold) if mask is not None else None - - depth = None - intrinsics = None - if points is not None: - if fov_x is None: - focal, shift = recover_focal_shift(points, mask_binary) - # The unconstrained least-squares solver inside recover_focal_shift - # can converge to a degenerate solution where (z + shift) is - # negative for most pixels, which flips the sign of the - # estimated focal. Detect that and fall back to a sensible - # 60-degree-FoV default rather than emitting garbage geometry. - bad = ~torch.isfinite(focal) | (focal <= 0) - if bool(bad.any()): - default_fov = 60.0 - fov_t = torch.as_tensor(default_fov, device=points.device, dtype=points.dtype) - fallback_focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 \ - / torch.tan(torch.deg2rad(fov_t / 2)) - fallback_focal = fallback_focal.expand_as(focal).clone() - focal = torch.where(bad, fallback_focal, focal) - _, shift = recover_focal_shift(points, mask_binary, focal=focal) - else: - fov_t = torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) - focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(fov_t / 2)) - if focal.ndim == 0: - focal = focal[None].expand(points.shape[0]) - _, shift = recover_focal_shift(points, mask_binary, focal=focal) - fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio - fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 - half = torch.tensor(0.5, device=points.device, dtype=points.dtype) - intrinsics = intrinsics_from_focal_center(fx, fy, half, half) - points[..., 2] = points[..., 2] + shift[..., None, None] - # v2 upstream additionally filters mask by depth > 0 as a safeguard - # against negative-depth artifacts from the metric-scale path; v1 - # does not, and applying it there can cut out the foreground when - # shift recovery overshoots slightly. - if mask_binary is not None and self.version == "v2": - mask_binary = mask_binary & (points[..., 2] > 0) - depth = points[..., 2].clone() - - if force_projection and depth is not None and intrinsics is not None: - points = depth_map_to_point_map(depth, intrinsics=intrinsics) - - if metric_scale is not None: - if points is not None: - points = points * metric_scale[:, None, None, None] - if depth is not None: - depth = depth * metric_scale[:, None, None] - - if apply_mask and mask_binary is not None: - if points is not None: - points = torch.where(mask_binary[..., None], points, - torch.full_like(points, float("inf"))) - if depth is not None: - depth = torch.where(mask_binary, depth, - torch.full_like(depth, float("inf"))) - if normal is not None: - normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) - - result = { - "points": points, - "depth": depth, - "intrinsics": intrinsics, - "mask": mask_binary, - "normal": normal, - } - return {k: v for k, v in result.items() if v is not None} - - -def load(ckpt_path: str) -> MoGeModel: - """Load a MoGe ``.pt`` / ``.safetensors`` checkpoint into a :class:`MoGeModel`.""" - sd = comfy.utils.load_torch_file(ckpt_path, safe_load=True) - if isinstance(sd, dict) and "model" in sd and "model_config" in sd: - sd = sd["model"] - return MoGeModel(sd) diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index 9667aa30d..436ca27ea 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -2,9 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Optional - import torch import comfy.utils @@ -14,19 +11,21 @@ from typing_extensions import override from comfy.ldm.moge.model import MoGeModel from comfy.ldm.moge.geometry import triangulate_grid_mesh +from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid +import comfy.model_management +from tqdm.auto import tqdm MoGeModelType = io.Custom("MOGE_MODEL") MoGeGeometry = io.Custom("MOGE_GEOMETRY") -@dataclass -class _MoGeGeometryPayload: - points: Optional[torch.Tensor] # (B, H, W, 3) - depth: Optional[torch.Tensor] # (B, H, W) - intrinsics: Optional[torch.Tensor] # (B, 3, 3) - mask: Optional[torch.Tensor] # (B, H, W) bool - normal: Optional[torch.Tensor] # (B, H, W, 3) or None for v1 - image: torch.Tensor # (B, H, W, 3) in [0, 1], CPU +# MOGE_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them): +# "points": torch.Tensor (B, H, W, 3) +# "depth": torch.Tensor (B, H, W) +# "intrinsics": torch.Tensor (B, 3, 3) -- perspective only +# "mask": torch.Tensor (B, H, W) bool +# "normal": torch.Tensor (B, H, W, 3) -- v2 only +# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present) def _turbo(x: torch.Tensor) -> torch.Tensor: @@ -137,15 +136,7 @@ class MoGePanoramaInference(io.ComfyNode): ) @classmethod - def execute(cls, moge_model, image, resolution_level, - split_resolution, merge_resolution, batch_size) -> io.NodeOutput: - from comfy.ldm.moge.panorama import ( - get_panorama_cameras, split_panorama_image, merge_panorama_depth, - spherical_uv_to_directions, _uv_grid, - ) - import comfy.model_management as cmm - import numpy as np - from tqdm.auto import tqdm + def execute(cls, moge_model, image, resolution_level, split_resolution, merge_resolution, batch_size) -> io.NodeOutput: if image.shape[0] != 1: raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})") @@ -156,7 +147,7 @@ class MoGePanoramaInference(io.ComfyNode): extrinsics, intrinsics = get_panorama_cameras() - cmm.load_model_gpu(moge_model.patcher) + comfy.model_management.load_model_gpu(moge_model.patcher) device = moge_model.load_device img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype) splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution) @@ -218,34 +209,25 @@ class MoGePanoramaInference(io.ComfyNode): merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics, on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end) + pano_depth = torch.from_numpy(pano_depth) + pano_mask = torch.from_numpy(pano_mask) + if (merge_h, merge_w) != (H, W): - t = torch.from_numpy(pano_depth).unsqueeze(0).unsqueeze(0) - pano_depth = torch.nn.functional.interpolate(t, size=(H, W), mode="bilinear", - align_corners=False).squeeze().numpy().astype(np.float32) - t = torch.from_numpy(pano_mask.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float() - pano_mask = (torch.nn.functional.interpolate(t, size=(H, W), mode="nearest").squeeze().numpy() > 0) + pano_depth = torch.nn.functional.interpolate(pano_depth[None, None], size=(H, W), mode="bilinear", align_corners=False).squeeze() + pano_mask = torch.nn.functional.interpolate(pano_mask[None, None].float(), size=(H, W), mode="nearest").squeeze() > 0 - # Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve - # and stay at log_depth=0 (depth=1) -- without this push-out they form a sphere shell - # woven through the foreground; here we lift them to a far skybox radius instead. + # Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve and stay at log_depth=0 (depth=1) if pano_mask.any() and not pano_mask.all(): - far = float(np.quantile(pano_depth[pano_mask], 0.95)) * 5.0 - pano_depth = np.where(pano_mask, pano_depth, far).astype(np.float32) + far = torch.quantile(pano_depth[pano_mask], 0.95) * 5.0 + pano_depth = torch.where(pano_mask, pano_depth, far) - uv = _uv_grid(H, W) - directions = spherical_uv_to_directions(uv) - points_np = directions * pano_depth[..., None] + directions = torch.from_numpy(spherical_uv_to_directions(_uv_grid(H, W))) + points = (directions * pano_depth[..., None]).unsqueeze(0) + depth = pano_depth.unsqueeze(0) + mask = pano_mask.unsqueeze(0) - points = torch.from_numpy(points_np).unsqueeze(0).float() - depth = torch.from_numpy(pano_depth).unsqueeze(0).float() - mask = torch.from_numpy(pano_mask).unsqueeze(0) - - # Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation - # after triangulation -- rotating before would scramble the rtol depth-edge check. - geometry = _MoGeGeometryPayload( - points=points, depth=depth, intrinsics=None, mask=mask, normal=None, - image=image.detach().cpu(), - ) + # Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation after triangulation + geometry = {"points": points, "depth": depth, "mask": mask, "image": image.cpu()} return io.NodeOutput(geometry) @@ -273,9 +255,7 @@ class MoGeInference(io.ComfyNode): ) @classmethod - def execute(cls, moge_model, image, resolution_level, fov_x_degrees, - batch_size, force_projection, apply_mask) -> io.NodeOutput: - from tqdm.auto import tqdm + def execute(cls, moge_model, image, resolution_level, fov_x_degrees, batch_size, force_projection, apply_mask) -> io.NodeOutput: bchw = image.movedim(-1, -3).contiguous() B = bchw.shape[0] @@ -295,20 +275,14 @@ class MoGeInference(io.ComfyNode): vals = [c[field] for c in chunks if field in c] return torch.cat(vals, dim=0) if vals else None - geometry = _MoGeGeometryPayload( - points=stack("points"), - depth=stack("depth"), - intrinsics=stack("intrinsics"), - mask=stack("mask"), - normal=stack("normal"), - image=image.detach().cpu(), - ) + geometry = {"image": image.cpu()} + for field in ("points", "depth", "intrinsics", "mask", "normal"): + v = stack(field) + if v is not None: + geometry[field] = v return io.NodeOutput(geometry) -_RENDER_MODES = ["depth", "depth_colored", "normal", "normal_screen", "mask"] - - class MoGeRender(io.ComfyNode): """Render a visualization or mask from a MOGE_GEOMETRY packet.""" @@ -320,36 +294,32 @@ class MoGeRender(io.ComfyNode): category="image/geometry", inputs=[ MoGeGeometry.Input("geometry"), - io.Combo.Input("output", options=_RENDER_MODES, default="depth_colored"), + io.Combo.Input("output", options=["depth", "depth_colored", "normal", "normal_screen", "mask"], default="depth"), ], outputs=[io.Image.Output()], ) @classmethod def execute(cls, geometry, output) -> io.NodeOutput: - from tqdm.auto import tqdm - # Pick the input tensor for the chosen mode and validate availability. if output in ("depth", "depth_colored", "normal_screen"): - if geometry.depth is None: + if "depth" not in geometry: raise ValueError("MoGeGeometry has no depth output.") - src = geometry.depth + src = geometry["depth"] elif output == "normal": - if geometry.normal is not None: - src = geometry.normal - elif geometry.points is not None: - src = geometry.points + if "normal" in geometry: + src = geometry["normal"] + elif "points" in geometry: + src = geometry["points"] else: raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.") elif output == "mask": - if geometry.mask is None: + if "mask" not in geometry: raise ValueError("MoGeGeometry has no mask output.") - src = geometry.mask + src = geometry["mask"] else: raise ValueError(f"Unknown output mode: {output}") - import comfy.model_management as cmm - B = src.shape[0] pbar = comfy.utils.ProgressBar(B) out: list[torch.Tensor] = [] @@ -361,7 +331,7 @@ class MoGeRender(io.ComfyNode): out.append(_turbo(d) if output == "depth_colored" else d.unsqueeze(-1).expand(*d.shape, 3).contiguous()) elif output == "normal": - n = slc if geometry.normal is not None else _normals_from_points(slc) + n = slc if "normal" in geometry else _normals_from_points(slc) out.append((n * 0.5 + 0.5).clamp(0.0, 1.0)) elif output == "normal_screen": n = _screen_normals_from_depth(slc) @@ -370,7 +340,7 @@ class MoGeRender(io.ComfyNode): out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous()) pbar.update_absolute(i + 1) tq.update(1) - result = torch.cat(out, dim=0).to(device=cmm.intermediate_device(), dtype=cmm.intermediate_dtype()) + result = torch.cat(out, dim=0).to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return io.NodeOutput(result) @@ -385,7 +355,7 @@ class MoGePointMapToMesh(io.ComfyNode): category="3d", inputs=[ MoGeGeometry.Input("geometry"), - io.Int.Input("batch_index", default=0, min=0, max=64, + io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts " "differ, so batches can't be stacked into a single MESH."), io.Int.Input("decimation", default=1, min=1, max=8, @@ -400,32 +370,32 @@ class MoGePointMapToMesh(io.ComfyNode): @classmethod def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput: - if geometry.points is None: + if "points" not in geometry: raise ValueError("MoGeGeometry has no points output.") - B = geometry.points.shape[0] + points = geometry["points"] + B = points.shape[0] if batch_index >= B: raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.") - # Pass geometry.depth so the rtol edge check sees radial depth -- for panoramas + # Pass depth so the rtol edge check sees radial depth -- for panoramas # points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half. - edge_depth = geometry.depth[batch_index] if geometry.depth is not None else None + edge_depth = geometry["depth"][batch_index] if "depth" in geometry else None verts, faces, uvs = triangulate_grid_mesh( - geometry.points[batch_index], decimation=decimation, + points[batch_index], decimation=decimation, discontinuity_threshold=discontinuity_threshold, depth=edge_depth, ) if verts.shape[0] == 0 or faces.shape[0] == 0: raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.") - if geometry.intrinsics is None: - # Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back). Pure rotation - # preserves the natural inward winding (correct for inside-the-sphere viewing). + if "intrinsics" not in geometry: + # Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back), correct for inside-the-sphere viewing) verts = verts[:, [1, 2, 0]].contiguous() else: # Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip. verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype) faces = faces[:, [0, 2, 1]].contiguous() - tex = geometry.image[batch_index:batch_index + 1] if texture and geometry.image is not None else None + tex = geometry["image"][batch_index:batch_index + 1] if texture else None mesh = Types.MESH( vertices=verts.unsqueeze(0), faces=faces.unsqueeze(0),