mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-05 22:07:25 +08:00
Support ortho camera
This commit is contained in:
parent
1fd7c500ff
commit
74fd3c4b62
@ -682,11 +682,13 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
|||||||
if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold)
|
if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold)
|
||||||
return background_only()
|
return background_only()
|
||||||
|
|
||||||
eye, _, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info
|
eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info
|
||||||
W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera)
|
W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera)
|
||||||
cam = (xyz - eye) @ W.T
|
cam = (xyz - eye) @ W.T
|
||||||
fov = float(camera_info.get("fov", 0) or 0) or 35.0
|
fov = float(camera_info.get("fov", 0) or 0) or 35.0
|
||||||
zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length
|
zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length
|
||||||
|
is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho")
|
||||||
|
cam_dist = float((target - eye).norm().clamp_min(1e-6)) # eye->target distance: sets the ortho pixel scale
|
||||||
yflip = 1.0 # splat +Y is image-down
|
yflip = 1.0 # splat +Y is image-down
|
||||||
xc, yc, zc = cam.unbind(-1)
|
xc, yc, zc = cam.unbind(-1)
|
||||||
|
|
||||||
@ -698,6 +700,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
|||||||
rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry
|
rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry
|
||||||
|
|
||||||
f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom
|
f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom
|
||||||
|
s = f / cam_dist # ortho: pixels per world unit at the target plane
|
||||||
invz = 1.0 / zc
|
invz = 1.0 / zc
|
||||||
cx0, cy0 = width / 2, height / 2
|
cx0, cy0 = width / 2, height / 2
|
||||||
|
|
||||||
@ -720,10 +723,15 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
|||||||
nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera)
|
nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera)
|
||||||
|
|
||||||
# Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel).
|
# Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel).
|
||||||
cx, cy = cx0 + f * xc * invz, cy0 + yflip * f * yc * invz
|
|
||||||
jm = torch.zeros(xc.shape[0], 2, 3, device=dev)
|
jm = torch.zeros(xc.shape[0], 2, 3, device=dev)
|
||||||
jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square()
|
if is_ortho: # parallel projection: screen = s * (xc, yc)
|
||||||
jm[:, 1, 1], jm[:, 1, 2] = yflip * f * invz, -yflip * f * yc * invz.square()
|
cx, cy = cx0 + s * xc, cy0 + yflip * s * yc
|
||||||
|
jm[:, 0, 0] = s
|
||||||
|
jm[:, 1, 1] = yflip * s
|
||||||
|
else: # perspective: screen = f * (xc, yc) / zc
|
||||||
|
cx, cy = cx0 + f * xc * invz, cy0 + yflip * f * yc * invz
|
||||||
|
jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square()
|
||||||
|
jm[:, 1, 1], jm[:, 1, 2] = yflip * f * invz, -yflip * f * yc * invz.square()
|
||||||
cov2 = jm @ cam_cov @ jm.transpose(1, 2)
|
cov2 = jm @ cam_cov @ jm.transpose(1, 2)
|
||||||
a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1]
|
a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1]
|
||||||
max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt()
|
max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt()
|
||||||
@ -741,16 +749,26 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
|||||||
opacity, rgb = opacity[order], rgb[order]
|
opacity, rgb = opacity[order], rgb[order]
|
||||||
zc_o = zc[order] if need_depth else None
|
zc_o = zc[order] if need_depth else None
|
||||||
nrm_o = nrm[order] if need_normal else None
|
nrm_o = nrm[order] if need_normal else None
|
||||||
|
mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None)
|
||||||
|
|
||||||
def splat(lo, hi): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray
|
def splat(lo, hi): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray
|
||||||
px = cxr[lo:hi, None] + ox[None, :]
|
px = cxr[lo:hi, None] + ox[None, :]
|
||||||
py = cyr[lo:hi, None] + oy[None, :]
|
py = cyr[lo:hi, None] + oy[None, :]
|
||||||
valid = (px >= 0) & (px < width) & (py >= 0) & (py < height)
|
valid = (px >= 0) & (px < width) & (py >= 0) & (py < height)
|
||||||
dx, dy = (px - cx0) / f, yflip * (py - cy0) / f # ray direction in camera space (z = 1)
|
if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0)
|
||||||
dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None]
|
rx = (px - cx0) / s - mux_o[lo:hi, None]
|
||||||
+ 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy))
|
ry = yflip * (py - cy0) / s - muy_o[lo:hi, None]
|
||||||
dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None]
|
rz = -muz_o[lo:hi, None] # constant per splat
|
||||||
q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min(0) # ray->centre Mahalanobis^2
|
rSr = (s00[lo:hi, None] * rx * rx + s11[lo:hi, None] * ry * ry + s22[lo:hi, None] * rz * rz
|
||||||
|
+ 2 * (s01[lo:hi, None] * rx * ry + s02[lo:hi, None] * rx * rz + s12[lo:hi, None] * ry * rz))
|
||||||
|
dsr = s02[lo:hi, None] * rx + s12[lo:hi, None] * ry + s22[lo:hi, None] * rz
|
||||||
|
q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min(0)
|
||||||
|
else: # perspective ray (dx,dy,1) through the camera origin
|
||||||
|
dx, dy = (px - cx0) / f, yflip * (py - cy0) / f
|
||||||
|
dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None]
|
||||||
|
+ 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy))
|
||||||
|
dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None]
|
||||||
|
q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min(0) # ray->centre Mahalanobis^2
|
||||||
alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp(0, 0.999)
|
alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp(0, 0.999)
|
||||||
idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1)
|
idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1)
|
||||||
return idx, alpha
|
return idx, alpha
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user