Optimize rendering and meshing speed

Multiple times faster
This commit is contained in:
kijai 2026-05-31 15:04:05 +03:00
parent 7188dca1f1
commit 87f5b6a692

View File

@ -9,7 +9,7 @@ from io import BytesIO
import numpy as np import numpy as np
import torch import torch
from typing_extensions import override from typing_extensions import override
from scipy.ndimage import map_coordinates from scipy.ndimage import map_coordinates, minimum as _ndi_minimum, maximum as _ndi_maximum
from scipy.sparse import coo_matrix from scipy.sparse import coo_matrix
from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import connected_components
@ -55,7 +55,7 @@ def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes:
"""Serialize render-ready gaussian tensors as a binary 3DGS .ply. """Serialize render-ready gaussian tensors as a binary 3DGS .ply.
positions (N,3) world; scales (N,3) linear; rotations (N,4) quat wxyz; opacities (N,1) in [0,1]; positions (N,3) world; scales (N,3) linear; rotations (N,4) quat wxyz; opacities (N,1) in [0,1];
sh (N,K,3) SH coefficients. Activated values are inverted to the standard 3DGS storage convention sh (N,K,3) SH coefficients. Activated values are inverted to the standard 3D gaussian splat storage convention
(log scale, logit opacity). (log scale, logit opacity).
""" """
xyz = positions.cpu().numpy().astype(np.float32) xyz = positions.cpu().numpy().astype(np.float32)
@ -65,7 +65,7 @@ def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes:
normals = np.zeros_like(xyz) normals = np.zeros_like(xyz)
f = sh.cpu().numpy().astype(np.float32) # (N, K, 3) f = sh.cpu().numpy().astype(np.float32) # (N, K, 3)
f_dc = f[:, 0, :] # (N, 3) f_dc = f[:, 0, :] # (N, 3)
f_rest = f[:, 1:, :].transpose(0, 2, 1).reshape(n, -1) # (N, 3*(K-1)) channel-major, per 3DGS f_rest = f[:, 1:, :].transpose(0, 2, 1).reshape(n, -1) # (N, 3*(K-1)) channel-major
op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(1e-6, 1 - 1e-6) op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(1e-6, 1 - 1e-6)
op = np.log(op / (1.0 - op)) # inverse sigmoid (logit) op = np.log(op / (1.0 - op)) # inverse sigmoid (logit)
scale = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-8)) scale = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-8))
@ -482,11 +482,11 @@ class SplatToFile3D(IO.ComfyNode):
"Supports one item per batch only.", "Supports one item per batch only.",
inputs=[ inputs=[
IO.Splat.Input("splat"), IO.Splat.Input("splat"),
IO.Combo.Input("format", options=["ply", "ksplat", "spz"], IO.Combo.Input("format", options=["ply", "ksplat", "spz"], # TODO: add "splat" when we have a writer for it
tooltip="ply: standard 3DGS with full spherical harmonics. " tooltip="ply: standard 3D Gaussian Splat with full spherical harmonics. "
"ksplat: mkkellogg SplatBuffer (level 0, uncompressed). " "ksplat: mkkellogg SplatBuffer (level 0, uncompressed), base color only "
"spz: Niantic gzip-compressed (~10x smaller). " "spz: Niantic gzip-compressed (~10x smaller), base color only "
"ksplat/spz keep base color only, view-dependent spherical harmonics is dropped."), ),
], ],
outputs=[IO.File3DAny.Output(display_name="model_3d")], outputs=[IO.File3DAny.Output(display_name="model_3d")],
) )
@ -512,7 +512,7 @@ class File3DToSplat(IO.ComfyNode):
category="3d/splat", category="3d/splat",
description="Parse a splat File3D into a gaussian splat. Inverse of Create 3D File (from Splat). " description="Parse a splat File3D into a gaussian splat. Inverse of Create 3D File (from Splat). "
"Supported format: PLY, SPLAT, KSPLAT, SPZ. PLY carries full spherical harmonics, " "Supported format: PLY, SPLAT, KSPLAT, SPZ. PLY carries full spherical harmonics, "
" the other formats are base color only. Format is auto-detected from the file contents.", "the other formats are base color only. Format is auto-detected from the file contents.",
inputs=[ inputs=[
IO.MultiType.Input( IO.MultiType.Input(
IO.File3DAny.Input("model_3d"), IO.File3DAny.Input("model_3d"),
@ -664,6 +664,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
# selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU. # selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU.
dev = comfy.model_management.get_torch_device() dev = comfy.model_management.get_torch_device()
t = lambda a: torch.as_tensor(a, dtype=torch.float32, device=dev) t = lambda a: torch.as_tensor(a, dtype=torch.float32, device=dev)
idev, idtype = comfy.model_management.intermediate_device(), comfy.model_management.intermediate_dtype()
xyz, rgb, opacity = t(xyz), t(rgb).clamp(0, 1), t(opacity).reshape(-1) xyz, rgb, opacity = t(xyz), t(rgb).clamp(0, 1), t(opacity).reshape(-1)
scale, rot = t(scale) * float(splat_scale), t(rot) scale, rot = t(scale) * float(splat_scale), t(rot)
do_linear = render_style == "color" # colour blends in linear light, re-encoded at the end do_linear = render_style == "color" # colour blends in linear light, re-encoded at the end
@ -677,7 +678,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
def background_only(): # no splats to rasterize -> just the background + empty mask def background_only(): # no splats to rasterize -> just the background + empty mask
img = bg_t.expand(height, width, 3) if render_style == "color" else torch.zeros(height, width, 3, device=dev) img = bg_t.expand(height, width, 3) if render_style == "color" else torch.zeros(height, width, 3, device=dev)
return img.cpu(), torch.zeros(height, width) return img.to(idev, idtype), torch.zeros(height, width, device=idev, dtype=idtype)
if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold) if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold)
return background_only() return background_only()
@ -688,8 +689,6 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
fov = float(camera_info.get("fov", 0) or 0) or 35.0 fov = float(camera_info.get("fov", 0) or 0) or 35.0
zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length
is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho") is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho")
cam_dist = float((target - eye).norm().clamp_min(1e-6)) # eye->target distance: sets the ortho pixel scale
yflip = 1.0 # splat +Y is image-down
xc, yc, zc = cam.unbind(-1) xc, yc, zc = cam.unbind(-1)
keep = zc > 1e-2 keep = zc > 1e-2
@ -700,8 +699,6 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry
f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom
s = f / cam_dist # ortho: pixels per world unit at the target plane
invz = 1.0 / zc
cx0, cy0 = width / 2, height / 2 cx0, cy0 = width / 2, height / 2
# Camera-space 3D covariance per splat: Sigma = (W Rq) diag(scale^2) (W Rq)^T, plus a tiny relative # Camera-space 3D covariance per splat: Sigma = (W Rq) diag(scale^2) (W Rq)^T, plus a tiny relative
@ -723,74 +720,93 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera) nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera)
# Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel). # Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel).
# The image is +y-down, so the projection's y row is unflipped - it matches the splat frame's +Y.
jm = torch.zeros(xc.shape[0], 2, 3, device=dev) jm = torch.zeros(xc.shape[0], 2, 3, device=dev)
if is_ortho: # parallel projection: screen = s * (xc, yc) if is_ortho: # parallel projection: screen = s * (xc, yc)
cx, cy = cx0 + s * xc, cy0 + yflip * s * yc s = f / float((target - eye).norm().clamp_min(1e-6)) # pixels per world unit at the target plane
cx, cy = cx0 + s * xc, cy0 + s * yc
jm[:, 0, 0] = s jm[:, 0, 0] = s
jm[:, 1, 1] = yflip * s jm[:, 1, 1] = s
else: # perspective: screen = f * (xc, yc) / zc else: # perspective: screen = f * (xc, yc) / zc
cx, cy = cx0 + f * xc * invz, cy0 + yflip * f * yc * invz invz = 1.0 / zc
cx, cy = cx0 + f * xc * invz, cy0 + f * yc * invz
jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square() jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square()
jm[:, 1, 1], jm[:, 1, 2] = yflip * f * invz, -yflip * f * yc * invz.square() jm[:, 1, 1], jm[:, 1, 2] = f * invz, -f * yc * invz.square()
cov2 = jm @ cam_cov @ jm.transpose(1, 2) cov2 = jm @ cam_cov @ jm.transpose(1, 2)
a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1] a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1]
max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt() max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt()
radius = 3.0 * max_eig.clamp_min(1e-8).sqrt() radius = 3.0 * max_eig.clamp_min(1e-8).sqrt()
K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item())))) K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item()))))
rng = torch.arange(-K, K + 1, device=dev, dtype=torch.float32)
oy, ox = torch.meshgrid(rng, rng, indexing="ij")
ox, oy = ox.reshape(-1), oy.reshape(-1) # (M,) kernel offsets
order = torch.argsort(zc) # front (small zc) -> back # Per-splat kernel size: bucket splats by radius into a ladder of kernel sizes (the global K stays the cap) and use
# a tiny window instead of the worst-case one
levels = [L for L in (4, 8, 16, 32, 64, 128, 256) if L < K] + [K]
levels_t = torch.tensor(levels, device=dev, dtype=torch.float32)
grids = []
for L in levels:
rng = torch.arange(-L, L + 1, device=dev, dtype=torch.float32)
gy, gx = torch.meshgrid(rng, rng, indexing="ij")
grids.append((gx.reshape(-1), gy.reshape(-1)))
blevel = torch.bucketize(radius * (4.0 / 3.0), levels_t).clamp_(max=len(levels) - 1) # window >= ~4 sigma
n = zc.shape[0]
ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped
order = torch.argsort(zc) # front (small zc) -> back -> defines the slabs
bounds = torch.linspace(0, n, ns + 1, device=dev).round().long()
rank = torch.empty(n, dtype=torch.long, device=dev)
rank[order] = torch.arange(n, device=dev) # depth rank of each splat
slab_id = (torch.searchsorted(bounds, rank, right=True) - 1).clamp_(0, ns - 1)
order = torch.argsort(slab_id * len(levels) + blevel) # group by slab, then kernel level (order-free within)
slab_bounds = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), torch.bincount(slab_id, minlength=ns).cumsum(0)]).tolist()
cxr, cyr = cx[order].round(), cy[order].round() cxr, cyr = cx[order].round(), cy[order].round()
s00, s01, s02 = s00[order], s01[order], s02[order] s00, s01, s02 = s00[order], s01[order], s02[order]
s11, s12, s22 = s11[order], s12[order], s22[order] s11, s12, s22 = s11[order], s12[order], s22[order]
simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order] simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order]
opacity, rgb = opacity[order], rgb[order] opacity, rgb = opacity[order], rgb[order]
blevel = blevel[order]
zc_o = zc[order] if need_depth else None zc_o = zc[order] if need_depth else None
nrm_o = nrm[order] if need_normal else None nrm_o = nrm[order] if need_normal else None
mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None) mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None)
def splat(lo, hi): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray def splat(lo, hi, ox, oy): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray
px = cxr[lo:hi, None] + ox[None, :] px = cxr[lo:hi, None] + ox[None, :]
py = cyr[lo:hi, None] + oy[None, :] py = cyr[lo:hi, None] + oy[None, :]
valid = (px >= 0) & (px < width) & (py >= 0) & (py < height) valid = (px >= 0) & (px < width) & (py >= 0) & (py < height)
if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0) if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0)
rx = (px - cx0) / s - mux_o[lo:hi, None] rx = (px - cx0) / s - mux_o[lo:hi, None]
ry = yflip * (py - cy0) / s - muy_o[lo:hi, None] ry = (py - cy0) / s - muy_o[lo:hi, None]
rz = -muz_o[lo:hi, None] # constant per splat rz = -muz_o[lo:hi, None] # constant per splat
rSr = (s00[lo:hi, None] * rx * rx + s11[lo:hi, None] * ry * ry + s22[lo:hi, None] * rz * rz rSr = (s00[lo:hi, None] * rx * rx + s11[lo:hi, None] * ry * ry + s22[lo:hi, None] * rz * rz
+ 2 * (s01[lo:hi, None] * rx * ry + s02[lo:hi, None] * rx * rz + s12[lo:hi, None] * ry * rz)) + 2 * (s01[lo:hi, None] * rx * ry + s02[lo:hi, None] * rx * rz + s12[lo:hi, None] * ry * rz))
dsr = s02[lo:hi, None] * rx + s12[lo:hi, None] * ry + s22[lo:hi, None] * rz dsr = s02[lo:hi, None] * rx + s12[lo:hi, None] * ry + s22[lo:hi, None] * rz
q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min(0) q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min_(0)
else: # perspective ray (dx,dy,1) through the camera origin else: # perspective ray (dx,dy,1) through the camera origin
dx, dy = (px - cx0) / f, yflip * (py - cy0) / f dx, dy = (px - cx0) / f, (py - cy0) / f
dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None] dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None]
+ 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy)) + 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy))
dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None] dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None]
q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min(0) # ray->centre Mahalanobis^2 q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min_(0) # ray->centre Mahalanobis^2
alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp(0, 0.999) alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999)
idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1) idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1)
return idx, alpha return idx, alpha
# Front-to-back compositing over many depth slabs (equal splat counts) -> the global depth order is # Front-to-back compositing over the depth slabs set up above. Within a slab the accumulation is a pure
# resolved finely, approaching a true per-pixel sort. # sum (order-independent), so splats are grouped by kernel level and each level uses its own tight window.
n = xc.shape[0] sharp = sharpen != 1.0 # winner-take-more colour blend: dominant splat shows more
ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped (result converges well below)
bounds = torch.linspace(0, n, ns + 1).round().long().tolist()
cacc = torch.zeros((flat, 3), device=dev) cacc = torch.zeros((flat, 3), device=dev)
trans = torch.ones((flat,), device=dev) trans = torch.ones((flat,), device=dev)
a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean) a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean)
tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) (order-independent) tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) (order-independent)
crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour
wbuf = torch.zeros((flat,), device=dev) # sum alpha^p -> colour normalizer (== a_buf when p==1) wbuf = torch.zeros((flat,), device=dev) if sharp else None # sum alpha^p -> colour normalizer (sharp only)
dacc = torch.zeros((flat,), device=dev) # front-weighted depth dacc = torch.zeros((flat,), device=dev) if need_depth else None # front-weighted depth
nacc = torch.zeros((flat, 3), device=dev) # front-weighted camera-space normal nacc = torch.zeros((flat, 3), device=dev) if need_normal else None # front-weighted camera-space normal
zslab = torch.zeros((flat,), device=dev) zslab = torch.zeros((flat,), device=dev) if need_depth else None
nslab = torch.zeros((flat, 3), device=dev) nslab = torch.zeros((flat, 3), device=dev) if need_normal else None
sharp = sharpen != 1.0 # winner-take-more colour blend: dominant splat shows more stale = 0 # consecutive fully-occluded slabs -> early-out
ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by kernel size (caps peak VRAM) for si in range(ns):
for s0, s1 in zip(bounds[:-1], bounds[1:]): s0, s1 = slab_bounds[si], slab_bounds[si + 1]
if s1 <= s0: if s1 <= s0:
continue continue
a_buf.zero_() a_buf.zero_()
@ -802,9 +818,15 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
zslab.zero_() zslab.zero_()
if need_normal: if need_normal:
nslab.zero_() nslab.zero_()
for lo in range(s0, s1, ch): lev = blevel[s0:s1] # kernel levels in this slab, sorted ascending
hi = min(lo + ch, s1) pos = s0
idx, alpha = splat(lo, hi) while pos < s1:
ox, oy = grids[int(lev[pos - s0])]
run_end = s0 + int(torch.searchsorted(lev, lev[pos - s0], right=True)) # contiguous same-level run
ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size
for lo in range(pos, run_end, ch):
hi = min(lo + ch, run_end)
idx, alpha = splat(lo, hi, ox, oy)
idx, af = idx.reshape(-1), alpha.reshape(-1) idx, af = idx.reshape(-1), alpha.reshape(-1)
a_buf.index_add_(0, idx, af) a_buf.index_add_(0, idx, af)
tau_buf.index_add_(0, idx, (-torch.log1p(-alpha)).reshape(-1)) # -ln(1-alpha), correct opacity merge tau_buf.index_add_(0, idx, (-torch.log1p(-alpha)).reshape(-1)) # -ln(1-alpha), correct opacity merge
@ -816,26 +838,33 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1)) zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1))
if need_normal: if need_normal:
nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3)) nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3))
pos = run_end
slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats
front = trans * slab_a front = trans * slab_a
ainv = a_buf.clamp_min(1e-8)
denom = wbuf if sharp else a_buf denom = wbuf if sharp else a_buf
cacc = cacc + front[:, None] * (crgb / denom.clamp_min(1e-8)[:, None]) cacc.addcmul_(front[:, None], crgb / denom.clamp_min(1e-8)[:, None]) # cacc += front * (crgb/denom)
if need_depth or need_normal:
ainv = a_buf.clamp_min(1e-8) # alpha-weighted-mean normalizer (depth/normal only)
if need_depth: if need_depth:
dacc = dacc + front * (zslab / ainv) dacc.addcmul_(front, zslab / ainv)
if need_normal: if need_normal:
nacc = nacc + front[:, None] * (nslab / ainv[:, None]) nacc.addcmul_(front[:, None], nslab / ainv[:, None])
trans = trans * (1 - slab_a) trans.mul_(1 - slab_a)
if si % 8 == 7: # checkpoint every 8 slabs (a per-slab GPU sync would cost more)
if float(front.max()) < 1e-3: # this checkpoint slab is fully occluded by what is in front
stale += 1
if stale >= 2: # two occluded checkpoints running -> the rest are too -> stop
break
else:
stale = 0
cov = 1 - trans cov = 1 - trans
covg = cov.reshape(height, width) covg = cov.reshape(height, width)
covm = covg > 0.5 covm = covg > 0.5 if render_style in ("depth", "normal") else None # silhouette mask (depth/normal styles only)
depth_map = (dacc / cov.clamp_min(1e-6)).reshape(height, width) if need_depth else None depth_map = (dacc / cov.clamp_min(1e-6)).reshape(height, width) if need_depth else None
nrm_map = None nrm_map = None
if need_normal: if need_normal:
# Coverage-weighted Gaussian blur of the accumulated normals. Per-splat surfel normals (the thinnest # Per-splat surfel normals are jittery, so do a masked blur
# gaussian axis) are jittery where splats are near-isotropic, so blur (nacc) and the weight (cov)
# together and divide -- a masked blur that smooths the noise without bleeding across the silhouette.
nb = nacc.reshape(height, width, 3).permute(2, 0, 1)[None] nb = nacc.reshape(height, width, 3).permute(2, 0, 1)[None]
cb = cov.reshape(1, 1, height, width) cb = cov.reshape(1, 1, height, width)
nb, cb = _gauss_blur(nb, 1.2, dev), _gauss_blur(cb, 1.2, dev) nb, cb = _gauss_blur(nb, 1.2, dev), _gauss_blur(cb, 1.2, dev)
@ -849,12 +878,12 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
d = torch.where(covm, ((hi - depth_map) / (hi - lo).clamp_min(1e-6)).clamp(0, 1), d) d = torch.where(covm, ((hi - depth_map) / (hi - lo).clamp_min(1e-6)).clamp(0, 1), d)
img = d[:, :, None].expand(height, width, 3) img = d[:, :, None].expand(height, width, 3)
elif render_style == "normal": # OpenGL normal map: +X right, +Y up, +Z to viewer elif render_style == "normal": # OpenGL normal map: +X right, +Y up, +Z to viewer
enc = (nrm_map * t([1.0, -yflip, -1.0]) * 0.5 + 0.5).clamp(0, 1) enc = (nrm_map * t([1.0, -1.0, -1.0]) * 0.5 + 0.5).clamp(0, 1)
img = enc * covm[:, :, None] # black background (masked out) img = enc * covm[:, :, None]
else: # color / clay else: # color / clay
img = cacc.reshape(height, width, 3) img = cacc.reshape(height, width, 3)
if render_style == "clay": # studio key light + ambient -> sculpted matte look if render_style == "clay": # studio key light + ambient -> sculpted matte look
kl = t([-0.4, -0.7 * yflip, -0.6]) # key from screen upper-left, angled toward the viewer kl = t([-0.4, -0.7, -0.6]) # key from screen upper-left, angled toward the viewer
kl = kl / kl.norm() kl = kl / kl.norm()
hl = (0.5 * (nrm_map * kl).sum(-1) + 0.5).clamp(0, 1) # half-Lambert: soft terminator, no harsh dark side hl = (0.5 * (nrm_map * kl).sum(-1) + 0.5).clamp(0, 1) # half-Lambert: soft terminator, no harsh dark side
img = img * (0.35 + 0.65 * hl * hl)[:, :, None] # ambient floor + diffuse key img = img * (0.35 + 0.65 * hl * hl)[:, :, None] # ambient floor + diffuse key
@ -862,10 +891,10 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
k = float(headlight_shading) k = float(headlight_shading)
ndotl = (-nrm_map[:, :, 2]).clamp(0, 1) ndotl = (-nrm_map[:, :, 2]).clamp(0, 1)
img = img * (1 - 0.6 * k + 0.6 * k * ndotl)[:, :, None] img = img * (1 - 0.6 * k + 0.6 * k * ndotl)[:, :, None]
img = img + trans.reshape(height, width, 1) * bg_comp img = img.addcmul_(trans.reshape(height, width, 1), bg_comp)
if do_linear: # back to display space after linear compositing if do_linear: # back to display space after linear compositing
img = _linear_to_srgb(img) img = _linear_to_srgb(img)
return img.clamp(0, 1).cpu(), covg.clamp(0, 1).cpu() return img.clamp(0, 1).to(idev, idtype), covg.clamp(0, 1).to(idev, idtype)
class RenderSplat(IO.ComfyNode): class RenderSplat(IO.ComfyNode):
@ -901,9 +930,8 @@ class RenderSplat(IO.ComfyNode):
IO.Float.Input("opacity_threshold", default=0.0, min=0.0, max=1.0, step=0.01, advanced=True, IO.Float.Input("opacity_threshold", default=0.0, min=0.0, max=1.0, step=0.01, advanced=True,
tooltip="Cull gaussians with opacity below this (removes faint floaters)."), tooltip="Cull gaussians with opacity below this (removes faint floaters)."),
IO.Combo.Input("render_style", options=["color", "clay", "depth", "normal"], IO.Combo.Input("render_style", options=["color", "clay", "depth", "normal"],
tooltip="What the image output shows: color (beauty), clay (neutral-albedo shaded - " tooltip="What the image output shows: color, clay (neutral-albedo shaded), "
"pure geometry), depth (near=bright), normal (OpenGL normal map). The mask " "depth (near=bright), normal (OpenGL normal map)."),
"output always carries the coverage regardless of this."),
IO.Color.Input("background", default="#000000"), IO.Color.Input("background", default="#000000"),
IO.Image.Input("bg_image", optional=True, IO.Image.Input("bg_image", optional=True,
tooltip="Optional background plate composited behind the splat (overrides the solid " tooltip="Optional background plate composited behind the splat (overrides the solid "
@ -913,8 +941,7 @@ class RenderSplat(IO.ComfyNode):
tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera " tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera "
"Info node. If empty, the splat is auto-framed from a default 3/4 view."), "Info node. If empty, the splat is auto-framed from a default 3/4 view."),
], ],
outputs=[IO.Image.Output(display_name="image"), outputs=[IO.Image.Output(display_name="image"), IO.Mask.Output(display_name="mask")],
IO.Mask.Output(display_name="mask")],
) )
@classmethod @classmethod
@ -928,7 +955,7 @@ class RenderSplat(IO.ComfyNode):
n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still) n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still)
orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction
imgs, masks = [], [] imgs, masks = [], []
device = comfy.model_management.get_torch_device() # render device; splat stays in torch here -> no roundtrip device = comfy.model_management.get_torch_device()
total = splat.positions.shape[0] * n_frames total = splat.positions.shape[0] * n_frames
pbar = comfy.utils.ProgressBar(total) if total > 1 else None pbar = comfy.utils.ProgressBar(total) if total > 1 else None
k = 0 k = 0
@ -1014,8 +1041,7 @@ class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file
) )
@classmethod @classmethod
def execute(cls, mode, target_x, target_y, target_z, roll, fov, def execute(cls, mode, target_x, target_y, target_z, roll, fov, zoom=1.0, camera_type="perspective") -> IO.NodeOutput:
zoom=1.0, camera_type="perspective") -> IO.NodeOutput:
dev = comfy.model_management.get_torch_device() dev = comfy.model_management.get_torch_device()
kind = mode["mode"] kind = mode["mode"]
if kind == "quaternion": # explicit world position + camera rotation if kind == "quaternion": # explicit world position + camera rotation
@ -1030,8 +1056,7 @@ class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file
position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy] position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy]
else: # look_at: explicit world-space camera position else: # look_at: explicit world-space camera position
position = [mode["position_x"], mode["position_y"], mode["position_z"]] position = [mode["position_x"], mode["position_y"], mode["position_z"]]
return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev, return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev, zoom=zoom, camera_type=camera_type, roll=roll))
zoom=zoom, camera_type=camera_type, roll=roll))
class TransformSplat(IO.ComfyNode): class TransformSplat(IO.ComfyNode):
@ -1198,7 +1223,8 @@ def _inverse_covariance(scale, quat):
return torch.einsum("nij,nj,nkj->nik", R, inv_s2, R) return torch.einsum("nij,nj,nkj->nik", R, inv_s2, R)
def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sharpen=1.0, chunk=4096, progress=None): def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sharpen=1.0, chunk=4096, progress=None,
col_dtype=torch.float16):
# Splat each gaussian as its oriented-covariance disk (3-sigma, opacity-weighted) into a density grid, # Splat each gaussian as its oriented-covariance disk (3-sigma, opacity-weighted) into a density grid,
# plus a colour volume. Each gaussian uses a voxel window sized to its OWN 3-sigma (capped at `kernel`). # plus a colour volume. Each gaussian uses a voxel window sized to its OWN 3-sigma (capped at `kernel`).
# Colour is weighted by w^color_sharpen: >1 biases each voxel toward its dominant gaussian (crisper # Colour is weighted by w^color_sharpen: >1 biases each voxel toward its dominant gaussian (crisper
@ -1214,8 +1240,8 @@ def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sh
kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width
sharp = color_sharpen != 1.0 sharp = color_sharpen != 1.0
vol = torch.zeros(dx * dy * dz, device=device) # Sum(w) density (surface) vol = torch.zeros(dx * dy * dz, device=device) # Sum(w) density (surface)
colvol = torch.zeros(dx * dy * dz, 3, device=device) # Sum(w^p * rgb) colour numerator colvol = torch.zeros(dx * dy * dz, 3, device=device, dtype=col_dtype) # Sum(w^p * rgb) colour numerator
wcol = torch.zeros(dx * dy * dz, device=device) if sharp else None # Sum(w^p) colour normaliser (p>1) wcol = torch.zeros(dx * dy * dz, device=device, dtype=col_dtype) if sharp else None # Sum(w^p) normaliser (p>1)
n, done = xyz.shape[0], 0 n, done = xyz.shape[0], 0
for k in range(1, int(kernel) + 1): for k in range(1, int(kernel) + 1):
sel = (kreq == k).nonzero(as_tuple=True)[0] sel = (kreq == k).nonzero(as_tuple=True)[0]
@ -1238,9 +1264,9 @@ def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sh
flat = (ix * (dy * dz) + iy * dz + iz).reshape(-1) flat = (ix * (dy * dz) + iy * dz + iz).reshape(-1)
vol.index_add_(0, flat, wgt.reshape(-1)) vol.index_add_(0, flat, wgt.reshape(-1))
wp = wgt.pow(color_sharpen) if sharp else wgt # winner-take-more colour weight wp = wgt.pow(color_sharpen) if sharp else wgt # winner-take-more colour weight
colvol.index_add_(0, flat, (wp[..., None] * rgb[gi, None, :]).reshape(-1, 3)) colvol.index_add_(0, flat, (wp[..., None] * rgb[gi, None, :]).reshape(-1, 3).to(col_dtype))
if sharp: if sharp:
wcol.index_add_(0, flat, wp.reshape(-1)) wcol.index_add_(0, flat, wp.reshape(-1).to(col_dtype))
done += gi.numel() done += gi.numel()
if progress is not None: if progress is not None:
progress(min(1.0, done / max(1, n))) progress(min(1.0, done / max(1, n)))
@ -1248,9 +1274,63 @@ def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sh
return vol.reshape(dx, dy, dz), colvol.reshape(dx, dy, dz, 3), colnorm, lo.cpu().numpy(), float(voxel) return vol.reshape(dx, dy, dz), colvol.reshape(dx, dy, dz, 3), colnorm, lo.cpu().numpy(), float(voxel)
def _clean_components(verts, faces, min_verts): def _connected_components_gpu(faces, nv):
# FastSV connected components: grandparent hooking + shortcutting, ~O(log nv) iterations.
# Returns per-vertex component labels (min node id, not densified).
a = torch.cat([faces[:, 0], faces[:, 1]]) # 2F edge endpoints: (v0,v1),(v1,v2)
b = torch.cat([faces[:, 1], faces[:, 2]])
f = torch.arange(nv, device=faces.device)
while True:
gp = f[f] # grandparent
ga, gb = gp[a], gp[b]
new = f.clone()
new.scatter_reduce_(0, f[a], gb, "amin", include_self=True) # stochastic hooking onto roots
new.scatter_reduce_(0, f[b], ga, "amin", include_self=True)
new.scatter_reduce_(0, a, gb, "amin", include_self=True) # aggressive hooking, both directions
new.scatter_reduce_(0, b, ga, "amin", include_self=True)
new = new[new] # shortcut (path compression)
if torch.equal(new, f):
return f
f = new
def _clean_components_gpu(verts, faces, min_verts, device):
# GPU port of _clean_components: FastSV components + scatter reductions. Byte-identical to the numpy path
vt = torch.as_tensor(verts, device=device)
ft = torch.as_tensor(faces, device=device)
nv = vt.shape[0]
_, label = torch.unique(_connected_components_gpu(ft, nv), return_inverse=True) # dense 0..ncomp-1
ncomp = int(label.max()) + 1
flabel = label[ft[:, 0]] # component id per face
keep = torch.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate
if int(keep.sum()) > 1:
fcount = torch.bincount(flabel, minlength=ncomp)
largest = int(torch.where(keep, fcount, fcount.new_tensor(-1)).argmax())
v0, v1, v2 = vt[ft[:, 0]], vt[ft[:, 1]], vt[ft[:, 2]]
cvol = torch.zeros(ncomp, device=device).scatter_add_(0, flabel, (v0 * torch.linalg.cross(v1, v2)).sum(-1))
idx3 = label[:, None].expand(-1, 3) # per-component vertex bbox
cmin = torch.full((ncomp, 3), float("inf"), device=device).scatter_reduce_(0, idx3, vt, "amin", include_self=True)
cmax = torch.full((ncomp, 3), float("-inf"), device=device).scatter_reduce_(0, idx3, vt, "amax", include_self=True)
tol = 1e-4 * (cmax[largest] - cmin[largest]).max()
enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1)
inner = enclosed & (torch.sign(cvol) != torch.sign(cvol[largest])) & (torch.arange(ncomp, device=device) != largest)
keep &= ~inner
faces_k = ft[keep[flabel]]
if faces_k.shape[0] == 0:
return verts[:0], faces[:0]
used = torch.unique(faces_k) # sorted, matches np.unique
remap = torch.full((nv,), -1, dtype=torch.int64, device=device)
remap[used] = torch.arange(used.shape[0], device=device)
return vt[used].cpu().numpy(), remap[faces_k].cpu().numpy()
def _clean_components(verts, faces, min_verts, device=None):
# Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density # Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density
# extracts a double wall (outer + inner cavity surface) # extracts a double wall (outer + inner cavity surface). GPU path (FastSV CC + scatter reductions, ~13x
# faster) when an accelerator has headroom; else numpy/scipy. Both produce byte-identical output.
if device is not None and device.type != "cpu" and \
comfy.model_management.get_free_memory(device) > 10 * faces.size * 8: # peak ~8.4x faces bytes
return _clean_components_gpu(verts, faces, min_verts, device)
nv = len(verts) nv = len(verts)
e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0)
ncomp, label = connected_components(coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)), directed=False) ncomp, label = connected_components(coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)), directed=False)
@ -1261,10 +1341,9 @@ def _clean_components(verts, faces, min_verts):
largest = np.where(keep, fcount, -1).argmax() largest = np.where(keep, fcount, -1).argmax()
v0, v1, v2 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]] v0, v1, v2 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]]
cvol = np.bincount(flabel, weights=np.einsum("ij,ij->i", v0, np.cross(v1, v2)), minlength=ncomp) # 6*signed vol cvol = np.bincount(flabel, weights=np.einsum("ij,ij->i", v0, np.cross(v1, v2)), minlength=ncomp) # 6*signed vol
fmin, fmax = verts[faces].min(1), verts[faces].max(1) cidx = np.arange(ncomp) # per-component vertex bbox via ndimage (~6x faster than ufunc.at)
cmin, cmax = np.full((ncomp, 3), np.inf), np.full((ncomp, 3), -np.inf) cmin = np.stack([_ndi_minimum(verts[:, a], label, cidx) for a in range(3)], 1)
np.minimum.at(cmin, flabel, fmin) cmax = np.stack([_ndi_maximum(verts[:, a], label, cidx) for a in range(3)], 1)
np.maximum.at(cmax, flabel, fmax)
tol = 1e-4 * (cmax[largest] - cmin[largest]).max() tol = 1e-4 * (cmax[largest] - cmin[largest]).max()
enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1)
inner = enclosed & (np.sign(cvol) != np.sign(cvol[largest])) & (np.arange(ncomp) != largest) inner = enclosed & (np.sign(cvol) != np.sign(cvol[largest])) & (np.arange(ncomp) != largest)
@ -1301,27 +1380,35 @@ def _surface_nets(vol, level, voxel, origin, device):
return empty return empty
# Active cells only (a thin shell): each dual vertex = mean of its 12 edges' zero-crossings. # Active cells only (a thin shell): each dual vertex = mean of its 12 edges' zero-crossings.
del any_in, all_in, cs8 # corner bool grids no longer needed
ac = active.nonzero(as_tuple=False) # (nv,3) cell min-corner indices ac = active.nonzero(as_tuple=False) # (nv,3) cell min-corner indices
offs = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], offs = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]], device=device) [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]], device=device)
ci = ac[:, None, :] + offs[None] # (nv,8,3) offf = offs.to(torch.float32)
cval = vol[ci[..., 0], ci[..., 1], ci[..., 2]] # (nv,8) corner values
csl = cval >= level
edges = torch.tensor([[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], edges = torch.tensor([[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3],
[2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]], device=device) [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]], device=device)
e0, e1 = edges[:, 0], edges[:, 1] e0, e1 = edges[:, 0], edges[:, 1]
v0, v1 = cval[:, e0], cval[:, e1] # (nv,12) oe0, oe1 = offf[e0], offf[e1] # (12,3) edge endpoints
cross = csl[:, e0] != csl[:, e1]
cstep = 1 << 18 # chunk to bound peak memory (CPU RAM too)
loc = []
for st in range(0, nv, cstep):
ci = ac[st:st + cstep, None, :] + offs[None] # (m,8,3)
cval = vol[ci[..., 0], ci[..., 1], ci[..., 2]] # (m,8) corner values
csl = cval >= level
v0, v1 = cval[:, e0], cval[:, e1] # (m,12)
cross = (csl[:, e0] != csl[:, e1])[..., None].to(torch.float32)
denom = v1 - v0 denom = v1 - v0
t = torch.where(denom.abs() > 1e-12, (level - v0) / denom, torch.full_like(denom, 0.5)).clamp(0, 1) t = torch.where(denom.abs() > 1e-12, (level - v0) / denom, torch.full_like(denom, 0.5)).clamp(0, 1)
offf = offs.to(torch.float32) pts = torch.lerp(oe0, oe1, t[..., None]) # (m,12,3) local crossings (fused interp)
pts = offf[e0] + t[..., None] * (offf[e1] - offf[e0]) # (nv,12,3) local crossings loc.append((pts * cross).sum(1) / cross.sum(1).clamp_min(1.0)) # (m,3) in [0,1]
cf = cross[..., None].to(pts.dtype) local = torch.cat(loc, 0) if len(loc) > 1 else loc[0] # (nv,3)
local = (pts * cf).sum(1) / cf.sum(1).clamp_min(1.0) # (nv,3) local vertex in [0,1]
verts = origin_t + (ac.to(torch.float32) + local) * voxel # world space verts = origin_t + (ac.to(torch.float32) + local) * voxel # world space
del loc, local, ac
vid = torch.full((dx - 1, dy - 1, dz - 1), -1, dtype=torch.int32, device=device) vid = torch.full((dx - 1, dy - 1, dz - 1), -1, dtype=torch.int32, device=device)
vid[active] = torch.arange(nv, dtype=torch.int32, device=device) vid[active] = torch.arange(nv, dtype=torch.int32, device=device)
del active
# Each straddling grid edge -> one quad from its 4 cells; `sol` (low-end sign) picks outward winding. # Each straddling grid edge -> one quad from its 4 cells; `sol` (low-end sign) picks outward winding.
faces = [] faces = []
@ -1376,14 +1463,34 @@ def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53):
nv = len(verts) nv = len(verts)
e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0)
e = np.concatenate([e, e[:, ::-1]], 0) # symmetric adjacency e = np.concatenate([e, e[:, ::-1]], 0) # symmetric adjacency
adj = coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)).tocsr() adj = coo_matrix((np.ones(len(e), np.float32), (e[:, 0], e[:, 1])), shape=(nv, nv)).tocsr()
adj.data[:] = 1.0 adj.data[:] = 1.0
deg = np.clip(np.asarray(adj.sum(1)).ravel(), 1.0, None)[:, None] deg = np.clip(np.asarray(adj.sum(1)).ravel(), 1.0, None).astype(np.float32)[:, None]
v = verts.astype(np.float64) v = verts.astype(np.float32) # fp32 matvec: ~2x faster, sub-micron drift on unit-scale verts
for _ in range(int(iters)): for _ in range(int(iters)):
for fac in (lam, mu): for fac in (lam, mu):
v = v + fac * ((adj @ v) / deg - v) # fac * (mean(neighbours) - v) v = v + np.float32(fac) * ((adj @ v) / deg - v) # fac * (mean(neighbours) - v)
return np.ascontiguousarray(v.astype(np.float32)) return np.ascontiguousarray(v)
def _sample_vertex_colours_gpu(colvol, colnorm, verts, origin, voxel, device):
# GPU trilinear sampling of the colour numerator (3ch) and normaliser (1ch) at vertex grid-coords
# reproduces scipy map_coordinates(order=1, mode='nearest'). Returns col (V,3) numpy.
dx, dy, dz = colnorm.shape
vt = torch.as_tensor(verts, device=device, dtype=torch.float32)
org = torch.as_tensor(origin, device=device, dtype=torch.float32)
gi = (vt - org) / voxel # (V,3) grid-index coords (x,y,z)
size = torch.tensor([dx, dy, dz], device=device, dtype=torch.float32)
g = 2.0 * gi / (size - 1).clamp_min(1.0) - 1.0 # -> [-1,1] (align_corners)
grid = torch.stack([g[:, 2], g[:, 1], g[:, 0]], -1)[None, None, None] # (1,1,1,V,3): grid_sample order (W=z,H=y,D=x)
def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32
inp = v.permute(3, 0, 1, 2).contiguous()[None].to(device=device, dtype=torch.float32)
o = torch.nn.functional.grid_sample(inp, grid, mode="bilinear", padding_mode="border", align_corners=True)
return o[0, :, 0, 0, :]
num = samp(colvol) # (3,V)
den = samp(colnorm[..., None]) # (1,V)
return (num / den.clamp_min(1e-8)).T.cpu().numpy() # (V,3)
def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None): def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None):
@ -1406,24 +1513,37 @@ def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_co
vol, colvol, colnorm, origin, voxel = _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, vol, colvol, colnorm, origin, voxel = _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device,
color_sharpen=color_sharpen, color_sharpen=color_sharpen,
progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25% progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25%
vol_np = vol.cpu().numpy() # Sum(w) density (for the surface) # Colour: sample on the GPU (grid_sample) when there's headroom
colvol_np = colvol.cpu().numpy() # Sum(w^p * rgb) colour numerator colour_gpu = device.type != "cpu" and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4
colnorm_np = colnorm.cpu().numpy() # Sum(w^p) colour normaliser if colour_gpu:
colvol_cpu, colnorm_cpu = colvol.cpu(), colnorm.half().cpu() # park colours (fp16) off-GPU during meshing
colvol_np = colnorm_np = None
else:
colvol_np = colvol.cpu().numpy().astype(np.float32) # Sum(w^p * rgb) colour numerator (fp16 grid -> fp32)
colnorm_np = colnorm.cpu().numpy().astype(np.float32) # Sum(w^p) colour normaliser
del colvol, colnorm # free the colour grids before iso-surfacing
rep(0.40) rep(0.40)
occ = vol_np[vol_np > vol_np.max() * 1e-3] # occupied voxels (skip the empty-space peak) vmin, vmax = float(vol.min()), float(vol.max())
if occ.size == 0: occ = vol[vol > vmax * 1e-3] # occupied voxels (skip the empty-space peak)
if occ.numel() == 0:
return None return None
# Otsu picks the inside/outside split principledly; `level_bias` nudges it (1.0 = auto). Clamp strictly # Otsu picks the inside/outside split principledly; `level_bias` nudges it (1.0 = auto). Clamp strictly
# inside the data range so a bias can't push the iso off the histogram (the old None / "no surface" bug). # inside the data range so a bias can't push the iso off the histogram.
lo, hi = float(vol_np.min()), float(vol_np.max()) level = min(max(_otsu_level(occ.cpu().numpy()) * level_bias, vmin + 1e-6 * (vmax - vmin)),
level = min(max(_otsu_level(occ) * level_bias, lo + 1e-6 * (hi - lo)), hi - 1e-6 * (hi - lo)) vmax - 1e-6 * (vmax - vmin))
# Surface Nets on CPU: the grid is already on CPU, and this keeps iso-surfacing off the GPU's VRAM. # Iso-surface on the accelerator when there's headroom: ~15x faster than CPU, identical output. Chunked
verts, faces = _surface_nets(torch.from_numpy(vol_np), level, voxel, origin, torch.device("cpu")) # Surface Nets peaks at ~3-3.5x the density grid, so fall back to CPU for large grids / tight VRAM.
sn_dev = device
if device.type != "cpu" and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4:
sn_dev = torch.device("cpu")
vol = vol.cpu()
verts, faces = _surface_nets(vol, level, voxel, origin, sn_dev)
del vol
rep(0.55) rep(0.55)
if min_component > 0 and len(faces) > 0: if min_component > 0 and len(faces) > 0:
verts, faces = _clean_components(verts, faces, min_component) verts, faces = _clean_components(verts, faces, min_component, device)
if len(verts) == 0 or len(faces) == 0: if len(verts) == 0 or len(faces) == 0:
return None return None
@ -1434,6 +1554,9 @@ def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_co
# Colour each vertex from the co-splatted colour volume: trilinearly sample the numerator Sum(w^p*rgb) # Colour each vertex from the co-splatted colour volume: trilinearly sample the numerator Sum(w^p*rgb)
# and normaliser Sum(w^p) separately, then divide. Normalising AFTER interpolation keeps zero-density # and normaliser Sum(w^p) separately, then divide. Normalising AFTER interpolation keeps zero-density
# edge voxels from pulling colours toward black, and matches the gaussians that formed the surface. # edge voxels from pulling colours toward black, and matches the gaussians that formed the surface.
if colour_gpu:
col = _sample_vertex_colours_gpu(colvol_cpu, colnorm_cpu, verts, origin, voxel, device)
else:
coords = ((verts - origin) / voxel).T # (3, V) grid-index coords, matching volume axes coords = ((verts - origin) / voxel).T # (3, V) grid-index coords, matching volume axes
num = np.stack([map_coordinates(colvol_np[..., c], coords, order=1, mode="nearest") for c in range(3)], -1) num = np.stack([map_coordinates(colvol_np[..., c], coords, order=1, mode="nearest") for c in range(3)], -1)
den = map_coordinates(colnorm_np, coords, order=1, mode="nearest") den = map_coordinates(colnorm_np, coords, order=1, mode="nearest")
@ -1462,7 +1585,7 @@ class SplatToMesh(IO.ComfyNode):
description="Extract a coloured mesh from a gaussian splat.", description="Extract a coloured mesh from a gaussian splat.",
inputs=[ inputs=[
IO.Splat.Input("splat"), IO.Splat.Input("splat"),
IO.Int.Input("resolution", default=384, min=64, max=1024, step=16, IO.Int.Input("resolution", default=384, min=64, max=768, step=16,
tooltip="Density-grid resolution along the longest axis. Higher = finer surface, " tooltip="Density-grid resolution along the longest axis. Higher = finer surface, "
"more VRAM/time (grows with resolution^3)."), "more VRAM/time (grows with resolution^3)."),
IO.Int.Input("kernel", default=5, min=1, max=8, IO.Int.Input("kernel", default=5, min=1, max=8,