diff --git a/comfy/ldm/depth_anything_3/ray_pose.py b/comfy/ldm/depth_anything_3/ray_pose.py index f9a3878db..e72264dd5 100644 --- a/comfy/ldm/depth_anything_3/ray_pose.py +++ b/comfy/ldm/depth_anything_3/ray_pose.py @@ -38,7 +38,10 @@ def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype) A_tilde = A @ P - Q_tilde, R_tilde = torch.linalg.qr(A_tilde) + # CUDA QR is not implemented for fp16/bf16; upcast just for this call. + Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float()) + Q_tilde = Q_tilde.to(A.dtype) + R_tilde = R_tilde.to(A.dtype) Q = Q_tilde @ P L = P @ R_tilde @ P d = torch.diag(L) @@ -75,7 +78,9 @@ def _find_homography_weighted_lsq( A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1) A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1) A = torch.cat([A1, A2], dim=0) # (2N, 9) - _, _, Vh = torch.linalg.svd(A) + # CUDA SVD is not implemented for fp16/bf16; upcast just for this call. + _, _, Vh = torch.linalg.svd(A.float()) + Vh = Vh.to(A.dtype) H = Vh[-1].reshape(3, 3) return H / H[-1, -1] @@ -96,7 +101,9 @@ def _find_homography_weighted_lsq_batched( A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2) A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2) A = torch.cat([A1, A2], dim=1) # (B, 2K, 9) - _, _, Vh = torch.linalg.svd(A) + # CUDA SVD is not implemented for fp16/bf16; upcast just for this call. + _, _, Vh = torch.linalg.svd(A.float()) + Vh = Vh.to(A.dtype) H = Vh[:, -1].reshape(B, 3, 3) return H / H[:, 2:3, 2:3] @@ -260,8 +267,9 @@ def _camray_to_caminfo( max_inlier_num=8000, ) # Flip sign on dets that come out < 0 (so that the QL produces a - # right-handed rotation). - flip = torch.linalg.det(A) < 0 + # right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so + # do the comparison in fp32. + flip = torch.linalg.det(A.float()) < 0 A = torch.where(flip[:, None, None], -A, A) A_list.append(A) A = torch.cat(A_list, dim=0) # (B*S, 3, 3)