Optimize rendering and meshing speed

Multiple times faster
This commit is contained in:
kijai 2026-05-31 15:04:05 +03:00
parent 7188dca1f1
commit 87f5b6a692

View File

@ -9,7 +9,7 @@ from io import BytesIO
import numpy as np
import torch
from typing_extensions import override
from scipy.ndimage import map_coordinates
from scipy.ndimage import map_coordinates, minimum as _ndi_minimum, maximum as _ndi_maximum
from scipy.sparse import coo_matrix
from scipy.sparse.csgraph import connected_components
@ -55,7 +55,7 @@ def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes:
"""Serialize render-ready gaussian tensors as a binary 3DGS .ply.
positions (N,3) world; scales (N,3) linear; rotations (N,4) quat wxyz; opacities (N,1) in [0,1];
sh (N,K,3) SH coefficients. Activated values are inverted to the standard 3DGS storage convention
sh (N,K,3) SH coefficients. Activated values are inverted to the standard 3D gaussian splat storage convention
(log scale, logit opacity).
"""
xyz = positions.cpu().numpy().astype(np.float32)
@ -65,7 +65,7 @@ def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes:
normals = np.zeros_like(xyz)
f = sh.cpu().numpy().astype(np.float32) # (N, K, 3)
f_dc = f[:, 0, :] # (N, 3)
f_rest = f[:, 1:, :].transpose(0, 2, 1).reshape(n, -1) # (N, 3*(K-1)) channel-major, per 3DGS
f_rest = f[:, 1:, :].transpose(0, 2, 1).reshape(n, -1) # (N, 3*(K-1)) channel-major
op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(1e-6, 1 - 1e-6)
op = np.log(op / (1.0 - op)) # inverse sigmoid (logit)
scale = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-8))
@ -87,8 +87,8 @@ def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes:
# then N 44-byte records. Bucketing/quantization only exist at levels >= 1. See SplatBuffer.js.
_KSPLAT_HEADER_BYTES = 4096
_KSPLAT_SECTION_HEADER_BYTES = 1024
_KSPLAT_BYTES_PER_SPLAT = 44 # center 12 + scale 12 + rotation 16 + color(RGBA u8) 4
_KSPLAT_VERSION = (0, 1) # SplatBuffer CurrentMajor/MinorVersion
_KSPLAT_BYTES_PER_SPLAT = 44 # center 12 + scale 12 + rotation 16 + color(RGBA u8) 4
_KSPLAT_VERSION = (0, 1) # SplatBuffer CurrentMajor/MinorVersion
def _gaussian_ksplat_bytes(positions, scales, rotations, opacities, sh) -> bytes:
@ -482,11 +482,11 @@ class SplatToFile3D(IO.ComfyNode):
"Supports one item per batch only.",
inputs=[
IO.Splat.Input("splat"),
IO.Combo.Input("format", options=["ply", "ksplat", "spz"],
tooltip="ply: standard 3DGS with full spherical harmonics. "
"ksplat: mkkellogg SplatBuffer (level 0, uncompressed). "
"spz: Niantic gzip-compressed (~10x smaller). "
"ksplat/spz keep base color only, view-dependent spherical harmonics is dropped."),
IO.Combo.Input("format", options=["ply", "ksplat", "spz"], # TODO: add "splat" when we have a writer for it
tooltip="ply: standard 3D Gaussian Splat with full spherical harmonics. "
"ksplat: mkkellogg SplatBuffer (level 0, uncompressed), base color only "
"spz: Niantic gzip-compressed (~10x smaller), base color only "
),
],
outputs=[IO.File3DAny.Output(display_name="model_3d")],
)
@ -512,7 +512,7 @@ class File3DToSplat(IO.ComfyNode):
category="3d/splat",
description="Parse a splat File3D into a gaussian splat. Inverse of Create 3D File (from Splat). "
"Supported format: PLY, SPLAT, KSPLAT, SPZ. PLY carries full spherical harmonics, "
" the other formats are base color only. Format is auto-detected from the file contents.",
"the other formats are base color only. Format is auto-detected from the file contents.",
inputs=[
IO.MultiType.Input(
IO.File3DAny.Input("model_3d"),
@ -664,133 +664,149 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
# selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU.
dev = comfy.model_management.get_torch_device()
t = lambda a: torch.as_tensor(a, dtype=torch.float32, device=dev)
idev, idtype = comfy.model_management.intermediate_device(), comfy.model_management.intermediate_dtype()
xyz, rgb, opacity = t(xyz), t(rgb).clamp(0, 1), t(opacity).reshape(-1)
scale, rot = t(scale) * float(splat_scale), t(rot)
do_linear = render_style == "color" # colour blends in linear light, re-encoded at the end
do_linear = render_style == "color" # colour blends in linear light, re-encoded at the end
if do_linear:
rgb = _srgb_to_linear(rgb)
flat = width * height
bg_t = t(bg)
bg_comp = _srgb_to_linear(bg_t) if do_linear else bg_t # background blended in the same space as the splats
bg_comp = _srgb_to_linear(bg_t) if do_linear else bg_t # background blended in the same space as the splats
need_depth = render_style == "depth"
need_normal = render_style in ("normal", "clay") or headlight_shading > 0
def background_only(): # no splats to rasterize -> just the background + empty mask
def background_only(): # no splats to rasterize -> just the background + empty mask
img = bg_t.expand(height, width, 3) if render_style == "color" else torch.zeros(height, width, 3, device=dev)
return img.cpu(), torch.zeros(height, width)
return img.to(idev, idtype), torch.zeros(height, width, device=idev, dtype=idtype)
if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold)
if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold)
return background_only()
eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info
W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera)
eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info
W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera)
cam = (xyz - eye) @ W.T
fov = float(camera_info.get("fov", 0) or 0) or 35.0
zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length
zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length
is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho")
cam_dist = float((target - eye).norm().clamp_min(1e-6)) # eye->target distance: sets the ortho pixel scale
yflip = 1.0 # splat +Y is image-down
xc, yc, zc = cam.unbind(-1)
keep = zc > 1e-2
xc, yc, zc, rgb, opacity, scale, rot = (a[keep] for a in (xc, yc, zc, rgb, opacity, scale, rot))
if xc.shape[0] == 0: # nothing in front of the camera -> background only
if xc.shape[0] == 0: # nothing in front of the camera -> background only
return background_only()
if render_style == "clay":
rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry
rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry
f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom
s = f / cam_dist # ortho: pixels per world unit at the target plane
invz = 1.0 / zc
f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom
cx0, cy0 = width / 2, height / 2
# Camera-space 3D covariance per splat: Sigma = (W Rq) diag(scale^2) (W Rq)^T, plus a tiny relative
# regularizer for a stable inverse (a pixel-size Mip low-pass would over-thicken flat surfels and blur).
Mw = W[None] @ _quat_to_mat(rot) # (N,3,3) world -> camera
Mw = W[None] @ _quat_to_mat(rot) # (N,3,3) world -> camera
cam_cov = (Mw * scale.square()[:, None, :]) @ Mw.transpose(1, 2)
cam_cov = cam_cov + (cam_cov.diagonal(dim1=-2, dim2=-1).mean(-1) * 1e-3)[:, None, None] * torch.eye(3, device=dev)
# Perspective-correct weighting: peak of the 3D Gaussian along each pixel ray. Precompute Si, Si@mu, mu^T Si mu.
mu = torch.stack([xc, yc, zc], -1)
si = torch.linalg.inv(cam_cov)
simu = (si @ mu[:, :, None])[:, :, 0] # (N,3)
musimu = (mu * simu).sum(-1) # (N,)
simu = (si @ mu[:, :, None])[:, :, 0] # (N,3)
musimu = (mu * simu).sum(-1) # (N,)
s00, s01, s02 = si[:, 0, 0], si[:, 0, 1], si[:, 0, 2]
s11, s12, s22 = si[:, 1, 1], si[:, 1, 2], si[:, 2, 2]
simu0, simu1, simu2 = simu.unbind(-1)
if need_normal: # surfel normal = thinnest axis, oriented toward camera
nrm = Mw[torch.arange(Mw.shape[0], device=dev), :, scale.argmin(-1)] # (N,3) camera-space normal
nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera)
if need_normal: # surfel normal = thinnest axis, oriented toward camera
nrm = Mw[torch.arange(Mw.shape[0], device=dev), :, scale.argmin(-1)] # (N,3) camera-space normal
nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera)
# Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel).
# The image is +y-down, so the projection's y row is unflipped - it matches the splat frame's +Y.
jm = torch.zeros(xc.shape[0], 2, 3, device=dev)
if is_ortho: # parallel projection: screen = s * (xc, yc)
cx, cy = cx0 + s * xc, cy0 + yflip * s * yc
if is_ortho: # parallel projection: screen = s * (xc, yc)
s = f / float((target - eye).norm().clamp_min(1e-6)) # pixels per world unit at the target plane
cx, cy = cx0 + s * xc, cy0 + s * yc
jm[:, 0, 0] = s
jm[:, 1, 1] = yflip * s
else: # perspective: screen = f * (xc, yc) / zc
cx, cy = cx0 + f * xc * invz, cy0 + yflip * f * yc * invz
jm[:, 1, 1] = s
else: # perspective: screen = f * (xc, yc) / zc
invz = 1.0 / zc
cx, cy = cx0 + f * xc * invz, cy0 + f * yc * invz
jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square()
jm[:, 1, 1], jm[:, 1, 2] = yflip * f * invz, -yflip * f * yc * invz.square()
jm[:, 1, 1], jm[:, 1, 2] = f * invz, -f * yc * invz.square()
cov2 = jm @ cam_cov @ jm.transpose(1, 2)
a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1]
max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt()
radius = 3.0 * max_eig.clamp_min(1e-8).sqrt()
K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item()))))
rng = torch.arange(-K, K + 1, device=dev, dtype=torch.float32)
oy, ox = torch.meshgrid(rng, rng, indexing="ij")
ox, oy = ox.reshape(-1), oy.reshape(-1) # (M,) kernel offsets
order = torch.argsort(zc) # front (small zc) -> back
# Per-splat kernel size: bucket splats by radius into a ladder of kernel sizes (the global K stays the cap) and use
# a tiny window instead of the worst-case one
levels = [L for L in (4, 8, 16, 32, 64, 128, 256) if L < K] + [K]
levels_t = torch.tensor(levels, device=dev, dtype=torch.float32)
grids = []
for L in levels:
rng = torch.arange(-L, L + 1, device=dev, dtype=torch.float32)
gy, gx = torch.meshgrid(rng, rng, indexing="ij")
grids.append((gx.reshape(-1), gy.reshape(-1)))
blevel = torch.bucketize(radius * (4.0 / 3.0), levels_t).clamp_(max=len(levels) - 1) # window >= ~4 sigma
n = zc.shape[0]
ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped
order = torch.argsort(zc) # front (small zc) -> back -> defines the slabs
bounds = torch.linspace(0, n, ns + 1, device=dev).round().long()
rank = torch.empty(n, dtype=torch.long, device=dev)
rank[order] = torch.arange(n, device=dev) # depth rank of each splat
slab_id = (torch.searchsorted(bounds, rank, right=True) - 1).clamp_(0, ns - 1)
order = torch.argsort(slab_id * len(levels) + blevel) # group by slab, then kernel level (order-free within)
slab_bounds = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), torch.bincount(slab_id, minlength=ns).cumsum(0)]).tolist()
cxr, cyr = cx[order].round(), cy[order].round()
s00, s01, s02 = s00[order], s01[order], s02[order]
s11, s12, s22 = s11[order], s12[order], s22[order]
simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order]
opacity, rgb = opacity[order], rgb[order]
blevel = blevel[order]
zc_o = zc[order] if need_depth else None
nrm_o = nrm[order] if need_normal else None
mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None)
def splat(lo, hi): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray
def splat(lo, hi, ox, oy): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray
px = cxr[lo:hi, None] + ox[None, :]
py = cyr[lo:hi, None] + oy[None, :]
valid = (px >= 0) & (px < width) & (py >= 0) & (py < height)
if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0)
if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0)
rx = (px - cx0) / s - mux_o[lo:hi, None]
ry = yflip * (py - cy0) / s - muy_o[lo:hi, None]
rz = -muz_o[lo:hi, None] # constant per splat
ry = (py - cy0) / s - muy_o[lo:hi, None]
rz = -muz_o[lo:hi, None] # constant per splat
rSr = (s00[lo:hi, None] * rx * rx + s11[lo:hi, None] * ry * ry + s22[lo:hi, None] * rz * rz
+ 2 * (s01[lo:hi, None] * rx * ry + s02[lo:hi, None] * rx * rz + s12[lo:hi, None] * ry * rz))
dsr = s02[lo:hi, None] * rx + s12[lo:hi, None] * ry + s22[lo:hi, None] * rz
q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min(0)
else: # perspective ray (dx,dy,1) through the camera origin
dx, dy = (px - cx0) / f, yflip * (py - cy0) / f
q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min_(0)
else: # perspective ray (dx,dy,1) through the camera origin
dx, dy = (px - cx0) / f, (py - cy0) / f
dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None]
+ 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy))
dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None]
q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min(0) # ray->centre Mahalanobis^2
alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp(0, 0.999)
q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min_(0) # ray->centre Mahalanobis^2
alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999)
idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1)
return idx, alpha
# Front-to-back compositing over many depth slabs (equal splat counts) -> the global depth order is
# resolved finely, approaching a true per-pixel sort.
n = xc.shape[0]
ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped (result converges well below)
bounds = torch.linspace(0, n, ns + 1).round().long().tolist()
# Front-to-back compositing over the depth slabs set up above. Within a slab the accumulation is a pure
# sum (order-independent), so splats are grouped by kernel level and each level uses its own tight window.
sharp = sharpen != 1.0 # winner-take-more colour blend: dominant splat shows more
cacc = torch.zeros((flat, 3), device=dev)
trans = torch.ones((flat,), device=dev)
a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean)
tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) (order-independent)
crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour
wbuf = torch.zeros((flat,), device=dev) # sum alpha^p -> colour normalizer (== a_buf when p==1)
dacc = torch.zeros((flat,), device=dev) # front-weighted depth
nacc = torch.zeros((flat, 3), device=dev) # front-weighted camera-space normal
zslab = torch.zeros((flat,), device=dev)
nslab = torch.zeros((flat, 3), device=dev)
sharp = sharpen != 1.0 # winner-take-more colour blend: dominant splat shows more
ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by kernel size (caps peak VRAM)
for s0, s1 in zip(bounds[:-1], bounds[1:]):
a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean)
tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) (order-independent)
crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour
wbuf = torch.zeros((flat,), device=dev) if sharp else None # sum alpha^p -> colour normalizer (sharp only)
dacc = torch.zeros((flat,), device=dev) if need_depth else None # front-weighted depth
nacc = torch.zeros((flat, 3), device=dev) if need_normal else None # front-weighted camera-space normal
zslab = torch.zeros((flat,), device=dev) if need_depth else None
nslab = torch.zeros((flat, 3), device=dev) if need_normal else None
stale = 0 # consecutive fully-occluded slabs -> early-out
for si in range(ns):
s0, s1 = slab_bounds[si], slab_bounds[si + 1]
if s1 <= s0:
continue
a_buf.zero_()
@ -802,70 +818,83 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
zslab.zero_()
if need_normal:
nslab.zero_()
for lo in range(s0, s1, ch):
hi = min(lo + ch, s1)
idx, alpha = splat(lo, hi)
idx, af = idx.reshape(-1), alpha.reshape(-1)
a_buf.index_add_(0, idx, af)
tau_buf.index_add_(0, idx, (-torch.log1p(-alpha)).reshape(-1)) # -ln(1-alpha), correct opacity merge
apw = alpha.pow(sharpen) if sharp else alpha # bias colour toward the highest-alpha splat
crgb.index_add_(0, idx, (apw[:, :, None] * rgb[lo:hi, None, :]).reshape(-1, 3))
if sharp:
wbuf.index_add_(0, idx, apw.reshape(-1))
if need_depth:
zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1))
if need_normal:
nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3))
slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats
lev = blevel[s0:s1] # kernel levels in this slab, sorted ascending
pos = s0
while pos < s1:
ox, oy = grids[int(lev[pos - s0])]
run_end = s0 + int(torch.searchsorted(lev, lev[pos - s0], right=True)) # contiguous same-level run
ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size
for lo in range(pos, run_end, ch):
hi = min(lo + ch, run_end)
idx, alpha = splat(lo, hi, ox, oy)
idx, af = idx.reshape(-1), alpha.reshape(-1)
a_buf.index_add_(0, idx, af)
tau_buf.index_add_(0, idx, (-torch.log1p(-alpha)).reshape(-1)) # -ln(1-alpha), correct opacity merge
apw = alpha.pow(sharpen) if sharp else alpha # bias colour toward the highest-alpha splat
crgb.index_add_(0, idx, (apw[:, :, None] * rgb[lo:hi, None, :]).reshape(-1, 3))
if sharp:
wbuf.index_add_(0, idx, apw.reshape(-1))
if need_depth:
zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1))
if need_normal:
nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3))
pos = run_end
slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats
front = trans * slab_a
ainv = a_buf.clamp_min(1e-8)
denom = wbuf if sharp else a_buf
cacc = cacc + front[:, None] * (crgb / denom.clamp_min(1e-8)[:, None])
if need_depth:
dacc = dacc + front * (zslab / ainv)
if need_normal:
nacc = nacc + front[:, None] * (nslab / ainv[:, None])
trans = trans * (1 - slab_a)
cacc.addcmul_(front[:, None], crgb / denom.clamp_min(1e-8)[:, None]) # cacc += front * (crgb/denom)
if need_depth or need_normal:
ainv = a_buf.clamp_min(1e-8) # alpha-weighted-mean normalizer (depth/normal only)
if need_depth:
dacc.addcmul_(front, zslab / ainv)
if need_normal:
nacc.addcmul_(front[:, None], nslab / ainv[:, None])
trans.mul_(1 - slab_a)
if si % 8 == 7: # checkpoint every 8 slabs (a per-slab GPU sync would cost more)
if float(front.max()) < 1e-3: # this checkpoint slab is fully occluded by what is in front
stale += 1
if stale >= 2: # two occluded checkpoints running -> the rest are too -> stop
break
else:
stale = 0
cov = 1 - trans
covg = cov.reshape(height, width)
covm = covg > 0.5
covm = covg > 0.5 if render_style in ("depth", "normal") else None # silhouette mask (depth/normal styles only)
depth_map = (dacc / cov.clamp_min(1e-6)).reshape(height, width) if need_depth else None
nrm_map = None
if need_normal:
# Coverage-weighted Gaussian blur of the accumulated normals. Per-splat surfel normals (the thinnest
# gaussian axis) are jittery where splats are near-isotropic, so blur (nacc) and the weight (cov)
# together and divide -- a masked blur that smooths the noise without bleeding across the silhouette.
# Per-splat surfel normals are jittery, so do a masked blur
nb = nacc.reshape(height, width, 3).permute(2, 0, 1)[None]
cb = cov.reshape(1, 1, height, width)
nb, cb = _gauss_blur(nb, 1.2, dev), _gauss_blur(cb, 1.2, dev)
normal = (nb / cb.clamp_min(1e-6))[0].permute(1, 2, 0)
nrm_map = normal / normal.norm(dim=-1, keepdim=True).clamp_min(1e-6)
if render_style == "depth": # near = bright, far = dark, 0 off-object
if render_style == "depth": # near = bright, far = dark, 0 off-object
d = torch.zeros(height, width, device=dev)
if bool(covm.any()):
lo, hi = depth_map[covm].min(), depth_map[covm].max()
d = torch.where(covm, ((hi - depth_map) / (hi - lo).clamp_min(1e-6)).clamp(0, 1), d)
img = d[:, :, None].expand(height, width, 3)
elif render_style == "normal": # OpenGL normal map: +X right, +Y up, +Z to viewer
enc = (nrm_map * t([1.0, -yflip, -1.0]) * 0.5 + 0.5).clamp(0, 1)
img = enc * covm[:, :, None] # black background (masked out)
else: # color / clay
elif render_style == "normal": # OpenGL normal map: +X right, +Y up, +Z to viewer
enc = (nrm_map * t([1.0, -1.0, -1.0]) * 0.5 + 0.5).clamp(0, 1)
img = enc * covm[:, :, None]
else: # color / clay
img = cacc.reshape(height, width, 3)
if render_style == "clay": # studio key light + ambient -> sculpted matte look
kl = t([-0.4, -0.7 * yflip, -0.6]) # key from screen upper-left, angled toward the viewer
if render_style == "clay": # studio key light + ambient -> sculpted matte look
kl = t([-0.4, -0.7, -0.6]) # key from screen upper-left, angled toward the viewer
kl = kl / kl.norm()
hl = (0.5 * (nrm_map * kl).sum(-1) + 0.5).clamp(0, 1) # half-Lambert: soft terminator, no harsh dark side
img = img * (0.35 + 0.65 * hl * hl)[:, :, None] # ambient floor + diffuse key
elif headlight_shading > 0: # camera headlight: darken faces turned from view
hl = (0.5 * (nrm_map * kl).sum(-1) + 0.5).clamp(0, 1) # half-Lambert: soft terminator, no harsh dark side
img = img * (0.35 + 0.65 * hl * hl)[:, :, None] # ambient floor + diffuse key
elif headlight_shading > 0: # camera headlight: darken faces turned from view
k = float(headlight_shading)
ndotl = (-nrm_map[:, :, 2]).clamp(0, 1)
img = img * (1 - 0.6 * k + 0.6 * k * ndotl)[:, :, None]
img = img + trans.reshape(height, width, 1) * bg_comp
if do_linear: # back to display space after linear compositing
img = img.addcmul_(trans.reshape(height, width, 1), bg_comp)
if do_linear: # back to display space after linear compositing
img = _linear_to_srgb(img)
return img.clamp(0, 1).cpu(), covg.clamp(0, 1).cpu()
return img.clamp(0, 1).to(idev, idtype), covg.clamp(0, 1).to(idev, idtype)
class RenderSplat(IO.ComfyNode):
@ -901,9 +930,8 @@ class RenderSplat(IO.ComfyNode):
IO.Float.Input("opacity_threshold", default=0.0, min=0.0, max=1.0, step=0.01, advanced=True,
tooltip="Cull gaussians with opacity below this (removes faint floaters)."),
IO.Combo.Input("render_style", options=["color", "clay", "depth", "normal"],
tooltip="What the image output shows: color (beauty), clay (neutral-albedo shaded - "
"pure geometry), depth (near=bright), normal (OpenGL normal map). The mask "
"output always carries the coverage regardless of this."),
tooltip="What the image output shows: color, clay (neutral-albedo shaded), "
"depth (near=bright), normal (OpenGL normal map)."),
IO.Color.Input("background", default="#000000"),
IO.Image.Input("bg_image", optional=True,
tooltip="Optional background plate composited behind the splat (overrides the solid "
@ -913,8 +941,7 @@ class RenderSplat(IO.ComfyNode):
tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera "
"Info node. If empty, the splat is auto-framed from a default 3/4 view."),
],
outputs=[IO.Image.Output(display_name="image"),
IO.Mask.Output(display_name="mask")],
outputs=[IO.Image.Output(display_name="image"), IO.Mask.Output(display_name="mask")],
)
@classmethod
@ -922,13 +949,13 @@ class RenderSplat(IO.ComfyNode):
opacity_threshold, background, render_style, camera_info=None, bg_image=None) -> IO.NodeOutput:
bg = _hex_to_rgb(background)
bg_imgs = None
if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3)
if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3)
bi = comfy.utils.common_upscale(bg_image.movedim(-1, 1), width, height, "bicubic", "disabled")
bg_imgs = bi.movedim(1, -1).clamp(0, 1)
n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still)
orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction
n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still)
orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction
imgs, masks = [], []
device = comfy.model_management.get_torch_device() # render device; splat stays in torch here -> no roundtrip
device = comfy.model_management.get_torch_device()
total = splat.positions.shape[0] * n_frames
pbar = comfy.utils.ProgressBar(total) if total > 1 else None
k = 0
@ -938,7 +965,7 @@ class RenderSplat(IO.ComfyNode):
keep = opacity >= opacity_threshold
xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep]
base_cam = camera_info
if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat
if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat
center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device)
extent = (_quantile((xyz - center).norm(dim=-1), 0.99).clamp_min(1e-4) if xyz.shape[0]
else torch.tensor(1.0, device=device))
@ -959,7 +986,7 @@ class RenderSplat(IO.ComfyNode):
return IO.NodeOutput(torch.stack(imgs), torch.stack(masks))
class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file
class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file
@classmethod
def define_schema(cls):
return IO.Schema(
@ -1014,24 +1041,22 @@ class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file
)
@classmethod
def execute(cls, mode, target_x, target_y, target_z, roll, fov,
zoom=1.0, camera_type="perspective") -> IO.NodeOutput:
def execute(cls, mode, target_x, target_y, target_z, roll, fov, zoom=1.0, camera_type="perspective") -> IO.NodeOutput:
dev = comfy.model_management.get_torch_device()
kind = mode["mode"]
if kind == "quaternion": # explicit world position + camera rotation
if kind == "quaternion": # explicit world position + camera rotation
position = [mode["position_x"], mode["position_y"], mode["position_z"]]
quat = [mode["quat_x"], mode["quat_y"], mode["quat_z"], mode["quat_w"]]
return IO.NodeOutput(_quat_camera_info(position, quat, fov, dev, zoom=zoom, camera_type=camera_type))
target = [target_x, target_y, target_z] # orbit pivot / aim; move it to pan the whole camera
if kind == "orbit": # yaw/pitch/distance about the target (world Y-up)
target = [target_x, target_y, target_z] # orbit pivot / aim; move it to pan the whole camera
if kind == "orbit": # yaw/pitch/distance about the target (world Y-up)
y, p = math.radians(mode["yaw"]), math.radians(mode["pitch"])
cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p)
d = mode["distance"]
position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy]
else: # look_at: explicit world-space camera position
else: # look_at: explicit world-space camera position
position = [mode["position_x"], mode["position_y"], mode["position_z"]]
return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev,
zoom=zoom, camera_type=camera_type, roll=roll))
return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev, zoom=zoom, camera_type=camera_type, roll=roll))
class TransformSplat(IO.ComfyNode):
@ -1079,7 +1104,7 @@ class TransformSplat(IO.ComfyNode):
rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation
s2 = splat.scales.reshape(-1, 3).square()
cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma
cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats)
cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats)
lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes
V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation
scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape)
@ -1108,7 +1133,7 @@ class GetSplatCount(IO.ComfyNode):
@classmethod
def execute(cls, splat) -> IO.NodeOutput:
count = sum(_real_len(splat, i) for i in range(splat.positions.shape[0]))
if cls.hidden.unique_id: # show the count inline on the node
if cls.hidden.unique_id: # show the count inline on the node
PromptServer.instance.send_progress_text(f"{count:,} splats", cls.hidden.unique_id)
return IO.NodeOutput(splat, count)
@ -1142,8 +1167,8 @@ def _merge_gaussians(gaussians: list) -> Types.SPLAT:
scl_i.append(g.scales[i, :end])
rot_i.append(g.rotations[i, :end])
op_i.append(g.opacities[i, :end])
sh = g.sh[i, :end] # (end, K, 3)
if sh.shape[1] < max_k: # zero-pad lower-degree SH
sh = g.sh[i, :end] # (end, K, 3)
if sh.shape[1] < max_k: # zero-pad lower-degree SH
sh = torch.cat([sh, sh.new_zeros(sh.shape[0], max_k - sh.shape[1], sh.shape[2])], dim=1)
sh_i.append(sh)
pos_b.append(torch.cat(pos_i))
@ -1198,7 +1223,8 @@ def _inverse_covariance(scale, quat):
return torch.einsum("nij,nj,nkj->nik", R, inv_s2, R)
def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sharpen=1.0, chunk=4096, progress=None):
def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sharpen=1.0, chunk=4096, progress=None,
col_dtype=torch.float16):
# Splat each gaussian as its oriented-covariance disk (3-sigma, opacity-weighted) into a density grid,
# plus a colour volume. Each gaussian uses a voxel window sized to its OWN 3-sigma (capped at `kernel`).
# Colour is weighted by w^color_sharpen: >1 biases each voxel toward its dominant gaussian (crisper
@ -1211,11 +1237,11 @@ def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sh
dx, dy, dz = int(dims[0]), int(dims[1]), int(dims[2])
sinv = _inverse_covariance(scale, quat)
kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width
kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width
sharp = color_sharpen != 1.0
vol = torch.zeros(dx * dy * dz, device=device) # Sum(w) density (surface)
colvol = torch.zeros(dx * dy * dz, 3, device=device) # Sum(w^p * rgb) colour numerator
wcol = torch.zeros(dx * dy * dz, device=device) if sharp else None # Sum(w^p) colour normaliser (p>1)
vol = torch.zeros(dx * dy * dz, device=device) # Sum(w) density (surface)
colvol = torch.zeros(dx * dy * dz, 3, device=device, dtype=col_dtype) # Sum(w^p * rgb) colour numerator
wcol = torch.zeros(dx * dy * dz, device=device, dtype=col_dtype) if sharp else None # Sum(w^p) normaliser (p>1)
n, done = xyz.shape[0], 0
for k in range(1, int(kernel) + 1):
sel = (kreq == k).nonzero(as_tuple=True)[0]
@ -1238,9 +1264,9 @@ def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sh
flat = (ix * (dy * dz) + iy * dz + iz).reshape(-1)
vol.index_add_(0, flat, wgt.reshape(-1))
wp = wgt.pow(color_sharpen) if sharp else wgt # winner-take-more colour weight
colvol.index_add_(0, flat, (wp[..., None] * rgb[gi, None, :]).reshape(-1, 3))
colvol.index_add_(0, flat, (wp[..., None] * rgb[gi, None, :]).reshape(-1, 3).to(col_dtype))
if sharp:
wcol.index_add_(0, flat, wp.reshape(-1))
wcol.index_add_(0, flat, wp.reshape(-1).to(col_dtype))
done += gi.numel()
if progress is not None:
progress(min(1.0, done / max(1, n)))
@ -1248,9 +1274,63 @@ def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sh
return vol.reshape(dx, dy, dz), colvol.reshape(dx, dy, dz, 3), colnorm, lo.cpu().numpy(), float(voxel)
def _clean_components(verts, faces, min_verts):
def _connected_components_gpu(faces, nv):
# FastSV connected components: grandparent hooking + shortcutting, ~O(log nv) iterations.
# Returns per-vertex component labels (min node id, not densified).
a = torch.cat([faces[:, 0], faces[:, 1]]) # 2F edge endpoints: (v0,v1),(v1,v2)
b = torch.cat([faces[:, 1], faces[:, 2]])
f = torch.arange(nv, device=faces.device)
while True:
gp = f[f] # grandparent
ga, gb = gp[a], gp[b]
new = f.clone()
new.scatter_reduce_(0, f[a], gb, "amin", include_self=True) # stochastic hooking onto roots
new.scatter_reduce_(0, f[b], ga, "amin", include_self=True)
new.scatter_reduce_(0, a, gb, "amin", include_self=True) # aggressive hooking, both directions
new.scatter_reduce_(0, b, ga, "amin", include_self=True)
new = new[new] # shortcut (path compression)
if torch.equal(new, f):
return f
f = new
def _clean_components_gpu(verts, faces, min_verts, device):
# GPU port of _clean_components: FastSV components + scatter reductions. Byte-identical to the numpy path
vt = torch.as_tensor(verts, device=device)
ft = torch.as_tensor(faces, device=device)
nv = vt.shape[0]
_, label = torch.unique(_connected_components_gpu(ft, nv), return_inverse=True) # dense 0..ncomp-1
ncomp = int(label.max()) + 1
flabel = label[ft[:, 0]] # component id per face
keep = torch.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate
if int(keep.sum()) > 1:
fcount = torch.bincount(flabel, minlength=ncomp)
largest = int(torch.where(keep, fcount, fcount.new_tensor(-1)).argmax())
v0, v1, v2 = vt[ft[:, 0]], vt[ft[:, 1]], vt[ft[:, 2]]
cvol = torch.zeros(ncomp, device=device).scatter_add_(0, flabel, (v0 * torch.linalg.cross(v1, v2)).sum(-1))
idx3 = label[:, None].expand(-1, 3) # per-component vertex bbox
cmin = torch.full((ncomp, 3), float("inf"), device=device).scatter_reduce_(0, idx3, vt, "amin", include_self=True)
cmax = torch.full((ncomp, 3), float("-inf"), device=device).scatter_reduce_(0, idx3, vt, "amax", include_self=True)
tol = 1e-4 * (cmax[largest] - cmin[largest]).max()
enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1)
inner = enclosed & (torch.sign(cvol) != torch.sign(cvol[largest])) & (torch.arange(ncomp, device=device) != largest)
keep &= ~inner
faces_k = ft[keep[flabel]]
if faces_k.shape[0] == 0:
return verts[:0], faces[:0]
used = torch.unique(faces_k) # sorted, matches np.unique
remap = torch.full((nv,), -1, dtype=torch.int64, device=device)
remap[used] = torch.arange(used.shape[0], device=device)
return vt[used].cpu().numpy(), remap[faces_k].cpu().numpy()
def _clean_components(verts, faces, min_verts, device=None):
# Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density
# extracts a double wall (outer + inner cavity surface)
# extracts a double wall (outer + inner cavity surface). GPU path (FastSV CC + scatter reductions, ~13x
# faster) when an accelerator has headroom; else numpy/scipy. Both produce byte-identical output.
if device is not None and device.type != "cpu" and \
comfy.model_management.get_free_memory(device) > 10 * faces.size * 8: # peak ~8.4x faces bytes
return _clean_components_gpu(verts, faces, min_verts, device)
nv = len(verts)
e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0)
ncomp, label = connected_components(coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)), directed=False)
@ -1261,10 +1341,9 @@ def _clean_components(verts, faces, min_verts):
largest = np.where(keep, fcount, -1).argmax()
v0, v1, v2 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]]
cvol = np.bincount(flabel, weights=np.einsum("ij,ij->i", v0, np.cross(v1, v2)), minlength=ncomp) # 6*signed vol
fmin, fmax = verts[faces].min(1), verts[faces].max(1)
cmin, cmax = np.full((ncomp, 3), np.inf), np.full((ncomp, 3), -np.inf)
np.minimum.at(cmin, flabel, fmin)
np.maximum.at(cmax, flabel, fmax)
cidx = np.arange(ncomp) # per-component vertex bbox via ndimage (~6x faster than ufunc.at)
cmin = np.stack([_ndi_minimum(verts[:, a], label, cidx) for a in range(3)], 1)
cmax = np.stack([_ndi_maximum(verts[:, a], label, cidx) for a in range(3)], 1)
tol = 1e-4 * (cmax[largest] - cmin[largest]).max()
enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1)
inner = enclosed & (np.sign(cvol) != np.sign(cvol[largest])) & (np.arange(ncomp) != largest)
@ -1289,39 +1368,47 @@ def _surface_nets(vol, level, voxel, origin, device):
return empty
# Active = cells whose 8 corners aren't all in/all out.
inside = vol >= level # (dx,dy,dz) bool
inside = vol >= level # (dx,dy,dz) bool
cs8 = [inside[ox:ox + dx - 1, oy:oy + dy - 1, oz:oz + dz - 1]
for ox, oy, oz in ((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0),
(0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1))]
any_in = cs8[0] | cs8[1] | cs8[2] | cs8[3] | cs8[4] | cs8[5] | cs8[6] | cs8[7]
all_in = cs8[0] & cs8[1] & cs8[2] & cs8[3] & cs8[4] & cs8[5] & cs8[6] & cs8[7]
active = any_in & ~all_in # (cx,cy,cz) straddling cells
active = any_in & ~all_in # (cx,cy,cz) straddling cells
nv = int(active.sum())
if nv == 0:
return empty
# Active cells only (a thin shell): each dual vertex = mean of its 12 edges' zero-crossings.
ac = active.nonzero(as_tuple=False) # (nv,3) cell min-corner indices
del any_in, all_in, cs8 # corner bool grids no longer needed
ac = active.nonzero(as_tuple=False) # (nv,3) cell min-corner indices
offs = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]], device=device)
ci = ac[:, None, :] + offs[None] # (nv,8,3)
cval = vol[ci[..., 0], ci[..., 1], ci[..., 2]] # (nv,8) corner values
csl = cval >= level
offf = offs.to(torch.float32)
edges = torch.tensor([[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3],
[2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]], device=device)
e0, e1 = edges[:, 0], edges[:, 1]
v0, v1 = cval[:, e0], cval[:, e1] # (nv,12)
cross = csl[:, e0] != csl[:, e1]
denom = v1 - v0
t = torch.where(denom.abs() > 1e-12, (level - v0) / denom, torch.full_like(denom, 0.5)).clamp(0, 1)
offf = offs.to(torch.float32)
pts = offf[e0] + t[..., None] * (offf[e1] - offf[e0]) # (nv,12,3) local crossings
cf = cross[..., None].to(pts.dtype)
local = (pts * cf).sum(1) / cf.sum(1).clamp_min(1.0) # (nv,3) local vertex in [0,1]
verts = origin_t + (ac.to(torch.float32) + local) * voxel # world space
oe0, oe1 = offf[e0], offf[e1] # (12,3) edge endpoints
cstep = 1 << 18 # chunk to bound peak memory (CPU RAM too)
loc = []
for st in range(0, nv, cstep):
ci = ac[st:st + cstep, None, :] + offs[None] # (m,8,3)
cval = vol[ci[..., 0], ci[..., 1], ci[..., 2]] # (m,8) corner values
csl = cval >= level
v0, v1 = cval[:, e0], cval[:, e1] # (m,12)
cross = (csl[:, e0] != csl[:, e1])[..., None].to(torch.float32)
denom = v1 - v0
t = torch.where(denom.abs() > 1e-12, (level - v0) / denom, torch.full_like(denom, 0.5)).clamp(0, 1)
pts = torch.lerp(oe0, oe1, t[..., None]) # (m,12,3) local crossings (fused interp)
loc.append((pts * cross).sum(1) / cross.sum(1).clamp_min(1.0)) # (m,3) in [0,1]
local = torch.cat(loc, 0) if len(loc) > 1 else loc[0] # (nv,3)
verts = origin_t + (ac.to(torch.float32) + local) * voxel # world space
del loc, local, ac
vid = torch.full((dx - 1, dy - 1, dz - 1), -1, dtype=torch.int32, device=device)
vid[active] = torch.arange(nv, dtype=torch.int32, device=device)
del active
# Each straddling grid edge -> one quad from its 4 cells; `sol` (low-end sign) picks outward winding.
faces = []
@ -1358,12 +1445,12 @@ def _otsu_level(values, bins=256):
hist, edges = np.histogram(values, bins=bins)
hist = hist.astype(np.float64)
centers = (edges[:-1] + edges[1:]) * 0.5
w = np.cumsum(hist) # background-class weight at each split
w = np.cumsum(hist) # background-class weight at each split
mu = np.cumsum(hist * centers)
wf = w[-1] - w # foreground-class weight
wf = w[-1] - w # foreground-class weight
mb = mu / np.where(w > 0, w, 1.0)
mf = (mu[-1] - mu) / np.where(wf > 0, wf, 1.0)
var_b = w * wf * (mb - mf) ** 2 # between-class variance
var_b = w * wf * (mb - mf) ** 2 # between-class variance
var_b[(w <= 0) | (wf <= 0)] = -1.0
return float(centers[int(np.argmax(var_b))])
@ -1375,15 +1462,35 @@ def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53):
return verts
nv = len(verts)
e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0)
e = np.concatenate([e, e[:, ::-1]], 0) # symmetric adjacency
adj = coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)).tocsr()
e = np.concatenate([e, e[:, ::-1]], 0) # symmetric adjacency
adj = coo_matrix((np.ones(len(e), np.float32), (e[:, 0], e[:, 1])), shape=(nv, nv)).tocsr()
adj.data[:] = 1.0
deg = np.clip(np.asarray(adj.sum(1)).ravel(), 1.0, None)[:, None]
v = verts.astype(np.float64)
deg = np.clip(np.asarray(adj.sum(1)).ravel(), 1.0, None).astype(np.float32)[:, None]
v = verts.astype(np.float32) # fp32 matvec: ~2x faster, sub-micron drift on unit-scale verts
for _ in range(int(iters)):
for fac in (lam, mu):
v = v + fac * ((adj @ v) / deg - v) # fac * (mean(neighbours) - v)
return np.ascontiguousarray(v.astype(np.float32))
v = v + np.float32(fac) * ((adj @ v) / deg - v) # fac * (mean(neighbours) - v)
return np.ascontiguousarray(v)
def _sample_vertex_colours_gpu(colvol, colnorm, verts, origin, voxel, device):
# GPU trilinear sampling of the colour numerator (3ch) and normaliser (1ch) at vertex grid-coords
# reproduces scipy map_coordinates(order=1, mode='nearest'). Returns col (V,3) numpy.
dx, dy, dz = colnorm.shape
vt = torch.as_tensor(verts, device=device, dtype=torch.float32)
org = torch.as_tensor(origin, device=device, dtype=torch.float32)
gi = (vt - org) / voxel # (V,3) grid-index coords (x,y,z)
size = torch.tensor([dx, dy, dz], device=device, dtype=torch.float32)
g = 2.0 * gi / (size - 1).clamp_min(1.0) - 1.0 # -> [-1,1] (align_corners)
grid = torch.stack([g[:, 2], g[:, 1], g[:, 0]], -1)[None, None, None] # (1,1,1,V,3): grid_sample order (W=z,H=y,D=x)
def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32
inp = v.permute(3, 0, 1, 2).contiguous()[None].to(device=device, dtype=torch.float32)
o = torch.nn.functional.grid_sample(inp, grid, mode="bilinear", padding_mode="border", align_corners=True)
return o[0, :, 0, 0, :]
num = samp(colvol) # (3,V)
den = samp(colnorm[..., None]) # (1,V)
return (num / den.clamp_min(1e-8)).T.cpu().numpy() # (V,3)
def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None):
@ -1406,24 +1513,37 @@ def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_co
vol, colvol, colnorm, origin, voxel = _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device,
color_sharpen=color_sharpen,
progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25%
vol_np = vol.cpu().numpy() # Sum(w) density (for the surface)
colvol_np = colvol.cpu().numpy() # Sum(w^p * rgb) colour numerator
colnorm_np = colnorm.cpu().numpy() # Sum(w^p) colour normaliser
# Colour: sample on the GPU (grid_sample) when there's headroom
colour_gpu = device.type != "cpu" and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4
if colour_gpu:
colvol_cpu, colnorm_cpu = colvol.cpu(), colnorm.half().cpu() # park colours (fp16) off-GPU during meshing
colvol_np = colnorm_np = None
else:
colvol_np = colvol.cpu().numpy().astype(np.float32) # Sum(w^p * rgb) colour numerator (fp16 grid -> fp32)
colnorm_np = colnorm.cpu().numpy().astype(np.float32) # Sum(w^p) colour normaliser
del colvol, colnorm # free the colour grids before iso-surfacing
rep(0.40)
occ = vol_np[vol_np > vol_np.max() * 1e-3] # occupied voxels (skip the empty-space peak)
if occ.size == 0:
vmin, vmax = float(vol.min()), float(vol.max())
occ = vol[vol > vmax * 1e-3] # occupied voxels (skip the empty-space peak)
if occ.numel() == 0:
return None
# Otsu picks the inside/outside split principledly; `level_bias` nudges it (1.0 = auto). Clamp strictly
# inside the data range so a bias can't push the iso off the histogram (the old None / "no surface" bug).
lo, hi = float(vol_np.min()), float(vol_np.max())
level = min(max(_otsu_level(occ) * level_bias, lo + 1e-6 * (hi - lo)), hi - 1e-6 * (hi - lo))
# inside the data range so a bias can't push the iso off the histogram.
level = min(max(_otsu_level(occ.cpu().numpy()) * level_bias, vmin + 1e-6 * (vmax - vmin)),
vmax - 1e-6 * (vmax - vmin))
# Surface Nets on CPU: the grid is already on CPU, and this keeps iso-surfacing off the GPU's VRAM.
verts, faces = _surface_nets(torch.from_numpy(vol_np), level, voxel, origin, torch.device("cpu"))
# Iso-surface on the accelerator when there's headroom: ~15x faster than CPU, identical output. Chunked
# Surface Nets peaks at ~3-3.5x the density grid, so fall back to CPU for large grids / tight VRAM.
sn_dev = device
if device.type != "cpu" and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4:
sn_dev = torch.device("cpu")
vol = vol.cpu()
verts, faces = _surface_nets(vol, level, voxel, origin, sn_dev)
del vol
rep(0.55)
if min_component > 0 and len(faces) > 0:
verts, faces = _clean_components(verts, faces, min_component)
verts, faces = _clean_components(verts, faces, min_component, device)
if len(verts) == 0 or len(faces) == 0:
return None
@ -1434,10 +1554,13 @@ def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_co
# Colour each vertex from the co-splatted colour volume: trilinearly sample the numerator Sum(w^p*rgb)
# and normaliser Sum(w^p) separately, then divide. Normalising AFTER interpolation keeps zero-density
# edge voxels from pulling colours toward black, and matches the gaussians that formed the surface.
coords = ((verts - origin) / voxel).T # (3, V) grid-index coords, matching volume axes
num = np.stack([map_coordinates(colvol_np[..., c], coords, order=1, mode="nearest") for c in range(3)], -1)
den = map_coordinates(colnorm_np, coords, order=1, mode="nearest")
col = num / np.clip(den, 1e-8, None)[:, None]
if colour_gpu:
col = _sample_vertex_colours_gpu(colvol_cpu, colnorm_cpu, verts, origin, voxel, device)
else:
coords = ((verts - origin) / voxel).T # (3, V) grid-index coords, matching volume axes
num = np.stack([map_coordinates(colvol_np[..., c], coords, order=1, mode="nearest") for c in range(3)], -1)
den = map_coordinates(colnorm_np, coords, order=1, mode="nearest")
col = num / np.clip(den, 1e-8, None)[:, None]
rep(1.0)
# The unlit material's COLOR_0 is linear and the viewer sRGB-encodes it on output; the splat colours
@ -1462,7 +1585,7 @@ class SplatToMesh(IO.ComfyNode):
description="Extract a coloured mesh from a gaussian splat.",
inputs=[
IO.Splat.Input("splat"),
IO.Int.Input("resolution", default=384, min=64, max=1024, step=16,
IO.Int.Input("resolution", default=384, min=64, max=768, step=16,
tooltip="Density-grid resolution along the longest axis. Higher = finer surface, "
"more VRAM/time (grows with resolution^3)."),
IO.Int.Input("kernel", default=5, min=1, max=8,