"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration.""" from __future__ import annotations import torch import comfy.utils import folder_paths from comfy_api.latest import ComfyExtension, Types, io from typing_extensions import override from comfy.ldm.moge.model import MoGeModel from comfy.ldm.moge.geometry import triangulate_grid_mesh from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid import comfy.model_management from tqdm.auto import tqdm MoGeModelType = io.Custom("MOGE_MODEL") MoGeGeometry = io.Custom("MOGE_GEOMETRY") # MOGE_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them): # "points": torch.Tensor (B, H, W, 3) # "depth": torch.Tensor (B, H, W) # "intrinsics": torch.Tensor (B, 3, 3) -- perspective only # "mask": torch.Tensor (B, H, W) bool # "normal": torch.Tensor (B, H, W, 3) -- v2 only # "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present) def _turbo(x: torch.Tensor) -> torch.Tensor: """Anton Mikhailov polynomial approximation of the turbo colormap.""" x = x.clamp(0.0, 1.0) x2 = x * x x3 = x2 * x x4 = x2 * x2 x5 = x4 * x r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5 g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5 b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5 return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0) def _normals_from_points(points: torch.Tensor) -> torch.Tensor: """Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback).""" finite = torch.isfinite(points).all(dim=-1) pts = torch.where(finite.unsqueeze(-1), points, torch.zeros_like(points)) dx = pts[..., :, 2:, :] - pts[..., :, :-2, :] dy = pts[..., 2:, :, :] - pts[..., :-2, :, :] dx = torch.nn.functional.pad(dx.permute(0, 3, 1, 2), (1, 1, 0, 0)).permute(0, 2, 3, 1) dy = torch.nn.functional.pad(dy.permute(0, 3, 1, 2), (0, 0, 1, 1)).permute(0, 2, 3, 1) # dy x dx (not dx x dy) so the result is outward-facing in OpenCV (Y-down flips the right-hand rule), matching v2's predicted normals. n = torch.cross(dy, dx, dim=-1) n = torch.nn.functional.normalize(n, dim=-1) return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n)) def _normalize_disparity(depth: torch.Tensor) -> torch.Tensor: """Per-batch normalize 1/depth to [0, 1] using 0.1/99.9 percentile clipping.""" out = torch.zeros_like(depth) for i in range(depth.shape[0]): d = depth[i] valid = torch.isfinite(d) & (d > 0) if not valid.any(): continue disp = torch.where(valid, 1.0 / d.clamp_min(1e-6), torch.zeros_like(d)) disp_valid = disp[valid] lo = torch.quantile(disp_valid, 0.001) hi = torch.quantile(disp_valid, 0.999) scale = (hi - lo).clamp_min(1e-6) norm = ((disp - lo) / scale).clamp(0.0, 1.0) out[i] = torch.where(valid, norm, torch.zeros_like(norm)) return out class LoadMoGeModel(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="LoadMoGeModel", display_name="Load MoGe Model", category="loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("moge")), ], outputs=[MoGeModelType.Output()], ) @classmethod def execute(cls, model_name) -> io.NodeOutput: path = folder_paths.get_full_path_or_raise("moge", model_name) sd = comfy.utils.load_torch_file(path, safe_load=True) return io.NodeOutput(MoGeModel(sd)) class MoGePanoramaInference(io.ComfyNode): """Equirectangular panorama inference: split into 12 perspective views, run MoGe at fov_x=90 on each, merge via multi-scale Poisson + gradient solve. v2's predicted normals and metric scale are ignored (per-view scales would not align across seams). """ @classmethod def define_schema(cls): return io.Schema( node_id="MoGePanoramaInference", display_name="MoGe Panorama Inference", category="image/geometry", inputs=[ MoGeModelType.Input("moge_model"), io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."), io.Int.Input("resolution_level", default=9, min=0, max=9, tooltip="Per-view detail (0 = fast, 9 = slow)."), io.Int.Input("split_resolution", default=512, min=256, max=1024, tooltip="Resolution of each perspective split."), io.Int.Input("merge_resolution", default=1920, min=256, max=8192, tooltip="Long-side resolution of the merged equirect distance map."), io.Int.Input("batch_size", default=4, min=1, max=12, tooltip="Views per inference batch (12 splits total)."), ], outputs=[MoGeGeometry.Output(display_name="geometry")], ) @classmethod def execute(cls, moge_model, image, resolution_level, split_resolution, merge_resolution, batch_size) -> io.NodeOutput: if image.shape[0] != 1: raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})") image = image[..., :3] H, W = int(image.shape[1]), int(image.shape[2]) scale = min(merge_resolution / max(H, W), 1.0) merge_h, merge_w = max(int(H * scale), 32), max(int(W * scale), 32) extrinsics, intrinsics = get_panorama_cameras() comfy.model_management.load_model_gpu(moge_model.patcher) device = moge_model.load_device img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype) splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution) n_views = splits.shape[0] # Weight each lsmr solve by 4^level so the final-resolution solve doesn't leave the bar idle. merge_levels: list[tuple[int, int]] = [] w_, h_ = merge_w, merge_h while True: merge_levels.append((w_, h_)) if max(w_, h_) <= 256: break w_, h_ = w_ // 2, h_ // 2 merge_levels.reverse() solve_weight = {wh: 4 ** i for i, wh in enumerate(merge_levels)} n_merge_view_units = n_views * len(merge_levels) n_merge_solve_units = sum(solve_weight.values()) pbar = comfy.utils.ProgressBar(n_views + n_merge_view_units + n_merge_solve_units) done = 0 distance_maps: list = [] masks: list = [] with tqdm(total=n_views, desc="MoGe panorama inference") as tq: for i in range(0, n_views, batch_size): batch = splits[i:i + batch_size] # apply_metric_scale=False: per-view scales would not align across overlap seams. result = moge_model.infer(batch, resolution_level=resolution_level, fov_x=90.0, force_projection=True, apply_mask=False, apply_metric_scale=False) distance_maps.extend(list(result["points"].float().norm(dim=-1).cpu().numpy())) masks.extend(list(result["mask"].cpu().numpy())) n = batch.shape[0] done += n pbar.update_absolute(done) tq.update(n) with tqdm(total=n_merge_view_units + n_merge_solve_units, desc="MoGe panorama merge: views") as tq: def _on_merge_view(): nonlocal done done += 1 pbar.update_absolute(done) tq.update(1) def _on_solve_start(w, h): tq.set_description(f"MoGe panorama merge: solving {w}x{h}") def _on_solve_end(w, h): nonlocal done weight = solve_weight[(w, h)] done += weight pbar.update_absolute(done) tq.update(weight) tq.set_description("MoGe panorama merge: views") pano_depth, pano_mask = merge_panorama_depth( merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics, on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end) pano_depth = torch.from_numpy(pano_depth) pano_mask = torch.from_numpy(pano_mask) if (merge_h, merge_w) != (H, W): pano_depth = torch.nn.functional.interpolate(pano_depth[None, None], size=(H, W), mode="bilinear", align_corners=False).squeeze() pano_mask = torch.nn.functional.interpolate(pano_mask[None, None].float(), size=(H, W), mode="nearest").squeeze() > 0 # Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve and stay at log_depth=0 (depth=1) if pano_mask.any() and not pano_mask.all(): far = torch.quantile(pano_depth[pano_mask], 0.95) * 5.0 pano_depth = torch.where(pano_mask, pano_depth, far) directions = torch.from_numpy(spherical_uv_to_directions(_uv_grid(H, W))) points = (directions * pano_depth[..., None]).unsqueeze(0) depth = pano_depth.unsqueeze(0) mask = pano_mask.unsqueeze(0) # Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation after triangulation geometry = {"points": points, "depth": depth, "mask": mask, "image": image.cpu()} return io.NodeOutput(geometry) class MoGeInference(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="MoGeInference", display_name="MoGe Inference", category="image/geometry", inputs=[ MoGeModelType.Input("moge_model"), io.Image.Input("image"), io.Int.Input("resolution_level", default=9, min=0, max=9, tooltip="0 = fastest, 9 = most detail."), io.Float.Input("fov_x_degrees", default=0.0, min=0.0, max=170.0, step=0.1, tooltip="Horizontal field of view of the source camera. Sets the focal length used to unproject the depth map into 3D. 0 = auto-recover from the predicted points."), io.Int.Input("batch_size", default=4, min=1, max=64, tooltip="Images per inference call. Lower if you OOM on a long video / image set."), io.Boolean.Input("force_projection", default=True), io.Boolean.Input("apply_mask", default=True, tooltip="Set masked-out (sky / invalid) pixels to inf in points and depth so meshing culls them. Disable to keep the raw predicted geometry everywhere; the mask is still returned separately."), ], outputs=[MoGeGeometry.Output(display_name="geometry")], ) @classmethod def execute(cls, moge_model, image, resolution_level, fov_x_degrees, batch_size, force_projection, apply_mask) -> io.NodeOutput: image = image[..., :3] bchw = image.movedim(-1, -3).contiguous() B = bchw.shape[0] fov = None if fov_x_degrees <= 0 else float(fov_x_degrees) pbar = comfy.utils.ProgressBar(B) chunks: list[dict] = [] with tqdm(total=B, desc="MoGe inference") as tq: for i in range(0, B, batch_size): chunk = bchw[i:i + batch_size] chunks.append(moge_model.infer(chunk, resolution_level=resolution_level, fov_x=fov, force_projection=force_projection, apply_mask=apply_mask)) pbar.update_absolute(min(i + batch_size, B)) tq.update(chunk.shape[0]) def stack(field): vals = [c[field] for c in chunks if field in c] return torch.cat(vals, dim=0) if vals else None geometry = {"image": image.cpu()} for field in ("points", "depth", "intrinsics", "mask", "normal"): v = stack(field) if v is not None: geometry[field] = v return io.NodeOutput(geometry) class MoGeRender(io.ComfyNode): """Render a visualization or mask from a MOGE_GEOMETRY packet.""" @classmethod def define_schema(cls): return io.Schema( node_id="MoGeRender", display_name="MoGe Render", category="image/geometry", inputs=[ MoGeGeometry.Input("geometry"), io.Combo.Input("output", options=["depth", "depth_colored", "normal_opengl", "normal_directx", "mask"], default="depth", tooltip="DirectX vs OpenGL controls the normal-map green-channel convention. DirectX: green = -Y down (Unreal). OpenGL: green = +Y up (Blender, Substance, Unity, glTF)."), ], outputs=[io.Image.Output()], ) @classmethod def execute(cls, geometry, output) -> io.NodeOutput: is_normal = output in ("normal_directx", "normal_opengl") opengl = output.endswith("_opengl") # Pick the input tensor for the chosen mode and validate availability. if output in ("depth", "depth_colored"): if "depth" not in geometry: raise ValueError("MoGeGeometry has no depth output.") src = geometry["depth"] elif is_normal: if "normal" in geometry: src = geometry["normal"] elif "points" in geometry: src = geometry["points"] else: raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.") elif output == "mask": if "mask" not in geometry: raise ValueError("MoGeGeometry has no mask output.") src = geometry["mask"] else: raise ValueError(f"Unknown output mode: {output}") B = src.shape[0] pbar = comfy.utils.ProgressBar(B) out: list[torch.Tensor] = [] with tqdm(total=B, desc=f"MoGe render: {output}") as tq: for i in range(B): slc = src[i:i + 1].float() if output in ("depth", "depth_colored"): d = _normalize_disparity(slc) out.append(_turbo(d) if output == "depth_colored" else d.unsqueeze(-1).expand(*d.shape, 3).contiguous()) elif is_normal: n = slc if "normal" in geometry else _normals_from_points(slc) # MoGe is OpenCV (Z+ into scene); normal-map convention is Z+ out of surface, so flip Z. y_sign = -1.0 if opengl else 1.0 n = n * n.new_tensor([1.0, y_sign, -1.0]) out.append((n * 0.5 + 0.5).clamp(0.0, 1.0)) elif output == "mask": out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous()) pbar.update_absolute(i + 1) tq.update(1) result = torch.cat(out, dim=0).to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return io.NodeOutput(result) class MoGePointMapToMesh(io.ComfyNode): """Triangulate one image of a MoGe point map into a Types.MESH (UVs + texture).""" @classmethod def define_schema(cls): return io.Schema( node_id="MoGePointMapToMesh", display_name="MoGe Point Map to Mesh", category="3d", inputs=[ MoGeGeometry.Input("geometry"), io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts " "differ, so batches can't be stacked into a single MESH."), io.Int.Input("decimation", default=1, min=1, max=8, tooltip="Vertex stride; 1 = full resolution."), io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, tooltip="Drop pixels whose 3x3 depth span exceeds this fraction. 0 = off."), io.Boolean.Input("texture", default=True, tooltip="Carry the source image through as the baseColor texture."), ], outputs=[io.Mesh.Output()], ) @classmethod def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput: if "points" not in geometry: raise ValueError("MoGeGeometry has no points output.") points = geometry["points"] B = points.shape[0] if batch_index >= B: raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.") # Pass depth so the rtol edge check sees radial depth -- for panoramas # points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half. edge_depth = geometry["depth"][batch_index] if "depth" in geometry else None verts, faces, uvs = triangulate_grid_mesh( points[batch_index], decimation=decimation, discontinuity_threshold=discontinuity_threshold, depth=edge_depth, ) if verts.shape[0] == 0 or faces.shape[0] == 0: raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.") if "intrinsics" not in geometry: # Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back), correct for inside-the-sphere viewing) verts = verts[:, [1, 2, 0]].contiguous() else: # Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip. verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype) faces = faces[:, [0, 2, 1]].contiguous() tex = geometry["image"][batch_index:batch_index + 1] if texture else None mesh = Types.MESH( vertices=verts.unsqueeze(0), faces=faces.unsqueeze(0), uvs=uvs.unsqueeze(0), texture=tex, ) return io.NodeOutput(mesh) class MoGeExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh] async def comfy_entrypoint() -> MoGeExtension: return MoGeExtension()