ComfyUI/comfy/ldm/moge/panorama.py
2026-05-12 16:09:24 +03:00

316 lines
14 KiB
Python

"""Panorama (equirectangular) inference helpers for MoGe.
Splits an equirect into 12 perspective views via an icosahedron camera rig, runs
the model per view, and stitches per-view distance maps back into a single
equirect distance map via a multi-scale Poisson + gradient sparse solve.
Image sampling uses ``F.grid_sample`` (GPU); the sparse solve uses ``lsmr`` (CPU).
"""
from __future__ import annotations
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
def _icosahedron_directions() -> np.ndarray:
"""12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order)."""
A = (1.0 + np.sqrt(5.0)) / 2.0
return np.array([
[0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A],
[1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0],
[A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1],
], dtype=np.float32)
def _intrinsics_from_fov(fov_x_rad: float, fov_y_rad: float) -> np.ndarray:
"""Normalised-image (unit-square) K matrix."""
fx = 0.5 / np.tan(fov_x_rad / 2)
fy = 0.5 / np.tan(fov_y_rad / 2)
return np.array([[fx, 0, 0.5], [0, fy, 0.5], [0, 0, 1]], dtype=np.float32)
def _extrinsics_look_at(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
"""OpenCV-convention world->camera extrinsics for an array of look-at targets (N, 4, 4)."""
eye = np.asarray(eye, dtype=np.float32)
target = np.asarray(target, dtype=np.float32)
up = np.asarray(up, dtype=np.float32)
if target.ndim == 1:
target = target[None]
fwd = target - eye
fwd = fwd / np.linalg.norm(fwd, axis=-1, keepdims=True).clip(1e-12)
right = np.cross(fwd, up)
right_norm = np.linalg.norm(right, axis=-1, keepdims=True)
# Fall back to an arbitrary perpendicular if forward is parallel to up.
parallel = right_norm.squeeze(-1) < 1e-6
if parallel.any():
alt_up = np.array([1, 0, 0], dtype=np.float32)
right = np.where(parallel[:, None], np.cross(fwd, alt_up), right)
right_norm = np.linalg.norm(right, axis=-1, keepdims=True)
right = right / right_norm.clip(1e-12)
new_up = np.cross(fwd, right)
R = np.stack([right, new_up, fwd], axis=-2)
t = -np.einsum("nij,j->ni", R, eye)
E = np.zeros((R.shape[0], 4, 4), dtype=np.float32)
E[:, :3, :3] = R
E[:, :3, 3] = t
E[:, 3, 3] = 1.0
return E
def get_panorama_cameras() -> Tuple[np.ndarray, List[np.ndarray]]:
"""Returns (extrinsics (12, 4, 4), [intrinsics] * 12) for icosahedron views at 90 deg FoV."""
targets = _icosahedron_directions()
eye = np.zeros(3, dtype=np.float32)
up = np.array([0, 0, 1], dtype=np.float32)
extrinsics = _extrinsics_look_at(eye, targets, up)
K = _intrinsics_from_fov(np.deg2rad(90.0), np.deg2rad(90.0))
return extrinsics, [K] * len(targets)
def spherical_uv_to_directions(uv: np.ndarray) -> np.ndarray:
"""Equirect UV in [0, 1] -> 3D unit-direction (Z up)."""
theta = (1 - uv[..., 0]) * (2 * np.pi)
phi = uv[..., 1] * np.pi
return np.stack([
np.sin(phi) * np.cos(theta),
np.sin(phi) * np.sin(theta),
np.cos(phi),
], axis=-1).astype(np.float32)
def directions_to_spherical_uv(directions: np.ndarray) -> np.ndarray:
"""3D direction -> equirect UV in [0, 1]."""
n = np.linalg.norm(directions, axis=-1, keepdims=True).clip(1e-12)
d = directions / n
u = 1 - np.arctan2(d[..., 1], d[..., 0]) / (2 * np.pi) % 1.0
v = np.arccos(d[..., 2].clip(-1, 1)) / np.pi
return np.stack([u, v], axis=-1).astype(np.float32)
def _uv_grid(H: int, W: int) -> np.ndarray:
"""Pixel-center UV grid in [0, 1]; (H, W, 2)."""
u = (np.arange(W, dtype=np.float32) + 0.5) / W
v = (np.arange(H, dtype=np.float32) + 0.5) / H
return np.stack(np.meshgrid(u, v, indexing="xy"), axis=-1)
def _unproject_cv(uv: np.ndarray, depth: np.ndarray,
extrinsics: np.ndarray, intrinsics: np.ndarray) -> np.ndarray:
"""Back-project pixels into world coords (OpenCV convention)."""
pix = np.concatenate([uv, np.ones_like(uv[..., :1])], axis=-1)
K_inv = np.linalg.inv(intrinsics)
cam = pix @ K_inv.T * depth[..., None]
cam_h = np.concatenate([cam, np.ones_like(cam[..., :1])], axis=-1)
E_inv = np.linalg.inv(extrinsics)
return (cam_h @ E_inv.T)[..., :3]
def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""World coords -> (uv, depth) in the camera (OpenCV convention)."""
pts_h = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
cam = pts_h @ extrinsics.T
cam_xyz = cam[..., :3]
depth = cam_xyz[..., 2]
proj = cam_xyz @ intrinsics.T
uv = proj[..., :2] / proj[..., 2:3].clip(1e-12)
return uv.astype(np.float32), depth.astype(np.float32)
def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor:
"""Sample img_bchw at UV-in-[0,1] coords ``uv`` of shape (B, H, W, 2); replicate-border."""
grid = uv * 2.0 - 1.0
return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False)
def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics: List[np.ndarray], resolution: int) -> torch.Tensor:
"""(3, Hp, Wp) equirect on any device -> (N, 3, R, R) perspective crops on the same device."""
device = image.device
N = len(extrinsics)
uv = _uv_grid(resolution, resolution)
sample_uvs = []
for i in range(N):
world = _unproject_cv(uv, np.ones(uv.shape[:-1], dtype=np.float32), extrinsics[i], intrinsics[i])
sample_uvs.append(directions_to_spherical_uv(world))
sample_uvs = np.stack(sample_uvs, axis=0)
img_bchw = image.unsqueeze(0).expand(N, -1, -1, -1).contiguous()
sample_uvs_t = torch.from_numpy(sample_uvs).to(device=device, dtype=image.dtype)
return _grid_sample_uv(img_bchw, sample_uvs_t, mode="bilinear")
def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
"""Sparse Laplacian operator over the H x W grid."""
from scipy.sparse import csr_array
grid_index = np.arange(H * W).reshape(H, W)
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge")
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge")
data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(H * W, axis=0).reshape(-1)
indices = np.stack([
grid_index[1:-1, 1:-1],
grid_index[:-2, 1:-1], grid_index[2:, 1:-1],
grid_index[1:-1, :-2], grid_index[1:-1, 2:],
], axis=-1).reshape(-1)
indptr = np.arange(0, H * W * 5 + 1, 5)
return csr_array((data, indices, indptr), shape=(H * W, H * W))
def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
"""Sparse forward-difference operator over the H x W grid."""
from scipy.sparse import csr_array
grid_index = np.arange(W * H).reshape(H, W)
if wrap_x:
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap")
if wrap_y:
grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode="wrap")
data = np.concatenate([
np.concatenate([
np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1),
-np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1),
], axis=1).reshape(-1),
np.concatenate([
np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1),
-np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1),
], axis=1).reshape(-1),
])
indices = np.concatenate([
np.concatenate([grid_index[:, :-1].reshape(-1, 1), grid_index[:, 1:].reshape(-1, 1)], axis=1).reshape(-1),
np.concatenate([grid_index[:-1, :].reshape(-1, 1), grid_index[1:, :].reshape(-1, 1)], axis=1).reshape(-1),
])
nx = grid_index.shape[0] * (grid_index.shape[1] - 1)
ny = (grid_index.shape[0] - 1) * grid_index.shape[1]
indptr = np.arange(0, nx * 2 + ny * 2 + 1, 2)
return csr_array((data, indices, indptr), shape=(nx + ny, H * W))
def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray:
"""Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border."""
from scipy.ndimage import map_coordinates
H, W = img.shape[:2]
yy = np.clip(sample_pixels[..., 1], 0, H - 1)
xx = np.clip(sample_pixels[..., 0], 0, W - 1)
order = 1 if mode == "bilinear" else 0
if img.ndim == 2:
return map_coordinates(img, [yy, xx], order=order, mode="nearest").astype(img.dtype)
out = np.stack([
map_coordinates(img[..., c], [yy, xx], order=order, mode="nearest")
for c in range(img.shape[-1])
], axis=-1)
return out.astype(img.dtype)
def merge_panorama_depth(width: int, height: int,
distance_maps: List[np.ndarray], pred_masks: List[np.ndarray],
extrinsics: List[np.ndarray], intrinsics: List[np.ndarray],
on_view: Optional[Callable[[], None]] = None,
on_solve_start: Optional[Callable[[int, int], None]] = None,
on_solve_end: Optional[Callable[[int, int], None]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Stitch per-view distance maps into a single equirect distance map.
Recursive multi-scale solve: solves at half resolution first and uses that as the lsmr init
for the full-resolution solve. Optional callbacks fire per view processed and around each
lsmr solve so callers can drive a progress bar.
"""
from scipy.ndimage import convolve
from scipy.sparse import vstack
from scipy.sparse.linalg import lsmr
if max(width, height) > 256:
coarse_depth, _ = merge_panorama_depth(width // 2, height // 2,
distance_maps, pred_masks, extrinsics, intrinsics,
on_view=on_view,
on_solve_start=on_solve_start,
on_solve_end=on_solve_end)
t = torch.from_numpy(coarse_depth).unsqueeze(0).unsqueeze(0)
t = F.interpolate(t, size=(height, width), mode="bilinear", align_corners=False)
depth_init = t.squeeze().numpy().astype(np.float32)
else:
depth_init = None
spherical_directions = spherical_uv_to_directions(_uv_grid(height, width))
pano_log_grad_maps, pano_grad_masks = [], []
pano_log_lap_maps, pano_lap_masks = [], []
pano_pred_masks: List[np.ndarray] = []
for i in range(len(distance_maps)):
proj_uv, proj_depth = _project_cv(spherical_directions, extrinsics[i], intrinsics[i])
proj_valid = (proj_depth > 0) & (proj_uv > 0).all(axis=-1) & (proj_uv < 1).all(axis=-1)
Hd, Wd = distance_maps[i].shape[:2]
proj_pixels = np.clip(proj_uv, 0, 1) * np.array([Wd - 1, Hd - 1], dtype=np.float32)
log_dist = np.log(np.clip(distance_maps[i], 1e-6, None))
sampled = _scipy_remap_bilinear(log_dist, proj_pixels, mode="bilinear")
pano_log = np.where(proj_valid, sampled, 0.0).astype(np.float32)
sampled_mask = _scipy_remap_bilinear(pred_masks[i].astype(np.uint8), proj_pixels, mode="nearest")
pano_pred = proj_valid & (sampled_mask > 0)
# Equirect wraps horizontally but not vertically: wrap pad along x, edge pad along y.
padded = np.pad(pano_log, ((0, 0), (0, 1)), mode="wrap")
gx, gy = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :]
padded_m = np.pad(pano_pred, ((0, 0), (0, 1)), mode="wrap")
mx, my = padded_m[:, :-1] & padded_m[:, 1:], padded_m[:-1, :] & padded_m[1:, :]
pano_log_grad_maps.append((gx, gy))
pano_grad_masks.append((mx, my))
padded = np.pad(pano_log, ((1, 1), (0, 0)), mode="edge")
padded = np.pad(padded, ((0, 0), (1, 1)), mode="wrap")
lap_kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32)
lap = convolve(padded, lap_kernel)[1:-1, 1:-1]
padded_m = np.pad(pano_pred, ((1, 1), (0, 0)), mode="edge")
padded_m = np.pad(padded_m, ((0, 0), (1, 1)), mode="wrap")
m_kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8)
lap_mask = convolve(padded_m.astype(np.uint8), m_kernel)[1:-1, 1:-1] == 5
pano_log_lap_maps.append(lap)
pano_lap_masks.append(lap_mask)
pano_pred_masks.append(pano_pred)
if on_view is not None:
on_view()
gx = np.stack([m[0] for m in pano_log_grad_maps], axis=0)
gy = np.stack([m[1] for m in pano_log_grad_maps], axis=0)
mx = np.stack([m[0] for m in pano_grad_masks], axis=0)
my = np.stack([m[1] for m in pano_grad_masks], axis=0)
gx_avg = (gx * mx).sum(axis=0) / mx.sum(axis=0).clip(1e-3)
gy_avg = (gy * my).sum(axis=0) / my.sum(axis=0).clip(1e-3)
laps = np.stack(pano_log_lap_maps, axis=0)
lap_masks = np.stack(pano_lap_masks, axis=0)
lap_avg = (laps * lap_masks).sum(axis=0) / lap_masks.sum(axis=0).clip(1e-3)
grad_x_mask = mx.any(axis=0).reshape(-1)
grad_y_mask = my.any(axis=0).reshape(-1)
grad_mask = np.concatenate([grad_x_mask, grad_y_mask])
lap_mask_flat = lap_masks.any(axis=0).reshape(-1)
A = vstack([
_grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask],
_poisson_equation(width, height, wrap_x=True, wrap_y=False)[lap_mask_flat],
])
b = np.concatenate([
gx_avg.reshape(-1)[grad_x_mask],
gy_avg.reshape(-1)[grad_y_mask],
lap_avg.reshape(-1)[lap_mask_flat],
])
x0 = np.log(np.clip(depth_init, 1e-6, None)).reshape(-1) if depth_init is not None else None
if on_solve_start is not None:
on_solve_start(width, height)
x, *_ = lsmr(A, b, atol=1e-5, btol=1e-5, x0=x0, show=False)
if on_solve_end is not None:
on_solve_end(width, height)
pano_depth = np.exp(x).reshape(height, width).astype(np.float32)
pano_mask = np.any(pano_pred_masks, axis=0)
return pano_depth, pano_mask