diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py index 450b87e89..111034ee7 100644 --- a/comfy_extras/nodes_gaussian_splat.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -682,11 +682,13 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold) return background_only() - eye, _, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info + eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera) cam = (xyz - eye) @ W.T fov = float(camera_info.get("fov", 0) or 0) or 35.0 zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length + is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho") + cam_dist = float((target - eye).norm().clamp_min(1e-6)) # eye->target distance: sets the ortho pixel scale yflip = 1.0 # splat +Y is image-down xc, yc, zc = cam.unbind(-1) @@ -698,6 +700,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom + s = f / cam_dist # ortho: pixels per world unit at the target plane invz = 1.0 / zc cx0, cy0 = width / 2, height / 2 @@ -720,10 +723,15 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera) # Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel). - cx, cy = cx0 + f * xc * invz, cy0 + yflip * f * yc * invz jm = torch.zeros(xc.shape[0], 2, 3, device=dev) - jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square() - jm[:, 1, 1], jm[:, 1, 2] = yflip * f * invz, -yflip * f * yc * invz.square() + if is_ortho: # parallel projection: screen = s * (xc, yc) + cx, cy = cx0 + s * xc, cy0 + yflip * s * yc + jm[:, 0, 0] = s + jm[:, 1, 1] = yflip * s + else: # perspective: screen = f * (xc, yc) / zc + cx, cy = cx0 + f * xc * invz, cy0 + yflip * f * yc * invz + jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square() + jm[:, 1, 1], jm[:, 1, 2] = yflip * f * invz, -yflip * f * yc * invz.square() cov2 = jm @ cam_cov @ jm.transpose(1, 2) a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1] max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt() @@ -741,16 +749,26 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, opacity, rgb = opacity[order], rgb[order] zc_o = zc[order] if need_depth else None nrm_o = nrm[order] if need_normal else None + mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None) def splat(lo, hi): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray px = cxr[lo:hi, None] + ox[None, :] py = cyr[lo:hi, None] + oy[None, :] valid = (px >= 0) & (px < width) & (py >= 0) & (py < height) - dx, dy = (px - cx0) / f, yflip * (py - cy0) / f # ray direction in camera space (z = 1) - dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None] - + 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy)) - dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None] - q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min(0) # ray->centre Mahalanobis^2 + if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0) + rx = (px - cx0) / s - mux_o[lo:hi, None] + ry = yflip * (py - cy0) / s - muy_o[lo:hi, None] + rz = -muz_o[lo:hi, None] # constant per splat + rSr = (s00[lo:hi, None] * rx * rx + s11[lo:hi, None] * ry * ry + s22[lo:hi, None] * rz * rz + + 2 * (s01[lo:hi, None] * rx * ry + s02[lo:hi, None] * rx * rz + s12[lo:hi, None] * ry * rz)) + dsr = s02[lo:hi, None] * rx + s12[lo:hi, None] * ry + s22[lo:hi, None] * rz + q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min(0) + else: # perspective ray (dx,dy,1) through the camera origin + dx, dy = (px - cx0) / f, yflip * (py - cy0) / f + dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None] + + 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy)) + dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None] + q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min(0) # ray->centre Mahalanobis^2 alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp(0, 0.999) idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1) return idx, alpha