ComfyUI/comfy_extras/nodes_moge.py
2026-05-12 16:09:24 +03:00

446 lines
20 KiB
Python

"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, Types, io
from typing_extensions import override
from comfy.ldm.moge.model import MoGeModel
from comfy.ldm.moge.geometry import triangulate_grid_mesh
MoGeModelType = io.Custom("MOGE_MODEL")
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
@dataclass
class _MoGeGeometryPayload:
points: Optional[torch.Tensor] # (B, H, W, 3)
depth: Optional[torch.Tensor] # (B, H, W)
intrinsics: Optional[torch.Tensor] # (B, 3, 3)
mask: Optional[torch.Tensor] # (B, H, W) bool
normal: Optional[torch.Tensor] # (B, H, W, 3) or None for v1
image: torch.Tensor # (B, H, W, 3) in [0, 1], CPU
def _turbo(x: torch.Tensor) -> torch.Tensor:
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
x = x.clamp(0.0, 1.0)
x2 = x * x
x3 = x2 * x
x4 = x2 * x2
x5 = x4 * x
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
finite = torch.isfinite(points).all(dim=-1)
pts = torch.where(finite.unsqueeze(-1), points, torch.zeros_like(points))
dx = pts[..., :, 2:, :] - pts[..., :, :-2, :]
dy = pts[..., 2:, :, :] - pts[..., :-2, :, :]
dx = torch.nn.functional.pad(dx.permute(0, 3, 1, 2), (1, 1, 0, 0)).permute(0, 2, 3, 1)
dy = torch.nn.functional.pad(dy.permute(0, 3, 1, 2), (0, 0, 1, 1)).permute(0, 2, 3, 1)
n = torch.cross(dx, dy, dim=-1)
n = torch.nn.functional.normalize(n, dim=-1)
return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n))
def _screen_normals_from_depth(depth: torch.Tensor) -> torch.Tensor:
"""Screen-space surface normals (X right, Y down, Z into scene)."""
finite = torch.isfinite(depth) & (depth > 0)
d = torch.where(finite, depth, torch.zeros_like(depth))
H, W = d.shape[-2:]
d4d = d.unsqueeze(1)
# Scale gradients to normalized image coords so a 45 deg tilt lands as a 45 deg normal regardless of resolution.
dz_dx = (d4d[..., :, 2:] - d4d[..., :, :-2]) * (W / 2.0)
dz_dy = (d4d[..., 2:, :] - d4d[..., :-2, :]) * (H / 2.0)
dz_dx = torch.nn.functional.pad(dz_dx, (1, 1, 0, 0)).squeeze(1)
dz_dy = torch.nn.functional.pad(dz_dy, (0, 0, 1, 1)).squeeze(1)
n = torch.stack([-dz_dx, -dz_dy, torch.ones_like(d)], dim=-1)
n = torch.nn.functional.normalize(n, dim=-1)
return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n))
def _normalize_disparity(depth: torch.Tensor) -> torch.Tensor:
"""Per-batch normalize 1/depth to [0, 1] using 0.1/99.9 percentile clipping."""
out = torch.zeros_like(depth)
for i in range(depth.shape[0]):
d = depth[i]
valid = torch.isfinite(d) & (d > 0)
if not valid.any():
continue
disp = torch.where(valid, 1.0 / d.clamp_min(1e-6), torch.zeros_like(d))
disp_valid = disp[valid]
lo = torch.quantile(disp_valid, 0.001)
hi = torch.quantile(disp_valid, 0.999)
scale = (hi - lo).clamp_min(1e-6)
norm = ((disp - lo) / scale).clamp(0.0, 1.0)
out[i] = torch.where(valid, norm, torch.zeros_like(norm))
return out
class LoadMoGeModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadMoGeModel",
display_name="Load MoGe Model",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("moge")),
],
outputs=[MoGeModelType.Output()],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
path = folder_paths.get_full_path_or_raise("moge", model_name)
sd = comfy.utils.load_torch_file(path, safe_load=True)
return io.NodeOutput(MoGeModel(sd))
class MoGePanoramaInference(io.ComfyNode):
"""Equirectangular panorama inference: split into 12 perspective views, run
MoGe at fov_x=90 on each, merge via multi-scale Poisson + gradient solve.
v2's predicted normals and metric scale are ignored (per-view scales would not align across seams).
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGePanoramaInference",
display_name="MoGe Panorama Inference",
category="image/geometry",
inputs=[
MoGeModelType.Input("moge_model"),
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
io.Int.Input("resolution_level", default=9, min=0, max=9,
tooltip="Per-view detail (0 = fast, 9 = slow)."),
io.Int.Input("split_resolution", default=512, min=256, max=1024,
tooltip="Resolution of each perspective split."),
io.Int.Input("merge_resolution", default=1920, min=256, max=8192,
tooltip="Long-side resolution of the merged equirect distance map."),
io.Int.Input("batch_size", default=4, min=1, max=12,
tooltip="Views per inference batch (12 splits total)."),
],
outputs=[MoGeGeometry.Output(display_name="geometry")],
)
@classmethod
def execute(cls, moge_model, image, resolution_level,
split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
from comfy.ldm.moge.panorama import (
get_panorama_cameras, split_panorama_image, merge_panorama_depth,
spherical_uv_to_directions, _uv_grid,
)
import comfy.model_management as cmm
import numpy as np
from tqdm.auto import tqdm
if image.shape[0] != 1:
raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})")
H, W = int(image.shape[1]), int(image.shape[2])
scale = min(merge_resolution / max(H, W), 1.0)
merge_h, merge_w = max(int(H * scale), 32), max(int(W * scale), 32)
extrinsics, intrinsics = get_panorama_cameras()
cmm.load_model_gpu(moge_model.patcher)
device = moge_model.load_device
img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype)
splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution)
n_views = splits.shape[0]
# Weight each lsmr solve by 4^level so the final-resolution solve doesn't leave the bar idle.
merge_levels: list[tuple[int, int]] = []
w_, h_ = merge_w, merge_h
while True:
merge_levels.append((w_, h_))
if max(w_, h_) <= 256:
break
w_, h_ = w_ // 2, h_ // 2
merge_levels.reverse()
solve_weight = {wh: 4 ** i for i, wh in enumerate(merge_levels)}
n_merge_view_units = n_views * len(merge_levels)
n_merge_solve_units = sum(solve_weight.values())
pbar = comfy.utils.ProgressBar(n_views + n_merge_view_units + n_merge_solve_units)
done = 0
distance_maps: list = []
masks: list = []
with tqdm(total=n_views, desc="MoGe panorama inference") as tq:
for i in range(0, n_views, batch_size):
batch = splits[i:i + batch_size]
# apply_metric_scale=False: per-view scales would not align across overlap seams.
result = moge_model.infer(batch, resolution_level=resolution_level,
fov_x=90.0, force_projection=True,
apply_mask=False, apply_metric_scale=False)
distance_maps.extend(list(result["points"].float().norm(dim=-1).cpu().numpy()))
masks.extend(list(result["mask"].cpu().numpy()))
n = batch.shape[0]
done += n
pbar.update_absolute(done)
tq.update(n)
with tqdm(total=n_merge_view_units + n_merge_solve_units, desc="MoGe panorama merge: views") as tq:
def _on_merge_view():
nonlocal done
done += 1
pbar.update_absolute(done)
tq.update(1)
def _on_solve_start(w, h):
tq.set_description(f"MoGe panorama merge: solving {w}x{h}")
def _on_solve_end(w, h):
nonlocal done
weight = solve_weight[(w, h)]
done += weight
pbar.update_absolute(done)
tq.update(weight)
tq.set_description("MoGe panorama merge: views")
pano_depth, pano_mask = merge_panorama_depth(
merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics,
on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end)
if (merge_h, merge_w) != (H, W):
t = torch.from_numpy(pano_depth).unsqueeze(0).unsqueeze(0)
pano_depth = torch.nn.functional.interpolate(t, size=(H, W), mode="bilinear",
align_corners=False).squeeze().numpy().astype(np.float32)
t = torch.from_numpy(pano_mask.astype(np.uint8)).unsqueeze(0).unsqueeze(0).float()
pano_mask = (torch.nn.functional.interpolate(t, size=(H, W), mode="nearest").squeeze().numpy() > 0)
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve
# and stay at log_depth=0 (depth=1) -- without this push-out they form a sphere shell
# woven through the foreground; here we lift them to a far skybox radius instead.
if pano_mask.any() and not pano_mask.all():
far = float(np.quantile(pano_depth[pano_mask], 0.95)) * 5.0
pano_depth = np.where(pano_mask, pano_depth, far).astype(np.float32)
uv = _uv_grid(H, W)
directions = spherical_uv_to_directions(uv)
points_np = directions * pano_depth[..., None]
points = torch.from_numpy(points_np).unsqueeze(0).float()
depth = torch.from_numpy(pano_depth).unsqueeze(0).float()
mask = torch.from_numpy(pano_mask).unsqueeze(0)
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation
# after triangulation -- rotating before would scramble the rtol depth-edge check.
geometry = _MoGeGeometryPayload(
points=points, depth=depth, intrinsics=None, mask=mask, normal=None,
image=image.detach().cpu(),
)
return io.NodeOutput(geometry)
class MoGeInference(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGeInference",
display_name="MoGe Inference",
category="image/geometry",
inputs=[
MoGeModelType.Input("moge_model"),
io.Image.Input("image"),
io.Int.Input("resolution_level", default=9, min=0, max=9,
tooltip="0 = fastest, 9 = most detail."),
io.Float.Input("fov_x_degrees", default=0.0, min=0.0, max=170.0, step=0.1,
tooltip="Override horizontal FoV. 0 = auto."),
io.Int.Input("batch_size", default=4, min=1, max=64,
tooltip="Images per inference call. Lower if you OOM on a long video / image set."),
io.Boolean.Input("force_projection", default=True),
io.Boolean.Input("apply_mask", default=True,
tooltip="Set masked-out points/depth to inf."),
],
outputs=[MoGeGeometry.Output(display_name="geometry")],
)
@classmethod
def execute(cls, moge_model, image, resolution_level, fov_x_degrees,
batch_size, force_projection, apply_mask) -> io.NodeOutput:
from tqdm.auto import tqdm
bchw = image.movedim(-1, -3).contiguous()
B = bchw.shape[0]
fov = None if fov_x_degrees <= 0 else float(fov_x_degrees)
pbar = comfy.utils.ProgressBar(B)
chunks: list[dict] = []
with tqdm(total=B, desc="MoGe inference") as tq:
for i in range(0, B, batch_size):
chunk = bchw[i:i + batch_size]
chunks.append(moge_model.infer(chunk, resolution_level=resolution_level, fov_x=fov,
force_projection=force_projection, apply_mask=apply_mask))
pbar.update_absolute(min(i + batch_size, B))
tq.update(chunk.shape[0])
def stack(field):
vals = [c[field] for c in chunks if field in c]
return torch.cat(vals, dim=0) if vals else None
geometry = _MoGeGeometryPayload(
points=stack("points"),
depth=stack("depth"),
intrinsics=stack("intrinsics"),
mask=stack("mask"),
normal=stack("normal"),
image=image.detach().cpu(),
)
return io.NodeOutput(geometry)
_RENDER_MODES = ["depth", "depth_colored", "normal", "normal_screen", "mask"]
class MoGeRender(io.ComfyNode):
"""Render a visualization or mask from a MOGE_GEOMETRY packet."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGeRender",
display_name="MoGe Render",
category="image/geometry",
inputs=[
MoGeGeometry.Input("geometry"),
io.Combo.Input("output", options=_RENDER_MODES, default="depth_colored"),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, geometry, output) -> io.NodeOutput:
from tqdm.auto import tqdm
# Pick the input tensor for the chosen mode and validate availability.
if output in ("depth", "depth_colored", "normal_screen"):
if geometry.depth is None:
raise ValueError("MoGeGeometry has no depth output.")
src = geometry.depth
elif output == "normal":
if geometry.normal is not None:
src = geometry.normal
elif geometry.points is not None:
src = geometry.points
else:
raise ValueError("MoGeGeometry has neither normals nor points to derive normals from.")
elif output == "mask":
if geometry.mask is None:
raise ValueError("MoGeGeometry has no mask output.")
src = geometry.mask
else:
raise ValueError(f"Unknown output mode: {output}")
import comfy.model_management as cmm
B = src.shape[0]
pbar = comfy.utils.ProgressBar(B)
out: list[torch.Tensor] = []
with tqdm(total=B, desc=f"MoGe render: {output}") as tq:
for i in range(B):
slc = src[i:i + 1].float()
if output in ("depth", "depth_colored"):
d = _normalize_disparity(slc)
out.append(_turbo(d) if output == "depth_colored"
else d.unsqueeze(-1).expand(*d.shape, 3).contiguous())
elif output == "normal":
n = slc if geometry.normal is not None else _normals_from_points(slc)
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
elif output == "normal_screen":
n = _screen_normals_from_depth(slc)
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
elif output == "mask":
out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous())
pbar.update_absolute(i + 1)
tq.update(1)
result = torch.cat(out, dim=0).to(device=cmm.intermediate_device(), dtype=cmm.intermediate_dtype())
return io.NodeOutput(result)
class MoGePointMapToMesh(io.ComfyNode):
"""Triangulate one image of a MoGe point map into a Types.MESH (UVs + texture)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGePointMapToMesh",
display_name="MoGe Point Map to Mesh",
category="3d",
inputs=[
MoGeGeometry.Input("geometry"),
io.Int.Input("batch_index", default=0, min=0, max=64,
tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts "
"differ, so batches can't be stacked into a single MESH."),
io.Int.Input("decimation", default=1, min=1, max=8,
tooltip="Vertex stride; 1 = full resolution."),
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
tooltip="Drop pixels whose 3x3 depth span exceeds this fraction. 0 = off."),
io.Boolean.Input("texture", default=True,
tooltip="Carry the source image through as the baseColor texture."),
],
outputs=[io.Mesh.Output()],
)
@classmethod
def execute(cls, geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput:
if geometry.points is None:
raise ValueError("MoGeGeometry has no points output.")
B = geometry.points.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} out of range; geometry has batch size {B}.")
# Pass geometry.depth so the rtol edge check sees radial depth -- for panoramas
# points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half.
edge_depth = geometry.depth[batch_index] if geometry.depth is not None else None
verts, faces, uvs = triangulate_grid_mesh(
geometry.points[batch_index], decimation=decimation,
discontinuity_threshold=discontinuity_threshold, depth=edge_depth,
)
if verts.shape[0] == 0 or faces.shape[0] == 0:
raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.")
if geometry.intrinsics is None:
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back). Pure rotation
# preserves the natural inward winding (correct for inside-the-sphere viewing).
verts = verts[:, [1, 2, 0]].contiguous()
else:
# Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip.
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
faces = faces[:, [0, 2, 1]].contiguous()
tex = geometry.image[batch_index:batch_index + 1] if texture and geometry.image is not None else None
mesh = Types.MESH(
vertices=verts.unsqueeze(0),
faces=faces.unsqueeze(0),
uvs=uvs.unsqueeze(0),
texture=tex,
)
return io.NodeOutput(mesh)
class MoGeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh]
async def comfy_entrypoint() -> MoGeExtension:
return MoGeExtension()