Support ortho camera

This commit is contained in:
kijai 2026-05-31 12:10:12 +03:00
parent 1fd7c500ff
commit 74fd3c4b62

View File

@ -682,11 +682,13 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold) if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold)
return background_only() return background_only()
eye, _, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info
W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera) W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera)
cam = (xyz - eye) @ W.T cam = (xyz - eye) @ W.T
fov = float(camera_info.get("fov", 0) or 0) or 35.0 fov = float(camera_info.get("fov", 0) or 0) or 35.0
zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length
is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho")
cam_dist = float((target - eye).norm().clamp_min(1e-6)) # eye->target distance: sets the ortho pixel scale
yflip = 1.0 # splat +Y is image-down yflip = 1.0 # splat +Y is image-down
xc, yc, zc = cam.unbind(-1) xc, yc, zc = cam.unbind(-1)
@ -698,6 +700,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry
f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom
s = f / cam_dist # ortho: pixels per world unit at the target plane
invz = 1.0 / zc invz = 1.0 / zc
cx0, cy0 = width / 2, height / 2 cx0, cy0 = width / 2, height / 2
@ -720,10 +723,15 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera) nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera)
# Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel). # Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel).
cx, cy = cx0 + f * xc * invz, cy0 + yflip * f * yc * invz
jm = torch.zeros(xc.shape[0], 2, 3, device=dev) jm = torch.zeros(xc.shape[0], 2, 3, device=dev)
jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square() if is_ortho: # parallel projection: screen = s * (xc, yc)
jm[:, 1, 1], jm[:, 1, 2] = yflip * f * invz, -yflip * f * yc * invz.square() cx, cy = cx0 + s * xc, cy0 + yflip * s * yc
jm[:, 0, 0] = s
jm[:, 1, 1] = yflip * s
else: # perspective: screen = f * (xc, yc) / zc
cx, cy = cx0 + f * xc * invz, cy0 + yflip * f * yc * invz
jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square()
jm[:, 1, 1], jm[:, 1, 2] = yflip * f * invz, -yflip * f * yc * invz.square()
cov2 = jm @ cam_cov @ jm.transpose(1, 2) cov2 = jm @ cam_cov @ jm.transpose(1, 2)
a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1] a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1]
max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt() max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt()
@ -741,16 +749,26 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
opacity, rgb = opacity[order], rgb[order] opacity, rgb = opacity[order], rgb[order]
zc_o = zc[order] if need_depth else None zc_o = zc[order] if need_depth else None
nrm_o = nrm[order] if need_normal else None nrm_o = nrm[order] if need_normal else None
mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None)
def splat(lo, hi): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray def splat(lo, hi): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray
px = cxr[lo:hi, None] + ox[None, :] px = cxr[lo:hi, None] + ox[None, :]
py = cyr[lo:hi, None] + oy[None, :] py = cyr[lo:hi, None] + oy[None, :]
valid = (px >= 0) & (px < width) & (py >= 0) & (py < height) valid = (px >= 0) & (px < width) & (py >= 0) & (py < height)
dx, dy = (px - cx0) / f, yflip * (py - cy0) / f # ray direction in camera space (z = 1) if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0)
dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None] rx = (px - cx0) / s - mux_o[lo:hi, None]
+ 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy)) ry = yflip * (py - cy0) / s - muy_o[lo:hi, None]
dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None] rz = -muz_o[lo:hi, None] # constant per splat
q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min(0) # ray->centre Mahalanobis^2 rSr = (s00[lo:hi, None] * rx * rx + s11[lo:hi, None] * ry * ry + s22[lo:hi, None] * rz * rz
+ 2 * (s01[lo:hi, None] * rx * ry + s02[lo:hi, None] * rx * rz + s12[lo:hi, None] * ry * rz))
dsr = s02[lo:hi, None] * rx + s12[lo:hi, None] * ry + s22[lo:hi, None] * rz
q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min(0)
else: # perspective ray (dx,dy,1) through the camera origin
dx, dy = (px - cx0) / f, yflip * (py - cy0) / f
dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None]
+ 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy))
dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None]
q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min(0) # ray->centre Mahalanobis^2
alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp(0, 0.999) alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp(0, 0.999)
idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1) idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1)
return idx, alpha return idx, alpha