mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-04 13:27:41 +08:00
Further improve splat render speed
~30% faster
This commit is contained in:
parent
cd47b31639
commit
c0f91c6782
@ -751,43 +751,65 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
||||
|
||||
n = zc.shape[0]
|
||||
ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped
|
||||
nl = len(levels)
|
||||
order = torch.argsort(zc) # front (small zc) -> back -> defines the slabs
|
||||
bounds = torch.linspace(0, n, ns + 1, device=dev).round().long()
|
||||
rank = torch.empty(n, dtype=torch.long, device=dev)
|
||||
rank[order] = torch.arange(n, device=dev) # depth rank of each splat
|
||||
slab_id = (torch.searchsorted(bounds, rank, right=True) - 1).clamp_(0, ns - 1)
|
||||
order = torch.argsort(slab_id * len(levels) + blevel) # group by slab, then kernel level (order-free within)
|
||||
slab_bounds = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), torch.bincount(slab_id, minlength=ns).cumsum(0)]).tolist()
|
||||
key = slab_id * nl + blevel # group by slab, then kernel level (order-free within)
|
||||
order = torch.argsort(key)
|
||||
key = key[order]
|
||||
|
||||
cxr, cyr = cx[order].round(), cy[order].round()
|
||||
s00, s01, s02 = s00[order], s01[order], s02[order]
|
||||
s11, s12, s22 = s11[order], s12[order], s22[order]
|
||||
s01b, s02b, s12b = s01 * 2, s02 * 2, s12 * 2 # doubled cross terms for the fused quadratic forms
|
||||
simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order]
|
||||
opacity, rgb = opacity[order], rgb[order]
|
||||
blevel = blevel[order]
|
||||
zc_o = zc[order] if need_depth else None
|
||||
nrm_o = nrm[order] if need_normal else None
|
||||
mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None)
|
||||
|
||||
# Pack the per-splat scalars into one tensor so each chunk slices once
|
||||
common = [cxr, cyr, s00, s11, s22, s01b, s02b, s12b, opacity]
|
||||
pstack = torch.stack(common + ([s02, s12, mux_o, muy_o, muz_o] if is_ortho else [simu0, simu1, simu2, musimu]))
|
||||
|
||||
# Precompute the (slab, level) run table on-GPU and pull it to the CPU once
|
||||
starts = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), (key[1:] != key[:-1]).nonzero().flatten() + 1])
|
||||
ks = key[starts]
|
||||
run_lo = starts.tolist() + [n]
|
||||
run_lev = (ks % nl).tolist()
|
||||
run_slab = torch.div(ks, nl, rounding_mode="floor").tolist()
|
||||
slab_runs = [[] for _ in range(ns)]
|
||||
for r in range(len(run_lev)):
|
||||
slab_runs[run_slab[r]].append((run_lo[r], run_lo[r + 1], run_lev[r]))
|
||||
|
||||
def splat(lo, hi, ox, oy): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray
|
||||
px = cxr[lo:hi, None] + ox[None, :]
|
||||
py = cyr[lo:hi, None] + oy[None, :]
|
||||
cols = pstack[:, lo:hi, None].unbind(0)
|
||||
cxr_, cyr_, a00, a11, a22, b01, b02, b12, opa = cols[:9] # a* = Si components; b* = 2 * cross terms
|
||||
px = cxr_ + ox[None, :]
|
||||
py = cyr_ + oy[None, :]
|
||||
valid = (px >= 0) & (px < width) & (py >= 0) & (py < height)
|
||||
if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0)
|
||||
rx = (px - cx0) / s - mux_o[lo:hi, None]
|
||||
ry = (py - cy0) / s - muy_o[lo:hi, None]
|
||||
rz = -muz_o[lo:hi, None] # constant per splat
|
||||
rSr = (s00[lo:hi, None] * rx * rx + s11[lo:hi, None] * ry * ry + s22[lo:hi, None] * rz * rz
|
||||
+ 2 * (s01[lo:hi, None] * rx * ry + s02[lo:hi, None] * rx * rz + s12[lo:hi, None] * ry * rz))
|
||||
dsr = s02[lo:hi, None] * rx + s12[lo:hi, None] * ry + s22[lo:hi, None] * rz
|
||||
q = (rSr - dsr * dsr / s22[lo:hi, None].clamp_min(1e-12)).clamp_min_(0)
|
||||
if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0); rz constant per splat
|
||||
c02, c12, mx, my, mz = cols[9:]
|
||||
rx = (px - cx0) / s - mx
|
||||
ry = (py - cy0) / s - my
|
||||
rz = -mz
|
||||
a22rz = a22 * rz
|
||||
inx = torch.addcmul(b02 * rz, a00, rx).addcmul_(b01, ry) # a00 rx + b01 ry + b02 rz
|
||||
rSr = torch.addcmul(a22rz * rz, rx, inx).addcmul_(ry, torch.addcmul(b12 * rz, a11, ry))
|
||||
dsr = torch.addcmul(a22rz, c02, rx).addcmul_(c12, ry)
|
||||
q = torch.addcdiv(rSr, dsr * dsr, a22.clamp_min(1e-12), value=-1).clamp_min_(0)
|
||||
else: # perspective ray (dx,dy,1) through the camera origin
|
||||
su0, su1, su2, mus = cols[9:]
|
||||
dx, dy = (px - cx0) / f, (py - cy0) / f
|
||||
dsid = (s00[lo:hi, None] * dx * dx + s11[lo:hi, None] * dy * dy + s22[lo:hi, None]
|
||||
+ 2 * (s01[lo:hi, None] * dx * dy + s02[lo:hi, None] * dx + s12[lo:hi, None] * dy))
|
||||
dsimu = dx * simu0[lo:hi, None] + dy * simu1[lo:hi, None] + simu2[lo:hi, None]
|
||||
q = (musimu[lo:hi, None] - dsimu * dsimu / dsid.clamp_min(1e-12)).clamp_min_(0) # ray->centre Mahalanobis^2
|
||||
alpha = (opacity[lo:hi, None] * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999)
|
||||
dsid = torch.addcmul(a22, dx, torch.addcmul(b02, a00, dx)) # a22 + dx*(a00 dx + b02)
|
||||
dsid = dsid.addcmul_(dy, torch.addcmul(b12, a11, dy)) # + dy*(a11 dy + b12)
|
||||
dsid = dsid.addcmul_(b01 * dx, dy) # + (2 s01) dx dy
|
||||
dsimu = torch.addcmul(su2, dx, su0).addcmul_(dy, su1)
|
||||
q = torch.addcdiv(mus, dsimu * dsimu, dsid.clamp_min(1e-12), value=-1).clamp_min_(0)
|
||||
alpha = (opa * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999)
|
||||
idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1)
|
||||
return idx, alpha
|
||||
|
||||
@ -797,7 +819,7 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
||||
cacc = torch.zeros((flat, 3), device=dev)
|
||||
trans = torch.ones((flat,), device=dev)
|
||||
a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean)
|
||||
tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) (order-independent)
|
||||
tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha)
|
||||
crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour
|
||||
wbuf = torch.zeros((flat,), device=dev) if sharp else None # sum alpha^p -> colour normalizer (sharp only)
|
||||
dacc = torch.zeros((flat,), device=dev) if need_depth else None # front-weighted depth
|
||||
@ -806,8 +828,8 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
||||
nslab = torch.zeros((flat, 3), device=dev) if need_normal else None
|
||||
stale = 0 # consecutive fully-occluded slabs -> early-out
|
||||
for si in range(ns):
|
||||
s0, s1 = slab_bounds[si], slab_bounds[si + 1]
|
||||
if s1 <= s0:
|
||||
runs = slab_runs[si]
|
||||
if not runs:
|
||||
continue
|
||||
a_buf.zero_()
|
||||
tau_buf.zero_()
|
||||
@ -818,14 +840,11 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
||||
zslab.zero_()
|
||||
if need_normal:
|
||||
nslab.zero_()
|
||||
lev = blevel[s0:s1] # kernel levels in this slab, sorted ascending
|
||||
pos = s0
|
||||
while pos < s1:
|
||||
ox, oy = grids[int(lev[pos - s0])]
|
||||
run_end = s0 + int(torch.searchsorted(lev, lev[pos - s0], right=True)) # contiguous same-level run
|
||||
ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size
|
||||
for lo in range(pos, run_end, ch):
|
||||
hi = min(lo + ch, run_end)
|
||||
for r_lo, r_hi, li in runs: # contiguous same-kernel-level runs in this slab
|
||||
ox, oy = grids[li]
|
||||
ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size
|
||||
for lo in range(r_lo, r_hi, ch):
|
||||
hi = min(lo + ch, r_hi)
|
||||
idx, alpha = splat(lo, hi, ox, oy)
|
||||
idx, af = idx.reshape(-1), alpha.reshape(-1)
|
||||
a_buf.index_add_(0, idx, af)
|
||||
@ -838,7 +857,6 @@ def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale,
|
||||
zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1))
|
||||
if need_normal:
|
||||
nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3))
|
||||
pos = run_end
|
||||
slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats
|
||||
front = trans * slab_a
|
||||
denom = wbuf if sharp else a_buf
|
||||
@ -1233,8 +1251,7 @@ def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sh
|
||||
lo = xyz.amin(0) - pad
|
||||
hi = xyz.amax(0) + pad
|
||||
voxel = ((hi - lo).max() / res).clamp_min(1e-8)
|
||||
dims = (torch.ceil((hi - lo) / voxel).long() + 1).tolist()
|
||||
dx, dy, dz = int(dims[0]), int(dims[1]), int(dims[2])
|
||||
dx, dy, dz = (torch.ceil((hi - lo) / voxel).long() + 1).tolist()
|
||||
|
||||
sinv = _inverse_covariance(scale, quat)
|
||||
kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width
|
||||
@ -1328,7 +1345,7 @@ def _clean_components(verts, faces, min_verts, device=None):
|
||||
# Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density
|
||||
# extracts a double wall (outer + inner cavity surface). GPU path (FastSV CC + scatter reductions, ~13x
|
||||
# faster) when an accelerator has headroom; else numpy/scipy. Both produce byte-identical output.
|
||||
if device is not None and device.type != "cpu" and \
|
||||
if device is not None and not comfy.model_management.is_device_cpu(device) and \
|
||||
comfy.model_management.get_free_memory(device) > 10 * faces.size * 8: # peak ~8.4x faces bytes
|
||||
return _clean_components_gpu(verts, faces, min_verts, device)
|
||||
nv = len(verts)
|
||||
@ -1484,8 +1501,8 @@ def _sample_vertex_colours_gpu(colvol, colnorm, verts, origin, voxel, device):
|
||||
g = 2.0 * gi / (size - 1).clamp_min(1.0) - 1.0 # -> [-1,1] (align_corners)
|
||||
grid = torch.stack([g[:, 2], g[:, 1], g[:, 0]], -1)[None, None, None] # (1,1,1,V,3): grid_sample order (W=z,H=y,D=x)
|
||||
|
||||
def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32
|
||||
inp = v.permute(3, 0, 1, 2).contiguous()[None].to(device=device, dtype=torch.float32)
|
||||
def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32 on device
|
||||
inp = v.to(device).permute(3, 0, 1, 2)[None].float()
|
||||
o = torch.nn.functional.grid_sample(inp, grid, mode="bilinear", padding_mode="border", align_corners=True)
|
||||
return o[0, :, 0, 0, :]
|
||||
num = samp(colvol) # (3,V)
|
||||
@ -1514,7 +1531,7 @@ def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_co
|
||||
color_sharpen=color_sharpen,
|
||||
progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25%
|
||||
# Colour: sample on the GPU (grid_sample) when there's headroom
|
||||
colour_gpu = device.type != "cpu" and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4
|
||||
colour_gpu = not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4
|
||||
if colour_gpu:
|
||||
colvol_cpu, colnorm_cpu = colvol.cpu(), colnorm.half().cpu() # park colours (fp16) off-GPU during meshing
|
||||
colvol_np = colnorm_np = None
|
||||
@ -1536,7 +1553,7 @@ def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_co
|
||||
# Iso-surface on the accelerator when there's headroom: ~15x faster than CPU, identical output. Chunked
|
||||
# Surface Nets peaks at ~3-3.5x the density grid, so fall back to CPU for large grids / tight VRAM.
|
||||
sn_dev = device
|
||||
if device.type != "cpu" and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4:
|
||||
if not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4:
|
||||
sn_dev = torch.device("cpu")
|
||||
vol = vol.cpu()
|
||||
verts, faces = _surface_nets(vol, level, voxel, origin, sn_dev)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user