From 1fd7c500ff8c3224a3f7ee4b26af7d00db266063 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 31 May 2026 12:09:41 +0300 Subject: [PATCH] Prevent crash at huge splats --- comfy_extras/nodes_gaussian_splat.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py index 3808fd1c2..450b87e89 100644 --- a/comfy_extras/nodes_gaussian_splat.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -43,6 +43,14 @@ def _hex_to_rgb(h: str) -> tuple[float, float, float]: return tuple(int(h[i:i + 2], 16) / 255.0 for i in (0, 2, 4)) +def _quantile(x, q): + # torch.quantile errors above 2**24 elements; stride-subsample large inputs for the estimate. + lim = 1 << 24 + if x.numel() > lim: + x = x[:: x.numel() // lim + 1] + return torch.quantile(x, q) + + def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes: """Serialize render-ready gaussian tensors as a binary 3DGS .ply. @@ -720,7 +728,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1] max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt() radius = 3.0 * max_eig.clamp_min(1e-8).sqrt() - K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(torch.quantile(radius, 0.995).item())))) + K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item())))) rng = torch.arange(-K, K + 1, device=dev, dtype=torch.float32) oy, ox = torch.meshgrid(rng, rng, indexing="ij") ox, oy = ox.reshape(-1), oy.reshape(-1) # (M,) kernel offsets @@ -916,7 +924,7 @@ class RenderSplat(IO.ComfyNode): base_cam = camera_info if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device) - extent = ((xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) if xyz.shape[0] + extent = (_quantile((xyz - center).norm(dim=-1), 0.99).clamp_min(1e-4) if xyz.shape[0] else torch.tensor(1.0, device=device)) dist = float(extent / (math.tan(math.radians(35.0) / 2) * 0.9)) base_cam = _orbit_camera_info(35.0, 30.0, dist, 35.0, center, device)