Prevent crash at huge splats

This commit is contained in:
kijai 2026-05-31 12:09:41 +03:00
parent 3c7d35f4e0
commit 1fd7c500ff

View File

@ -43,6 +43,14 @@ def _hex_to_rgb(h: str) -> tuple[float, float, float]:
return tuple(int(h[i:i + 2], 16) / 255.0 for i in (0, 2, 4))
def _quantile(x, q):
# torch.quantile errors above 2**24 elements; stride-subsample large inputs for the estimate.
lim = 1 << 24
if x.numel() > lim:
x = x[:: x.numel() // lim + 1]
return torch.quantile(x, q)
def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes:
"""Serialize render-ready gaussian tensors as a binary 3DGS .ply.
@ -720,7 +728,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1]
max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt()
radius = 3.0 * max_eig.clamp_min(1e-8).sqrt()
K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(torch.quantile(radius, 0.995).item()))))
K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item()))))
rng = torch.arange(-K, K + 1, device=dev, dtype=torch.float32)
oy, ox = torch.meshgrid(rng, rng, indexing="ij")
ox, oy = ox.reshape(-1), oy.reshape(-1) # (M,) kernel offsets
@ -916,7 +924,7 @@ class RenderSplat(IO.ComfyNode):
base_cam = camera_info
if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat
center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device)
extent = ((xyz - center).norm(dim=-1).quantile(0.99).clamp_min(1e-4) if xyz.shape[0]
extent = (_quantile((xyz - center).norm(dim=-1), 0.99).clamp_min(1e-4) if xyz.shape[0]
else torch.tensor(1.0, device=device))
dist = float(extent / (math.tan(math.radians(35.0) / 2) * 0.9))
base_cam = _orbit_camera_info(35.0, 30.0, dist, 35.0, center, device)