"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration.""" from __future__ import annotations from dataclasses import dataclass from typing import Optional import torch import comfy.utils import folder_paths from comfy_api.latest import ComfyExtension, Types, io from typing_extensions import override from comfy.ldm.moge.model import MoGeModel from comfy.ldm.moge.geometry import triangulate_grid_mesh MoGeModelType = io.Custom("MOGE_MODEL") MoGeGeometry = io.Custom("MOGE_GEOMETRY") @dataclass class _MoGeGeometryPayload: points: Optional[torch.Tensor] # (B, H, W, 3) depth: Optional[torch.Tensor] # (B, H, W) intrinsics: Optional[torch.Tensor] # (B, 3, 3) mask: Optional[torch.Tensor] # (B, H, W) bool normal: Optional[torch.Tensor] # (B, H, W, 3) or None for v1 image: torch.Tensor # (B, H, W, 3) in [0, 1], CPU def _turbo(x: torch.Tensor) -> torch.Tensor: """Anton Mikhailov polynomial approximation of the turbo colormap.""" x = x.clamp(0.0, 1.0) x2 = x * x x3 = x2 * x x4 = x2 * x2 x5 = x4 * x r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5 g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5 b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5 return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0) def _normals_from_points(points: torch.Tensor) -> torch.Tensor: """Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback).""" finite = torch.isfinite(points).all(dim=-1) pts = torch.where(finite.unsqueeze(-1), points, torch.zeros_like(points)) dx = pts[..., :, 2:, :] - pts[..., :, :-2, :] dy = pts[..., 2:, :, :] - pts[..., :-2, :, :] dx = torch.nn.functional.pad(dx.permute(0, 3, 1, 2), (1, 1, 0, 0)).permute(0, 2, 3, 1) dy = torch.nn.functional.pad(dy.permute(0, 3, 1, 2), (0, 0, 1, 1)).permute(0, 2, 3, 1) n = torch.cross(dx, dy, dim=-1) n = torch.nn.functional.normalize(n, dim=-1) return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n)) def _screen_normals_from_depth(depth: torch.Tensor) -> torch.Tensor: """Screen-space surface normals (X right, Y down, Z into scene).""" finite = torch.isfinite(depth) & (depth > 0) d = torch.where(finite, depth, torch.zeros_like(depth)) H, W = d.shape[-2:] d4d = d.unsqueeze(1) # Scale gradients to normalized image coords so a 45 deg tilt lands as a 45 deg normal regardless of resolution. dz_dx = (d4d[..., :, 2:] - d4d[..., :, :-2]) * (W / 2.0) dz_dy = (d4d[..., 2:, :] - d4d[..., :-2, :]) * (H / 2.0) dz_dx = torch.nn.functional.pad(dz_dx, (1, 1, 0, 0)).squeeze(1) dz_dy = torch.nn.functional.pad(dz_dy, (0, 0, 1, 1)).squeeze(1) n = torch.stack([-dz_dx, -dz_dy, torch.ones_like(d)], dim=-1) n = torch.nn.functional.normalize(n, dim=-1) return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n)) def _normalize_disparity(depth: torch.Tensor) -> torch.Tensor: """Per-batch normalize 1/depth to [0, 1] using 0.1/99.9 percentile clipping.""" out = torch.zeros_like(depth) for i in range(depth.shape[0]): d = depth[i] valid = torch.isfinite(d) & (d > 0) if not valid.any(): continue disp = torch.where(valid, 1.0 / d.clamp_min(1e-6), torch.zeros_like(d)) disp_valid = disp[valid] lo = torch.quantile(disp_valid, 0.001) hi = torch.quantile(disp_valid, 0.999) scale = (hi - lo).clamp_min(1e-6) norm = ((disp - lo) / scale).clamp(0.0, 1.0) out[i] = torch.where(valid, norm, torch.zeros_like(norm)) return out class LoadMoGeModel(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="LoadMoGeModel", display_name="Load MoGe Model", category="loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("moge")), ], outputs=[MoGeModelType.Output()], ) @classmethod def execute(cls, model_name) -> io.NodeOutput: path = folder_paths.get_full_path_or_raise("moge", model_name) sd = comfy.utils.load_torch_file(path, safe_load=True) return io.NodeOutput(MoGeModel(sd)) class MoGePanoramaInference(io.ComfyNode): """Equirectangular panorama inference: split into 12 perspective views, run MoGe at fov_x=90 on each, merge via multi-scale Poisson + gradient solve. v2's predicted normals and metric scale are ignored (per-view scales would not align across seams). """ @classmethod def define_schema(cls): return io.Schema( node_id="MoGePanoramaInference", display_name="MoGe Panorama Inference", category="image/geometry", inputs=[ MoGeModelType.Input("moge_model"), io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."), io.Int.Input("resolution_level", default=9, min=0, max=9, tooltip="Per-view detail (0 = fast, 9 = slow)."), io.Int.Input("split_resolution", default=512, min=256, max=1024, tooltip="Resolution of each perspective split."), io.Int.Input("merge_resolution", default=1920, min=256, max=8192, tooltip="Long-side resolution of the merged equirect distance map."), io.Int.Input("batch_size", default=4, min=1, max=12, tooltip="Views per inference batch (12 splits total)."), ], outputs=[MoGeGeometry.Output(display_name="geometry")], ) @classmethod def execute(cls, moge_model, image, resolution_level, split_resolution, merge_resolution, batch_size) -> io.NodeOutput: from comfy.ldm.moge.panorama import ( get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid, ) import comfy.model_management as cmm import numpy as np from tqdm.auto import tqdm if image.shape[0] != 1: raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})") H, W = int(image.shape[1]), int(image.shape[2]) scale = min(merge_resolution / max(H, W), 1.0) merge_h, merge_w = max(int(H * scale), 32), max(int(W * scale), 32) extrinsics, intrinsics = get_panorama_cameras() cmm.load_model_gpu(moge_model.patcher) device = moge_model.load_device img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype) splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution) n_views = splits.shape[0] # Weight each lsmr solve by 4^level so the final-resolution solve doesn't leave the bar idle. merge_levels: list[tuple[int, int]] = [] w_, h_ = merge_w, merge_h while True: merge_levels.append((w_, h_)) if max(w_, h_) <= 256: break w_, h_ = w_ // 2, h_ // 2 merge_levels.reverse() solve_weight = {wh: 4 ** i for i, wh in enumerate(merge_levels)} n_merge_view_units = n_views * len(merge_levels) n_merge_solve_units = sum(solve_weight.values()) pbar = comfy.utils.ProgressBar(n_views + n_merge_view_units + n_merge_solve_units) done = 0 distance_maps: list = [] masks: list = [] with tqdm(total=n_views, desc="MoGe panorama inference") as tq: for i in range(0, n_views, batch_size): batch = splits[i:i + batch_size] # apply_metric_scale=False: per-view scales would not align across overlap seams. result = moge_model.infer(batch, resolution_level=resolution_level, fov_x=90.0, force_projection=True, apply_mask=False, apply_metric_scale=False) distance_maps.extend(list(result["points"].float().norm(dim=-1).cpu().numpy())) masks.extend(list(result["mask"].cpu().numpy())) n = batch.shape[0] done += n pbar.update_absolute(done) tq.update(n) with tqdm(total=n_merge_view_units + n_merge_solve_units, desc="MoGe panorama merge: views") as tq: def _on_merge_view(): nonlocal done done += 1 pbar.update_absolute(done) tq.update(1) def _on_solve_start(w, h): tq.set_description(f"MoGe panorama merge: solving {w}x{h}") def _on_solve_end(w, h): nonlocal done weight = solve_weight[(w, h)] done += weight pbar.update_absolute(done) tq.update(weight) tq.set_description("MoGe panorama merge: views") pano_depth, pano_mask = merge_panorama_depth( merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics, on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end) if (merge_h, merge_w) != (H, W): t = torch.from_numpy(pano_depth).unsqueeze(0).unsqueeze(0) pano_depth = torch.nn.functional.interpolate(t, size=(H, W), mode="bilinear", align_corners=False).squeeze().numpy().astype(np.float32) t = torch.from_numpy(pano_mask.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float() pano_mask = (torch.nn.functional.interpolate(t, size=(H, W), mode="nearest").squeeze().numpy() > 0) # Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve # and stay at log_depth=0 (depth=1) -- without this push-out they form a sphere shell # woven through the foreground; here we lift them to a far skybox radius instead. if pano_mask.any() and not pano_mask.all(): far = float(np.quantile(pano_depth[pano_mask], 0.95)) * 5.0 pano_depth = np.where(pano_mask, pano_depth, far).astype(np.float32) uv = _uv_grid(H, W) directions = spherical_uv_to_directions(uv) points_np = directions * pano_depth[..., None] points = torch.from_numpy(points_np).unsqueeze(0).float() depth = torch.from_numpy(pano_depth).unsqueeze(0).float() mask = torch.from_numpy(pano_mask).unsqueeze(0) # Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation # after triangulation -- rotating before would scramble the rtol depth-edge check. geometry = _MoGeGeometryPayload( points=points, depth=depth, intrinsics=None, mask=mask, normal=None, image=image.detach().cpu(), ) return io.NodeOutput(geometry) class MoGeInference(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="MoGeInference", display_name="MoGe Inference", category="image/geometry", inputs=[ MoGeModelType.Input("moge_model"), io.Image.Input("image"), io.Int.Input("resolution_level", default=9, min=0, max=9, tooltip="0 = fastest, 9 = most detail."), io.Float.Input("fov_x_degrees", default=0.0, min=0.0, max=170.0, step=0.1, tooltip="Override horizontal FoV. 0 = auto."), io.Int.Input("batch_size", default=4, min=1, max=64, tooltip="Images per inference call. Lower if you OOM on a long video / image set."), io.Boolean.Input("force_projection", default=True), io.Boolean.Input("apply_mask", default=True, tooltip="Set masked-out points/depth to inf."), ], outputs=[MoGeGeometry.Output(display_name="geometry")], ) @classmethod def execute(cls, moge_model, image, resolution_level, fov_x_degrees, batch_size, force_projection, apply_mask) -> io.NodeOutput: from tqdm.auto import tqdm bchw = image.movedim(-1, -3).contiguous() B = bchw.shape[0] fov = None if fov_x_degrees <= 0 else float(fov_x_degrees) pbar = comfy.utils.ProgressBar(B) chunks: list[dict] = [] with tqdm(total=B, desc="MoGe inference") as tq: for i in range(0, B, batch_size): chunk = bchw[i:i + batch_size] chunks.append(moge_model.infer(chunk, resolution_level=resolution_level, fov_x=fov, force_projection=force_projection, apply_mask=apply_mask)) pbar.update_absolute(min(i + batch_size, B)) tq.update(chunk.shape[0]) def stack(field): vals = [c[field] for c in chunks if field in c] return torch.cat(vals, dim=0) if vals else None geometry = _MoGeGeometryPayload( points=stack("points"), depth=stack("depth"), intrinsics=stack("intrinsics"), mask=stack("mask"), normal=stack("normal"), image=image.detach().cpu(), ) return io.NodeOutput(geometry) _RENDER_MODES = ["depth", "depth_colored", "normal", "normal_screen", "mask"] class MoGeRender(io.ComfyNode): """Render a visualization or mask from a MOGE_GEOMETRY packet.""" @classmethod def define_schema(cls): return io.Schema( node_id="MoGeRender", display_name="MoGe Render", category="image/geometry", inputs=[ MoGeGeometry.Input("geometry"), io.Combo.Input("output", options=_RENDER_MODES, default="depth_colored"), ], outputs=[io.Image.Output()], ) @classmethod def execute(cls, geometry, output) -> io.NodeOutput: from tqdm.auto import tqdm # Pick the input tensor for the chosen mode and validate availability. if output in ("depth", "depth_colored", "normal_screen"): if geometry.depth is None: raise ValueError("MoGeGeometry has no depth output.") src = geometry.depth elif output == "normal": if geometry.normal is not None: src = geometry.normal elif geometry.points is not None: src = geometry.points else: raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.") elif output == "mask": if geometry.mask is None: raise ValueError("MoGeGeometry has no mask output.") src = geometry.mask else: raise ValueError(f"Unknown output mode: {output}") import comfy.model_management as cmm B = src.shape[0] pbar = comfy.utils.ProgressBar(B) out: list[torch.Tensor] = [] with tqdm(total=B, desc=f"MoGe render: {output}") as tq: for i in range(B): slc = src[i:i + 1].float() if output in ("depth", "depth_colored"): d = _normalize_disparity(slc) out.append(_turbo(d) if output == "depth_colored" else d.unsqueeze(-1).expand(*d.shape, 3).contiguous()) elif output == "normal": n = slc if geometry.normal is not None else _normals_from_points(slc) out.append((n * 0.5 + 0.5).clamp(0.0, 1.0)) elif output == "normal_screen": n = _screen_normals_from_depth(slc) out.append((n * 0.5 + 0.5).clamp(0.0, 1.0)) elif output == "mask": out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous()) pbar.update_absolute(i + 1) tq.update(1) result = torch.cat(out, dim=0).to(device=cmm.intermediate_device(), dtype=cmm.intermediate_dtype()) return io.NodeOutput(result) class MoGePointMapToMesh(io.ComfyNode): """Triangulate one image of a MoGe point map into a Types.MESH (UVs + texture).""" @classmethod def define_schema(cls): return io.Schema( node_id="MoGePointMapToMesh", display_name="MoGe Point Map to Mesh", category="3d", inputs=[ MoGeGeometry.Input("geometry"), io.Int.Input("batch_index", default=0, min=0, max=64, tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts " "differ, so batches can't be stacked into a single MESH."), io.Int.Input("decimation", default=1, min=1, max=8, tooltip="Vertex stride; 1 = full resolution."), io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, tooltip="Drop pixels whose 3x3 depth span exceeds this fraction. 0 = off."), io.Boolean.Input("texture", default=True, tooltip="Carry the source image through as the baseColor texture."), ], outputs=[io.Mesh.Output()], ) @classmethod def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput: if geometry.points is None: raise ValueError("MoGeGeometry has no points output.") B = geometry.points.shape[0] if batch_index >= B: raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.") # Pass geometry.depth so the rtol edge check sees radial depth -- for panoramas # points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half. edge_depth = geometry.depth[batch_index] if geometry.depth is not None else None verts, faces, uvs = triangulate_grid_mesh( geometry.points[batch_index], decimation=decimation, discontinuity_threshold=discontinuity_threshold, depth=edge_depth, ) if verts.shape[0] == 0 or faces.shape[0] == 0: raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.") if geometry.intrinsics is None: # Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back). Pure rotation # preserves the natural inward winding (correct for inside-the-sphere viewing). verts = verts[:, [1, 2, 0]].contiguous() else: # Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip. verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype) faces = faces[:, [0, 2, 1]].contiguous() tex = geometry.image[batch_index:batch_index + 1] if texture and geometry.image is not None else None mesh = Types.MESH( vertices=verts.unsqueeze(0), faces=faces.unsqueeze(0), uvs=uvs.unsqueeze(0), texture=tex, ) return io.NodeOutput(mesh) class MoGeExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh] async def comfy_entrypoint() -> MoGeExtension: return MoGeExtension()