mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 11:37:32 +08:00
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
321 lines
13 KiB
Python
321 lines
13 KiB
Python
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3.
|
|
|
|
Converts the auxiliary "ray" output of :class:`DualDPT` (per-pixel camera
|
|
ray vectors, predicted on the per-view local feature map) into per-view
|
|
extrinsics + intrinsics. Implementation is a 1:1 port of
|
|
``depth_anything_3.utils.ray_utils`` upstream, using a weighted-RANSAC
|
|
homography fit followed by a QL decomposition.
|
|
|
|
No learned parameters; pure tensor math. Output:
|
|
|
|
* ``R`` -- ``(B, S, 3, 3)`` rotation matrix
|
|
* ``T`` -- ``(B, S, 3)`` camera-space translation
|
|
* ``focal_lengths`` -- ``(B, S, 2)`` in normalised image space (image=2x2)
|
|
* ``principal_points`` -- ``(B, S, 2)`` ditto
|
|
|
|
:func:`get_extrinsic_from_camray` wraps these into a 4x4 extrinsic matrix
|
|
that the public node converts back into pixel-space intrinsics.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Linear-algebra helpers
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Decompose ``A = Q @ L`` with ``Q`` orthogonal and ``L`` lower-triangular.
|
|
|
|
Implemented in terms of QR by reversing the columns/rows; the standard
|
|
trick from the upstream reference. Inputs ``A`` are ``(3, 3)``.
|
|
"""
|
|
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]],
|
|
device=A.device, dtype=A.dtype)
|
|
A_tilde = A @ P
|
|
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
|
|
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
|
|
Q_tilde = Q_tilde.to(A.dtype)
|
|
R_tilde = R_tilde.to(A.dtype)
|
|
Q = Q_tilde @ P
|
|
L = P @ R_tilde @ P
|
|
d = torch.diag(L)
|
|
sign = torch.sign(d)
|
|
Q = Q * sign[None, :] # scale columns of Q
|
|
L = L * sign[:, None] # scale rows of L
|
|
return Q, L
|
|
|
|
|
|
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
|
|
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Weighted-LSQ + RANSAC homography (batched)
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def _find_homography_weighted_lsq(
|
|
src_pts: torch.Tensor,
|
|
dst_pts: torch.Tensor,
|
|
confident_weight: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Solve a single ``H`` with weighted least-squares (DLT)."""
|
|
N = src_pts.shape[0]
|
|
if N < 4:
|
|
raise ValueError("At least 4 points are required to compute a homography.")
|
|
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
|
|
x = src_pts[:, 0:1]
|
|
y = src_pts[:, 1:2]
|
|
u = dst_pts[:, 0:1]
|
|
v = dst_pts[:, 1:2]
|
|
zeros = torch.zeros_like(x)
|
|
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
|
|
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
|
|
A = torch.cat([A1, A2], dim=0) # (2N, 9)
|
|
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
|
_, _, Vh = torch.linalg.svd(A.float())
|
|
Vh = Vh.to(A.dtype)
|
|
H = Vh[-1].reshape(3, 3)
|
|
return H / H[-1, -1]
|
|
|
|
|
|
def _find_homography_weighted_lsq_batched(
|
|
src_pts_batch: torch.Tensor,
|
|
dst_pts_batch: torch.Tensor,
|
|
confident_weight_batch: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Batched DLT solver. Inputs ``(B, K, 2)`` / ``(B, K)``; output ``(B, 3, 3)``."""
|
|
B, K, _ = src_pts_batch.shape
|
|
w = confident_weight_batch.sqrt().unsqueeze(2)
|
|
x = src_pts_batch[:, :, 0:1]
|
|
y = src_pts_batch[:, :, 1:2]
|
|
u = dst_pts_batch[:, :, 0:1]
|
|
v = dst_pts_batch[:, :, 1:2]
|
|
zeros = torch.zeros_like(x)
|
|
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
|
|
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
|
|
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
|
|
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
|
_, _, Vh = torch.linalg.svd(A.float())
|
|
Vh = Vh.to(A.dtype)
|
|
H = Vh[:, -1].reshape(B, 3, 3)
|
|
return H / H[:, 2:3, 2:3]
|
|
|
|
|
|
def _ransac_find_homography_weighted_batched(
|
|
src_pts: torch.Tensor, # (B, N, 2)
|
|
dst_pts: torch.Tensor, # (B, N, 2)
|
|
confident_weight: torch.Tensor, # (B, N)
|
|
n_sample: int,
|
|
n_iter: int = 100,
|
|
reproj_threshold: float = 3.0,
|
|
num_sample_for_ransac: int = 8,
|
|
max_inlier_num: int = 10000,
|
|
rand_sample_iters_idx: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Batched weighted-RANSAC homography estimator.
|
|
|
|
Returns ``(B, 3, 3)`` homography matrices.
|
|
"""
|
|
B, N, _ = src_pts.shape
|
|
assert N >= 4
|
|
device = src_pts.device
|
|
|
|
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
|
|
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
|
|
|
|
if rand_sample_iters_idx is None:
|
|
rand_sample_iters_idx = torch.stack(
|
|
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
|
|
for _ in range(n_iter)],
|
|
dim=0,
|
|
)
|
|
|
|
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
|
|
b_idx = (
|
|
torch.arange(B, device=device)
|
|
.view(B, 1, 1)
|
|
.expand(B, n_iter, num_sample_for_ransac)
|
|
)
|
|
src_b = src_pts[b_idx, rand_idx]
|
|
dst_b = dst_pts[b_idx, rand_idx]
|
|
w_b = confident_weight[b_idx, rand_idx]
|
|
|
|
cB, cN = src_b.shape[:2]
|
|
H_batch = _find_homography_weighted_lsq_batched(
|
|
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
|
|
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
|
|
|
|
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
|
|
proj = torch.bmm(
|
|
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
|
|
H_batch.reshape(-1, 3, 3).transpose(1, 2),
|
|
) # (B*n_iter, N, 3)
|
|
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
|
|
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
|
|
inlier_mask = err < reproj_threshold
|
|
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
|
|
best_idx = torch.argmax(score, dim=1)
|
|
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
|
|
|
|
# Refit with the inlier set (per-batch, since the inlier counts vary).
|
|
H_inlier_list = []
|
|
for b in range(B):
|
|
mask = best_inlier_mask[b]
|
|
in_src = src_pts[b][mask]
|
|
in_dst = dst_pts[b][mask]
|
|
in_w = confident_weight[b][mask]
|
|
if in_src.shape[0] < 4:
|
|
# Fall back to identity when RANSAC fails to find enough inliers.
|
|
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
|
|
continue
|
|
sorted_w = torch.argsort(in_w, descending=True)
|
|
if len(sorted_w) > max_inlier_num:
|
|
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
|
|
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
|
|
H_inlier_list.append(
|
|
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
|
|
)
|
|
return torch.stack(H_inlier_list, dim=0)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Camera-ray utilities
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def _unproject_identity(num_y: int, num_x: int, B: int, S: int,
|
|
device, dtype) -> torch.Tensor:
|
|
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane.
|
|
|
|
Replicates ``unproject_depth(..., ixt_normalized=True)`` upstream: pixel
|
|
coords ``(x, y)`` in ``[dx, 2-dx] x [dy, 2-dy]`` get mapped to
|
|
camera-space rays ``(x-1, y-1, 1)`` via the identity intrinsic
|
|
``[[1,0,1],[0,1,1],[0,0,1]]``. Returns ``(B, S, num_y, num_x, 3)``.
|
|
"""
|
|
dx = 1.0 / num_x
|
|
dy = 1.0 / num_y
|
|
# Centered camera-space coords directly (skip the K^-1 step since it's
|
|
# just a translation by -1 on x and y when K is identity-with-center=1).
|
|
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
|
|
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
|
|
yy, xx = torch.meshgrid(y, x, indexing="ij")
|
|
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
|
|
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
|
|
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
|
|
|
|
|
|
def _camray_to_caminfo(
|
|
camray: torch.Tensor, # (B, S, h, w, 6)
|
|
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
|
|
reproj_threshold: float = 0.2,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
|
|
if confidence is None:
|
|
confidence = torch.ones_like(camray[..., 0])
|
|
B, S, h, w, _ = camray.shape
|
|
device = camray.device
|
|
dtype = camray.dtype
|
|
|
|
rays_target = camray[..., :3] # (B, S, h, w, 3)
|
|
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
|
|
|
|
# Flatten (B*S, h*w, *) for the RANSAC routine.
|
|
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
|
|
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
|
|
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
|
|
|
|
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
|
|
z_thresh = 1e-4
|
|
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
|
|
weights = torch.where(mask, weights, torch.zeros_like(weights))
|
|
src = rays_origin.clone()
|
|
dst = rays_target.clone()
|
|
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
|
|
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
|
|
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
|
|
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
|
|
src = src[..., :2]
|
|
dst = dst[..., :2]
|
|
|
|
N = src.shape[1]
|
|
n_iter = 100
|
|
sample_ratio = 0.3
|
|
num_sample_for_ransac = 8
|
|
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
|
|
rand_idx = torch.stack(
|
|
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
|
|
dim=0,
|
|
)
|
|
|
|
# Chunk along the view axis to keep peak memory predictable.
|
|
chunk = 2
|
|
A_list = []
|
|
for i in range(0, src.shape[0], chunk):
|
|
A = _ransac_find_homography_weighted_batched(
|
|
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
|
|
n_sample=n_sample, n_iter=n_iter,
|
|
num_sample_for_ransac=num_sample_for_ransac,
|
|
reproj_threshold=reproj_threshold,
|
|
rand_sample_iters_idx=rand_idx,
|
|
max_inlier_num=8000,
|
|
)
|
|
# Flip sign on dets that come out < 0 (so that the QL produces a
|
|
# right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so
|
|
# do the comparison in fp32.
|
|
flip = torch.linalg.det(A.float()) < 0
|
|
A = torch.where(flip[:, None, None], -A, A)
|
|
A_list.append(A)
|
|
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
|
|
|
|
R_list, f_list, pp_list = [], [], []
|
|
for i in range(A.shape[0]):
|
|
R, L = _ql_decomposition(A[i])
|
|
L = L / L[2][2]
|
|
f_list.append(torch.stack((L[0][0], L[1][1])))
|
|
pp_list.append(torch.stack((L[2][0], L[2][1])))
|
|
R_list.append(R)
|
|
R = torch.stack(R_list).reshape(B, S, 3, 3)
|
|
focal = torch.stack(f_list).reshape(B, S, 2)
|
|
pp = torch.stack(pp_list).reshape(B, S, 2)
|
|
|
|
# Translation: confidence-weighted average of camray direction(s).
|
|
cf = confidence.flatten(0, 1).flatten(1, 2)
|
|
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
|
|
T = T / cf.sum(dim=-1, keepdim=True)
|
|
T = T.reshape(B, S, 3)
|
|
|
|
# Match upstream output convention: focal -> 1/focal, pp + 1.
|
|
return R, T, 1.0 / focal, pp + 1.0
|
|
|
|
|
|
def get_extrinsic_from_camray(
|
|
camray: torch.Tensor, # (B, S, h, w, 6)
|
|
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
|
|
patch_size_y: int,
|
|
patch_size_x: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output.
|
|
|
|
Returns:
|
|
* extrinsic ``(B, S, 4, 4)`` camera-to-world (the inverse is
|
|
what gets stored in ``output.extrinsics``
|
|
by the caller).
|
|
* focals ``(B, S, 2)`` in normalised image space.
|
|
* pp ``(B, S, 2)`` in normalised image space.
|
|
"""
|
|
if conf.ndim == 5 and conf.shape[-1] == 1:
|
|
conf = conf.squeeze(-1)
|
|
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
|
|
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
|
|
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
|
|
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
|
|
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
|
|
return extr, focal, pp
|